Commit 6738d947 authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Implement RMS Norm

parent a1f49a2b
......@@ -2,6 +2,7 @@ This CUDA extension implements fused dropout + residual + LayerNorm, building on
Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
We also make it work for more hidden dimensions (all dimensions divisible by 8, up to 6144).
We also implement RMSNorm as an option.
If you want to use it for dimensions larger than 6k, please file an issue.
......
......@@ -44,6 +44,7 @@ struct ParamsBase {
, colscale(nullptr)
, dropout_keep_p(1.f)
, dropout_scale(1.f)
, is_rms_norm(false)
, workspace(nullptr)
, barrier(nullptr)
{
......@@ -75,6 +76,8 @@ struct ParamsBase {
float dropout_scale;
float rowscale_const;
bool is_rms_norm;
// Multi-CTA workspace in gmem.
void *workspace;
......
......@@ -83,7 +83,7 @@ layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype ityp
std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size
c10::optional<const at::Tensor> &x1_, // Residual: BxSxhidden_size
const at::Tensor &gamma, // hidden_size
const at::Tensor &beta, // hidden_size
c10::optional<const at::Tensor> &beta_, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS
c10::optional<const at::Tensor> &colscale_, // hidden_size
c10::optional<const at::Tensor> &x0_subset_, // BxS
......@@ -93,7 +93,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
const float rowscale_const,
const int64_t z_numrows,
c10::optional<at::Generator> gen_,
bool residual_in_fp32
bool residual_in_fp32=false,
bool is_rms_norm=false
) {
auto itype = x0.scalar_type();
auto rtype = x1_.has_value()
......@@ -104,11 +105,8 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto ctype = torch::kFloat32;
auto mtype = torch::kUInt8;
TORCH_CHECK(beta.dtype() == wtype);
TORCH_CHECK(x0.is_cuda())
TORCH_CHECK(gamma.is_cuda())
TORCH_CHECK(beta.is_cuda())
TORCH_CHECK(x0.is_contiguous());
// c10::IntArrayRef does not own the storage, so we need to construct a vector.
......@@ -123,6 +121,14 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
const int cols = sizes[1];
auto hidden_size = gamma.numel();
if (beta_.has_value()) {
auto beta = beta_.value();
TORCH_CHECK(beta.dtype() == wtype);
TORCH_CHECK(beta.is_cuda())
TORCH_CHECK(beta.is_contiguous());
TORCH_CHECK(gamma.sizes() == beta.sizes());
}
if (x1_.has_value()) {
auto x1 = x1_.value();
TORCH_CHECK(x1.is_cuda())
......@@ -161,7 +167,6 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
TORCH_CHECK(z_subset.dtype() == torch::kInt32);
}
TORCH_CHECK(gamma.sizes() == beta.sizes());
TORCH_CHECK(hidden_size == cols);
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
......@@ -218,12 +223,13 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.beta = beta.data_ptr();
params.beta = beta_.has_value() ? beta_.value().data_ptr() : nullptr;
params.z = z.data_ptr();
params.epsilon = epsilon;
params.dropout_scale = 1.f / (1.f - dropout_p);
params.inverse_cols = 1.f / float(params.cols);
params.rowscale_const = rowscale_const;
params.is_rms_norm = is_rms_norm;
if (dropout_p > 0.f) {
// number of times random will be generated per thread, to offset philox counter in thc random
......@@ -268,7 +274,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
const float dropout_p,
const float rowscale_const,
const int64_t x0_numrows,
const bool has_residual
const bool has_residual,
bool is_rms_norm=false
) {
auto itype = dz.scalar_type();
......@@ -431,6 +438,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
params.dropout_scale = 1.f / (1.f - dropout_p);
params.inverse_cols = 1.f / float(params.cols);
params.rowscale_const = rowscale_const;
params.is_rms_norm = is_rms_norm;
if( launch_params.barrier_size > 0 ) {
// TODO Any way to avoid this?
......@@ -453,6 +461,14 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "CUDA DropoutAddLayerNorm";
m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel");
m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel");
m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel",
py::arg("x0"), py::arg("x1"), py::arg("gamma"), py::arg("beta"),
py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"),
py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"),
py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);
m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel",
py::arg("dz"), py::arg("dx_"), py::arg("x"), py::arg("x0_"), py::arg("dmask_"), py::arg("mu"),
py::arg("rsigma"), py::arg("gamma"), py::arg("rowscale_"), py::arg("colscale_"),
py::arg("x0_subset_"), py::arg("z_subset_"), py::arg("dropout_p"), py::arg("rowscale_const"),
py::arg("x0_numrows"), py::arg("has_residual"), py::arg("is_rms_norm")=false);
}
......@@ -125,7 +125,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_tmp = x.data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp - mu_r);
compute_t y_tmp = rs_r * (x_tmp - (!params.is_rms_norm ? mu_r : 0.f));
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]);
compute_t dz_tmp = dz.data.elt[jt];
......@@ -173,7 +173,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
if (load_dz) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + (!params.is_rms_norm ? mdy_local : 0.f)));
dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
} else {
dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f;
......
......@@ -89,7 +89,11 @@ void ln_fwd_kernel(FwdParams params) {
for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
gamma[it].load_from(params.gamma, idx);
beta[it].load_from(params.beta, idx);
if (params.beta != nullptr) {
beta[it].load_from(params.beta, idx);
} else {
beta[it].zero_();
}
if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }
idx += VEC_COLS_PER_LDG;
}
......@@ -159,7 +163,7 @@ void ln_fwd_kernel(FwdParams params) {
mu_ptr[row] = mu;
}
compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon);
compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu));
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
rs_ptr[row] = rs;
......@@ -174,7 +178,7 @@ void ln_fwd_kernel(FwdParams params) {
Ovec z;
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - mu));
compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f)));
compute_t g_ij = gamma[it].data.elt[jt];
compute_t b_ij = beta[it].data.elt[jt];
z.data.elt[jt] = output_t(g_ij * y_ij + b_ij);
......
......@@ -308,6 +308,13 @@ struct Vec {
}
}
inline __device__ void zero_() {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = Elt_type(0.f);
}
}
inline __device__ void load_from(const void *base_ptr, const size_t idx) {
this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
}
......
......@@ -8,7 +8,7 @@ import dropout_layer_norm
def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32):
residual_in_fp32=False, is_rms_norm=False):
""" Assume that arguments are contiguous
"""
hidden_size = gamma.numel()
......@@ -17,7 +17,7 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro
rowscale = rowscale.view(-1) if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, x1mat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
1.0, 0, None, residual_in_fp32
1.0, 0, None, residual_in_fp32, is_rms_norm
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
......@@ -25,7 +25,7 @@ def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dro
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
dropout_p, has_residual):
dropout_p, has_residual, is_rms_norm=False):
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd).
......@@ -41,7 +41,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
dropout_p, 1.0, 0, has_residual
dropout_p, 1.0, 0, has_residual, is_rms_norm
)
# dx1mat is None if not has_residual
if colscale is None:
......@@ -53,7 +53,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_subset, out_subset,
dropout_p, epsilon, rowscale_const, out_numrows,
residual_in_fp32):
residual_in_fp32=False, is_rms_norm=False):
""" Assume that arguments are contiguous
"""
hidden_size = gamma.numel()
......@@ -63,7 +63,7 @@ def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_sub
out_subset = out_subset.view(-1) if out_subset is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, x1mat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, None, residual_in_fp32
rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
......@@ -72,7 +72,7 @@ def _dropout_add_layer_norm_subset_forward(x0, x1, gamma, beta, colscale, x0_sub
def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
x0_subset, out_subset, dropout_p, rowscale_const,
x0_numrows, has_residual):
x0_numrows, has_residual, is_rms_norm=False):
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) + x1 was not returned in the fwd).
......@@ -89,7 +89,7 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
dropout_p, rowscale_const, x0_numrows, has_residual
dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm
)
# dx1mat is None if not has_residual
if colscale is None:
......@@ -101,16 +101,17 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
class DropoutAddLayerNormFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32,
prenorm=False, return_dmask=False):
def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None
gamma = gamma.contiguous()
beta = beta.contiguous()
beta = beta.contiguous() if beta is not None else None
rowscale = rowscale.contiguous() if rowscale is not None else None
colscale = colscale.contiguous() if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32
x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32, is_rms_norm
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None
......@@ -118,6 +119,8 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
ctx.has_residual = x1 is not None
ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta is not None
if not return_dmask:
return (zmat.view(x0.shape) if not prenorm
else (zmat.view(x0.shape), xmat.view(x0.shape)))
......@@ -138,26 +141,29 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
dropout_p = ctx.dropout_p
has_residual = ctx.has_residual
dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual
dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual,
ctx.is_rms_norm
)
dx0 = dx0mat.view(x.shape)
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
dcolscale = rest[0] if colscale is not None else None
return dx0, dx1, dgamma, dbeta, None, dcolscale, None, None, None, None, None
return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None, None,
None, None, None, None)
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm=False, return_dmask=False):
rowscale_const, out_numrows, residual_in_fp32=False,
prenorm=False, is_rms_norm=False, return_dmask=False):
x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None
gamma = gamma.contiguous()
beta = beta.contiguous()
beta = beta.contiguous() if beta is not None else None
colscale = colscale.contiguous() if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
x0, x1, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32
rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None
......@@ -169,6 +175,8 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
ctx.rowscale_const = rowscale_const
ctx.x0_numrows = x0.shape[:-1].numel()
ctx.has_residual = x1 is not None
ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta is not None
z_shape = (-1, *x0.shape[1:])
if not return_dmask:
return (zmat.view(z_shape) if not prenorm
......@@ -191,13 +199,13 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
has_residual = ctx.has_residual
dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
ctx.rowscale_const, ctx.x0_numrows, has_residual
ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm
)
dx0 = dx0mat.view(-1, *x.shape[1:])
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
dcolscale = rest[0] if colscale is not None else None
return (dx0, dx1, dgamma, dbeta, dcolscale, None, None, None, None, None, None, None,
None, None)
return (dx0, dx1, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None, None,
None, None, None, None, None, None, None)
def layer_norm(x, weight, bias, epsilon):
......@@ -212,7 +220,7 @@ def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=No
"""
return DropoutAddLayerNormFn.apply(
x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
return_dropout_mask
False, return_dropout_mask
)
......@@ -225,7 +233,7 @@ def dropout_add_layer_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, laye
"""
return DropoutAddLayerNormSubsetFn.apply(
x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm, return_dropout_mask
rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask
)
......
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import torch
from torch.nn import init
from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn
def rms_norm(x, weight, epsilon):
return DropoutAddLayerNormFn.apply(x, None, weight, None, None, None, 0.0, epsilon, False,
False, True)
def dropout_add_rms_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
prenorm=False, residual_in_fp32=False, return_dropout_mask=False):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
"""
return DropoutAddLayerNormFn.apply(
x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
True, return_dropout_mask
)
def dropout_add_rms_norm_subset(x0, x1, weight, bias, dropout_p, epsilon, layerscale=None,
x0_subset=None, out_subset=None, rowscale_const=1.0,
out_numrows=0, prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
"""
return DropoutAddLayerNormSubsetFn.apply(
x0, x1, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask
)
class DropoutAddRMSNorm(torch.nn.Module):
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.prenorm = prenorm
self.p = p
self.epsilon = eps
self.residual_in_fp32 = residual_in_fp32
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
def forward(self, x0, x1=None):
return dropout_add_rms_norm(x0, x1, self.weight, None,
self.p if self.training else 0.0, self.epsilon,
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
......@@ -8,11 +8,20 @@ from einops import rearrange, repeat
from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_norm
from flash_attn.ops.layer_norm import dropout_add_layer_norm_subset
from flash_attn.ops.rms_norm import DropoutAddRMSNorm, dropout_add_rms_norm
from flash_attn.ops.rms_norm import dropout_add_rms_norm_subset
try:
from apex.normalization import FusedRMSNorm
except:
FusedRMSNorm = None
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('is_rms_norm', [False, True])
@pytest.mark.parametrize('has_colscale', [True, False])
# @pytest.mark.parametrize('has_colscale', [False])
@pytest.mark.parametrize('has_rowscale', [True, False])
# @pytest.mark.parametrize('has_rowscale', [True])
@pytest.mark.parametrize('has_residual', [True, False])
......@@ -26,11 +35,17 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_residual, has_rowscale, has_colscale):
dropout_p, has_residual, has_rowscale, has_colscale, is_rms_norm):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
if is_rms_norm and FusedRMSNorm is None:
pytest.skip() # We need Apex's FusedRMSNorm to test
layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4)
......@@ -67,20 +82,22 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
if has_colscale:
x0_scaled_pt = x0_scaled_pt * colscale_pt
x0_scaled_ref = x0_scaled_ref * colscale_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
if not is_rms_norm:
torch.nn.init.normal_(model_pt.bias)
model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)
model = our_layer_norm_cls(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
model_ref.weight.copy_(model_pt.weight)
model_ref.bias.copy_(model_pt.bias)
if not is_rms_norm:
model.bias.copy_(model_pt.bias)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
model.epsilon, rowscale=rowscale, layerscale=colscale,
residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
out, dmask = our_layer_norm_func(x0, x1, model.weight, model.bias, model.p,
model.epsilon, rowscale=rowscale, layerscale=colscale,
residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
assert out.dtype == input_dtype
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
if has_residual:
......@@ -101,7 +118,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
if has_residual:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5
if not is_rms_norm:
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
......@@ -151,27 +169,34 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
# @pytest.mark.parametrize('has_colscale', [True, False])
# @pytest.mark.parametrize('has_rowscale', [True, False])
# @pytest.mark.parametrize('has_residual', [True, False])
# @pytest.mark.parametrize('dropout_p', [0.37, 0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
# @pytest.mark.parametrize('input_dtype,residual_dtype',
# [(torch.float16, torch.float16), (torch.float16, torch.float32),
# (torch.float32, torch.float32)]
# + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize('has_colscale', [True])
@pytest.mark.parametrize('has_rowscale', [False])
@pytest.mark.parametrize('has_residual', [True])
@pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32])
@pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
# @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@pytest.mark.parametrize('hidden_size', [256])
@pytest.mark.parametrize('is_rms_norm', [False, True])
@pytest.mark.parametrize('has_colscale', [True, False])
@pytest.mark.parametrize('has_rowscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('has_colscale', [True])
# @pytest.mark.parametrize('has_rowscale', [False])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_residual, has_rowscale, has_colscale):
dropout_p, has_residual, has_rowscale, has_colscale,
is_rms_norm):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
if is_rms_norm and FusedRMSNorm is None:
pytest.skip() # We need Apex's FusedRMSNorm to test
layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm
our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm
our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm
device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4)
......@@ -208,23 +233,25 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
if has_colscale:
x0_scaled_pt = x0_scaled_pt * colscale_pt
x0_scaled_ref = x0_scaled_ref * colscale_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
dtype=weight_dtype)
if not is_rms_norm:
torch.nn.init.normal_(model_pt.bias)
model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)
model = our_layer_norm_cls(hidden_size, prenorm=True, p=dropout_p, device=device,
dtype=weight_dtype)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
model_ref.weight.copy_(model_pt.weight)
model_ref.bias.copy_(model_pt.bias)
if not is_rms_norm:
model.bias.copy_(model_pt.bias)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, residual, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
model.epsilon, rowscale=rowscale,
layerscale=colscale, prenorm=True,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True)
out, residual, dmask = our_layer_norm_func(x0, x1, model.weight, model.bias, model.p,
model.epsilon, rowscale=rowscale,
layerscale=colscale, prenorm=True,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True)
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
......@@ -247,7 +274,8 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
if has_residual:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
if not is_rms_norm:
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment