Commit 665b55e2 authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Implement parallel layer norm in Triton

parent aa5c6438
...@@ -21,20 +21,28 @@ def layer_norm_ref( ...@@ -21,20 +21,28 @@ def layer_norm_ref(
weight, weight,
bias, bias,
residual=None, residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6, eps=1e-6,
dropout_p=0.0, dropout_p=0.0,
rowscale=None, rowscale=None,
prenorm=False, prenorm=False,
dropout_mask=None, dropout_mask=None,
dropout_mask1=None,
upcast=False, upcast=False,
): ):
dtype = x.dtype dtype = x.dtype
if upcast: if upcast:
x = x.float()
weight = weight.float() weight = weight.float()
bias = bias.float() if bias is not None else None bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual residual = residual.float() if residual is not None else residual
x1 = x1.float() if x1 is not None else None
weight1 = weight1.float() if weight1 is not None else None
bias1 = bias1.float() if bias1 is not None else None
if x1 is not None:
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
if rowscale is not None: if rowscale is not None:
x = x * rowscale[..., None] x = x * rowscale[..., None]
if dropout_p > 0.0: if dropout_p > 0.0:
...@@ -42,12 +50,25 @@ def layer_norm_ref( ...@@ -42,12 +50,25 @@ def layer_norm_ref(
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
else: else:
x = F.dropout(x, p=dropout_p) x = F.dropout(x, p=dropout_p)
if x1 is not None:
if dropout_mask1 is not None:
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
else:
x1 = F.dropout(x1, p=dropout_p)
if x1 is not None:
x = x + x1
if residual is not None: if residual is not None:
x = (x + residual).to(x.dtype) x = (x + residual).to(x.dtype)
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
dtype dtype
) )
if weight1 is None:
return out if not prenorm else (out, x) return out if not prenorm else (out, x)
else:
out1 = F.layer_norm(
x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
).to(dtype)
return (out, out1) if not prenorm else (out, out1, x)
def rms_norm_ref( def rms_norm_ref(
...@@ -55,20 +76,28 @@ def rms_norm_ref( ...@@ -55,20 +76,28 @@ def rms_norm_ref(
weight, weight,
bias, bias,
residual=None, residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6, eps=1e-6,
dropout_p=0.0, dropout_p=0.0,
rowscale=None, rowscale=None,
prenorm=False, prenorm=False,
dropout_mask=None, dropout_mask=None,
dropout_mask1=None,
upcast=False, upcast=False,
): ):
dtype = x.dtype dtype = x.dtype
if upcast: if upcast:
x = x.float()
weight = weight.float() weight = weight.float()
bias = bias.float() if bias is not None else None bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
residual = residual.float() if residual is not None else residual residual = residual.float() if residual is not None else residual
x1 = x1.float() if x1 is not None else None
weight1 = weight1.float() if weight1 is not None else None
bias1 = bias1.float() if bias1 is not None else None
if x1 is not None:
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
if rowscale is not None: if rowscale is not None:
x = x * rowscale[..., None] x = x * rowscale[..., None]
if dropout_p > 0.0: if dropout_p > 0.0:
...@@ -76,12 +105,24 @@ def rms_norm_ref( ...@@ -76,12 +105,24 @@ def rms_norm_ref(
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
else: else:
x = F.dropout(x, p=dropout_p) x = F.dropout(x, p=dropout_p)
if x1 is not None:
if dropout_mask1 is not None:
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
else:
x1 = F.dropout(x1, p=dropout_p)
if x1 is not None:
x = x + x1
if residual is not None: if residual is not None:
x = (x + residual).to(x.dtype) x = (x + residual).to(x.dtype)
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
out = out.to(dtype) if weight1 is None:
return out if not prenorm else (out, x) return out if not prenorm else (out, x)
else:
out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
dtype
)
return (out, out1) if not prenorm else (out, out1, x)
@triton.autotune( @triton.autotune(
...@@ -97,6 +138,9 @@ def rms_norm_ref( ...@@ -97,6 +138,9 @@ def rms_norm_ref(
) )
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
@triton.jit @triton.jit
def _layer_norm_fwd_1pass_kernel( def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input X, # pointer to the input
...@@ -104,6 +148,10 @@ def _layer_norm_fwd_1pass_kernel( ...@@ -104,6 +148,10 @@ def _layer_norm_fwd_1pass_kernel(
W, # pointer to the weights W, # pointer to the weights
B, # pointer to the biases B, # pointer to the biases
RESIDUAL, # pointer to the residual RESIDUAL, # pointer to the residual
X1,
W1,
B1,
Y1,
RESIDUAL_OUT, # pointer to the residual RESIDUAL_OUT, # pointer to the residual
ROWSCALE, ROWSCALE,
SEEDS, # Dropout seeds for each row SEEDS, # Dropout seeds for each row
...@@ -114,6 +162,9 @@ def _layer_norm_fwd_1pass_kernel( ...@@ -114,6 +162,9 @@ def _layer_norm_fwd_1pass_kernel(
stride_y_row, stride_y_row,
stride_res_row, stride_res_row,
stride_res_out_row, stride_res_out_row,
stride_x1_row,
stride_y1_row,
M, # number of rows in X
N, # number of columns in X N, # number of columns in X
eps, # epsilon to avoid division by zero eps, # epsilon to avoid division by zero
dropout_p, # Dropout probability dropout_p, # Dropout probability
...@@ -125,6 +176,9 @@ def _layer_norm_fwd_1pass_kernel( ...@@ -125,6 +176,9 @@ def _layer_norm_fwd_1pass_kernel(
HAS_DROPOUT: tl.constexpr, HAS_DROPOUT: tl.constexpr,
STORE_DROPOUT_MASK: tl.constexpr, STORE_DROPOUT_MASK: tl.constexpr,
HAS_ROWSCALE: tl.constexpr, HAS_ROWSCALE: tl.constexpr,
HAS_X1: tl.constexpr,
HAS_W1: tl.constexpr,
HAS_B1: tl.constexpr,
): ):
# Map the program id to the row of X and Y it should compute. # Map the program id to the row of X and Y it should compute.
row = tl.program_id(0) row = tl.program_id(0)
...@@ -134,6 +188,10 @@ def _layer_norm_fwd_1pass_kernel( ...@@ -134,6 +188,10 @@ def _layer_norm_fwd_1pass_kernel(
RESIDUAL += row * stride_res_row RESIDUAL += row * stride_res_row
if STORE_RESIDUAL_OUT: if STORE_RESIDUAL_OUT:
RESIDUAL_OUT += row * stride_res_out_row RESIDUAL_OUT += row * stride_res_out_row
if HAS_X1:
X1 += row * stride_x1_row
if HAS_W1:
Y1 += row * stride_y1_row
# Compute mean and variance # Compute mean and variance
cols = tl.arange(0, BLOCK_N) cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
...@@ -147,6 +205,21 @@ def _layer_norm_fwd_1pass_kernel( ...@@ -147,6 +205,21 @@ def _layer_norm_fwd_1pass_kernel(
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
if STORE_DROPOUT_MASK: if STORE_DROPOUT_MASK:
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
if HAS_X1:
x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_ROWSCALE:
rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
x1 *= rowscale
if HAS_DROPOUT:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask = (
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
)
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
if STORE_DROPOUT_MASK:
tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
x += x1
if HAS_RESIDUAL: if HAS_RESIDUAL:
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
x += residual x += residual
...@@ -171,6 +244,12 @@ def _layer_norm_fwd_1pass_kernel( ...@@ -171,6 +244,12 @@ def _layer_norm_fwd_1pass_kernel(
y = x_hat * w + b if HAS_BIAS else x_hat * w y = x_hat * w + b if HAS_BIAS else x_hat * w
# Write output # Write output
tl.store(Y + cols, y, mask=mask) tl.store(Y + cols, y, mask=mask)
if HAS_W1:
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
if HAS_B1:
b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
tl.store(Y1 + cols, y1, mask=mask)
def _layer_norm_fwd( def _layer_norm_fwd(
...@@ -179,6 +258,9 @@ def _layer_norm_fwd( ...@@ -179,6 +258,9 @@ def _layer_norm_fwd(
bias, bias,
eps, eps,
residual=None, residual=None,
x1=None,
weight1=None,
bias1=None,
dropout_p=0.0, dropout_p=0.0,
rowscale=None, rowscale=None,
out_dtype=None, out_dtype=None,
...@@ -198,17 +280,33 @@ def _layer_norm_fwd( ...@@ -198,17 +280,33 @@ def _layer_norm_fwd(
if bias is not None: if bias is not None:
assert bias.stride(-1) == 1 assert bias.stride(-1) == 1
assert bias.shape == (N,) assert bias.shape == (N,)
if x1 is not None:
assert x1.shape == x.shape
assert rowscale is None
assert x1.stride(-1) == 1
if weight1 is not None:
assert weight1.shape == (N,)
assert weight1.stride(-1) == 1
if bias1 is not None:
assert bias1.shape == (N,)
assert bias1.stride(-1) == 1
if rowscale is not None: if rowscale is not None:
assert rowscale.is_contiguous() assert rowscale.is_contiguous()
assert rowscale.shape == (M,) assert rowscale.shape == (M,)
# allocate output # allocate output
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
assert y.stride(-1) == 1 assert y.stride(-1) == 1
if weight1 is not None:
y1 = torch.empty_like(y)
assert y1.stride(-1) == 1
else:
y1 = None
if ( if (
residual is not None residual is not None
or (residual_dtype is not None and residual_dtype != x.dtype) or (residual_dtype is not None and residual_dtype != x.dtype)
or dropout_p > 0.0 or dropout_p > 0.0
or rowscale is not None or rowscale is not None
or x1 is not None
): ):
residual_out = torch.empty( residual_out = torch.empty(
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
...@@ -219,11 +317,13 @@ def _layer_norm_fwd( ...@@ -219,11 +317,13 @@ def _layer_norm_fwd(
mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None
rstd = torch.empty((M,), dtype=torch.float32, device="cuda") rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
if dropout_p > 0.0: if dropout_p > 0.0:
seeds = torch.randint(2**32, (M,), device=x.device, dtype=torch.int64) seeds = torch.randint(
2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
)
else: else:
seeds = None seeds = None
if return_dropout_mask and dropout_p > 0.0: if return_dropout_mask and dropout_p > 0.0:
dropout_mask = torch.empty_like(x, dtype=torch.bool) dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
else: else:
dropout_mask = None dropout_mask = None
# Less than 64KB per feature: enqueue fused kernel # Less than 64KB per feature: enqueue fused kernel
...@@ -231,7 +331,6 @@ def _layer_norm_fwd( ...@@ -231,7 +331,6 @@ def _layer_norm_fwd(
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N: if N > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
with torch.cuda.device(x.device.index): with torch.cuda.device(x.device.index):
_layer_norm_fwd_1pass_kernel[(M,)]( _layer_norm_fwd_1pass_kernel[(M,)](
x, x,
...@@ -239,6 +338,10 @@ def _layer_norm_fwd( ...@@ -239,6 +338,10 @@ def _layer_norm_fwd(
weight, weight,
bias, bias,
residual, residual,
x1,
weight1,
bias1,
y1,
residual_out, residual_out,
rowscale, rowscale,
seeds, seeds,
...@@ -249,6 +352,9 @@ def _layer_norm_fwd( ...@@ -249,6 +352,9 @@ def _layer_norm_fwd(
y.stride(0), y.stride(0),
residual.stride(0) if residual is not None else 0, residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0, residual_out.stride(0) if residual_out is not None else 0,
x1.stride(0) if x1 is not None else 0,
y1.stride(0) if y1 is not None else 0,
M,
N, N,
eps, eps,
dropout_p, dropout_p,
...@@ -262,7 +368,20 @@ def _layer_norm_fwd( ...@@ -262,7 +368,20 @@ def _layer_norm_fwd(
rowscale is not None, rowscale is not None,
) )
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
return y, mean, rstd, residual_out if residual_out is not None else x, seeds, dropout_mask if dropout_mask is not None and x1 is not None:
dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
else:
dropout_mask1 = None
return (
y,
y1,
mean,
rstd,
residual_out if residual_out is not None else x,
seeds,
dropout_mask,
dropout_mask1,
)
@triton.autotune( @triton.autotune(
...@@ -280,6 +399,9 @@ def _layer_norm_fwd( ...@@ -280,6 +399,9 @@ def _layer_norm_fwd(
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
@triton.jit @triton.jit
def _layer_norm_bwd_kernel( def _layer_norm_bwd_kernel(
...@@ -292,6 +414,11 @@ def _layer_norm_bwd_kernel( ...@@ -292,6 +414,11 @@ def _layer_norm_bwd_kernel(
DW, # pointer to the partial sum of weights gradient DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient DB, # pointer to the partial sum of biases gradient
DRESIDUAL, DRESIDUAL,
W1,
DY1,
DX1,
DW1,
DB1,
DRESIDUAL_IN, DRESIDUAL_IN,
ROWSCALE, ROWSCALE,
SEEDS, SEEDS,
...@@ -302,6 +429,8 @@ def _layer_norm_bwd_kernel( ...@@ -302,6 +429,8 @@ def _layer_norm_bwd_kernel(
stride_dy_row, stride_dy_row,
stride_dx_row, stride_dx_row,
stride_dres_row, stride_dres_row,
stride_dy1_row,
stride_dx1_row,
stride_dres_in_row, stride_dres_in_row,
M, # number of rows in X M, # number of rows in X
N, # number of columns in X N, # number of columns in X
...@@ -315,6 +444,9 @@ def _layer_norm_bwd_kernel( ...@@ -315,6 +444,9 @@ def _layer_norm_bwd_kernel(
HAS_BIAS: tl.constexpr, HAS_BIAS: tl.constexpr,
HAS_DROPOUT: tl.constexpr, HAS_DROPOUT: tl.constexpr,
HAS_ROWSCALE: tl.constexpr, HAS_ROWSCALE: tl.constexpr,
HAS_DY1: tl.constexpr,
HAS_DX1: tl.constexpr,
HAS_B1: tl.constexpr,
RECOMPUTE_OUTPUT: tl.constexpr, RECOMPUTE_OUTPUT: tl.constexpr,
): ):
# Map the program id to the elements of X, DX, and DY it should compute. # Map the program id to the elements of X, DX, and DY it should compute.
...@@ -331,19 +463,31 @@ def _layer_norm_bwd_kernel( ...@@ -331,19 +463,31 @@ def _layer_norm_bwd_kernel(
DRESIDUAL_IN += row_start * stride_dres_in_row DRESIDUAL_IN += row_start * stride_dres_in_row
DY += row_start * stride_dy_row DY += row_start * stride_dy_row
DX += row_start * stride_dx_row DX += row_start * stride_dx_row
if HAS_DY1:
DY1 += row_start * stride_dy1_row
if HAS_DX1:
DX1 += row_start * stride_dx1_row
if RECOMPUTE_OUTPUT: if RECOMPUTE_OUTPUT:
Y += row_start * stride_y_row Y += row_start * stride_y_row
w = tl.load(W + cols, mask=mask).to(tl.float32) w = tl.load(W + cols, mask=mask).to(tl.float32)
if RECOMPUTE_OUTPUT and HAS_BIAS: if RECOMPUTE_OUTPUT and HAS_BIAS:
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
if HAS_DY1:
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
dw = tl.zeros((BLOCK_N,), dtype=tl.float32) dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
if HAS_BIAS: if HAS_BIAS:
db = tl.zeros((BLOCK_N,), dtype=tl.float32) db = tl.zeros((BLOCK_N,), dtype=tl.float32)
if HAS_DY1:
dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
if HAS_B1:
db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
row_end = min((row_block_id + 1) * rows_per_program, M) row_end = min((row_block_id + 1) * rows_per_program, M)
for row in range(row_start, row_end): for row in range(row_start, row_end):
# Load data to SRAM # Load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
if HAS_DY1:
dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
if not IS_RMS_NORM: if not IS_RMS_NORM:
mean = tl.load(Mean + row) mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row) rstd = tl.load(Rstd + row)
...@@ -357,6 +501,11 @@ def _layer_norm_bwd_kernel( ...@@ -357,6 +501,11 @@ def _layer_norm_bwd_kernel(
dw += dy * xhat dw += dy * xhat
if HAS_BIAS: if HAS_BIAS:
db += dy db += dy
if HAS_DY1:
wdy += w1 * dy1
dw1 += dy1 * xhat
if HAS_B1:
db1 += dy1
if not IS_RMS_NORM: if not IS_RMS_NORM:
c1 = tl.sum(xhat * wdy, axis=0) / N c1 = tl.sum(xhat * wdy, axis=0) / N
c2 = tl.sum(wdy, axis=0) / N c2 = tl.sum(wdy, axis=0) / N
...@@ -370,6 +519,15 @@ def _layer_norm_bwd_kernel( ...@@ -370,6 +519,15 @@ def _layer_norm_bwd_kernel(
# Write dx # Write dx
if STORE_DRESIDUAL: if STORE_DRESIDUAL:
tl.store(DRESIDUAL_IN + cols, dx, mask=mask) tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
if HAS_DX1:
if HAS_DROPOUT:
keep_mask = (
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
)
dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
else:
dx1 = dx
tl.store(DX1 + cols, dx1, mask=mask)
if HAS_DROPOUT: if HAS_DROPOUT:
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
...@@ -387,9 +545,17 @@ def _layer_norm_bwd_kernel( ...@@ -387,9 +545,17 @@ def _layer_norm_bwd_kernel(
Y += stride_y_row Y += stride_y_row
DY += stride_dy_row DY += stride_dy_row
DX += stride_dx_row DX += stride_dx_row
if HAS_DY1:
DY1 += stride_dy1_row
if HAS_DX1:
DX1 += stride_dx1_row
tl.store(DW + row_block_id * N + cols, dw, mask=mask) tl.store(DW + row_block_id * N + cols, dw, mask=mask)
if HAS_BIAS: if HAS_BIAS:
tl.store(DB + row_block_id * N + cols, db, mask=mask) tl.store(DB + row_block_id * N + cols, db, mask=mask)
if HAS_DY1:
tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
if HAS_B1:
tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
def _layer_norm_bwd( def _layer_norm_bwd(
...@@ -401,10 +567,14 @@ def _layer_norm_bwd( ...@@ -401,10 +567,14 @@ def _layer_norm_bwd(
mean, mean,
rstd, rstd,
dresidual=None, dresidual=None,
dy1=None,
weight1=None,
bias1=None,
seeds=None, seeds=None,
dropout_p=0.0, dropout_p=0.0,
rowscale=None, rowscale=None,
has_residual=False, has_residual=False,
has_x1=False,
is_rms_norm=False, is_rms_norm=False,
x_dtype=None, x_dtype=None,
recompute_output=False, recompute_output=False,
...@@ -421,9 +591,19 @@ def _layer_norm_bwd( ...@@ -421,9 +591,19 @@ def _layer_norm_bwd(
if bias is not None: if bias is not None:
assert bias.stride(-1) == 1 assert bias.stride(-1) == 1
assert bias.shape == (N,) assert bias.shape == (N,)
if dy1 is not None:
assert weight1 is not None
assert dy1.shape == dy.shape
assert dy1.stride(-1) == 1
if weight1 is not None:
assert weight1.shape == (N,)
assert weight1.stride(-1) == 1
if bias1 is not None:
assert bias1.shape == (N,)
assert bias1.stride(-1) == 1
if seeds is not None: if seeds is not None:
assert seeds.is_contiguous() assert seeds.is_contiguous()
assert seeds.shape == (M,) assert seeds.shape == (M if not has_x1 else M * 2,)
if rowscale is not None: if rowscale is not None:
assert rowscale.is_contiguous() assert rowscale.is_contiguous()
assert rowscale.shape == (M,) assert rowscale.shape == (M,)
...@@ -435,10 +615,14 @@ def _layer_norm_bwd( ...@@ -435,10 +615,14 @@ def _layer_norm_bwd(
) )
dresidual_in = ( dresidual_in = (
torch.empty_like(x) torch.empty_like(x)
if has_residual and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None) if has_residual
and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
else None else None
) )
dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
if recompute_output:
assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
# Less than 64KB per feature: enqueue fused kernel # Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size() MAX_FUSED_SIZE = 65536 // x.element_size()
...@@ -452,6 +636,8 @@ def _layer_norm_bwd( ...@@ -452,6 +636,8 @@ def _layer_norm_bwd(
if bias is not None if bias is not None
else None else None
) )
_dw1 = torch.empty_like(_dw) if weight1 is not None else None
_db1 = torch.empty_like(_db) if bias1 is not None else None
rows_per_program = math.ceil(M / sm_count) rows_per_program = math.ceil(M / sm_count)
grid = (sm_count,) grid = (sm_count,)
with torch.cuda.device(x.device.index): with torch.cuda.device(x.device.index):
...@@ -465,6 +651,11 @@ def _layer_norm_bwd( ...@@ -465,6 +651,11 @@ def _layer_norm_bwd(
_dw, _dw,
_db, _db,
dresidual, dresidual,
weight1,
dy1,
dx1,
_dw1,
_db1,
dresidual_in, dresidual_in,
rowscale, rowscale,
seeds, seeds,
...@@ -475,6 +666,8 @@ def _layer_norm_bwd( ...@@ -475,6 +666,8 @@ def _layer_norm_bwd(
dy.stride(0), dy.stride(0),
dx.stride(0), dx.stride(0),
dresidual.stride(0) if dresidual is not None else 0, dresidual.stride(0) if dresidual is not None else 0,
dy1.stride(0) if dy1 is not None else 0,
dx1.stride(0) if dx1 is not None else 0,
dresidual_in.stride(0) if dresidual_in is not None else 0, dresidual_in.stride(0) if dresidual_in is not None else 0,
M, M,
N, N,
...@@ -490,10 +683,18 @@ def _layer_norm_bwd( ...@@ -490,10 +683,18 @@ def _layer_norm_bwd(
) )
dw = _dw.sum(0).to(weight.dtype) dw = _dw.sum(0).to(weight.dtype)
db = _db.sum(0).to(bias.dtype) if bias is not None else None db = _db.sum(0).to(bias.dtype) if bias is not None else None
dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
# Don't need to compute dresidual_in separately in this case # Don't need to compute dresidual_in separately in this case
if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
dresidual_in = dx dresidual_in = dx
return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) if has_x1 and dropout_p == 0.0:
dx1 = dx
return (
(dx, dw, db, dresidual_in, dx1, dw1, db1)
if not recompute_output
else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
)
class LayerNormFn(torch.autograd.Function): class LayerNormFn(torch.autograd.Function):
...@@ -504,6 +705,9 @@ class LayerNormFn(torch.autograd.Function): ...@@ -504,6 +705,9 @@ class LayerNormFn(torch.autograd.Function):
weight, weight,
bias, bias,
residual=None, residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6, eps=1e-6,
dropout_p=0.0, dropout_p=0.0,
rowscale=None, rowscale=None,
...@@ -522,9 +726,19 @@ class LayerNormFn(torch.autograd.Function): ...@@ -522,9 +726,19 @@ class LayerNormFn(torch.autograd.Function):
residual = residual.reshape(-1, residual.shape[-1]) residual = residual.reshape(-1, residual.shape[-1])
if residual.stride(-1) != 1: if residual.stride(-1) != 1:
residual = residual.contiguous() residual = residual.contiguous()
if x1 is not None:
assert x1.shape == x_shape_og
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
x1 = x1.reshape(-1, x1.shape[-1])
if x1.stride(-1) != 1:
x1 = x1.contiguous()
weight = weight.contiguous() weight = weight.contiguous()
if bias is not None: if bias is not None:
bias = bias.contiguous() bias = bias.contiguous()
if weight1 is not None:
weight1 = weight1.contiguous()
if bias1 is not None:
bias1 = bias1.contiguous()
if rowscale is not None: if rowscale is not None:
rowscale = rowscale.reshape(-1).contiguous() rowscale = rowscale.reshape(-1).contiguous()
residual_dtype = ( residual_dtype = (
...@@ -532,41 +746,71 @@ class LayerNormFn(torch.autograd.Function): ...@@ -532,41 +746,71 @@ class LayerNormFn(torch.autograd.Function):
if residual is not None if residual is not None
else (torch.float32 if residual_in_fp32 else None) else (torch.float32 if residual_in_fp32 else None)
) )
y, mean, rstd, residual_out, seeds, dropout_mask = _layer_norm_fwd( y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
x, x,
weight, weight,
bias, bias,
eps, eps,
residual, residual,
x1,
weight1,
bias1,
dropout_p=dropout_p, dropout_p=dropout_p,
rowscale=rowscale, rowscale=rowscale,
residual_dtype=residual_dtype, residual_dtype=residual_dtype,
is_rms_norm=is_rms_norm, is_rms_norm=is_rms_norm,
return_dropout_mask=return_dropout_mask, return_dropout_mask=return_dropout_mask,
) )
ctx.save_for_backward(residual_out, weight, bias, rowscale, seeds, mean, rstd) ctx.save_for_backward(
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
)
ctx.x_shape_og = x_shape_og ctx.x_shape_og = x_shape_og
ctx.eps = eps ctx.eps = eps
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.is_rms_norm = is_rms_norm ctx.is_rms_norm = is_rms_norm
ctx.has_residual = residual is not None ctx.has_residual = residual is not None
ctx.has_x1 = x1 is not None
ctx.prenorm = prenorm ctx.prenorm = prenorm
ctx.x_dtype = x.dtype ctx.x_dtype = x.dtype
y = y.reshape(x_shape_og) y = y.reshape(x_shape_og)
y1 = y1.reshape(x_shape_og) if y1 is not None else None
residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
if not return_dropout_mask: if not return_dropout_mask:
if weight1 is None:
return y if not prenorm else (y, residual_out) return y if not prenorm else (y, residual_out)
else: else:
return (y, dropout_mask) if not prenorm else (y, residual_out, dropout_mask) return (y, y1) if not prenorm else (y, y1, residual_out)
else:
if weight1 is None:
return (
(y, dropout_mask, dropout_mask1)
if not prenorm
else (y, residual_out, dropout_mask, dropout_mask1)
)
else:
return (
(y, y1, dropout_mask, dropout_mask1)
if not prenorm
else (y, y1, residual_out, dropout_mask, dropout_mask1)
)
@staticmethod @staticmethod
def backward(ctx, dy, *args): def backward(ctx, dy, *args):
x, weight, bias, rowscale, seeds, mean, rstd = ctx.saved_tensors x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
dy = dy.reshape(-1, dy.shape[-1]) dy = dy.reshape(-1, dy.shape[-1])
if dy.stride(-1) != 1: if dy.stride(-1) != 1:
dy = dy.contiguous() dy = dy.contiguous()
assert dy.shape == x.shape assert dy.shape == x.shape
if weight1 is not None:
dy1, args = args[0], args[1:]
dy1 = dy1.reshape(-1, dy1.shape[-1])
if dy1.stride(-1) != 1:
dy1 = dy1.contiguous()
assert dy1.shape == x.shape
else:
dy1 = None
if ctx.prenorm: if ctx.prenorm:
dresidual = args[0] dresidual = args[0]
dresidual = dresidual.reshape(-1, dresidual.shape[-1]) dresidual = dresidual.reshape(-1, dresidual.shape[-1])
...@@ -575,7 +819,7 @@ class LayerNormFn(torch.autograd.Function): ...@@ -575,7 +819,7 @@ class LayerNormFn(torch.autograd.Function):
assert dresidual.shape == x.shape assert dresidual.shape == x.shape
else: else:
dresidual = None dresidual = None
dx, dw, db, dresidual_in = _layer_norm_bwd( dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
dy, dy,
x, x,
weight, weight,
...@@ -584,10 +828,14 @@ class LayerNormFn(torch.autograd.Function): ...@@ -584,10 +828,14 @@ class LayerNormFn(torch.autograd.Function):
mean, mean,
rstd, rstd,
dresidual, dresidual,
dy1,
weight1,
bias1,
seeds, seeds,
ctx.dropout_p, ctx.dropout_p,
rowscale, rowscale,
ctx.has_residual, ctx.has_residual,
ctx.has_x1,
ctx.is_rms_norm, ctx.is_rms_norm,
x_dtype=ctx.x_dtype, x_dtype=ctx.x_dtype,
) )
...@@ -596,6 +844,9 @@ class LayerNormFn(torch.autograd.Function): ...@@ -596,6 +844,9 @@ class LayerNormFn(torch.autograd.Function):
dw, dw,
db, db,
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
dw1,
db1,
None, None,
None, None,
None, None,
...@@ -611,6 +862,9 @@ def layer_norm_fn( ...@@ -611,6 +862,9 @@ def layer_norm_fn(
weight, weight,
bias, bias,
residual=None, residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6, eps=1e-6,
dropout_p=0.0, dropout_p=0.0,
rowscale=None, rowscale=None,
...@@ -624,6 +878,9 @@ def layer_norm_fn( ...@@ -624,6 +878,9 @@ def layer_norm_fn(
weight, weight,
bias, bias,
residual, residual,
x1,
weight1,
bias1,
eps, eps,
dropout_p, dropout_p,
rowscale, rowscale,
...@@ -639,6 +896,9 @@ def rms_norm_fn( ...@@ -639,6 +896,9 @@ def rms_norm_fn(
weight, weight,
bias, bias,
residual=None, residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6, eps=1e-6,
dropout_p=0.0, dropout_p=0.0,
rowscale=None, rowscale=None,
...@@ -651,6 +911,9 @@ def rms_norm_fn( ...@@ -651,6 +911,9 @@ def rms_norm_fn(
weight, weight,
bias, bias,
residual, residual,
x1,
weight1,
bias1,
eps, eps,
dropout_p, dropout_p,
rowscale, rowscale,
...@@ -662,11 +925,15 @@ def rms_norm_fn( ...@@ -662,11 +925,15 @@ def rms_norm_fn(
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None): def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.dropout_p = dropout_p if dropout_p > 0.0:
self.drop = torch.nn.Dropout(dropout_p)
else:
self.drop = None
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.reset_parameters() self.reset_parameters()
...@@ -681,7 +948,7 @@ class RMSNorm(torch.nn.Module): ...@@ -681,7 +948,7 @@ class RMSNorm(torch.nn.Module):
self.bias, self.bias,
residual=residual, residual=residual,
eps=self.eps, eps=self.eps,
dropout_p=self.dropout_p if self.training else 0.0, dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
prenorm=prenorm, prenorm=prenorm,
residual_in_fp32=residual_in_fp32, residual_in_fp32=residual_in_fp32,
) )
......
...@@ -16,12 +16,16 @@ from flash_attn.ops.triton.layernorm import ( ...@@ -16,12 +16,16 @@ from flash_attn.ops.triton.layernorm import (
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize("has_weight1", [False, True])
# @pytest.mark.parametrize("has_weight1", [True])
@pytest.mark.parametrize("has_x1", [False, True])
# @pytest.mark.parametrize("has_x1", [False])
@pytest.mark.parametrize("has_rowscale", [False, True]) @pytest.mark.parametrize("has_rowscale", [False, True])
# @pytest.mark.parametrize("has_rowscale", [True]) # @pytest.mark.parametrize("has_rowscale", [False])
@pytest.mark.parametrize("dropout_p", [0.0, 0.27]) @pytest.mark.parametrize("dropout_p", [0.0, 0.27])
# @pytest.mark.parametrize("dropout_p", [0.0]) # @pytest.mark.parametrize("dropout_p", [0.0])
@pytest.mark.parametrize("prenorm", [True, False]) @pytest.mark.parametrize("prenorm", [True, False])
# @pytest.mark.parametrize("prenorm", [True]) # @pytest.mark.parametrize("prenorm", [False])
@pytest.mark.parametrize("is_rms_norm", [False, True]) @pytest.mark.parametrize("is_rms_norm", [False, True])
# @pytest.mark.parametrize("is_rms_norm", [True]) # @pytest.mark.parametrize("is_rms_norm", [True])
@pytest.mark.parametrize("has_residual", [True, False]) @pytest.mark.parametrize("has_residual", [True, False])
...@@ -48,7 +52,11 @@ def test_layer_norm( ...@@ -48,7 +52,11 @@ def test_layer_norm(
prenorm, prenorm,
dropout_p, dropout_p,
has_rowscale, has_rowscale,
has_x1,
has_weight1,
): ):
if has_rowscale and has_x1:
pytest.skip("Not supported")
device = "cuda" device = "cuda"
if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]): if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
atol = 5e-2 atol = 5e-2
...@@ -62,9 +70,16 @@ def test_layer_norm( ...@@ -62,9 +70,16 @@ def test_layer_norm(
seqlen = 512 seqlen = 512
layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
allclose = ( allclose = (
lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
# Sometimes x0_pt.grad is NaN # Sometimes x0_pt.grad is NaN
lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
<= 2 * (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() + atol <= 2 * (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() + atol
or (
# Sometimes x_pt and x_ref are the same (e.g. bfloat16) so we want to perturb is a bit
# by multiply and divide by 0.3
(x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() == 0.0
and (x - x_ref).abs().max()
<= 2 * (x_pt[~x_pt.isnan()] * 0.3 / 0.3 - x_ref[~x_pt.isnan()]).abs().max() + atol
)
) )
x0 = torch.randn( x0 = torch.randn(
batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
...@@ -86,8 +101,35 @@ def test_layer_norm( ...@@ -86,8 +101,35 @@ def test_layer_norm(
weight_ref = weight.detach().clone().requires_grad_() weight_ref = weight.detach().clone().requires_grad_()
bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
if has_x1:
x1 = torch.randn_like(x0, dtype=input_dtype, requires_grad=True)
x1_pt = x1.detach().clone().requires_grad_()
x1_ref = x1.detach().clone().requires_grad_()
else:
x1, x1_pt, x1_ref = None, None, None
if has_weight1:
weight1 = torch.randn(
hidden_size, device=device, dtype=weight_dtype, requires_grad=True
)
weight1_pt = weight1.detach().clone().requires_grad_()
weight1_ref = weight1.detach().clone().requires_grad_()
if not is_rms_norm:
bias1 = torch.randn(
hidden_size, device=device, dtype=weight_dtype, requires_grad=True
)
else:
bias1 = None
bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
bias1_ref = bias1.detach().clone().requires_grad_() if bias1 is not None else None
else:
weight1, weight1_pt, weight1_ref = None, None, None
bias1, bias1_pt, bias1_ref = None, None, None
rowscale = torch.randn(batch_size, seqlen, dtype=input_dtype, device=device) if has_rowscale else None rowscale = (
torch.randn(batch_size, seqlen, dtype=input_dtype, device=device)
if has_rowscale
else None
)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, *rest = layer_norm_fn( out, *rest = layer_norm_fn(
...@@ -95,6 +137,9 @@ def test_layer_norm( ...@@ -95,6 +137,9 @@ def test_layer_norm(
weight, weight,
bias, bias,
residual=res, residual=res,
x1=x1,
weight1=weight1,
bias1=bias1,
eps=1e-6, eps=1e-6,
dropout_p=dropout_p, dropout_p=dropout_p,
rowscale=rowscale, rowscale=rowscale,
...@@ -103,44 +148,75 @@ def test_layer_norm( ...@@ -103,44 +148,75 @@ def test_layer_norm(
is_rms_norm=is_rms_norm, is_rms_norm=is_rms_norm,
return_dropout_mask=True, return_dropout_mask=True,
) )
dropout_mask = rest[-1] if dropout_p > 0.0 else None dropout_mask = rest[-2] if dropout_p > 0.0 else None
dropout_mask1 = rest[-1] if dropout_p > 0.0 and x1 is not None else None
out_pt = layer_norm_ref_fn( out_pt = layer_norm_ref_fn(
x0_pt, x0_pt,
weight_pt, weight_pt,
bias_pt, bias_pt,
residual=res_pt, residual=res_pt,
x1=x1_pt,
weight1=weight1_pt,
bias1=bias1_pt,
eps=1e-6, eps=1e-6,
dropout_p=dropout_p, dropout_p=dropout_p,
rowscale=rowscale, rowscale=rowscale,
prenorm=prenorm, prenorm=prenorm,
dropout_mask=dropout_mask, dropout_mask=dropout_mask,
dropout_mask1=dropout_mask1,
) )
out_ref = layer_norm_ref_fn( out_ref = layer_norm_ref_fn(
x0_ref, x0_ref,
weight_ref, weight_ref,
bias_ref, bias_ref,
residual=res_ref, residual=res_ref,
x1=x1_ref,
weight1=weight1_ref,
bias1=bias1_ref,
eps=1e-6, eps=1e-6,
dropout_p=dropout_p, dropout_p=dropout_p,
rowscale=rowscale, rowscale=rowscale,
prenorm=prenorm, prenorm=prenorm,
dropout_mask=dropout_mask, dropout_mask=dropout_mask,
dropout_mask1=dropout_mask1,
upcast=True, upcast=True,
) )
if not has_weight1:
if prenorm: if prenorm:
residual = rest[0] residual = rest[0]
out_pt, residual_pt = out_pt out_pt, residual_pt = out_pt
out_ref, residual_ref = out_ref out_ref, residual_ref = out_ref
out1, out1_pt, out1_ref = None, None, None
else:
out1 = rest.pop(0)
if prenorm:
residual = rest[0]
out_pt, out1_pt, residual_pt = out_pt
out_ref, out1_ref, residual_ref = out_ref
else:
out_pt, out1_pt = out_pt
out_ref, out1_ref = out_ref
assert out.dtype == input_dtype assert out.dtype == input_dtype
if prenorm: if prenorm:
assert residual.dtype == residual_dtype assert residual.dtype == residual_dtype
assert allclose(residual, residual_pt, residual_ref) assert allclose(residual, residual_pt, residual_ref)
assert allclose(out, out_pt, out_ref) assert allclose(out, out_pt, out_ref)
if out1 is not None:
assert out1.dtype == input_dtype
assert allclose(out1, out1_pt, out1_ref)
if dropout_mask is not None: if dropout_mask is not None:
dropout_fraction = 1.0 - dropout_mask.float().mean() dropout_fraction = 1.0 - dropout_mask.float().mean()
assert abs(dropout_fraction - dropout_p) < 0.01 assert abs(dropout_fraction - dropout_p) < 0.01
if dropout_mask1 is not None:
dropout_fraction = 1.0 - dropout_mask1.float().mean()
assert abs(dropout_fraction - dropout_p) < 0.01
assert not torch.equal(dropout_mask, dropout_mask1)
g = torch.randn_like(out) / batch_size g = torch.randn_like(out) / batch_size
if has_weight1:
out = out * F.gelu(out1)
out_pt = out_pt * F.gelu(out1_pt)
out_ref = out_ref * F.gelu(out1_ref)
if not prenorm: if not prenorm:
out.backward(g) out.backward(g)
out_pt.backward(g) out_pt.backward(g)
...@@ -152,9 +228,15 @@ def test_layer_norm( ...@@ -152,9 +228,15 @@ def test_layer_norm(
assert allclose(x0.grad, x0_pt.grad, x0_ref.grad) assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)
if has_residual: if has_residual:
assert allclose(res.grad, res_pt.grad, res_ref.grad) assert allclose(res.grad, res_pt.grad, res_ref.grad)
if has_x1:
assert allclose(x1.grad, x1_pt.grad, x1_ref.grad)
assert allclose(weight.grad, weight_pt.grad, weight_ref.grad) assert allclose(weight.grad, weight_pt.grad, weight_ref.grad)
if bias is not None: if bias is not None:
assert allclose(bias.grad, bias_pt.grad, bias_ref.grad) assert allclose(bias.grad, bias_pt.grad, bias_ref.grad)
if has_weight1:
assert allclose(weight1.grad, weight1_pt.grad, weight1_ref.grad)
if bias1 is not None:
assert allclose(bias1.grad, bias1_pt.grad, bias1_ref.grad)
@pytest.mark.parametrize("prenorm", [True, False]) @pytest.mark.parametrize("prenorm", [True, False])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment