Commit d900e93c authored by Michael Carilli's avatar Michael Carilli
Browse files

Merging in master

parents c978bda5 855808f3
...@@ -111,7 +111,7 @@ def check_optimizers(optimizers): ...@@ -111,7 +111,7 @@ def check_optimizers(optimizers):
if isinstance(optim, FP16_Optimizer_for_fused): if isinstance(optim, FP16_Optimizer_for_fused):
bad_optim_type = "apex.optimizers.FP16_Optimizer" bad_optim_type = "apex.optimizers.FP16_Optimizer"
if bad_optim_type is not None: if bad_optim_type is not None:
raise RuntimeError("An incoming optimizer is an instance of {}. ".format(optim_type) + raise RuntimeError("An incoming optimizer is an instance of {}. ".format(bad_optim_type) +
"The optimizer(s) passed to amp.initialize() must be bare \n" "The optimizer(s) passed to amp.initialize() must be bare \n"
"instances of either ordinary Pytorch optimizers, or Apex fused \n" "instances of either ordinary Pytorch optimizers, or Apex fused \n"
"optimizers (FusedAdam or FusedSGD). \n" "optimizers (FusedAdam or FusedSGD). \n"
...@@ -132,7 +132,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs ...@@ -132,7 +132,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
optimizers = [] optimizers = []
elif isinstance(optimizers, list): elif isinstance(optimizers, list):
optimizers_was_list = True optimizers_was_list = True
check_optimizers(optimizers)
else: else:
check_optimizers([optimizers])
raise TypeError("optimizers must be either a single optimizer or a list of optimizers.") raise TypeError("optimizers must be either a single optimizer or a list of optimizers.")
if isinstance(models, torch.nn.Module): if isinstance(models, torch.nn.Module):
...@@ -148,8 +150,6 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs ...@@ -148,8 +150,6 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
if not _amp_state.allow_incoming_model_not_fp32: if not _amp_state.allow_incoming_model_not_fp32:
check_params_fp32(models) check_params_fp32(models)
check_optimizers(optimizers)
# In the future, when FP16_Optimizer can be deprecated and master weights can # In the future, when FP16_Optimizer can be deprecated and master weights can
# become an attribute, remember to stash master weights before casting the model. # become an attribute, remember to stash master weights before casting the model.
......
...@@ -182,19 +182,19 @@ void fused_adam_cuda( ...@@ -182,19 +182,19 @@ void fused_adam_cuda(
} }
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.type().scalarType() == at::ScalarType::Half) { if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients //all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type //dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(g.type(), "adam_cuda_kernel", ([&] { DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<accscalar_t>(), p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL, p_copy.numel() ? p_copy.data<scalar_t_0>() : NULL,
m.data<accscalar_t>(), m.data<accscalar_t>(),
v.data<accscalar_t>(), v.data<accscalar_t>(),
g.data<scalar_t>(), g.data<scalar_t_0>(),
beta1, beta1,
beta2, beta2,
eps, eps,
...@@ -203,16 +203,16 @@ void fused_adam_cuda( ...@@ -203,16 +203,16 @@ void fused_adam_cuda(
tsize, tsize,
(adamMode_t) mode, (adamMode_t) mode,
decay); decay);
})); )
} else { } else {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES(g.type(), "adam_cuda_kernel", ([&] { DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<scalar_t>(), p.data<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write NULL, //don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(), m.data<scalar_t_0>(),
v.data<scalar_t>(), v.data<scalar_t_0>(),
g.data<scalar_t>(), g.data<scalar_t_0>(),
beta1, beta1,
beta2, beta2,
eps, eps,
...@@ -221,7 +221,7 @@ void fused_adam_cuda( ...@@ -221,7 +221,7 @@ void fused_adam_cuda(
tsize, tsize,
(adamMode_t) mode, (adamMode_t) mode,
decay); decay);
})); );
} }
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
......
...@@ -129,7 +129,7 @@ std::vector<at::Tensor> layer_norm( ...@@ -129,7 +129,7 @@ std::vector<at::Tensor> layer_norm(
int n1,n2; int n1,n2;
check_args(input,normalized_shape,n1,n2); check_args(input,normalized_shape,n1,n2);
at::Tensor output = at::empty_like(input); at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : input.type().scalarType())); at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon); normalized_shape,NULL,NULL,epsilon);
...@@ -151,7 +151,7 @@ std::vector<at::Tensor> layer_norm_affine( ...@@ -151,7 +151,7 @@ std::vector<at::Tensor> layer_norm_affine(
int n1,n2; int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2); check_args(input,normalized_shape,gamma,beta,n1,n2);
at::Tensor output = at::empty_like(input); at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : input.type().scalarType())); at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon); normalized_shape,&gamma,&beta,epsilon);
......
...@@ -685,18 +685,18 @@ void cuda_layer_norm( ...@@ -685,18 +685,18 @@ void cuda_layer_norm(
double epsilon) double epsilon)
{ {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input->type(), "layer_norm_cuda_kernel", ([&] { DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
HostApplyLayerNorm( HostApplyLayerNorm(
output->data<scalar_t>(), output->data<scalar_t_0>(),
mean->data<accscalar_t>(), mean->data<accscalar_t>(),
invvar->data<accscalar_t>(), invvar->data<accscalar_t>(),
input->data<scalar_t>(), input->data<scalar_t_0>(),
n1,n2, n1,n2,
epsilon, epsilon,
gamma != NULL ? gamma->data<scalar_t>() : NULL, gamma != NULL ? gamma->data<scalar_t_0>() : NULL,
beta != NULL ? beta->data<scalar_t>() : NULL); beta != NULL ? beta->data<scalar_t_0>() : NULL);
})); )
} }
template<typename T, typename U> template<typename T, typename U>
...@@ -725,7 +725,7 @@ void HostLayerNormGradient( ...@@ -725,7 +725,7 @@ void HostLayerNormGradient(
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : input->type().scalarType())); at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input->scalar_type()));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>( cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout, dout,
...@@ -787,19 +787,19 @@ void cuda_layer_norm_gradient( ...@@ -787,19 +787,19 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta) at::Tensor* grad_beta)
{ {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input->type(), "cuComputeGradInput", ([&] { DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
HostLayerNormGradient( HostLayerNormGradient(
dout->data<scalar_t>(), dout->data<scalar_t_0>(),
mean->data<accscalar_t>(), mean->data<accscalar_t>(),
invvar->data<accscalar_t>(), invvar->data<accscalar_t>(),
input, input,
n1,n2, n1,n2,
gamma->data<scalar_t>(), gamma->data<scalar_t_0>(),
beta->data<scalar_t>(), beta->data<scalar_t_0>(),
epsilon, epsilon,
grad_input->data<scalar_t>(), grad_input->data<scalar_t_0>(),
grad_gamma->data<scalar_t>(), grad_gamma->data<scalar_t_0>(),
grad_beta->data<scalar_t>()); grad_beta->data<scalar_t_0>());
})); )
} }
...@@ -75,7 +75,7 @@ at::Tensor multi_tensor_l2norm_cuda( ...@@ -75,7 +75,7 @@ at::Tensor multi_tensor_l2norm_cuda(
at::Tensor noop_flag, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists) std::vector<std::vector<at::Tensor>> tensor_lists)
{ {
auto output = at::zeros({320}, tensor_lists[0][0].options().dtype(at::ScalarType::Float)); auto output = at::zeros({320}, tensor_lists[0][0].options().dtype(at::kFloat));
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>( multi_tensor_apply<1>(
......
...@@ -86,39 +86,15 @@ void multi_tensor_scale_cuda( ...@@ -86,39 +86,15 @@ void multi_tensor_scale_cuda(
// If build times suffer, think about where to put this dispatch, // If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply. // and what logic should be moved out of multi_tensor_apply.
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[0][0].type(), DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
"multi_tensor_scale_cuda", DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
[&] multi_tensor_apply<2>(
{ BLOCK_SIZE,
// using accscalar_t = acc_type<scalar_t, true>; chunk_size,
switch(tensor_lists[1][0].scalar_type()) noop_flag,
{ tensor_lists,
case at::ScalarType::Half: ScaleFunctor<scalar_t_0, scalar_t_1>(),
multi_tensor_apply<2>( scale); ))
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
ScaleFunctor<scalar_t, at::Half>(),
scale);
break;
case at::ScalarType::Float:
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
ScaleFunctor<scalar_t, float>(),
scale);
break;
default:
std::stringstream ss;
ss << "multi_tensor_scale_cuda not implemented for output type = "
<< tensor_lists[1][0].dtype();
AT_ERROR(ss.str().c_str());
}
});
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize()); // AT_CUDA_CHECK(cudaDeviceSynchronize());
......
...@@ -3,15 +3,15 @@ ...@@ -3,15 +3,15 @@
// Forward/backward compatiblity hack around // Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 // https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream. // pending more future-proof guidance from upstream.
struct TypeShim // struct TypeShim
{ // {
const at::Type& payload; // const at::Type& payload;
TypeShim(const at::Type& type) : payload(type) {} // TypeShim(const at::Type& type) : payload(type) {}
// Enable trivial conversion to a const at::Type& for pre-3aeb78 // // Enable trivial conversion to a const at::Type& for pre-3aeb78
operator const at::Type&(){ return payload; }; // operator const at::Type&(){ return payload; };
// Enable dispatch switch statements to take *this directly for post-3aeb78 // // Enable dispatch switch statements to take *this directly for post-3aeb78
operator at::ScalarType(){ return payload.scalarType(); }; // //operator at::ScalarType(){ return payload.; };
}; // };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \ switch(TYPE) \
...@@ -33,6 +33,52 @@ struct TypeShim ...@@ -33,6 +33,52 @@ struct TypeShim
} }
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T> template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes __device__ __forceinline__ T reduce_block_into_lanes
(T *x, (T *x,
......
...@@ -182,14 +182,14 @@ __host__ int get_tensor_spatial_size(const at::Tensor& input) ...@@ -182,14 +182,14 @@ __host__ int get_tensor_spatial_size(const at::Tensor& input)
// promote accumulation scalar type. promote half to float. // promote accumulation scalar type. promote half to float.
__host__ at::ScalarType promote_scalartype(const at::Tensor& input) __host__ at::ScalarType promote_scalartype(const at::Tensor& input)
{ {
return input.type().scalarType() == at::ScalarType::Half ? return input.scalar_type() == at::ScalarType::Half ?
at::ScalarType::Float : input.type().scalarType(); at::ScalarType::Float : input.scalar_type();
} }
// return single element size, optional accumulation type promotion. // return single element size, optional accumulation type promotion.
__host__ size_t get_element_data_size(const at::Tensor& input, bool accumulation = false) __host__ size_t get_element_data_size(const at::Tensor& input, bool accumulation = false)
{ {
auto scalar_type = accumulation ? promote_scalartype(input) : input.type().scalarType(); auto scalar_type = accumulation ? promote_scalartype(input) : input.scalar_type();
return at::elementSize(scalar_type); return at::elementSize(scalar_type);
} }
...@@ -846,16 +846,16 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) { ...@@ -846,16 +846,16 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
{ {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "welford_mean_var_kernel", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_kernel",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
welford_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( welford_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
out_mean.data<accscalar_t>(), out_mean.data<accscalar_t>(),
out_var_biased.data<accscalar_t>(), out_var_biased.data<accscalar_t>(),
batch_size, batch_size,
feature_size, feature_size,
space_size); space_size);
})); );
} }
return {out_mean, out_var_biased}; return {out_mean, out_var_biased};
...@@ -881,40 +881,40 @@ at::Tensor batchnorm_forward_CUDA( ...@@ -881,40 +881,40 @@ at::Tensor batchnorm_forward_CUDA(
const dim3 grid(feature_size, batch_group_size, grid_z); const dim3 grid(feature_size, batch_group_size, grid_z);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && && weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) { weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( batchnorm_forward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL, weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>() : NULL, shift.has_value() ? shift.value().data<accscalar_t>() : NULL,
out.data<scalar_t>(), out.data<scalar_t_0>(),
space_size, space_size,
batch_size); batch_size);
})); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( batchnorm_forward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL, weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
shift.has_value() ? shift.value().data<scalar_t>() : NULL, shift.has_value() ? shift.value().data<scalar_t_0>() : NULL,
out.data<scalar_t>(), out.data<scalar_t_0>(),
space_size, space_size,
batch_size); batch_size);
})); );
} }
return out; return out;
} }
...@@ -952,15 +952,15 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -952,15 +952,15 @@ std::vector<at::Tensor> reduce_bn_CUDA(
const dim3 grid(feature_size); const dim3 grid(feature_size);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && && weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) { weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
reduce_bn_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( reduce_bn_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
grad_output.data<scalar_t>(), grad_output.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
...@@ -970,28 +970,28 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -970,28 +970,28 @@ std::vector<at::Tensor> reduce_bn_CUDA(
batch_size, batch_size,
feature_size, feature_size,
space_size); space_size);
})); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
reduce_bn_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( reduce_bn_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
grad_output.data<scalar_t>(), grad_output.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
weight.has_value() ? grad_weight.data<scalar_t>() : NULL, weight.has_value() ? grad_weight.data<scalar_t_0>() : NULL,
weight.has_value() ? grad_bias.data<scalar_t>() : NULL, weight.has_value() ? grad_bias.data<scalar_t_0>() : NULL,
batch_size, batch_size,
feature_size, feature_size,
space_size); space_size);
})); );
} }
return {mean_dy, mean_dy_xmu, grad_weight, grad_bias}; return {mean_dy, mean_dy_xmu, grad_weight, grad_bias};
...@@ -1021,44 +1021,44 @@ at::Tensor batchnorm_backward_CUDA( ...@@ -1021,44 +1021,44 @@ at::Tensor batchnorm_backward_CUDA(
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && && weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) { weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( batchnorm_backward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(), grad_output.data<scalar_t_0>(),
input.data<scalar_t>(), input.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL, weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(), grad_input.data<scalar_t_0>(),
space_size, space_size,
batch_size); batch_size);
})); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( batchnorm_backward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(), grad_output.data<scalar_t_0>(),
input.data<scalar_t>(), input.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL, weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(), grad_input.data<scalar_t_0>(),
space_size, space_size,
batch_size); batch_size);
})); );
} }
return grad_input; return grad_input;
...@@ -1083,18 +1083,18 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node ...@@ -1083,18 +1083,18 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
{ {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(mean_feature_nodes.type(), "welford_parallel_kernel", ([&] { DISPATCH_FLOAT_AND_HALF(mean_feature_nodes.scalar_type(), 0, "welford_parallel_kernel",
welford_kernel_parallel<scalar_t><<<grid, block, 0, stream>>>( welford_kernel_parallel<scalar_t_0><<<grid, block, 0, stream>>>(
mean_feature_nodes.data<scalar_t>(), mean_feature_nodes.data<scalar_t_0>(),
var_biased.data<scalar_t>(), var_biased.data<scalar_t_0>(),
out_mean.data<scalar_t>(), out_mean.data<scalar_t_0>(),
out_var.data<scalar_t>(), out_var.data<scalar_t_0>(),
inv_std.data<scalar_t>(), inv_std.data<scalar_t_0>(),
world_size, world_size,
feature_size, feature_size,
eps, eps,
numel); numel);
})); );
} }
return {out_mean, out_var, inv_std}; return {out_mean, out_var, inv_std};
...@@ -1118,27 +1118,27 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) { ...@@ -1118,27 +1118,27 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
at::Tensor semaphores; at::Tensor semaphores;
if (grid.y > 1) { if (grid.y > 1) {
staging_data = at::empty({4*stride*grid.y}, option); staging_data = at::empty({4*stride*grid.y}, option);
semaphores = at::zeros({grid.x}, input.options().dtype(at::ScalarType::Int)); semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
} }
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
{ {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "welford_mean_var_c_last", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_c_last",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr; accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr; int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
welford_kernel_c_last<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> welford_kernel_c_last<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
out_mean.data<accscalar_t>(), out_mean.data<accscalar_t>(),
out_var_biased.data<accscalar_t>(), out_var_biased.data<accscalar_t>(),
staging_data_ptr, staging_data_ptr,
semaphores_ptr, semaphores_ptr,
reduction_size, reduction_size,
stride); stride);
})); );
} }
return {out_mean, out_var_biased}; return {out_mean, out_var_biased};
...@@ -1161,41 +1161,41 @@ at::Tensor batchnorm_forward_c_last_CUDA( ...@@ -1161,41 +1161,41 @@ at::Tensor batchnorm_forward_c_last_CUDA(
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) { && weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL, weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>(): NULL, shift.has_value() ? shift.value().data<accscalar_t>(): NULL,
out.data<scalar_t>(), out.data<scalar_t_0>(),
reduction_size, reduction_size,
stride); stride);
})); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER> batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL, weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
shift.has_value() ? shift.value().data<scalar_t>(): NULL, shift.has_value() ? shift.value().data<scalar_t_0>(): NULL,
out.data<scalar_t>(), out.data<scalar_t_0>(),
reduction_size, reduction_size,
stride); stride);
})); );
} }
return out; return out;
} }
...@@ -1231,22 +1231,22 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA( ...@@ -1231,22 +1231,22 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
at::Tensor semaphores; at::Tensor semaphores;
if (grid.y > 1) { if (grid.y > 1) {
staging_data = at::empty({2*stride*grid.y}, mean.options()); staging_data = at::empty({2*stride*grid.y}, mean.options());
semaphores = at::zeros({grid.x}, input.options().dtype(at::ScalarType::Int)); semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt));
} }
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && weight.has_value()
&& weight.value().type().scalarType() == at::ScalarType::Float) { && weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr; accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr; int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
reduce_bn_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
grad_output.data<scalar_t>(), grad_output.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
...@@ -1257,32 +1257,32 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA( ...@@ -1257,32 +1257,32 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
semaphores_ptr, semaphores_ptr,
reduction_size, reduction_size,
stride); stride);
})); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr; accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr; int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
reduce_bn_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER> reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t_0>(),
grad_output.data<scalar_t>(), grad_output.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
weight.has_value() ? grad_weight.data<scalar_t>() : NULL, weight.has_value() ? grad_weight.data<scalar_t_0>() : NULL,
weight.has_value() ?grad_bias.data<scalar_t>() : NULL, weight.has_value() ?grad_bias.data<scalar_t_0>() : NULL,
staging_data_ptr, staging_data_ptr,
semaphores_ptr, semaphores_ptr,
reduction_size, reduction_size,
stride); stride);
})); );
} }
return {mean_dy, mean_dy_xmu, grad_weight, grad_bias}; return {mean_dy, mean_dy_xmu, grad_weight, grad_bias};
...@@ -1307,45 +1307,45 @@ at::Tensor batchnorm_backward_c_last_CUDA( ...@@ -1307,45 +1307,45 @@ at::Tensor batchnorm_backward_c_last_CUDA(
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) { && weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(), grad_output.data<scalar_t_0>(),
input.data<scalar_t>(), input.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL, weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(), grad_input.data<scalar_t_0>(),
reduction_size, reduction_size,
stride); stride);
})); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER> batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(), grad_output.data<scalar_t_0>(),
input.data<scalar_t>(), input.data<scalar_t_0>(),
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t>() : NULL, weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
mean_dy.data<accscalar_t>(), mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(), grad_input.data<scalar_t_0>(),
reduction_size, reduction_size,
stride); stride);
})); );
} }
return grad_input; return grad_input;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment