Added support for memory format API(torch.channels_last) in GBN (#72)
* Added suuport for memory format API(torch.channels_last) in GBN
Group Batch Norm (GBN) is an NHWC operation. It assumes that the
underlying memory format of an input tensor is NHWC. It originally does
not support PyTorch's memory_format API.
To support PyTorch's memory_format API, i.e., .to(memory_format=...) or
.contiguous(memory_format=...), we add the torch_channels_last
flag to indicate whether the workload adopts the PyTorch memory_format
API by setting memory_format=torch.channels_last. This flag allows GBN
to handle memory formats of input tensors properly.
An example to use memory_format in GBN:
"""
from apex.contrib.groupbn.batch_norm import BatchNorm2d_NHWC
GBN = BatchNorm2d_NHWC(planes, fuse_relu=True, bn_group=1, torch_channels_last=True)
"""
The cases that GBN handles are as follows:
1. torch_channels_last=True and input tensor's
memory_format=torch.channels_last, GBN will generate the
torch.channels_last output tensor.
2. torch_channels_last=True and input tensor's
memory_format=torch.contiguous_format, GBN will convert the input tensor
to torch.channels_last and will generate the torch.channels_last output
tensor.
3. use_pytorch_channels_last=False and input tensor's
memory_format=torch.contiguous_format, GBN will generate the
torch.contiguous_format output tensor.
* Add GBN unit tests for channel_last memory format
Co-authored-by:
hubertlu-tw <hubertlu@amd.com>
Showing
Please register or sign in to comment