Commit 665b55e2 authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Implement parallel layer norm in Triton

parent aa5c6438
......@@ -21,20 +21,28 @@ def layer_norm_ref(
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
dropout_mask=None,
dropout_mask1=None,
upcast=False,
):
dtype = x.dtype
if upcast:
x = x.float()
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual
x1 = x1.float() if x1 is not None else None
weight1 = weight1.float() if weight1 is not None else None
bias1 = bias1.float() if bias1 is not None else None
if x1 is not None:
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
if rowscale is not None:
x = x * rowscale[..., None]
if dropout_p > 0.0:
......@@ -42,12 +50,25 @@ def layer_norm_ref(
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
else:
x = F.dropout(x, p=dropout_p)
if x1 is not None:
if dropout_mask1 is not None:
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
else:
x1 = F.dropout(x1, p=dropout_p)
if x1 is not None:
x = x + x1
if residual is not None:
x = (x + residual).to(x.dtype)
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
dtype
)
return out if not prenorm else (out, x)
if weight1 is None:
return out if not prenorm else (out, x)
else:
out1 = F.layer_norm(
x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
).to(dtype)
return (out, out1) if not prenorm else (out, out1, x)
def rms_norm_ref(
......@@ -55,20 +76,28 @@ def rms_norm_ref(
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
dropout_mask=None,
dropout_mask1=None,
upcast=False,
):
dtype = x.dtype
if upcast:
x = x.float()
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual
x1 = x1.float() if x1 is not None else None
weight1 = weight1.float() if weight1 is not None else None
bias1 = bias1.float() if bias1 is not None else None
if x1 is not None:
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
if rowscale is not None:
x = x * rowscale[..., None]
if dropout_p > 0.0:
......@@ -76,12 +105,24 @@ def rms_norm_ref(
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
else:
x = F.dropout(x, p=dropout_p)
if x1 is not None:
if dropout_mask1 is not None:
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
else:
x1 = F.dropout(x1, p=dropout_p)
if x1 is not None:
x = x + x1
if residual is not None:
x = (x + residual).to(x.dtype)
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
out = out.to(dtype)
return out if not prenorm else (out, x)
out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
if weight1 is None:
return out if not prenorm else (out, x)
else:
out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
dtype
)
return (out, out1) if not prenorm else (out, out1, x)
@triton.autotune(
......@@ -97,6 +138,9 @@ def rms_norm_ref(
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
......@@ -104,6 +148,10 @@ def _layer_norm_fwd_1pass_kernel(
W, # pointer to the weights
B, # pointer to the biases
RESIDUAL, # pointer to the residual
X1,
W1,
B1,
Y1,
RESIDUAL_OUT, # pointer to the residual
ROWSCALE,
SEEDS, # Dropout seeds for each row
......@@ -114,6 +162,9 @@ def _layer_norm_fwd_1pass_kernel(
stride_y_row,
stride_res_row,
stride_res_out_row,
stride_x1_row,
stride_y1_row,
M, # number of rows in X
N, # number of columns in X
eps, # epsilon to avoid division by zero
dropout_p, # Dropout probability
......@@ -125,6 +176,9 @@ def _layer_norm_fwd_1pass_kernel(
HAS_DROPOUT: tl.constexpr,
STORE_DROPOUT_MASK: tl.constexpr,
HAS_ROWSCALE: tl.constexpr,
HAS_X1: tl.constexpr,
HAS_W1: tl.constexpr,
HAS_B1: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
......@@ -134,6 +188,10 @@ def _layer_norm_fwd_1pass_kernel(
RESIDUAL += row * stride_res_row
if STORE_RESIDUAL_OUT:
RESIDUAL_OUT += row * stride_res_out_row
if HAS_X1:
X1 += row * stride_x1_row
if HAS_W1:
Y1 += row * stride_y1_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
......@@ -147,6 +205,21 @@ def _layer_norm_fwd_1pass_kernel(
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
if STORE_DROPOUT_MASK:
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
if HAS_X1:
x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_ROWSCALE:
rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
x1 *= rowscale
if HAS_DROPOUT:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask = (
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
)
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
if STORE_DROPOUT_MASK:
tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
x += x1
if HAS_RESIDUAL:
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
x += residual
......@@ -171,6 +244,12 @@ def _layer_norm_fwd_1pass_kernel(
y = x_hat * w + b if HAS_BIAS else x_hat * w
# Write output
tl.store(Y + cols, y, mask=mask)
if HAS_W1:
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
if HAS_B1:
b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
tl.store(Y1 + cols, y1, mask=mask)
def _layer_norm_fwd(
......@@ -179,6 +258,9 @@ def _layer_norm_fwd(
bias,
eps,
residual=None,
x1=None,
weight1=None,
bias1=None,
dropout_p=0.0,
rowscale=None,
out_dtype=None,
......@@ -198,17 +280,33 @@ def _layer_norm_fwd(
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
if x1 is not None:
assert x1.shape == x.shape
assert rowscale is None
assert x1.stride(-1) == 1
if weight1 is not None:
assert weight1.shape == (N,)
assert weight1.stride(-1) == 1
if bias1 is not None:
assert bias1.shape == (N,)
assert bias1.stride(-1) == 1
if rowscale is not None:
assert rowscale.is_contiguous()
assert rowscale.shape == (M,)
# allocate output
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
assert y.stride(-1) == 1
if weight1 is not None:
y1 = torch.empty_like(y)
assert y1.stride(-1) == 1
else:
y1 = None
if (
residual is not None
or (residual_dtype is not None and residual_dtype != x.dtype)
or dropout_p > 0.0
or rowscale is not None
or x1 is not None
):
residual_out = torch.empty(
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
......@@ -219,11 +317,13 @@ def _layer_norm_fwd(
mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
if dropout_p > 0.0:
seeds = torch.randint(2**32, (M,), device=x.device, dtype=torch.int64)
seeds = torch.randint(
2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
)
else:
seeds = None
if return_dropout_mask and dropout_p > 0.0:
dropout_mask = torch.empty_like(x, dtype=torch.bool)
dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
else:
dropout_mask = None
# Less than 64KB per feature: enqueue fused kernel
......@@ -231,7 +331,6 @@ def _layer_norm_fwd(
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
with torch.cuda.device(x.device.index):
_layer_norm_fwd_1pass_kernel[(M,)](
x,
......@@ -239,6 +338,10 @@ def _layer_norm_fwd(
weight,
bias,
residual,
x1,
weight1,
bias1,
y1,
residual_out,
rowscale,
seeds,
......@@ -249,6 +352,9 @@ def _layer_norm_fwd(
y.stride(0),
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0,
x1.stride(0) if x1 is not None else 0,
y1.stride(0) if y1 is not None else 0,
M,
N,
eps,
dropout_p,
......@@ -262,7 +368,20 @@ def _layer_norm_fwd(
rowscale is not None,
)
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
return y, mean, rstd, residual_out if residual_out is not None else x, seeds, dropout_mask
if dropout_mask is not None and x1 is not None:
dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
else:
dropout_mask1 = None
return (
y,
y1,
mean,
rstd,
residual_out if residual_out is not None else x,
seeds,
dropout_mask,
dropout_mask1,
)
@triton.autotune(
......@@ -280,6 +399,9 @@ def _layer_norm_fwd(
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
@triton.jit
def _layer_norm_bwd_kernel(
......@@ -292,6 +414,11 @@ def _layer_norm_bwd_kernel(
DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
DRESIDUAL,
W1,
DY1,
DX1,
DW1,
DB1,
DRESIDUAL_IN,
ROWSCALE,
SEEDS,
......@@ -302,6 +429,8 @@ def _layer_norm_bwd_kernel(
stride_dy_row,
stride_dx_row,
stride_dres_row,
stride_dy1_row,
stride_dx1_row,
stride_dres_in_row,
M, # number of rows in X
N, # number of columns in X
......@@ -315,6 +444,9 @@ def _layer_norm_bwd_kernel(
HAS_BIAS: tl.constexpr,
HAS_DROPOUT: tl.constexpr,
HAS_ROWSCALE: tl.constexpr,
HAS_DY1: tl.constexpr,
HAS_DX1: tl.constexpr,
HAS_B1: tl.constexpr,
RECOMPUTE_OUTPUT: tl.constexpr,
):
# Map the program id to the elements of X, DX, and DY it should compute.
......@@ -331,19 +463,31 @@ def _layer_norm_bwd_kernel(
DRESIDUAL_IN += row_start * stride_dres_in_row
DY += row_start * stride_dy_row
DX += row_start * stride_dx_row
if HAS_DY1:
DY1 += row_start * stride_dy1_row
if HAS_DX1:
DX1 += row_start * stride_dx1_row
if RECOMPUTE_OUTPUT:
Y += row_start * stride_y_row
w = tl.load(W + cols, mask=mask).to(tl.float32)
if RECOMPUTE_OUTPUT and HAS_BIAS:
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
if HAS_DY1:
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
if HAS_BIAS:
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
if HAS_DY1:
dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
if HAS_B1:
db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
row_end = min((row_block_id + 1) * rows_per_program, M)
for row in range(row_start, row_end):
# Load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
if HAS_DY1:
dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
if not IS_RMS_NORM:
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
......@@ -357,6 +501,11 @@ def _layer_norm_bwd_kernel(
dw += dy * xhat
if HAS_BIAS:
db += dy
if HAS_DY1:
wdy += w1 * dy1
dw1 += dy1 * xhat
if HAS_B1:
db1 += dy1
if not IS_RMS_NORM:
c1 = tl.sum(xhat * wdy, axis=0) / N
c2 = tl.sum(wdy, axis=0) / N
......@@ -370,6 +519,15 @@ def _layer_norm_bwd_kernel(
# Write dx
if STORE_DRESIDUAL:
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
if HAS_DX1:
if HAS_DROPOUT:
keep_mask = (
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
)
dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
else:
dx1 = dx
tl.store(DX1 + cols, dx1, mask=mask)
if HAS_DROPOUT:
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
......@@ -387,9 +545,17 @@ def _layer_norm_bwd_kernel(
Y += stride_y_row
DY += stride_dy_row
DX += stride_dx_row
if HAS_DY1:
DY1 += stride_dy1_row
if HAS_DX1:
DX1 += stride_dx1_row
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
if HAS_BIAS:
tl.store(DB + row_block_id * N + cols, db, mask=mask)
if HAS_DY1:
tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
if HAS_B1:
tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
def _layer_norm_bwd(
......@@ -401,10 +567,14 @@ def _layer_norm_bwd(
mean,
rstd,
dresidual=None,
dy1=None,
weight1=None,
bias1=None,
seeds=None,
dropout_p=0.0,
rowscale=None,
has_residual=False,
has_x1=False,
is_rms_norm=False,
x_dtype=None,
recompute_output=False,
......@@ -421,9 +591,19 @@ def _layer_norm_bwd(
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
if dy1 is not None:
assert weight1 is not None
assert dy1.shape == dy.shape
assert dy1.stride(-1) == 1
if weight1 is not None:
assert weight1.shape == (N,)
assert weight1.stride(-1) == 1
if bias1 is not None:
assert bias1.shape == (N,)
assert bias1.stride(-1) == 1
if seeds is not None:
assert seeds.is_contiguous()
assert seeds.shape == (M,)
assert seeds.shape == (M if not has_x1 else M * 2,)
if rowscale is not None:
assert rowscale.is_contiguous()
assert rowscale.shape == (M,)
......@@ -435,10 +615,14 @@ def _layer_norm_bwd(
)
dresidual_in = (
torch.empty_like(x)
if has_residual and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None)
if has_residual
and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
else None
)
dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
if recompute_output:
assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
......@@ -452,6 +636,8 @@ def _layer_norm_bwd(
if bias is not None
else None
)
_dw1 = torch.empty_like(_dw) if weight1 is not None else None
_db1 = torch.empty_like(_db) if bias1 is not None else None
rows_per_program = math.ceil(M / sm_count)
grid = (sm_count,)
with torch.cuda.device(x.device.index):
......@@ -465,6 +651,11 @@ def _layer_norm_bwd(
_dw,
_db,
dresidual,
weight1,
dy1,
dx1,
_dw1,
_db1,
dresidual_in,
rowscale,
seeds,
......@@ -475,6 +666,8 @@ def _layer_norm_bwd(
dy.stride(0),
dx.stride(0),
dresidual.stride(0) if dresidual is not None else 0,
dy1.stride(0) if dy1 is not None else 0,
dx1.stride(0) if dx1 is not None else 0,
dresidual_in.stride(0) if dresidual_in is not None else 0,
M,
N,
......@@ -490,10 +683,18 @@ def _layer_norm_bwd(
)
dw = _dw.sum(0).to(weight.dtype)
db = _db.sum(0).to(bias.dtype) if bias is not None else None
dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
# Don't need to compute dresidual_in separately in this case
if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
dresidual_in = dx
return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
if has_x1 and dropout_p == 0.0:
dx1 = dx
return (
(dx, dw, db, dresidual_in, dx1, dw1, db1)
if not recompute_output
else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
)
class LayerNormFn(torch.autograd.Function):
......@@ -504,6 +705,9 @@ class LayerNormFn(torch.autograd.Function):
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
......@@ -522,9 +726,19 @@ class LayerNormFn(torch.autograd.Function):
residual = residual.reshape(-1, residual.shape[-1])
if residual.stride(-1) != 1:
residual = residual.contiguous()
if x1 is not None:
assert x1.shape == x_shape_og
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
x1 = x1.reshape(-1, x1.shape[-1])
if x1.stride(-1) != 1:
x1 = x1.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
if weight1 is not None:
weight1 = weight1.contiguous()
if bias1 is not None:
bias1 = bias1.contiguous()
if rowscale is not None:
rowscale = rowscale.reshape(-1).contiguous()
residual_dtype = (
......@@ -532,41 +746,71 @@ class LayerNormFn(torch.autograd.Function):
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
y, mean, rstd, residual_out, seeds, dropout_mask = _layer_norm_fwd(
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
x,
weight,
bias,
eps,
residual,
x1,
weight1,
bias1,
dropout_p=dropout_p,
rowscale=rowscale,
residual_dtype=residual_dtype,
is_rms_norm=is_rms_norm,
return_dropout_mask=return_dropout_mask,
)
ctx.save_for_backward(residual_out, weight, bias, rowscale, seeds, mean, rstd)
ctx.save_for_backward(
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.dropout_p = dropout_p
ctx.is_rms_norm = is_rms_norm
ctx.has_residual = residual is not None
ctx.has_x1 = x1 is not None
ctx.prenorm = prenorm
ctx.x_dtype = x.dtype
y = y.reshape(x_shape_og)
y1 = y1.reshape(x_shape_og) if y1 is not None else None
residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
if not return_dropout_mask:
return y if not prenorm else (y, residual_out)
if weight1 is None:
return y if not prenorm else (y, residual_out)
else:
return (y, y1) if not prenorm else (y, y1, residual_out)
else:
return (y, dropout_mask) if not prenorm else (y, residual_out, dropout_mask)
if weight1 is None:
return (
(y, dropout_mask, dropout_mask1)
if not prenorm
else (y, residual_out, dropout_mask, dropout_mask1)
)
else:
return (
(y, y1, dropout_mask, dropout_mask1)
if not prenorm
else (y, y1, residual_out, dropout_mask, dropout_mask1)
)
@staticmethod
def backward(ctx, dy, *args):
x, weight, bias, rowscale, seeds, mean, rstd = ctx.saved_tensors
x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
dy = dy.reshape(-1, dy.shape[-1])
if dy.stride(-1) != 1:
dy = dy.contiguous()
assert dy.shape == x.shape
if weight1 is not None:
dy1, args = args[0], args[1:]
dy1 = dy1.reshape(-1, dy1.shape[-1])
if dy1.stride(-1) != 1:
dy1 = dy1.contiguous()
assert dy1.shape == x.shape
else:
dy1 = None
if ctx.prenorm:
dresidual = args[0]
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
......@@ -575,7 +819,7 @@ class LayerNormFn(torch.autograd.Function):
assert dresidual.shape == x.shape
else:
dresidual = None
dx, dw, db, dresidual_in = _layer_norm_bwd(
dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
dy,
x,
weight,
......@@ -584,10 +828,14 @@ class LayerNormFn(torch.autograd.Function):
mean,
rstd,
dresidual,
dy1,
weight1,
bias1,
seeds,
ctx.dropout_p,
rowscale,
ctx.has_residual,
ctx.has_x1,
ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
)
......@@ -596,6 +844,9 @@ class LayerNormFn(torch.autograd.Function):
dw,
db,
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
dw1,
db1,
None,
None,
None,
......@@ -611,6 +862,9 @@ def layer_norm_fn(
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
......@@ -624,6 +878,9 @@ def layer_norm_fn(
weight,
bias,
residual,
x1,
weight1,
bias1,
eps,
dropout_p,
rowscale,
......@@ -639,6 +896,9 @@ def rms_norm_fn(
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
......@@ -651,6 +911,9 @@ def rms_norm_fn(
weight,
bias,
residual,
x1,
weight1,
bias1,
eps,
dropout_p,
rowscale,
......@@ -662,11 +925,15 @@ def rms_norm_fn(
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.dropout_p = dropout_p
if dropout_p > 0.0:
self.drop = torch.nn.Dropout(dropout_p)
else:
self.drop = None
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.reset_parameters()
......@@ -681,7 +948,7 @@ class RMSNorm(torch.nn.Module):
self.bias,
residual=residual,
eps=self.eps,
dropout_p=self.dropout_p if self.training else 0.0,
dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
)
......
......@@ -16,12 +16,16 @@ from flash_attn.ops.triton.layernorm import (
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize("has_weight1", [False, True])
# @pytest.mark.parametrize("has_weight1", [True])
@pytest.mark.parametrize("has_x1", [False, True])
# @pytest.mark.parametrize("has_x1", [False])
@pytest.mark.parametrize("has_rowscale", [False, True])
# @pytest.mark.parametrize("has_rowscale", [True])
# @pytest.mark.parametrize("has_rowscale", [False])
@pytest.mark.parametrize("dropout_p", [0.0, 0.27])
# @pytest.mark.parametrize("dropout_p", [0.0])
@pytest.mark.parametrize("prenorm", [True, False])
# @pytest.mark.parametrize("prenorm", [True])
# @pytest.mark.parametrize("prenorm", [False])
@pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize("is_rms_norm", [True])
@pytest.mark.parametrize("has_residual", [True, False])
......@@ -48,7 +52,11 @@ def test_layer_norm(
prenorm,
dropout_p,
has_rowscale,
has_x1,
has_weight1,
):
if has_rowscale and has_x1:
pytest.skip("Not supported")
device = "cuda"
if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
atol = 5e-2
......@@ -62,9 +70,16 @@ def test_layer_norm(
seqlen = 512
layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
allclose = (
lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
# Sometimes x0_pt.grad is NaN
lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
<= 2 * (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() + atol
or (
# Sometimes x_pt and x_ref are the same (e.g. bfloat16) so we want to perturb is a bit
# by multiply and divide by 0.3
(x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() == 0.0
and (x - x_ref).abs().max()
<= 2 * (x_pt[~x_pt.isnan()] * 0.3 / 0.3 - x_ref[~x_pt.isnan()]).abs().max() + atol
)
)
x0 = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
......@@ -86,8 +101,35 @@ def test_layer_norm(
weight_ref = weight.detach().clone().requires_grad_()
bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
if has_x1:
x1 = torch.randn_like(x0, dtype=input_dtype, requires_grad=True)
x1_pt = x1.detach().clone().requires_grad_()
x1_ref = x1.detach().clone().requires_grad_()
else:
x1, x1_pt, x1_ref = None, None, None
if has_weight1:
weight1 = torch.randn(
hidden_size, device=device, dtype=weight_dtype, requires_grad=True
)
weight1_pt = weight1.detach().clone().requires_grad_()
weight1_ref = weight1.detach().clone().requires_grad_()
if not is_rms_norm:
bias1 = torch.randn(
hidden_size, device=device, dtype=weight_dtype, requires_grad=True
)
else:
bias1 = None
bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
bias1_ref = bias1.detach().clone().requires_grad_() if bias1 is not None else None
else:
weight1, weight1_pt, weight1_ref = None, None, None
bias1, bias1_pt, bias1_ref = None, None, None
rowscale = torch.randn(batch_size, seqlen, dtype=input_dtype, device=device) if has_rowscale else None
rowscale = (
torch.randn(batch_size, seqlen, dtype=input_dtype, device=device)
if has_rowscale
else None
)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, *rest = layer_norm_fn(
......@@ -95,6 +137,9 @@ def test_layer_norm(
weight,
bias,
residual=res,
x1=x1,
weight1=weight1,
bias1=bias1,
eps=1e-6,
dropout_p=dropout_p,
rowscale=rowscale,
......@@ -103,44 +148,75 @@ def test_layer_norm(
is_rms_norm=is_rms_norm,
return_dropout_mask=True,
)
dropout_mask = rest[-1] if dropout_p > 0.0 else None
dropout_mask = rest[-2] if dropout_p > 0.0 else None
dropout_mask1 = rest[-1] if dropout_p > 0.0 and x1 is not None else None
out_pt = layer_norm_ref_fn(
x0_pt,
weight_pt,
bias_pt,
residual=res_pt,
x1=x1_pt,
weight1=weight1_pt,
bias1=bias1_pt,
eps=1e-6,
dropout_p=dropout_p,
rowscale=rowscale,
prenorm=prenorm,
dropout_mask=dropout_mask,
dropout_mask1=dropout_mask1,
)
out_ref = layer_norm_ref_fn(
x0_ref,
weight_ref,
bias_ref,
residual=res_ref,
x1=x1_ref,
weight1=weight1_ref,
bias1=bias1_ref,
eps=1e-6,
dropout_p=dropout_p,
rowscale=rowscale,
prenorm=prenorm,
dropout_mask=dropout_mask,
dropout_mask1=dropout_mask1,
upcast=True,
)
if prenorm:
residual = rest[0]
out_pt, residual_pt = out_pt
out_ref, residual_ref = out_ref
if not has_weight1:
if prenorm:
residual = rest[0]
out_pt, residual_pt = out_pt
out_ref, residual_ref = out_ref
out1, out1_pt, out1_ref = None, None, None
else:
out1 = rest.pop(0)
if prenorm:
residual = rest[0]
out_pt, out1_pt, residual_pt = out_pt
out_ref, out1_ref, residual_ref = out_ref
else:
out_pt, out1_pt = out_pt
out_ref, out1_ref = out_ref
assert out.dtype == input_dtype
if prenorm:
assert residual.dtype == residual_dtype
assert allclose(residual, residual_pt, residual_ref)
assert allclose(out, out_pt, out_ref)
if out1 is not None:
assert out1.dtype == input_dtype
assert allclose(out1, out1_pt, out1_ref)
if dropout_mask is not None:
dropout_fraction = 1.0 - dropout_mask.float().mean()
assert abs(dropout_fraction - dropout_p) < 0.01
if dropout_mask1 is not None:
dropout_fraction = 1.0 - dropout_mask1.float().mean()
assert abs(dropout_fraction - dropout_p) < 0.01
assert not torch.equal(dropout_mask, dropout_mask1)
g = torch.randn_like(out) / batch_size
if has_weight1:
out = out * F.gelu(out1)
out_pt = out_pt * F.gelu(out1_pt)
out_ref = out_ref * F.gelu(out1_ref)
if not prenorm:
out.backward(g)
out_pt.backward(g)
......@@ -152,9 +228,15 @@ def test_layer_norm(
assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)
if has_residual:
assert allclose(res.grad, res_pt.grad, res_ref.grad)
if has_x1:
assert allclose(x1.grad, x1_pt.grad, x1_ref.grad)
assert allclose(weight.grad, weight_pt.grad, weight_ref.grad)
if bias is not None:
assert allclose(bias.grad, bias_pt.grad, bias_ref.grad)
if has_weight1:
assert allclose(weight1.grad, weight1_pt.grad, weight1_ref.grad)
if bias1 is not None:
assert allclose(bias1.grad, bias1_pt.grad, bias1_ref.grad)
@pytest.mark.parametrize("prenorm", [True, False])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment