Commit 90795025 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fixes

parent 174abea7
...@@ -517,7 +517,7 @@ void fused_adam_cuda_no_overflow_check( ...@@ -517,7 +517,7 @@ void fused_adam_cuda_no_overflow_check(
//Determine #threads and #blocks //Determine #threads and #blocks
const int threadsPerBlock = 512; const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock); const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_in), "parameter tensor is too large to be indexed with int32");
//Constants //Constants
float step_size = 0; float step_size = 0;
if (bias_correction == 1) { if (bias_correction == 1) {
...@@ -559,13 +559,13 @@ void fused_adam_cuda_no_overflow_check( ...@@ -559,13 +559,13 @@ void fused_adam_cuda_no_overflow_check(
using namespace at; using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_no_overflow_check_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_no_overflow_check_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p_in.DATA_PTR<accscalar_t>(), p_in.DATA_PTR<scalar_t_0>(),
p_out.DATA_PTR<accscalar_t>(), p_out.DATA_PTR<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write NULL, //don't output p_copy for fp32, it's wasted write
m_in.DATA_PTR<accscalar_t>(), m_in.DATA_PTR<scalar_t_0>(),
m_out.DATA_PTR<accscalar_t>(), m_out.DATA_PTR<scalar_t_0>(),
v_in.DATA_PTR<accscalar_t>(), v_in.DATA_PTR<scalar_t_0>(),
v_out.DATA_PTR<accscalar_t>(), v_out.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(), g.DATA_PTR<scalar_t_0>(),
beta1, beta1,
beta2, beta2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment