Commit 42180bd9 authored by Michael Carilli's avatar Michael Carilli
Browse files

Forward/backward compatibility around pytorch 3aeb78, to fix #191

parent 975ed322
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include "ATen/AccumulateType.h" #include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h> #include <THC/THCGeneral.h>
#include "type_shim.h"
typedef enum{ typedef enum{
ADAM_MODE_0 =0, // eps under square root ADAM_MODE_0 =0, // eps under square root
ADAM_MODE_1 =1 // eps outside square root ADAM_MODE_1 =1 // eps outside square root
...@@ -29,8 +31,8 @@ __global__ void adam_cuda_kernel( ...@@ -29,8 +31,8 @@ __global__ void adam_cuda_kernel(
const float step_size, const float step_size,
const size_t tsize, const size_t tsize,
adamMode_t mode, adamMode_t mode,
const float decay) { const float decay)
{
//Assuming 2D grids and 2D blocks //Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x; const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y; const int threadsPerBlock = blockDim.x * blockDim.y;
...@@ -67,7 +69,9 @@ void fused_adam_cuda( ...@@ -67,7 +69,9 @@ void fused_adam_cuda(
int step, int step,
int mode, int mode,
int bias_correction, int bias_correction,
float decay) { float decay)
{
// using namespace at;
//Get tensor size //Get tensor size
int tsize = p.numel(); int tsize = p.numel();
...@@ -91,7 +95,8 @@ void fused_adam_cuda( ...@@ -91,7 +95,8 @@ void fused_adam_cuda(
//all other values should be fp32 for half gradients //all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type //dispatch is done on the gradient type
AT_DISPATCH_FLOATING_TYPES_AND_HALF(g.type(), "adam_cuda_kernel", ([&] { using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(g.type()), "adam_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
adam_cuda_kernel<accscalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<accscalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<accscalar_t>(), p.data<accscalar_t>(),
...@@ -109,7 +114,8 @@ void fused_adam_cuda( ...@@ -109,7 +114,8 @@ void fused_adam_cuda(
decay); decay);
})); }));
} else { } else {
AT_DISPATCH_FLOATING_TYPES(g.type(), "adam_cuda_kernel", ([&] { using namespace at;
AT_DISPATCH_FLOATING_TYPES(TypeShim(g.type()), "adam_cuda_kernel", ([&] {
adam_cuda_kernel<scalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<scalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<scalar_t>(), p.data<scalar_t>(),
NULL, //don't output p_copy for fp32, it's wasted write NULL, //don't output p_copy for fp32, it's wasted write
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "type_shim.h"
template<typename U> __device__ template<typename U> __device__
void cuWelfordOnlineSum( void cuWelfordOnlineSum(
const U curr, const U curr,
...@@ -675,7 +677,8 @@ void cuda_layer_norm( ...@@ -675,7 +677,8 @@ void cuda_layer_norm(
at::Tensor* beta, at::Tensor* beta,
double epsilon) double epsilon)
{ {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input->type(), "layer_norm_cuda_kernel", ([&] { using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input->type()), "layer_norm_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
HostApplyLayerNorm( HostApplyLayerNorm(
output->data<scalar_t>(), output->data<scalar_t>(),
...@@ -772,7 +775,8 @@ void cuda_layer_norm_gradient( ...@@ -772,7 +775,8 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_gamma, at::Tensor* grad_gamma,
at::Tensor* grad_beta) at::Tensor* grad_beta)
{ {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input->type(), "cuComputeGradInput", ([&] { using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input->type()), "cuComputeGradInput", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
HostLayerNormGradient( HostLayerNormGradient(
dout->data<scalar_t>(), dout->data<scalar_t>(),
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
template<int n> struct TensorList template<int n> struct TensorListMetadata
{ {
void* addresses[n][depth_to_max_tensors[n-1]]; void* addresses[n][depth_to_max_tensors[n-1]];
int sizes[depth_to_max_tensors[n-1]]; int sizes[depth_to_max_tensors[n-1]];
...@@ -62,7 +62,7 @@ void multi_tensor_apply( ...@@ -62,7 +62,7 @@ void multi_tensor_apply(
int ntensors = tensor_lists[0].size(); int ntensors = tensor_lists[0].size();
TensorList<depth> tl; TensorListMetadata<depth> tl;
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
......
...@@ -2,9 +2,15 @@ ...@@ -2,9 +2,15 @@
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh" // Another possibility:
// #include <torch/all.h>
#include <assert.h> #include <assert.h>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <sstream>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
...@@ -15,7 +21,7 @@ struct ScaleFunctor ...@@ -15,7 +21,7 @@ struct ScaleFunctor
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* noop_gmem,
TensorList<2>& tl, TensorListMetadata<2>& tl,
float scale) float scale)
{ {
__shared__ int noop_smem; __shared__ int noop_smem;
...@@ -87,15 +93,17 @@ void multi_tensor_scale_cuda( ...@@ -87,15 +93,17 @@ void multi_tensor_scale_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
float scale) float scale)
{ {
using namespace at;
// The output (downscaled) type is always float. // The output (downscaled) type is always float.
// If build times suffer, think about where to put this dispatch, // If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply. // and what logic should be moved out of multi_tensor_apply.
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[0][0].type(),
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(tensor_lists[0][0].type()),
"multi_tensor_scale_cuda", "multi_tensor_scale_cuda",
[&] [&]
{ {
// using accscalar_t = acc_type<scalar_t, true>; // using accscalar_t = acc_type<scalar_t, true>;
switch(tensor_lists[1][0].type().scalarType()) switch(tensor_lists[1][0].scalar_type())
{ {
case at::ScalarType::Half: case at::ScalarType::Half:
multi_tensor_apply<2>( multi_tensor_apply<2>(
...@@ -116,8 +124,10 @@ void multi_tensor_scale_cuda( ...@@ -116,8 +124,10 @@ void multi_tensor_scale_cuda(
scale); scale);
break; break;
default: default:
AT_ERROR("multi_tensor_scale_cuda not implemented for output type = ", std::stringstream ss;
tensor_lists[1][0].type().toString()); ss << "multi_tensor_scale_cuda not implemented for output type = "
<< tensor_lists[1][0].dtype();
AT_ERROR(ss.str().c_str());
} }
}); });
......
#include <ATen/ATen.h>
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
struct TypeShim
{
const at::Type& payload;
TypeShim(const at::Type& type) : payload(type) {}
// Enable trivial conversion to a const at::Type& for pre-3aeb78
operator const at::Type&(){ return payload; };
// Enable dispatch switch statements to take *this directly for post-3aeb78
operator at::ScalarType(){ return payload.scalarType(); };
};
...@@ -3,13 +3,13 @@ ...@@ -3,13 +3,13 @@
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <vector> #include <vector>
#include "type_shim.h"
__device__ __forceinline__ int lastpow2(int n) __device__ __forceinline__ int lastpow2(int n)
{ {
...@@ -844,16 +844,19 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) { ...@@ -844,16 +844,19 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "welford_mean_var_kernel", ([&] { {
using accscalar_t = at::acc_type<scalar_t, true>; using namespace at;
welford_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "welford_mean_var_kernel", ([&] {
input.data<scalar_t>(), using accscalar_t = at::acc_type<scalar_t, true>;
out_mean.data<accscalar_t>(), welford_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
out_var_biased.data<accscalar_t>(), input.data<scalar_t>(),
batch_size, out_mean.data<accscalar_t>(),
feature_size, out_var_biased.data<accscalar_t>(),
space_size); batch_size,
})); feature_size,
space_size);
}));
}
return {out_mean, out_var_biased}; return {out_mean, out_var_biased};
} }
...@@ -881,7 +884,8 @@ at::Tensor batchnorm_forward_CUDA( ...@@ -881,7 +884,8 @@ at::Tensor batchnorm_forward_CUDA(
if (input.type().scalarType() == at::ScalarType::Half if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() && && weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) { weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( batchnorm_forward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t>(),
...@@ -898,7 +902,8 @@ at::Tensor batchnorm_forward_CUDA( ...@@ -898,7 +902,8 @@ at::Tensor batchnorm_forward_CUDA(
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.type().scalarType() is not supported with weight.type().scalarType()");
} }
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t>(),
...@@ -950,7 +955,8 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -950,7 +955,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
if (input.type().scalarType() == at::ScalarType::Half if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() && && weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) { weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] { using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( reduce_bn_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t>(),
...@@ -970,7 +976,8 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -970,7 +976,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.type().scalarType() is not supported with weight.type().scalarType()");
} }
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] { using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( reduce_bn_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(), input.data<scalar_t>(),
...@@ -1017,7 +1024,8 @@ at::Tensor batchnorm_backward_CUDA( ...@@ -1017,7 +1024,8 @@ at::Tensor batchnorm_backward_CUDA(
if (input.type().scalarType() == at::ScalarType::Half if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() && && weight.has_value() &&
weight.value().type().scalarType() == at::ScalarType::Float) { weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] { using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( batchnorm_backward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(), grad_output.data<scalar_t>(),
...@@ -1036,7 +1044,8 @@ at::Tensor batchnorm_backward_CUDA( ...@@ -1036,7 +1044,8 @@ at::Tensor batchnorm_backward_CUDA(
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.type().scalarType() is not supported with weight.type().scalarType()");
} }
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] { using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( batchnorm_backward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(), grad_output.data<scalar_t>(),
...@@ -1072,18 +1081,21 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node ...@@ -1072,18 +1081,21 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(mean_feature_nodes.type(), "welford_parallel_kernel", ([&] { {
welford_kernel_parallel<scalar_t><<<grid, block, 0, stream>>>( using namespace at;
mean_feature_nodes.data<scalar_t>(), AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(mean_feature_nodes.type()), "welford_parallel_kernel", ([&] {
var_biased.data<scalar_t>(), welford_kernel_parallel<scalar_t><<<grid, block, 0, stream>>>(
out_mean.data<scalar_t>(), mean_feature_nodes.data<scalar_t>(),
out_var.data<scalar_t>(), var_biased.data<scalar_t>(),
inv_std.data<scalar_t>(), out_mean.data<scalar_t>(),
world_size, out_var.data<scalar_t>(),
feature_size, inv_std.data<scalar_t>(),
eps, world_size,
numel); feature_size,
})); eps,
numel);
}));
}
return {out_mean, out_var, inv_std}; return {out_mean, out_var, inv_std};
} }
...@@ -1111,21 +1123,23 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) { ...@@ -1111,21 +1123,23 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "welford_mean_var_c_last", ([&] { using namespace at;
using accscalar_t = at::acc_type<scalar_t, true>; AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "welford_mean_var_c_last", ([&] {
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr; using accscalar_t = at::acc_type<scalar_t, true>;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr; accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
welford_kernel_c_last<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
<<<grid, block, 0, stream>>>( welford_kernel_c_last<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
input.data<scalar_t>(), <<<grid, block, 0, stream>>>(
out_mean.data<accscalar_t>(), input.data<scalar_t>(),
out_var_biased.data<accscalar_t>(), out_mean.data<accscalar_t>(),
staging_data_ptr, out_var_biased.data<accscalar_t>(),
semaphores_ptr, staging_data_ptr,
reduction_size, semaphores_ptr,
stride); reduction_size,
})); stride);
}));
}
return {out_mean, out_var_biased}; return {out_mean, out_var_biased};
} }
...@@ -1149,7 +1163,8 @@ at::Tensor batchnorm_forward_c_last_CUDA( ...@@ -1149,7 +1163,8 @@ at::Tensor batchnorm_forward_c_last_CUDA(
if (input.type().scalarType() == at::ScalarType::Half if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) { && weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
...@@ -1167,7 +1182,8 @@ at::Tensor batchnorm_forward_c_last_CUDA( ...@@ -1167,7 +1182,8 @@ at::Tensor batchnorm_forward_c_last_CUDA(
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.type().scalarType() is not supported with weight.type().scalarType()");
} }
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER> batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
...@@ -1222,7 +1238,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA( ...@@ -1222,7 +1238,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
if (input.type().scalarType() == at::ScalarType::Half if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() && weight.has_value()
&& weight.value().type().scalarType() == at::ScalarType::Float) { && weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] { using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr; accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr; int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
...@@ -1246,7 +1263,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA( ...@@ -1246,7 +1263,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.type().scalarType() is not supported with weight.type().scalarType()");
} }
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] { using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr; accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr; int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
...@@ -1291,7 +1309,8 @@ at::Tensor batchnorm_backward_c_last_CUDA( ...@@ -1291,7 +1309,8 @@ at::Tensor batchnorm_backward_c_last_CUDA(
if (input.type().scalarType() == at::ScalarType::Half if (input.type().scalarType() == at::ScalarType::Half
&& weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) { && weight.has_value() && weight.value().type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
...@@ -1311,7 +1330,8 @@ at::Tensor batchnorm_backward_c_last_CUDA( ...@@ -1311,7 +1330,8 @@ at::Tensor batchnorm_backward_c_last_CUDA(
AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(), AT_CHECK(input.type().scalarType() == weight.value().type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()"); "input.type().scalarType() is not supported with weight.type().scalarType()");
} }
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { using namespace at;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input.type()), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER> batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
......
...@@ -26,7 +26,7 @@ override the defaults established by the ``opt_level``. ...@@ -26,7 +26,7 @@ override the defaults established by the ``opt_level``.
Example:: Example::
# Declare model and optimizer as usual # Declare model and optimizer as usual, with default (FP32) precision
model = torch.nn.Linear(D_in, D_out).cuda() model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
......
...@@ -55,8 +55,8 @@ if "--cuda_ext" in sys.argv: ...@@ -55,8 +55,8 @@ if "--cuda_ext" in sys.argv:
'--use_fast_math']})) '--use_fast_math']}))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fused_adam_cuda', CUDAExtension(name='fused_adam_cuda',
sources=['apex/optimizers/csrc/fused_adam_cuda.cpp', sources=['csrc/fused_adam_cuda.cpp',
'apex/optimizers/csrc/fused_adam_cuda_kernel.cu'], 'csrc/fused_adam_cuda_kernel.cu'],
extra_compile_args={'cxx': ['-O3',], extra_compile_args={'cxx': ['-O3',],
'nvcc':['-O3', 'nvcc':['-O3',
'--use_fast_math']})) '--use_fast_math']}))
...@@ -66,8 +66,8 @@ if "--cuda_ext" in sys.argv: ...@@ -66,8 +66,8 @@ if "--cuda_ext" in sys.argv:
'csrc/welford.cu'])) 'csrc/welford.cu']))
ext_modules.append( ext_modules.append(
CUDAExtension(name='fused_layer_norm_cuda', CUDAExtension(name='fused_layer_norm_cuda',
sources=['apex/normalization/csrc/layer_norm_cuda.cpp', sources=['csrc/layer_norm_cuda.cpp',
'apex/normalization/csrc/layer_norm_cuda_kernel.cu'], 'csrc/layer_norm_cuda_kernel.cu'],
extra_compile_args={'cxx': ['-O3'] + version_ge_1_1, extra_compile_args={'cxx': ['-O3'] + version_ge_1_1,
'nvcc':['-maxrregcount=50', 'nvcc':['-maxrregcount=50',
'-O3', '-O3',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment