Commit 855808f3 authored by ptrblck's avatar ptrblck Committed by mcarilli
Browse files

Replace type().ScalarType() with scalar_type() (#272)

* change .type().ScalarType() to .scalar_type() + at::ScalarType::X to at::kX

* revert scalar_type() to type() for AT_DISPATCH_FLOATING_TYPES_AND_HALF

* revert scalar_type() to type() in AT_DISPATCH_FLOATING_TYPES

* revert scalar_type() to type() for AT_DISPATCH_FLOATING_TYPES_AND_HALF in welford.cu

* revert scalar_type() to type() in layer_norm_cuda_kernel.cu

* revert at::kType  to at::ScalarType::Type

* use DISPATCH_FLOAT_AND_HALF to get rid of warnings

* add dispatch mechanisms for double+float and double+float+half
parent 1c464b48
...@@ -91,19 +91,19 @@ void fused_adam_cuda( ...@@ -91,19 +91,19 @@ void fused_adam_cuda(
} }
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.type().scalarType() == at::ScalarType::Half) { if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients //all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type //dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(g.type(), "adam_cuda_kernel", ([&] { DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<accscalar_t>(), p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL, p_copy.numel() ? p_copy.data<scalar_t_0>() : NULL,
m.data<accscalar_t>(), m.data<accscalar_t>(),
v.data<accscalar_t>(), v.data<accscalar_t>(),
g.data<scalar_t>(), g.data<scalar_t_0>(),
beta1, beta1,
beta2, beta2,
eps, eps,
...@@ -112,16 +112,16 @@ void fused_adam_cuda( ...@@ -112,16 +112,16 @@ void fused_adam_cuda(
tsize, tsize,
(adamMode_t) mode, (adamMode_t) mode,
decay); decay);
})); )
} else { } else {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES(g.type(), "adam_cuda_kernel", ([&] { DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
adam_cuda_kernel<scalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<scalar_t>(), p.data<scalar_t_0>(),
NULL, //don't output p_copy for fp32, it's wasted write NULL, //don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(), m.data<scalar_t_0>(),
v.data<scalar_t>(), v.data<scalar_t_0>(),
g.data<scalar_t>(), g.data<scalar_t_0>(),
beta1, beta1,
beta2, beta2,
eps, eps,
...@@ -130,7 +130,7 @@ void fused_adam_cuda( ...@@ -130,7 +130,7 @@ void fused_adam_cuda(
tsize, tsize,
(adamMode_t) mode, (adamMode_t) mode,
decay); decay);
})); );
} }
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
......
...@@ -129,7 +129,7 @@ std::vector<at::Tensor> layer_norm( ...@@ -129,7 +129,7 @@ std::vector<at::Tensor> layer_norm(
int n1,n2; int n1,n2;
check_args(input,normalized_shape,n1,n2); check_args(input,normalized_shape,n1,n2);
at::Tensor output = at::empty_like(input); at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : input.type().scalarType())); at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon); normalized_shape,NULL,NULL,epsilon);
...@@ -151,7 +151,7 @@ std::vector<at::Tensor> layer_norm_affine( ...@@ -151,7 +151,7 @@ std::vector<at::Tensor> layer_norm_affine(
int n1,n2; int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2); check_args(input,normalized_shape,gamma,beta,n1,n2);
at::Tensor output = at::empty_like(input); at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : input.type().scalarType())); at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon); normalized_shape,&gamma,&beta,epsilon);
......
...@@ -685,18 +685,18 @@ void cuda_layer_norm( ...@@ -685,18 +685,18 @@ void cuda_layer_norm(
double epsilon) double epsilon)
{ {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input->type(), "layer_norm_cuda_kernel", ([&] { DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
HostApplyLayerNorm( HostApplyLayerNorm(
output->data<scalar_t>(), output->data<scalar_t_0>(),
mean->data<accscalar_t>(), mean->data<accscalar_t>(),
invvar->data<accscalar_t>(), invvar->data<accscalar_t>(),
input->data<scalar_t>(), input->data<scalar_t_0>(),
n1,n2, n1,n2,
epsilon, epsilon,
gamma != NULL ? gamma->data<scalar_t>() : NULL, gamma != NULL ? gamma->data<scalar_t_0>() : NULL,
beta != NULL ? beta->data<scalar_t>() : NULL); beta != NULL ? beta->data<scalar_t_0>() : NULL);
})); )
} }
template<typename T, typename U> template<typename T, typename U>
...@@ -725,7 +725,7 @@ void HostLayerNormGradient( ...@@ -725,7 +725,7 @@ void HostLayerNormGradient(
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : input->type().scalarType())); at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input->scalar_type()));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>( cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout, dout,
...@@ -787,19 +787,19 @@ void cuda_layer_norm_gradient( ...@@ -787,19 +787,19 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta) at::Tensor* grad_beta)
{ {
using namespace at; using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input->type(), "cuComputeGradInput", ([&] { DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput",
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
HostLayerNormGradient( HostLayerNormGradient(
dout->data<scalar_t>(), dout->data<scalar_t_0>(),
mean->data<accscalar_t>(), mean->data<accscalar_t>(),
invvar->data<accscalar_t>(), invvar->data<accscalar_t>(),
input, input,
n1,n2, n1,n2,
gamma->data<scalar_t>(), gamma->data<scalar_t_0>(),
beta->data<scalar_t>(), beta->data<scalar_t_0>(),
epsilon, epsilon,
grad_input->data<scalar_t>(), grad_input->data<scalar_t_0>(),
grad_gamma->data<scalar_t>(), grad_gamma->data<scalar_t_0>(),
grad_beta->data<scalar_t>()); grad_beta->data<scalar_t_0>());
})); )
} }
...@@ -75,7 +75,7 @@ at::Tensor multi_tensor_l2norm_cuda( ...@@ -75,7 +75,7 @@ at::Tensor multi_tensor_l2norm_cuda(
at::Tensor noop_flag, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists) std::vector<std::vector<at::Tensor>> tensor_lists)
{ {
auto output = at::zeros({320}, tensor_lists[0][0].options().dtype(at::ScalarType::Float)); auto output = at::zeros({320}, tensor_lists[0][0].options().dtype(at::kFloat));
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>( multi_tensor_apply<1>(
......
...@@ -86,39 +86,15 @@ void multi_tensor_scale_cuda( ...@@ -86,39 +86,15 @@ void multi_tensor_scale_cuda(
// If build times suffer, think about where to put this dispatch, // If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply. // and what logic should be moved out of multi_tensor_apply.
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[0][0].type(), DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
"multi_tensor_scale_cuda", DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
[&]
{
// using accscalar_t = acc_type<scalar_t, true>;
switch(tensor_lists[1][0].scalar_type())
{
case at::ScalarType::Half:
multi_tensor_apply<2>( multi_tensor_apply<2>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, noop_flag,
tensor_lists, tensor_lists,
ScaleFunctor<scalar_t, at::Half>(), ScaleFunctor<scalar_t_0, scalar_t_1>(),
scale); scale); ))
break;
case at::ScalarType::Float:
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
ScaleFunctor<scalar_t, float>(),
scale);
break;
default:
std::stringstream ss;
ss << "multi_tensor_scale_cuda not implemented for output type = "
<< tensor_lists[1][0].dtype();
AT_ERROR(ss.str().c_str());
}
});
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize()); // AT_CUDA_CHECK(cudaDeviceSynchronize());
......
...@@ -3,15 +3,15 @@ ...@@ -3,15 +3,15 @@
// Forward/backward compatiblity hack around // Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 // https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream. // pending more future-proof guidance from upstream.
struct TypeShim // struct TypeShim
{ // {
const at::Type& payload; // const at::Type& payload;
TypeShim(const at::Type& type) : payload(type) {} // TypeShim(const at::Type& type) : payload(type) {}
// Enable trivial conversion to a const at::Type& for pre-3aeb78 // // Enable trivial conversion to a const at::Type& for pre-3aeb78
operator const at::Type&(){ return payload; }; // operator const at::Type&(){ return payload; };
// Enable dispatch switch statements to take *this directly for post-3aeb78 // // Enable dispatch switch statements to take *this directly for post-3aeb78
operator at::ScalarType(){ return payload.scalarType(); }; // //operator at::ScalarType(){ return payload.; };
}; // };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \ switch(TYPE) \
...@@ -33,6 +33,52 @@ struct TypeShim ...@@ -33,6 +33,52 @@ struct TypeShim
} }
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T> template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes __device__ __forceinline__ T reduce_block_into_lanes
(T *x, (T *x,
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment