Commit aa5c6438 authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Implement rowscale in Triton layernorm

parent 386e3911
# Copyright (c) 2023, Tri Dao. # Copyright (c) 2024, Tri Dao.
# Implement dropout + residual + layer_norm / rms_norm. # Implement dropout + residual + layer_norm / rms_norm.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
...@@ -23,6 +23,7 @@ def layer_norm_ref( ...@@ -23,6 +23,7 @@ def layer_norm_ref(
residual=None, residual=None,
eps=1e-6, eps=1e-6,
dropout_p=0.0, dropout_p=0.0,
rowscale=None,
prenorm=False, prenorm=False,
dropout_mask=None, dropout_mask=None,
upcast=False, upcast=False,
...@@ -34,6 +35,8 @@ def layer_norm_ref( ...@@ -34,6 +35,8 @@ def layer_norm_ref(
if upcast: if upcast:
x = x.float() x = x.float()
residual = residual.float() if residual is not None else residual residual = residual.float() if residual is not None else residual
if rowscale is not None:
x = x * rowscale[..., None]
if dropout_p > 0.0: if dropout_p > 0.0:
if dropout_mask is not None: if dropout_mask is not None:
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
...@@ -54,6 +57,7 @@ def rms_norm_ref( ...@@ -54,6 +57,7 @@ def rms_norm_ref(
residual=None, residual=None,
eps=1e-6, eps=1e-6,
dropout_p=0.0, dropout_p=0.0,
rowscale=None,
prenorm=False, prenorm=False,
dropout_mask=None, dropout_mask=None,
upcast=False, upcast=False,
...@@ -65,6 +69,8 @@ def rms_norm_ref( ...@@ -65,6 +69,8 @@ def rms_norm_ref(
if upcast: if upcast:
x = x.float() x = x.float()
residual = residual.float() if residual is not None else residual residual = residual.float() if residual is not None else residual
if rowscale is not None:
x = x * rowscale[..., None]
if dropout_p > 0.0: if dropout_p > 0.0:
if dropout_mask is not None: if dropout_mask is not None:
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
...@@ -99,6 +105,7 @@ def _layer_norm_fwd_1pass_kernel( ...@@ -99,6 +105,7 @@ def _layer_norm_fwd_1pass_kernel(
B, # pointer to the biases B, # pointer to the biases
RESIDUAL, # pointer to the residual RESIDUAL, # pointer to the residual
RESIDUAL_OUT, # pointer to the residual RESIDUAL_OUT, # pointer to the residual
ROWSCALE,
SEEDS, # Dropout seeds for each row SEEDS, # Dropout seeds for each row
DROPOUT_MASK, DROPOUT_MASK,
Mean, # pointer to the mean Mean, # pointer to the mean
...@@ -117,6 +124,7 @@ def _layer_norm_fwd_1pass_kernel( ...@@ -117,6 +124,7 @@ def _layer_norm_fwd_1pass_kernel(
HAS_BIAS: tl.constexpr, HAS_BIAS: tl.constexpr,
HAS_DROPOUT: tl.constexpr, HAS_DROPOUT: tl.constexpr,
STORE_DROPOUT_MASK: tl.constexpr, STORE_DROPOUT_MASK: tl.constexpr,
HAS_ROWSCALE: tl.constexpr,
): ):
# Map the program id to the row of X and Y it should compute. # Map the program id to the row of X and Y it should compute.
row = tl.program_id(0) row = tl.program_id(0)
...@@ -129,6 +137,9 @@ def _layer_norm_fwd_1pass_kernel( ...@@ -129,6 +137,9 @@ def _layer_norm_fwd_1pass_kernel(
# Compute mean and variance # Compute mean and variance
cols = tl.arange(0, BLOCK_N) cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_ROWSCALE:
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
x *= rowscale
if HAS_DROPOUT: if HAS_DROPOUT:
# Compute dropout mask # Compute dropout mask
# 7 rounds is good enough, and reduces register pressure # 7 rounds is good enough, and reduces register pressure
...@@ -169,6 +180,7 @@ def _layer_norm_fwd( ...@@ -169,6 +180,7 @@ def _layer_norm_fwd(
eps, eps,
residual=None, residual=None,
dropout_p=0.0, dropout_p=0.0,
rowscale=None,
out_dtype=None, out_dtype=None,
residual_dtype=None, residual_dtype=None,
is_rms_norm=False, is_rms_norm=False,
...@@ -186,6 +198,9 @@ def _layer_norm_fwd( ...@@ -186,6 +198,9 @@ def _layer_norm_fwd(
if bias is not None: if bias is not None:
assert bias.stride(-1) == 1 assert bias.stride(-1) == 1
assert bias.shape == (N,) assert bias.shape == (N,)
if rowscale is not None:
assert rowscale.is_contiguous()
assert rowscale.shape == (M,)
# allocate output # allocate output
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
assert y.stride(-1) == 1 assert y.stride(-1) == 1
...@@ -193,6 +208,7 @@ def _layer_norm_fwd( ...@@ -193,6 +208,7 @@ def _layer_norm_fwd(
residual is not None residual is not None
or (residual_dtype is not None and residual_dtype != x.dtype) or (residual_dtype is not None and residual_dtype != x.dtype)
or dropout_p > 0.0 or dropout_p > 0.0
or rowscale is not None
): ):
residual_out = torch.empty( residual_out = torch.empty(
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
...@@ -224,6 +240,7 @@ def _layer_norm_fwd( ...@@ -224,6 +240,7 @@ def _layer_norm_fwd(
bias, bias,
residual, residual,
residual_out, residual_out,
rowscale,
seeds, seeds,
dropout_mask, dropout_mask,
mean, mean,
...@@ -242,6 +259,7 @@ def _layer_norm_fwd( ...@@ -242,6 +259,7 @@ def _layer_norm_fwd(
bias is not None, bias is not None,
dropout_p > 0.0, dropout_p > 0.0,
dropout_mask is not None, dropout_mask is not None,
rowscale is not None,
) )
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
return y, mean, rstd, residual_out if residual_out is not None else x, seeds, dropout_mask return y, mean, rstd, residual_out if residual_out is not None else x, seeds, dropout_mask
...@@ -261,6 +279,7 @@ def _layer_norm_fwd( ...@@ -261,6 +279,7 @@ def _layer_norm_fwd(
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
@triton.jit @triton.jit
def _layer_norm_bwd_kernel( def _layer_norm_bwd_kernel(
...@@ -274,6 +293,7 @@ def _layer_norm_bwd_kernel( ...@@ -274,6 +293,7 @@ def _layer_norm_bwd_kernel(
DB, # pointer to the partial sum of biases gradient DB, # pointer to the partial sum of biases gradient
DRESIDUAL, DRESIDUAL,
DRESIDUAL_IN, DRESIDUAL_IN,
ROWSCALE,
SEEDS, SEEDS,
Mean, # pointer to the mean Mean, # pointer to the mean
Rstd, # pointer to the 1/std Rstd, # pointer to the 1/std
...@@ -294,11 +314,14 @@ def _layer_norm_bwd_kernel( ...@@ -294,11 +314,14 @@ def _layer_norm_bwd_kernel(
STORE_DRESIDUAL: tl.constexpr, STORE_DRESIDUAL: tl.constexpr,
HAS_BIAS: tl.constexpr, HAS_BIAS: tl.constexpr,
HAS_DROPOUT: tl.constexpr, HAS_DROPOUT: tl.constexpr,
HAS_ROWSCALE: tl.constexpr,
RECOMPUTE_OUTPUT: tl.constexpr, RECOMPUTE_OUTPUT: tl.constexpr,
): ):
# Map the program id to the elements of X, DX, and DY it should compute. # Map the program id to the elements of X, DX, and DY it should compute.
row_block_id = tl.program_id(0) row_block_id = tl.program_id(0)
row_start = row_block_id * rows_per_program row_start = row_block_id * rows_per_program
if row_start >= M:
return
cols = tl.arange(0, BLOCK_N) cols = tl.arange(0, BLOCK_N)
mask = cols < N mask = cols < N
X += row_start * stride_x_row X += row_start * stride_x_row
...@@ -350,6 +373,9 @@ def _layer_norm_bwd_kernel( ...@@ -350,6 +373,9 @@ def _layer_norm_bwd_kernel(
if HAS_DROPOUT: if HAS_DROPOUT:
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
if HAS_ROWSCALE:
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
dx *= rowscale
tl.store(DX + cols, dx, mask=mask) tl.store(DX + cols, dx, mask=mask)
X += stride_x_row X += stride_x_row
...@@ -377,6 +403,7 @@ def _layer_norm_bwd( ...@@ -377,6 +403,7 @@ def _layer_norm_bwd(
dresidual=None, dresidual=None,
seeds=None, seeds=None,
dropout_p=0.0, dropout_p=0.0,
rowscale=None,
has_residual=False, has_residual=False,
is_rms_norm=False, is_rms_norm=False,
x_dtype=None, x_dtype=None,
...@@ -397,6 +424,9 @@ def _layer_norm_bwd( ...@@ -397,6 +424,9 @@ def _layer_norm_bwd(
if seeds is not None: if seeds is not None:
assert seeds.is_contiguous() assert seeds.is_contiguous()
assert seeds.shape == (M,) assert seeds.shape == (M,)
if rowscale is not None:
assert rowscale.is_contiguous()
assert rowscale.shape == (M,)
# allocate output # allocate output
dx = ( dx = (
torch.empty_like(x) torch.empty_like(x)
...@@ -404,7 +434,9 @@ def _layer_norm_bwd( ...@@ -404,7 +434,9 @@ def _layer_norm_bwd(
else torch.empty(M, N, dtype=x_dtype, device=x.device) else torch.empty(M, N, dtype=x_dtype, device=x.device)
) )
dresidual_in = ( dresidual_in = (
torch.empty_like(x) if has_residual and (dx.dtype != x.dtype or dropout_p > 0.0) else None torch.empty_like(x)
if has_residual and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None)
else None
) )
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
...@@ -434,6 +466,7 @@ def _layer_norm_bwd( ...@@ -434,6 +466,7 @@ def _layer_norm_bwd(
_db, _db,
dresidual, dresidual,
dresidual_in, dresidual_in,
rowscale,
seeds, seeds,
mean, mean,
rstd, rstd,
...@@ -458,7 +491,7 @@ def _layer_norm_bwd( ...@@ -458,7 +491,7 @@ def _layer_norm_bwd(
dw = _dw.sum(0).to(weight.dtype) dw = _dw.sum(0).to(weight.dtype)
db = _db.sum(0).to(bias.dtype) if bias is not None else None db = _db.sum(0).to(bias.dtype) if bias is not None else None
# Don't need to compute dresidual_in separately in this case # Don't need to compute dresidual_in separately in this case
if has_residual and dx.dtype == x.dtype and dropout_p == 0.0: if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
dresidual_in = dx dresidual_in = dx
return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
...@@ -473,6 +506,7 @@ class LayerNormFn(torch.autograd.Function): ...@@ -473,6 +506,7 @@ class LayerNormFn(torch.autograd.Function):
residual=None, residual=None,
eps=1e-6, eps=1e-6,
dropout_p=0.0, dropout_p=0.0,
rowscale=None,
prenorm=False, prenorm=False,
residual_in_fp32=False, residual_in_fp32=False,
is_rms_norm=False, is_rms_norm=False,
...@@ -491,6 +525,8 @@ class LayerNormFn(torch.autograd.Function): ...@@ -491,6 +525,8 @@ class LayerNormFn(torch.autograd.Function):
weight = weight.contiguous() weight = weight.contiguous()
if bias is not None: if bias is not None:
bias = bias.contiguous() bias = bias.contiguous()
if rowscale is not None:
rowscale = rowscale.reshape(-1).contiguous()
residual_dtype = ( residual_dtype = (
residual.dtype residual.dtype
if residual is not None if residual is not None
...@@ -503,11 +539,12 @@ class LayerNormFn(torch.autograd.Function): ...@@ -503,11 +539,12 @@ class LayerNormFn(torch.autograd.Function):
eps, eps,
residual, residual,
dropout_p=dropout_p, dropout_p=dropout_p,
rowscale=rowscale,
residual_dtype=residual_dtype, residual_dtype=residual_dtype,
is_rms_norm=is_rms_norm, is_rms_norm=is_rms_norm,
return_dropout_mask=return_dropout_mask, return_dropout_mask=return_dropout_mask,
) )
ctx.save_for_backward(residual_out, weight, bias, seeds, mean, rstd) ctx.save_for_backward(residual_out, weight, bias, rowscale, seeds, mean, rstd)
ctx.x_shape_og = x_shape_og ctx.x_shape_og = x_shape_og
ctx.eps = eps ctx.eps = eps
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
...@@ -525,7 +562,7 @@ class LayerNormFn(torch.autograd.Function): ...@@ -525,7 +562,7 @@ class LayerNormFn(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, dy, *args): def backward(ctx, dy, *args):
x, weight, bias, seeds, mean, rstd = ctx.saved_tensors x, weight, bias, rowscale, seeds, mean, rstd = ctx.saved_tensors
dy = dy.reshape(-1, dy.shape[-1]) dy = dy.reshape(-1, dy.shape[-1])
if dy.stride(-1) != 1: if dy.stride(-1) != 1:
dy = dy.contiguous() dy = dy.contiguous()
...@@ -549,6 +586,7 @@ class LayerNormFn(torch.autograd.Function): ...@@ -549,6 +586,7 @@ class LayerNormFn(torch.autograd.Function):
dresidual, dresidual,
seeds, seeds,
ctx.dropout_p, ctx.dropout_p,
rowscale,
ctx.has_residual, ctx.has_residual,
ctx.is_rms_norm, ctx.is_rms_norm,
x_dtype=ctx.x_dtype, x_dtype=ctx.x_dtype,
...@@ -564,6 +602,7 @@ class LayerNormFn(torch.autograd.Function): ...@@ -564,6 +602,7 @@ class LayerNormFn(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -574,6 +613,7 @@ def layer_norm_fn( ...@@ -574,6 +613,7 @@ def layer_norm_fn(
residual=None, residual=None,
eps=1e-6, eps=1e-6,
dropout_p=0.0, dropout_p=0.0,
rowscale=None,
prenorm=False, prenorm=False,
residual_in_fp32=False, residual_in_fp32=False,
is_rms_norm=False, is_rms_norm=False,
...@@ -586,6 +626,7 @@ def layer_norm_fn( ...@@ -586,6 +626,7 @@ def layer_norm_fn(
residual, residual,
eps, eps,
dropout_p, dropout_p,
rowscale,
prenorm, prenorm,
residual_in_fp32, residual_in_fp32,
is_rms_norm, is_rms_norm,
...@@ -600,6 +641,7 @@ def rms_norm_fn( ...@@ -600,6 +641,7 @@ def rms_norm_fn(
residual=None, residual=None,
eps=1e-6, eps=1e-6,
dropout_p=0.0, dropout_p=0.0,
rowscale=None,
prenorm=False, prenorm=False,
residual_in_fp32=False, residual_in_fp32=False,
return_dropout_mask=False, return_dropout_mask=False,
...@@ -611,6 +653,7 @@ def rms_norm_fn( ...@@ -611,6 +653,7 @@ def rms_norm_fn(
residual, residual,
eps, eps,
dropout_p, dropout_p,
rowscale,
prenorm, prenorm,
residual_in_fp32, residual_in_fp32,
True, True,
......
# Copyright (c) 2023, Tri Dao. # Copyright (c) 2024, Tri Dao.
import pytest import pytest
import torch import torch
...@@ -16,14 +16,16 @@ from flash_attn.ops.triton.layernorm import ( ...@@ -16,14 +16,16 @@ from flash_attn.ops.triton.layernorm import (
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize("has_rowscale", [False, True])
# @pytest.mark.parametrize("has_rowscale", [True])
@pytest.mark.parametrize("dropout_p", [0.0, 0.27]) @pytest.mark.parametrize("dropout_p", [0.0, 0.27])
# @pytest.mark.parametrize("dropout_p", [0.27]) # @pytest.mark.parametrize("dropout_p", [0.0])
@pytest.mark.parametrize("prenorm", [True, False]) @pytest.mark.parametrize("prenorm", [True, False])
# @pytest.mark.parametrize("prenorm", [False]) # @pytest.mark.parametrize("prenorm", [True])
@pytest.mark.parametrize("is_rms_norm", [False, True]) @pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize("is_rms_norm", [True]) # @pytest.mark.parametrize("is_rms_norm", [True])
@pytest.mark.parametrize("has_residual", [True, False]) @pytest.mark.parametrize("has_residual", [True, False])
# @pytest.mark.parametrize("has_residual", [True]) # @pytest.mark.parametrize("has_residual", [False])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else []) "weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else [])
) )
...@@ -45,6 +47,7 @@ def test_layer_norm( ...@@ -45,6 +47,7 @@ def test_layer_norm(
is_rms_norm, is_rms_norm,
prenorm, prenorm,
dropout_p, dropout_p,
has_rowscale,
): ):
device = "cuda" device = "cuda"
if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]): if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
...@@ -60,7 +63,8 @@ def test_layer_norm( ...@@ -60,7 +63,8 @@ def test_layer_norm(
layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
allclose = ( allclose = (
lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max() lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
<= 2 * (x_pt - x_ref).abs().max() + atol # Sometimes x0_pt.grad is NaN
<= 2 * (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() + atol
) )
x0 = torch.randn( x0 = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
...@@ -83,6 +87,8 @@ def test_layer_norm( ...@@ -83,6 +87,8 @@ def test_layer_norm(
bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
rowscale = torch.randn(batch_size, seqlen, dtype=input_dtype, device=device) if has_rowscale else None
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, *rest = layer_norm_fn( out, *rest = layer_norm_fn(
x0, x0,
...@@ -91,6 +97,7 @@ def test_layer_norm( ...@@ -91,6 +97,7 @@ def test_layer_norm(
residual=res, residual=res,
eps=1e-6, eps=1e-6,
dropout_p=dropout_p, dropout_p=dropout_p,
rowscale=rowscale,
prenorm=prenorm, prenorm=prenorm,
residual_in_fp32=residual_in_fp32, residual_in_fp32=residual_in_fp32,
is_rms_norm=is_rms_norm, is_rms_norm=is_rms_norm,
...@@ -104,6 +111,7 @@ def test_layer_norm( ...@@ -104,6 +111,7 @@ def test_layer_norm(
residual=res_pt, residual=res_pt,
eps=1e-6, eps=1e-6,
dropout_p=dropout_p, dropout_p=dropout_p,
rowscale=rowscale,
prenorm=prenorm, prenorm=prenorm,
dropout_mask=dropout_mask, dropout_mask=dropout_mask,
) )
...@@ -114,6 +122,7 @@ def test_layer_norm( ...@@ -114,6 +122,7 @@ def test_layer_norm(
residual=res_ref, residual=res_ref,
eps=1e-6, eps=1e-6,
dropout_p=dropout_p, dropout_p=dropout_p,
rowscale=rowscale,
prenorm=prenorm, prenorm=prenorm,
dropout_mask=dropout_mask, dropout_mask=dropout_mask,
upcast=True, upcast=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment