Commit 855808f3 authored by ptrblck's avatar ptrblck Committed by mcarilli
Browse files

Replace type().ScalarType() with scalar_type() (#272)

* change .type().ScalarType() to .scalar_type() + at::ScalarType::X to at::kX

* revert scalar_type() to type() for AT_DISPATCH_FLOATING_TYPES_AND_HALF

* revert scalar_type() to type() in AT_DISPATCH_FLOATING_TYPES

* revert scalar_type() to type() for AT_DISPATCH_FLOATING_TYPES_AND_HALF in welford.cu

* revert scalar_type() to type() in layer_norm_cuda_kernel.cu

* revert at::kType  to at::ScalarType::Type

* use DISPATCH_FLOAT_AND_HALF to get rid of warnings

* add dispatch mechanisms for double+float and double+float+half
parent 1c464b48
...@@ -91,19 +91,19 @@ void fused_adam_cuda( ...@@ -91,19 +91,19 @@ void fused_adam_cuda(
} }
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.type().scalarType() == at::ScalarType::Half) { if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients //all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type //dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(g.type(), "adam_cuda_kernel", ([&] { DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<accscalar_t>(), p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL, p_copy.numel() ? p_copy.data<scalar_t_0>() : NULL,
m.data<accscalar_t>(), m.data<accscalar_t>(),
v.data<accscalar_t>(), v.data<accscalar_t>(),
g.data<scalar_t>(), g.data<scalar_t_0>(),
beta1, beta1,
beta2, beta2,
eps, eps,
...@@ -112,16 +112,16 @@ void fused_adam_cuda( ...@@ -112,16 +112,16 @@ void fused_adam_cuda(
tsize, tsize,
(adamMode_t) mode, (adamMode_t) mode,
decay); decay);
})); )
} else { } else {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES(g.type(), "adam_cuda_kernel", ([&] { DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<scalar_t>(), p.data<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write NULL, //don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(), m.data<scalar_t_0>(),
v.data<scalar_t>(), v.data<scalar_t_0>(),
g.data<scalar_t>(), g.data<scalar_t_0>(),
beta1, beta1,
beta2, beta2,
eps, eps,
...@@ -130,7 +130,7 @@ void fused_adam_cuda( ...@@ -130,7 +130,7 @@ void fused_adam_cuda(
tsize, tsize,
(adamMode_t) mode, (adamMode_t) mode,
decay); decay);
})); );
} }
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
......
...@@ -129,7 +129,7 @@ std::vector<at::Tensor> layer_norm( ...@@ -129,7 +129,7 @@ std::vector<at::Tensor> layer_norm(
int n1,n2; int n1,n2;
check_args(input,normalized_shape,n1,n2); check_args(input,normalized_shape,n1,n2);
at::Tensor output = at::empty_like(input); at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : input.type().scalarType())); at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon); normalized_shape,NULL,NULL,epsilon);
...@@ -151,7 +151,7 @@ std::vector<at::Tensor> layer_norm_affine( ...@@ -151,7 +151,7 @@ std::vector<at::Tensor> layer_norm_affine(
int n1,n2; int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2); check_args(input,normalized_shape,gamma,beta,n1,n2);
at::Tensor output = at::empty_like(input); at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : input.type().scalarType())); at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon); normalized_shape,&gamma,&beta,epsilon);
......
...@@ -685,18 +685,18 @@ void cuda_layer_norm( ...@@ -685,18 +685,18 @@ void cuda_layer_norm(
double epsilon) double epsilon)
{ {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input->type(), "layer_norm_cuda_kernel", ([&] { DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
HostApplyLayerNorm( HostApplyLayerNorm(
output->data<scalar_t>(), output->data<scalar_t_0>(),
mean->data<accscalar_t>(), mean->data<accscalar_t>(),
invvar->data<accscalar_t>(), invvar->data<accscalar_t>(),
input->data<scalar_t>(), input->data<scalar_t_0>(),
n1,n2, n1,n2,
epsilon, epsilon,
gamma != NULL ? gamma->data<scalar_t>() : NULL, gamma != NULL ? gamma->data<scalar_t_0>() : NULL,
beta != NULL ? beta->data<scalar_t>() : NULL); beta != NULL ? beta->data<scalar_t_0>() : NULL);
})); )
} }
template<typename T, typename U> template<typename T, typename U>
...@@ -725,7 +725,7 @@ void HostLayerNormGradient( ...@@ -725,7 +725,7 @@ void HostLayerNormGradient(
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : input->type().scalarType())); at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input->scalar_type()));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>( cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout, dout,
...@@ -787,19 +787,19 @@ void cuda_layer_norm_gradient( ...@@ -787,19 +787,19 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta) at::Tensor* grad_beta)
{ {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input->type(), "cuComputeGradInput", ([&] { DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
HostLayerNormGradient( HostLayerNormGradient(
dout->data<scalar_t>(), dout->data<scalar_t_0>(),
mean->data<accscalar_t>(), mean->data<accscalar_t>(),
invvar->data<accscalar_t>(), invvar->data<accscalar_t>(),
input, input,
n1,n2, n1,n2,
gamma->data<scalar_t>(), gamma->data<scalar_t_0>(),
beta->data<scalar_t>(), beta->data<scalar_t_0>(),
epsilon, epsilon,
grad_input->data<scalar_t>(), grad_input->data<scalar_t_0>(),
grad_gamma->data<scalar_t>(), grad_gamma->data<scalar_t_0>(),
grad_beta->data<scalar_t>()); grad_beta->data<scalar_t_0>());
})); )
} }
...@@ -75,7 +75,7 @@ at::Tensor multi_tensor_l2norm_cuda( ...@@ -75,7 +75,7 @@ at::Tensor multi_tensor_l2norm_cuda(
at::Tensor noop_flag, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists) std::vector<std::vector<at::Tensor>> tensor_lists)
{ {
auto output = at::zeros({320}, tensor_lists[0][0].options().dtype(at::ScalarType::Float)); auto output = at::zeros({320}, tensor_lists[0][0].options().dtype(at::kFloat));
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>( multi_tensor_apply<1>(
......
...@@ -86,39 +86,15 @@ void multi_tensor_scale_cuda( ...@@ -86,39 +86,15 @@ void multi_tensor_scale_cuda(
// If build times suffer, think about where to put this dispatch, // If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply. // and what logic should be moved out of multi_tensor_apply.
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[0][0].type(), DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
"multi_tensor_scale_cuda", DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
[&] multi_tensor_apply<2>(
{ BLOCK_SIZE,
// using accscalar_t = acc_type<scalar_t, true>; chunk_size,
switch(tensor_lists[1][0].scalar_type()) noop_flag,
{ tensor_lists,
case at::ScalarType::Half: ScaleFunctor<scalar_t_0, scalar_t_1>(),
multi_tensor_apply<2>( scale); ))
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
ScaleFunctor<scalar_t, at::Half>(),
scale);
break;
case at::ScalarType::Float:
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
ScaleFunctor<scalar_t, float>(),
scale);
break;
default:
std::stringstream ss;
ss << "multi_tensor_scale_cuda not implemented for output type = "
<< tensor_lists[1][0].dtype();
AT_ERROR(ss.str().c_str());
}
});
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize()); // AT_CUDA_CHECK(cudaDeviceSynchronize());
......
...@@ -3,15 +3,15 @@ ...@@ -3,15 +3,15 @@
// Forward/backward compatiblity hack around // Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 // https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream. // pending more future-proof guidance from upstream.
struct TypeShim // struct TypeShim
{ // {
const at::Type& payload; // const at::Type& payload;
TypeShim(const at::Type& type) : payload(type) {} // TypeShim(const at::Type& type) : payload(type) {}
// Enable trivial conversion to a const at::Type& for pre-3aeb78 // // Enable trivial conversion to a const at::Type& for pre-3aeb78
operator const at::Type&(){ return payload; }; // operator const at::Type&(){ return payload; };
// Enable dispatch switch statements to take *this directly for post-3aeb78 // // Enable dispatch switch statements to take *this directly for post-3aeb78
operator at::ScalarType(){ return payload.scalarType(); }; // //operator at::ScalarType(){ return payload.; };
}; // };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \ switch(TYPE) \
...@@ -33,6 +33,52 @@ struct TypeShim ...@@ -33,6 +33,52 @@ struct TypeShim
} }
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T> template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes __device__ __forceinline__ T reduce_block_into_lanes
(T *x, (T *x,
......
...@@ -182,14 +182,14 @@ __host__ int get_tensor_spatial_size(const at::Tensor& input) ...@@ -182,14 +182,14 @@ __host__ int get_tensor_spatial_size(const at::Tensor& input)
// promote accumulation scalar type. promote half to float. // promote accumulation scalar type. promote half to float.
__host__ at::ScalarType promote_scalartype(const at::Tensor& input) __host__ at::ScalarType promote_scalartype(const at::Tensor& input)
{ {
return input.type().scalarType() == at::ScalarType::Half ? return input.scalar_type() == at::ScalarType::Half ?
at::ScalarType::Float : input.type().scalarType(); at::ScalarType::Float : input.scalar_type();
} }
// return single element size, optional accumulation type promotion. // return single element size, optional accumulation type promotion.
__host__ size_t get_element_data_size(const at::Tensor& input, bool accumulation = false) __host__ size_t get_element_data_size(const at::Tensor& input, bool accumulation = false)
{ {
auto scalar_type = accumulation ? promote_scalartype(input) : input.type().scalarType(); auto scalar_type = accumulation ? promote_scalartype(input) : input.scalar_type();
return at::elementSize(scalar_type); return at::elementSize(scalar_type);
} }
...@@ -846,16 +846,16 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) { ...@@ -846,16 +846,16 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
{ {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "welford_mean_var_kernel", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_kernel",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
welford_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( welford_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
out_mean.data<accscalar_t>(), out_mean.data<accscalar_t>(),
out_var_biased.data<accscalar_t>(), out_var_biased.data<accscalar_t>(),
batch_size, batch_size,
feature_size, feature_size,
space_size); space_size);
})); );
} }
return {out_mean, out_var_biased}; return {out_mean, out_var_biased};
...@@ -881,40 +881,40 @@ at::Tensor batchnorm_forward_CUDA( ...@@ -881,40 +881,40 @@ at::Tensor batchnorm_forward_CUDA(
const dim3 grid(feature_size, batch_group_size, grid_z); const dim3 grid(feature_size, batch_group_size, grid_z);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && && weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) { weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( batchnorm_forward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL, weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>() : NULL, shift.has_value() ? shift.value().data<accscalar_t>() : NULL,
out.data<scalar_t>(), out.data<scalar_t_0>(),
space_size, space_size,
batch_size); batch_size);
})); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( batchnorm_forward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL, weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
shift.has_value() ? shift.value().data<scalar_t>() : NULL, shift.has_value() ? shift.value().data<scalar_t_0>() : NULL,
out.data<scalar_t>(), out.data<scalar_t_0>(),
space_size, space_size,
batch_size); batch_size);
})); );
} }
return out; return out;
} }
...@@ -952,15 +952,15 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -952,15 +952,15 @@ std::vector<at::Tensor> reduce_bn_CUDA(
const dim3 grid(feature_size); const dim3 grid(feature_size);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && && weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) { weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
reduce_bn_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( reduce_bn_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
grad_output.data<scalar_t>(), grad_output.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
...@@ -970,28 +970,28 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -970,28 +970,28 @@ std::vector<at::Tensor> reduce_bn_CUDA(
batch_size, batch_size,
feature_size, feature_size,
space_size); space_size);
})); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
reduce_bn_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( reduce_bn_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
grad_output.data<scalar_t>(), grad_output.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
weight.has_value() ? grad_weight.data<scalar_t>() : NULL, weight.has_value() ? grad_weight.data<scalar_t_0>() : NULL,
weight.has_value() ? grad_bias.data<scalar_t>() : NULL, weight.has_value() ? grad_bias.data<scalar_t_0>() : NULL,
batch_size, batch_size,
feature_size, feature_size,
space_size); space_size);
})); );
} }
return {mean_dy, mean_dy_xmu, grad_weight, grad_bias}; return {mean_dy, mean_dy_xmu, grad_weight, grad_bias};
...@@ -1021,44 +1021,44 @@ at::Tensor batchnorm_backward_CUDA( ...@@ -1021,44 +1021,44 @@ at::Tensor batchnorm_backward_CUDA(
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && && weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) { weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( batchnorm_backward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(), grad_output.data<scalar_t_0>(),
input.data<scalar_t>(), input.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL, weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(), grad_input.data<scalar_t_0>(),
space_size, space_size,
batch_size); batch_size);
})); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( batchnorm_backward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(), grad_output.data<scalar_t_0>(),
input.data<scalar_t>(), input.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL, weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(), grad_input.data<scalar_t_0>(),
space_size, space_size,
batch_size); batch_size);
})); );
} }
return grad_input; return grad_input;
...@@ -1083,18 +1083,18 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node ...@@ -1083,18 +1083,18 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
{ {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(mean_feature_nodes.type(), "welford_parallel_kernel", ([&] { DISPATCH_FLOAT_AND_HALF(mean_feature_nodes.scalar_type(), 0, "welford_parallel_kernel",
welford_kernel_parallel<scalar_t><<<grid, block, 0, stream>>>( welford_kernel_parallel<scalar_t_0><<<grid, block, 0, stream>>>(
mean_feature_nodes.data<scalar_t>(), mean_feature_nodes.data<scalar_t_0>(),
var_biased.data<scalar_t>(), var_biased.data<scalar_t_0>(),
out_mean.data<scalar_t>(), out_mean.data<scalar_t_0>(),
out_var.data<scalar_t>(), out_var.data<scalar_t_0>(),
inv_std.data<scalar_t>(), inv_std.data<scalar_t_0>(),
world_size, world_size,
feature_size, feature_size,
eps, eps,
numel); numel);
})); );
} }
return {out_mean, out_var, inv_std}; return {out_mean, out_var, inv_std};
...@@ -1118,27 +1118,27 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) { ...@@ -1118,27 +1118,27 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
at::Tensor semaphores; at::Tensor semaphores;
if (grid.y > 1) { if (grid.y > 1) {
staging_data = at::empty({4*stride*grid.y}, option); staging_data = at::empty({4*stride*grid.y}, option);
semaphores = at::zeros({grid.x}, input.options().dtype(at::ScalarType::Int)); semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
} }
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
{ {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "welford_mean_var_c_last", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_c_last",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr; accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr; int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
welford_kernel_c_last<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> welford_kernel_c_last<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
out_mean.data<accscalar_t>(), out_mean.data<accscalar_t>(),
out_var_biased.data<accscalar_t>(), out_var_biased.data<accscalar_t>(),
staging_data_ptr, staging_data_ptr,
semaphores_ptr, semaphores_ptr,
reduction_size, reduction_size,
stride); stride);
})); );
} }
return {out_mean, out_var_biased}; return {out_mean, out_var_biased};
...@@ -1161,41 +1161,41 @@ at::Tensor batchnorm_forward_c_last_CUDA( ...@@ -1161,41 +1161,41 @@ at::Tensor batchnorm_forward_c_last_CUDA(
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) { && weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL, weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>(): NULL, shift.has_value() ? shift.value().data<accscalar_t>(): NULL,
out.data<scalar_t>(), out.data<scalar_t_0>(),
reduction_size, reduction_size,
stride); stride);
})); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER> batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL, weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
shift.has_value() ? shift.value().data<scalar_t>(): NULL, shift.has_value() ? shift.value().data<scalar_t_0>(): NULL,
out.data<scalar_t>(), out.data<scalar_t_0>(),
reduction_size, reduction_size,
stride); stride);
})); );
} }
return out; return out;
} }
...@@ -1231,22 +1231,22 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA( ...@@ -1231,22 +1231,22 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
at::Tensor semaphores; at::Tensor semaphores;
if (grid.y > 1) { if (grid.y > 1) {
staging_data = at::empty({2*stride*grid.y}, mean.options()); staging_data = at::empty({2*stride*grid.y}, mean.options());
semaphores = at::zeros({grid.x}, input.options().dtype(at::ScalarType::Int)); semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
} }
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && weight.has_value()
&& weight.value().type().scalarType() == at::ScalarType::Float) { && weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr; accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr; int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
reduce_bn_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
grad_output.data<scalar_t>(), grad_output.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
...@@ -1257,32 +1257,32 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA( ...@@ -1257,32 +1257,32 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
semaphores_ptr, semaphores_ptr,
reduction_size, reduction_size,
stride); stride);
})); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr; accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr; int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
reduce_bn_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER> reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
grad_output.data<scalar_t>(), grad_output.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
weight.has_value() ? grad_weight.data<scalar_t>() : NULL, weight.has_value() ? grad_weight.data<scalar_t_0>() : NULL,
weight.has_value() ?grad_bias.data<scalar_t>() : NULL, weight.has_value() ?grad_bias.data<scalar_t_0>() : NULL,
staging_data_ptr, staging_data_ptr,
semaphores_ptr, semaphores_ptr,
reduction_size, reduction_size,
stride); stride);
})); );
} }
return {mean_dy, mean_dy_xmu, grad_weight, grad_bias}; return {mean_dy, mean_dy_xmu, grad_weight, grad_bias};
...@@ -1307,45 +1307,45 @@ at::Tensor batchnorm_backward_c_last_CUDA( ...@@ -1307,45 +1307,45 @@ at::Tensor batchnorm_backward_c_last_CUDA(
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) { && weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(), grad_output.data<scalar_t_0>(),
input.data<scalar_t>(), input.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL, weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(), grad_input.data<scalar_t_0>(),
reduction_size, reduction_size,
stride); stride);
})); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER> batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(), grad_output.data<scalar_t_0>(),
input.data<scalar_t>(), input.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL, weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(), grad_input.data<scalar_t_0>(),
reduction_size, reduction_size,
stride); stride);
})); );
} }
return grad_input; return grad_input;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment