Commit f3528d99 authored by Michael Carilli's avatar Michael Carilli
Browse files

Converting dispatch macros in fused_adam_cuda_kernel.cu

parent d0505433
......@@ -203,7 +203,7 @@ void fused_adam_cuda(
tsize,
(adamMode_t) mode,
decay);
)
);
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
......@@ -261,14 +261,14 @@ void fused_adam_cuda_mt(
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type
if (tl_sz == 5) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, accscalar_t, scalar_t>(),
AdamFunctor<5, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
......@@ -276,16 +276,16 @@ void fused_adam_cuda_mt(
step_size,
(adamMode_t) mode,
decay);
}));
);
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, accscalar_t, scalar_t>(),
AdamFunctor<4, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
......@@ -293,17 +293,17 @@ void fused_adam_cuda_mt(
step_size,
(adamMode_t) mode,
decay);
}));
);
}
} else {
if (tl_sz == 5) {
AT_DISPATCH_FLOATING_TYPES(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, scalar_t, scalar_t>(),
AdamFunctor<5, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
......@@ -311,15 +311,15 @@ void fused_adam_cuda_mt(
step_size,
(adamMode_t) mode,
decay);
}));
);
} else {
AT_DISPATCH_FLOATING_TYPES(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, scalar_t, scalar_t>(),
AdamFunctor<4, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
......@@ -327,7 +327,7 @@ void fused_adam_cuda_mt(
step_size,
(adamMode_t) mode,
decay);
}));
);
}
}
THCudaCheck(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment