Unverified Commit 1ff54b8f authored by jjsjann123's avatar jjsjann123 Committed by GitHub
Browse files

[sync BN] (#792)

* [sync BN]

support non-uniform batch size across process group.

TODO: test should be added once cleaned up.

* updating unit tests

* new unit tests for different inputs

* cleaning
parent 43a6f9fe
......@@ -28,16 +28,24 @@ class SyncBatchnormFunction(Function):
if torch.distributed.is_initialized():
if not process_group:
process_group = torch.distributed.group.WORLD
device = mean.device
world_size = torch.distributed.get_world_size(process_group)
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=var_biased.device)
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=device)
var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=device)
count_all = torch.cuda.IntTensor(world_size, device=device)
mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)]
var_l = [var_all.narrow(0, i, 1) for i in range(world_size)]
count_l = [count_all.narrow(0, i, 1) for i in range(world_size)]
torch.distributed.all_gather(mean_l, mean, process_group)
torch.distributed.all_gather(var_l, var_biased, process_group)
mean, var, inv_std = syncbn.welford_parallel(mean_all, var_all, count, eps)
# TODO(Jie): should do fp32 math instead!
torch.distributed.all_gather(
count_l,
torch.cuda.IntTensor([count], device=device),
process_group)
mean, var, inv_std = syncbn.welford_parallel(mean_all, var_all, count_all, eps)
else:
device = mean.device
count_all = torch.cuda.IntTensor([count], device=device)
inv_std = 1.0 / torch.sqrt(var_biased + eps)
var = var_biased * (count) / (count-1)
......@@ -52,7 +60,7 @@ class SyncBatchnormFunction(Function):
mean = running_mean.data
inv_std = 1.0 / torch.sqrt(running_variance.data + eps)
ctx.save_for_backward(input, weight, mean, inv_std, z, bias)
ctx.save_for_backward(input, weight, mean, inv_std, z, bias, count_all)
ctx.process_group = process_group
ctx.channel_last = channel_last
ctx.world_size = world_size
......@@ -71,7 +79,7 @@ class SyncBatchnormFunction(Function):
# mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input, weight, mean, inv_std, z, bias = ctx.saved_tensors
saved_input, weight, mean, inv_std, z, bias, count = ctx.saved_tensors
process_group = ctx.process_group
channel_last = ctx.channel_last
world_size = ctx.world_size
......@@ -83,26 +91,24 @@ class SyncBatchnormFunction(Function):
if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]:
grad_z = grad_output.clone()
# TODO(jie): why do I have to clone here? life time of grad_output?
# TODO: update kernel to not pre_divide by item_num
if channel_last:
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(grad_output, saved_input, mean, inv_std, weight)
sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(grad_output, saved_input, mean, inv_std, weight)
else:
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, mean, inv_std, weight)
sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, mean, inv_std, weight)
# calculate grad_input
if ctx.needs_input_grad[0]:
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
mean_dy, ReduceOp.SUM, process_group)
mean_dy = mean_dy / world_size
sum_dy, ReduceOp.SUM, process_group)
torch.distributed.all_reduce(
mean_dy_xmu, ReduceOp.SUM, process_group)
mean_dy_xmu = mean_dy_xmu / world_size
sum_dy_xmu, ReduceOp.SUM, process_group)
if channel_last:
grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count)
else:
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count)
if weight is None or not ctx.needs_input_grad[2]:
grad_weight = None
......
......@@ -12,7 +12,7 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input);
// implemented using welford
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
const at::Tensor var_biased_feature_nodes,
int numel,
const at::Tensor numel,
const float eps);
// elementwise BN operation, returns output
......@@ -24,7 +24,7 @@ at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// implemented using kahan summation
......@@ -36,14 +36,15 @@ std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32
at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
const at::Tensor sum_dy,
const at::Tensor sum_dy_xmu,
const at::Tensor count);
// returns {mean, biased_var}
// implemented using welford
......@@ -62,7 +63,7 @@ at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::optional<at::Tensor> shift,
const bool fuse_relu);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
......@@ -74,15 +75,16 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
// elementwise backward BN operation, returns grad_input
// grad_output/input/weight precision could be fp16/fp32;
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
const at::Tensor sum_dy,
const at::Tensor sum_dy_xmu,
const at::Tensor count);
at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
......
......@@ -327,15 +327,15 @@ __global__ void reduce_bn_kernel(
const scalar_t* __restrict__ grad_output,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
accscalar_t* __restrict__ mean_dy,
accscalar_t* __restrict__ mean_dy_xmu,
accscalar_t* __restrict__ sum_dy_o,
accscalar_t* __restrict__ sum_dy_xmu_o,
layerscalar_t* __restrict__ grad_weight,
layerscalar_t* __restrict__ grad_bias,
const int bs,
const int fs,
const int ss) {
static __shared__ int s_mem[64];
int total_item_num = bs * ss;
//int total_item_num = bs * ss;
int thread_id = threadIdx.y*blockDim.x + threadIdx.x;
......@@ -377,8 +377,10 @@ __global__ void reduce_bn_kernel(
if (grad_weight != NULL) {
grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor);
}
mean_dy[blockIdx.x] = sum_dy / total_item_num;
mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num;
//mean_dy[blockIdx.x] = sum_dy / total_item_num;
//mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num;
sum_dy_o[blockIdx.x] = sum_dy;
sum_dy_xmu_o[blockIdx.x] = sum_dy_xmu;
}
}
......@@ -390,16 +392,24 @@ __global__ void batchnorm_backward_kernel(
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
const accscalar_t* __restrict__ mean_dy,
const accscalar_t* __restrict__ mean_dy_xmu,
const accscalar_t* __restrict__ sum_dy,
const accscalar_t* __restrict__ sum_dy_xmu,
const int* __restrict__ numel,
scalar_t* __restrict__ grad_input,
const int64_t world_size,
const int ss,
const int bs) {
int64_t div = 0;
for (int i = 0; i < world_size; i++) {
div += numel[i];
}
auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]);
auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
//auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
auto m_dy_c = static_cast<accscalar_t>(sum_dy[blockIdx.x]) / div;
auto factor_1_c = inv_std[blockIdx.x];
auto factor_2_c = (weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x])) * factor_1_c;
factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x];
//factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x];
factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[blockIdx.x] / div;
for (int batch_offset = blockIdx.y*blockDim.y+threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;
......@@ -559,13 +569,13 @@ template <typename scalar_t>
__global__ void welford_kernel_parallel(
const scalar_t* __restrict__ mean,
const scalar_t* __restrict__ var_biased,
const int* __restrict__ numel,
scalar_t* __restrict__ out_mean,
scalar_t* __restrict__ out_var,
scalar_t* __restrict__ inv_std,
const int world_size,
const int feature_size,
const float eps,
const int numel) {
const float eps) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < feature_size; i += gridDim.x * blockDim.x) {
// load data;
......@@ -574,7 +584,7 @@ __global__ void welford_kernel_parallel(
scalar_t m_2_n = 0;
int count = 0;
for (int j = 0; j < world_size; j++) {
welford_merge_element(count, x_mean, m_2_n, numel, mean[address], var_biased[address]*numel);
welford_merge_element(count, x_mean, m_2_n, numel[j], mean[address], var_biased[address]*numel[j]);
address += feature_size;
}
out_mean[i] = x_mean;
......@@ -694,8 +704,8 @@ __global__ void reduce_bn_c_last_kernel(
const scalar_t* __restrict__ grad_output,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
accscalar_t* __restrict__ mean_dy,
accscalar_t* __restrict__ mean_dy_xmu,
accscalar_t* __restrict__ sum_dy_o,
accscalar_t* __restrict__ sum_dy_xmu_o,
layerscalar_t* __restrict__ grad_weight,
layerscalar_t* __restrict__ grad_bias,
volatile accscalar_t* staging_data,
......@@ -814,8 +824,10 @@ __global__ void reduce_bn_c_last_kernel(
if (grad_weight != NULL) {
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
}
mean_dy[c_offset] = sum_dy_th / reduction_size;
mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
//mean_dy[c_offset] = sum_dy_th / reduction_size;
//mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
sum_dy_o[c_offset] = sum_dy_th;
sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
}
}
} else {
......@@ -826,8 +838,10 @@ __global__ void reduce_bn_c_last_kernel(
if (grad_weight != NULL) {
grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor);
}
mean_dy[c_offset] = sum_dy_th / reduction_size;
mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
//mean_dy[c_offset] = sum_dy_th / reduction_size;
//mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size;
sum_dy_o[c_offset] = sum_dy_th;
sum_dy_xmu_o[c_offset] = sum_dy_xmu_th;
}
}
}
......@@ -844,11 +858,17 @@ __global__ void batchnorm_backward_c_last_kernel(
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
const accscalar_t* __restrict__ mean_dy,
const accscalar_t* __restrict__ mean_dy_xmu,
const accscalar_t* __restrict__ sum_dy,
const accscalar_t* __restrict__ sum_dy_xmu,
const int* __restrict__ numel,
scalar_t* __restrict__ grad_input,
const int64_t world_size,
const int reduction_size,
const int stride) {
int64_t div = 0;
for (int i = 0; i < world_size; i++) {
div += numel[i];
}
// tensor dimension (m,c)
// loop along m dimension
int inner_loop_stride = blockDim.y * gridDim.y;
......@@ -858,10 +878,10 @@ __global__ void batchnorm_backward_c_last_kernel(
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
auto m_c = mean[c_offset];
auto m_dy_c = mean_dy[c_offset];
auto m_dy_c = sum_dy[c_offset] / div;
auto factor_1_c = inv_std[c_offset];
auto factor_2_c = (weight == NULL? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c;
factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[c_offset];
factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] / div;
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
int address_base = m_offset * stride + c_offset;
......@@ -986,8 +1006,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
auto scalar_type = promote_scalartype(input);
at::Tensor mean_dy = at::empty({feature_size}, mean.options());
at::Tensor mean_dy_xmu = at::empty({feature_size}, mean.options());
at::Tensor sum_dy = at::empty({feature_size}, mean.options());
at::Tensor sum_dy_xmu = at::empty({feature_size}, mean.options());
at::Tensor grad_weight;
at::Tensor grad_bias;
......@@ -1018,8 +1038,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
grad_output.DATA_PTR<scalar_t_0>(),
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
weight.has_value() ? grad_weight.DATA_PTR<accscalar_t>() : NULL,
weight.has_value() ? grad_bias.DATA_PTR<accscalar_t>() : NULL,
batch_size,
......@@ -1039,8 +1059,8 @@ std::vector<at::Tensor> reduce_bn_CUDA(
grad_output.DATA_PTR<scalar_t_0>(),
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
weight.has_value() ? grad_weight.DATA_PTR<scalar_t_0>() : NULL,
weight.has_value() ? grad_bias.DATA_PTR<scalar_t_0>() : NULL,
batch_size,
......@@ -1049,7 +1069,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
);
}
return {mean_dy, mean_dy_xmu, grad_weight, grad_bias};
return {sum_dy, sum_dy_xmu, grad_weight, grad_bias};
}
at::Tensor batchnorm_backward_CUDA(
......@@ -1058,8 +1078,9 @@ at::Tensor batchnorm_backward_CUDA(
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu) {
const at::Tensor sum_dy,
const at::Tensor sum_dy_xmu,
const at::Tensor count) {
const auto batch_size = input.size(0);
const auto feature_size = input.size(1);
......@@ -1088,9 +1109,11 @@ at::Tensor batchnorm_backward_CUDA(
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
count.DATA_PTR<int>(),
grad_input.DATA_PTR<scalar_t_0>(),
count.numel(),
space_size,
batch_size);
);
......@@ -1108,9 +1131,11 @@ at::Tensor batchnorm_backward_CUDA(
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
count.DATA_PTR<int>(),
grad_input.DATA_PTR<scalar_t_0>(),
count.numel(),
space_size,
batch_size);
);
......@@ -1121,7 +1146,7 @@ at::Tensor batchnorm_backward_CUDA(
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
const at::Tensor var_biased,
int numel,
const at::Tensor numel,
const float eps) {
const auto world_size = mean_feature_nodes.size(0);
const auto feature_size = mean_feature_nodes.size(1);
......@@ -1142,13 +1167,13 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
welford_kernel_parallel<scalar_t_0><<<grid, block, 0, stream>>>(
mean_feature_nodes.DATA_PTR<scalar_t_0>(),
var_biased.DATA_PTR<scalar_t_0>(),
numel.DATA_PTR<int>(),
out_mean.DATA_PTR<scalar_t_0>(),
out_var.DATA_PTR<scalar_t_0>(),
inv_std.DATA_PTR<scalar_t_0>(),
world_size,
feature_size,
eps,
numel);
eps);
);
}
......@@ -1270,8 +1295,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
at::Tensor mean_dy = at::empty({stride}, mean.options());
at::Tensor mean_dy_xmu = at::empty({stride}, mean.options());
at::Tensor sumn_dy = at::empty({stride}, mean.options());
at::Tensor sum_dy_xmu = at::empty({stride}, mean.options());
at::Tensor grad_weight;
at::Tensor grad_bias;
......@@ -1310,8 +1335,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
grad_output.DATA_PTR<scalar_t_0>(),
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sumn_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
weight.has_value() ? grad_weight.DATA_PTR<accscalar_t>() : NULL,
weight.has_value() ?grad_bias.DATA_PTR<accscalar_t>() : NULL,
staging_data_ptr,
......@@ -1335,8 +1360,8 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
grad_output.DATA_PTR<scalar_t_0>(),
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sumn_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
weight.has_value() ? grad_weight.DATA_PTR<scalar_t_0>() : NULL,
weight.has_value() ?grad_bias.DATA_PTR<scalar_t_0>() : NULL,
staging_data_ptr,
......@@ -1346,7 +1371,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
);
}
return {mean_dy, mean_dy_xmu, grad_weight, grad_bias};
return {sumn_dy, sum_dy_xmu, grad_weight, grad_bias};
}
at::Tensor batchnorm_backward_c_last_CUDA(
......@@ -1355,8 +1380,9 @@ at::Tensor batchnorm_backward_c_last_CUDA(
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu) {
const at::Tensor sum_dy,
const at::Tensor sum_dy_xmu,
const at::Tensor count) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
......@@ -1380,9 +1406,11 @@ at::Tensor batchnorm_backward_c_last_CUDA(
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL,
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
count.DATA_PTR<int>(),
grad_input.DATA_PTR<scalar_t_0>(),
count.numel(),
reduction_size,
stride);
);
......@@ -1401,9 +1429,11 @@ at::Tensor batchnorm_backward_c_last_CUDA(
mean.DATA_PTR<accscalar_t>(),
inv_std.DATA_PTR<accscalar_t>(),
weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL,
mean_dy.DATA_PTR<accscalar_t>(),
mean_dy_xmu.DATA_PTR<accscalar_t>(),
sum_dy.DATA_PTR<accscalar_t>(),
sum_dy_xmu.DATA_PTR<accscalar_t>(),
count.DATA_PTR<int>(),
grad_input.DATA_PTR<scalar_t_0>(),
count.numel(),
reduction_size,
stride);
);
......
......@@ -35,6 +35,7 @@ inp = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype
grad = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype(dtype)
weight = (np.random.randn(feature_size)).astype(dtype)
bias = (np.random.randn(feature_size)).astype(dtype)
count = torch.cuda.IntTensor([batch_size*space_size**2])
type_tensor = torch.cuda.FloatTensor
ref_tensor = torch.cuda.DoubleTensor
......@@ -110,17 +111,19 @@ grad_output2_r = ref_tensor(grad)
grad_bias_r = grad_output_r.sum(1)
grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
sum_dy_r = grad_output_r.sum(1)
mean_dy_r = grad_output_r.mean(1)
sum_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1)
grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)
sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, sum_dy, sum_dy_xmu, count)
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result
sbn_result = compare("comparing mean_dy_xmu grad: ", mean_dy_xmu, mean_dy_xmu_r, error) and sbn_result
sbn_result = compare("comparing sum_dy grad: ", sum_dy, sum_dy_r, error) and sbn_result
sbn_result = compare("comparing sum_dy_xmu grad: ", sum_dy_xmu, sum_dy_xmu_r, error) and sbn_result
sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result
compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)
sbn_result = compare("comparing sbn input grad: ", inp_sbn.grad, grad_input_r, error) and sbn_result
......
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from apex.parallel import SyncBatchNorm as ApexSyncBatchNorm
import argparse
import os
import numpy as np
var_batch = 16
def compare(desc, inp1, inp2, error= 1e-5):
a = inp1.clone().detach().cpu().numpy()
b = inp2.clone().detach().cpu().numpy()
close = np.allclose(a,b, error, error)
if not close:
print(desc, close)
z = a - b
index = (np.abs(z) >= error + error * np.abs(b)).nonzero()
print("dif : ", z[index])
print("inp1 : ", a[index])
print("inp2 : ", b[index])
return close
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--apex', action='store_true')
args = parser.parse_args()
torch.manual_seed(2809)
# Setup DDP
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda:{}'.format(args.local_rank))
torch.distributed.init_process_group(
'nccl',
init_method='env://',
rank=args.local_rank,
)
# Setup model
if args.apex:
model = nn.Sequential(
nn.Conv2d(3, 6, 3, 1, 1),
ApexSyncBatchNorm(6)
)
else:
model = nn.Sequential(
nn.Conv2d(3, 6, 3, 1, 1),
nn.SyncBatchNorm(6)
)
# Setup reference model
model_reference = nn.Sequential(
nn.Conv2d(3, 6, 3, 1, 1),
nn.BatchNorm2d(6)
)
with torch.no_grad():
model_reference[0].weight.copy_(model[0].weight)
model_reference[0].bias.copy_(model[0].bias)
model_reference.to(device)
model = model.to(device)
model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
global_batch_size = var_batch + 8
# Create random data
if args.local_rank == 0:
data = torch.randn(var_batch, 3, 8, 8, device=device, dtype=torch.float) * 50.0
grad = torch.randint(0, 10, (var_batch, 6, 8, 8), device=device, dtype=torch.float) / 10.0
else:
data = torch.randn(8, 3, 8, 8, device=device)
grad = torch.randint(0, 10, (8, 6, 8, 8), device=device, dtype=torch.float) / 10.0
data.requires_grad_()
data.retain_grad = True
weighted_gradient = True
# DDP forward/backward
output = model(data)
if weighted_gradient:
output.backward(grad * 2 / global_batch_size)
else:
output.backward(grad / output.size(0))
d_list = [torch.randn(8, 3, 8, 8, device=device) for i in range(int(os.environ['WORLD_SIZE']))]
y_list = [torch.randn(8, 6, 8, 8, device=device) for i in range(int(os.environ['WORLD_SIZE']))]
dgrad_list = [torch.randn(8, 3, 8, 8, device=device) for i in range(int(os.environ['WORLD_SIZE']))]
grad_list = [torch.randn(8, 6, 8, 8, device=device) for i in range(int(os.environ['WORLD_SIZE']))]
if args.local_rank == 0:
# placeholder, these random data will later be discarded.
torch.distributed.all_gather(d_list, torch.randn(8, 3, 8, 8, device=device))
torch.distributed.all_gather(y_list, torch.randn(8, 6, 8, 8, device=device))
torch.distributed.all_gather(dgrad_list, torch.randn(8, 3, 8, 8, device=device))
torch.distributed.all_gather(grad_list, torch.randn(8, 6, 8, 8, device=device))
else:
torch.distributed.all_gather(d_list, data)
torch.distributed.all_gather(y_list, output)
torch.distributed.all_gather(dgrad_list, data.grad)
torch.distributed.all_gather(grad_list, grad)
torch.distributed.barrier()
if args.local_rank == 0:
ref_tensor = d_list[1:]
ref_tensor.insert(0, data)
assert(ref_tensor[0].equal(data))
ref_tensor = torch.cat(ref_tensor, 0)
ref_tensor = ref_tensor.detach()
ref_tensor.requires_grad_()
ref_tensor.retain_grad()
# Reference forward/backward
output_reference = model_reference(ref_tensor)
grad_tensor = grad_list[1:]
grad_tensor.insert(0, grad)
assert(grad_tensor[0].equal(grad))
grad_tensor = torch.cat(grad_tensor, 0)
if weighted_gradient:
output_reference.backward(grad_tensor / output_reference.size(0))
else:
output_reference.backward(grad_tensor / output_reference.size(0))
dgrad_tensor = dgrad_list[1:]
dgrad_tensor.insert(0, data.grad)
dgrad_tensor = torch.cat(dgrad_tensor, 0)
# check output
output_tensor = y_list[1:]
output_tensor.insert(0, output)
output_tensor = torch.cat(output_tensor, 0)
passed = True
passed = passed and compare("check output",
output_tensor,
output_reference)
# check stats
passed = passed and compare("check running mean failed",
model_reference[1].running_mean,
model.module[1].running_mean)
passed = passed and compare("check running var failed",
model_reference[1].running_var,
model.module[1].running_var)
passed = passed and compare("bn wgrad check failed!",
model_reference[1].weight.grad,
model.module[1].weight.grad, 1e-6)
passed = passed and compare("conv wgrad check failed!",
model_reference[0].weight.grad,
model.module[0].weight.grad)
# can't really compare dgrad directly, as we need to scale it to account for
# DDP
# passed = passed and compare("dgrad check failed!", ref_tensor.grad, dgrad_tensor)
if passed:
print("====SBN two gpu with different batches test passed")
else:
assert("*failed two gpu with different batches tests*")
......@@ -114,6 +114,11 @@ grad_sbn = grad_output_t.clone().detach()
out_sbn = sbn(inp_sbn[start:finish])
out_sbn.backward(grad_sbn[start:finish])
count = [ space_size**2 * ( (i+1) * batch_size // args.world_size - i * batch_size // args.world_size ) for i in range(0, args.world_size)]
count = torch.cuda.IntTensor(count)
print("--- count : " , count)
sbn_result = True
bn_result = True
......@@ -136,18 +141,20 @@ grad_output2_r = ref_tensor(grad)
grad_bias_r = grad_output_r.sum(1)
grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
sum_dy_r = grad_output_r.sum(1)
mean_dy_r = grad_output_r.mean(1)
mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1)
sum_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)
sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, sum_dy, sum_dy_xmu, count)
if args.local_rank == 0:
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result
sbn_result = compare("comparing mean_dy_xmu grad: ", mean_dy_xmu, mean_dy_xmu_r, error) and sbn_result
sbn_result = compare("comparing sum_dy grad: ", sum_dy, sum_dy_r, error) and sbn_result
sbn_result = compare("comparing sum_dy_xmu grad: ", sum_dy_xmu, sum_dy_xmu_r, error) and sbn_result
sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result
compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)
......
......@@ -3,5 +3,6 @@ python single_gpu_unit_test.py
python test_batchnorm1d.py
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py --fp16
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_test_different_batch_size.py --apex
#beware, you need a system with at least 4 gpus to test group_size<world_size
#python -m torch.distributed.launch --nproc_per_node=4 test_groups.py --group_size=2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment