Commit 393882bc authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Implement LN with parallel residual, support dim 8k

parent 009a3e71
......@@ -18,6 +18,11 @@ try:
except ImportError:
dropout_add_layer_norm = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
except ImportError:
dropout_add_layer_norm_parallel_residual = None
class Block(nn.Module):
......@@ -64,7 +69,7 @@ class Block(nn.Module):
self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln:
assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed'
assert dropout_add_layer_norm is not None, 'dropout_layer_norm is not installed'
assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
......@@ -214,7 +219,6 @@ class ParallelBlock(nn.Module):
super().__init__()
self.tied_norm = tied_norm
self.fused_dropout_add_ln = fused_dropout_add_ln
assert not self.fused_dropout_add_ln, 'This is not implemented for ParallelBlock yet'
self.residual_in_fp32 = residual_in_fp32
if mixer_cls is None:
mixer_cls = partial(MHA, num_heads=dim // 64)
......@@ -229,7 +233,7 @@ class ParallelBlock(nn.Module):
self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln:
assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed'
assert dropout_add_layer_norm_parallel_residual is not None, 'dropout_layer_norm is not installed'
assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
......@@ -262,19 +266,30 @@ class ParallelBlock(nn.Module):
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
residual.
"""
dropped1 = self.dropout1(hidden_states1)
# For the very 1st block, we only want 1 dropout, not two different dropouts
if hidden_states2 is not None:
dropped2 = self.dropout2(hidden_states2)
residual = ((residual + dropped1 + dropped2)
if residual is not None else dropped1 + dropped2)
if not self.fused_dropout_add_ln:
dropped1 = self.dropout1(hidden_states1)
# For the very 1st block, we only want 1 dropout, not two different dropouts
if hidden_states2 is not None:
dropped2 = self.dropout2(hidden_states2)
residual = ((residual + dropped1 + dropped2)
if residual is not None else dropped1 + dropped2)
else:
residual = (residual + dropped1) if residual is not None else dropped1
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if not self.tied_norm else hidden_states1)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
residual = (residual + dropped1) if residual is not None else dropped1
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if not self.tied_norm else hidden_states1)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
weight2, bias2 = ((self.norm2.weight, self.norm2.bias)
if not self.tied_norm else (None, None))
hidden_states1, hidden_states2, residual = dropout_add_layer_norm_parallel_residual(
hidden_states1, hidden_states2, residual, self.norm1.weight, self.norm1.bias,
weight2, bias2, self.dropout1.p if self.training else 0.0, self.norm1.eps,
prenorm=True, residual_in_fp32=self.residual_in_fp32
)
if self.tied_norm:
hidden_states2 = hidden_states1
if mixer_kwargs is None:
mixer_kwargs = {}
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
......
......@@ -99,6 +99,46 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
def _dropout_add_layer_norm_parallel_residual_forward(
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p,
epsilon, residual_in_fp32=False, is_rms_norm=False
):
""" Assume that arguments are contiguous
"""
hidden_size = gamma0.numel()
x0mat = x0.view((-1, hidden_size))
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
x0mat, x1mat, residualmat, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
None, residual_in_fp32, is_rms_norm
)
# dmask0 and dmask1 are None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma
def _dropout_add_layer_norm_parallel_residual_backward(
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
dropout_p, has_x1, has_residual, is_rms_norm=False
):
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
"""
hidden_size = gamma0.numel()
xmat = x.view((-1, hidden_size))
dz0mat = dz0.view(xmat.shape)
dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
dxmat = dx.view(xmat.shape) if dx is not None else None
dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1, *rest = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
dz0mat, dz1mat, dxmat, xmat, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
dropout_p, has_x1, has_residual, is_rms_norm
)
# dresidualmat is None if not has_residual
return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
class DropoutAddLayerNormFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
......@@ -115,7 +155,7 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None
ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale)
ctx.save_for_backward(xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
ctx.has_residual = residual is not None
......@@ -168,7 +208,7 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None
x_shape = (-1, *x0.shape[1:])
ctx.save_for_backward(xmat.view(x_shape), x0, dmask, gamma, mu, rsigma, colscale,
ctx.save_for_backward(xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale,
x0_subset, out_subset)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
......@@ -208,6 +248,60 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
None, None, None, None, None, None, None, None)
class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None
residual = residual.contiguous() if residual is not None else None
gamma0 = gamma0.contiguous()
beta0 = beta0.contiguous() if beta0 is not None else None
gamma1 = gamma1.contiguous() if gamma1 is not None else None
beta1 = beta1.contiguous() if beta1 is not None else None
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = _dropout_add_layer_norm_parallel_residual_forward(
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
residual_in_fp32, is_rms_norm
)
ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
ctx.has_x1 = x1 is not None
ctx.has_residual = residual is not None
ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta0 is not None
z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
if not return_dmask:
return z if not prenorm else (*z, xmat.view(x0.shape))
else:
dmask0 = (dmask0.view(x0.shape) if dropout_p > 0.
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
dmask1 = (dmask1.view(x0.shape) if dropout_p > 0. and x1 is not None
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
ctx.mark_non_differentiable(dmask0)
ctx.mark_non_differentiable(dmask1)
return (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
@staticmethod
def backward(ctx, dz0, dz1, *args):
dz0 = dz0.contiguous() # this happens!
dz1 = dz1.contiguous() if dz1 is not None else None
dx = args[0].contiguous() if ctx.prenorm else None
x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
dropout_p = ctx.dropout_p
has_x1 = ctx.has_x1
has_residual = ctx.has_residual
dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 = _dropout_add_layer_norm_parallel_residual_backward(
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1,
has_residual, ctx.is_rms_norm
)
dx0 = dx0mat.view(x.shape)
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
return (dx0, dx1, dresidual, dgamma0, dbeta0 if ctx.has_beta else None, dgamma1,
dbeta1 if ctx.has_beta else None, None, None, None, None, None, None)
def layer_norm(x, weight, bias, epsilon):
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
......@@ -237,6 +331,19 @@ def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon
)
def dropout_add_layer_norm_parallel_residual(
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, prenorm=False,
residual_in_fp32=False, return_dropout_mask=False
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormParallelResidualFn.apply(
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
False, return_dropout_mask
)
class DropoutAddLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
device=None, dtype=None):
......
......@@ -5,6 +5,7 @@ import torch
from torch.nn import init
from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn
from flash_attn.ops.layer_norm import DropoutAddLayerNormParallelResidualFn
def rms_norm(x, weight, epsilon):
......@@ -37,6 +38,19 @@ def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon,
)
def dropout_add_rms_norm_parallel_residual(
x0, x1, residual, weight0, bias0, weight1, bias1,
dropout_p, epsilon, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormParallelResidualFn.apply(
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
True, return_dropout_mask
)
class DropoutAddRMSNorm(torch.nn.Module):
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
device=None, dtype=None):
......
......@@ -35,7 +35,7 @@ def test_gpt_neox_optimized(model_name):
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
config.fused_dropout_add_ln = False # We don't support parallel block yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
......
......@@ -36,7 +36,7 @@ def test_gptj_optimized(model_name):
config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = False # We don't support parallel block yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
......
......@@ -10,11 +10,14 @@ from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_nor
from flash_attn.ops.layer_norm import dropout_add_layer_norm_subset
from flash_attn.ops.rms_norm import DropoutAddRMSNorm, dropout_add_rms_norm
from flash_attn.ops.rms_norm import dropout_add_rms_norm_subset
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
try:
from apex.normalization import FusedRMSNorm
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
except:
FusedRMSNorm = None
FusedRMSNorm, fused_rms_norm_affine = None, None
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
......@@ -35,8 +38,8 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
# @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
@pytest.mark.parametrize('hidden_size', [256])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_residual, has_rowscale, has_colscale, is_rms_norm):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
......@@ -64,11 +67,11 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
else:
colscale = None
if has_residual:
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
res = res_pt.detach().clone().requires_grad_()
res_ref = res_pt.detach().clone().float().requires_grad_()
else:
x1 = None
res = None
if has_rowscale:
rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
survival_rate = 0.87
......@@ -95,14 +98,14 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
model.bias.copy_(model_pt.bias)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, dmask = our_layer_norm_func(x0, x1, model.weight, model.bias, model.p,
out, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
model.epsilon, rowscale=rowscale, layerscale=colscale,
residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
assert out.dtype == input_dtype
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
......@@ -116,8 +119,8 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
out_ref.backward(g)
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_residual:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 3 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
if not is_rms_norm:
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5
if has_colscale:
......@@ -145,9 +148,9 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
requires_grad=True)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
res = res_pt.detach().clone().requires_grad_()
res_ref = res_pt.detach().clone().float().requires_grad_()
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
......@@ -161,9 +164,9 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
model_pt.eval()
model.eval()
model_ref.eval()
out = model(x0, x1)
residual_pt = (x0_pt.float() + x1_pt.float()).to(dtype=residual_dtype)
residual_ref = x0_ref + x1_ref
out = model(x0, res)
residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype)
residual_ref = x0_ref + res_ref
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
out_ref = model_ref(residual_ref)
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
......@@ -215,11 +218,11 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
else:
colscale = None
if has_residual:
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
res = res_pt.detach().clone().requires_grad_()
res_ref = res_pt.detach().clone().float().requires_grad_()
else:
x1 = None
res = None
if has_rowscale:
rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
survival_rate = 0.87
......@@ -247,15 +250,15 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
model.bias.copy_(model_pt.bias)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, residual, dmask = our_layer_norm_func(x0, x1, model.weight, model.bias, model.p,
out, residual, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
model.epsilon, rowscale=rowscale,
layerscale=colscale, prenorm=True,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True)
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
......@@ -272,7 +275,7 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
(out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_residual:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
if not is_rms_norm:
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
......@@ -301,9 +304,9 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
requires_grad=True)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
res = res_pt.detach().clone().requires_grad_()
res_ref = res_pt.detach().clone().float().requires_grad_()
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
......@@ -318,9 +321,9 @@ def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtyp
model_pt.eval()
model.eval()
model_ref.eval()
out, residual = model(x0, x1)
residual_pt = (x0_pt.float() + x1_pt.float()).to(dtype=residual_dtype)
residual_ref = x0_ref + x1_ref
out, residual = model(x0, res)
residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype)
residual_ref = x0_ref + res_ref
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
out_ref = model_ref(residual_ref)
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
......@@ -382,11 +385,11 @@ def test_dropout_layer_norm_subset_training(
else:
colscale = None
if has_residual:
x1_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
res = res_pt.detach().clone().requires_grad_()
res_ref = res_pt.detach().clone().float().requires_grad_()
else:
x1 = None
res = None
if has_colscale:
x0_scaled_pt = x0_pt * colscale_pt
......@@ -409,7 +412,7 @@ def test_dropout_layer_norm_subset_training(
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, dmask = dropout_add_layer_norm_subset(
x0, x1, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
x0, res, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
out_numrows = out_numrows, prenorm=False, residual_in_fp32=residual_in_fp32,
return_dropout_mask=True)
......@@ -424,8 +427,8 @@ def test_dropout_layer_norm_subset_training(
dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
dmask_expanded[x0_mask_batch] = dmask
if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + x1_ref
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
......@@ -440,7 +443,7 @@ def test_dropout_layer_norm_subset_training(
out_ref.backward(g)
assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4
if has_residual:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
if has_colscale:
......@@ -502,11 +505,11 @@ def test_dropout_layer_norm_subset_prenorm_training(
else:
colscale = None
if has_residual:
x1_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)
res = res_pt.detach().clone().requires_grad_()
res_ref = res_pt.detach().clone().float().requires_grad_()
else:
x1 = None
res = None
if has_colscale:
x0_scaled_pt = x0_pt * colscale_pt
......@@ -529,7 +532,7 @@ def test_dropout_layer_norm_subset_prenorm_training(
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, residual, dmask = dropout_add_layer_norm_subset(
x0, x1, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
x0, res, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
out_numrows = out_numrows, prenorm=True, residual_in_fp32=residual_in_fp32,
return_dropout_mask=True)
......@@ -544,8 +547,8 @@ def test_dropout_layer_norm_subset_prenorm_training(
dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)
dmask_expanded[x0_mask_batch] = dmask
if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + x1_ref
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)
......@@ -562,8 +565,301 @@ def test_dropout_layer_norm_subset_prenorm_training(
(out_ref * F.sigmoid(residual_ref[out_mask_batch].to(dtype=residual_dtype)) + residual_ref.mean(0, keepdim=True)).backward(g)
assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[x0_mask_batch].abs().max() + 1e-4
if has_residual:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
if has_colscale:
assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (colscale_pt.grad - colscale_ref.grad).abs().max() + 2e-4
@pytest.mark.parametrize('is_rms_norm', [False, True])
# @pytest.mark.parametrize('is_rms_norm', [False])
@pytest.mark.parametrize('tied_norm', [False, True])
# @pytest.mark.parametrize('tied_norm', [False])
@pytest.mark.parametrize('has_residual', [True, False])
# @pytest.mark.parametrize('has_residual', [False])
@pytest.mark.parametrize('has_x1', [True, False])
# @pytest.mark.parametrize('has_x1', [True])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
# @pytest.mark.parametrize('weight_dtype', [torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_parallel_residual_training(
hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_x1, has_residual, tied_norm, is_rms_norm
):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
if is_rms_norm and fused_rms_norm_affine is None:
pytest.skip() # We need Apex's FusedRMSNorm to test
our_layer_norm_func = (dropout_add_layer_norm_parallel_residual if not is_rms_norm
else dropout_add_rms_norm_parallel_residual)
device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_x1:
x1_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
else:
x1 = None
if has_residual:
res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
res = res_pt.detach().clone().requires_grad_()
res_ref = res_pt.detach().clone().float().requires_grad_()
else:
res = None
weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
bias0 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm else None)
weight0_pt = weight0.detach().clone().requires_grad_()
weight0_ref = weight0.detach().clone().float().requires_grad_()
bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None
bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None
if not tied_norm:
weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
bias1 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm else None)
weight1_pt = weight1.detach().clone().requires_grad_()
weight1_ref = weight1.detach().clone().float().requires_grad_()
bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None
else:
weight1, bias1 = None, None
epsilon = 1e-5
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out0, out1, dmask0, dmask1 = our_layer_norm_func(
x0, x1, res, weight0, bias0, weight1, bias1, dropout_p,
epsilon, residual_in_fp32=residual_in_fp32, return_dropout_mask=True
)
assert out0.dtype == input_dtype
if not tied_norm:
assert out1.dtype == input_dtype
print(f'Actual dropout fraction: {1 - dmask0.float().mean().item()}')
if has_residual:
if has_x1:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
+ res_pt.float()).to(dtype=residual_dtype)
residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
+ (x1_ref * dmask1.float()) / (1 - dropout_p)) + res_ref
else:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ res_pt.float()).to(dtype=residual_dtype)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref
else:
if has_x1:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
+ (x1_ref * dmask1.float()) / (1 - dropout_p))
else:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)
if not is_rms_norm:
out0_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt,
eps=epsilon).to(dtype=input_dtype)
out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)
if not tied_norm:
out1_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight1_pt,
bias1_pt, eps=epsilon).to(dtype=input_dtype)
out1_ref = F.layer_norm(residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon)
else:
out0_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,),
eps=epsilon).to(dtype=input_dtype)
out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)
if not tied_norm:
out1_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight1_pt,
(hidden_size,), eps=epsilon).to(dtype=input_dtype)
out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)
assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4
if not tied_norm:
assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4
g0 = torch.randn_like(out0) / batch_size
if tied_norm:
out0.backward(g0)
out0_pt.backward(g0)
out0_ref.backward(g0)
else:
g1 = torch.randn_like(out1) / batch_size
(out0 * g0 + out1 * g1).sum().backward()
(out0_pt * g0 + out1_pt * g1).sum().backward()
(out0_ref * g0 + out1_ref * g1).sum().backward()
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_x1:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (weight0_pt.grad - weight0_ref.grad).abs().max() + 3e-5
if not is_rms_norm:
assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (bias0_pt.grad - bias0_ref.grad).abs().max() + 3e-5
if not tied_norm:
assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (weight1_pt.grad - weight1_ref.grad).abs().max() + 3e-5
if not is_rms_norm:
assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (bias1_pt.grad - bias1_ref.grad).abs().max() + 3e-5
@pytest.mark.parametrize('is_rms_norm', [False, True])
# @pytest.mark.parametrize('is_rms_norm', [False])
@pytest.mark.parametrize('tied_norm', [False, True])
# @pytest.mark.parametrize('tied_norm', [False])
@pytest.mark.parametrize('has_residual', [True, False])
# @pytest.mark.parametrize('has_residual', [False])
@pytest.mark.parametrize('has_x1', [True, False])
# @pytest.mark.parametrize('has_x1', [True])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
# @pytest.mark.parametrize('weight_dtype', [torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [256])
def test_dropout_layer_norm_parallel_residual_prenorm_training(
hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_x1, has_residual, tied_norm, is_rms_norm
):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
if is_rms_norm and fused_rms_norm_affine is None:
pytest.skip() # We need Apex's FusedRMSNorm to test
our_layer_norm_func = (dropout_add_layer_norm_parallel_residual if not is_rms_norm
else dropout_add_rms_norm_parallel_residual)
device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_x1:
x1_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
else:
x1 = None
if has_residual:
res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
res = res_pt.detach().clone().requires_grad_()
res_ref = res_pt.detach().clone().float().requires_grad_()
else:
res = None
weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
bias0 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm else None)
weight0_pt = weight0.detach().clone().requires_grad_()
weight0_ref = weight0.detach().clone().float().requires_grad_()
bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None
bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None
if not tied_norm:
weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
bias1 = (torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
if not is_rms_norm else None)
weight1_pt = weight1.detach().clone().requires_grad_()
weight1_ref = weight1.detach().clone().float().requires_grad_()
bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None
else:
weight1, bias1 = None, None
epsilon = 1e-5
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out0, out1, residual, dmask0, dmask1 = our_layer_norm_func(
x0, x1, res, weight0, bias0, weight1, bias1, dropout_p,
epsilon, prenorm=True, residual_in_fp32=residual_in_fp32, return_dropout_mask=True
)
assert out0.dtype == input_dtype
if not tied_norm:
assert out1.dtype == input_dtype
print(f'Actual dropout fraction: {1 - dmask0.float().mean().item()}')
if has_residual:
if has_x1:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)
+ res_pt.float()).to(dtype=residual_dtype)
residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
+ (x1_ref * dmask1.float()) / (1 - dropout_p)) + res_ref
else:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ res_pt.float()).to(dtype=residual_dtype)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref
else:
if has_x1:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)
+ (x1_pt.float() * dmask1.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_ref = ((x0_ref * dmask0.float()) / (1 - dropout_p)
+ (x1_ref * dmask1.float()) / (1 - dropout_p))
else:
residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)
if not is_rms_norm:
out0_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt,
eps=epsilon).to(dtype=input_dtype)
out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)
if not tied_norm:
out1_pt = F.layer_norm(residual_pt.to(dtype=weight_dtype), (hidden_size,), weight1_pt,
bias1_pt, eps=epsilon).to(dtype=input_dtype)
out1_ref = F.layer_norm(residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon)
else:
out0_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,),
eps=epsilon).to(dtype=input_dtype)
out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)
if not tied_norm:
out1_pt = fused_rms_norm_affine(residual_pt.to(dtype=weight_dtype), weight1_pt,
(hidden_size,), eps=epsilon).to(dtype=input_dtype)
out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)
assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4
if not tied_norm:
assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
g0 = torch.randn_like(out0) / batch_size
if tied_norm:
(out0 * F.sigmoid(residual)).backward(g0)
(out0_pt * F.sigmoid(residual_pt)).backward(g0)
(out0_ref * F.sigmoid(residual_ref)).backward(g0)
else:
g1 = torch.randn_like(out1) / batch_size
(out0 * F.sigmoid(residual) * g0 + out1 * g1).sum().backward()
(out0_pt * F.sigmoid(residual_pt) * g0 + out1_pt * g1).sum().backward()
(out0_ref * F.sigmoid(residual_ref) * g0 + out1_ref * g1).sum().backward()
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_x1:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
if has_residual:
assert (res.grad - res_ref.grad).abs().max() <= 4 * (res_pt.grad - res_ref.grad).abs().max() + 1e-4
assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (weight0_pt.grad - weight0_ref.grad).abs().max() + 3e-5
if not is_rms_norm:
assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (bias0_pt.grad - bias0_ref.grad).abs().max() + 3e-5
if not tied_norm:
assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (weight1_pt.grad - weight1_ref.grad).abs().max() + 3e-5
if not is_rms_norm:
assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (bias1_pt.grad - bias1_ref.grad).abs().max() + 3e-5
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment