Commit 90795025 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fixes

parent 174abea7
......@@ -517,7 +517,7 @@ void fused_adam_cuda_no_overflow_check(
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_in), "parameter tensor is too large to be indexed with int32");
//Constants
float step_size = 0;
if (bias_correction == 1) {
......@@ -559,13 +559,13 @@ void fused_adam_cuda_no_overflow_check(
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_no_overflow_check_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p_in.DATA_PTR<accscalar_t>(),
p_out.DATA_PTR<accscalar_t>(),
p_in.DATA_PTR<scalar_t_0>(),
p_out.DATA_PTR<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write
m_in.DATA_PTR<accscalar_t>(),
m_out.DATA_PTR<accscalar_t>(),
v_in.DATA_PTR<accscalar_t>(),
v_out.DATA_PTR<accscalar_t>(),
m_in.DATA_PTR<scalar_t_0>(),
m_out.DATA_PTR<scalar_t_0>(),
v_in.DATA_PTR<scalar_t_0>(),
v_out.DATA_PTR<scalar_t_0>(),
g.DATA_PTR<scalar_t_0>(),
beta1,
beta2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment