Commit 2c8e1c86 authored by Michael Carilli's avatar Michael Carilli
Browse files

Anticipating upstream #17996

parent a730f38f
...@@ -96,7 +96,7 @@ void fused_adam_cuda( ...@@ -96,7 +96,7 @@ void fused_adam_cuda(
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type //dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(g.type()), "adam_cuda_kernel", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(g.type(), "adam_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
adam_cuda_kernel<accscalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<accscalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<accscalar_t>(), p.data<accscalar_t>(),
...@@ -115,7 +115,7 @@ void fused_adam_cuda( ...@@ -115,7 +115,7 @@ void fused_adam_cuda(
})); }));
} else { } else {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES(TypeShim(g.type()), "adam_cuda_kernel", ([&] { AT_DISPATCH_FLOATING_TYPES(g.type(), "adam_cuda_kernel", ([&] {
adam_cuda_kernel<scalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<scalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<scalar_t>(), p.data<scalar_t>(),
NULL, //don't output p_copy for fp32, it's wasted write NULL, //don't output p_copy for fp32, it's wasted write
......
...@@ -678,7 +678,7 @@ void cuda_layer_norm( ...@@ -678,7 +678,7 @@ void cuda_layer_norm(
double epsilon) double epsilon)
{ {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input->type()), "layer_norm_cuda_kernel", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input->type(), "layer_norm_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
HostApplyLayerNorm( HostApplyLayerNorm(
output->data<scalar_t>(), output->data<scalar_t>(),
...@@ -776,7 +776,7 @@ void cuda_layer_norm_gradient( ...@@ -776,7 +776,7 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta) at::Tensor* grad_beta)
{ {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input->type()), "cuComputeGradInput", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input->type(), "cuComputeGradInput", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
HostLayerNormGradient( HostLayerNormGradient(
dout->data<scalar_t>(), dout->data<scalar_t>(),
......
...@@ -98,7 +98,7 @@ void multi_tensor_scale_cuda( ...@@ -98,7 +98,7 @@ void multi_tensor_scale_cuda(
// If build times suffer, think about where to put this dispatch, // If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply. // and what logic should be moved out of multi_tensor_apply.
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(tensor_lists[0][0].type()), AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[0][0].type(),
"multi_tensor_scale_cuda", "multi_tensor_scale_cuda",
[&] [&]
{ {
......
...@@ -846,7 +846,7 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) { ...@@ -846,7 +846,7 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
{ {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "welford_mean_var_kernel", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "welford_mean_var_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
welford_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( welford_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t>(),
...@@ -885,7 +885,7 @@ at::Tensor batchnorm_forward_CUDA( ...@@ -885,7 +885,7 @@ at::Tensor batchnorm_forward_CUDA(
&& weight.has_value() && && weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) { weight.value().type().scalarType() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( batchnorm_forward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t>(),
...@@ -903,7 +903,7 @@ at::Tensor batchnorm_forward_CUDA( ...@@ -903,7 +903,7 @@ at::Tensor batchnorm_forward_CUDA(
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.type().scalarType() is not supported with weight.type().scalarType()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t>(),
...@@ -956,7 +956,7 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -956,7 +956,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
&& weight.has_value() && && weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) { weight.value().type().scalarType() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward_reduce", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( reduce_bn_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t>(),
...@@ -977,7 +977,7 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -977,7 +977,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.type().scalarType() is not supported with weight.type().scalarType()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward_reduce", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( reduce_bn_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t>(),
...@@ -1025,7 +1025,7 @@ at::Tensor batchnorm_backward_CUDA( ...@@ -1025,7 +1025,7 @@ at::Tensor batchnorm_backward_CUDA(
&& weight.has_value() && && weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) { weight.value().type().scalarType() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( batchnorm_backward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(), grad_output.data<scalar_t>(),
...@@ -1045,7 +1045,7 @@ at::Tensor batchnorm_backward_CUDA( ...@@ -1045,7 +1045,7 @@ at::Tensor batchnorm_backward_CUDA(
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.type().scalarType() is not supported with weight.type().scalarType()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( batchnorm_backward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(), grad_output.data<scalar_t>(),
...@@ -1083,7 +1083,7 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node ...@@ -1083,7 +1083,7 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
{ {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(mean_feature_nodes.type()), "welford_parallel_kernel", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(mean_feature_nodes.type(), "welford_parallel_kernel", ([&] {
welford_kernel_parallel<scalar_t><<<grid, block, 0, stream>>>( welford_kernel_parallel<scalar_t><<<grid, block, 0, stream>>>(
mean_feature_nodes.data<scalar_t>(), mean_feature_nodes.data<scalar_t>(),
var_biased.data<scalar_t>(), var_biased.data<scalar_t>(),
...@@ -1125,7 +1125,7 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) { ...@@ -1125,7 +1125,7 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
{ {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "welford_mean_var_c_last", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "welford_mean_var_c_last", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr; accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr; int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
...@@ -1164,7 +1164,7 @@ at::Tensor batchnorm_forward_c_last_CUDA( ...@@ -1164,7 +1164,7 @@ at::Tensor batchnorm_forward_c_last_CUDA(
if (input.type().scalarType() == at::ScalarType::Half if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) { && weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
...@@ -1183,7 +1183,7 @@ at::Tensor batchnorm_forward_c_last_CUDA( ...@@ -1183,7 +1183,7 @@ at::Tensor batchnorm_forward_c_last_CUDA(
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.type().scalarType() is not supported with weight.type().scalarType()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER> batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
...@@ -1239,7 +1239,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA( ...@@ -1239,7 +1239,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
&& weight.has_value() && weight.has_value()
&& weight.value().type().scalarType() == at::ScalarType::Float) { && weight.value().type().scalarType() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward_reduce", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr; accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr; int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
...@@ -1264,7 +1264,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA( ...@@ -1264,7 +1264,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.type().scalarType() is not supported with weight.type().scalarType()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward_reduce", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr; accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr; int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
...@@ -1310,7 +1310,7 @@ at::Tensor batchnorm_backward_c_last_CUDA( ...@@ -1310,7 +1310,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
if (input.type().scalarType() == at::ScalarType::Half if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) { && weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
...@@ -1331,7 +1331,7 @@ at::Tensor batchnorm_backward_c_last_CUDA( ...@@ -1331,7 +1331,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.type().scalarType() is not supported with weight.type().scalarType()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER> batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment