Unverified Commit 74c06d87 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Merge pull request #205 from NVIDIA/remove_type_shim

Conforming with upstream #17996
parents 8ab5e470 2c8e1c86
......@@ -96,7 +96,7 @@ void fused_adam_cuda(
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(g.type()), "adam_cuda_kernel", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(g.type(), "adam_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
adam_cuda_kernel<accscalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<accscalar_t>(),
......@@ -115,7 +115,7 @@ void fused_adam_cuda(
}));
} else {
using namespace at;
AT_DISPATCH_FLOATING_TYPES(TypeShim(g.type()), "adam_cuda_kernel", ([&] {
AT_DISPATCH_FLOATING_TYPES(g.type(), "adam_cuda_kernel", ([&] {
adam_cuda_kernel<scalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<scalar_t>(),
NULL, //don't output p_copy for fp32, it's wasted write
......
......@@ -678,7 +678,7 @@ void cuda_layer_norm(
double epsilon)
{
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input->type()), "layer_norm_cuda_kernel", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input->type(), "layer_norm_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
HostApplyLayerNorm(
output->data<scalar_t>(),
......@@ -776,7 +776,7 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta)
{
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input->type()), "cuComputeGradInput", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input->type(), "cuComputeGradInput", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
HostLayerNormGradient(
dout->data<scalar_t>(),
......
......@@ -98,7 +98,7 @@ void multi_tensor_scale_cuda(
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(tensor_lists[0][0].type()),
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[0][0].type(),
"multi_tensor_scale_cuda",
[&]
{
......
......@@ -846,7 +846,7 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
{
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "welford_mean_var_kernel", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "welford_mean_var_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
welford_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
......@@ -885,7 +885,7 @@ at::Tensor batchnorm_forward_CUDA(
&& weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) {
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
......@@ -903,7 +903,7 @@ at::Tensor batchnorm_forward_CUDA(
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
......@@ -956,7 +956,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
&& weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) {
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward_reduce", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
......@@ -977,7 +977,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward_reduce", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
......@@ -1025,7 +1025,7 @@ at::Tensor batchnorm_backward_CUDA(
&& weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) {
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(),
......@@ -1045,7 +1045,7 @@ at::Tensor batchnorm_backward_CUDA(
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(),
......@@ -1083,7 +1083,7 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
{
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(mean_feature_nodes.type()), "welford_parallel_kernel", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(mean_feature_nodes.type(), "welford_parallel_kernel", ([&] {
welford_kernel_parallel<scalar_t><<<grid, block, 0, stream>>>(
mean_feature_nodes.data<scalar_t>(),
var_biased.data<scalar_t>(),
......@@ -1125,7 +1125,7 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
{
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "welford_mean_var_c_last", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "welford_mean_var_c_last", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
......@@ -1164,7 +1164,7 @@ at::Tensor batchnorm_forward_c_last_CUDA(
if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) {
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
......@@ -1183,7 +1183,7 @@ at::Tensor batchnorm_forward_c_last_CUDA(
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
......@@ -1239,7 +1239,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
&& weight.has_value()
&& weight.value().type().scalarType() == at::ScalarType::Float) {
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward_reduce", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
......@@ -1264,7 +1264,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward_reduce", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
......@@ -1310,7 +1310,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) {
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
......@@ -1331,7 +1331,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
"input.type().scalarType() is not supported with weight.type().scalarType()");
}
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment