Unverified Commit 438f6f9f authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Merge pull request #125 from NVIDIA/nhwc_sbn_pr

[sync BN nhwc]
parents 3c7a0e44 a62b87ea
...@@ -19,7 +19,7 @@ except ImportError: ...@@ -19,7 +19,7 @@ except ImportError:
warned_syncbn = True warned_syncbn = True
from .sync_batchnorm import SyncBatchNorm from .sync_batchnorm import SyncBatchNorm
def convert_syncbn_model(module, process_group=None): def convert_syncbn_model(module, process_group=None, channel_last=False):
''' '''
Recursively traverse module and its children to replace all Recursively traverse module and its children to replace all
`torch.nn.modules.batchnorm._BatchNorm` with `apex.parallel.SyncBatchNorm` `torch.nn.modules.batchnorm._BatchNorm` with `apex.parallel.SyncBatchNorm`
...@@ -38,14 +38,16 @@ def convert_syncbn_model(module, process_group=None): ...@@ -38,14 +38,16 @@ def convert_syncbn_model(module, process_group=None):
''' '''
mod = module mod = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group) mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group, channel_last=channel_last)
mod.running_mean = module.running_mean mod.running_mean = module.running_mean
mod.running_var = module.running_var mod.running_var = module.running_var
if module.affine: if module.affine:
mod.weight.data = module.weight.data.clone().detach() mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach() mod.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children(): for name, child in module.named_children():
mod.add_module(name, convert_syncbn_model(child)) mod.add_module(name, convert_syncbn_model(child,
process_group=process_group,
channel_last=channel_last))
# TODO(jie) should I delete model explicitly? # TODO(jie) should I delete model explicitly?
del module del module
return mod return mod
...@@ -38,26 +38,43 @@ class SyncBatchNorm(_BatchNorm): ...@@ -38,26 +38,43 @@ class SyncBatchNorm(_BatchNorm):
process_group: pass in a process group within which the stats of the process_group: pass in a process group within which the stats of the
mini-batch is being synchronized. ``None`` for using default process mini-batch is being synchronized. ``None`` for using default process
group group
channel_last: a boolean value that when set to ``True``, this module
take the last dimension of the input tensor to be the channel
dimension. Default: False
Examples:: Examples::
>>> # channel first tensor
>>> sbn = apex.parallel.SyncBatchNorm(100).cuda() >>> sbn = apex.parallel.SyncBatchNorm(100).cuda()
>>> inp = torch.randn(10, 100, 14, 14).cuda() >>> inp = torch.randn(10, 100, 14, 14).cuda()
>>> out = sbn(inp) >>> out = sbn(inp)
>>> inp = torch.randn(3, 100, 20).cuda() >>> inp = torch.randn(3, 100, 20).cuda()
>>> out = sbn(inp) >>> out = sbn(inp)
>>> # channel last tensor
>>> sbn = apex.parallel.SyncBatchNorm(100, channel_last=True).cuda()
>>> inp = torch.randn(10, 14, 14, 100).cuda()
""" """
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last = False):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.process_group = process_group self.process_group = process_group
self.channel_last = channel_last
def _specify_process_group(self, process_group): def _specify_process_group(self, process_group):
self.process_group = process_group self.process_group = process_group
def _specify_channel_last(self, channel_last):
self.channel_last = channel_last
def forward(self, input): def forward(self, input):
if not self.training and self.track_running_stats: if not self.training and self.track_running_stats and not self.channel_last:
# fall back to pytorch implementation for inference # fall back to pytorch implementation for inference
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps) return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else: else:
self.num_batches_tracked += 1 exponential_average_factor = 0.0
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.track_running_stats, self.momentum, self.process_group) if self.training and self.track_running_stats:
self.num_batches_tracked += 1
if self.momentum is None:
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else:
exponential_average_factor = self.momentum
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, self.channel_last)
...@@ -7,26 +7,40 @@ from apex.parallel import ReduceOp ...@@ -7,26 +7,40 @@ from apex.parallel import ReduceOp
class SyncBatchnormFunction(Function): class SyncBatchnormFunction(Function):
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None): def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False):
torch.cuda.nvtx.range_push("sync_BN_fw") torch.cuda.nvtx.range_push("sync_BN_fw")
input = input.contiguous() input = input.contiguous()
world_size = 0 world_size = 0
mean = None
var_biased = None
inv_std = None
var = None
out = None
count = None
if track_running_stats: if track_running_stats:
mean, var, var_biased = syncbn.welford_mean_var(input) if channel_last:
count = int(input.numel()/input.size(-1))
mean, var_biased = syncbn.welford_mean_var_c_last(input)
else :
count = int(input.numel()/input.size(1))
mean, var_biased = syncbn.welford_mean_var(input)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if not process_group: if not process_group:
process_group = torch.distributed.group.WORLD process_group = torch.distributed.group.WORLD
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device) mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
var_all = torch.empty(world_size, var.size(0), dtype=var.dtype, device=var.device) var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=var_biased.device)
mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)] mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)]
var_l = [var_all.narrow(0, i, 1) for i in range(world_size)] var_l = [var_all.narrow(0, i, 1) for i in range(world_size)]
torch.distributed.all_gather(mean_l, mean, process_group) torch.distributed.all_gather(mean_l, mean, process_group)
torch.distributed.all_gather(var_l, var_biased, process_group) torch.distributed.all_gather(var_l, var_biased, process_group)
mean, var, var_biased = syncbn.welford_parallel(mean_all.transpose(1,0).contiguous(), var_all.transpose(1,0).contiguous(), int(input.numel()/input.size(1))) mean, var, inv_std = syncbn.welford_parallel(mean_all, var_all, count, eps)
# TODO(Jie): should do fp32 math instead! # TODO(Jie): should do fp32 math instead!
else:
inv_std = 1.0 / torch.sqrt(var_biased + eps)
var = var_biased * (count) / (count-1)
r_m_inc = mean if running_mean.dtype != torch.float16 else mean.half() r_m_inc = mean if running_mean.dtype != torch.float16 else mean.half()
r_v_inc = var if running_variance.dtype != torch.float16 else var.half() r_v_inc = var if running_variance.dtype != torch.float16 else var.half()
...@@ -34,14 +48,17 @@ class SyncBatchnormFunction(Function): ...@@ -34,14 +48,17 @@ class SyncBatchnormFunction(Function):
running_variance.data = running_variance.data * (1-momentum) + momentum*r_v_inc running_variance.data = running_variance.data * (1-momentum) + momentum*r_v_inc
else: else:
mean = running_mean.data mean = running_mean.data
var_biased = running_var.data inv_std = 1.0 / torch.sqrt(running_var.data + eps)
ctx.save_for_backward(input, weight, mean, var_biased) ctx.save_for_backward(input, weight, mean, inv_std)
ctx.eps = eps
ctx.process_group = process_group ctx.process_group = process_group
ctx.channel_last = channel_last
ctx.world_size = world_size ctx.world_size = world_size
out = syncbn.batchnorm_forward(input, mean, var_biased, weight, bias, eps) if channel_last:
out = syncbn.batchnorm_forward_c_last(input, mean, inv_std, weight, bias)
else:
out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
return out return out
...@@ -53,14 +70,17 @@ class SyncBatchnormFunction(Function): ...@@ -53,14 +70,17 @@ class SyncBatchnormFunction(Function):
# mini batch mean & var are calculated by forward path. # mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0) # mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0) # var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input, weight, running_mean, running_variance = ctx.saved_tensors saved_input, weight, mean, inv_std = ctx.saved_tensors
eps = ctx.eps
process_group = ctx.process_group process_group = ctx.process_group
channel_last = ctx.channel_last
world_size = ctx.world_size world_size = ctx.world_size
grad_input = grad_weight = grad_bias = None grad_input = grad_weight = grad_bias = None
# TODO(jie): why do I have to clone here? life time of grad_output? # TODO(jie): why do I have to clone here? life time of grad_output?
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, running_mean, running_variance, weight, eps) if channel_last:
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(grad_output, saved_input, mean, inv_std, weight)
else:
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, mean, inv_std, weight)
# calculate grad_input # calculate grad_input
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
...@@ -72,7 +92,10 @@ class SyncBatchnormFunction(Function): ...@@ -72,7 +92,10 @@ class SyncBatchnormFunction(Function):
torch.distributed.all_reduce( torch.distributed.all_reduce(
mean_dy_xmu, ReduceOp.SUM, process_group) mean_dy_xmu, ReduceOp.SUM, process_group)
mean_dy_xmu = mean_dy_xmu / world_size mean_dy_xmu = mean_dy_xmu / world_size
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, running_mean, running_variance, weight, mean_dy, mean_dy_xmu, eps) if channel_last:
grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
else:
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
if weight is None or not ctx.needs_input_grad[1]: if weight is None or not ctx.needs_input_grad[1]:
grad_weight = None grad_weight = None
...@@ -81,4 +104,4 @@ class SyncBatchnormFunction(Function): ...@@ -81,4 +104,4 @@ class SyncBatchnormFunction(Function):
grad_bias = None grad_bias = None
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None
...@@ -3,52 +3,93 @@ ...@@ -3,52 +3,93 @@
#include <vector> #include <vector>
// returns {mean,unbiased_var,biased_var} // returns {mean,biased_var}
// implemented using welford // implemented using welford
std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input); std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input);
// reduces array of mean/var across processes // reduces array of mean/var across processes
// returns global {mean,unbiased_var,biased_var} // returns global {mean,inv_std,biased_var}
// implemented using welford // implemented using welford
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes, const at::Tensor var_biased_feature_nodes, int numel); std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
const at::Tensor var_biased_feature_nodes,
int numel,
const float eps);
// elementwise BN operation, returns output // elementwise BN operation, returns output
// input/weight/shift should have identical data type; // input/weight/shift should have identical data type;
// mean/var have promoted data type (dtype==fp16?fp32:dtype) // mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
at::Tensor batchnorm_forward_CUDA(const at::Tensor input, at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor var, const at::Tensor inv_std,
const at::Tensor weight, const at::Tensor weight,
const at::Tensor shift, const at::Tensor shift);
const float eps);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias} // backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type; // grad_output/input should have identical data type;
// mean/var have promoted data type (dtype==fp16?fp32:dtype) // mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// implemented using kahan summation // implemented using kahan summation
std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output, std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
const at::Tensor input, const at::Tensor input,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor var, const at::Tensor inv_std,
const at::Tensor weight, const at::Tensor weight);
const float eps);
// elementwise backward BN operation, returns grad_input // elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32; // grad_output/input/weight precision could be fp16/fp32;
// mean/var/mean_dy/mean_dy_xmu precision is fp32 // mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output, at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
const at::Tensor input, const at::Tensor input,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor var, const at::Tensor inv_std,
const at::Tensor weight, const at::Tensor weight,
const at::Tensor mean_dy, const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu, const at::Tensor mean_dy_xmu);
const float eps);
// returns {mean, biased_var}
// implemented using welford
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
// elementwise BN operation, returns output
// input/weight/shift should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor shift);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight);
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::Tensor weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance"); m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance");
m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance"); m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance");
m.def("batchnorm_forward", &batchnorm_forward_CUDA, "batchnorm forward"); m.def("batchnorm_forward", &batchnorm_forward_CUDA, "batchnorm forward");
m.def("reduce_bn", &reduce_bn_CUDA, "batchnorm backward reduce grad sum and bias/weight gradient"); m.def("reduce_bn", &reduce_bn_CUDA, "batchnorm backward reduce grad sum and bias/weight grad");
m.def("batchnorm_backward", &batchnorm_backward_CUDA, "batchnorm backward dgrad"); m.def("batchnorm_backward", &batchnorm_backward_CUDA, "batchnorm backward dgrad");
m.def("welford_mean_var_c_last", &welford_mean_var_c_last_CUDA, "welford mean variance nhwc");
m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc");
m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc");
m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc");
} }
This diff is collapsed.
...@@ -54,7 +54,11 @@ m = inp_r.mean(1) ...@@ -54,7 +54,11 @@ m = inp_r.mean(1)
b_v = inp_r.var(1, unbiased=False) b_v = inp_r.var(1, unbiased=False)
unb_v = inp_r.var(1, unbiased=True) unb_v = inp_r.var(1, unbiased=True)
mean, var, var_biased = syncbn.welford_mean_var(inp_t) eps = 1e-5
#mean, var, var_biased = syncbn.welford_mean_var(inp_t)
mean, var_biased = syncbn.welford_mean_var(inp_t)
inv_std = 1.0 / torch.sqrt(var_biased + eps)
bn = torch.nn.BatchNorm2d(feature_size).cuda() bn = torch.nn.BatchNorm2d(feature_size).cuda()
bn.momentum = 1.0 bn.momentum = 1.0
...@@ -74,16 +78,25 @@ grad_sbn = grad_output_t.clone().detach() ...@@ -74,16 +78,25 @@ grad_sbn = grad_output_t.clone().detach()
out_sbn = sbn(inp_sbn) out_sbn = sbn(inp_sbn)
out_sbn.backward(grad_sbn) out_sbn.backward(grad_sbn)
sbn_c_last = apex.parallel.SyncBatchNorm(feature_size, channel_last=True).cuda()
sbn_c_last.momentum = 1.0
sbn_c_last.weight.data = weight_t.clone()
sbn_c_last.bias.data = bias_t.clone()
inp_sbn_c_last = inp_t.clone().transpose(-1, 1).contiguous().requires_grad_()
grad_sbn_c_last = grad_output_t.clone().transpose(-1, 1).contiguous().detach()
out_sbn_c_last = sbn_c_last(inp_sbn_c_last)
out_sbn_c_last.backward(grad_sbn_c_last)
sbn_result = True sbn_result = True
sbn_result_c_last = True
bn_result = True bn_result = True
sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result
sbn_result = compare("comparing variance: ", var, unb_v, error) and sbn_result #sbn_result = compare("comparing variance: ", var, unb_v, error) and sbn_result
sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result
eps = 1e-5
out = syncbn.batchnorm_forward(inp_t, mean, var_biased, weight_t, bias_t, eps) out = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t)
out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r
sbn_result = compare("comparing output: ", out, out_r, error) and sbn_result sbn_result = compare("comparing output: ", out, out_r, error) and sbn_result
...@@ -102,8 +115,8 @@ mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).co ...@@ -102,8 +115,8 @@ mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).co
grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1) grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, var_biased, weight_t, eps) mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, var_biased, weight_t, mean_dy, mean_dy_xmu, eps) grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result
...@@ -112,7 +125,7 @@ sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) ...@@ -112,7 +125,7 @@ sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error)
compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error) compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)
sbn_result = compare("comparing sbn input grad: ", inp_sbn.grad, grad_input_r, error) and sbn_result sbn_result = compare("comparing sbn input grad: ", inp_sbn.grad, grad_input_r, error) and sbn_result
compare("comparing output: ", out_bn, out_sbn, error) compare("comparing bn/sbn output: ", out_bn, out_sbn, error)
sbn_result = compare("comparing running_mean: ", bn.running_mean.data, sbn.running_mean.data, error) and sbn_result sbn_result = compare("comparing running_mean: ", bn.running_mean.data, sbn.running_mean.data, error) and sbn_result
sbn_result = compare("comparing running_variance: ", bn.running_var.data, sbn.running_var.data, error) and sbn_result sbn_result = compare("comparing running_variance: ", bn.running_var.data, sbn.running_var.data, error) and sbn_result
compare("comparing grad_input: ", inp_bn.grad, inp_sbn.grad, error) compare("comparing grad_input: ", inp_bn.grad, inp_sbn.grad, error)
...@@ -123,7 +136,21 @@ compare("comparing grad_weight: ", bn.weight.grad, sbn.weight.grad, error) ...@@ -123,7 +136,21 @@ compare("comparing grad_weight: ", bn.weight.grad, sbn.weight.grad, error)
compare("comparing grad_weight bn to ref: ", bn.weight.grad, grad_weight_r, error) compare("comparing grad_weight bn to ref: ", bn.weight.grad, grad_weight_r, error)
sbn_result = compare("comparing grad_weight sbn to ref: ", sbn.weight.grad, grad_weight_r, error) and sbn_result sbn_result = compare("comparing grad_weight sbn to ref: ", sbn.weight.grad, grad_weight_r, error) and sbn_result
compare("comparing channel last bn/sbn output: ", out_bn, out_sbn_c_last.transpose(-1, 1).contiguous(), error)
sbn_result_c_last = compare("comparing channel last running_mean: ", bn.running_mean.data, sbn_c_last.running_mean.data, error) and sbn_result_c_last
sbn_result_c_last = compare("comparing channel last running_variance: ", bn.running_var.data, sbn_c_last.running_var.data, error) and sbn_result_c_last
compare("comparing channel last grad_input: ", inp_bn.grad, inp_sbn_c_last.grad.transpose(-1, 1).contiguous(), error)
compare("comparing channel last grad_bias: ", bn.bias.grad, sbn_c_last.bias.grad, error)
sbn_result_c_last = compare("comparing channel last grad_bias sbn to ref: ", sbn_c_last.bias.grad, grad_bias_r, error) and sbn_result_c_last
compare("comparing channel last grad_weight: ", bn.weight.grad, sbn_c_last.weight.grad, error)
sbn_result_c_last = compare("comparing channel last grad_weight sbn to ref: ", sbn_c_last.weight.grad, grad_weight_r, error) and sbn_result_c_last
if sbn_result: if sbn_result:
print("====SBN single gpu passed tests") print("====SBN single gpu passed tests")
else: else:
print("*SBN single gpu failed*") print("*SBN single gpu failed*")
if sbn_result_c_last:
print("====SBN channel last single gpu passed tests")
else:
print("*SBN channel last single gpu failed*")
...@@ -75,7 +75,10 @@ m = inp_r.mean(1) ...@@ -75,7 +75,10 @@ m = inp_r.mean(1)
b_v = inp_r.var(1, unbiased=False) b_v = inp_r.var(1, unbiased=False)
unb_v = inp_r.var(1, unbiased=True) unb_v = inp_r.var(1, unbiased=True)
mean, var, var_biased = syncbn.welford_mean_var(inp_t) eps = 1e-5
mean, var_biased = syncbn.welford_mean_var(inp_t)
inv_std = 1.0 / torch.sqrt(var_biased + eps)
bn = torch.nn.BatchNorm2d(feature_size).cuda() bn = torch.nn.BatchNorm2d(feature_size).cuda()
bn.momentum = 1.0 bn.momentum = 1.0
...@@ -111,12 +114,9 @@ bn_result = True ...@@ -111,12 +114,9 @@ bn_result = True
if args.local_rank == 0: if args.local_rank == 0:
sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result
sbn_result = compare("comparing variance: ", var, unb_v, error) and sbn_result
sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result
eps = 1e-5 out = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t)
out = syncbn.batchnorm_forward(inp_t, mean, var_biased, weight_t, bias_t, eps)
out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r
if args.local_rank == 0: if args.local_rank == 0:
...@@ -136,8 +136,8 @@ mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).co ...@@ -136,8 +136,8 @@ mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).co
grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1) grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, var_biased, weight_t, eps) mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, var_biased, weight_t, mean_dy, mean_dy_xmu, eps) grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)
if args.local_rank == 0: if args.local_rank == 0:
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment