Commit 855808f3 authored by ptrblck's avatar ptrblck Committed by mcarilli
Browse files

Replace type().ScalarType() with scalar_type() (#272)

* change .type().ScalarType() to .scalar_type() + at::ScalarType::X to at::kX

* revert scalar_type() to type() for AT_DISPATCH_FLOATING_TYPES_AND_HALF

* revert scalar_type() to type() in AT_DISPATCH_FLOATING_TYPES

* revert scalar_type() to type() for AT_DISPATCH_FLOATING_TYPES_AND_HALF in welford.cu

* revert scalar_type() to type() in layer_norm_cuda_kernel.cu

* revert at::kType  to at::ScalarType::Type

* use DISPATCH_FLOAT_AND_HALF to get rid of warnings

* add dispatch mechanisms for double+float and double+float+half
parent 1c464b48
......@@ -91,19 +91,19 @@ void fused_adam_cuda(
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.type().scalarType() == at::ScalarType::Half) {
if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type");
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(g.type(), "adam_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
adam_cuda_kernel<accscalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>(
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
p_copy.numel() ? p_copy.data<scalar_t_0>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
g.data<scalar_t_0>(),
beta1,
beta2,
eps,
......@@ -112,16 +112,16 @@ void fused_adam_cuda(
tsize,
(adamMode_t) mode,
decay);
}));
)
} else {
using namespace at;
AT_DISPATCH_FLOATING_TYPES(g.type(), "adam_cuda_kernel", ([&] {
adam_cuda_kernel<scalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<scalar_t>(),
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
m.data<scalar_t_0>(),
v.data<scalar_t_0>(),
g.data<scalar_t_0>(),
beta1,
beta2,
eps,
......@@ -130,7 +130,7 @@ void fused_adam_cuda(
tsize,
(adamMode_t) mode,
decay);
}));
);
}
THCudaCheck(cudaGetLastError());
......
......@@ -129,7 +129,7 @@ std::vector<at::Tensor> layer_norm(
int n1,n2;
check_args(input,normalized_shape,n1,n2);
at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : input.type().scalarType()));
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon);
......@@ -151,7 +151,7 @@ std::vector<at::Tensor> layer_norm_affine(
int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2);
at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : input.type().scalarType()));
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon);
......
......@@ -685,18 +685,18 @@ void cuda_layer_norm(
double epsilon)
{
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input->type(), "layer_norm_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
HostApplyLayerNorm(
output->data<scalar_t>(),
output->data<scalar_t_0>(),
mean->data<accscalar_t>(),
invvar->data<accscalar_t>(),
input->data<scalar_t>(),
input->data<scalar_t_0>(),
n1,n2,
epsilon,
gamma != NULL ? gamma->data<scalar_t>() : NULL,
beta != NULL ? beta->data<scalar_t>() : NULL);
}));
gamma != NULL ? gamma->data<scalar_t_0>() : NULL,
beta != NULL ? beta->data<scalar_t_0>() : NULL);
)
}
template<typename T, typename U>
......@@ -725,7 +725,7 @@ void HostLayerNormGradient(
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : input->type().scalarType()));
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input->scalar_type()));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout,
......@@ -787,19 +787,19 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta)
{
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input->type(), "cuComputeGradInput", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput",
using accscalar_t = at::acc_type<scalar_t_0, true>;
HostLayerNormGradient(
dout->data<scalar_t>(),
dout->data<scalar_t_0>(),
mean->data<accscalar_t>(),
invvar->data<accscalar_t>(),
input,
n1,n2,
gamma->data<scalar_t>(),
beta->data<scalar_t>(),
gamma->data<scalar_t_0>(),
beta->data<scalar_t_0>(),
epsilon,
grad_input->data<scalar_t>(),
grad_gamma->data<scalar_t>(),
grad_beta->data<scalar_t>());
}));
grad_input->data<scalar_t_0>(),
grad_gamma->data<scalar_t_0>(),
grad_beta->data<scalar_t_0>());
)
}
......@@ -75,7 +75,7 @@ at::Tensor multi_tensor_l2norm_cuda(
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists)
{
auto output = at::zeros({320}, tensor_lists[0][0].options().dtype(at::ScalarType::Float));
auto output = at::zeros({320}, tensor_lists[0][0].options().dtype(at::kFloat));
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>(
......
......@@ -86,39 +86,15 @@ void multi_tensor_scale_cuda(
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[0][0].type(),
"multi_tensor_scale_cuda",
[&]
{
// using accscalar_t = acc_type<scalar_t, true>;
switch(tensor_lists[1][0].scalar_type())
{
case at::ScalarType::Half:
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
ScaleFunctor<scalar_t, at::Half>(),
scale);
break;
case at::ScalarType::Float:
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
ScaleFunctor<scalar_t, float>(),
scale);
break;
default:
std::stringstream ss;
ss << "multi_tensor_scale_cuda not implemented for output type = "
<< tensor_lists[1][0].dtype();
AT_ERROR(ss.str().c_str());
}
});
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
ScaleFunctor<scalar_t_0, scalar_t_1>(),
scale); ))
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
......
......@@ -3,15 +3,15 @@
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
struct TypeShim
{
const at::Type& payload;
TypeShim(const at::Type& type) : payload(type) {}
// Enable trivial conversion to a const at::Type& for pre-3aeb78
operator const at::Type&(){ return payload; };
// Enable dispatch switch statements to take *this directly for post-3aeb78
operator at::ScalarType(){ return payload.scalarType(); };
};
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
......@@ -33,6 +33,52 @@ struct TypeShim
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes
(T *x,
......
......@@ -182,14 +182,14 @@ __host__ int get_tensor_spatial_size(const at::Tensor& input)
// promote accumulation scalar type. promote half to float.
__host__ at::ScalarType promote_scalartype(const at::Tensor& input)
{
return input.type().scalarType() == at::ScalarType::Half ?
at::ScalarType::Float : input.type().scalarType();
return input.scalar_type() == at::ScalarType::Half ?
at::ScalarType::Float : input.scalar_type();
}
// return single element size, optional accumulation type promotion.
__host__ size_t get_element_data_size(const at::Tensor& input, bool accumulation = false)
{
auto scalar_type = accumulation ? promote_scalartype(input) : input.type().scalarType();
auto scalar_type = accumulation ? promote_scalartype(input) : input.scalar_type();
return at::elementSize(scalar_type);
}
......@@ -846,16 +846,16 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
{
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "welford_mean_var_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
welford_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
welford_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t_0>(),
out_mean.data<accscalar_t>(),
out_var_biased.data<accscalar_t>(),
batch_size,
feature_size,
space_size);
}));
);
}
return {out_mean, out_var_biased};
......@@ -881,40 +881,40 @@ at::Tensor batchnorm_forward_CUDA(
const dim3 grid(feature_size, batch_group_size, grid_z);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half
if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) {
weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_forward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t_0>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>() : NULL,
out.data<scalar_t>(),
out.data<scalar_t_0>(),
space_size,
batch_size);
}));
);
} else {
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()");
}
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_forward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
input.data<scalar_t_0>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL,
shift.has_value() ? shift.value().data<scalar_t>() : NULL,
out.data<scalar_t>(),
weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
shift.has_value() ? shift.value().data<scalar_t_0>() : NULL,
out.data<scalar_t_0>(),
space_size,
batch_size);
}));
);
}
return out;
}
......@@ -952,15 +952,15 @@ std::vector<at::Tensor> reduce_bn_CUDA(
const dim3 grid(feature_size);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half
if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) {
weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
grad_output.data<scalar_t>(),
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
using accscalar_t = at::acc_type<scalar_t_0, true>;
reduce_bn_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t_0>(),
grad_output.data<scalar_t_0>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
......@@ -970,28 +970,28 @@ std::vector<at::Tensor> reduce_bn_CUDA(
batch_size,
feature_size,
space_size);
}));
);
} else {
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()");
}
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
grad_output.data<scalar_t>(),
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
using accscalar_t = at::acc_type<scalar_t_0, true>;
reduce_bn_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
input.data<scalar_t_0>(),
grad_output.data<scalar_t_0>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
weight.has_value() ? grad_weight.data<scalar_t>() : NULL,
weight.has_value() ? grad_bias.data<scalar_t>() : NULL,
weight.has_value() ? grad_weight.data<scalar_t_0>() : NULL,
weight.has_value() ? grad_bias.data<scalar_t_0>() : NULL,
batch_size,
feature_size,
space_size);
}));
);
}
return {mean_dy, mean_dy_xmu, grad_weight, grad_bias};
......@@ -1021,44 +1021,44 @@ at::Tensor batchnorm_backward_CUDA(
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half
if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) {
weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(),
input.data<scalar_t>(),
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_backward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
grad_output.data<scalar_t_0>(),
input.data<scalar_t_0>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
grad_input.data<scalar_t_0>(),
space_size,
batch_size);
}));
);
} else {
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()");
}
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(),
input.data<scalar_t>(),
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_backward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
grad_output.data<scalar_t_0>(),
input.data<scalar_t_0>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL,
weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
grad_input.data<scalar_t_0>(),
space_size,
batch_size);
}));
);
}
return grad_input;
......@@ -1083,18 +1083,18 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
{
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(mean_feature_nodes.type(), "welford_parallel_kernel", ([&] {
welford_kernel_parallel<scalar_t><<<grid, block, 0, stream>>>(
mean_feature_nodes.data<scalar_t>(),
var_biased.data<scalar_t>(),
out_mean.data<scalar_t>(),
out_var.data<scalar_t>(),
inv_std.data<scalar_t>(),
DISPATCH_FLOAT_AND_HALF(mean_feature_nodes.scalar_type(), 0, "welford_parallel_kernel",
welford_kernel_parallel<scalar_t_0><<<grid, block, 0, stream>>>(
mean_feature_nodes.data<scalar_t_0>(),
var_biased.data<scalar_t_0>(),
out_mean.data<scalar_t_0>(),
out_var.data<scalar_t_0>(),
inv_std.data<scalar_t_0>(),
world_size,
feature_size,
eps,
numel);
}));
);
}
return {out_mean, out_var, inv_std};
......@@ -1118,27 +1118,27 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
at::Tensor semaphores;
if (grid.y > 1) {
staging_data = at::empty({4*stride*grid.y}, option);
semaphores = at::zeros({grid.x}, input.options().dtype(at::ScalarType::Int));
semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
}
auto stream = at::cuda::getCurrentCUDAStream();
{
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "welford_mean_var_c_last", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_c_last",
using accscalar_t = at::acc_type<scalar_t_0, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
welford_kernel_c_last<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
welford_kernel_c_last<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
input.data<scalar_t_0>(),
out_mean.data<accscalar_t>(),
out_var_biased.data<accscalar_t>(),
staging_data_ptr,
semaphores_ptr,
reduction_size,
stride);
}));
);
}
return {out_mean, out_var_biased};
......@@ -1161,41 +1161,41 @@ at::Tensor batchnorm_forward_c_last_CUDA(
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) {
if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
input.data<scalar_t_0>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>(): NULL,
out.data<scalar_t>(),
out.data<scalar_t_0>(),
reduction_size,
stride);
}));
);
} else {
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()");
}
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
input.data<scalar_t_0>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL,
shift.has_value() ? shift.value().data<scalar_t>(): NULL,
out.data<scalar_t>(),
weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
shift.has_value() ? shift.value().data<scalar_t_0>(): NULL,
out.data<scalar_t_0>(),
reduction_size,
stride);
}));
);
}
return out;
}
......@@ -1231,22 +1231,22 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
at::Tensor semaphores;
if (grid.y > 1) {
staging_data = at::empty({2*stride*grid.y}, mean.options());
semaphores = at::zeros({grid.x}, input.options().dtype(at::ScalarType::Int));
semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
}
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half
if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value()
&& weight.value().type().scalarType() == at::ScalarType::Float) {
&& weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
using accscalar_t = at::acc_type<scalar_t_0, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
reduce_bn_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
grad_output.data<scalar_t>(),
input.data<scalar_t_0>(),
grad_output.data<scalar_t_0>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
......@@ -1257,32 +1257,32 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
semaphores_ptr,
reduction_size,
stride);
}));
);
} else {
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()");
}
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
using accscalar_t = at::acc_type<scalar_t_0, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
reduce_bn_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
grad_output.data<scalar_t>(),
input.data<scalar_t_0>(),
grad_output.data<scalar_t_0>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
weight.has_value() ? grad_weight.data<scalar_t>() : NULL,
weight.has_value() ?grad_bias.data<scalar_t>() : NULL,
weight.has_value() ? grad_weight.data<scalar_t_0>() : NULL,
weight.has_value() ?grad_bias.data<scalar_t_0>() : NULL,
staging_data_ptr,
semaphores_ptr,
reduction_size,
stride);
}));
);
}
return {mean_dy, mean_dy_xmu, grad_weight, grad_bias};
......@@ -1307,45 +1307,45 @@ at::Tensor batchnorm_backward_c_last_CUDA(
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) {
if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(),
input.data<scalar_t>(),
grad_output.data<scalar_t_0>(),
input.data<scalar_t_0>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
grad_input.data<scalar_t_0>(),
reduction_size,
stride);
}));
);
} else {
if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()");
}
using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(),
input.data<scalar_t>(),
grad_output.data<scalar_t_0>(),
input.data<scalar_t_0>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL,
weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
grad_input.data<scalar_t_0>(),
reduction_size,
stride);
}));
);
}
return grad_input;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment