Commit c7dcb0e1 authored by Michael Carilli's avatar Michael Carilli
Browse files

Merge branch 'master' into testing_cache_fix

parents c619fe6e 06e11bd3
......@@ -112,3 +112,6 @@ class NoOpHandle(object):
@property
def verbose(self):
return False
def _deactivate(self):
pass
......@@ -19,7 +19,7 @@ except ImportError:
warned_syncbn = True
from .sync_batchnorm import SyncBatchNorm
def convert_syncbn_model(module, process_group=None):
def convert_syncbn_model(module, process_group=None, channel_last=False):
'''
Recursively traverse module and its children to replace all
`torch.nn.modules.batchnorm._BatchNorm` with `apex.parallel.SyncBatchNorm`
......@@ -38,14 +38,16 @@ def convert_syncbn_model(module, process_group=None):
'''
mod = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group)
mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group, channel_last=channel_last)
mod.running_mean = module.running_mean
mod.running_var = module.running_var
if module.affine:
mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
mod.add_module(name, convert_syncbn_model(child))
mod.add_module(name, convert_syncbn_model(child,
process_group=process_group,
channel_last=channel_last))
# TODO(jie) should I delete model explicitly?
del module
return mod
......@@ -38,26 +38,43 @@ class SyncBatchNorm(_BatchNorm):
process_group: pass in a process group within which the stats of the
mini-batch is being synchronized. ``None`` for using default process
group
channel_last: a boolean value that when set to ``True``, this module
take the last dimension of the input tensor to be the channel
dimension. Default: False
Examples::
>>> # channel first tensor
>>> sbn = apex.parallel.SyncBatchNorm(100).cuda()
>>> inp = torch.randn(10, 100, 14, 14).cuda()
>>> out = sbn(inp)
>>> inp = torch.randn(3, 100, 20).cuda()
>>> out = sbn(inp)
>>> # channel last tensor
>>> sbn = apex.parallel.SyncBatchNorm(100, channel_last=True).cuda()
>>> inp = torch.randn(10, 14, 14, 100).cuda()
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last = False):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.process_group = process_group
self.channel_last = channel_last
def _specify_process_group(self, process_group):
self.process_group = process_group
def _specify_channel_last(self, channel_last):
self.channel_last = channel_last
def forward(self, input):
if not self.training and self.track_running_stats:
if not self.training and self.track_running_stats and not self.channel_last:
# fall back to pytorch implementation for inference
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else:
self.num_batches_tracked += 1
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.track_running_stats, self.momentum, self.process_group)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
self.num_batches_tracked += 1
if self.momentum is None:
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else:
exponential_average_factor = self.momentum
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, self.channel_last)
......@@ -7,26 +7,40 @@ from apex.parallel import ReduceOp
class SyncBatchnormFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None):
def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False):
torch.cuda.nvtx.range_push("sync_BN_fw")
input = input.contiguous()
world_size = 0
mean = None
var_biased = None
inv_std = None
var = None
out = None
count = None
if track_running_stats:
mean, var, var_biased = syncbn.welford_mean_var(input)
if channel_last:
count = int(input.numel()/input.size(-1))
mean, var_biased = syncbn.welford_mean_var_c_last(input)
else :
count = int(input.numel()/input.size(1))
mean, var_biased = syncbn.welford_mean_var(input)
if torch.distributed.is_initialized():
if not process_group:
process_group = torch.distributed.group.WORLD
world_size = torch.distributed.get_world_size(process_group)
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
var_all = torch.empty(world_size, var.size(0), dtype=var.dtype, device=var.device)
var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=var_biased.device)
mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)]
var_l = [var_all.narrow(0, i, 1) for i in range(world_size)]
torch.distributed.all_gather(mean_l, mean, process_group)
torch.distributed.all_gather(var_l, var_biased, process_group)
mean, var, var_biased = syncbn.welford_parallel(mean_all.transpose(1,0).contiguous(), var_all.transpose(1,0).contiguous(), int(input.numel()/input.size(1)))
mean, var, inv_std = syncbn.welford_parallel(mean_all, var_all, count, eps)
# TODO(Jie): should do fp32 math instead!
else:
inv_std = 1.0 / torch.sqrt(var_biased + eps)
var = var_biased * (count) / (count-1)
r_m_inc = mean if running_mean.dtype != torch.float16 else mean.half()
r_v_inc = var if running_variance.dtype != torch.float16 else var.half()
......@@ -34,14 +48,17 @@ class SyncBatchnormFunction(Function):
running_variance.data = running_variance.data * (1-momentum) + momentum*r_v_inc
else:
mean = running_mean.data
var_biased = running_var.data
inv_std = 1.0 / torch.sqrt(running_var.data + eps)
ctx.save_for_backward(input, weight, mean, var_biased)
ctx.eps = eps
ctx.save_for_backward(input, weight, mean, inv_std)
ctx.process_group = process_group
ctx.channel_last = channel_last
ctx.world_size = world_size
out = syncbn.batchnorm_forward(input, mean, var_biased, weight, bias, eps)
if channel_last:
out = syncbn.batchnorm_forward_c_last(input, mean, inv_std, weight, bias)
else:
out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias)
torch.cuda.nvtx.range_pop()
return out
......@@ -53,14 +70,17 @@ class SyncBatchnormFunction(Function):
# mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input, weight, running_mean, running_variance = ctx.saved_tensors
eps = ctx.eps
saved_input, weight, mean, inv_std = ctx.saved_tensors
process_group = ctx.process_group
channel_last = ctx.channel_last
world_size = ctx.world_size
grad_input = grad_weight = grad_bias = None
# TODO(jie): why do I have to clone here? life time of grad_output?
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, running_mean, running_variance, weight, eps)
if channel_last:
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(grad_output, saved_input, mean, inv_std, weight)
else:
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, mean, inv_std, weight)
# calculate grad_input
if ctx.needs_input_grad[0]:
......@@ -72,7 +92,10 @@ class SyncBatchnormFunction(Function):
torch.distributed.all_reduce(
mean_dy_xmu, ReduceOp.SUM, process_group)
mean_dy_xmu = mean_dy_xmu / world_size
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, running_mean, running_variance, weight, mean_dy, mean_dy_xmu, eps)
if channel_last:
grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
else:
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
if weight is None or not ctx.needs_input_grad[1]:
grad_weight = None
......@@ -81,4 +104,4 @@ class SyncBatchnormFunction(Function):
grad_bias = None
torch.cuda.nvtx.range_pop()
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None
......@@ -3,52 +3,93 @@
#include <vector>
// returns {mean,unbiased_var,biased_var}
// returns {mean,biased_var}
// implemented using welford
std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input);
// reduces array of mean/var across processes
// returns global {mean,unbiased_var,biased_var}
// returns global {mean,inv_std,biased_var}
// implemented using welford
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes, const at::Tensor var_biased_feature_nodes, int numel);
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
const at::Tensor var_biased_feature_nodes,
int numel,
const float eps);
// elementwise BN operation, returns output
// input/weight/shift should have identical data type;
// mean/var have promoted data type (dtype==fp16?fp32:dtype)
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
const at::Tensor mean,
const at::Tensor var,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor shift,
const float eps);
const at::Tensor shift);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/var have promoted data type (dtype==fp16?fp32:dtype)
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// implemented using kahan summation
std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor var,
const at::Tensor weight,
const float eps);
const at::Tensor inv_std,
const at::Tensor weight);
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/var/mean_dy/mean_dy_xmu precision is fp32
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor var,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu,
const float eps);
const at::Tensor mean_dy_xmu);
// returns {mean, biased_var}
// implemented using welford
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
// elementwise BN operation, returns output
// input/weight/shift should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor shift);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight);
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance");
m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance");
m.def("batchnorm_forward", &batchnorm_forward_CUDA, "batchnorm forward");
m.def("reduce_bn", &reduce_bn_CUDA, "batchnorm backward reduce grad sum and bias/weight gradient");
m.def("reduce_bn", &reduce_bn_CUDA, "batchnorm backward reduce grad sum and bias/weight grad");
m.def("batchnorm_backward", &batchnorm_backward_CUDA, "batchnorm backward dgrad");
m.def("welford_mean_var_c_last", &welford_mean_var_c_last_CUDA, "welford mean variance nhwc");
m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc");
m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc");
m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc");
}
......@@ -71,8 +71,57 @@ __device__ __forceinline__ T reduce_block(T *x, T val)
return val;
}
#define TILE_W 32
#define MAX_BLOCK_SIZE 1024
#define ELEMENTS_PER_ITER 4 // enables concurrency within each thread to hide latency
#define ELEMENTS_PER_THREAD 16
#define OPTIMAL_TILE_W 32
#define MAX_H_BLOCK 128
#define MAX_BLOCK_SIZE 512
__host__ int div_ru(int x, int y) {
return h_last_pow2(1 + (x-1)/y);
}
__host__ void flexible_launch_configs(
const int reduction,
const int stride,
dim3 &block,
dim3 &grid,
const bool coop_flag = false) {
int block_x = std::min(h_last_pow2(stride), OPTIMAL_TILE_W);
int block_y = std::min(h_last_pow2(div_ru(reduction , ELEMENTS_PER_THREAD)),
MAX_BLOCK_SIZE / block_x);
if (block_x * block_y != MAX_BLOCK_SIZE) {
block_x = std::min(h_last_pow2(stride), MAX_BLOCK_SIZE / block_y);
}
int grid_x = div_ru(stride, block_x);
int grid_y = std::min(div_ru(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK);
if (coop_flag) {
// it's not worth having a grid reduction if the reduction dimension is not big enough
grid_y = grid_y < 8 ? 1 : grid_y;
}
block.x = block_x;
block.y = block_y;
block.z = 1;
grid.x = grid_x;
grid.y = grid_y;
grid.z = 1;
}
template<typename T, typename C>
__device__ __forceinline__ void welford_merge_element(C& count,
T& mean,
T& m2n,
const C& num_new,
const T& mean_new,
const T& m2n_new) {
T factor = T(1.0) / max(1, (count + num_new));
T delta0 = mean - mean_new;
mean = (mean_new * num_new + mean * count) * factor;
m2n += m2n_new + delta0 * delta0 * num_new * count * factor;
count += num_new;
}
template<typename T>
__device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
......@@ -82,11 +131,7 @@ __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
auto num_new = __shfl_down_sync(0xffffffff, num, i);
auto mean_new = __shfl_down_sync(0xffffffff, mean, i);
auto m2n_new = __shfl_down_sync(0xffffffff, m2n, i);
T factor = 1.0 / max(1, (num+num_new));
auto dif_mean = mean - mean_new;
mean = (mean_new * num_new + mean * num)*factor;
m2n += m2n_new + dif_mean*dif_mean*num*num_new*factor;
num += num_new;
welford_merge_element(num, mean, m2n, num_new, mean_new, m2n_new);
}
}
......@@ -148,13 +193,71 @@ __host__ size_t get_element_data_size(const at::Tensor& input, bool accumulation
return at::elementSize(scalar_type);
}
template<typename T, typename C>
__device__ __forceinline__ void welford_merge_block_vertical(C& count,
T& mean,
T& m2n,
C* shmem_count,
T* shmem_mean,
T* shmem_m2n) {
// write to shared memory
auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
shmem_mean[address_base] = mean;
shmem_m2n[address_base] = m2n;
shmem_count[address_base] = count;
#pragma unroll
for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
__syncthreads();
if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
auto address = address_base + offset * blockDim.x;
// read shared memory back to register for reduction
auto num_new = shmem_count[address];
auto mean_new = shmem_mean[address];
auto m2n_new = shmem_m2n[address];
welford_merge_element(count, mean, m2n, num_new, mean_new, m2n_new);
// last write is not necessary
shmem_mean[address_base] = mean;
shmem_m2n[address_base] = m2n;
shmem_count[address_base] = count;
}
}
}
template<typename T>
__device__ __forceinline__ void merge_block_vertical(T& sum_dy,
T& sum_dy_xmu,
T* shmem_sum_dy,
T* shmem_sum_dy_xmu) {
// write to shared memory
auto address_base = threadIdx.x + threadIdx.y * blockDim.x;
shmem_sum_dy[address_base] = sum_dy;
shmem_sum_dy_xmu[address_base] = sum_dy_xmu;
#pragma unroll
for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
__syncthreads();
if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
auto address = address_base + offset * blockDim.x;
sum_dy += shmem_sum_dy[address];
sum_dy_xmu += shmem_sum_dy_xmu[address];
// last write is not necessary
shmem_sum_dy[address_base] = sum_dy;
shmem_sum_dy_xmu[address_base] = sum_dy_xmu;
}
}
}
// welford kernel calculating mean/biased_variance/unbiased_variance
template <typename scalar_t, typename accscalar_t, typename outscalar_t>
__global__ void welford_kernel(
const scalar_t* __restrict__ input,
outscalar_t* __restrict__ out_mean,
outscalar_t* __restrict__ out_var,
outscalar_t* __restrict__ out_var_biased,
const int bs,
const int fs,
......@@ -185,7 +288,6 @@ __global__ void welford_kernel(
if (thread_id == 0) {
out_mean[blockIdx.x] = static_cast<outscalar_t>(x_mean);
out_var[blockIdx.x] = static_cast<outscalar_t>(m_2_n/(count-1));
out_var_biased[blockIdx.x] = static_cast<outscalar_t>(m_2_n/count);
}
}
......@@ -195,15 +297,14 @@ template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
__global__ void batchnorm_forward_kernel(
const scalar_t* __restrict__ input,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ var,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
const layerscalar_t* __restrict__ shift,
scalar_t* __restrict__ out,
const int ss,
const int bs,
const float eps) {
const int bs) {
auto m_c = mean[blockIdx.x];
auto inv_std_c = static_cast<accscalar_t>(rsqrt(var[blockIdx.x] + eps));
auto inv_std_c = inv_std[blockIdx.x];
auto w_c = static_cast<accscalar_t>(weight[blockIdx.x]);
auto s_c = static_cast<accscalar_t>(shift[blockIdx.x]);
......@@ -224,22 +325,21 @@ __global__ void reduce_bn_kernel(
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ grad_output,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ var,
const accscalar_t* __restrict__ inv_std,
accscalar_t* __restrict__ mean_dy,
accscalar_t* __restrict__ mean_dy_xmu,
layerscalar_t* __restrict__ grad_weight,
layerscalar_t* __restrict__ grad_bias,
const int bs,
const int fs,
const int ss,
const float eps) {
const int ss) {
static __shared__ int s_mem[64];
int total_item_num = bs * ss;
int thread_id = threadIdx.y*blockDim.x + threadIdx.x;
auto r_mean = mean[blockIdx.x];
auto factor = accscalar_t(1.0) / (accscalar_t)sqrt(var[blockIdx.x] + eps);
auto factor = inv_std[blockIdx.x];
// Kahan sum
accscalar_t sum_dy = 0.0;
......@@ -283,64 +383,437 @@ __global__ void batchnorm_backward_kernel(
const scalar_t* __restrict__ grad_output,
const scalar_t* __restrict__ input,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ var,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
const accscalar_t* __restrict__ mean_dy,
const accscalar_t* __restrict__ mean_dy_xmu,
scalar_t* __restrict__ grad_input,
const int ss,
const int bs,
const float eps) {
const int bs) {
auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]);
auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
auto factor_1_c = static_cast<accscalar_t>(var[blockIdx.x]) + eps;
auto factor_2_c = static_cast<accscalar_t>(weight[blockIdx.x]) / sqrt(factor_1_c);
factor_1_c /= static_cast<accscalar_t>(mean_dy_xmu[blockIdx.x]);
auto factor_1_c = inv_std[blockIdx.x];
auto factor_2_c = static_cast<accscalar_t>(weight[blockIdx.x]) * factor_1_c;
factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x];
for (int batch_offset = blockIdx.y*blockDim.y+threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;
for (int offset = threadIdx.x + blockIdx.z*blockDim.x; offset < ss ; offset+= gridDim.z*blockDim.x) {
grad_input[address_base+offset] = (static_cast<accscalar_t>(grad_output[address_base+offset]) - m_dy_c - (static_cast<accscalar_t>(input[address_base+offset]) - m_c) / factor_1_c) * factor_2_c;
grad_input[address_base+offset] = (static_cast<accscalar_t>(grad_output[address_base+offset]) - m_dy_c - (static_cast<accscalar_t>(input[address_base+offset]) - m_c) * factor_1_c) * factor_2_c;
}
}
}
// parallel welford kernel to further reduce mean / biased_var / unbiased_var
// across multiple processes.
template <typename scalar_t, typename accscalar_t>
// welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance
template
<typename scalar_t,
typename accscalar_t,
typename outscalar_t,
int PARALLEL_LOADS>
__global__ void
welford_kernel_c_last(
const scalar_t* __restrict__ input,
outscalar_t* __restrict__ out_mean,
outscalar_t* __restrict__ out_var_biased,
volatile accscalar_t* staging_data,
int* semaphores,
const int reduction_size,
const int stride) {
// hide latency with concurrency
accscalar_t x_mean[PARALLEL_LOADS];
accscalar_t m_2_n[PARALLEL_LOADS];
int count[PARALLEL_LOADS];
#pragma unroll
for (int i = 0; i < PARALLEL_LOADS; i++) {
x_mean[i] = accscalar_t(0);
m_2_n[i] = accscalar_t(0);
count[i] = accscalar_t(0);
}
// tensor dimension (m,c)
// loop along m dimension
int inner_loop_stride = blockDim.y * gridDim.y;
// offset along m dimension
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
int address_base = m_offset * stride + c_offset;
int address_increment = inner_loop_stride * stride;
for (int i = 0; i < loop_count; i++) {
accscalar_t x_math[PARALLEL_LOADS];
accscalar_t x_count_inv[PARALLEL_LOADS];
accscalar_t is_valid[PARALLEL_LOADS];
// load multiple data in
#pragma unroll
for (int j = 0; j < PARALLEL_LOADS; j++) {
if (c_offset < stride && m_offset < reduction_size) {
x_math[j] = input[address_base];
count[j]++;
x_count_inv[j] = accscalar_t(1) / count[j];
is_valid[j] = accscalar_t(1);
} else {
x_math[j] = accscalar_t(0);
x_count_inv[j] = accscalar_t(0);
is_valid[j] = accscalar_t(0);
}
m_offset += inner_loop_stride;
address_base += address_increment;
}
// calculate mean/m2n with welford
#pragma unroll
for (int j = 0; j < PARALLEL_LOADS; j++) {
accscalar_t delta0 = x_math[j] - x_mean[j];
x_mean[j] += delta0 * x_count_inv[j];
accscalar_t delta1 = x_math[j] - x_mean[j];
m_2_n[j] += delta0 * delta1 * is_valid[j];
}
}
// thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS
#pragma unroll
for (int j = 1; j < PARALLEL_LOADS; j++) {
welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]);
}
// release x_mean / m_2_n
auto mean_th = x_mean[0];
auto m2_th = m_2_n[0];
auto count_th = count[0];
// block-wise reduction with shared memory (since reduction cannot be done within a warp)
static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE];
static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE];
static __shared__ int shmem_count[MAX_BLOCK_SIZE];
welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
// grid reduction if needed (coop launch used at the first place)
if (gridDim.y > 1) {
volatile accscalar_t* staging_mean = staging_data;
volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y];
volatile int* staging_count = reinterpret_cast<volatile int*>(&staging_m2n[stride*gridDim.y]);
address_base = c_offset + blockIdx.y * stride;
// write data to staging_data;
if (threadIdx.y == 0 && c_offset < stride) {
staging_mean[address_base] = mean_th;
staging_m2n[address_base] = m2_th;
staging_count[address_base] = count_th;
}
__threadfence();
__syncthreads(); // ensuring writes to staging_ is visible to all blocks
__shared__ bool is_last_block_done;
// mark block done
if (threadIdx.x == 0 && threadIdx.y == 0) {
int old = atomicAdd(&semaphores[blockIdx.x], 1);
is_last_block_done = (old == (gridDim.y-1));
}
__syncthreads();
// check that all data is now available in global memory
if (is_last_block_done) {
count_th = 0;
mean_th = accscalar_t(0.0);
m2_th = accscalar_t(0.0);
for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
address_base = c_offset + y * stride;
int num_new = c_offset < stride ? staging_count[address_base] : 0;
accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0);
accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0);
welford_merge_element(count_th, mean_th, m2_th, num_new, mean_new, m2n_new);
}
welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n);
if (threadIdx.y == 0 && c_offset < stride) {
out_mean[c_offset] = static_cast<outscalar_t>(mean_th);
out_var_biased[c_offset] = static_cast<outscalar_t>(m2_th / count_th);
}
}
} else {
if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
out_mean[c_offset] = static_cast<outscalar_t>(mean_th);
out_var_biased[c_offset] = static_cast<outscalar_t>(m2_th / count_th);
}
}
}
// parallel welford kernel to further reduce mean / biased_var
// into mean / unbiased_var / inv_std across multiple processes.
template <typename scalar_t>
__global__ void welford_kernel_parallel(
const scalar_t* __restrict__ mean,
const scalar_t* __restrict__ var_biased,
scalar_t* __restrict__ out_mean,
scalar_t* __restrict__ out_var,
scalar_t* __restrict__ out_var_biased,
const int ns,
const int fs,
scalar_t* __restrict__ inv_std,
const int world_size,
const int feature_size,
const float eps,
const int numel) {
static __shared__ int s_mem[160];
int block_size = blockDim.x;
accscalar_t* s_mem_ac = (accscalar_t*) &s_mem[32];
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < feature_size; i += gridDim.x * blockDim.x) {
// load data;
int address = i;
scalar_t x_mean = 0;
scalar_t m_2_n = 0;
int count = 0;
for (int j = 0; j < world_size; j++) {
welford_merge_element(count, x_mean, m_2_n, numel, mean[address], var_biased[address]*numel);
address += feature_size;
}
out_mean[i] = x_mean;
out_var[i] = m_2_n/ (count - 1);
inv_std[i] = scalar_t(1) / sqrt(m_2_n/count + eps);
}
}
int input_base = blockIdx.x*ns + threadIdx.x;
int thread_id = threadIdx.x;
// elementwise BN kernel
template <
typename scalar_t,
typename accscalar_t,
typename layerscalar_t,
int PARALLEL_LOADS>
__global__ void batchnorm_forward_c_last_kernel(
const scalar_t* __restrict__ input,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
const layerscalar_t* __restrict__ shift,
scalar_t* __restrict__ out,
const int reduction_size,
const int stride) {
// tensor dimension (m,c)
// loop along m dimension
int inner_loop_stride = blockDim.y * gridDim.y;
// offset along m dimension
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
auto m_c = mean[c_offset];
auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
auto w_c = static_cast<accscalar_t>(weight[c_offset]);
auto s_c = static_cast<accscalar_t>(shift[c_offset]);
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
int address_base = m_offset * stride + c_offset;
int address_increment = inner_loop_stride * stride;
for (int i = 0; i < loop_count; i++) {
#pragma unroll
for (int j = 0; j < PARALLEL_LOADS; j++) {
if (c_offset < stride && m_offset < reduction_size) {
out[address_base] = static_cast<scalar_t>(
w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c
);
}
m_offset += inner_loop_stride;
address_base += address_increment;
}
}
}
// load data;
auto x_mean = static_cast<accscalar_t>(mean[input_base]);
auto m_2_n = static_cast<accscalar_t>(var_biased[input_base]) * numel;
auto count = numel;
// batchnorm backward kernel for c last tensor
template
<typename scalar_t,
typename accscalar_t,
typename layerscalar_t,
int PARALLEL_LOADS>
__global__ void reduce_bn_c_last_kernel(
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ grad_output,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
accscalar_t* __restrict__ mean_dy,
accscalar_t* __restrict__ mean_dy_xmu,
layerscalar_t* __restrict__ grad_weight,
layerscalar_t* __restrict__ grad_bias,
volatile accscalar_t* staging_data,
int* semaphores,
const int reduction_size,
const int stride) {
// hide latency with concurrency
accscalar_t sum_dy[PARALLEL_LOADS];
accscalar_t sum_dy_xmu[PARALLEL_LOADS];
#pragma unroll
for (int i = 0; i < PARALLEL_LOADS; i++) {
sum_dy[i] = accscalar_t(0);
sum_dy_xmu[i] = accscalar_t(0);
}
// tensor dimension (m,c)
// loop along m dimension
int inner_loop_stride = blockDim.y * gridDim.y;
// offset along m dimension
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
int address_base = m_offset * stride + c_offset;
int address_increment = inner_loop_stride * stride;
auto r_mean = mean[c_offset];
auto factor = inv_std[c_offset];
for (int i = 0; i < loop_count; i++) {
accscalar_t x_input[PARALLEL_LOADS];
accscalar_t x_grad_output[PARALLEL_LOADS];
// load multiple data in
#pragma unroll
for (int j = 0; j < PARALLEL_LOADS; j++) {
if (c_offset < stride && m_offset < reduction_size) {
x_input[j] = input[address_base];
x_grad_output[j] = grad_output[address_base];
} else {
x_input[j] = accscalar_t(0);
x_grad_output[j] = accscalar_t(0);
}
m_offset += inner_loop_stride;
address_base += address_increment;
}
__syncthreads();
// calculate sum_dy / sum_dy_xmu
#pragma unroll
for (int j = 0; j < PARALLEL_LOADS; j++) {
sum_dy[j] += x_grad_output[j];
sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean);
}
}
welford_reduce_mean_m2n<accscalar_t>(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id);
// thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS
#pragma unroll
for (int j = 1; j < PARALLEL_LOADS; j++) {
sum_dy[0] += sum_dy[j];
sum_dy_xmu[0] += sum_dy_xmu[j];
}
if (thread_id == 0) {
out_mean[blockIdx.x] = static_cast<scalar_t>(x_mean);
out_var[blockIdx.x] = static_cast<scalar_t>(m_2_n/(count-1));
out_var_biased[blockIdx.x] = static_cast<scalar_t>(m_2_n/count);
// release array of registers
auto sum_dy_th = sum_dy[0];
auto sum_dy_xmu_th = sum_dy_xmu[0];
// block-wise reduction with shared memory (since reduction cannot be done within a warp)
static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE];
static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE];
merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
// grid reduction if needed (coop launch used at the first place)
if (gridDim.y > 1) {
volatile accscalar_t* staging_sum_dy = staging_data;
volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y];
address_base = c_offset + blockIdx.y * stride;
// write data to staging_data;
if (threadIdx.y == 0 && c_offset < stride) {
staging_sum_dy[address_base] = sum_dy_th;
staging_sum_dy_xmu[address_base] = sum_dy_xmu_th;
}
__threadfence();
__syncthreads(); // ensuring writes to staging_ is visible to all blocks
__shared__ bool is_last_block_done;
// mark block done
if (threadIdx.x == 0 && threadIdx.y == 0) {
int old = atomicAdd(&semaphores[blockIdx.x], 1);
is_last_block_done = (old == (gridDim.y-1));
}
__syncthreads();
// check that all data is now available in global memory
if (is_last_block_done) {
sum_dy_th = accscalar_t(0.0);
sum_dy_xmu_th = accscalar_t(0.0);
for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) {
address_base = c_offset + y * stride;
sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0));
sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0));
}
merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu);
if (threadIdx.y == 0 && c_offset < stride) {
grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
mean_dy[c_offset] = sum_dy_th / reduction_size;
mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
}
}
} else {
if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) {
grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th);
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
mean_dy[c_offset] = sum_dy_th / reduction_size;
mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
}
}
}
// elementwise BN kernel
template <
typename scalar_t,
typename accscalar_t,
typename layerscalar_t,
int PARALLEL_LOADS>
__global__ void batchnorm_backward_c_last_kernel(
const scalar_t* __restrict__ grad_output,
const scalar_t* __restrict__ input,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
const accscalar_t* __restrict__ mean_dy,
const accscalar_t* __restrict__ mean_dy_xmu,
scalar_t* __restrict__ grad_input,
const int reduction_size,
const int stride) {
// tensor dimension (m,c)
// loop along m dimension
int inner_loop_stride = blockDim.y * gridDim.y;
// offset along m dimension
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
auto m_c = mean[c_offset];
auto m_dy_c = mean_dy[c_offset];
auto factor_1_c = inv_std[c_offset];
auto factor_2_c = static_cast<accscalar_t>(weight[c_offset]) * factor_1_c;
factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[c_offset];
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
int address_base = m_offset * stride + c_offset;
int address_increment = inner_loop_stride * stride;
for (int i = 0; i < loop_count; i++) {
#pragma unroll
for (int j = 0; j < PARALLEL_LOADS; j++) {
if (c_offset < stride && m_offset < reduction_size) {
grad_input[address_base] = static_cast<scalar_t>(
(static_cast<accscalar_t>(grad_output[address_base]) - m_dy_c -
(static_cast<accscalar_t>(input[address_base]) - m_c) * factor_1_c)
* factor_2_c);
}
m_offset += inner_loop_stride;
address_base += address_increment;
}
}
}
std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
const auto batch_size = input.size(0);
......@@ -349,7 +822,6 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
auto space_size = get_tensor_spatial_size(input);
auto scalar_type = promote_scalartype(input);
at::Tensor out_var = at::empty({feature_size}, input.options().dtype(scalar_type));
at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type));
at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type));
......@@ -358,7 +830,6 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
const dim3 block(block_x, block_y);
const dim3 grid(feature_size);
// shared memory used for reduce on mean, var, num_elements;
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "welford_mean_var_kernel", ([&] {
......@@ -366,23 +837,21 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
welford_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
out_mean.data<accscalar_t>(),
out_var.data<accscalar_t>(),
out_var_biased.data<accscalar_t>(),
batch_size,
feature_size,
space_size);
}));
return {out_mean, out_var, out_var_biased};
return {out_mean, out_var_biased};
}
at::Tensor batchnorm_forward_CUDA(
const at::Tensor input,
const at::Tensor mean,
const at::Tensor var,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor shift,
const float eps) {
const at::Tensor shift) {
const auto batch_size = input.size(0);
const auto feature_size = input.size(1);
at::Tensor out = at::empty_like(input);
......@@ -403,13 +872,12 @@ at::Tensor batchnorm_forward_CUDA(
batchnorm_forward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
var.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<accscalar_t>(),
shift.data<accscalar_t>(),
out.data<scalar_t>(),
space_size,
batch_size,
eps);
batch_size);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
......@@ -418,13 +886,12 @@ at::Tensor batchnorm_forward_CUDA(
batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
var.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<scalar_t>(),
shift.data<scalar_t>(),
out.data<scalar_t>(),
space_size,
batch_size,
eps);
batch_size);
}));
}
return out;
......@@ -434,9 +901,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor var,
const at::Tensor weight,
const float eps)
const at::Tensor inv_std,
const at::Tensor weight)
{
const auto batch_size = input.size(0);
const auto feature_size = input.size(1);
......@@ -463,15 +929,14 @@ std::vector<at::Tensor> reduce_bn_CUDA(
input.data<scalar_t>(),
grad_output.data<scalar_t>(),
mean.data<accscalar_t>(),
var.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<accscalar_t>(),
grad_bias.data<accscalar_t>(),
batch_size,
feature_size,
space_size,
eps);
space_size);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
......@@ -481,15 +946,14 @@ std::vector<at::Tensor> reduce_bn_CUDA(
input.data<scalar_t>(),
grad_output.data<scalar_t>(),
mean.data<accscalar_t>(),
var.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<scalar_t>(),
grad_bias.data<scalar_t>(),
batch_size,
feature_size,
space_size,
eps);
space_size);
}));
}
......@@ -500,11 +964,10 @@ at::Tensor batchnorm_backward_CUDA(
const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor var,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu,
const float eps) {
const at::Tensor mean_dy_xmu) {
const auto batch_size = input.size(0);
const auto feature_size = input.size(1);
......@@ -528,14 +991,13 @@ at::Tensor batchnorm_backward_CUDA(
grad_output.data<scalar_t>(),
input.data<scalar_t>(),
mean.data<accscalar_t>(),
var.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
space_size,
batch_size,
eps);
batch_size);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
......@@ -545,46 +1007,273 @@ at::Tensor batchnorm_backward_CUDA(
grad_output.data<scalar_t>(),
input.data<scalar_t>(),
mean.data<accscalar_t>(),
var.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<scalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
space_size,
batch_size,
eps);
batch_size);
}));
}
return grad_input;
}
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes, const at::Tensor var_biased, int numel) {
const auto feature_size = mean_feature_nodes.size(0);
const auto world_size = mean_feature_nodes.size(1);
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
const at::Tensor var_biased,
int numel,
const float eps) {
const auto world_size = mean_feature_nodes.size(0);
const auto feature_size = mean_feature_nodes.size(1);
at::Tensor out_var = at::empty({feature_size}, var_biased.options());
at::Tensor out_var_biased = at::empty_like(out_var);
at::Tensor inv_std = at::empty_like(out_var);
at::Tensor out_mean = at::empty_like(out_var);
// TODO(jie): tile this for memory coalescing!
const dim3 block(world_size);
const dim3 grid(feature_size);
// shared memory used for reduce on mean, var, num_elements;
const int block = std::min(h_last_pow2(feature_size), MAX_BLOCK_SIZE);
const int grid = std::max<int>(1, feature_size / block);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(mean_feature_nodes.type(), "welford_parallel_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
welford_kernel_parallel<scalar_t, accscalar_t><<<grid, block, 0, stream>>>(
welford_kernel_parallel<scalar_t><<<grid, block, 0, stream>>>(
mean_feature_nodes.data<scalar_t>(),
var_biased.data<scalar_t>(),
out_mean.data<scalar_t>(),
out_var.data<scalar_t>(),
out_var_biased.data<scalar_t>(),
inv_std.data<scalar_t>(),
world_size,
feature_size,
eps,
numel);
}));
return {out_mean, out_var, out_var_biased};
return {out_mean, out_var, inv_std};
}
std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
auto scalar_type = promote_scalartype(input);
auto option = input.options().dtype(scalar_type);
at::Tensor out_var_biased = at::empty({stride}, option);
at::Tensor out_mean = at::empty({stride}, option);
dim3 block;
dim3 grid;
flexible_launch_configs(reduction_size, stride, block, grid, true);
at::Tensor staging_data;
at::Tensor semaphores;
if (grid.y > 1) {
staging_data = at::empty({4*stride*grid.y}, option);
semaphores = at::zeros({grid.x}, input.options().dtype(at::ScalarType::Int));
}
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "welford_mean_var_c_last", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
welford_kernel_c_last<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
out_mean.data<accscalar_t>(),
out_var_biased.data<accscalar_t>(),
staging_data_ptr,
semaphores_ptr,
reduction_size,
stride);
}));
return {out_mean, out_var_biased};
}
at::Tensor batchnorm_forward_c_last_CUDA(
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor shift) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
at::Tensor out = at::empty_like(input);
dim3 block;
dim3 grid;
flexible_launch_configs(reduction_size, stride, block, grid);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half
&& weight.type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<accscalar_t>(),
shift.data<accscalar_t>(),
out.data<scalar_t>(),
reduction_size,
stride);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<scalar_t>(),
shift.data<scalar_t>(),
out.data<scalar_t>(),
reduction_size,
stride);
}));
}
return out;
}
std::vector<at::Tensor> reduce_bn_c_last_CUDA(
const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
at::Tensor mean_dy = at::empty({stride}, mean.options());
at::Tensor mean_dy_xmu = at::empty({stride}, mean.options());
at::Tensor grad_weight = at::empty({stride}, weight.options());
at::Tensor grad_bias = at::empty({stride}, weight.options());
dim3 block;
dim3 grid;
flexible_launch_configs(reduction_size, stride, block, grid, true);
at::Tensor staging_data;
at::Tensor semaphores;
if (grid.y > 1) {
staging_data = at::empty({2*stride*grid.y}, mean.options());
semaphores = at::zeros({grid.x}, input.options().dtype(at::ScalarType::Int));
}
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
reduce_bn_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
grad_output.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<accscalar_t>(),
grad_bias.data<accscalar_t>(),
staging_data_ptr,
semaphores_ptr,
reduction_size,
stride);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.data<accscalar_t>() : nullptr;
int* semaphores_ptr = grid.y > 1 ? semaphores.data<int>() : nullptr;
reduce_bn_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
grad_output.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<scalar_t>(),
grad_bias.data<scalar_t>(),
staging_data_ptr,
semaphores_ptr,
reduction_size,
stride);
}));
}
return {mean_dy, mean_dy_xmu, grad_weight, grad_bias};
}
at::Tensor batchnorm_backward_c_last_CUDA(
const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
at::Tensor grad_input = at::empty_like(input);
dim3 block;
dim3 grid;
flexible_launch_configs(reduction_size, stride, block, grid);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half
&& weight.type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(),
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
reduction_size,
stride);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(),
"input.type().scalarType() is not supported with weight.type().scalarType()");
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_c_last_kernel<scalar_t, accscalar_t, scalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(),
input.data<scalar_t>(),
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.data<scalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
reduction_size,
stride);
}));
}
return grad_input;
}
# Base image must at least have pytorch and CUDA installed.
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:18.04-py3
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:18.12-py3
FROM $BASE_IMAGE
ARG BASE_IMAGE
RUN echo "Installing Apex on top of ${BASE_IMAGE}"
......
......@@ -21,8 +21,8 @@ Currently, Pytorch's default non-devel image on Dockerhub
## Option 2: Install Apex in a running container
Instead of building a new container, it is also a viable option to clone Apex on bare metal, mount the Apex repo into your container at launch by running, for example,
Instead of building a new container, it is also a viable option to `git clone https://github.com/NVIDIA/apex.git` on bare metal, mount the Apex repo into your container at launch by running, for example,
```
docker run --runtime=nvidia -it --rm --ipc=host -v /bare/metal/apex:/apex/in/container <base image>
```
then go to /apex/in/container within the running container and `python setup.py install`.
then go to /apex/in/container within the running container and `python setup.py install [--cuda_ext] [--cpp_ext]`.
......@@ -54,7 +54,11 @@ m = inp_r.mean(1)
b_v = inp_r.var(1, unbiased=False)
unb_v = inp_r.var(1, unbiased=True)
mean, var, var_biased = syncbn.welford_mean_var(inp_t)
eps = 1e-5
#mean, var, var_biased = syncbn.welford_mean_var(inp_t)
mean, var_biased = syncbn.welford_mean_var(inp_t)
inv_std = 1.0 / torch.sqrt(var_biased + eps)
bn = torch.nn.BatchNorm2d(feature_size).cuda()
bn.momentum = 1.0
......@@ -74,16 +78,25 @@ grad_sbn = grad_output_t.clone().detach()
out_sbn = sbn(inp_sbn)
out_sbn.backward(grad_sbn)
sbn_c_last = apex.parallel.SyncBatchNorm(feature_size, channel_last=True).cuda()
sbn_c_last.momentum = 1.0
sbn_c_last.weight.data = weight_t.clone()
sbn_c_last.bias.data = bias_t.clone()
inp_sbn_c_last = inp_t.clone().transpose(-1, 1).contiguous().requires_grad_()
grad_sbn_c_last = grad_output_t.clone().transpose(-1, 1).contiguous().detach()
out_sbn_c_last = sbn_c_last(inp_sbn_c_last)
out_sbn_c_last.backward(grad_sbn_c_last)
sbn_result = True
sbn_result_c_last = True
bn_result = True
sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result
sbn_result = compare("comparing variance: ", var, unb_v, error) and sbn_result
#sbn_result = compare("comparing variance: ", var, unb_v, error) and sbn_result
sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result
eps = 1e-5
out = syncbn.batchnorm_forward(inp_t, mean, var_biased, weight_t, bias_t, eps)
out = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t)
out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r
sbn_result = compare("comparing output: ", out, out_r, error) and sbn_result
......@@ -102,8 +115,8 @@ mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).co
grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, var_biased, weight_t, eps)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, var_biased, weight_t, mean_dy, mean_dy_xmu, eps)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result
......@@ -112,7 +125,7 @@ sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error)
compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)
sbn_result = compare("comparing sbn input grad: ", inp_sbn.grad, grad_input_r, error) and sbn_result
compare("comparing output: ", out_bn, out_sbn, error)
compare("comparing bn/sbn output: ", out_bn, out_sbn, error)
sbn_result = compare("comparing running_mean: ", bn.running_mean.data, sbn.running_mean.data, error) and sbn_result
sbn_result = compare("comparing running_variance: ", bn.running_var.data, sbn.running_var.data, error) and sbn_result
compare("comparing grad_input: ", inp_bn.grad, inp_sbn.grad, error)
......@@ -123,7 +136,21 @@ compare("comparing grad_weight: ", bn.weight.grad, sbn.weight.grad, error)
compare("comparing grad_weight bn to ref: ", bn.weight.grad, grad_weight_r, error)
sbn_result = compare("comparing grad_weight sbn to ref: ", sbn.weight.grad, grad_weight_r, error) and sbn_result
compare("comparing channel last bn/sbn output: ", out_bn, out_sbn_c_last.transpose(-1, 1).contiguous(), error)
sbn_result_c_last = compare("comparing channel last running_mean: ", bn.running_mean.data, sbn_c_last.running_mean.data, error) and sbn_result_c_last
sbn_result_c_last = compare("comparing channel last running_variance: ", bn.running_var.data, sbn_c_last.running_var.data, error) and sbn_result_c_last
compare("comparing channel last grad_input: ", inp_bn.grad, inp_sbn_c_last.grad.transpose(-1, 1).contiguous(), error)
compare("comparing channel last grad_bias: ", bn.bias.grad, sbn_c_last.bias.grad, error)
sbn_result_c_last = compare("comparing channel last grad_bias sbn to ref: ", sbn_c_last.bias.grad, grad_bias_r, error) and sbn_result_c_last
compare("comparing channel last grad_weight: ", bn.weight.grad, sbn_c_last.weight.grad, error)
sbn_result_c_last = compare("comparing channel last grad_weight sbn to ref: ", sbn_c_last.weight.grad, grad_weight_r, error) and sbn_result_c_last
if sbn_result:
print("====SBN single gpu passed tests")
else:
print("*SBN single gpu failed*")
if sbn_result_c_last:
print("====SBN channel last single gpu passed tests")
else:
print("*SBN channel last single gpu failed*")
......@@ -75,7 +75,10 @@ m = inp_r.mean(1)
b_v = inp_r.var(1, unbiased=False)
unb_v = inp_r.var(1, unbiased=True)
mean, var, var_biased = syncbn.welford_mean_var(inp_t)
eps = 1e-5
mean, var_biased = syncbn.welford_mean_var(inp_t)
inv_std = 1.0 / torch.sqrt(var_biased + eps)
bn = torch.nn.BatchNorm2d(feature_size).cuda()
bn.momentum = 1.0
......@@ -111,12 +114,9 @@ bn_result = True
if args.local_rank == 0:
sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result
sbn_result = compare("comparing variance: ", var, unb_v, error) and sbn_result
sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result
eps = 1e-5
out = syncbn.batchnorm_forward(inp_t, mean, var_biased, weight_t, bias_t, eps)
out = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t)
out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r
if args.local_rank == 0:
......@@ -136,8 +136,8 @@ mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).co
grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, var_biased, weight_t, eps)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, var_biased, weight_t, mean_dy, mean_dy_xmu, eps)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)
if args.local_rank == 0:
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment