Commit 81eef1ef authored by jjsjann123's avatar jjsjann123 Committed by mcarilli
Browse files

[syncBN] (#48)

* [syncBN]
  added syncBN in native pure python apex
  added fused cuda kernels used for sync BN. Using welford for mean/var
    optional installation using 'python setup.py install --cuda_ext'
  added unit test with side to side comparison between apex sync BN with
    PyTorch BN. Notice that for pytorch BN implementation, because of
    numerical issue for mean/var, the output will be slightly off.

* [syncBN PR]
  added fp16 support
  addressing review comments on:
    1. updating last pow 2
    2. look for import error when importing syncBN kernel

* [syncBN PR]
  added convert function to insert SyncBatchNorm
  refactored some kernel code

* fixing type issue (fp16/fp32/fp64)
added Kahan summation
editing unit test to use pytorch primitive ops with double, passing reasonable tests now

* updating tensor creation calls

* fixing the all_reduce contiguous tensor

* transposed all reduce results

* [syncBN]
support fp16 input & fp32 layer for apex fp16
partially fixing launch configs
enabling imagenet example to run with --sync_bn

* [syncBN PR]
Documentation added

* adjusting README

* adjusting again

* added some doc to imagenet example

* [syncBN]
  warp-level reduction
  bug fix: warp reduction logic updated. check for dummy element to avoid nan.
  improved launch config for better reduction kernels. Further improvements
would be to increase grid size.

* [syncBN]
  fixing undefined behavior in __shfl_down_sync from divergent threads in warp
reduction.
  changing at::native::empty to at::empty (upstream comments)
parent e12c1ec3
......@@ -55,6 +55,18 @@ optimized for NVIDIA's NCCL communication library.
The [Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
mixed precision examples also demonstrate `apex.parallel.DistributedDataParallel`.
### Synchronized Batch Normalization
`apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to
support synchronized BN.
It reduces stats across processes during multiprocess distributed data parallel
training.
Synchronous Batch Normalization has been used in cases where only very small
number of mini-batch could be fit on each GPU.
All-reduced stats boost the effective batch size for sync BN layer to be the
total number of mini-batches across all processes.
It has improved the converged accuracy in some of our research models.
# Requirements
Python 3
......@@ -81,8 +93,18 @@ To use the extension
import apex
```
### CUDA/C++ extension
To build Apex with CUDA/C++ extension, follow the Linux instruction with the
`--cuda_ext` option enabled
```
python setup.py install --cuda_ext
```
CUDA/C++ extension provides customed synchronized Batch Normalization kernels
that provides better performance and numerical accuracy.
### Windows support
Windows support is experimental, and Linux is recommended. However, since Apex is Python-only, there's a good chance it "just works" the same way as Linux. If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
Windows support is experimental, and Linux is recommended. However, since Apex could be Python-only, there's a good chance the Python-only features "just works" the same way as Linux. If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
<!--
reparametrization and RNN API under construction
......@@ -91,5 +113,3 @@ Current version of apex contains:
3. Reparameterization function that allows you to recursively apply reparameterization to an entire module (including children modules).
4. An experimental and in development flexible RNN API.
-->
## Distributed Data Parallel
distributed.py contains the source code for `apex.parallel.DistributedDataParallel`, a module wrapper that enables multi-process multi-GPU data parallel training optimized for NVIDIA's NCCL communication library.
`apex.parallel.DistributedDataParallel` achieves high performance by overlapping communication with
......@@ -14,4 +16,51 @@ multiproc.py contains the source code for `apex.parallel.multiproc`, a launch ut
#### [Simple example with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple/distributed_apex)
### Synchronized Batch Normalization
`apex.parallel.SyncBatchNorm` has similar APIs as with `torch.nn.BatchNorm*N*d`.
It reduces stats on the first (channel) dimension of the Tensor and accepts
arbitrary spatial dimensions.
#### Installation
Apex provides two sync BN implementation:
1. There is the Python-only implementation, which is the default implementation
when install with `python setup.py install`.
It uses PyTorch primitive operations and distributed communication package from
`torch.distributed`.
- _Python-only implementation requires input tensor to be of same data type as
layer_
2. We also provide implementation with kernels through CUDA/C++ extension with
improved performance. We are experimenting with Welford and Kahan for reduction
hoping to get better accuracy.
To use the kernel implementation, user need to install Apex with CUDA extension
enabled `python setup.py install --cuda_ext`.
- _Custom kernel implementation supports fp16 input with fp32 layer as cudnn.
This is required to run imagenet example in fp16._
- _Currently kernel implementation only supports GPU._
#### HowTo
1. User could use `apex.parallel.SyncBatchNorm` by building their module with
the layer explicitly.
```
import apex
input_t = torch.randn(3, 5, 20).cuda()
sbn = apex.parallel.SyncBatchNorm(5).cuda()
output_t = sbn(input)
```
2. User could also take a constructed `torch.nn.Model` and replace all its `torch.nn.BatchNorm*N*d` modules with `apex.parallel.SyncBatchNorm` through utility function `apex.parallel.convert_syncbn_model`.
```
# model is an instance of torch.nn.Module
import apex
sync_bn_model = apex.parallel.convert_syncbn_model(model)
```
import torch
from .distributed import DistributedDataParallel, Reducer
try:
import syncbn
print("using fused syncBN")
from .optimized_sync_batchnorm import SyncBatchNorm
except ImportError:
print("using non-fused syncBN, try install apex with 'python setup.py install --cuda_ext' to enable fused syncBN for better performance")
from .sync_batchnorm import SyncBatchNorm
def convert_syncbn_model(module):
'''
Recursively traverse module and its children to replace all
`torch.nn.modules.batchnorm._BatchNorm` with `apex.parallel.SyncBatchNorm`
All `torch.nn.BatchNorm*N*d` wraps around
`torch.nn.modules.batchnorm._BatchNorm`, this function let you easily switch
to use sync BN.
Args:
module: input module `torch.nn.Module`
Examples::
>>> # model is an instance of torch.nn.Module
>>> import apex
>>> sync_bn_model = apex.parallel.convert_syncbn_model(model)
'''
mod = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats)
mod.running_mean = module.running_mean
mod.running_var = module.running_var
if module.affine:
mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
mod.add_module(name, convert_syncbn_model(child))
# TODO(jie) should I delete model explicitly?
del module
return mod
import torch
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn import functional as F
from .optimized_sync_batchnorm_kernel import SyncBatchnormFunction
class SyncBatchNorm(_BatchNorm):
"""
synchronized batch normalization module extented from `torch.nn.BatchNormNd`
with the added stats reduction across multiple processes.
:class:`apex.parallel.SyncBatchNorm` is designed to work with
`DistributedDataParallel`.
When running in training mode, the layer reduces stats across all processes
to increase the effective batchsize for normalization layer. This is useful
in applications where batch size is small on a given process that would
diminish converged accuracy of the model. The model uses collective
communication package from `torch.distributed`.
When running in evaluation mode, the layer falls back to
`torch.nn.functional.batch_norm`
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
statistics in both training and eval modes. Default: ``True``
Examples::
>>> sbn = apex.parallel.SyncBatchNorm(100).cuda()
>>> inp = torch.randn(10, 100, 14, 14).cuda()
>>> out = sbn(inp)
>>> inp = torch.randn(3, 100, 20).cuda()
>>> out = sbn(inp)
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
def forward(self, input):
if not self.training and self.track_running_stats:
# fall back to pytorch implementation for inference
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else:
self.num_batches_tracked += 1
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.track_running_stats, self.momentum)
import torch
from torch.autograd.function import Function
import syncbn
class SyncBatchnormFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0):
torch.cuda.nvtx.range_push("sync_BN_fw")
input = input.contiguous()
if track_running_stats:
mean, var, var_biased = syncbn.welford_mean_var(input)
if torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
var_all = torch.empty(world_size, var.size(0), dtype=var.dtype, device=var.device)
mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)]
var_l = [var_all.narrow(0, i, 1) for i in range(world_size)]
torch.distributed.all_gather(mean_l, mean)
torch.distributed.all_gather(var_l, var_biased)
mean, var, var_biased = syncbn.welford_parallel(mean_all.transpose(1,0).contiguous(), var_all.transpose(1,0).contiguous(), int(input.numel()/input.size(1)))
# TODO(Jie): should do fp32 math instead!
r_m_inc = mean if running_mean.dtype != torch.float16 else mean.half()
r_v_inc = var if running_variance.dtype != torch.float16 else var.half()
running_mean.data = running_mean.data * (1-momentum) + momentum*r_m_inc
running_variance.data = running_variance.data * (1-momentum) + momentum*r_v_inc
else:
mean = running_mean.data
var_biased = running_var.data
ctx.save_for_backward(input, weight, mean, var_biased)
ctx.eps = eps
out = syncbn.batchnorm_forward(input, mean, var_biased, weight, bias, eps)
torch.cuda.nvtx.range_pop()
return out
@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
torch.cuda.nvtx.range_push("sync_BN_bw")
# mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input, weight, running_mean, running_variance = ctx.saved_tensors
eps = ctx.eps
grad_input = grad_weight = grad_bias = None
# TODO(jie): why do I have to clone here? life time of grad_output?
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, running_mean, running_variance, weight, eps)
# calculate grad_input
if ctx.needs_input_grad[0]:
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
mean_dy, op=torch.distributed.reduce_op.SUM)
mean_dy = mean_dy / torch.distributed.get_world_size()
torch.distributed.all_reduce(
mean_dy_xmu, op=torch.distributed.reduce_op.SUM)
mean_dy_xmu = mean_dy_xmu / torch.distributed.get_world_size()
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, running_mean, running_variance, weight, mean_dy, mean_dy_xmu, eps)
if weight is None or not ctx.needs_input_grad[1]:
grad_weight = None
if weight is None or not ctx.needs_input_grad[2]:
grad_bias = None
torch.cuda.nvtx.range_pop()
return grad_input, grad_weight, grad_bias, None, None, None, None, None
import torch
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn import functional as F
from .sync_batchnorm_kernel import SyncBatchnormFunction
class SyncBatchNorm(_BatchNorm):
"""
synchronized batch normalization module extented from `torch.nn.BatchNormNd`
with the added stats reduction across multiple processes.
:class:`apex.parallel.SyncBatchNorm` is designed to work with
`DistributedDataParallel`.
When running in training mode, the layer reduces stats across all processes
to increase the effective batchsize for normalization layer. This is useful
in applications where batch size is small on a given process that would
diminish converged accuracy of the model. The model uses collective
communication package from `torch.distributed`.
When running in evaluation mode, the layer falls back to
`torch.nn.functional.batch_norm`
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
statistics in both training and eval modes. Default: ``True``
Examples::
>>> sbn = apex.parallel.SyncBatchNorm(100).cuda()
>>> inp = torch.randn(10, 100, 14, 14).cuda()
>>> out = sbn(inp)
>>> inp = torch.randn(3, 100, 20).cuda()
>>> out = sbn(inp)
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
def forward(self, input):
torch.cuda.nvtx.range_push("sync_bn_fw_with_mean_var")
mean = None
var = None
if not self.training and self.track_running_stats:
# fall back to pytorch implementation for inference
torch.cuda.nvtx.range_pop()
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else:
self.num_batches_tracked += 1
with torch.no_grad():
channel_first_input = input.transpose(0, 1).contiguous()
squashed_input_tensor_view = channel_first_input.view(
channel_first_input.size(0), -1)
# total number of data points for each variance entry. Used to calculate unbiased variance estimate
m = None
local_m = float(squashed_input_tensor_view.size()[1])
local_mean = torch.mean(squashed_input_tensor_view, 1)
local_sqr_mean = torch.pow(
squashed_input_tensor_view, 2).mean(1)
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
local_mean, op=torch.distributed.reduce_op.SUM)
mean = local_mean / torch.distributed.get_world_size()
torch.distributed.all_reduce(
local_sqr_mean, op=torch.distributed.reduce_op.SUM)
sqr_mean = local_sqr_mean / torch.distributed.get_world_size()
m = local_m * torch.distributed.get_world_size()
else:
m = local_m
mean = local_mean
sqr_mean = local_sqr_mean
# var(x) = E (( x - mean_x ) ** 2)
# = 1 / N * sum ( x - mean_x ) ** 2
# = 1 / N * sum (x**2) - mean_x**2
var = sqr_mean - mean.pow(2)
if self.running_mean is not None:
self.running_mean = self.momentum * mean + \
(1 - self.momentum) * self.running_mean
if self.running_var is not None:
# as noted by the paper, we used unbiased variance estimate of the mini-batch
# Var[x] = m / (m-1) * Eb (sample_variance)
self.running_var = m / \
(m-1) * self.momentum * var + \
(1 - self.momentum) * self.running_var
torch.cuda.nvtx.range_pop()
return SyncBatchnormFunction.apply(input, self.weight, self.bias, mean, var, self.eps)
import torch
from torch.autograd.function import Function
class SyncBatchnormFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias, running_mean, running_variance, eps):
torch.cuda.nvtx.range_push("sync_BN_fw")
# transpose it to channel last to support broadcasting for input with different rank
c_last_input = input.transpose(1, -1).contiguous().clone()
ctx.save_for_backward(c_last_input, weight, bias,
running_mean, running_variance)
ctx.eps = eps
c_last_input = (c_last_input - running_mean) / \
torch.sqrt(running_variance + eps)
if weight is not None:
c_last_input = c_last_input * weight
if bias is not None:
c_last_input = c_last_input + bias
torch.cuda.nvtx.range_pop()
return c_last_input.transpose(1, -1).contiguous().clone()
@staticmethod
def backward(ctx, grad_output):
torch.cuda.nvtx.range_push("sync_BN_bw")
# mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0)
c_last_input, weight, bias, running_mean, running_variance = ctx.saved_tensors
eps = ctx.eps
grad_input = grad_weight = grad_bias = None
num_features = running_mean.size()[0]
# transpose it to channel last to support broadcasting for input with different rank
torch.cuda.nvtx.range_push("carilli field")
c_last_grad = grad_output.transpose(1, -1).contiguous()
# squash non-channel dimension so we can easily calculate mean
c_grad = c_last_grad.view(-1, num_features).contiguous()
torch.cuda.nvtx.range_pop()
# calculate grad_input
if ctx.needs_input_grad[0]:
# dh = gamma * (var + eps)**(-1. / 2.) * (dy - np.mean(dy, axis=0)
# - (h - mu) * (var + eps)**(-1.0) * np.mean(dy * (h - mu), axis=0))
mean_dy = c_grad.mean(0)
mean_dy_xmu = (c_last_grad * (c_last_input -
running_mean)).view(-1, num_features).mean(0)
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
mean_dy, op=torch.distributed.reduce_op.SUM)
mean_dy = mean_dy / torch.distributed.get_world_size()
torch.distributed.all_reduce(
mean_dy_xmu, op=torch.distributed.reduce_op.SUM)
mean_dy_xmu = mean_dy_xmu / torch.distributed.get_world_size()
c_last_grad_input = (c_last_grad - mean_dy - (c_last_input - running_mean) / (
running_variance + eps) * mean_dy_xmu) / torch.sqrt(running_variance + eps)
if weight is not None:
c_last_grad_input.mul_(weight)
grad_input = c_last_grad_input.transpose(1, -1).contiguous()
# calculate grad_weight
grad_weight = None
if weight is not None and ctx.needs_input_grad[1]:
# dgamma = np.sum((h - mu) * (var + eps)**(-1. / 2.) * dy, axis=0)
grad_weight = ((c_last_input - running_mean) / torch.sqrt(
running_variance + eps) * c_last_grad).view(-1, num_features).sum(0)
# calculate grad_bias
grad_bias = None
if bias is not None and ctx.needs_input_grad[2]:
# dbeta = np.sum(dy, axis=0)
grad_bias = c_grad.sum(0)
torch.cuda.nvtx.range_pop()
return grad_input, grad_weight, grad_bias, None, None, None
#include <torch/torch.h>
#include <ATen/ATen.h>
#include <vector>
// returns {mean,unbiased_var,biased_var}
// implemented using welford
std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input);
// reduces array of mean/var across processes
// returns global {mean,unbiased_var,biased_var}
// implemented using welford
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes, const at::Tensor var_biased_feature_nodes, int numel);
// elementwise BN operation, returns output
// input/weight/shift should have identical data type;
// mean/var have promoted data type (dtype==fp16?fp32:dtype)
at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
const at::Tensor mean,
const at::Tensor var,
const at::Tensor weight,
const at::Tensor shift,
const float eps);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/var have promoted data type (dtype==fp16?fp32:dtype)
// implemented using kahan summation
std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor var,
const at::Tensor weight,
const float eps);
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/var/mean_dy/mean_dy_xmu precision is fp32
at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor var,
const at::Tensor weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu,
const float eps);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance");
m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance");
m.def("batchnorm_forward", &batchnorm_forward_CUDA, "batchnorm forward");
m.def("reduce_bn", &reduce_bn_CUDA, "batchnorm backward reduce grad sum and bias/weight gradient");
m.def("batchnorm_backward", &batchnorm_backward_CUDA, "batchnorm backward dgrad");
}
#include <iostream>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
__device__ __forceinline__ int lastpow2(int n)
{
int out = 1 << (31 - __clz(n));
if(n == out)
out >>= 1;
return out;
}
__host__ __forceinline__ int h_next_pow2(unsigned int n) {
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8);
n |= (n >> 16);
return n + 1;
}
__host__ __forceinline__ int h_last_pow2(unsigned int n) {
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8);
n |= (n >> 16);
return n - (n >> 1);
}
#define WARP_SIZE 32
template<typename T>
__device__ __forceinline__ T warp_reduce_sum(T val)
{
#pragma unroll
for(int i = WARP_SIZE/2; i > 0; i >>= 1)
val = val + __shfl_down_sync(0xffffffff, val, i);
return val;
}
template<typename T>
__device__ __forceinline__ T reduce_block(T *x, T val)
{
int tid = threadIdx.y*blockDim.x + threadIdx.x;
int blockSize = blockDim.x * blockDim.y;
if (blockSize > 32) {
val = warp_reduce_sum(val);
if (tid % WARP_SIZE == 0)
x[tid/WARP_SIZE] = val;
__syncthreads();
val = (tid < blockSize / WARP_SIZE? x[tid%WARP_SIZE] : T(0));
}
if(tid/WARP_SIZE==0) val = warp_reduce_sum(val);
return val;
}
#define TILE_W 32
#define MAX_BLOCK_SIZE 256
template<typename T>
__device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
{
#pragma unroll
for(int i = WARP_SIZE/2; i > 0; i >>= 1) {
auto num_new = __shfl_down_sync(0xffffffff, num, i);
auto mean_new = __shfl_down_sync(0xffffffff, mean, i);
auto m2n_new = __shfl_down_sync(0xffffffff, m2n, i);
if (num_new != 0) {
auto dif_mean = mean - mean_new;
mean = (mean_new * num_new + mean * num) / (num + num_new);
m2n += m2n_new + dif_mean*dif_mean*num*num_new/(num_new+num);
num += num_new;
}
}
}
template <typename T>
__device__ void welford_reduce_mean_m2n(
T* __restrict__ x,
int* __restrict__ count,
T &mean,
T &m2n,
int &num,
int block_size,
int thread_id)
{
int lane = thread_id % WARP_SIZE;
int wid = thread_id / WARP_SIZE;
if (block_size > 32) {
warp_reduce_mean_m2n(mean, m2n, num);
if (lane == 0) {
x[wid*2] = mean;
x[wid*2+1] = m2n;
count[wid] = num;
}
__syncthreads();
if (wid == 0) {
mean = (thread_id < block_size / WARP_SIZE)? x[lane*2] : T(0);
m2n = (thread_id < block_size / WARP_SIZE)? x[lane*2+1] : T(0);
num = (thread_id < block_size / WARP_SIZE)? count[lane] : int(0);
}
}
if (wid==0) warp_reduce_mean_m2n(mean, m2n, num);
return;
}
// return spatial size for NC+ Tensors
__host__ int get_tensor_spatial_size(const at::Tensor& input)
{
auto space_size = input.size(2);
for (int i = 3; i < input.ndimension(); i++) {
space_size *= input.size(i);
}
return space_size;
}
// promote accumulation scalar type. promote half to float.
__host__ at::ScalarType promote_scalartype(const at::Tensor& input)
{
return input.type().scalarType() == at::ScalarType::Half ?
at::ScalarType::Float : input.type().scalarType();
}
// return single element size, optional accumulation type promotion.
__host__ size_t get_element_data_size(const at::Tensor& input, bool accumulation = false)
{
auto scalar_type = accumulation ? promote_scalartype(input) : input.type().scalarType();
return at::elementSize(scalar_type);
}
// welford kernel calculating mean/biased_variance/unbiased_variance
template <typename scalar_t, typename accscalar_t, typename outscalar_t>
__global__ void welford_kernel(
const scalar_t* __restrict__ input,
outscalar_t* __restrict__ out_mean,
outscalar_t* __restrict__ out_var,
outscalar_t* __restrict__ out_var_biased,
const int bs,
const int fs,
const int ss) {
static __shared__ int s_mem[160];
int block_size = blockDim.x * blockDim.y;
accscalar_t* s_mem_ac = (accscalar_t*) &s_mem[32];
int count = 0;
accscalar_t x_mean = accscalar_t(0);
accscalar_t m_2_n = accscalar_t(0);
int thread_id = threadIdx.y*blockDim.x + threadIdx.x;
for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) {
int input_base = blockIdx.x*ss + batch_id*ss*fs;
// sequential welford
for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) {
count++;
auto x_n = static_cast<accscalar_t>(input[offset+input_base]);
auto x_mean_new = x_mean + (x_n - x_mean) / count;
m_2_n = m_2_n + (x_n - x_mean_new) * (x_n - x_mean);
x_mean = x_mean_new;
}
}
welford_reduce_mean_m2n<accscalar_t>(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id);
if (thread_id == 0) {
out_mean[blockIdx.x] = static_cast<outscalar_t>(x_mean);
out_var[blockIdx.x] = static_cast<outscalar_t>(m_2_n/(count-1));
out_var_biased[blockIdx.x] = static_cast<outscalar_t>(m_2_n/count);
}
}
// elementwise BN kernel
template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
__global__ void batchnorm_forward_kernel(
const scalar_t* __restrict__ input,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ var,
const layerscalar_t* __restrict__ weight,
const layerscalar_t* __restrict__ shift,
scalar_t* __restrict__ out,
const int ss,
const float eps) {
int address_base = blockIdx.x*ss + blockIdx.y*gridDim.x*ss;
auto m_c = mean[blockIdx.x];
auto inv_std_c = static_cast<accscalar_t>(rsqrt(var[blockIdx.x] + eps));
auto w_c = static_cast<accscalar_t>(weight[blockIdx.x]);
auto s_c = static_cast<accscalar_t>(shift[blockIdx.x]);
for (int offset = threadIdx.x; offset < ss ; offset+= blockDim.x) {
out[address_base+offset] = static_cast<scalar_t>(w_c * (static_cast<accscalar_t>(input[address_base+offset]) - m_c ) * inv_std_c + s_c);
}
}
// Backward BN kernel, calculates grad_bias, grad_weight as well as intermediate
// results to calculating grad_input.
// Breaking the grad_input to two step to support sync BN, which requires all
// reduce of the intermediate results across processes.
template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
__global__ void reduce_bn_kernel(
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ grad_output,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ var,
accscalar_t* __restrict__ mean_dy,
accscalar_t* __restrict__ mean_dy_xmu,
layerscalar_t* __restrict__ grad_weight,
layerscalar_t* __restrict__ grad_bias,
const int bs,
const int fs,
const int ss,
const float eps) {
static __shared__ int s_mem[64];
int total_item_num = bs * ss;
int thread_id = threadIdx.y*blockDim.x + threadIdx.x;
auto r_mean = mean[blockIdx.x];
auto factor = accscalar_t(1.0) / (accscalar_t)sqrt(var[blockIdx.x] + eps);
// Kahan sum
accscalar_t sum_dy = 0.0;
accscalar_t sum_dy_xmu = 0.0;
accscalar_t sum_dy_c = 0.0;
accscalar_t sum_dy_xmu_c = 0.0;
for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) {
int input_base = blockIdx.x*ss + batch_id*ss*fs;
for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) {
auto e_grad = static_cast<accscalar_t>(grad_output[offset+input_base]);
auto e_input = static_cast<accscalar_t>(input[offset+input_base]);
// calculating sum_dy
auto sum_dy_y = e_grad - sum_dy_c;
auto sum_dy_t = sum_dy + sum_dy_y;
sum_dy_c = (sum_dy_t - sum_dy) - sum_dy_y;
sum_dy = sum_dy_t;
// calculating sum_dy_xmu
auto sum_dy_xmu_y = e_grad * (e_input - r_mean) - sum_dy_xmu_c;
auto sum_dy_xmu_t = sum_dy_xmu + sum_dy_xmu_y;
sum_dy_xmu_c = (sum_dy_xmu_t - sum_dy_xmu) - sum_dy_xmu_y;
sum_dy_xmu = sum_dy_xmu_t;
}
}
sum_dy = reduce_block((accscalar_t*)s_mem, sum_dy);
__syncthreads();
sum_dy_xmu = reduce_block((accscalar_t*)s_mem, sum_dy_xmu);
if (thread_id == 0) {
grad_bias[blockIdx.x] = static_cast<layerscalar_t>(sum_dy);
grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor);
mean_dy[blockIdx.x] = sum_dy / total_item_num;
mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num;
}
}
// elementwise backward BN kernel
template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
__global__ void batchnorm_backward_kernel(
const scalar_t* __restrict__ grad_output,
const scalar_t* __restrict__ input,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ var,
const layerscalar_t* __restrict__ weight,
const accscalar_t* __restrict__ mean_dy,
const accscalar_t* __restrict__ mean_dy_xmu,
scalar_t* __restrict__ grad_input,
const int ss,
const float eps) {
int address_base = blockIdx.x*ss + blockIdx.y*gridDim.x*ss;
auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]);
auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
auto factor_1_c = static_cast<accscalar_t>(var[blockIdx.x]) + eps;
auto factor_2_c = static_cast<accscalar_t>(weight[blockIdx.x]) / sqrt(factor_1_c);
factor_1_c /= static_cast<accscalar_t>(mean_dy_xmu[blockIdx.x]);
for (int offset = threadIdx.x; offset < ss ; offset+= blockDim.x) {
grad_input[address_base+offset] = (static_cast<accscalar_t>(grad_output[address_base+offset]) - m_dy_c - (static_cast<accscalar_t>(input[address_base+offset]) - m_c) / factor_1_c) * factor_2_c;
}
}
// parallel welford kernel to further reduce mean / biased_var / unbiased_var
// across multiple processes.
template <typename scalar_t, typename accscalar_t>
__global__ void welford_kernel_parallel(
const scalar_t* __restrict__ mean,
const scalar_t* __restrict__ var_biased,
scalar_t* __restrict__ out_mean,
scalar_t* __restrict__ out_var,
scalar_t* __restrict__ out_var_biased,
const int ns,
const int fs,
const int numel) {
static __shared__ int s_mem[160];
int block_size = blockDim.x;
accscalar_t* s_mem_ac = (accscalar_t*) &s_mem[32];
int input_base = blockIdx.x*ns + threadIdx.x;
int thread_id = threadIdx.x;
// load data;
auto x_mean = static_cast<accscalar_t>(mean[input_base]);
auto m_2_n = static_cast<accscalar_t>(var_biased[input_base]) * numel;
auto count = numel;
__syncthreads();
welford_reduce_mean_m2n<accscalar_t>(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id);
if (thread_id == 0) {
out_mean[blockIdx.x] = static_cast<scalar_t>(x_mean);
out_var[blockIdx.x] = static_cast<scalar_t>(m_2_n/(count-1));
out_var_biased[blockIdx.x] = static_cast<scalar_t>(m_2_n/count);
}
}
std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
const auto batch_size = input.size(0);
const auto feature_size = input.size(1);
auto space_size = get_tensor_spatial_size(input);
auto scalar_type = promote_scalartype(input);
at::Tensor out_var = at::empty({feature_size}, input.options().dtype(scalar_type));
at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type));
at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type));
int block_x = TILE_W;
int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / block_x));
const dim3 block(block_x, block_y);
const dim3 grid(feature_size);
// shared memory used for reduce on mean, var, num_elements;
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "welford_mean_var_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
welford_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
out_mean.data<accscalar_t>(),
out_var.data<accscalar_t>(),
out_var_biased.data<accscalar_t>(),
batch_size,
feature_size,
space_size);
}));
return {out_mean, out_var, out_var_biased};
}
at::Tensor batchnorm_forward_CUDA(
const at::Tensor input,
const at::Tensor mean,
const at::Tensor var,
const at::Tensor weight,
const at::Tensor shift,
const float eps) {
const auto batch_size = input.size(0);
const auto feature_size = input.size(1);
at::Tensor out = at::empty_like(input);
auto space_size = get_tensor_spatial_size(input);
int block = min(MAX_BLOCK_SIZE, h_next_pow2(space_size)/4);
// TODO(jie): should I do 1 block per feature?
const dim3 grid(feature_size, batch_size);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
var.data<accscalar_t>(),
weight.data<accscalar_t>(),
shift.data<accscalar_t>(),
out.data<scalar_t>(),
space_size,
eps);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
mean.data<accscalar_t>(),
var.data<accscalar_t>(),
weight.data<scalar_t>(),
shift.data<scalar_t>(),
out.data<scalar_t>(),
space_size,
eps);
}));
}
return out;
}
std::vector<at::Tensor> reduce_bn_CUDA(
const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor var,
const at::Tensor weight,
const float eps)
{
const auto batch_size = input.size(0);
const auto feature_size = input.size(1);
auto scalar_type = promote_scalartype(input);
at::Tensor mean_dy = at::empty({feature_size}, mean.options());
at::Tensor mean_dy_xmu = at::empty({feature_size}, mean.options());
at::Tensor grad_weight = at::empty({feature_size}, weight.options());
at::Tensor grad_bias = at::empty({feature_size}, weight.options());
auto space_size = get_tensor_spatial_size(input);
int block_x = TILE_W;
int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / block_x));
const dim3 block(block_x, block_y);
const dim3 grid(feature_size);
// shared memory used for reduce on sum_dy, sum_dy_xmu;
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
grad_output.data<scalar_t>(),
mean.data<accscalar_t>(),
var.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<accscalar_t>(),
grad_bias.data<accscalar_t>(),
batch_size,
feature_size,
space_size,
eps);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
input.data<scalar_t>(),
grad_output.data<scalar_t>(),
mean.data<accscalar_t>(),
var.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_weight.data<scalar_t>(),
grad_bias.data<scalar_t>(),
batch_size,
feature_size,
space_size,
eps);
}));
}
return {mean_dy, mean_dy_xmu, grad_weight, grad_bias};
}
at::Tensor batchnorm_backward_CUDA(
const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor var,
const at::Tensor weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu,
const float eps) {
const auto batch_size = input.size(0);
const auto feature_size = input.size(1);
at::Tensor grad_input = at::empty_like(input);
auto space_size = get_tensor_spatial_size(input);
int block = min(MAX_BLOCK_SIZE, h_next_pow2(space_size)/4);
// TODO(jie): should I do 1 block per feature?
const dim3 grid(feature_size, batch_size);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(),
input.data<scalar_t>(),
mean.data<accscalar_t>(),
var.data<accscalar_t>(),
weight.data<accscalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
space_size,
eps);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
grad_output.data<scalar_t>(),
input.data<scalar_t>(),
mean.data<accscalar_t>(),
var.data<accscalar_t>(),
weight.data<scalar_t>(),
mean_dy.data<accscalar_t>(),
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
space_size,
eps);
}));
}
return grad_input;
}
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes, const at::Tensor var_biased, int numel) {
const auto feature_size = mean_feature_nodes.size(0);
const auto world_size = mean_feature_nodes.size(1);
at::Tensor out_var = at::empty({feature_size}, var_biased.options());
at::Tensor out_var_biased = at::empty_like(out_var);
at::Tensor out_mean = at::empty_like(out_var);
// TODO(jie): tile this for memory coalescing!
const dim3 block(world_size);
const dim3 grid(feature_size);
// shared memory used for reduce on mean, var, num_elements;
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(mean_feature_nodes.type(), "welford_parallel_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
welford_kernel_parallel<scalar_t, accscalar_t><<<grid, block, 0, stream>>>(
mean_feature_nodes.data<scalar_t>(),
var_biased.data<scalar_t>(),
out_mean.data<scalar_t>(),
out_var.data<scalar_t>(),
out_var_biased.data<scalar_t>(),
world_size,
feature_size,
numel);
}));
return {out_mean, out_var, out_var_biased};
}
......@@ -15,3 +15,11 @@ apex.parallel
.. autoclass:: Reducer
:members:
.. autoclass:: SyncBatchNorm
:members:
Utility functions
----------------------------------
.. autofunction:: convert_syncbn_model
......@@ -40,6 +40,9 @@ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS main.py args...
```
`NUM_GPUS` should be less than or equal to the number of visible GPU devices on the node.
Optionally one can run imagenet with sync batch normalization by adding
`--sync_bn` into the `args...`
## Example commands
(note: batch size `--b 256` assumes your GPUs have >=16GB of onboard memory)
......
......@@ -67,6 +67,8 @@ parser.add_argument('--prof', dest='prof', action='store_true',
help='Only run 10 iterations for profiling.')
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--sync_bn', action='store_true',
help='enabling apex sync BN.')
cudnn.benchmark = True
......@@ -118,6 +120,11 @@ def main():
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch]()
if args.sync_bn:
import apex
print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model)
model = model.cuda()
if args.fp16:
model = network_to_half(model)
......
......@@ -70,6 +70,8 @@ parser.add_argument('--prof', dest='prof', action='store_true',
help='Only run 10 iterations for profiling.')
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--sync_bn', action='store_true',
help='enabling apex sync BN.')
cudnn.benchmark = True
......@@ -123,6 +125,10 @@ def main():
else:
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch]()
if args.sync_bn:
import apex
print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model)
model = model.cuda()
if args.fp16:
......
......@@ -67,6 +67,8 @@ parser.add_argument('--prof', dest='prof', action='store_true',
help='Only run 10 iterations for profiling.')
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--sync_bn', action='store_true',
help='enabling apex sync BN.')
cudnn.benchmark = True
......@@ -117,6 +119,11 @@ def main():
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch]()
if args.sync_bn:
import apex
print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model)
model = model.cuda()
if args.fp16:
model = network_to_half(model)
......
import torch
from setuptools import setup, find_packages
import sys
if not torch.cuda.is_available():
print("Warning: Torch did not find available GPUs on this system.\n",
"If your intention is to cross-compile, this is not an error.")
......@@ -13,6 +15,19 @@ if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
raise RuntimeError("APEx requires Pytorch 0.4 or newer.\n" +
"The latest stable release can be obtained from https://pytorch.org/")
cmdclass = {}
ext_modules = []
if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
sys.argv.remove("--cuda_ext")
cmdclass['build_ext'] = BuildExtension
ext_modules.append(CUDAExtension('syncbn',[
'csrc/syncbn.cpp',
'csrc/welford.cu'
]))
setup(
name='apex',
version='0.1',
......@@ -26,4 +41,6 @@ setup(
'examples',
'apex.egg-info',)),
description='PyTorch Extensions written by NVIDIA',
ext_modules=ext_modules,
cmdclass=cmdclass,
)
import torch
import numpy as np
import apex
if True:
print("using setup tools")
import syncbn
else:
print("using jit")
from torch.utils.cpp_extension import load
syncbn = load(name='syncbn', sources=['../../csrc/syncbn.cpp', '../../csrc/welford.cu'])
def compare(desc, inp1, inp2, error):
a = inp1.clone().detach().cpu().numpy()
b = inp2.clone().detach().cpu().numpy()
close = np.allclose(a,b, error, error)
if not close:
print(desc, close)
z = a - b
index = (np.abs(z) >= error + error * np.abs(b)).nonzero()
print("dif : ", z[index])
print("inp1 : ", a[index])
print("inp2 : ", b[index])
return close
feature_size = 10
space_size = 16
batch_size = 5
error = 1e-5
np.random.seed(1)
dtype = np.float32
inp = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype(dtype)
grad = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype(dtype)
weight = (np.random.randn(feature_size)).astype(dtype)
bias = (np.random.randn(feature_size)).astype(dtype)
type_tensor = torch.cuda.FloatTensor
ref_tensor = torch.cuda.DoubleTensor
inp_t = type_tensor(inp)
weight_t = type_tensor(weight)
bias_t = type_tensor(bias)
inp_r = ref_tensor(inp.transpose(1, 0, 2, 3).reshape(feature_size, -1))
inp2_r = ref_tensor(inp)
weight_r = ref_tensor(weight).view(-1, 1, 1)
bias_r = ref_tensor(bias).view(-1, 1, 1)
grad_output_t = type_tensor(grad)
m = inp_r.mean(1)
b_v = inp_r.var(1, unbiased=False)
unb_v = inp_r.var(1, unbiased=True)
mean, var, var_biased = syncbn.welford_mean_var(inp_t)
bn = torch.nn.BatchNorm2d(feature_size).cuda()
bn.momentum = 1.0
bn.weight.data = weight_t.clone()
bn.bias.data = bias_t.clone()
inp_bn = inp_t.clone().requires_grad_()
grad_bn = grad_output_t.clone().detach()
out_bn = bn(inp_bn)
out_bn.backward(grad_bn)
sbn = apex.parallel.SyncBatchNorm(feature_size).cuda()
sbn.momentum = 1.0
sbn.weight.data = weight_t.clone()
sbn.bias.data = bias_t.clone()
inp_sbn = inp_t.clone().requires_grad_()
grad_sbn = grad_output_t.clone().detach()
out_sbn = sbn(inp_sbn)
out_sbn.backward(grad_sbn)
sbn_result = True
bn_result = True
sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result
sbn_result = compare("comparing variance: ", var, unb_v, error) and sbn_result
sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result
eps = 1e-5
out = syncbn.batchnorm_forward(inp_t, mean, var_biased, weight_t, bias_t, eps)
out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r
sbn_result = compare("comparing output: ", out, out_r, error) and sbn_result
compare("comparing bn output: ", out_bn, out_r, error)
grad_output_t = type_tensor(grad)
grad_output_r = ref_tensor(grad.transpose(1, 0, 2, 3).reshape(feature_size, -1))
grad_output2_r = ref_tensor(grad)
grad_bias_r = grad_output_r.sum(1)
grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
mean_dy_r = grad_output_r.mean(1)
mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1)
grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, var_biased, weight_t, eps)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, var_biased, weight_t, mean_dy, mean_dy_xmu, eps)
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result
sbn_result = compare("comparing mean_dy_xmu grad: ", mean_dy_xmu, mean_dy_xmu_r, error) and sbn_result
sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result
compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)
sbn_result = compare("comparing sbn input grad: ", inp_sbn.grad, grad_input_r, error) and sbn_result
compare("comparing output: ", out_bn, out_sbn, error)
sbn_result = compare("comparing running_mean: ", bn.running_mean.data, sbn.running_mean.data, error) and sbn_result
sbn_result = compare("comparing running_variance: ", bn.running_var.data, sbn.running_var.data, error) and sbn_result
compare("comparing grad_input: ", inp_bn.grad, inp_sbn.grad, error)
compare("comparing grad_bias: ", bn.bias.grad, sbn.bias.grad, error)
compare("comparing grad_bias bn to ref: ", bn.bias.grad, grad_bias_r, error)
sbn_result = compare("comparing grad_bias sbn to ref: ", sbn.bias.grad, grad_bias_r, error) and sbn_result
compare("comparing grad_weight: ", bn.weight.grad, sbn.weight.grad, error)
compare("comparing grad_weight bn to ref: ", bn.weight.grad, grad_weight_r, error)
sbn_result = compare("comparing grad_weight sbn to ref: ", sbn.weight.grad, grad_weight_r, error) and sbn_result
if sbn_result:
print("====SBN single gpu passed tests")
else:
print("*SBN single gpu failed*")
import torch
import numpy as np
import apex
import syncbn
import os
import argparse
import torch.optim as optim
def compare(desc, inp1, inp2, error):
a = inp1.clone().detach().cpu().numpy()
b = inp2.clone().detach().cpu().numpy()
close = np.allclose(a,b, error, error)
if not close:
print(desc, close)
z = a - b
index = (np.abs(z) >= error + error * np.abs(b)).nonzero()
print("dif : ", z[index])
print("inp1 : ", a[index])
print("inp2 : ", b[index])
return close
feature_size = 10
space_size = 40
batch_size = 32
from apex.parallel import DistributedDataParallel as DDP
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--fp16", action='store_true', default=False)
parser.add_argument("--fp64", action='store_true', default=False)
args = parser.parse_args()
args.world_size = int(os.environ['WORLD_SIZE'])
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
start = args.local_rank * batch_size//args.world_size
finish = (args.local_rank + 1) * batch_size//args.world_size
error = 1e-5
dtype = np.float32
if args.fp16:
error = 1e-3
dtype = np.float16
elif args.fp64:
error = 1e-8
dtype = np.float64
np.random.seed(18)
inp = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)
grad = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)
weight = np.random.randn(feature_size).astype(dtype)
bias = np.random.randn(feature_size).astype(dtype)
type_tensor = torch.cuda.FloatTensor
if args.fp16:
type_tensor = torch.cuda.HalfTensor
if args.fp64:
type_tensor = torch.cuda.DoubleTensor
ref_tensor = torch.cuda.DoubleTensor
inp_t = type_tensor(inp)
weight_t = type_tensor(weight)
bias_t = type_tensor(bias)
inp_r = ref_tensor(inp.transpose(1, 0, 2, 3).reshape(feature_size, -1))
inp2_r = ref_tensor(inp)
weight_r = ref_tensor(weight).view(-1, 1, 1)
bias_r = ref_tensor(bias).view(-1, 1, 1)
grad_output_t = type_tensor(grad)
m = inp_r.mean(1)
b_v = inp_r.var(1, unbiased=False)
unb_v = inp_r.var(1, unbiased=True)
mean, var, var_biased = syncbn.welford_mean_var(inp_t)
bn = torch.nn.BatchNorm2d(feature_size).cuda()
bn.momentum = 1.0
bn.weight.data = weight_t.clone()
bn.bias.data = bias_t.clone()
if args.fp16:
bn.half()
if args.fp64:
bn.double()
inp_bn = inp_t.clone().requires_grad_()
grad_bn = grad_output_t.clone().detach()
out_bn = bn(inp_bn)
out_bn.backward(grad_bn)
bn_opt = optim.SGD(bn.parameters(), lr=1.0)
sbn = apex.parallel.SyncBatchNorm(feature_size).cuda()
sbn.momentum = 1.0
sbn.weight.data = weight_t.clone()
sbn.bias.data = bias_t.clone()
if args.fp16:
sbn.half()
if args.fp64:
sbn.double()
sbn = DDP(sbn)
sbn_opt = optim.SGD(sbn.parameters(), lr=1.0*args.world_size)
inp_sbn = inp_t.clone().requires_grad_()
grad_sbn = grad_output_t.clone().detach()
out_sbn = sbn(inp_sbn[start:finish])
out_sbn.backward(grad_sbn[start:finish])
sbn_result = True
bn_result = True
if args.local_rank == 0:
sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result
sbn_result = compare("comparing variance: ", var, unb_v, error) and sbn_result
sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result
eps = 1e-5
out = syncbn.batchnorm_forward(inp_t, mean, var_biased, weight_t, bias_t, eps)
out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r
if args.local_rank == 0:
sbn_result = compare("comparing output: ", out, out_r, error) and sbn_result
compare("comparing bn output: ", out_bn, out_r, error)
grad_output_t = type_tensor(grad)
grad_output_r = ref_tensor(grad.transpose(1, 0, 2, 3).reshape(feature_size, -1))
grad_output2_r = ref_tensor(grad)
grad_bias_r = grad_output_r.sum(1)
grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
mean_dy_r = grad_output_r.mean(1)
mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1)
grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, var_biased, weight_t, eps)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, var_biased, weight_t, mean_dy, mean_dy_xmu, eps)
if args.local_rank == 0:
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result
sbn_result = compare("comparing mean_dy_xmu grad: ", mean_dy_xmu, mean_dy_xmu_r, error) and sbn_result
sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result
compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)
if args.local_rank == 0:
sbn_result = compare("comparing running_mean: ", bn.running_mean.data, sbn.module.running_mean.data, error) and sbn_result
sbn_result = compare("comparing running_variance: ", bn.running_var.data, sbn.module.running_var.data, error) and sbn_result
# execute by both
compare("comparing layers output: ", out_bn[start:finish], out_sbn, error) and sbn_result
compare("comparing layers grad_input: ", inp_bn.grad[start:finish], inp_sbn.grad[start:finish], error) and sbn_result
bn_opt.step()
sbn_opt.step()
if args.local_rank == 0:
compare("comparing bn vs sbn bias: ", bn.bias, sbn.module.bias, error)
compare("comparing bn vs ref bias: ", bn.bias, bias_r.view(-1) - grad_bias_r, error)
sbn_result = compare("comparing sbn vs ref bias: ", sbn.module.bias, bias_r.view(-1) - grad_bias_r, error) and sbn_result
compare("comparing bn vs sbn weight: ", bn.weight, sbn.module.weight, error)
compare("comparing bn vs ref weight: ", bn.weight, (weight_r.view(-1) - grad_weight_r), error)
sbn_result = compare("comparing sbn vs ref weight: ", sbn.module.weight, (weight_r.view(-1) - grad_weight_r), error) and sbn_result
if sbn_result:
print("====SBN two gpu passed tests")
else:
print("*SBN two gpu failed*")
python single_gpu_unit_test.py
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py --fp64
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment