Commit 665b55e2 authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Implement parallel layer norm in Triton

parent aa5c6438
This diff is collapsed.
......@@ -16,12 +16,16 @@ from flash_attn.ops.triton.layernorm import (
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize("has_weight1", [False, True])
# @pytest.mark.parametrize("has_weight1", [True])
@pytest.mark.parametrize("has_x1", [False, True])
# @pytest.mark.parametrize("has_x1", [False])
@pytest.mark.parametrize("has_rowscale", [False, True])
# @pytest.mark.parametrize("has_rowscale", [True])
# @pytest.mark.parametrize("has_rowscale", [False])
@pytest.mark.parametrize("dropout_p", [0.0, 0.27])
# @pytest.mark.parametrize("dropout_p", [0.0])
@pytest.mark.parametrize("prenorm", [True, False])
# @pytest.mark.parametrize("prenorm", [True])
# @pytest.mark.parametrize("prenorm", [False])
@pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize("is_rms_norm", [True])
@pytest.mark.parametrize("has_residual", [True, False])
......@@ -48,7 +52,11 @@ def test_layer_norm(
prenorm,
dropout_p,
has_rowscale,
has_x1,
has_weight1,
):
if has_rowscale and has_x1:
pytest.skip("Not supported")
device = "cuda"
if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
atol = 5e-2
......@@ -62,9 +70,16 @@ def test_layer_norm(
seqlen = 512
layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
allclose = (
lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
# Sometimes x0_pt.grad is NaN
lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
<= 2 * (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() + atol
or (
# Sometimes x_pt and x_ref are the same (e.g. bfloat16) so we want to perturb is a bit
# by multiply and divide by 0.3
(x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() == 0.0
and (x - x_ref).abs().max()
<= 2 * (x_pt[~x_pt.isnan()] * 0.3 / 0.3 - x_ref[~x_pt.isnan()]).abs().max() + atol
)
)
x0 = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
......@@ -86,8 +101,35 @@ def test_layer_norm(
weight_ref = weight.detach().clone().requires_grad_()
bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
if has_x1:
x1 = torch.randn_like(x0, dtype=input_dtype, requires_grad=True)
x1_pt = x1.detach().clone().requires_grad_()
x1_ref = x1.detach().clone().requires_grad_()
else:
x1, x1_pt, x1_ref = None, None, None
if has_weight1:
weight1 = torch.randn(
hidden_size, device=device, dtype=weight_dtype, requires_grad=True
)
weight1_pt = weight1.detach().clone().requires_grad_()
weight1_ref = weight1.detach().clone().requires_grad_()
if not is_rms_norm:
bias1 = torch.randn(
hidden_size, device=device, dtype=weight_dtype, requires_grad=True
)
else:
bias1 = None
bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
bias1_ref = bias1.detach().clone().requires_grad_() if bias1 is not None else None
else:
weight1, weight1_pt, weight1_ref = None, None, None
bias1, bias1_pt, bias1_ref = None, None, None
rowscale = torch.randn(batch_size, seqlen, dtype=input_dtype, device=device) if has_rowscale else None
rowscale = (
torch.randn(batch_size, seqlen, dtype=input_dtype, device=device)
if has_rowscale
else None
)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, *rest = layer_norm_fn(
......@@ -95,6 +137,9 @@ def test_layer_norm(
weight,
bias,
residual=res,
x1=x1,
weight1=weight1,
bias1=bias1,
eps=1e-6,
dropout_p=dropout_p,
rowscale=rowscale,
......@@ -103,44 +148,75 @@ def test_layer_norm(
is_rms_norm=is_rms_norm,
return_dropout_mask=True,
)
dropout_mask = rest[-1] if dropout_p > 0.0 else None
dropout_mask = rest[-2] if dropout_p > 0.0 else None
dropout_mask1 = rest[-1] if dropout_p > 0.0 and x1 is not None else None
out_pt = layer_norm_ref_fn(
x0_pt,
weight_pt,
bias_pt,
residual=res_pt,
x1=x1_pt,
weight1=weight1_pt,
bias1=bias1_pt,
eps=1e-6,
dropout_p=dropout_p,
rowscale=rowscale,
prenorm=prenorm,
dropout_mask=dropout_mask,
dropout_mask1=dropout_mask1,
)
out_ref = layer_norm_ref_fn(
x0_ref,
weight_ref,
bias_ref,
residual=res_ref,
x1=x1_ref,
weight1=weight1_ref,
bias1=bias1_ref,
eps=1e-6,
dropout_p=dropout_p,
rowscale=rowscale,
prenorm=prenorm,
dropout_mask=dropout_mask,
dropout_mask1=dropout_mask1,
upcast=True,
)
if prenorm:
residual = rest[0]
out_pt, residual_pt = out_pt
out_ref, residual_ref = out_ref
if not has_weight1:
if prenorm:
residual = rest[0]
out_pt, residual_pt = out_pt
out_ref, residual_ref = out_ref
out1, out1_pt, out1_ref = None, None, None
else:
out1 = rest.pop(0)
if prenorm:
residual = rest[0]
out_pt, out1_pt, residual_pt = out_pt
out_ref, out1_ref, residual_ref = out_ref
else:
out_pt, out1_pt = out_pt
out_ref, out1_ref = out_ref
assert out.dtype == input_dtype
if prenorm:
assert residual.dtype == residual_dtype
assert allclose(residual, residual_pt, residual_ref)
assert allclose(out, out_pt, out_ref)
if out1 is not None:
assert out1.dtype == input_dtype
assert allclose(out1, out1_pt, out1_ref)
if dropout_mask is not None:
dropout_fraction = 1.0 - dropout_mask.float().mean()
assert abs(dropout_fraction - dropout_p) < 0.01
if dropout_mask1 is not None:
dropout_fraction = 1.0 - dropout_mask1.float().mean()
assert abs(dropout_fraction - dropout_p) < 0.01
assert not torch.equal(dropout_mask, dropout_mask1)
g = torch.randn_like(out) / batch_size
if has_weight1:
out = out * F.gelu(out1)
out_pt = out_pt * F.gelu(out1_pt)
out_ref = out_ref * F.gelu(out1_ref)
if not prenorm:
out.backward(g)
out_pt.backward(g)
......@@ -152,9 +228,15 @@ def test_layer_norm(
assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)
if has_residual:
assert allclose(res.grad, res_pt.grad, res_ref.grad)
if has_x1:
assert allclose(x1.grad, x1_pt.grad, x1_ref.grad)
assert allclose(weight.grad, weight_pt.grad, weight_ref.grad)
if bias is not None:
assert allclose(bias.grad, bias_pt.grad, bias_ref.grad)
if has_weight1:
assert allclose(weight1.grad, weight1_pt.grad, weight1_ref.grad)
if bias1 is not None:
assert allclose(bias1.grad, bias1_pt.grad, bias1_ref.grad)
@pytest.mark.parametrize("prenorm", [True, False])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment