Unverified Commit c97ebfab authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Enable FusedRMSNorm (#78)



* FusedRMSNorm/"T5LayerNorm" based on FusedLayerNorm (#1274)

* FusedRMSNorm based on FusedLayerNorm

* refactor duplicated kernels

* delete comments

* delete comments

* cleanup

* cleanup

* cleanup, fixed clobbering forward_affine_mixed_dtypes

* fix pybind naming and add MixedFused test

* undo skipping

* check elementwise_affine

* Update tests/L0/run_fused_layer_norm/test_fused_layer_norm.py

Oof, nice catch, thanks
Co-authored-by: default avatarMasaki Kozuki <masaki.kozuki.2014@gmail.com>
Co-authored-by: default avatarMasaki Kozuki <masaki.kozuki.2014@gmail.com>

* fix and generate docs for FusedRMSNorm (#1285)

* [FusedRMSNorm doc] document where epsilon is added (#1295)

* [FusedRMSNorm doc] add epsilon to formula

* correct

* better wording

* Fix some bugs

* Optimize HostRMSNormGradient and HostApplyRMSNorm for AMD GPUs

* Fix NaN issues in FusedRMSNorm

* Update test_fused_layer_norm.py

* Skip test_fused_layer_norm.TestAutocastFusedRMSNorm on ROCm

* Use at::cuda::warp_size() instead of at::cuda::getCurrentDeviceProperties()->warpSize
Co-authored-by: default avatareqy <eddiey@nvidia.com>
Co-authored-by: default avatarMasaki Kozuki <masaki.kozuki.2014@gmail.com>
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
parent cf77e9b5
from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm
from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm, FusedRMSNorm, MixedFusedRMSNorm
......@@ -12,6 +12,23 @@ global fused_layer_norm_cuda
fused_layer_norm_cuda = None
# Reference implementation from Huggingface
def manual_rms_norm(input, normalized_shape, weight, eps):
# layer norm should always be calculated in float32
dims = tuple(i for i in range(-1, -len(normalized_shape)-1, -1))
variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True)
input = input * torch.rsqrt(variance + eps)
if weight is None:
return input
# convert into half-precision if necessary
if weight.dtype in [torch.float16, torch.bfloat16]:
input = input.to(self.weight.dtype)
return weight * input
class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
......@@ -39,6 +56,31 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None
class FusedRMSNormAffineFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
output, invvar = fused_layer_norm_cuda.rms_forward_affine(
input_, ctx.normalized_shape, weight_, ctx.eps)
ctx.save_for_backward(input_, weight_, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, invvar = ctx.saved_tensors
grad_input = grad_weight = None
grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine(
grad_output.contiguous(), invvar, input_, ctx.normalized_shape, weight_, ctx.eps
)
return grad_input, grad_weight, None, None
class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction):
@staticmethod
......@@ -58,6 +100,25 @@ class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction):
return output
class FusedRMSNormAffineMixedDtypesFunction(FusedRMSNormAffineFunction):
@staticmethod
def forward(ctx, input, weight, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
output, invvar = fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes(
input_, ctx.normalized_shape, weight_, ctx.eps
)
ctx.save_for_backward(input_, weight_, invvar)
return output
class FusedLayerNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, normalized_shape, eps):
......@@ -81,6 +142,29 @@ class FusedLayerNormFunction(torch.autograd.Function):
return grad_input, None, None
class FusedRMSNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, normalized_shape, eps):
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
output, invvar = fused_layer_norm_cuda.rms_forward(input_, ctx.normalized_shape, ctx.eps)
ctx.save_for_backward(input_, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, invvar = ctx.saved_tensors
grad_input = None
grad_input = fused_layer_norm_cuda.rms_backward(
grad_output.contiguous(), invvar, input_, ctx.normalized_shape, ctx.eps
)
return grad_input, None, None
def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
......@@ -99,6 +183,24 @@ def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, e
return FusedLayerNormAffineMixedDtypesFunction.apply(*args)
def fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedRMSNormAffineFunction.apply(*args)
def fused_rms_norm(input, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedRMSNormFunction.apply(*args)
def mixed_dtype_fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6):
args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps)
with torch.cuda.amp.autocast(enabled=False):
return FusedRMSNormAffineMixedDtypesFunction.apply(*args)
class FusedLayerNorm(torch.nn.Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ .
......@@ -195,6 +297,100 @@ class FusedLayerNorm(torch.nn.Module):
return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__)
class FusedRMSNorm(torch.nn.Module):
r"""Applies RMS Normalization over a mini-batch of inputs
Currently only runs on cuda() tensors.
.. math::
y = \frac{x}{\mathrm{RMS}[x]} * \gamma
The root-mean-square is calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` is a learnable affine transform parameter of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
`epsilon` is added to the mean-square, then the root of the sum is taken.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, RMS Normalization applies per-element scale
with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1]
\times \ldots \times \text{normalized}\_\text{shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = apex.normalization.FusedRMSNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = apex.normalization.FusedRMSNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = apex.normalization.FusedRMSNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = apex.normalization.FusedRMSNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Root Mean Square Layer Normalization`: https://arxiv.org/pdf/1910.07467.pdf
"""
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super().__init__()
global fused_layer_norm_cuda
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.Tensor(*normalized_shape))
else:
self.register_parameter("weight", None)
self.reset_parameters()
def reset_parameters(self):
if self.elementwise_affine:
init.ones_(self.weight)
def forward(self, input):
if not input.is_cuda:
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)
if self.elementwise_affine:
return fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps)
else:
return fused_rms_norm(input, self.normalized_shape, self.eps)
def extra_repr(self):
return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__)
# NOTE (mkozuki): Why "mixed"?
# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype
# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype.
......@@ -216,3 +412,26 @@ class MixedFusedLayerNorm(FusedLayerNorm):
if not input.is_cuda:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
return mixed_dtype_fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps)
# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype
# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype.
# See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp"
class MixedFusedRMSNorm(FusedRMSNorm):
def __init__(self, normalized_shape, eps=1e-5, **kwargs):
if "elementwise_affine" in kwargs:
import warnings
warnings.warn("MixedFusedRMSNorm does not support `elementwise_affine` argument")
elementwise_affine = kwargs.pop("elementwise_affine")
if not elementwise_affine:
raise RuntimeError("MixedFusedRMSNorm does not support `elementwise_affine = False`")
super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True)
def forward(self, input: torch.Tensor):
# NOTE (mkozuki): CPU path is here mainly for unittest sake.
# TODO Manual RMS Norm Implementation Here
if not input.is_cuda:
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)
return mixed_dtype_fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps)
......@@ -40,6 +40,19 @@ void check_args(
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}
void check_args(
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma
)
{
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
}
void check_args(
at::Tensor input,
#ifdef VERSION_GE_1_1
......@@ -79,7 +92,6 @@ void check_args(
compute_n1_n2(input,normalized_shape,n1,n2);
}
void check_args(
at::Tensor input,
#ifdef VERSION_GE_1_1
......@@ -96,6 +108,22 @@ void check_args(
check_args(input,normalized_shape,n1,n2);
check_args(normalized_shape,gamma,beta);
}
void check_args(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
int& n1,
int& n2
)
{
check_args(input,normalized_shape,n1,n2);
check_args(normalized_shape,gamma);
}
}
void cuda_layer_norm(
......@@ -256,6 +284,147 @@ std::vector<at::Tensor> layer_norm_gradient_affine(
return {grad_input, grad_gamma, grad_beta};
}
void cuda_rms_norm(
at::Tensor* output,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> rms_norm(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
double epsilon) {
CHECK_INPUT(input);
int n1,n2;
check_args(input,normalized_shape,n1,n2);
at::Tensor output = at::empty_like(input);
at::Tensor invvar = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type()));
cuda_rms_norm(&output,&invvar,&input,n1,n2,
normalized_shape,NULL,epsilon);
return {output, invvar};
}
std::vector<at::Tensor> rms_norm_affine(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
double epsilon) {
CHECK_INPUT(input);
CHECK_INPUT(gamma);
int n1,n2;
check_args(input,normalized_shape,gamma,n1,n2);
at::Tensor output = at::empty_like(input);
const auto stats_dtype = (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type();
at::Tensor invvar = at::empty({n1}, input.options().dtype(stats_dtype));
cuda_rms_norm(&output,&invvar,&input,n1,n2,
normalized_shape,&gamma,epsilon);
return {output, invvar};
}
std::vector<at::Tensor> rms_norm_affine_mixed_dtypes(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
double epsilon) {
CHECK_INPUT(input);
int n1, n2;
check_args(input, normalized_shape, n1, n2);
at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor invvar = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type()));
cuda_rms_norm(&output,&invvar, &input, n1, n2,
normalized_shape, &gamma,epsilon);
return {output,invvar};
}
void cuda_rms_norm_gradient(
at::Tensor* dout,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
double epsilon,
at::Tensor* grad_input,
at::Tensor* grad_gamma);
at::Tensor rms_norm_gradient(
at::Tensor dout,
at::Tensor invvar,
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
int n1,n2;
check_args(input,normalized_shape,n1,n2);
at::Tensor grad_input = at::empty_like(input);
cuda_rms_norm_gradient(&dout,&invvar,&input,n1,n2,
normalized_shape,NULL,epsilon,
&grad_input,NULL);
return grad_input;
}
std::vector<at::Tensor> rms_norm_gradient_affine(
at::Tensor dout,
at::Tensor invvar,
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
int n1,n2;
check_args(input,normalized_shape,gamma,n1,n2);
at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma);
cuda_rms_norm_gradient(&dout,&invvar,&input,n1,n2,
normalized_shape,&gamma,epsilon,
&grad_input,&grad_gamma);
return {grad_input, grad_gamma};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
m.def("forward", &layer_norm, "LayerNorm forward (CUDA)");
......@@ -263,5 +432,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)");
m.def("forward_affine_mixed_dtypes", &layer_norm_affine_mixed_dtypes, "LayerNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation");
}
m.def("rms_forward_affine", &rms_norm_affine, "RMSNorm forward (CUDA)");
m.def("rms_forward", &rms_norm, "RMSNorm forward (CUDA)");
m.def("rms_backward_affine", &rms_norm_gradient_affine, "RMSNorm backward (CUDA)");
m.def("rms_backward", &rms_norm_gradient, "RMSNorm backward (CUDA)");
m.def("rms_forward_affine_mixed_dtypes", &rms_norm_affine_mixed_dtypes, "RMSNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation");
}
......@@ -49,6 +49,23 @@ void cuChanOnlineSum(
}
}
template<typename U> __device__
void cuRMSOnlineSum(
const U curr,
U& sigma2)
{
sigma2 = sigma2 + curr * curr;
}
template<typename U> __device__
void cuChanRMSOnlineSum(
const U sigma2B,
U& sigma2)
{
sigma2 = sigma2 + sigma2B;
}
template<typename T, typename U> __device__
void cuWelfordMuSigma2(
const T* __restrict__ vals,
......@@ -58,7 +75,8 @@ void cuWelfordMuSigma2(
U& mu,
U& sigma2,
U* buf,
const int GPU_WARP_SIZE)
const int GPU_WARP_SIZE,
bool rms_only)
{
// Assumptions:
// 1) blockDim.x == warpSize
......@@ -80,20 +98,32 @@ void cuWelfordMuSigma2(
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
U curr = static_cast<U>(lvals[l+k]);
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
if (!rms_only) {
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
} else {
cuRMSOnlineSum<U>(curr, sigma2);
}
}
}
for (; l < n2; ++l) {
U curr = static_cast<U>(lvals[l]);
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
if (!rms_only) {
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
} else {
cuRMSOnlineSum<U>(curr, sigma2);
}
}
// intra-warp reductions
#pragma unroll
for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) {
U muB = WARP_SHFL_DOWN(mu, stride);
U countB = WARP_SHFL_DOWN(count, stride);
for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) {
U sigma2B = WARP_SHFL_DOWN(sigma2, stride);
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
if (!rms_only) {
U muB = WARP_SHFL_DOWN(mu, stride);
U countB = WARP_SHFL_DOWN(count, stride);
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
} else {
cuChanRMSOnlineSum<U>(sigma2B, sigma2);
}
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
......@@ -104,32 +134,44 @@ void cuWelfordMuSigma2(
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2*wrt_y] = mu;
if (!rms_only) {
ubuf[2*wrt_y] = mu;
ibuf[wrt_y] = count;
}
ubuf[2*wrt_y+1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
U muB = ubuf[2*threadIdx.y];
U sigma2B = ubuf[2*threadIdx.y+1];
U countB = ibuf[threadIdx.y];
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
if (!rms_only) {
U muB = ubuf[2*threadIdx.y];
U countB = ibuf[threadIdx.y];
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
} else {
cuChanRMSOnlineSum<U>(sigma2B,sigma2);
}
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
if (!rms_only) {
ubuf[0] = mu;
}
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
if (!rms_only) {
mu = ubuf[0];
}
sigma2 = ubuf[1]/U(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2 / U(n2), 0);
if (!rms_only) {
mu = WARP_SHFL(mu, 0);
}
sigma2 = WARP_SHFL(sigma2/U(n2), 0);
}
}
}
......@@ -143,7 +185,8 @@ void cuWelfordMuSigma2(
float& mu,
float& sigma2,
float* buf,
const int GPU_WARP_SIZE)
const int GPU_WARP_SIZE,
bool rms_only)
{
// Assumptions:
// 1) blockDim.x == warpSize
......@@ -167,7 +210,12 @@ void cuWelfordMuSigma2(
// first thread consumes first point
if (thrx == 0) {
float curr = static_cast<float>(lvals[0]);
cuWelfordOnlineSum(curr,mu,sigma2,count);
if (!rms_only) {
cuWelfordOnlineSum(curr,mu,sigma2,count);
} else {
cuRMSOnlineSum(curr, sigma2);
}
}
++l;
}
......@@ -175,21 +223,34 @@ void cuWelfordMuSigma2(
for (; l+7 < n2; l+=8*numx) {
for (int k = 0; k < 8; k+=2) {
float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
cuWelfordOnlineSum(curr.x,mu,sigma2,count);
cuWelfordOnlineSum(curr.y,mu,sigma2,count);
if (!rms_only) {
cuWelfordOnlineSum(curr.x,mu,sigma2,count);
cuWelfordOnlineSum(curr.y,mu,sigma2,count);
} else {
cuRMSOnlineSum(curr.x, sigma2);
cuRMSOnlineSum(curr.y, sigma2);
}
}
}
for (; l < n2; ++l) {
float curr = static_cast<float>(lvals[l]);
cuWelfordOnlineSum(curr,mu,sigma2,count);
if (!rms_only) {
cuWelfordOnlineSum(curr,mu,sigma2,count);
} else {
cuRMSOnlineSum(curr, sigma2);
}
}
// intra-warp reductions
#pragma unroll
for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { // TODO
float muB = WARP_SHFL_DOWN(mu, stride);
float countB = WARP_SHFL_DOWN(count, stride);
for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) {
float sigma2B = WARP_SHFL_DOWN(sigma2, stride);
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
if (!rms_only) {
float muB = WARP_SHFL_DOWN(mu, stride);
float countB = WARP_SHFL_DOWN(count, stride);
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
} else {
cuChanRMSOnlineSum(sigma2B, sigma2);
}
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
......@@ -200,32 +261,44 @@ void cuWelfordMuSigma2(
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2*wrt_y] = mu;
ubuf[2*wrt_y+1] = sigma2;
ibuf[wrt_y] = count;
if (!rms_only) {
ubuf[2*wrt_y] = mu;
ibuf[wrt_y] = count;
}
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
float muB = ubuf[2*threadIdx.y];
float sigma2B = ubuf[2*threadIdx.y+1];
float countB = ibuf[threadIdx.y];
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
if (!rms_only) {
float muB = ubuf[2*threadIdx.y];
float countB = ibuf[threadIdx.y];
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
} else {
cuChanRMSOnlineSum(sigma2B, sigma2);
}
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
if (!rms_only) {
ubuf[0] = mu;
}
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
if (!rms_only) {
mu = ubuf[0];
}
sigma2 = ubuf[1]/float(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2 / float(n2), 0);
if (!rms_only) {
mu = WARP_SHFL(mu, 0);
}
sigma2 = WARP_SHFL(sigma2/float(n2), 0);
}
}
}
......@@ -296,8 +369,8 @@ void cuApplyLayerNorm_(
const U epsilon,
const V* __restrict__ gamma,
const V* __restrict__ beta,
const int GPU_WARP_SIZE
)
const int GPU_WARP_SIZE,
bool rms_only)
{
// Assumptions:
// 1) blockDim.x == warpSize
......@@ -307,25 +380,36 @@ void cuApplyLayerNorm_(
SharedMemory<U> shared;
U* buf = shared.getPointer();
U mu,sigma2;
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf, GPU_WARP_SIZE);
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf, GPU_WARP_SIZE, rms_only);
const T* lvals = vals + i1*n2;
V* ovals = output_vals + i1*n2;
U c_invvar = rsqrt(sigma2 + epsilon);
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL && beta != NULL) {
if (gamma != NULL && (beta != NULL || rms_only)) {
for (int i = thrx; i < n2; i+=numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
if (!rms_only) {
ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
} else {
ovals[i] = gamma[i] * static_cast<V>(c_invvar * curr);
}
}
} else {
for (int i = thrx; i < n2; i+=numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = static_cast<V>(c_invvar * (curr - mu));
if (!rms_only) {
ovals[i] = static_cast<V>(c_invvar * (curr - mu));
} else {
ovals[i] = static_cast<V>(c_invvar * curr);
}
}
}
if (threadIdx.x == 0 && threadIdx.y == 0) {
mean[i1] = mu;
if (!rms_only) {
mean[i1] = mu;
}
invvar[i1] = c_invvar;
}
__syncthreads();
......@@ -345,7 +429,21 @@ void cuApplyLayerNorm(
const V* __restrict__ beta,
const int warp_size)
{
cuApplyLayerNorm_<T, U, V>(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, warp_size);
cuApplyLayerNorm_<T, U, V>(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, warp_size, false);
}
template<typename T, typename U, typename V=T> __global__
void cuApplyRMSNorm(
V* __restrict__ output_vals,
U* __restrict__ invvar,
const T* __restrict__ vals,
const int n1,
const int n2,
const U epsilon,
const V* __restrict__ gamma,
const int warp_size)
{
cuApplyLayerNorm_<T, U, V>(output_vals, NULL, invvar, vals, n1, n2, epsilon, gamma, NULL, warp_size, true);
}
template<typename T, typename U, typename V> __device__
......@@ -362,12 +460,16 @@ void cuLoadWriteStridedInputs(
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
const U* __restrict__ invvar,
bool rms_only
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_mean;
if (!rms_only) {
curr_mean = mean[i1];
}
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
......@@ -376,17 +478,25 @@ void cuLoadWriteStridedInputs(
if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
if (!rms_only) {
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf2[write_idx] = curr_dout * (curr_input) * curr_invvar;
}
} else {
warp_buf1[write_idx] = U(0);
if (!rms_only) {
warp_buf1[write_idx] = U(0);
}
warp_buf2[write_idx] = U(0);
}
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
warp_buf1[write_idx] = U(0);
if (!rms_only) {
warp_buf1[write_idx] = U(0);
}
warp_buf2[write_idx] = U(0);
}
}
......@@ -405,12 +515,16 @@ void cuLoadAddStridedInputs(
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
const U* __restrict__ invvar,
bool rms_only
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_mean;
if (!rms_only) {
curr_mean = mean[i1];
}
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
......@@ -419,13 +533,18 @@ void cuLoadAddStridedInputs(
if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
if (!rms_only) {
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf2[write_idx] += curr_dout * (curr_input) * curr_invvar;
}
}
}
}
}
template<typename T, typename U, typename V> __global__
void cuComputePartGradGammaBeta(
const V* __restrict__ dout,
......@@ -436,7 +555,8 @@ void cuComputePartGradGammaBeta(
const U* __restrict__ invvar,
U epsilon,
U* part_grad_gamma,
U* part_grad_beta)
U* part_grad_beta,
bool rms_only)
{
const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
......@@ -453,9 +573,9 @@ void cuComputePartGradGammaBeta(
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar, rms_only);
for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar, rms_only);
}
__syncthreads();
// inter-warp reductions
......@@ -465,10 +585,14 @@ void cuComputePartGradGammaBeta(
for (int k = 0; k < blockDim.y; ++k) {
int row1 = threadIdx.y + k*blockDim.y;
int idx1 = row1*row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
if (!rms_only) {
acc1 += warp_buf1[idx1];
}
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
if (!rms_only) {
warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
}
warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
__syncthreads();
// sum all warps
......@@ -478,7 +602,9 @@ void cuComputePartGradGammaBeta(
int row2 = threadIdx.y + offset;
int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
if (!rms_only) {
warp_buf1[idx1] += warp_buf1[idx2];
}
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
......@@ -489,7 +615,9 @@ void cuComputePartGradGammaBeta(
int row2 = threadIdx.y + 1;
int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x;
part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
if (!rms_only) {
part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
}
part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
......@@ -502,7 +630,8 @@ void cuComputeGradGammaBeta(
const int n1,
const int n2,
V* grad_gamma,
V* grad_beta)
V* grad_beta,
bool rms_only)
{
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
......@@ -517,7 +646,9 @@ void cuComputeGradGammaBeta(
const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
sum_beta += part_grad_beta_ptr[warp_offset*n2];
if (!rms_only) {
sum_beta += part_grad_beta_ptr[warp_offset*n2];
}
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
......@@ -526,25 +657,32 @@ void cuComputeGradGammaBeta(
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx+nbsize3] = sum_beta;
if (!rms_only) {
buf[write_idx+nbsize3] = sum_beta;
}
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx+nbsize3];
if (!rms_only) {
sum_beta += buf[read_idx+nbsize3];
}
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
if (!rms_only) {
grad_beta[i2] = sum_beta;
}
}
}
}
template<typename T, typename U, typename V> __global__
void cuComputeGradInput(
const V* __restrict__ dout,
......@@ -555,12 +693,16 @@ void cuComputeGradInput(
const U* __restrict__ invvar,
U epsilon,
const V* gamma,
T* grad_input)
T* grad_input,
bool rms_only)
{
for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
U c_mean;
if (!rms_only) {
c_mean = mean[i1];
}
const U c_invvar = invvar[i1];
const T* k_input = input + i1*n2;
const V* k_dout = dout + i1*n2;
......@@ -570,18 +712,27 @@ void cuComputeGradInput(
#ifndef __HIP_PLATFORM_HCC__
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l+k]);
const U c_loss = static_cast<U>(k_dout[l+k]);
sum_loss1 += c_loss * gamma[l+k];
sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar;
if (!rms_only) {
sum_loss1 += c_loss * gamma[l+k];
sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar;
} else {
sum_loss2 += c_loss * gamma[l+k] * (c_h) * c_invvar;
}
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
if (!rms_only) {
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
} else {
sum_loss2 += c_loss * gamma[l] * (c_h) * c_invvar;
}
}
#else
// Optimization for ROCm MI100
......@@ -590,8 +741,12 @@ void cuComputeGradInput(
const U gamma_idx = static_cast<U>((idx<n2) ? gamma[idx] : V(0));
const U c_h = static_cast<U>((idx<n2) ? k_input[idx] : T(0));
const U c_loss = static_cast<U>((idx<n2) ? k_dout[idx] : V(0));
sum_loss1 += c_loss * gamma_idx;
sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_invvar;
if (!rms_only) {
sum_loss1 += c_loss * gamma_idx;
sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_invvar;
} else {
sum_loss2 += c_loss * gamma_idx * (c_h) * c_invvar;
}
}
#endif
} else {
......@@ -601,29 +756,43 @@ void cuComputeGradInput(
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l+k]);
const U c_loss = static_cast<U>(k_dout[l+k]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
if (!rms_only) {
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
} else {
sum_loss2 += c_loss * (c_h) * c_invvar;
}
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
if (!rms_only) {
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
} else {
sum_loss2 += c_loss * (c_h) * c_invvar;
}
}
#else
for( int l = 0; l < n2 ; l += numx) {
int idx = l + thrx;
const U c_h = static_cast<U>((idx<n2) ? k_input[idx] : T(0));
const U c_loss = static_cast<U>((idx<n2) ? k_dout[idx] : V(0));
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
if (!rms_only) {
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
} else {
sum_loss2 += c_loss * (c_h) * c_invvar;
}
}
#endif
}
// intra-warp reductions
for (int mask = blockDim.x / 2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
if (!rms_only) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
}
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
}
// inter-warp reductions
......@@ -634,25 +803,33 @@ void cuComputeGradInput(
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[2*wrt_i] = sum_loss1;
if (!rms_only) {
buf[2*wrt_i] = sum_loss1;
}
buf[2*wrt_i+1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2*read_i];
if (!rms_only) {
sum_loss1 += buf[2*read_i];
}
sum_loss2 += buf[2*read_i+1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2*threadIdx.x] = sum_loss1;
if (!rms_only) {
buf[2*threadIdx.x] = sum_loss1;
}
buf[2*threadIdx.x+1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y !=0) {
sum_loss1 = buf[2*threadIdx.x];
if (!rms_only) {
sum_loss1 = buf[2*threadIdx.x];
}
sum_loss2 = buf[2*threadIdx.x+1];
}
}
......@@ -665,8 +842,12 @@ void cuComputeGradInput(
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * gamma[l];
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
if (!rms_only) {
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
} else {
f_grad_input -= (c_h) * c_invvar * sum_loss2;
}
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
......@@ -675,8 +856,12 @@ void cuComputeGradInput(
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
if (!rms_only) {
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
} else {
f_grad_input -= (c_h) * c_invvar * sum_loss2;
}
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
......@@ -686,6 +871,7 @@ void cuComputeGradInput(
}
}
template<typename T, typename U, typename V=T>
void HostApplyLayerNorm(
V* output,
......@@ -700,7 +886,7 @@ void HostApplyLayerNorm(
)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int warp_size = at::cuda::getCurrentDeviceProperties()->warpSize;
const int warp_size = at::cuda::warp_size();
dim3 threads(warp_size ,4, 1); // MI100 wavefront/warp = 64
#ifdef __HIP_PLATFORM_HCC__
// Optimization for ROCm MI100
......@@ -711,12 +897,40 @@ void HostApplyLayerNorm(
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared =
threads.y > 1 ?
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
0;
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
0;
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta, warp_size);
}
// Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
template<typename T, typename U, typename V=T>
void HostApplyRMSNorm(
V* output,
U* invvar,
const T* input,
int n1,
int n2,
double epsilon,
const V* gamma)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int warp_size = at::cuda::warp_size();
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
dim3 threads(warp_size,4,1);
#ifdef __HIP_PLATFORM_HCC__
// Optimization for ROCm MI100
threads.y = 2;
#endif
int nshared =
threads.y > 1 ?
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
0;
cuApplyRMSNorm<<<blocks, threads, nshared, stream>>>(
output, invvar, input, n1, n2, U(epsilon), gamma, warp_size);
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
......@@ -739,7 +953,7 @@ void cuda_layer_norm(
using accscalar_t = at::acc_type<scalar_t_in, true>;
HostApplyLayerNorm<scalar_t_in, accscalar_t, scalar_t_out>(
output->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<accscalar_t>(),
mean->DATA_PTR<accscalar_t>(),
invvar->DATA_PTR<accscalar_t>(),
input->DATA_PTR<scalar_t_in>(),
n1,n2,
......@@ -749,6 +963,35 @@ void cuda_layer_norm(
)
}
void cuda_rms_norm(
at::Tensor* output,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
double epsilon)
{
using namespace at;
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), output->scalar_type(), "rms_norm_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_in, true>;
HostApplyRMSNorm<scalar_t_in, accscalar_t, scalar_t_out>(
output->DATA_PTR<scalar_t_out>(),
invvar->DATA_PTR<accscalar_t>(),
input->DATA_PTR<scalar_t_in>(),
n1,n2,
epsilon,
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL);
)
}
template<typename T, typename U=float, typename V=T>
void HostLayerNormGradient(
const V* dout,
......@@ -766,10 +1009,11 @@ void HostLayerNormGradient(
)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int warp_size = at::cuda::getCurrentDeviceProperties()->warpSize;
const int warp_size = at::cuda::warp_size();
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
// Optimize layer normalization for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
const int part_size = warp_size;
const dim3 threads2(warp_size, 4, 1);
const dim3 blocks2((n2+threads2.x-1) / threads2.x,part_size, 1);
......@@ -785,25 +1029,27 @@ void HostLayerNormGradient(
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout,
input->DATA_PTR<T>(),
n1,n2,
mean,
invvar,
U(epsilon),
part_grad_gamma.DATA_PTR<U>(),
part_grad_beta.DATA_PTR<U>());
dout,
input->DATA_PTR<T>(),
n1,n2,
mean,
invvar,
U(epsilon),
part_grad_gamma.DATA_PTR<U>(),
part_grad_beta.DATA_PTR<U>(),
false);
const dim3 threads3(warp_size, 8, 1);
const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
part_grad_gamma.DATA_PTR<U>(),
part_grad_beta.DATA_PTR<U>(),
part_size,
n1,n2,
grad_gamma,
grad_beta);
part_grad_gamma.DATA_PTR<U>(),
part_grad_beta.DATA_PTR<U>(),
part_size,
n1,n2,
grad_gamma,
grad_beta,
false);
}
// compute grad_input
......@@ -818,9 +1064,9 @@ void HostLayerNormGradient(
threads1.y = 2;
#endif
int nshared =
threads1.y > 1 ?
threads1.y*threads1.x*sizeof(U) :
0;
threads1.y > 1 ?
threads1.y*threads1.x*sizeof(U) :
0;
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
dout,
input->DATA_PTR<T>(),
......@@ -829,7 +1075,80 @@ void HostLayerNormGradient(
invvar,
U(epsilon),
gamma,
grad_input);
grad_input,
false);
}
// Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files
template<typename T, typename U=float, typename V=T>
void HostRMSNormGradient(
const V* dout,
const U* invvar,
at::Tensor* input,
int n1,
int n2,
const V* gamma,
double epsilon,
T* grad_input,
V* grad_gamma)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int warp_size = at::cuda::warp_size();
if (gamma != NULL) {
const int part_size = warp_size;
const dim3 threads2(warp_size,4,1);
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
// note (mkozuki): I can hard code part_grad_gamma's dtype as float given that
// the `cuda_layer_norm_gradient` doesn't support double.
const auto part_grad_dtype =
(input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ?
at::ScalarType::Float :
input->scalar_type();
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype));
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout,
input->DATA_PTR<T>(),
n1,n2,
invvar, // unused
invvar,
U(epsilon),
part_grad_gamma.DATA_PTR<U>(),
part_grad_gamma.DATA_PTR<U>(), /* unused */
true);
const dim3 threads3(warp_size,8,1);
const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
part_grad_gamma.DATA_PTR<U>(),
part_grad_gamma.DATA_PTR<U>(), /* unused */
part_size,
n1,n2,
grad_gamma,
grad_gamma, /* unused */
true);
}
// compute grad_input
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(warp_size,4,1);
int nshared =
threads1.y > 1 ?
threads1.y*threads1.x*sizeof(U) :
0;
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
dout,
input->DATA_PTR<T>(),
n1,n2,
invvar, /* unused */
invvar,
U(epsilon),
gamma,
grad_input,
true);
}
void cuda_layer_norm_gradient(
......@@ -873,3 +1192,40 @@ void cuda_layer_norm_gradient(
)
}
void cuda_rms_norm_gradient(
at::Tensor* dout,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
double epsilon,
at::Tensor* grad_input,
at::Tensor* grad_gamma)
{
using namespace at;
// we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16
// DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), gamma == NULL ? input->scalar_type() : gamma->scalar_type(), "cuComputeGradInputRMS",
using accscalar_t = at::acc_type<scalar_t_in, true>;
HostRMSNormGradient(
dout->DATA_PTR<scalar_t_out>(),
invvar->DATA_PTR<accscalar_t>(),
input,
n1,n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
epsilon,
grad_input->DATA_PTR<scalar_t_in>(),
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL);
)
}
......@@ -12,3 +12,6 @@ apex.normalization.fused_layer_norm
.. autoclass:: FusedLayerNorm
:members:
.. autoclass:: FusedRMSNorm
:members:
import itertools
import unittest
import os
import random
import torch
import apex
from torch.autograd import Variable
class TestFusedLayerNorm(unittest.TestCase):
dtype = torch.float
elementwise_affine = False
normalized_shape = [32, 16]
rtol, atol = None, None
fwd_thresholds = dict(rtol=None, atol=None)
bwd_thresholds = dict(rtol=None, atol=None)
mixed_fused = False
def setUp(self):
# bias and weight are set to 0 and 1 respectively, so no need to copy parameters from cpu module to the gpu one
self.module_cpu_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=False).cpu()
self.module_cuda_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=False).cuda()
if not self.mixed_fused:
self.module_cpu_ = apex.normalization.FusedLayerNorm(
normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).cpu()
self.module_cuda_ = apex.normalization.FusedLayerNorm(
normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).to(device="cuda", dtype=self.dtype)
else:
assert self.elementwise_affine
self.module_cpu_ = apex.normalization.MixedFusedLayerNorm(
normalized_shape=self.normalized_shape).cpu()
self.module_cuda_ = apex.normalization.MixedFusedLayerNorm(
normalized_shape=self.normalized_shape).to(device="cuda", dtype=self.dtype)
def _check_same_output(self, batch_size, contiguous):
torch.cuda.manual_seed(42)
if contiguous:
input_shape = [batch_size] + self.normalized_shape
input_ = torch.randn(input_shape, device="cpu").requires_grad_(True)
input_cuda_ = input_.to(device="cuda", dtype=self.dtype).detach().requires_grad_(True)
self.assertTrue(input_.is_contiguous())
self.assertTrue(input_cuda_.is_contiguous())
else:
input_shape = [batch_size] + self.normalized_shape
input_shape = [batch_size * 3] + [self.normalized_shape[0] * 5, self.normalized_shape[1] * 3]
input_src_ = torch.randn(input_shape, device="cpu")
input_ = input_src_[::3, ::5, ::3].detach().requires_grad_(True)
input_cuda_ = input_src_.to(device="cuda", dtype=self.dtype)[::3, ::5, ::3].detach().requires_grad_(True)
# make sure that tensors are NOT contiguous.
self.assertFalse(input_.is_contiguous())
self.assertFalse(input_cuda_.is_contiguous())
out_cpu_ = self.module_cpu_(input_)
gO = torch.rand_like(out_cpu_)
out_cpu_.backward(gO)
out_cuda_ = self.module_cuda_(input_cuda_)
gO = gO.to(device="cuda", dtype=self.dtype)
out_cuda_.backward(gO)
self.assertFalse(out_cpu_.is_cuda)
self.assertTrue(out_cuda_.is_cuda)
# TODO (mkozuki): `torch.testing.assert_allclose` is deprecated.
# Use `torch.testing.assert_close`.
# See https://github.com/pytorch/pytorch/issues/61844
torch.testing.assert_allclose(
out_cpu_.to(device="cuda", dtype=self.dtype), out_cuda_, **self.fwd_thresholds)
torch.testing.assert_allclose(
input_.grad.to(device="cuda", dtype=self.dtype), input_cuda_.grad, **self.bwd_thresholds)
def _test_same_output(self, batch_size):
for contiguous in (True, False):
with self.subTest(contiguous=contiguous):
self._check_same_output(batch_size, contiguous)
def test_layer_norm(self):
self._test_same_output(16)
def test_large_batch(self):
self._test_same_output(65536)
class TestFusedRMSNorm(unittest.TestCase):
dtype = torch.float
elementwise_affine = False
normalized_shape = [32, 16]
rtol, atol = None, None
fwd_thresholds = dict(rtol=None, atol=None)
bwd_thresholds = dict(rtol=None, atol=None)
mixed_fused = False
def setUp(self):
# bias and weight are set to 0 and 1 respectively, so no need to copy parameters from cpu module to the gpu one
if not self.mixed_fused:
self.module_cpu_ = apex.normalization.FusedRMSNorm(
normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).cpu()
self.module_cuda_ = apex.normalization.FusedRMSNorm(
normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).to(device="cuda", dtype=self.dtype)
else:
assert self.elementwise_affine
self.module_cpu_ = apex.normalization.MixedFusedRMSNorm(
normalized_shape=self.normalized_shape).cpu()
self.module_cuda_ = apex.normalization.MixedFusedRMSNorm(
normalized_shape=self.normalized_shape).to(device="cuda", dtype=self.dtype)
def _check_same_output(self, batch_size, contiguous):
torch.cuda.manual_seed(42)
self.input_ = torch.randn((batch_size, *self.module_cpu_.normalized_shape), device="cpu").requires_grad_(True)
self.input_cuda_ = self.input_.cuda().detach().requires_grad_(True)
out_cpu_ = self.module_cpu_(self.input_)
if contiguous:
input_shape = [batch_size] + self.normalized_shape
input_ = torch.randn(input_shape, device="cpu").requires_grad_(True)
input_cuda_ = input_.to(device="cuda", dtype=self.dtype).detach().requires_grad_(True)
self.assertTrue(input_.is_contiguous())
self.assertTrue(input_cuda_.is_contiguous())
else:
input_shape = [batch_size] + self.normalized_shape
input_shape = [batch_size * 3] + [self.normalized_shape[0] * 5, self.normalized_shape[1] * 3]
input_src_ = torch.randn(input_shape, device="cpu")
input_ = input_src_[::3, ::5, ::3].detach().requires_grad_(True)
input_cuda_ = input_src_.to(device="cuda", dtype=self.dtype)[::3, ::5, ::3].detach().requires_grad_(True)
# make sure that tensors are NOT contiguous.
self.assertFalse(input_.is_contiguous())
self.assertFalse(input_cuda_.is_contiguous())
out_cpu_ = self.module_cpu_(input_)
gO = torch.rand_like(out_cpu_)
out_cpu_.backward(gO)
out_cuda_ = self.module_cuda_(self.input_cuda_)
gO = gO.cuda()
out_cuda_ = self.module_cuda_(input_cuda_)
# TODO (mkozuki): `torch.testing.assert_allclose` is deprecated.
# Use `torch.testing.assert_close`.
# See https://github.com/pytorch/pytorch/issues/61844
torch.testing.assert_allclose(
out_cpu_.to(device="cuda", dtype=self.dtype), out_cuda_.clone().detach(), **self.fwd_thresholds)
gO = gO.to(device="cuda", dtype=self.dtype)
out_cuda_.backward(gO)
assert out_cpu_.is_cuda == False
assert out_cuda_.is_cuda == True
torch.testing.assert_allclose(out_cpu_, out_cuda_.cpu())
torch.testing.assert_allclose(self.input_.grad, self.input_cuda_.grad.cpu())
self.assertFalse(out_cpu_.is_cuda)
self.assertTrue(out_cuda_.is_cuda)
torch.testing.assert_allclose(
input_.grad.to(device="cuda", dtype=self.dtype), input_cuda_.grad, **self.bwd_thresholds)
if self.elementwise_affine:
torch.testing.assert_allclose(self.module_cpu_.weight.grad.to(device="cuda", dtype=self.dtype),
self.module_cuda_.weight.grad, **self.bwd_thresholds)
def _test_same_output(self, batch_size):
for contiguous in (True, False):
with self.subTest(contiguous=contiguous):
self._check_same_output(batch_size, contiguous)
def test_layer_norm(self):
self._test_same_output(16)
......@@ -38,6 +149,9 @@ class TestFusedLayerNorm(unittest.TestCase):
class TestFusedLayerNormElemWise(TestFusedLayerNorm):
elementwise_affine = True
class TestMixedFusedLayerNormElemWise(TestFusedLayerNorm):
elementwise_affine = True
mixed_fused = True
class TestFusedLayerNormElemWiseHalf(TestFusedLayerNormElemWise):
dtype = torch.half
......@@ -45,6 +159,34 @@ class TestFusedLayerNormElemWiseHalf(TestFusedLayerNormElemWise):
def test_large_batch(self):
self.skipTest("Skip to save time")
class TestFusedLayerNormElemWiseBFloat16(TestFusedLayerNormElemWise):
dtype = torch.bfloat16
# NOTE (mkozuki): [BFloat16 Layer Norm flakiness]
# Use thresholds larger than those used in pytorch, see
# https://github.com/pytorch/pytorch/blob/72274e2a2fd55019ec860e1743dbdc5b0c5a5624/torch/testing/_asserts.py#L26
fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)
bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
def test_large_batch(self):
self.skipTest("Skip to save time")
class TestFusedRMSNormElemWise(TestFusedRMSNorm):
bwd_thresholds = dict(rtol=2e-3, atol=2e-4)
elementwise_affine = True
class TestMixedFusedRMSNormElemWise(TestFusedRMSNorm):
bwd_thresholds = dict(rtol=2e-3, atol=2e-4)
elementwise_affine = True
mixed_fused = True
class TestFusedRMSNormElemWiseHalf(TestFusedRMSNormElemWise):
dtype = torch.half
bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
def test_large_batch(self):
self.skipTest("Skip to save time")
class TestFusedLayerNormElemWiseBFloat16(TestFusedLayerNormElemWise):
dtype = torch.bfloat16
......@@ -68,6 +210,16 @@ def _prep_layers(normalized_shape, elementwise_affine, dtype):
return native, fused
def _prep_rms_layers(normalized_shape, elementwise_affine, dtype):
native = apex.normalization.FusedRMSNorm(
normalized_shape=normalized_shape, elementwise_affine=elementwise_affine
)
fused = apex.normalization.FusedRMSNorm(
normalized_shape=normalized_shape, elementwise_affine=elementwise_affine
).cuda()
return native, fused
def _prep_inputs(batch_size, normalized_shape, dtype):
shape = (batch_size, *normalized_shape)
fused = torch.randn(shape).cuda().requires_grad_(True)
......@@ -75,12 +227,8 @@ def _prep_inputs(batch_size, normalized_shape, dtype):
native = fused.clone().to(dtype).requires_grad_(True)
return native, fused
TORCH_MAJOR, TORCH_MINOR = int(torch.__version__.split('.')[0]), int(torch.__version__.split('.')[1])
if (TORCH_MAJOR <= 1 and TORCH_MINOR < 10):
autocast_dtypes = (torch.half,)
else:
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
class TestAutocastFusedLayerNorm(unittest.TestCase):
bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)
......@@ -106,6 +254,43 @@ class TestAutocastFusedLayerNorm(unittest.TestCase):
expected.backward(g_native)
actual.backward(g_fused)
tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedLayerNorm.bf16_bwd_thresholds
torch.testing.assert_allclose(native_x.grad, fused_x.grad, **tols)
def test_autocast(self):
for (dtype, elementwise_affine) in itertools.product(autocast_dtypes, (True, False)):
with self.subTest(f"{dtype}-{elementwise_affine}"):
self._run_test(dtype, elementwise_affine)
@unittest.skip("Skipped on ROCm5.2 due to the failure of reproducing the issue locally. (Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!) Please refer to https://github.com/ROCmSoftwarePlatform/apex/pull/78")
class TestAutocastFusedRMSNorm(unittest.TestCase):
bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)
bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
def setUp(self):
self.batch_size = 16
self.normalized_shape = [32, 16]
def _run_test(self, dtype, elementwise_affine):
native, fused = _prep_rms_layers(self.normalized_shape, elementwise_affine, dtype)
native_x, fused_x = _prep_inputs(self.batch_size, self.normalized_shape, dtype)
expected = native(native_x.cpu())
with torch.cuda.amp.autocast(dtype=dtype):
actual = fused(fused_x)
tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedRMSNorm.bf16_fwd_thresholds
torch.testing.assert_allclose(actual, expected.detach().clone().cuda(), **tols)
g_native = torch.rand_like(expected)
with torch.no_grad():
g_fused = g_native.detach().clone().cuda()
expected.backward(g_native)
actual.backward(g_fused)
tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedRMSNorm.bf16_bwd_thresholds
torch.testing.assert_allclose(native_x.grad.cuda(), fused_x.grad, **tols)
if __name__ == '__main__':
unittest.main()
def test_autocast(self):
for (dtype, elementwise_affine) in itertools.product(autocast_dtypes, (True, False)):
with self.subTest(f"{dtype}-{elementwise_affine}"):
self._run_test(dtype, elementwise_affine)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment