Commit c7dcb0e1 authored by Michael Carilli's avatar Michael Carilli
Browse files

Merge branch 'master' into testing_cache_fix

parents c619fe6e 06e11bd3
......@@ -112,3 +112,6 @@ class NoOpHandle(object):
@property
def verbose(self):
return False
def _deactivate(self):
pass
......@@ -19,7 +19,7 @@ except ImportError:
warned_syncbn = True
from .sync_batchnorm import SyncBatchNorm
def convert_syncbn_model(module, process_group=None):
def convert_syncbn_model(module, process_group=None, channel_last=False):
'''
Recursively traverse module and its children to replace all
`torch.nn.modules.batchnorm._BatchNorm` with `apex.parallel.SyncBatchNorm`
......@@ -38,14 +38,16 @@ def convert_syncbn_model(module, process_group=None):
'''
mod = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group)
mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group, channel_last=channel_last)
mod.running_mean = module.running_mean
mod.running_var = module.running_var
if module.affine:
mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
mod.add_module(name, convert_syncbn_model(child))
mod.add_module(name, convert_syncbn_model(child,
process_group=process_group,
channel_last=channel_last))
# TODO(jie) should I delete model explicitly?
del module
return mod
......@@ -38,26 +38,43 @@ class SyncBatchNorm(_BatchNorm):
process_group: pass in a process group within which the stats of the
mini-batch is being synchronized. ``None`` for using default process
group
channel_last: a boolean value that when set to ``True``, this module
take the last dimension of the input tensor to be the channel
dimension. Default: False
Examples::
>>> # channel first tensor
>>> sbn = apex.parallel.SyncBatchNorm(100).cuda()
>>> inp = torch.randn(10, 100, 14, 14).cuda()
>>> out = sbn(inp)
>>> inp = torch.randn(3, 100, 20).cuda()
>>> out = sbn(inp)
>>> # channel last tensor
>>> sbn = apex.parallel.SyncBatchNorm(100, channel_last=True).cuda()
>>> inp = torch.randn(10, 14, 14, 100).cuda()
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last = False):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.process_group = process_group
self.channel_last = channel_last
def _specify_process_group(self, process_group):
self.process_group = process_group
def _specify_channel_last(self, channel_last):
self.channel_last = channel_last
def forward(self, input):
if not self.training and self.track_running_stats:
if not self.training and self.track_running_stats and not self.channel_last:
# fall back to pytorch implementation for inference
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else:
self.num_batches_tracked += 1
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.track_running_stats, self.momentum, self.process_group)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
self.num_batches_tracked += 1
if self.momentum is None:
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else:
exponential_average_factor = self.momentum
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, self.channel_last)
......@@ -7,26 +7,40 @@ from apex.parallel import ReduceOp
class SyncBatchnormFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None):
def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False):
torch.cuda.nvtx.range_push("sync_BN_fw")
input = input.contiguous()
world_size = 0
mean = None
var_biased = None
inv_std = None
var = None
out = None
count = None
if track_running_stats:
mean, var, var_biased = syncbn.welford_mean_var(input)
if channel_last:
count = int(input.numel()/input.size(-1))
mean, var_biased = syncbn.welford_mean_var_c_last(input)
else :
count = int(input.numel()/input.size(1))
mean, var_biased = syncbn.welford_mean_var(input)
if torch.distributed.is_initialized():
if not process_group:
process_group = torch.distributed.group.WORLD
world_size = torch.distributed.get_world_size(process_group)
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
var_all = torch.empty(world_size, var.size(0), dtype=var.dtype, device=var.device)
var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=var_biased.device)
mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)]
var_l = [var_all.narrow(0, i, 1) for i in range(world_size)]
torch.distributed.all_gather(mean_l, mean, process_group)
torch.distributed.all_gather(var_l, var_biased, process_group)
mean, var, var_biased = syncbn.welford_parallel(mean_all.transpose(1,0).contiguous(), var_all.transpose(1,0).contiguous(), int(input.numel()/input.size(1)))
mean, var, inv_std = syncbn.welford_parallel(mean_all, var_all, count, eps)
# TODO(Jie): should do fp32 math instead!
else:
inv_std = 1.0 / torch.sqrt(var_biased + eps)
var = var_biased * (count) / (count-1)
r_m_inc = mean if running_mean.dtype != torch.float16 else mean.half()
r_v_inc = var if running_variance.dtype != torch.float16 else var.half()
......@@ -34,14 +48,17 @@ class SyncBatchnormFunction(Function):
running_variance.data = running_variance.data * (1-momentum) + momentum*r_v_inc
else:
mean = running_mean.data
var_biased = running_var.data
inv_std = 1.0 / torch.sqrt(running_var.data + eps)
ctx.save_for_backward(input, weight, mean, var_biased)
ctx.eps = eps
ctx.save_for_backward(input, weight, mean, inv_std)
ctx.process_group = process_group
ctx.channel_last = channel_last
ctx.world_size = world_size
out = syncbn.batchnorm_forward(input, mean, var_biased, weight, bias, eps)
if channel_last:
out = syncbn.batchnorm_forward_c_last(input, mean, inv_std, weight, bias)
else:
out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias)
torch.cuda.nvtx.range_pop()
return out
......@@ -53,14 +70,17 @@ class SyncBatchnormFunction(Function):
# mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input, weight, running_mean, running_variance = ctx.saved_tensors
eps = ctx.eps
saved_input, weight, mean, inv_std = ctx.saved_tensors
process_group = ctx.process_group
channel_last = ctx.channel_last
world_size = ctx.world_size
grad_input = grad_weight = grad_bias = None
# TODO(jie): why do I have to clone here? life time of grad_output?
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, running_mean, running_variance, weight, eps)
if channel_last:
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(grad_output, saved_input, mean, inv_std, weight)
else:
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, mean, inv_std, weight)
# calculate grad_input
if ctx.needs_input_grad[0]:
......@@ -72,7 +92,10 @@ class SyncBatchnormFunction(Function):
torch.distributed.all_reduce(
mean_dy_xmu, ReduceOp.SUM, process_group)
mean_dy_xmu = mean_dy_xmu / world_size
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, running_mean, running_variance, weight, mean_dy, mean_dy_xmu, eps)
if channel_last:
grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
else:
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
if weight is None or not ctx.needs_input_grad[1]:
grad_weight = None
......@@ -81,4 +104,4 @@ class SyncBatchnormFunction(Function):
grad_bias = None
torch.cuda.nvtx.range_pop()
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None
......@@ -3,52 +3,93 @@
#include <vector>
// returns {mean,unbiased_var,biased_var}
// returns {mean,biased_var}
// implemented using welford
std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input);
// reduces array of mean/var across processes
// returns global {mean,unbiased_var,biased_var}
// returns global {mean,inv_std,biased_var}
// implemented using welford
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes, const at::Tensor var_biased_feature_nodes, int numel);
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
const at::Tensor var_biased_feature_nodes,
int numel,
const float eps);
// elementwise BN operation, returns output
// input/weight/shift should have identical data type;
// mean/var have promoted data type (dtype==fp16?fp32:dtype)
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
const at::Tensor mean,
const at::Tensor var,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor shift,
const float eps);
const at::Tensor shift);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/var have promoted data type (dtype==fp16?fp32:dtype)
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// implemented using kahan summation
std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor var,
const at::Tensor weight,
const float eps);
const at::Tensor inv_std,
const at::Tensor weight);
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/var/mean_dy/mean_dy_xmu precision is fp32
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor var,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu,
const float eps);
const at::Tensor mean_dy_xmu);
// returns {mean, biased_var}
// implemented using welford
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
// elementwise BN operation, returns output
// input/weight/shift should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor shift);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight);
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance");
m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance");
m.def("batchnorm_forward", &batchnorm_forward_CUDA, "batchnorm forward");
m.def("reduce_bn", &reduce_bn_CUDA, "batchnorm backward reduce grad sum and bias/weight gradient");
m.def("reduce_bn", &reduce_bn_CUDA, "batchnorm backward reduce grad sum and bias/weight grad");
m.def("batchnorm_backward", &batchnorm_backward_CUDA, "batchnorm backward dgrad");
m.def("welford_mean_var_c_last", &welford_mean_var_c_last_CUDA, "welford mean variance nhwc");
m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc");
m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc");
m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc");
}
This diff is collapsed.
# Base image must at least have pytorch and CUDA installed.
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:18.04-py3
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:18.12-py3
FROM $BASE_IMAGE
ARG BASE_IMAGE
RUN echo "Installing Apex on top of ${BASE_IMAGE}"
......
......@@ -21,8 +21,8 @@ Currently, Pytorch's default non-devel image on Dockerhub
## Option 2: Install Apex in a running container
Instead of building a new container, it is also a viable option to clone Apex on bare metal, mount the Apex repo into your container at launch by running, for example,
Instead of building a new container, it is also a viable option to `git clone https://github.com/NVIDIA/apex.git` on bare metal, mount the Apex repo into your container at launch by running, for example,
```
docker run --runtime=nvidia -it --rm --ipc=host -v /bare/metal/apex:/apex/in/container <base image>
```
then go to /apex/in/container within the running container and `python setup.py install`.
then go to /apex/in/container within the running container and `python setup.py install [--cuda_ext] [--cpp_ext]`.
......@@ -54,7 +54,11 @@ m = inp_r.mean(1)
b_v = inp_r.var(1, unbiased=False)
unb_v = inp_r.var(1, unbiased=True)
mean, var, var_biased = syncbn.welford_mean_var(inp_t)
eps = 1e-5
#mean, var, var_biased = syncbn.welford_mean_var(inp_t)
mean, var_biased = syncbn.welford_mean_var(inp_t)
inv_std = 1.0 / torch.sqrt(var_biased + eps)
bn = torch.nn.BatchNorm2d(feature_size).cuda()
bn.momentum = 1.0
......@@ -74,16 +78,25 @@ grad_sbn = grad_output_t.clone().detach()
out_sbn = sbn(inp_sbn)
out_sbn.backward(grad_sbn)
sbn_c_last = apex.parallel.SyncBatchNorm(feature_size, channel_last=True).cuda()
sbn_c_last.momentum = 1.0
sbn_c_last.weight.data = weight_t.clone()
sbn_c_last.bias.data = bias_t.clone()
inp_sbn_c_last = inp_t.clone().transpose(-1, 1).contiguous().requires_grad_()
grad_sbn_c_last = grad_output_t.clone().transpose(-1, 1).contiguous().detach()
out_sbn_c_last = sbn_c_last(inp_sbn_c_last)
out_sbn_c_last.backward(grad_sbn_c_last)
sbn_result = True
sbn_result_c_last = True
bn_result = True
sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result
sbn_result = compare("comparing variance: ", var, unb_v, error) and sbn_result
#sbn_result = compare("comparing variance: ", var, unb_v, error) and sbn_result
sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result
eps = 1e-5
out = syncbn.batchnorm_forward(inp_t, mean, var_biased, weight_t, bias_t, eps)
out = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t)
out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r
sbn_result = compare("comparing output: ", out, out_r, error) and sbn_result
......@@ -102,8 +115,8 @@ mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).co
grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, var_biased, weight_t, eps)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, var_biased, weight_t, mean_dy, mean_dy_xmu, eps)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result
......@@ -112,7 +125,7 @@ sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error)
compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)
sbn_result = compare("comparing sbn input grad: ", inp_sbn.grad, grad_input_r, error) and sbn_result
compare("comparing output: ", out_bn, out_sbn, error)
compare("comparing bn/sbn output: ", out_bn, out_sbn, error)
sbn_result = compare("comparing running_mean: ", bn.running_mean.data, sbn.running_mean.data, error) and sbn_result
sbn_result = compare("comparing running_variance: ", bn.running_var.data, sbn.running_var.data, error) and sbn_result
compare("comparing grad_input: ", inp_bn.grad, inp_sbn.grad, error)
......@@ -123,7 +136,21 @@ compare("comparing grad_weight: ", bn.weight.grad, sbn.weight.grad, error)
compare("comparing grad_weight bn to ref: ", bn.weight.grad, grad_weight_r, error)
sbn_result = compare("comparing grad_weight sbn to ref: ", sbn.weight.grad, grad_weight_r, error) and sbn_result
compare("comparing channel last bn/sbn output: ", out_bn, out_sbn_c_last.transpose(-1, 1).contiguous(), error)
sbn_result_c_last = compare("comparing channel last running_mean: ", bn.running_mean.data, sbn_c_last.running_mean.data, error) and sbn_result_c_last
sbn_result_c_last = compare("comparing channel last running_variance: ", bn.running_var.data, sbn_c_last.running_var.data, error) and sbn_result_c_last
compare("comparing channel last grad_input: ", inp_bn.grad, inp_sbn_c_last.grad.transpose(-1, 1).contiguous(), error)
compare("comparing channel last grad_bias: ", bn.bias.grad, sbn_c_last.bias.grad, error)
sbn_result_c_last = compare("comparing channel last grad_bias sbn to ref: ", sbn_c_last.bias.grad, grad_bias_r, error) and sbn_result_c_last
compare("comparing channel last grad_weight: ", bn.weight.grad, sbn_c_last.weight.grad, error)
sbn_result_c_last = compare("comparing channel last grad_weight sbn to ref: ", sbn_c_last.weight.grad, grad_weight_r, error) and sbn_result_c_last
if sbn_result:
print("====SBN single gpu passed tests")
else:
print("*SBN single gpu failed*")
if sbn_result_c_last:
print("====SBN channel last single gpu passed tests")
else:
print("*SBN channel last single gpu failed*")
......@@ -75,7 +75,10 @@ m = inp_r.mean(1)
b_v = inp_r.var(1, unbiased=False)
unb_v = inp_r.var(1, unbiased=True)
mean, var, var_biased = syncbn.welford_mean_var(inp_t)
eps = 1e-5
mean, var_biased = syncbn.welford_mean_var(inp_t)
inv_std = 1.0 / torch.sqrt(var_biased + eps)
bn = torch.nn.BatchNorm2d(feature_size).cuda()
bn.momentum = 1.0
......@@ -111,12 +114,9 @@ bn_result = True
if args.local_rank == 0:
sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result
sbn_result = compare("comparing variance: ", var, unb_v, error) and sbn_result
sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result
eps = 1e-5
out = syncbn.batchnorm_forward(inp_t, mean, var_biased, weight_t, bias_t, eps)
out = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t)
out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r
if args.local_rank == 0:
......@@ -136,8 +136,8 @@ mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).co
grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, var_biased, weight_t, eps)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, var_biased, weight_t, mean_dy, mean_dy_xmu, eps)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)
if args.local_rank == 0:
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment