Unverified Commit 3dd36070 authored by Asit's avatar Asit Committed by GitHub
Browse files

Merge pull request #1 from NVIDIA/master

Updating my repo
parents 02a33875 3104fd59
......@@ -40,3 +40,7 @@ def scalar_python_val(x):
return x.data[0]
else:
return x[0]
# Accounts for the possibility that some ops may be removed from a namespace.
def filter_attrs(module, attrs):
return list(attrname for attrname in attrs if hasattr(module, attrname))
......@@ -11,20 +11,20 @@ MODULE = torch.Tensor
# MODULE = torch.autograd.Variable
FP16_FUNCS = [
FP16_FUNCS = compat.filter_attrs(MODULE, [
'__matmul__',
]
])
FP32_FUNCS = [
FP32_FUNCS = compat.filter_attrs(MODULE, [
'__ipow__',
'__pow__',
'__rpow__',
# Cast to fp32 before transfer to CPU
'cpu',
]
])
CASTS = [
CASTS = compat.filter_attrs(MODULE, [
'__add__',
'__div__',
'__eq__',
......@@ -46,7 +46,7 @@ CASTS = [
'__rtruediv__',
'__sub__',
'__truediv__',
]
])
# None of these, but here to make code cleaner.
SEQUENCE_CASTS = []
......
import math
import torch
from torch import nn
from torch.nn import Parameter
......@@ -76,7 +78,11 @@ class EncdecMultiheadAttn(nn.Module):
def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight_q)
nn.init.xavier_uniform_(self.in_proj_weight_kv)
# in_proj_weight_kv has shape [2 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (2 * hidden + hidden)) = sqrt(1.5)
# therefore xavier_uniform gain should be set to sqrt(1.5).
nn.init.xavier_uniform_(self.in_proj_weight_kv, gain=math.sqrt(1.5))
nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias:
nn.init.constant_(self.in_proj_bias_q, 0.)
......
import math
import torch
from torch import nn
from torch.nn import Parameter
......@@ -98,7 +100,11 @@ class SelfMultiheadAttn(nn.Module):
nn.init.xavier_uniform_(self.k_weight)
nn.init.xavier_uniform_(self.v_weight)
else:
nn.init.xavier_uniform_(self.in_proj_weight)
# in_proj_weight has shape [3 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (3 * hidden + hidden)) = sqrt(2)
# therefore xavier_uniform gain should be set to sqrt(2).
nn.init.xavier_uniform_(self.in_proj_weight, gain=math.sqrt(2))
nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias:
if self.separate_qkv_params:
......
......@@ -6,7 +6,7 @@ torchvision_imported=True
try:
import torchvision
except ImportError:
print("[ASP][Warning] torchvision cannot be imported, may infuence functionality of MaskRCNN/KeypointRCNN network from torchvision.")
print("[ASP][Warning] torchvision cannot be imported.")
torchvision_imported=False
def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallowed_layer_names):
......@@ -78,7 +78,7 @@ class ASP:
# function to extract variables that will be sparsified.
# idea is that you will add one of these functions for each module type that can be sparsified.
if torchvision_imported:
print("[ASP] torchvision is imported, can work smoothly with the MaskRCNN/KeypointRCNN from torchvision.")
print("[ASP] torchvision is imported, can work with the MaskRCNN/KeypointRCNN from torchvision.")
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torchvision.ops.misc.Conv2d: ['weight']}
else:
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight']}
......@@ -102,7 +102,7 @@ class ASP:
print("[ASP] Sparsifying %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype)))
mask = torch.ones_like(p).bool()
buffname = name.split(".")[-1] # buffer names cannot contain "."
buffname = p_name.split(".")[-1] # buffer names cannot contain "."
module.register_buffer('__%s_mma_mask' % buffname, mask)
if allow_recompute_mask:
pruned = torch.zeros_like(p).cpu()
......
......@@ -5,7 +5,7 @@ import collections
from itertools import permutations
""" compute density (helper fn to compute % NNZs in a tensor)"""
""" compute density (helper fn to compute % NNZs in a tensor) """
def fill(x):
return float(x.nonzero().size(0))/torch.numel(x)
......@@ -20,7 +20,7 @@ def reshape_1d(matrix, m):
else:
return matrix.view(-1,m), matrix.shape
""" return all possible m:n patterns in a 1d vector. """
""" return all possible m:n patterns in a 1d vector """
valid_m4n2_1d_patterns = None
def compute_valid_1d_patterns(m,n):
# Early exit if patterns was already created.
......@@ -49,8 +49,21 @@ def mn_1d_best(matrix, m, n):
def m4n2_1d(mat, density):
return mn_1d_best(mat, 4, 2)
""" Comment: Following 2d masking related code (for training) can be removed or marked experimental (78 LOC) """
""" m:n 2d structured greedy """
"""
Below 2d-masking related code is targeted more for training (from scratch).
2d-pruning of a weight tensor is done to accelerate DGRAD step during backprop
phase of training algorithm. Acceleration comes from using SpMMA instructions in
Tensor Cores of NVIDIA Ampere GPU Architecture
(note: this code does not do the acceleration, GPU kernels are required for this).
1d pruning of weight tensor helps speed up FPROP step by pruning in 2:4 pattern
along the horizontal (logical) direction.
During DGRAD step, weight tensor is transposed. 2d pruning functions below, mask
weight tensor such that their transposed versions are also 2:4 sparse along the
horizontal (logical) direction. Thus, with 2d pruning, weight tensors are
2:4 sparse along row and column directions.
"""
""" m:n 2d structured pruning: greedy method to select mask """
def mn_2d_greedy(matrix, m, n):
# Convert to numpy
mat = matrix.cpu().detach().numpy()
......@@ -105,7 +118,7 @@ def compute_valid_2d_patterns(m,n):
if m == 4 and n == 2: valid_m4n2_2d_patterns = valid_patterns
return valid_patterns
""" m:n 2d structured best """
""" m:n 2d structured pruning: exhaustive method to select best mask """
def mn_2d_best(matrix, m, n):
# Find all possible patterns.
patterns = compute_valid_2d_patterns(m,n).cuda()
......@@ -127,6 +140,7 @@ def mn_2d_best(matrix, m, n):
def m4n2_2d_best(mat, density):
return mn_2d_best(mat, 4, 2)
""" returns a sparse mask """
def create_mask(tensor, pattern="m4n2_1d", density=0.5):
# Reshape tensor and mask.
......
......@@ -28,16 +28,24 @@ class SyncBatchnormFunction(Function):
if torch.distributed.is_initialized():
if not process_group:
process_group = torch.distributed.group.WORLD
device = mean.device
world_size = torch.distributed.get_world_size(process_group)
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=var_biased.device)
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=device)
var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=device)
count_all = torch.cuda.IntTensor(world_size, device=device)
mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)]
var_l = [var_all.narrow(0, i, 1) for i in range(world_size)]
count_l = [count_all.narrow(0, i, 1) for i in range(world_size)]
torch.distributed.all_gather(mean_l, mean, process_group)
torch.distributed.all_gather(var_l, var_biased, process_group)
mean, var, inv_std = syncbn.welford_parallel(mean_all, var_all, count, eps)
# TODO(Jie): should do fp32 math instead!
torch.distributed.all_gather(
count_l,
torch.cuda.IntTensor([count], device=device),
process_group)
mean, var, inv_std = syncbn.welford_parallel(mean_all, var_all, count_all, eps)
else:
device = mean.device
count_all = torch.cuda.IntTensor([count], device=device)
inv_std = 1.0 / torch.sqrt(var_biased + eps)
var = var_biased * (count) / (count-1)
......@@ -52,7 +60,7 @@ class SyncBatchnormFunction(Function):
mean = running_mean.data
inv_std = 1.0 / torch.sqrt(running_variance.data + eps)
ctx.save_for_backward(input, weight, mean, inv_std, z, bias)
ctx.save_for_backward(input, weight, mean, inv_std, z, bias, count_all)
ctx.process_group = process_group
ctx.channel_last = channel_last
ctx.world_size = world_size
......@@ -71,7 +79,7 @@ class SyncBatchnormFunction(Function):
# mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input, weight, mean, inv_std, z, bias = ctx.saved_tensors
saved_input, weight, mean, inv_std, z, bias, count = ctx.saved_tensors
process_group = ctx.process_group
channel_last = ctx.channel_last
world_size = ctx.world_size
......@@ -83,26 +91,24 @@ class SyncBatchnormFunction(Function):
if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]:
grad_z = grad_output.clone()
# TODO(jie): why do I have to clone here? life time of grad_output?
# TODO: update kernel to not pre_divide by item_num
if channel_last:
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(grad_output, saved_input, mean, inv_std, weight)
sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(grad_output, saved_input, mean, inv_std, weight)
else:
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, mean, inv_std, weight)
sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, mean, inv_std, weight)
# calculate grad_input
if ctx.needs_input_grad[0]:
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
mean_dy, ReduceOp.SUM, process_group)
mean_dy = mean_dy / world_size
sum_dy, ReduceOp.SUM, process_group)
torch.distributed.all_reduce(
mean_dy_xmu, ReduceOp.SUM, process_group)
mean_dy_xmu = mean_dy_xmu / world_size
sum_dy_xmu, ReduceOp.SUM, process_group)
if channel_last:
grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count)
else:
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count)
if weight is None or not ctx.needs_input_grad[2]:
grad_weight = None
......
......@@ -12,7 +12,7 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input);
// implemented using welford
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
const at::Tensor var_biased_feature_nodes,
int numel,
const at::Tensor numel,
const float eps);
// elementwise BN operation, returns output
......@@ -24,7 +24,7 @@ at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// implemented using kahan summation
......@@ -36,14 +36,15 @@ std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32
at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
const at::Tensor sum_dy,
const at::Tensor sum_dy_xmu,
const at::Tensor count);
// returns {mean, biased_var}
// implemented using welford
......@@ -62,7 +63,7 @@ at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::optional<at::Tensor> shift,
const bool fuse_relu);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
......@@ -74,15 +75,16 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
const at::Tensor sum_dy,
const at::Tensor sum_dy_xmu,
const at::Tensor count);
at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
......
......@@ -327,15 +327,15 @@ __global__ void reduce_bn_kernel(
const scalar_t* __restrict__ grad_output,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
accscalar_t* __restrict__ mean_dy,
accscalar_t* __restrict__ mean_dy_xmu,
accscalar_t* __restrict__ sum_dy_o,
accscalar_t* __restrict__ sum_dy_xmu_o,
layerscalar_t* __restrict__ grad_weight,
layerscalar_t* __restrict__ grad_bias,
const int bs,
const int fs,
const int ss) {
static __shared__ int s_mem[64];
int total_item_num = bs * ss;
//int total_item_num = bs * ss;
int thread_id = threadIdx.y*blockDim.x + threadIdx.x;
......@@ -377,8 +377,10 @@ __global__ void reduce_bn_kernel(
if (grad_weight != NULL) {
grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor);
}
mean_dy[blockIdx.x] = sum_dy / total_item_num;
mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num;
//mean_dy[blockIdx.x] = sum_dy / total_item_num;
//mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num;
sum_dy_o[blockIdx.x] = sum_dy;
sum_dy_xmu_o[blockIdx.x] = sum_dy_xmu;
}
}
......@@ -390,16 +392,24 @@ __global__ void batchnorm_backward_kernel(
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
const accscalar_t* __restrict__ mean_dy,
const accscalar_t* __restrict__ mean_dy_xmu,
const accscalar_t* __restrict__ sum_dy,
const accscalar_t* __restrict__ sum_dy_xmu,
const int* __restrict__ numel,
scalar_t* __restrict__ grad_input,
const int64_t world_size,
const int ss,
const int bs) {
int64_t div = 0;
for (int i = 0; i < world_size; i++) {
div += numel[i];
}
auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]);
auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
//auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
auto m_dy_c = static_cast<accscalar_t>(sum_dy[blockIdx.x]) / div;
auto factor_1_c = inv_std[blockIdx.x];
auto factor_2_c = (weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x])) * factor_1_c;
factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x];
//factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x];
factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[blockIdx.x] / div;
for (int batch_offset = blockIdx.y*blockDim.y+threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;
......@@ -559,13 +569,13 @@ template <typename scalar_t>
__global__ void welford_kernel_parallel(
const scalar_t* __restrict__ mean,
const scalar_t* __restrict__ var_biased,
const int* __restrict__ numel,
scalar_t* __restrict__ out_mean,
scalar_t* __restrict__ out_var,
scalar_t* __restrict__ inv_std,
const int world_size,
const int feature_size,
const float eps,
const int numel) {
const float eps) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < feature_size; i += gridDim.x * blockDim.x) {
// load data;
......@@ -574,7 +584,7 @@ __global__ void welford_kernel_parallel(
scalar_t m_2_n = 0;
int count = 0;
for (int j = 0; j < world_size; j++) {
welford_merge_element(count, x_mean, m_2_n, numel, mean[address], var_biased[address]*numel);
welford_merge_element(count, x_mean, m_2_n, numel[j], mean[address], var_biased[address]*numel[j]);
address += feature_size;
}
out_mean[i] = x_mean;
......@@ -694,8 +704,8 @@ __global__ void reduce_bn_c_last_kernel(
const scalar_t* __restrict__ grad_output,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
accscalar_t* __restrict__ mean_dy,
accscalar_t* __restrict__ mean_dy_xmu,
accscalar_t* __restrict__ sum_dy_o,
accscalar_t* __restrict__ sum_dy_xmu_o,
layerscalar_t* __restrict__ grad_weight,
layerscalar_t* __restrict__ grad_bias,
volatile accscalar_t* staging_data,
......@@ -814,8 +824,10 @@ __global__ void reduce_bn_c_last_kernel(
if (grad_weight != NULL) {
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
}
mean_dy[c_offset] = sum_dy_th / reduction_size;
mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
//mean_dy[c_offset] = sum_dy_th / reduction_size;
//mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
sum_dy_o[c_offset] = sum_dy_th;
sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
}
}
} else {
......@@ -826,8 +838,10 @@ __global__ void reduce_bn_c_last_kernel(
if (grad_weight != NULL) {
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
}
mean_dy[c_offset] = sum_dy_th / reduction_size;
mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
//mean_dy[c_offset] = sum_dy_th / reduction_size;
//mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
sum_dy_o[c_offset] = sum_dy_th;
sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
}
}
}
......@@ -844,11 +858,17 @@ __global__ void batchnorm_backward_c_last_kernel(
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
const accscalar_t* __restrict__ mean_dy,
const accscalar_t* __restrict__ mean_dy_xmu,
const accscalar_t* __restrict__ sum_dy,
const accscalar_t* __restrict__ sum_dy_xmu,
const int* __restrict__ numel,
scalar_t* __restrict__ grad_input,
const int64_t world_size,
const int reduction_size,
const int stride) {
int64_t div = 0;
for (int i = 0; i < world_size; i++) {
div += numel[i];
}
// tensor dimension (m,c)
// loop along m dimension
int inner_loop_stride = blockDim.y * gridDim.y;
......@@ -858,10 +878,10 @@ __global__ void batchnorm_backward_c_last_kernel(
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
auto m_c = mean[c_offset];
auto m_dy_c = mean_dy[c_offset];
auto m_dy_c = sum_dy[c_offset] / div;
auto factor_1_c = inv_std[c_offset];
auto factor_2_c = (weight == NULL? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;
factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[c_offset];
factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] / div;
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
int address_base = m_offset * stride + c_offset;
......@@ -986,8 +1006,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
auto scalar_type = promote_scalartype(input);
at::Tensor mean_dy = at::empty({feature_size}, mean.options());
at::Tensor mean_dy_xmu = at::empty({feature_size}, mean.options());
at::Tensor sum_dy = at::empty({feature_size}, mean.options());
at::Tensor sum_dy_xmu = at::empty({feature_size}, mean.options());
at::Tensor grad_weight;
at::Tensor grad_bias;
......@@ -1018,8 +1038,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
grad_output.DATA_PTR<scalar_t_0>(),
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
weight.has_value() ? grad_weight.DATA_PTR<accscalar_t>() : NULL,
weight.has_value() ? grad_bias.DATA_PTR<accscalar_t>() : NULL,
batch_size,
......@@ -1039,8 +1059,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
grad_output.DATA_PTR<scalar_t_0>(),
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
weight.has_value() ? grad_weight.DATA_PTR<scalar_t_0>() : NULL,
weight.has_value() ? grad_bias.DATA_PTR<scalar_t_0>() : NULL,
batch_size,
......@@ -1049,7 +1069,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
);
}
return {mean_dy, mean_dy_xmu, grad_weight, grad_bias};
return {sum_dy, sum_dy_xmu, grad_weight, grad_bias};
}
at::Tensor batchnorm_backward_CUDA(
......@@ -1058,8 +1078,9 @@ at::Tensor batchnorm_backward_CUDA(
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu) {
const at::Tensor sum_dy,
const at::Tensor sum_dy_xmu,
const at::Tensor count) {
const auto batch_size = input.size(0);
const auto feature_size = input.size(1);
......@@ -1088,9 +1109,11 @@ at::Tensor batchnorm_backward_CUDA(
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
count.DATA_PTR<int>(),
grad_input.DATA_PTR<scalar_t_0>(),
count.numel(),
space_size,
batch_size);
);
......@@ -1108,9 +1131,11 @@ at::Tensor batchnorm_backward_CUDA(
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
count.DATA_PTR<int>(),
grad_input.DATA_PTR<scalar_t_0>(),
count.numel(),
space_size,
batch_size);
);
......@@ -1121,7 +1146,7 @@ at::Tensor batchnorm_backward_CUDA(
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
const at::Tensor var_biased,
int numel,
const at::Tensor numel,
const float eps) {
const auto world_size = mean_feature_nodes.size(0);
const auto feature_size = mean_feature_nodes.size(1);
......@@ -1142,13 +1167,13 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
welford_kernel_parallel<scalar_t_0><<<grid, block, 0, stream>>>(
mean_feature_nodes.DATA_PTR<scalar_t_0>(),
var_biased.DATA_PTR<scalar_t_0>(),
numel.DATA_PTR<int>(),
out_mean.DATA_PTR<scalar_t_0>(),
out_var.DATA_PTR<scalar_t_0>(),
inv_std.DATA_PTR<scalar_t_0>(),
world_size,
feature_size,
eps,
numel);
eps);
);
}
......@@ -1270,8 +1295,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
at::Tensor mean_dy = at::empty({stride}, mean.options());
at::Tensor mean_dy_xmu = at::empty({stride}, mean.options());
at::Tensor sumn_dy = at::empty({stride}, mean.options());
at::Tensor sum_dy_xmu = at::empty({stride}, mean.options());
at::Tensor grad_weight;
at::Tensor grad_bias;
......@@ -1310,8 +1335,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
grad_output.DATA_PTR<scalar_t_0>(),
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sumn_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
weight.has_value() ? grad_weight.DATA_PTR<accscalar_t>() : NULL,
weight.has_value() ?grad_bias.DATA_PTR<accscalar_t>() : NULL,
staging_data_ptr,
......@@ -1335,8 +1360,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
grad_output.DATA_PTR<scalar_t_0>(),
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sumn_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
weight.has_value() ? grad_weight.DATA_PTR<scalar_t_0>() : NULL,
weight.has_value() ?grad_bias.DATA_PTR<scalar_t_0>() : NULL,
staging_data_ptr,
......@@ -1346,7 +1371,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
);
}
return {mean_dy, mean_dy_xmu, grad_weight, grad_bias};
return {sumn_dy, sum_dy_xmu, grad_weight, grad_bias};
}
at::Tensor batchnorm_backward_c_last_CUDA(
......@@ -1355,8 +1380,9 @@ at::Tensor batchnorm_backward_c_last_CUDA(
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu) {
const at::Tensor sum_dy,
const at::Tensor sum_dy_xmu,
const at::Tensor count) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
......@@ -1380,9 +1406,11 @@ at::Tensor batchnorm_backward_c_last_CUDA(
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
count.DATA_PTR<int>(),
grad_input.DATA_PTR<scalar_t_0>(),
count.numel(),
reduction_size,
stride);
);
......@@ -1401,9 +1429,11 @@ at::Tensor batchnorm_backward_c_last_CUDA(
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
count.DATA_PTR<int>(),
grad_input.DATA_PTR<scalar_t_0>(),
count.numel(),
reduction_size,
stride);
);
......
import unittest
import os
import torch
from torch.optim import Optimizer
import apex
from apex.multi_tensor_apply import multi_tensor_applier
class RefLAMB(Optimizer):
r"""Implements Lamb algorithm.
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-6)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01)
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(RefLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available:
import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_lamb = amp_C.multi_tensor_lamb
else:
raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
# create separate grad lists for fp32 and fp16 params
g_all_32, g_all_16 = [], []
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
if p.dtype == torch.float32:
g_all_32.append(p.grad.data)
elif p.dtype == torch.float16:
g_all_16.append(p.grad.data)
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
g_norm_32, g_norm_16 = torch.zeros(1, device='cuda'), torch.zeros(1, device='cuda')
# compute grad norm for two lists
if len(g_all_32) > 0:
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_32], False)[0]
if len(g_all_16) > 0:
g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_16], False)[0]
# blend two grad norms to get global grad norm
global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[[g_norm_32, g_norm_16]],
False)[0]
max_grad_norm = 1.0
clipped_ratio = max_grad_norm / max(global_grad_norm, max_grad_norm)
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
p.grad.data *= clipped_ratio
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['m'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['v'] = torch.zeros_like(p.data)
m_t, v_t = state['m'], state['v']
beta1, beta2 = group['betas']
state['step'] += 1
# m_t = beta1 * m + (1 - beta1) * g_t
m_t.mul_(beta1).add_(grad, alpha=1-beta1)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v_t.mul_(beta2).addcmul_(grad, grad, value=1-beta2)
# Debiasing
m_t_hat = m_t / (1.0 - beta1 ** state['step'])
v_t_hat = v_t / (1.0 - beta2 ** state['step'])
update = m_t_hat / v_t_hat.sqrt().add(group['eps'])
if group['weight_decay'] != 0:
update.add_(p.data, alpha=group['weight_decay'])
trust_ratio = 1.0
w_norm = p.data.pow(2).sum().sqrt()
g_norm = update.pow(2).sum().sqrt()
if w_norm > 0 and g_norm > 0:
trust_ratio = w_norm / g_norm
state['w_norm'] = w_norm
state['g_norm'] = g_norm
state['trust_ratio'] = trust_ratio
step_size = group['lr']
p.data.add_(update, alpha=-step_size*trust_ratio)
return loss
class TestFusedLAMB(unittest.TestCase):
def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
self.max_abs_diff = max_abs_diff
self.max_rel_diff = max_rel_diff
self.iters = iters
torch.cuda.manual_seed(9876)
def tearDown(self):
pass
def gen_param_optim(self, tensors, lamb_option):
ref_param = []
tst_param = []
for tensor in tensors:
ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone()))
ref_optim = RefLAMB(ref_param, **lamb_option)
tst_optim = apex.optimizers.FusedLAMB(tst_param, use_nvlamb=True, **lamb_option)
return (ref_param, tst_param, ref_optim, tst_optim)
def gen_grad(self, ref_param, tst_param):
for p_ref, p_tst in zip(ref_param, tst_param):
p_ref.grad = torch.rand_like(p_ref)
p_tst.grad = p_ref.grad
def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
half_grads = []
for p_ref, _ in zip(ref_param, tst_param):
half_grads.append(torch.rand_like(p_ref).half())
p_ref.grad = half_grads[-1].float() / scale
return half_grads
def get_max_diff(self, ref_param, tst_param):
max_abs_diff = max_rel_diff = 0
for p_ref, p_tst in zip(ref_param, tst_param):
max_abs_diff_p = (p_ref - p_tst).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()
if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p
if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p
return max_abs_diff, max_rel_diff
def gen_single_type_test(self, param_type=torch.float):
nelem = 278011
tensor = torch.rand(nelem, dtype=param_type, device='cuda')
weight_decay = [0, 0.01]
for wd in weight_decay:
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], lamb_option)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_float(self):
self.gen_single_type_test(param_type=torch.float)
@unittest.skip("PyTorch optimizer is not numerically correct for fp16")
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)
def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
weight_decay = [0, 0.01]
for wd in weight_decay:
lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
tensors = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim(tensors, lamb_option)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_lamb_option(self):
nelem = 1
tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
weight_decay = [0, 0.01]
for wd in weight_decay:
lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':wd}
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim([tensor], lamb_option)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step()
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
if __name__ == '__main__':
script_path = os.path.dirname(os.path.realpath(__file__))
unittest.main()
......@@ -35,6 +35,7 @@ inp = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype
grad = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype(dtype)
weight = (np.random.randn(feature_size)).astype(dtype)
bias = (np.random.randn(feature_size)).astype(dtype)
count = torch.cuda.IntTensor([batch_size*space_size**2])
type_tensor = torch.cuda.FloatTensor
ref_tensor = torch.cuda.DoubleTensor
......@@ -110,17 +111,19 @@ grad_output2_r = ref_tensor(grad)
grad_bias_r = grad_output_r.sum(1)
grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
sum_dy_r = grad_output_r.sum(1)
mean_dy_r = grad_output_r.mean(1)
sum_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1)
grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)
sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, sum_dy, sum_dy_xmu, count)
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result
sbn_result = compare("comparing mean_dy_xmu grad: ", mean_dy_xmu, mean_dy_xmu_r, error) and sbn_result
sbn_result = compare("comparing sum_dy grad: ", sum_dy, sum_dy_r, error) and sbn_result
sbn_result = compare("comparing sum_dy_xmu grad: ", sum_dy_xmu, sum_dy_xmu_r, error) and sbn_result
sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result
compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)
sbn_result = compare("comparing sbn input grad: ", inp_sbn.grad, grad_input_r, error) and sbn_result
......
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from apex.parallel import SyncBatchNorm as ApexSyncBatchNorm
import argparse
import os
import numpy as np
var_batch = 16
def compare(desc, inp1, inp2, error= 1e-5):
a = inp1.clone().detach().cpu().numpy()
b = inp2.clone().detach().cpu().numpy()
close = np.allclose(a,b, error, error)
if not close:
print(desc, close)
z = a - b
index = (np.abs(z) >= error + error * np.abs(b)).nonzero()
print("dif : ", z[index])
print("inp1 : ", a[index])
print("inp2 : ", b[index])
return close
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--apex', action='store_true')
args = parser.parse_args()
torch.manual_seed(2809)
# Setup DDP
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda:{}'.format(args.local_rank))
torch.distributed.init_process_group(
'nccl',
init_method='env://',
rank=args.local_rank,
)
# Setup model
if args.apex:
model = nn.Sequential(
nn.Conv2d(3, 6, 3, 1, 1),
ApexSyncBatchNorm(6)
)
else:
model = nn.Sequential(
nn.Conv2d(3, 6, 3, 1, 1),
nn.SyncBatchNorm(6)
)
# Setup reference model
model_reference = nn.Sequential(
nn.Conv2d(3, 6, 3, 1, 1),
nn.BatchNorm2d(6)
)
with torch.no_grad():
model_reference[0].weight.copy_(model[0].weight)
model_reference[0].bias.copy_(model[0].bias)
model_reference.to(device)
model = model.to(device)
model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
global_batch_size = var_batch + 8
# Create random data
if args.local_rank == 0:
data = torch.randn(var_batch, 3, 8, 8, device=device, dtype=torch.float) * 50.0
grad = torch.randint(0, 10, (var_batch, 6, 8, 8), device=device, dtype=torch.float) / 10.0
else:
data = torch.randn(8, 3, 8, 8, device=device)
grad = torch.randint(0, 10, (8, 6, 8, 8), device=device, dtype=torch.float) / 10.0
data.requires_grad_()
data.retain_grad = True
weighted_gradient = True
# DDP forward/backward
output = model(data)
if weighted_gradient:
output.backward(grad * 2 / global_batch_size)
else:
output.backward(grad / output.size(0))
d_list = [torch.randn(8, 3, 8, 8, device=device) for i in range(int(os.environ['WORLD_SIZE']))]
y_list = [torch.randn(8, 6, 8, 8, device=device) for i in range(int(os.environ['WORLD_SIZE']))]
dgrad_list = [torch.randn(8, 3, 8, 8, device=device) for i in range(int(os.environ['WORLD_SIZE']))]
grad_list = [torch.randn(8, 6, 8, 8, device=device) for i in range(int(os.environ['WORLD_SIZE']))]
if args.local_rank == 0:
# placeholder, these random data will later be discarded.
torch.distributed.all_gather(d_list, torch.randn(8, 3, 8, 8, device=device))
torch.distributed.all_gather(y_list, torch.randn(8, 6, 8, 8, device=device))
torch.distributed.all_gather(dgrad_list, torch.randn(8, 3, 8, 8, device=device))
torch.distributed.all_gather(grad_list, torch.randn(8, 6, 8, 8, device=device))
else:
torch.distributed.all_gather(d_list, data)
torch.distributed.all_gather(y_list, output)
torch.distributed.all_gather(dgrad_list, data.grad)
torch.distributed.all_gather(grad_list, grad)
torch.distributed.barrier()
if args.local_rank == 0:
ref_tensor = d_list[1:]
ref_tensor.insert(0, data)
assert(ref_tensor[0].equal(data))
ref_tensor = torch.cat(ref_tensor, 0)
ref_tensor = ref_tensor.detach()
ref_tensor.requires_grad_()
ref_tensor.retain_grad()
# Reference forward/backward
output_reference = model_reference(ref_tensor)
grad_tensor = grad_list[1:]
grad_tensor.insert(0, grad)
assert(grad_tensor[0].equal(grad))
grad_tensor = torch.cat(grad_tensor, 0)
if weighted_gradient:
output_reference.backward(grad_tensor / output_reference.size(0))
else:
output_reference.backward(grad_tensor / output_reference.size(0))
dgrad_tensor = dgrad_list[1:]
dgrad_tensor.insert(0, data.grad)
dgrad_tensor = torch.cat(dgrad_tensor, 0)
# check output
output_tensor = y_list[1:]
output_tensor.insert(0, output)
output_tensor = torch.cat(output_tensor, 0)
passed = True
passed = passed and compare("check output",
output_tensor,
output_reference)
# check stats
passed = passed and compare("check running mean failed",
model_reference[1].running_mean,
model.module[1].running_mean)
passed = passed and compare("check running var failed",
model_reference[1].running_var,
model.module[1].running_var)
passed = passed and compare("bn wgrad check failed!",
model_reference[1].weight.grad,
model.module[1].weight.grad, 1e-6)
passed = passed and compare("conv wgrad check failed!",
model_reference[0].weight.grad,
model.module[0].weight.grad)
# can't really compare dgrad directly, as we need to scale it to account for
# DDP
# passed = passed and compare("dgrad check failed!", ref_tensor.grad, dgrad_tensor)
if passed:
print("====SBN two gpu with different batches test passed")
else:
assert("*failed two gpu with different batches tests*")
......@@ -114,6 +114,11 @@ grad_sbn = grad_output_t.clone().detach()
out_sbn = sbn(inp_sbn[start:finish])
out_sbn.backward(grad_sbn[start:finish])
count = [ space_size**2 * ( (i+1) * batch_size // args.world_size - i * batch_size // args.world_size ) for i in range(0, args.world_size)]
count = torch.cuda.IntTensor(count)
print("--- count : " , count)
sbn_result = True
bn_result = True
......@@ -136,18 +141,20 @@ grad_output2_r = ref_tensor(grad)
grad_bias_r = grad_output_r.sum(1)
grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
sum_dy_r = grad_output_r.sum(1)
mean_dy_r = grad_output_r.mean(1)
mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1)
sum_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)
sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, sum_dy, sum_dy_xmu, count)
if args.local_rank == 0:
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result
sbn_result = compare("comparing mean_dy_xmu grad: ", mean_dy_xmu, mean_dy_xmu_r, error) and sbn_result
sbn_result = compare("comparing sum_dy grad: ", sum_dy, sum_dy_r, error) and sbn_result
sbn_result = compare("comparing sum_dy_xmu grad: ", sum_dy_xmu, sum_dy_xmu_r, error) and sbn_result
sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result
compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)
......
......@@ -3,5 +3,6 @@ python single_gpu_unit_test.py
python test_batchnorm1d.py
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py --fp16
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_test_different_batch_size.py --apex
#beware, you need a system with at least 4 gpus to test group_size<world_size
#python -m torch.distributed.launch --nproc_per_node=4 test_groups.py --group_size=2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment