Unverified Commit 684c4733 authored by eqy's avatar eqy Committed by GitHub
Browse files

FusedRMSNorm/"T5LayerNorm" based on FusedLayerNorm (#1274)



* FusedRMSNorm based on FusedLayerNorm

* refactor duplicated kernels

* delete comments

* delete comments

* cleanup

* cleanup

* cleanup, fixed clobbering forward_affine_mixed_dtypes

* fix pybind naming and add MixedFused test

* undo skipping

* check elementwise_affine

* Update tests/L0/run_fused_layer_norm/test_fused_layer_norm.py

Oof, nice catch, thanks
Co-authored-by: default avatarMasaki Kozuki <masaki.kozuki.2014@gmail.com>
Co-authored-by: default avatarMasaki Kozuki <masaki.kozuki.2014@gmail.com>
parent 89edb819
from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm, FusedRMSNorm, MixedFusedRMSNorm
...@@ -12,6 +12,23 @@ global fused_layer_norm_cuda ...@@ -12,6 +12,23 @@ global fused_layer_norm_cuda
fused_layer_norm_cuda = None fused_layer_norm_cuda = None
# Reference implementation from Huggingface
def manual_rms_norm(input, normalized_shape, weight, eps):
# layer norm should always be calculated in float32
dims = tuple(i for i in range(-1, -len(normalized_shape)-1, -1))
variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True)
input = input * torch.rsqrt(variance + eps)
if weight is None:
return input
# convert into half-precision if necessary
if weight.dtype in [torch.float16, torch.bfloat16]:
input = input.to(self.weight.dtype)
return weight * input
class FusedLayerNormAffineFunction(torch.autograd.Function): class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps): def forward(ctx, input, weight, bias, normalized_shape, eps):
...@@ -39,6 +56,31 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): ...@@ -39,6 +56,31 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None return grad_input, grad_weight, grad_bias, None, None
class FusedRMSNormAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
output, invvar = fused_layer_norm_cuda.rms_forward_affine(
input_, ctx.normalized_shape, weight_, ctx.eps)
ctx.save_for_backward(input_, weight_, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, invvar = ctx.saved_tensors
grad_input = grad_weight = None
grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine(
grad_output.contiguous(), invvar, input_, ctx.normalized_shape, weight_, ctx.eps
)
return grad_input, grad_weight, None, None
class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction): class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction):
@staticmethod @staticmethod
...@@ -58,6 +100,25 @@ class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction): ...@@ -58,6 +100,25 @@ class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction):
return output return output
class FusedRMSNormAffineMixedDtypesFunction(FusedRMSNormAffineFunction):
@staticmethod
def forward(ctx, input, weight, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
output, invvar = fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes(
input_, ctx.normalized_shape, weight_, ctx.eps
)
ctx.save_for_backward(input_, weight_, invvar)
return output
class FusedLayerNormFunction(torch.autograd.Function): class FusedLayerNormFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, normalized_shape, eps): def forward(ctx, input, normalized_shape, eps):
...@@ -81,6 +142,29 @@ class FusedLayerNormFunction(torch.autograd.Function): ...@@ -81,6 +142,29 @@ class FusedLayerNormFunction(torch.autograd.Function):
return grad_input, None, None return grad_input, None, None
class FusedRMSNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
output, invvar = fused_layer_norm_cuda.rms_forward(input_, ctx.normalized_shape, ctx.eps)
ctx.save_for_backward(input_, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, invvar = ctx.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.rms_backward(
grad_output.contiguous(), invvar, input_, ctx.normalized_shape, ctx.eps
)
return grad_input, None, None
def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6): def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps) args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
...@@ -99,6 +183,24 @@ def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, e ...@@ -99,6 +183,24 @@ def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, e
return FusedLayerNormAffineMixedDtypesFunction.apply(*args) return FusedLayerNormAffineMixedDtypesFunction.apply(*args)
def fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedRMSNormAffineFunction.apply(*args)
def fused_rms_norm(input, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedRMSNormFunction.apply(*args)
def mixed_dtype_fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedRMSNormAffineMixedDtypesFunction.apply(*args)
class FusedLayerNorm(torch.nn.Module): class FusedLayerNorm(torch.nn.Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in r"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ . the paper `Layer Normalization`_ .
...@@ -195,6 +297,99 @@ class FusedLayerNorm(torch.nn.Module): ...@@ -195,6 +297,99 @@ class FusedLayerNorm(torch.nn.Module):
return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__)
class FusedRMSNorm(torch.nn.Module):
r"""Applies RMS Normalization over a mini-batch of inputs
Currently only runs on cuda() tensors.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1]
\times \ldots \times \text{normalized}\_\text{shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = apex.normalization.FusedRMSNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = apex.normalization.FusedRMSNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = apex.normalization.FusedRMSNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = apex.normalization.FusedRMSNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super().__init__()
global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.Tensor(*normalized_shape))
else:
self.register_parameter("weight", None)
self.reset_parameters()
def reset_parameters(self):
if self.elementwise_affine:
init.ones_(self.weight)
def forward(self, input):
if not input.is_cuda:
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)
if self.elementwise_affine:
return fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps)
else:
return fused_rms_norm(input, self.normalized_shape, self.eps)
def extra_repr(self):
return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__)
# NOTE (mkozuki): Why "mixed"? # NOTE (mkozuki): Why "mixed"?
# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype # MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype
# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype. # as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype.
...@@ -216,3 +411,26 @@ class MixedFusedLayerNorm(FusedLayerNorm): ...@@ -216,3 +411,26 @@ class MixedFusedLayerNorm(FusedLayerNorm):
if not input.is_cuda: if not input.is_cuda:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
return mixed_dtype_fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps) return mixed_dtype_fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps)
# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype
# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype.
# See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp"
class MixedFusedRMSNorm(FusedRMSNorm):
def __init__(self, normalized_shape, eps=1e-5, **kwargs):
if "elementwise_affine" in kwargs:
import warnings
warnings.warn("MixedFusedRMSNorm does not support `elementwise_affine` argument")
elementwise_affine = kwargs.pop("elementwise_affine")
if not elementwise_affine:
raise RuntimeError("MixedFusedRMSNorm does not support `elementwise_affine = False`")
super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True)
def forward(self, input: torch.Tensor):
# NOTE (mkozuki): CPU path is here mainly for unittest sake.
# TODO Manual RMS Norm Implementation Here
if not input.is_cuda:
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)
return mixed_dtype_fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps)
...@@ -40,6 +40,19 @@ void check_args( ...@@ -40,6 +40,19 @@ void check_args(
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
} }
void check_args(
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma
)
{
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
}
void check_args( void check_args(
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1 #ifdef VERSION_GE_1_1
...@@ -79,7 +92,6 @@ void check_args( ...@@ -79,7 +92,6 @@ void check_args(
compute_n1_n2(input,normalized_shape,n1,n2); compute_n1_n2(input,normalized_shape,n1,n2);
} }
void check_args( void check_args(
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1 #ifdef VERSION_GE_1_1
...@@ -96,6 +108,22 @@ void check_args( ...@@ -96,6 +108,22 @@ void check_args(
check_args(input,normalized_shape,n1,n2); check_args(input,normalized_shape,n1,n2);
check_args(normalized_shape,gamma,beta); check_args(normalized_shape,gamma,beta);
} }
void check_args(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
int& n1,
int& n2
)
{
check_args(input,normalized_shape,n1,n2);
check_args(normalized_shape,gamma);
}
} }
void cuda_layer_norm( void cuda_layer_norm(
...@@ -256,6 +284,147 @@ std::vector<at::Tensor> layer_norm_gradient_affine( ...@@ -256,6 +284,147 @@ std::vector<at::Tensor> layer_norm_gradient_affine(
return {grad_input, grad_gamma, grad_beta}; return {grad_input, grad_gamma, grad_beta};
} }
void cuda_rms_norm(
at::Tensor* output,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> rms_norm(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
double epsilon) {
CHECK_INPUT(input);
int n1,n2;
check_args(input,normalized_shape,n1,n2);
at::Tensor output = at::empty_like(input);
at::Tensor invvar = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type()));
cuda_rms_norm(&output,&invvar,&input,n1,n2,
normalized_shape,NULL,epsilon);
return {output, invvar};
}
std::vector<at::Tensor> rms_norm_affine(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
double epsilon) {
CHECK_INPUT(input);
CHECK_INPUT(gamma);
int n1,n2;
check_args(input,normalized_shape,gamma,n1,n2);
at::Tensor output = at::empty_like(input);
const auto stats_dtype = (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type();
at::Tensor invvar = at::empty({n1}, input.options().dtype(stats_dtype));
cuda_rms_norm(&output,&invvar,&input,n1,n2,
normalized_shape,&gamma,epsilon);
return {output, invvar};
}
std::vector<at::Tensor> rms_norm_affine_mixed_dtypes(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
double epsilon) {
CHECK_INPUT(input);
int n1, n2;
check_args(input, normalized_shape, n1, n2);
at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor invvar = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type()));
cuda_rms_norm(&output,&invvar, &input, n1, n2,
normalized_shape, &gamma,epsilon);
return {output,invvar};
}
void cuda_rms_norm_gradient(
at::Tensor* dout,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
double epsilon,
at::Tensor* grad_input,
at::Tensor* grad_gamma);
at::Tensor rms_norm_gradient(
at::Tensor dout,
at::Tensor invvar,
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
int n1,n2;
check_args(input,normalized_shape,n1,n2);
at::Tensor grad_input = at::empty_like(input);
cuda_rms_norm_gradient(&dout,&invvar,&input,n1,n2,
normalized_shape,NULL,epsilon,
&grad_input,NULL);
return grad_input;
}
std::vector<at::Tensor> rms_norm_gradient_affine(
at::Tensor dout,
at::Tensor invvar,
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
int n1,n2;
check_args(input,normalized_shape,gamma,n1,n2);
at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma);
cuda_rms_norm_gradient(&dout,&invvar,&input,n1,n2,
normalized_shape,&gamma,epsilon,
&grad_input,&grad_gamma);
return {grad_input, grad_gamma};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
m.def("forward", &layer_norm, "LayerNorm forward (CUDA)"); m.def("forward", &layer_norm, "LayerNorm forward (CUDA)");
...@@ -263,5 +432,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -263,5 +432,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)"); m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)");
m.def("forward_affine_mixed_dtypes", &layer_norm_affine_mixed_dtypes, "LayerNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation"); m.def("forward_affine_mixed_dtypes", &layer_norm_affine_mixed_dtypes, "LayerNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation");
}
m.def("rms_forward_affine", &rms_norm_affine, "RMSNorm forward (CUDA)");
m.def("rms_forward", &rms_norm, "RMSNorm forward (CUDA)");
m.def("rms_backward_affine", &rms_norm_gradient_affine, "RMSNorm backward (CUDA)");
m.def("rms_backward", &rms_norm_gradient, "RMSNorm backward (CUDA)");
m.def("rms_forward_affine_mixed_dtypes", &rms_norm_affine_mixed_dtypes, "RMSNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation");
}
This diff is collapsed.
...@@ -13,13 +13,22 @@ class TestFusedLayerNorm(unittest.TestCase): ...@@ -13,13 +13,22 @@ class TestFusedLayerNorm(unittest.TestCase):
rtol, atol = None, None rtol, atol = None, None
fwd_thresholds = dict(rtol=None, atol=None) fwd_thresholds = dict(rtol=None, atol=None)
bwd_thresholds = dict(rtol=None, atol=None) bwd_thresholds = dict(rtol=None, atol=None)
mixed_fused = False
def setUp(self): def setUp(self):
# bias and weight are set to 0 and 1 respectively, so no need to copy parameters from cpu module to the gpu one # bias and weight are set to 0 and 1 respectively, so no need to copy parameters from cpu module to the gpu one
self.module_cpu_ = apex.normalization.FusedLayerNorm( if not self.mixed_fused:
normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).cpu() self.module_cpu_ = apex.normalization.FusedLayerNorm(
self.module_cuda_ = apex.normalization.FusedLayerNorm( normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).cpu()
normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).to(device="cuda", dtype=self.dtype) self.module_cuda_ = apex.normalization.FusedLayerNorm(
normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).to(device="cuda", dtype=self.dtype)
else:
assert self.elementwise_affine
self.module_cpu_ = apex.normalization.MixedFusedLayerNorm(
normalized_shape=self.normalized_shape).cpu()
self.module_cuda_ = apex.normalization.MixedFusedLayerNorm(
normalized_shape=self.normalized_shape).to(device="cuda", dtype=self.dtype)
def _check_same_output(self, batch_size, contiguous): def _check_same_output(self, batch_size, contiguous):
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
...@@ -66,9 +75,83 @@ class TestFusedLayerNorm(unittest.TestCase): ...@@ -66,9 +75,83 @@ class TestFusedLayerNorm(unittest.TestCase):
self._test_same_output(65536) self._test_same_output(65536)
class TestFusedRMSNorm(unittest.TestCase):
dtype = torch.float
elementwise_affine = False
normalized_shape = [32, 16]
rtol, atol = None, None
fwd_thresholds = dict(rtol=None, atol=None)
bwd_thresholds = dict(rtol=None, atol=None)
mixed_fused = False
def setUp(self):
# bias and weight are set to 0 and 1 respectively, so no need to copy parameters from cpu module to the gpu one
if not self.mixed_fused:
self.module_cpu_ = apex.normalization.FusedRMSNorm(
normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).cpu()
self.module_cuda_ = apex.normalization.FusedRMSNorm(
normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).to(device="cuda", dtype=self.dtype)
else:
assert self.elementwise_affine
self.module_cpu_ = apex.normalization.MixedFusedRMSNorm(
normalized_shape=self.normalized_shape).cpu()
self.module_cuda_ = apex.normalization.MixedFusedRMSNorm(
normalized_shape=self.normalized_shape).to(device="cuda", dtype=self.dtype)
def _check_same_output(self, batch_size, contiguous):
torch.cuda.manual_seed(42)
if contiguous:
input_shape = [batch_size] + self.normalized_shape
input_ = torch.randn(input_shape, device="cpu").requires_grad_(True)
input_cuda_ = input_.to(device="cuda", dtype=self.dtype).detach().requires_grad_(True)
self.assertTrue(input_.is_contiguous())
self.assertTrue(input_cuda_.is_contiguous())
else:
input_shape = [batch_size] + self.normalized_shape
input_shape = [batch_size * 3] + [self.normalized_shape[0] * 5, self.normalized_shape[1] * 3]
input_src_ = torch.randn(input_shape, device="cpu")
input_ = input_src_[::3, ::5, ::3].detach().requires_grad_(True)
input_cuda_ = input_src_.to(device="cuda", dtype=self.dtype)[::3, ::5, ::3].detach().requires_grad_(True)
# make sure that tensors are NOT contiguous.
self.assertFalse(input_.is_contiguous())
self.assertFalse(input_cuda_.is_contiguous())
out_cpu_ = self.module_cpu_(input_)
gO = torch.rand_like(out_cpu_)
out_cpu_.backward(gO)
out_cuda_ = self.module_cuda_(input_cuda_)
# TODO (mkozuki): `torch.testing.assert_allclose` is deprecated.
# Use `torch.testing.assert_close`.
# See https://github.com/pytorch/pytorch/issues/61844
torch.testing.assert_allclose(
out_cpu_.to(device="cuda", dtype=self.dtype), out_cuda_.clone().detach(), **self.fwd_thresholds)
gO = gO.to(device="cuda", dtype=self.dtype)
out_cuda_.backward(gO)
self.assertFalse(out_cpu_.is_cuda)
self.assertTrue(out_cuda_.is_cuda)
torch.testing.assert_allclose(
input_.grad.to(device="cuda", dtype=self.dtype), input_cuda_.grad, **self.bwd_thresholds)
if self.elementwise_affine:
torch.testing.assert_allclose(self.module_cpu_.weight.grad.to(device="cuda", dtype=self.dtype),
self.module_cuda_.weight.grad, **self.bwd_thresholds)
def _test_same_output(self, batch_size):
for contiguous in (True, False):
with self.subTest(contiguous=contiguous):
self._check_same_output(batch_size, contiguous)
def test_layer_norm(self):
self._test_same_output(16)
def test_large_batch(self):
self._test_same_output(65536)
class TestFusedLayerNormElemWise(TestFusedLayerNorm): class TestFusedLayerNormElemWise(TestFusedLayerNorm):
elementwise_affine = True elementwise_affine = True
class TestMixedFusedLayerNormElemWise(TestFusedLayerNorm):
elementwise_affine = True
mixed_fused = True
class TestFusedLayerNormElemWiseHalf(TestFusedLayerNormElemWise): class TestFusedLayerNormElemWiseHalf(TestFusedLayerNormElemWise):
dtype = torch.half dtype = torch.half
...@@ -76,6 +159,34 @@ class TestFusedLayerNormElemWiseHalf(TestFusedLayerNormElemWise): ...@@ -76,6 +159,34 @@ class TestFusedLayerNormElemWiseHalf(TestFusedLayerNormElemWise):
def test_large_batch(self): def test_large_batch(self):
self.skipTest("Skip to save time") self.skipTest("Skip to save time")
class TestFusedLayerNormElemWiseBFloat16(TestFusedLayerNormElemWise):
dtype = torch.bfloat16
# NOTE (mkozuki): [BFloat16 Layer Norm flakiness]
# Use thresholds larger than those used in pytorch, see
# https://github.com/pytorch/pytorch/blob/72274e2a2fd55019ec860e1743dbdc5b0c5a5624/torch/testing/_asserts.py#L26
fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)
bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
def test_large_batch(self):
self.skipTest("Skip to save time")
class TestFusedRMSNormElemWise(TestFusedRMSNorm):
bwd_thresholds = dict(rtol=2e-3, atol=2e-4)
elementwise_affine = True
class TestMixedFusedRMSNormElemWise(TestFusedRMSNorm):
bwd_thresholds = dict(rtol=2e-3, atol=2e-4)
elementwise_affine = True
mixed_fused = True
class TestFusedRMSNormElemWiseHalf(TestFusedRMSNormElemWise):
dtype = torch.half
bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
def test_large_batch(self):
self.skipTest("Skip to save time")
class TestFusedLayerNormElemWiseBFloat16(TestFusedLayerNormElemWise): class TestFusedLayerNormElemWiseBFloat16(TestFusedLayerNormElemWise):
dtype = torch.bfloat16 dtype = torch.bfloat16
...@@ -99,6 +210,16 @@ def _prep_layers(normalized_shape, elementwise_affine, dtype): ...@@ -99,6 +210,16 @@ def _prep_layers(normalized_shape, elementwise_affine, dtype):
return native, fused return native, fused
def _prep_rms_layers(normalized_shape, elementwise_affine, dtype):
native = apex.normalization.FusedRMSNorm(
normalized_shape=normalized_shape, elementwise_affine=elementwise_affine
)
fused = apex.normalization.FusedRMSNorm(
normalized_shape=normalized_shape, elementwise_affine=elementwise_affine
).cuda()
return native, fused
def _prep_inputs(batch_size, normalized_shape, dtype): def _prep_inputs(batch_size, normalized_shape, dtype):
shape = (batch_size, *normalized_shape) shape = (batch_size, *normalized_shape)
fused = torch.randn(shape).cuda().requires_grad_(True) fused = torch.randn(shape).cuda().requires_grad_(True)
...@@ -109,7 +230,6 @@ def _prep_inputs(batch_size, normalized_shape, dtype): ...@@ -109,7 +230,6 @@ def _prep_inputs(batch_size, normalized_shape, dtype):
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,) autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
class TestAutocastFusedLayerNorm(unittest.TestCase): class TestAutocastFusedLayerNorm(unittest.TestCase):
bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)
bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
...@@ -141,3 +261,35 @@ class TestAutocastFusedLayerNorm(unittest.TestCase): ...@@ -141,3 +261,35 @@ class TestAutocastFusedLayerNorm(unittest.TestCase):
for (dtype, elementwise_affine) in itertools.product(autocast_dtypes, (True, False)): for (dtype, elementwise_affine) in itertools.product(autocast_dtypes, (True, False)):
with self.subTest(f"{dtype}-{elementwise_affine}"): with self.subTest(f"{dtype}-{elementwise_affine}"):
self._run_test(dtype, elementwise_affine) self._run_test(dtype, elementwise_affine)
class TestAutocastFusedRMSNorm(unittest.TestCase):
bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)
bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
def setUp(self):
self.batch_size = 16
self.normalized_shape = [32, 16]
def _run_test(self, dtype, elementwise_affine):
native, fused = _prep_rms_layers(self.normalized_shape, elementwise_affine, dtype)
native_x, fused_x = _prep_inputs(self.batch_size, self.normalized_shape, dtype)
expected = native(native_x.cpu())
with torch.cuda.amp.autocast(dtype=dtype):
actual = fused(fused_x)
tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedRMSNorm.bf16_fwd_thresholds
torch.testing.assert_allclose(actual, expected.detach().clone().cuda(), **tols)
g_native = torch.rand_like(expected)
with torch.no_grad():
g_fused = g_native.detach().clone().cuda()
expected.backward(g_native)
actual.backward(g_fused)
tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedRMSNorm.bf16_bwd_thresholds
torch.testing.assert_allclose(native_x.grad.cuda(), fused_x.grad, **tols)
def test_autocast(self):
for (dtype, elementwise_affine) in itertools.product(autocast_dtypes, (True, False)):
with self.subTest(f"{dtype}-{elementwise_affine}"):
self._run_test(dtype, elementwise_affine)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment