Commit d175acb0 authored by Michael Carilli's avatar Michael Carilli
Browse files

Removing instances of ScalarType, still need to change macros

parent d900e93c
...@@ -187,7 +187,7 @@ void fused_adam_cuda( ...@@ -187,7 +187,7 @@ void fused_adam_cuda(
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type //dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<accscalar_t>(), p.data<accscalar_t>(),
...@@ -256,9 +256,9 @@ void fused_adam_cuda_mt( ...@@ -256,9 +256,9 @@ void fused_adam_cuda_mt(
size_t tl_sz = tensor_lists.size(); size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5"); AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tensor_lists[3][0].type().scalarType() == at::ScalarType::Half) { if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) {
//alher values should be fp32 for half gradients //alher values should be fp32 for half gradients
AT_ASSERTM(tensor_lists[0][0].type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type //dich is done on the gradient type
if (tl_sz == 5) { if (tl_sz == 5) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment