Commit 1b9b65ca authored by Thor Johnsen's avatar Thor Johnsen Committed by mcarilli
Browse files

[WIP] Fused layer norm cuda (#69)

* Pre-release of fused layer norm apex extension

* Remove half and __half2 specializations

* Code changes from review
parent 45f030db
......@@ -101,6 +101,7 @@ python setup.py install [--cuda_ext] [--cpp_ext]
Currently, `--cuda_ext` enables
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
- Fused kernels required to use `apex.optimizers.FusedAdam`.
- Fused kernels required to use 'apex.normalization.FusedLayerNorm'.
`--cpp_ext` enables
- C++-side flattening and unflattening utilities that reduce the CPU overhead of `apex.parallel.DistributedDataParallel`.
......
......@@ -7,3 +7,8 @@ try:
from . import optimizers
except ImportError:
print("Warning: apex was installed without --cuda_ext. FusedAdam will be unavailable.")
try:
from . import normalization
except ImportError:
print("Warning: apex was installed without --cuda_ext. FusedLayerNorm will be unavailable.")
from .fused_layer_norm import FusedLayerNorm
#include <torch/extension.h>
#include <vector>
#include <cassert>
namespace {
void compute_n1_n2(
at::Tensor input,
at::IntList normalized_shape,
int& n1,
int& n2)
{
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
assert( input.sizes()[i+idiff] == normalized_shape[i] );
n2 *= normalized_shape[i];
}
n1 = 1;
for (int i = 0; i < idiff; ++i) {
n1 *= input.sizes()[i];
}
}
void check_args(
at::IntList normalized_shape,
at::Tensor gamma,
at::Tensor beta
)
{
AT_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
AT_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}
void check_args(
at::Tensor input,
at::IntList normalized_shape,
int& n1,
int& n2
)
{
int64_t normalized_ndim = normalized_shape.size();
if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape="
<< normalized_shape;
throw std::runtime_error(ss.str());
}
auto input_shape = input.sizes();
auto input_ndim = input.dim();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}
compute_n1_n2(input,normalized_shape,n1,n2);
}
void check_args(
at::Tensor input,
at::IntList normalized_shape,
at::Tensor gamma,
at::Tensor beta,
int& n1,
int& n2
)
{
check_args(input,normalized_shape,n1,n2);
check_args(normalized_shape,gamma,beta);
}
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntList normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon);
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm(
at::Tensor input,
at::IntList normalized_shape,
double epsilon) {
CHECK_INPUT(input);
int n1,n2;
check_args(input,normalized_shape,n1,n2);
at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : input.type().scalarType()));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon);
return {output, mean, invvar};
}
std::vector<at::Tensor> layer_norm_affine(
at::Tensor input,
at::IntList normalized_shape,
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2);
at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.type().scalarType()==at::ScalarType::Half ? at::ScalarType::Float : input.type().scalarType()));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon);
return {output, mean, invvar};
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntList normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_input,
at::Tensor* grad_gamma,
at::Tensor* grad_beta
);
at::Tensor layer_norm_gradient(
at::Tensor dout,
at::Tensor mean,
at::Tensor invvar,
at::Tensor input,
at::IntList normalized_shape,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
int n1,n2;
check_args(input,normalized_shape,n1,n2);
at::Tensor grad_input = at::empty_like(input);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon,
&grad_input,NULL,NULL);
return grad_input;
}
std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor dout,
at::Tensor mean,
at::Tensor invvar,
at::Tensor input,
at::IntList normalized_shape,
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2);
at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon,
&grad_input,&grad_gamma,&grad_beta);
return {grad_input, grad_gamma, grad_beta};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
m.def("forward", &layer_norm, "LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)");
m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)");
}
This diff is collapsed.
import math
import torch
import numbers
from torch.nn.parameter import Parameter
from torch.nn import init
import fused_layer_norm_cuda
class FusedLayerNormAffineFunction(torch.autograd.Function):
def __init__(self, normalized_shape, eps=1e-6):
self.normalized_shape = normalized_shape
self.eps = eps
def forward(self, input, weight, bias):
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward_affine(
input_, self.normalized_shape, weight_, bias_, self.eps)
self.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
def backward(self, grad_output):
input_, weight_, bias_, mean, invvar = self.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, self.normalized_shape,
weight_, bias_, self.eps)
return grad_input, grad_weight, grad_bias;
class FusedLayerNormFunction(torch.autograd.Function):
def __init__(self, normalized_shape, eps=1e-6):
self.normalized_shape = normalized_shape
self.eps = eps
def forward(self, input):
input_ = input.contiguous()
output, mean, invvar = fused_layer_norm_cuda.forward(
input_, self.normalized_shape, self.eps)
self.save_for_backward(input_, mean, invvar)
return output
def backward(self, grad_output):
input_, mean, invvar = self.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.backward(
grad_output.contiguous(), mean, invvar,
input_, self.normalized_shape,
self.eps)
return grad_input
def fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6):
return FusedLayerNormAffineFunction(normalized_shape,eps)(input, weight, bias)
def fused_layer_norm(input, normalized_shape, eps=1e-6):
return FusedLayerNormFunction(normalized_shape,eps)(input)
class FusedLayerNorm(torch.nn.Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ .
Currently only runs on cuda() tensors.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
\times \ldots \times \text{normalized\_shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = apex.normalization.FusedLayerNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = apex.normalization.FusedLayerNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super(FusedLayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
if self.elementwise_affine:
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
if self.elementwise_affine:
return FusedLayerNormAffineFunction(self.normalized_shape,self.eps)(
input, self.weight, self.bias)
else:
return FusedLayerNormFunction(self.normalized_shape,self.eps)(
input)
def extra_repr(self):
return '{normalized_shape}, eps={eps}, ' \
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
......@@ -43,7 +43,14 @@ if "--cuda_ext" in sys.argv:
CUDAExtension(name='syncbn',
sources=['csrc/syncbn.cpp',
'csrc/welford.cu']))
ext_modules.append(
CUDAExtension(name='fused_layer_norm_cuda',
sources=['apex/normalization/csrc/layer_norm_cuda.cpp',
'apex/normalization/csrc/layer_norm_cuda_kernel.cu'],
extra_compile_args={'cxx': ['-O3',],
'nvcc':['-maxrregcount=50',
'-O3',
'--use_fast_math']}))
setup(
name='apex',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment