Commit aa5c6438 authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Implement rowscale in Triton layernorm

parent 386e3911
# Copyright (c) 2023, Tri Dao.
# Copyright (c) 2024, Tri Dao.
# Implement dropout + residual + layer_norm / rms_norm.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
......@@ -23,6 +23,7 @@ def layer_norm_ref(
residual=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
dropout_mask=None,
upcast=False,
......@@ -34,6 +35,8 @@ def layer_norm_ref(
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual
if rowscale is not None:
x = x * rowscale[..., None]
if dropout_p > 0.0:
if dropout_mask is not None:
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
......@@ -54,6 +57,7 @@ def rms_norm_ref(
residual=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
dropout_mask=None,
upcast=False,
......@@ -65,6 +69,8 @@ def rms_norm_ref(
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual
if rowscale is not None:
x = x * rowscale[..., None]
if dropout_p > 0.0:
if dropout_mask is not None:
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
......@@ -99,6 +105,7 @@ def _layer_norm_fwd_1pass_kernel(
B, # pointer to the biases
RESIDUAL, # pointer to the residual
RESIDUAL_OUT, # pointer to the residual
ROWSCALE,
SEEDS, # Dropout seeds for each row
DROPOUT_MASK,
Mean, # pointer to the mean
......@@ -117,6 +124,7 @@ def _layer_norm_fwd_1pass_kernel(
HAS_BIAS: tl.constexpr,
HAS_DROPOUT: tl.constexpr,
STORE_DROPOUT_MASK: tl.constexpr,
HAS_ROWSCALE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
......@@ -129,6 +137,9 @@ def _layer_norm_fwd_1pass_kernel(
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_ROWSCALE:
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
x *= rowscale
if HAS_DROPOUT:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
......@@ -169,6 +180,7 @@ def _layer_norm_fwd(
eps,
residual=None,
dropout_p=0.0,
rowscale=None,
out_dtype=None,
residual_dtype=None,
is_rms_norm=False,
......@@ -186,6 +198,9 @@ def _layer_norm_fwd(
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
if rowscale is not None:
assert rowscale.is_contiguous()
assert rowscale.shape == (M,)
# allocate output
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
assert y.stride(-1) == 1
......@@ -193,6 +208,7 @@ def _layer_norm_fwd(
residual is not None
or (residual_dtype is not None and residual_dtype != x.dtype)
or dropout_p > 0.0
or rowscale is not None
):
residual_out = torch.empty(
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
......@@ -224,6 +240,7 @@ def _layer_norm_fwd(
bias,
residual,
residual_out,
rowscale,
seeds,
dropout_mask,
mean,
......@@ -242,6 +259,7 @@ def _layer_norm_fwd(
bias is not None,
dropout_p > 0.0,
dropout_mask is not None,
rowscale is not None,
)
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
return y, mean, rstd, residual_out if residual_out is not None else x, seeds, dropout_mask
......@@ -261,6 +279,7 @@ def _layer_norm_fwd(
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
@triton.jit
def _layer_norm_bwd_kernel(
......@@ -274,6 +293,7 @@ def _layer_norm_bwd_kernel(
DB, # pointer to the partial sum of biases gradient
DRESIDUAL,
DRESIDUAL_IN,
ROWSCALE,
SEEDS,
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
......@@ -294,11 +314,14 @@ def _layer_norm_bwd_kernel(
STORE_DRESIDUAL: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_DROPOUT: tl.constexpr,
HAS_ROWSCALE: tl.constexpr,
RECOMPUTE_OUTPUT: tl.constexpr,
):
# Map the program id to the elements of X, DX, and DY it should compute.
row_block_id = tl.program_id(0)
row_start = row_block_id * rows_per_program
if row_start >= M:
return
cols = tl.arange(0, BLOCK_N)
mask = cols < N
X += row_start * stride_x_row
......@@ -350,6 +373,9 @@ def _layer_norm_bwd_kernel(
if HAS_DROPOUT:
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
if HAS_ROWSCALE:
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
dx *= rowscale
tl.store(DX + cols, dx, mask=mask)
X += stride_x_row
......@@ -377,6 +403,7 @@ def _layer_norm_bwd(
dresidual=None,
seeds=None,
dropout_p=0.0,
rowscale=None,
has_residual=False,
is_rms_norm=False,
x_dtype=None,
......@@ -397,6 +424,9 @@ def _layer_norm_bwd(
if seeds is not None:
assert seeds.is_contiguous()
assert seeds.shape == (M,)
if rowscale is not None:
assert rowscale.is_contiguous()
assert rowscale.shape == (M,)
# allocate output
dx = (
torch.empty_like(x)
......@@ -404,7 +434,9 @@ def _layer_norm_bwd(
else torch.empty(M, N, dtype=x_dtype, device=x.device)
)
dresidual_in = (
torch.empty_like(x) if has_residual and (dx.dtype != x.dtype or dropout_p > 0.0) else None
torch.empty_like(x)
if has_residual and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None)
else None
)
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
......@@ -434,6 +466,7 @@ def _layer_norm_bwd(
_db,
dresidual,
dresidual_in,
rowscale,
seeds,
mean,
rstd,
......@@ -458,7 +491,7 @@ def _layer_norm_bwd(
dw = _dw.sum(0).to(weight.dtype)
db = _db.sum(0).to(bias.dtype) if bias is not None else None
# Don't need to compute dresidual_in separately in this case
if has_residual and dx.dtype == x.dtype and dropout_p == 0.0:
if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
dresidual_in = dx
return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
......@@ -473,6 +506,7 @@ class LayerNormFn(torch.autograd.Function):
residual=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
......@@ -491,6 +525,8 @@ class LayerNormFn(torch.autograd.Function):
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
if rowscale is not None:
rowscale = rowscale.reshape(-1).contiguous()
residual_dtype = (
residual.dtype
if residual is not None
......@@ -503,11 +539,12 @@ class LayerNormFn(torch.autograd.Function):
eps,
residual,
dropout_p=dropout_p,
rowscale=rowscale,
residual_dtype=residual_dtype,
is_rms_norm=is_rms_norm,
return_dropout_mask=return_dropout_mask,
)
ctx.save_for_backward(residual_out, weight, bias, seeds, mean, rstd)
ctx.save_for_backward(residual_out, weight, bias, rowscale, seeds, mean, rstd)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.dropout_p = dropout_p
......@@ -525,7 +562,7 @@ class LayerNormFn(torch.autograd.Function):
@staticmethod
def backward(ctx, dy, *args):
x, weight, bias, seeds, mean, rstd = ctx.saved_tensors
x, weight, bias, rowscale, seeds, mean, rstd = ctx.saved_tensors
dy = dy.reshape(-1, dy.shape[-1])
if dy.stride(-1) != 1:
dy = dy.contiguous()
......@@ -549,6 +586,7 @@ class LayerNormFn(torch.autograd.Function):
dresidual,
seeds,
ctx.dropout_p,
rowscale,
ctx.has_residual,
ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
......@@ -564,6 +602,7 @@ class LayerNormFn(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -574,6 +613,7 @@ def layer_norm_fn(
residual=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
......@@ -586,6 +626,7 @@ def layer_norm_fn(
residual,
eps,
dropout_p,
rowscale,
prenorm,
residual_in_fp32,
is_rms_norm,
......@@ -600,6 +641,7 @@ def rms_norm_fn(
residual=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
return_dropout_mask=False,
......@@ -611,6 +653,7 @@ def rms_norm_fn(
residual,
eps,
dropout_p,
rowscale,
prenorm,
residual_in_fp32,
True,
......
# Copyright (c) 2023, Tri Dao.
# Copyright (c) 2024, Tri Dao.
import pytest
import torch
......@@ -16,14 +16,16 @@ from flash_attn.ops.triton.layernorm import (
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize("has_rowscale", [False, True])
# @pytest.mark.parametrize("has_rowscale", [True])
@pytest.mark.parametrize("dropout_p", [0.0, 0.27])
# @pytest.mark.parametrize("dropout_p", [0.27])
# @pytest.mark.parametrize("dropout_p", [0.0])
@pytest.mark.parametrize("prenorm", [True, False])
# @pytest.mark.parametrize("prenorm", [False])
# @pytest.mark.parametrize("prenorm", [True])
@pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize("is_rms_norm", [True])
@pytest.mark.parametrize("has_residual", [True, False])
# @pytest.mark.parametrize("has_residual", [True])
# @pytest.mark.parametrize("has_residual", [False])
@pytest.mark.parametrize(
"weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else [])
)
......@@ -45,6 +47,7 @@ def test_layer_norm(
is_rms_norm,
prenorm,
dropout_p,
has_rowscale,
):
device = "cuda"
if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
......@@ -60,7 +63,8 @@ def test_layer_norm(
layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
allclose = (
lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
<= 2 * (x_pt - x_ref).abs().max() + atol
# Sometimes x0_pt.grad is NaN
<= 2 * (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() + atol
)
x0 = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
......@@ -83,6 +87,8 @@ def test_layer_norm(
bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
rowscale = torch.randn(batch_size, seqlen, dtype=input_dtype, device=device) if has_rowscale else None
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, *rest = layer_norm_fn(
x0,
......@@ -91,6 +97,7 @@ def test_layer_norm(
residual=res,
eps=1e-6,
dropout_p=dropout_p,
rowscale=rowscale,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
is_rms_norm=is_rms_norm,
......@@ -104,6 +111,7 @@ def test_layer_norm(
residual=res_pt,
eps=1e-6,
dropout_p=dropout_p,
rowscale=rowscale,
prenorm=prenorm,
dropout_mask=dropout_mask,
)
......@@ -114,6 +122,7 @@ def test_layer_norm(
residual=res_ref,
eps=1e-6,
dropout_p=dropout_p,
rowscale=rowscale,
prenorm=prenorm,
dropout_mask=dropout_mask,
upcast=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment