Commit 174abea7 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fixes

parent dde13741
......@@ -513,7 +513,7 @@ void fused_adam_cuda_no_overflow_check(
// using namespace at;
//Get tensor size
int tsize = p.numel();
int tsize = p_in.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
......@@ -532,12 +532,12 @@ void fused_adam_cuda_no_overflow_check(
if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
AT_ASSERTM(p_in.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel_no_overflow_check<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
adam_cuda_no_overflow_check_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p_in.DATA_PTR<accscalar_t>(),
p_out.DATA_PTR<accscalar_t>(),
p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,
......@@ -558,7 +558,7 @@ void fused_adam_cuda_no_overflow_check(
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel_no_overflow_check<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
adam_cuda_no_overflow_check_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p_in.DATA_PTR<accscalar_t>(),
p_out.DATA_PTR<accscalar_t>(),
NULL, //don't output p_copy for fp32, it's wasted write
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment