Unverified Commit dd584a59 authored by mahathis's avatar mahathis Committed by GitHub
Browse files

Added support for memory format API(torch.channels_last) in GBN (#72)



* Added suuport for memory format API(torch.channels_last) in GBN

Group Batch Norm (GBN) is an NHWC operation.  It assumes that the
underlying memory format of an input tensor is NHWC.  It originally does
not support PyTorch's memory_format API.

To support PyTorch's memory_format API, i.e., .to(memory_format=...) or
.contiguous(memory_format=...), we add the torch_channels_last
flag to indicate whether the workload adopts the PyTorch memory_format
API by setting memory_format=torch.channels_last.  This flag allows GBN
to handle memory formats of input tensors properly.

An example to use memory_format in GBN:

"""
from apex.contrib.groupbn.batch_norm import BatchNorm2d_NHWC

GBN = BatchNorm2d_NHWC(planes, fuse_relu=True, bn_group=1, torch_channels_last=True)

"""

The cases that GBN handles are as follows:

1. torch_channels_last=True and input tensor's
memory_format=torch.channels_last, GBN will generate the
torch.channels_last output tensor.

2. torch_channels_last=True and input tensor's
memory_format=torch.contiguous_format, GBN will convert the input tensor
to torch.channels_last and will generate the torch.channels_last output
tensor.

3. use_pytorch_channels_last=False and input tensor's
memory_format=torch.contiguous_format, GBN will generate the
torch.contiguous_format output tensor.

* Add GBN unit tests for channel_last memory format
Co-authored-by: default avatarhubertlu-tw <hubertlu@amd.com>
parent 29b36315
...@@ -63,17 +63,19 @@ at::Tensor nhwc_bn_fwd_train( ...@@ -63,17 +63,19 @@ at::Tensor nhwc_bn_fwd_train(
const int grid_dim_x, const int grid_dim_x,
const bool coop) { const bool coop) {
auto memory_format = x.suggest_memory_format();
const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast);
const int N = x.size(0); const int N = x.size(0);
const int H = x.size(1); const int H = check_channels_last ? x.size(2) : x.size(1);
const int W = x.size(2); const int W = check_channels_last ? x.size(3) : x.size(2);
const int C = x.size(3); const int C = check_channels_last ? x.size(1) : x.size(3);
// generating new magic number and use that for sync // generating new magic number and use that for sync
int* magic = magic_tensor.DATA_PTR<int>(); int* magic = magic_tensor.DATA_PTR<int>();
*magic = (*magic + 1) & 0xff; *magic = (*magic + 1) & 0xff;
// Allocate output tensor // Allocate output tensor
at::Tensor y = at::empty({N, H, W, C}, x.options()); at::Tensor y = check_channels_last ? at::empty({N, C, H, W}, x.options().memory_format(memory_format)) : at::empty({N, H, W, C}, x.options());
// Create wrapper // Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm(); NhwcBatchNorm *bn = new NhwcBatchNorm();
...@@ -84,9 +86,9 @@ at::Tensor nhwc_bn_fwd_train( ...@@ -84,9 +86,9 @@ at::Tensor nhwc_bn_fwd_train(
bn->setConstants(momentum, epsilon); bn->setConstants(momentum, epsilon);
// set pointers within the wrapper // set pointers within the wrapper
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(), bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr, nullptr,
y.DATA_PTR<at::Half>(), y.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr); nullptr);
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(), bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
...@@ -132,7 +134,7 @@ at::Tensor nhwc_bn_fwd_train( ...@@ -132,7 +134,7 @@ at::Tensor nhwc_bn_fwd_train(
// Don't fuse in ReLU for now at least // Don't fuse in ReLU for now at least
bn->fwd(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop); bn->fwd(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
return y; return y.contiguous(memory_format);
} }
at::Tensor nhwc_bn_fwd_eval( at::Tensor nhwc_bn_fwd_eval(
...@@ -147,13 +149,15 @@ at::Tensor nhwc_bn_fwd_eval( ...@@ -147,13 +149,15 @@ at::Tensor nhwc_bn_fwd_eval(
const float epsilon, const float epsilon,
const bool fuse_relu) { const bool fuse_relu) {
const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast);
auto memory_format = x.suggest_memory_format();
const int N = x.size(0); const int N = x.size(0);
const int H = x.size(1); const int H = check_channels_last ? x.size(2) : x.size(1);
const int W = x.size(2); const int W = check_channels_last ? x.size(3) : x.size(2);
const int C = x.size(3); const int C = check_channels_last ? x.size(1) : x.size(3);
// Allocate output tensor // Allocate output tensor
at::Tensor y = at::empty({N, H, W, C}, x.options()); at::Tensor y = check_channels_last ? at::empty({N, C, H, W}, x.options().memory_format(memory_format)) : at::empty({N, H, W, C}, x.options());
// Create wrapper // Create wrapper
NhwcBatchNorm *bn = new NhwcBatchNorm(); NhwcBatchNorm *bn = new NhwcBatchNorm();
...@@ -164,9 +168,9 @@ at::Tensor nhwc_bn_fwd_eval( ...@@ -164,9 +168,9 @@ at::Tensor nhwc_bn_fwd_eval(
bn->setConstants(momentum, epsilon); bn->setConstants(momentum, epsilon);
// set pointers within the wrapper // set pointers within the wrapper
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(), bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr, nullptr,
y.DATA_PTR<at::Half>(), y.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr); nullptr);
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(), bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
...@@ -212,7 +216,7 @@ at::Tensor nhwc_bn_fwd_eval( ...@@ -212,7 +216,7 @@ at::Tensor nhwc_bn_fwd_eval(
// Don't fuse in ReLU for now at least // Don't fuse in ReLU for now at least
bn->fwdInference(stream, fuse_relu); bn->fwdInference(stream, fuse_relu);
return y; return y.contiguous(memory_format);
} }
...@@ -239,10 +243,12 @@ std::vector<at::Tensor> nhwc_bn_bwd( ...@@ -239,10 +243,12 @@ std::vector<at::Tensor> nhwc_bn_bwd(
const int grid_dim_x, const int grid_dim_x,
const bool coop) { const bool coop) {
// shape // shape
const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast);
auto memory_format = x.suggest_memory_format();
const int N = x.size(0); const int N = x.size(0);
const int H = x.size(1); const int H = check_channels_last ? x.size(2) : x.size(1);
const int W = x.size(2); const int W = check_channels_last ? x.size(3) : x.size(2);
const int C = x.size(3); const int C = check_channels_last ? x.size(1) : x.size(3);
// generating new magic number and use that for sync // generating new magic number and use that for sync
int* magic = magic_tensor.DATA_PTR<int>(); int* magic = magic_tensor.DATA_PTR<int>();
...@@ -252,7 +258,7 @@ std::vector<at::Tensor> nhwc_bn_bwd( ...@@ -252,7 +258,7 @@ std::vector<at::Tensor> nhwc_bn_bwd(
at::Tensor x_grad, scale_grad, bias_grad; at::Tensor x_grad, scale_grad, bias_grad;
// Allocate outputs // Allocate outputs
x_grad = at::empty_like(x); x_grad = check_channels_last ? at::empty({N, C, H, W}, dy.options().memory_format(memory_format)) : at::empty_like(x);
scale_grad = at::empty_like(scale); scale_grad = at::empty_like(scale);
bias_grad = at::empty_like(bias); bias_grad = at::empty_like(bias);
...@@ -265,10 +271,10 @@ std::vector<at::Tensor> nhwc_bn_bwd( ...@@ -265,10 +271,10 @@ std::vector<at::Tensor> nhwc_bn_bwd(
bn->setConstants(momentum, epsilon); bn->setConstants(momentum, epsilon);
// set pointers within the wrapper // set pointers within the wrapper
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(), bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR<at::Half>(),
x_grad.DATA_PTR<at::Half>(), x_grad.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr, nullptr,
dy.contiguous().DATA_PTR<at::Half>()); dy.contiguous(memory_format).DATA_PTR<at::Half>());
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(), bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()}, bias.contiguous().DATA_PTR<float>()},
...@@ -314,7 +320,7 @@ std::vector<at::Tensor> nhwc_bn_bwd( ...@@ -314,7 +320,7 @@ std::vector<at::Tensor> nhwc_bn_bwd(
bn->dgrad(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop); bn->dgrad(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
return std::vector<at::Tensor>{x_grad, scale_grad, bias_grad}; return std::vector<at::Tensor>{x_grad.contiguous(memory_format), scale_grad, bias_grad};
} }
int nhwc_bn_fwd_occupancy() { int nhwc_bn_fwd_occupancy() {
......
...@@ -65,17 +65,19 @@ at::Tensor nhwc_bn_addrelu_fwd_train( ...@@ -65,17 +65,19 @@ at::Tensor nhwc_bn_addrelu_fwd_train(
const int grid_dim_x, const int grid_dim_x,
const bool coop) { const bool coop) {
auto memory_format = x.suggest_memory_format();
const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast);
const int N = x.size(0); const int N = x.size(0);
const int H = x.size(1); const int H = check_channels_last ? x.size(2) : x.size(1);
const int W = x.size(2); const int W = check_channels_last ? x.size(3) : x.size(2);
const int C = x.size(3); const int C = check_channels_last ? x.size(1) : x.size(3);
// generating new magic number and use that for sync // generating new magic number and use that for sync
int* magic = magic_tensor.DATA_PTR<int>(); int* magic = magic_tensor.DATA_PTR<int>();
*magic = (*magic + 1) & 0xff; *magic = (*magic + 1) & 0xff;
// Allocate output tensor // Allocate output tensor
at::Tensor y = at::empty({N, H, W, C}, x.options()); at::Tensor y = check_channels_last? at::empty({N, C, H, W}, x.options().memory_format(memory_format)) : at::empty({N, H, W, C}, x.options());
// Create wrapper // Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu(); NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
...@@ -86,11 +88,11 @@ at::Tensor nhwc_bn_addrelu_fwd_train( ...@@ -86,11 +88,11 @@ at::Tensor nhwc_bn_addrelu_fwd_train(
bn->setConstants(momentum, epsilon); bn->setConstants(momentum, epsilon);
// set pointers within the wrapper // set pointers within the wrapper
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(), bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr, nullptr,
y.DATA_PTR<at::Half>(), y.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr, nullptr,
z.contiguous().DATA_PTR<at::Half>(), z.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr); nullptr);
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(), bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
...@@ -138,7 +140,7 @@ at::Tensor nhwc_bn_addrelu_fwd_train( ...@@ -138,7 +140,7 @@ at::Tensor nhwc_bn_addrelu_fwd_train(
// Don't fuse in ReLU for now at least // Don't fuse in ReLU for now at least
bn->fwd(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop); bn->fwd(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
return y; return y.contiguous(memory_format);
} }
at::Tensor nhwc_bn_addrelu_fwd_eval( at::Tensor nhwc_bn_addrelu_fwd_eval(
...@@ -153,13 +155,15 @@ at::Tensor nhwc_bn_addrelu_fwd_eval( ...@@ -153,13 +155,15 @@ at::Tensor nhwc_bn_addrelu_fwd_eval(
const float momentum, const float momentum,
const float epsilon) { const float epsilon) {
auto memory_format = x.suggest_memory_format();
const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast);
const int N = x.size(0); const int N = x.size(0);
const int H = x.size(1); const int H = check_channels_last ? x.size(2) : x.size(1);
const int W = x.size(2); const int W = check_channels_last ? x.size(3) : x.size(2);
const int C = x.size(3); const int C = check_channels_last ? x.size(1) : x.size(3);
// Allocate output tensor // Allocate output tensor
at::Tensor y = at::empty({N, H, W, C}, x.options()); at::Tensor y = check_channels_last? at::empty({N, C, H, W}, x.options().memory_format(memory_format)): at::empty({N, H, W, C}, x.options());
// Create wrapper // Create wrapper
NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu(); NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu();
...@@ -170,11 +174,11 @@ at::Tensor nhwc_bn_addrelu_fwd_eval( ...@@ -170,11 +174,11 @@ at::Tensor nhwc_bn_addrelu_fwd_eval(
bn->setConstants(momentum, epsilon); bn->setConstants(momentum, epsilon);
// set pointers within the wrapper // set pointers within the wrapper
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(), bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr, nullptr,
y.DATA_PTR<at::Half>(), y.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr, nullptr,
z.contiguous().DATA_PTR<at::Half>(), z.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr); nullptr);
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(), bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
...@@ -221,7 +225,7 @@ at::Tensor nhwc_bn_addrelu_fwd_eval( ...@@ -221,7 +225,7 @@ at::Tensor nhwc_bn_addrelu_fwd_eval(
// Don't fuse in ReLU for now at least // Don't fuse in ReLU for now at least
bn->fwdInference(stream); bn->fwdInference(stream);
return y; return y.contiguous(memory_format);
} }
...@@ -248,10 +252,12 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd( ...@@ -248,10 +252,12 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
const int grid_dim_x, const int grid_dim_x,
const bool coop) { const bool coop) {
// shape // shape
auto memory_format = x.suggest_memory_format();
const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast);
const int N = x.size(0); const int N = x.size(0);
const int H = x.size(1); const int H = check_channels_last ? x.size(2) : x.size(1);
const int W = x.size(2); const int W = check_channels_last ? x.size(3) : x.size(2);
const int C = x.size(3); const int C = check_channels_last ? x.size(1) : x.size(3);
// generating new magic number and use that for sync // generating new magic number and use that for sync
int* magic = magic_tensor.DATA_PTR<int>(); int* magic = magic_tensor.DATA_PTR<int>();
...@@ -261,8 +267,8 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd( ...@@ -261,8 +267,8 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
at::Tensor x_grad, z_grad, scale_grad, bias_grad; at::Tensor x_grad, z_grad, scale_grad, bias_grad;
// Allocate outputs // Allocate outputs
x_grad = at::empty_like(x); x_grad = check_channels_last ? at::empty({N, C, H, W}, dy.options().memory_format(memory_format)) : at::empty_like(x);
z_grad = at::empty_like(x); z_grad = check_channels_last ? at::empty({N, C, H, W}, dy.options().memory_format(memory_format)) : at::empty_like(x);
scale_grad = at::empty_like(scale); scale_grad = at::empty_like(scale);
bias_grad = at::empty_like(bias); bias_grad = at::empty_like(bias);
...@@ -275,12 +281,12 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd( ...@@ -275,12 +281,12 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
bn->setConstants(momentum, epsilon); bn->setConstants(momentum, epsilon);
// set pointers within the wrapper // set pointers within the wrapper
bn->setInputOutputPointers(x.contiguous().DATA_PTR<at::Half>(), bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR<at::Half>(),
x_grad.DATA_PTR<at::Half>(), x_grad.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr, nullptr,
dy.contiguous().DATA_PTR<at::Half>(), dy.contiguous(memory_format).DATA_PTR<at::Half>(),
nullptr, nullptr,
z_grad.DATA_PTR<at::Half>()); z_grad.contiguous(memory_format).DATA_PTR<at::Half>());
bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(), bn->setWeightPointers({scale.contiguous().DATA_PTR<float>(),
bias.contiguous().DATA_PTR<float>()}, bias.contiguous().DATA_PTR<float>()},
...@@ -326,7 +332,7 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd( ...@@ -326,7 +332,7 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
bn->dgrad(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop); bn->dgrad(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
return std::vector<at::Tensor>{x_grad, z_grad, scale_grad, bias_grad}; return std::vector<at::Tensor>{x_grad.contiguous(memory_format), z_grad.contiguous(memory_format), scale_grad, bias_grad};
} }
int nhwc_bn_addrelu_fwd_occupancy() { int nhwc_bn_addrelu_fwd_occupancy() {
......
...@@ -14,11 +14,20 @@ def check_if_rocm_pytorch(): ...@@ -14,11 +14,20 @@ def check_if_rocm_pytorch():
IS_ROCM_PYTORCH = check_if_rocm_pytorch() IS_ROCM_PYTORCH = check_if_rocm_pytorch()
def check_and_convert_channels_last(tensor, torch_channels_last):
if torch_channels_last:
channels_last = tensor.is_contiguous(memory_format = torch.channels_last)
if not channels_last:
tensor = tensor.to(memory_format = torch.channels_last)
return tensor
class bn_NHWC_impl(torch.autograd.Function): class bn_NHWC_impl(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream): def forward(ctx, x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, is_train, torch_channels_last, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream):
x = check_and_convert_channels_last(x, torch_channels_last)
if is_train: if is_train:
ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv) ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv)
ctx.torch_channels_last = torch_channels_last
ctx.epsilon = epsilon ctx.epsilon = epsilon
ctx.momentum = mom ctx.momentum = mom
ctx.ret_cta = ret_cta ctx.ret_cta = ret_cta
...@@ -41,6 +50,8 @@ class bn_NHWC_impl(torch.autograd.Function): ...@@ -41,6 +50,8 @@ class bn_NHWC_impl(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_y): def backward(ctx, grad_y):
x, s, b, rm, riv, mini_m, mini_riv = ctx.saved_variables x, s, b, rm, riv, mini_m, mini_riv = ctx.saved_variables
grad_y = check_and_convert_channels_last(grad_y, ctx.torch_channels_last)
x = check_and_convert_channels_last(x, ctx.torch_channels_last)
epsilon = ctx.epsilon epsilon = ctx.epsilon
mom = ctx.momentum mom = ctx.momentum
ret_cta = ctx.ret_cta ret_cta = ctx.ret_cta
...@@ -57,20 +68,26 @@ class bn_NHWC_impl(torch.autograd.Function): ...@@ -57,20 +68,26 @@ class bn_NHWC_impl(torch.autograd.Function):
dx, dscale, dbias = bnp.bn_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream) dx, dscale, dbias = bnp.bn_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream)
return dx, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None return dx, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class bn_addrelu_NHWC_impl(torch.autograd.Function): class bn_addrelu_NHWC_impl(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, z, s, b, rm, riv, mini_m, mini_riv, grid_dim_y, ret_cta, mom, epsilon, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream): def forward(ctx, x, z, s, b, rm, riv, mini_m, mini_riv, grid_dim_y, ret_cta, mom, epsilon, is_train, torch_channels_last, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream):
x = check_and_convert_channels_last(x, torch_channels_last)
z = check_and_convert_channels_last(z, torch_channels_last)
if is_train: if is_train:
if IS_ROCM_PYTORCH: if IS_ROCM_PYTORCH:
if torch_channels_last:
nhw = x.shape[0] * x.shape[2] * x.shape[3]
else:
nhw = x.shape[0] * x.shape[1] * x.shape[2] nhw = x.shape[0] * x.shape[1] * x.shape[2]
shape = int(((nhw + 3) & ~3) * grid_dim_y) shape = int(((nhw + 3) & ~3) * grid_dim_y)
bitmask = torch.cuda.LongTensor(shape) bitmask = torch.cuda.LongTensor(shape)
else: else:
bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y) bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y)
ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask) ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask)
ctx.torch_channels_last = torch_channels_last
ctx.epsilon = epsilon ctx.epsilon = epsilon
ctx.momentum = mom ctx.momentum = mom
ctx.ret_cta = ret_cta ctx.ret_cta = ret_cta
...@@ -92,6 +109,8 @@ class bn_addrelu_NHWC_impl(torch.autograd.Function): ...@@ -92,6 +109,8 @@ class bn_addrelu_NHWC_impl(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_y): def backward(ctx, grad_y):
x, s, b, rm, riv, mini_m, mini_riv, bitmask = ctx.saved_variables x, s, b, rm, riv, mini_m, mini_riv, bitmask = ctx.saved_variables
grad_y = check_and_convert_channels_last(grad_y, ctx.torch_channels_last)
x = check_and_convert_channels_last(x, ctx.torch_channels_last)
epsilon = ctx.epsilon epsilon = ctx.epsilon
mom = ctx.momentum mom = ctx.momentum
ret_cta = ctx.ret_cta ret_cta = ctx.ret_cta
...@@ -107,7 +126,7 @@ class bn_addrelu_NHWC_impl(torch.autograd.Function): ...@@ -107,7 +126,7 @@ class bn_addrelu_NHWC_impl(torch.autograd.Function):
dx, dz, dscale, dbias = bnp.bn_addrelu_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, bitmask, ret_cta, mom, epsilon, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream) dx, dz, dscale, dbias = bnp.bn_addrelu_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, bitmask, ret_cta, mom, epsilon, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream)
return dx, dz, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None return dx, dz, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
...@@ -115,10 +134,11 @@ class bn_addrelu_NHWC_impl(torch.autograd.Function): ...@@ -115,10 +134,11 @@ class bn_addrelu_NHWC_impl(torch.autograd.Function):
class BatchNorm2d_NHWC(_BatchNorm): class BatchNorm2d_NHWC(_BatchNorm):
# if using BatchNorm2d_NHWC simultaneously with multiple streams set multi_stream to True # if using BatchNorm2d_NHWC simultaneously with multiple streams set multi_stream to True
def __init__(self, num_features, fuse_relu=False, bn_group=1, max_cta_per_sm=2, cta_launch_margin=12, multi_stream=False): def __init__(self, num_features, fuse_relu=False, bn_group=1, torch_channels_last=False,max_cta_per_sm=2, cta_launch_margin=12, multi_stream=False):
super(BatchNorm2d_NHWC, self).__init__(num_features) super(BatchNorm2d_NHWC, self).__init__(num_features)
self.fuse_relu = fuse_relu self.fuse_relu = fuse_relu
self.torch_channels_last = torch_channels_last
self.multi_stream = multi_stream self.multi_stream = multi_stream
self.minibatch_mean = torch.cuda.FloatTensor(num_features) self.minibatch_mean = torch.cuda.FloatTensor(num_features)
...@@ -216,7 +236,7 @@ class BatchNorm2d_NHWC(_BatchNorm): ...@@ -216,7 +236,7 @@ class BatchNorm2d_NHWC(_BatchNorm):
self.running_mean, self.running_var, self.running_mean, self.running_var,
self.minibatch_mean, self.minibatch_riv, self.grid_dim_y, self.ret_cta, self.minibatch_mean, self.minibatch_riv, self.grid_dim_y, self.ret_cta,
self.momentum, self.momentum,
self.eps, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3, self.eps, self.training, self.torch_channels_last, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3,
self.addrelu_fwd_occupancy, self.addrelu_fwd_grid_dim_x, self.addrelu_fwd_occupancy, self.addrelu_fwd_grid_dim_x,
self.addrelu_bwd_occupancy, self.addrelu_bwd_grid_dim_x, self.addrelu_bwd_occupancy, self.addrelu_bwd_grid_dim_x,
self.multi_stream) self.multi_stream)
...@@ -226,7 +246,7 @@ class BatchNorm2d_NHWC(_BatchNorm): ...@@ -226,7 +246,7 @@ class BatchNorm2d_NHWC(_BatchNorm):
self.running_mean, self.running_var, self.running_mean, self.running_var,
self.minibatch_mean, self.minibatch_riv, self.ret_cta, self.minibatch_mean, self.minibatch_riv, self.ret_cta,
self.momentum, self.momentum,
self.eps, self.fuse_relu, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3, self.eps, self.fuse_relu, self.training, self.torch_channels_last, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3,
self.fwd_occupancy, self.fwd_grid_dim_x, self.fwd_occupancy, self.fwd_grid_dim_x,
self.bwd_occupancy, self.bwd_grid_dim_x, self.bwd_occupancy, self.bwd_grid_dim_x,
self.multi_stream) self.multi_stream)
......
...@@ -72,10 +72,8 @@ class TestGroupBN(unittest.TestCase): ...@@ -72,10 +72,8 @@ class TestGroupBN(unittest.TestCase):
print('Running {}'.format(mode)) print('Running {}'.format(mode))
tensor_sizes = [ tensor_sizes = [
(120, 64, 150, 150),
(120, 64, 75, 75), (120, 64, 75, 75),
(120, 128, 38, 38), (120, 128, 38, 38)]
(120, 256, 38, 38)]
for i in range(len(tensor_sizes)): for i in range(len(tensor_sizes)):
tensor_size = tensor_sizes[i] tensor_size = tensor_sizes[i]
...@@ -103,7 +101,7 @@ class TestGroupBN(unittest.TestCase): ...@@ -103,7 +101,7 @@ class TestGroupBN(unittest.TestCase):
# Create models # Create models
batchnorm_model = Bn(num_channels, mode).cuda() batchnorm_model = Bn(num_channels, mode).cuda()
group_batchnorm = BatchNorm2d_NHWC(num_channels, fuse_relu=fuse_relu, bn_group=1).cuda() group_batchnorm = BatchNorm2d_NHWC(num_channels, fuse_relu=fuse_relu, bn_group=1,torch_channels_last=False).cuda()
# Run reference forward # Run reference forward
bn_output = batchnorm_model(input_data, residual_data) bn_output = batchnorm_model(input_data, residual_data)
......
import torch
import unittest
import numpy as np
import random
from apex.contrib.groupbn.batch_norm import BatchNorm2d_NHWC
def generate_uniform_tensor(size, np_dtype, pyt_dtype, device):
array = None
while array is None or np.isnan(array).any():
array = np.random.uniform(low=-1.0, high=1.0, size=size).astype(np_dtype)
return torch.from_numpy(array).to(device).to(pyt_dtype)
def to_channels_last(tensor):
#return tensor.permute(0, 2, 3, 1).contiguous()
return tensor.to(memory_format = torch.channels_last)
def to_channels_first(tensor):
#return tensor.permute(0, 3, 1, 2).contiguous()
return tensor.to(memory_format = torch.contiguous_format)
class Bn(torch.nn.BatchNorm2d):
def __init__(self, planes, mode):
super(Bn, self).__init__(planes, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.mode = mode
def forward(self, x, z=None):
out = super().forward(x)
if self.mode == 'bn_add_relu':
out = out.add_(z)
if self.mode != 'bn':
out = out.relu_()
return out
def bn_nhwc_bwd_ref(grad_y, x, mu, ivar, gamma):
grad_y = grad_y.permute(0, 2, 3, 1).contiguous()
x = x.permute(0, 2, 3, 1).contiguous()
sum_dim_c = (0, 1, 2)
grad_y_f32 = grad_y.float()
x_f32 = x.float()
N = x.shape[0] * x.shape[1] * x.shape[2] # nhw
ones = torch.ones(x.shape, dtype=torch.float32, device='cuda')
xmu = x_f32 - mu
xhat = xmu * ivar
dbias = torch.sum(grad_y_f32, dim=sum_dim_c)
dscale = torch.sum(grad_y_f32 * xhat, dim=sum_dim_c)
dx1 = (gamma * ivar) / N
dx2 = (N * grad_y_f32) - (dbias * ones)
dx3 = -xhat * dscale
dx23 = dx2 + dx3
dx = dx1 * (dx23)
dx = dx.half()
dx = dx.permute(0, 3, 1, 2).contiguous()
return dx, dscale, dbias
class TestGroupBNChannelLast(unittest.TestCase):
def setUp(self, seed=5, verbose=False):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
self.verbose = verbose
def test_bn_channel_last(self):
self.run_group_bn_channel_last('bn')
def test_bn_relu_channel_last(self):
self.run_group_bn_channel_last('bn_relu')
def test_bn_add_relu_channel_last(self):
self.run_group_bn_channel_last('bn_add_relu')
def run_group_bn_channel_last(self, mode):
if self.verbose:
print('Running {}'.format(mode))
tensor_sizes = [
(120, 64, 75, 75),
(120, 128, 38, 38)]
for i in range(len(tensor_sizes)):
tensor_size = tensor_sizes[i]
num_channels = tensor_size[1]
# Create input data
input_data = generate_uniform_tensor(tensor_size, np.float16, torch.half, 'cuda')
np.save('input.npy', input_data.detach().cpu().numpy())
input_data.requires_grad = True
gbn_input = torch.from_numpy(np.load('input.npy')).cuda().half()
gbn_input.requires_grad = True
residual_data = None
gbn_residual_data = None
if mode == 'bn':
fuse_relu = False
else:
fuse_relu = True
if mode == 'bn_add_relu':
residual_data = generate_uniform_tensor(tensor_size, np.float16, torch.half, 'cuda')
gbn_residual_data = to_channels_last(residual_data)
bn_grad = generate_uniform_tensor(input_data.shape, np.float16, torch.half, 'cuda')
# Create models
batchnorm_model = Bn(num_channels, mode).cuda()
group_batchnorm = BatchNorm2d_NHWC(num_channels, fuse_relu=fuse_relu, bn_group=1, torch_channels_last=True).cuda()
# Run reference forward
bn_output = batchnorm_model(input_data, residual_data)
# Run GBN forward
gbn_input_data = to_channels_last(gbn_input)
#gbn_input_data = gbn_input
gbn_output = group_batchnorm(gbn_input_data, gbn_residual_data)
torch.cuda.synchronize()
# Run reference backward
# (Use the same input and parameters as GBN)
gbn_grad = to_channels_last(bn_grad)
#gbn_grad = bn_grad
grad = gbn_grad.clone().detach()
input_data = torch.from_numpy(np.load('input.npy')).cuda().half()
input_data = to_channels_last(input_data)
if mode != 'bn':
grad[gbn_output <= 0] = 0
bn_output_grad, _, _ = bn_nhwc_bwd_ref( \
grad,
input_data,
group_batchnorm.minibatch_mean,
group_batchnorm.minibatch_riv,
group_batchnorm.weight)
bn_output_grad = to_channels_first(bn_output_grad)
# Run GBN backward
gbn_output.backward(gbn_grad)
torch.cuda.synchronize()
gbn_output = to_channels_first(gbn_output)
gbn_output_grad = gbn_input.grad.detach().clone().cpu()
########################## Validate results ##########################
if self.verbose:
print('Validate activation')
self.validate(bn_output.shape, bn_output, gbn_output)
if self.verbose:
print('Validate grad')
self.validate(bn_output_grad.shape, bn_output_grad, gbn_output_grad, is_grad=True)
def validate(self, tensors, output_ref, output_test, is_grad=False):
output_ref = output_ref.detach().cpu().numpy()
output_test = output_test.detach().cpu().numpy()
if self.verbose:
print('>>> tensor_size\t{}'.format(tensors))
print("sum_output_ref {}, isnan {}, max {}, min {}".format(
np.sum(output_ref, dtype=float), np.isnan(output_ref).any(), np.max(output_ref), np.min(output_ref)))
print("sum_output_test {}, isnan {}, max {}, min {}".format(
np.sum(output_test, dtype=float), np.isnan(output_test).any(), np.max(output_test), np.min(output_test)))
ret = np.array_equal(output_ref, output_test)
if not ret:
ret_allclose = np.allclose(
output_ref, output_test, rtol=1e-3, atol=1e-3, equal_nan=True)
if self.verbose:
print('{}\tshape {}\tidentical {}\tclose {}'.format('cpu/gpu', tensors, ret, ret_allclose))
output_ref = output_ref.flatten()
output_test = output_test.flatten()
if not ret:
sub = np.absolute(output_ref - output_test)
norm_diff = np.average(sub)
rel = np.divide(sub, np.absolute(output_ref))
rel[rel == np.inf] = 0
max_abs_idx = np.argmax(sub)
max_rel_idx = np.argmax(rel)
if self.verbose:
print('max_diff {}, max_rel_diff {}, norm_diff {}'.format(np.max(sub), np.max(rel), np.average(sub)))
print('max_abs pair [{}] {} {}'.format(max_abs_idx, output_ref[max_abs_idx], output_test[max_abs_idx]))
print('max_rel pair [{}] {} {}'.format(max_rel_idx, output_ref[max_rel_idx], output_test[max_rel_idx]))
result = ret or ret_allclose or (is_grad and norm_diff < 1e-4)
if self.verbose:
print("Result {}".format("PASS" if result else "FAIL"))
self.assertTrue(result)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment