Commit f3528d99 authored by Michael Carilli's avatar Michael Carilli
Browse files

Converting dispatch macros in fused_adam_cuda_kernel.cu

parent d0505433
...@@ -203,7 +203,7 @@ void fused_adam_cuda( ...@@ -203,7 +203,7 @@ void fused_adam_cuda(
tsize, tsize,
(adamMode_t) mode, (adamMode_t) mode,
decay); decay);
) );
} else { } else {
using namespace at; using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
...@@ -261,14 +261,14 @@ void fused_adam_cuda_mt( ...@@ -261,14 +261,14 @@ void fused_adam_cuda_mt(
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type //dich is done on the gradient type
if (tl_sz == 5) { if (tl_sz == 5) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] { DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>( multi_tensor_apply<5>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, noop_flag,
tensor_lists, tensor_lists,
AdamFunctor<5, accscalar_t, scalar_t>(), AdamFunctor<5, accscalar_t, scalar_t_0>(),
beta1, beta1,
beta2, beta2,
eps, eps,
...@@ -276,16 +276,16 @@ void fused_adam_cuda_mt( ...@@ -276,16 +276,16 @@ void fused_adam_cuda_mt(
step_size, step_size,
(adamMode_t) mode, (adamMode_t) mode,
decay); decay);
})); );
} else { } else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] { DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, noop_flag,
tensor_lists, tensor_lists,
AdamFunctor<4, accscalar_t, scalar_t>(), AdamFunctor<4, accscalar_t, scalar_t_0>(),
beta1, beta1,
beta2, beta2,
eps, eps,
...@@ -293,17 +293,17 @@ void fused_adam_cuda_mt( ...@@ -293,17 +293,17 @@ void fused_adam_cuda_mt(
step_size, step_size,
(adamMode_t) mode, (adamMode_t) mode,
decay); decay);
})); );
} }
} else { } else {
if (tl_sz == 5) { if (tl_sz == 5) {
AT_DISPATCH_FLOATING_TYPES(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] { DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<5>( multi_tensor_apply<5>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, noop_flag,
tensor_lists, tensor_lists,
AdamFunctor<5, scalar_t, scalar_t>(), AdamFunctor<5, scalar_t_0, scalar_t_0>(),
beta1, beta1,
beta2, beta2,
eps, eps,
...@@ -311,15 +311,15 @@ void fused_adam_cuda_mt( ...@@ -311,15 +311,15 @@ void fused_adam_cuda_mt(
step_size, step_size,
(adamMode_t) mode, (adamMode_t) mode,
decay); decay);
})); );
} else { } else {
AT_DISPATCH_FLOATING_TYPES(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] { DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, noop_flag,
tensor_lists, tensor_lists,
AdamFunctor<4, scalar_t, scalar_t>(), AdamFunctor<4, scalar_t_0, scalar_t_0>(),
beta1, beta1,
beta2, beta2,
eps, eps,
...@@ -327,7 +327,7 @@ void fused_adam_cuda_mt( ...@@ -327,7 +327,7 @@ void fused_adam_cuda_mt(
step_size, step_size,
(adamMode_t) mode, (adamMode_t) mode,
decay); decay);
})); );
} }
} }
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment