Commit 393882bc authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Implement LN with parallel residual, support dim 8k

parent 009a3e71
...@@ -18,6 +18,11 @@ try: ...@@ -18,6 +18,11 @@ try:
except ImportError: except ImportError:
dropout_add_layer_norm = None dropout_add_layer_norm = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
except ImportError:
dropout_add_layer_norm_parallel_residual = None
class Block(nn.Module): class Block(nn.Module):
...@@ -64,7 +69,7 @@ class Block(nn.Module): ...@@ -64,7 +69,7 @@ class Block(nn.Module):
self.norm2 = norm_cls(dim) self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln: if self.fused_dropout_add_ln:
assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed' assert dropout_add_layer_norm is not None, 'dropout_layer_norm is not installed'
assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout) assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
...@@ -214,7 +219,6 @@ class ParallelBlock(nn.Module): ...@@ -214,7 +219,6 @@ class ParallelBlock(nn.Module):
super().__init__() super().__init__()
self.tied_norm = tied_norm self.tied_norm = tied_norm
self.fused_dropout_add_ln = fused_dropout_add_ln self.fused_dropout_add_ln = fused_dropout_add_ln
assert not self.fused_dropout_add_ln, 'This is not implemented for ParallelBlock yet'
self.residual_in_fp32 = residual_in_fp32 self.residual_in_fp32 = residual_in_fp32
if mixer_cls is None: if mixer_cls is None:
mixer_cls = partial(MHA, num_heads=dim // 64) mixer_cls = partial(MHA, num_heads=dim // 64)
...@@ -229,7 +233,7 @@ class ParallelBlock(nn.Module): ...@@ -229,7 +233,7 @@ class ParallelBlock(nn.Module):
self.norm2 = norm_cls(dim) self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln: if self.fused_dropout_add_ln:
assert dropout_add_layer_norm is not None, 'dropout_add_ln is not installed' assert dropout_add_layer_norm_parallel_residual is not None, 'dropout_layer_norm is not installed'
assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout) assert isinstance(self.norm1, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
...@@ -262,19 +266,30 @@ class ParallelBlock(nn.Module): ...@@ -262,19 +266,30 @@ class ParallelBlock(nn.Module):
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1). hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
residual. residual.
""" """
dropped1 = self.dropout1(hidden_states1) if not self.fused_dropout_add_ln:
# For the very 1st block, we only want 1 dropout, not two different dropouts dropped1 = self.dropout1(hidden_states1)
if hidden_states2 is not None: # For the very 1st block, we only want 1 dropout, not two different dropouts
dropped2 = self.dropout2(hidden_states2) if hidden_states2 is not None:
residual = ((residual + dropped1 + dropped2) dropped2 = self.dropout2(hidden_states2)
if residual is not None else dropped1 + dropped2) residual = ((residual + dropped1 + dropped2)
if residual is not None else dropped1 + dropped2)
else:
residual = (residual + dropped1) if residual is not None else dropped1
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if not self.tied_norm else hidden_states1)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else: else:
residual = (residual + dropped1) if residual is not None else dropped1 weight2, bias2 = ((self.norm2.weight, self.norm2.bias)
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) if not self.tied_norm else (None, None))
hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype)) hidden_states1, hidden_states2, residual = dropout_add_layer_norm_parallel_residual(
if not self.tied_norm else hidden_states1) hidden_states1, hidden_states2, residual, self.norm1.weight, self.norm1.bias,
if self.residual_in_fp32: weight2, bias2, self.dropout1.p if self.training else 0.0, self.norm1.eps,
residual = residual.to(torch.float32) prenorm=True, residual_in_fp32=self.residual_in_fp32
)
if self.tied_norm:
hidden_states2 = hidden_states1
if mixer_kwargs is None: if mixer_kwargs is None:
mixer_kwargs = {} mixer_kwargs = {}
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs) hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
......
...@@ -99,6 +99,46 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga ...@@ -99,6 +99,46 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
def _dropout_add_layer_norm_parallel_residual_forward(
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p,
epsilon, residual_in_fp32=False, is_rms_norm=False
):
""" Assume that arguments are contiguous
"""
hidden_size = gamma0.numel()
x0mat = x0.view((-1, hidden_size))
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
x0mat, x1mat, residualmat, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
None, residual_in_fp32, is_rms_norm
)
# dmask0 and dmask1 are None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma
def _dropout_add_layer_norm_parallel_residual_backward(
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
dropout_p, has_x1, has_residual, is_rms_norm=False
):
""" Assume that arguments are contiguous
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
"""
hidden_size = gamma0.numel()
xmat = x.view((-1, hidden_size))
dz0mat = dz0.view(xmat.shape)
dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
dxmat = dx.view(xmat.shape) if dx is not None else None
dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1, *rest = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
dz0mat, dz1mat, dxmat, xmat, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
dropout_p, has_x1, has_residual, is_rms_norm
)
# dresidualmat is None if not has_residual
return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
class DropoutAddLayerNormFn(torch.autograd.Function): class DropoutAddLayerNormFn(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
...@@ -115,7 +155,7 @@ class DropoutAddLayerNormFn(torch.autograd.Function): ...@@ -115,7 +155,7 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
) )
# Only need to save x0 if we need to compute gradient wrt colscale # Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None x0_saved = x0 if colscale is not None else None
ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale) ctx.save_for_backward(xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale)
ctx.prenorm = prenorm ctx.prenorm = prenorm
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.has_residual = residual is not None ctx.has_residual = residual is not None
...@@ -168,7 +208,7 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function): ...@@ -168,7 +208,7 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
# Only need to save x0 if we need to compute gradient wrt colscale # Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None x0_saved = x0 if colscale is not None else None
x_shape = (-1, *x0.shape[1:]) x_shape = (-1, *x0.shape[1:])
ctx.save_for_backward(xmat.view(x_shape), x0, dmask, gamma, mu, rsigma, colscale, ctx.save_for_backward(xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale,
x0_subset, out_subset) x0_subset, out_subset)
ctx.prenorm = prenorm ctx.prenorm = prenorm
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
...@@ -208,6 +248,60 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function): ...@@ -208,6 +248,60 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
None, None, None, None, None, None, None, None) None, None, None, None, None, None, None, None)
class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None
residual = residual.contiguous() if residual is not None else None
gamma0 = gamma0.contiguous()
beta0 = beta0.contiguous() if beta0 is not None else None
gamma1 = gamma1.contiguous() if gamma1 is not None else None
beta1 = beta1.contiguous() if beta1 is not None else None
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = _dropout_add_layer_norm_parallel_residual_forward(
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
residual_in_fp32, is_rms_norm
)
ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
ctx.has_x1 = x1 is not None
ctx.has_residual = residual is not None
ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta0 is not None
z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
if not return_dmask:
return z if not prenorm else (*z, xmat.view(x0.shape))
else:
dmask0 = (dmask0.view(x0.shape) if dropout_p > 0.
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
dmask1 = (dmask1.view(x0.shape) if dropout_p > 0. and x1 is not None
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
ctx.mark_non_differentiable(dmask0)
ctx.mark_non_differentiable(dmask1)
return (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
@staticmethod
def backward(ctx, dz0, dz1, *args):
dz0 = dz0.contiguous() # this happens!
dz1 = dz1.contiguous() if dz1 is not None else None
dx = args[0].contiguous() if ctx.prenorm else None
x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
dropout_p = ctx.dropout_p
has_x1 = ctx.has_x1
has_residual = ctx.has_residual
dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 = _dropout_add_layer_norm_parallel_residual_backward(
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1,
has_residual, ctx.is_rms_norm
)
dx0 = dx0mat.view(x.shape)
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
return (dx0, dx1, dresidual, dgamma0, dbeta0 if ctx.has_beta else None, dgamma1,
dbeta1 if ctx.has_beta else None, None, None, None, None, None, None)
def layer_norm(x, weight, bias, epsilon): def layer_norm(x, weight, bias, epsilon):
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
...@@ -237,6 +331,19 @@ def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon ...@@ -237,6 +331,19 @@ def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon
) )
def dropout_add_layer_norm_parallel_residual(
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, prenorm=False,
residual_in_fp32=False, return_dropout_mask=False
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormParallelResidualFn.apply(
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
False, return_dropout_mask
)
class DropoutAddLayerNorm(torch.nn.Module): class DropoutAddLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False, def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
device=None, dtype=None): device=None, dtype=None):
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from torch.nn import init from torch.nn import init
from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn
from flash_attn.ops.layer_norm import DropoutAddLayerNormParallelResidualFn
def rms_norm(x, weight, epsilon): def rms_norm(x, weight, epsilon):
...@@ -37,6 +38,19 @@ def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, ...@@ -37,6 +38,19 @@ def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon,
) )
def dropout_add_rms_norm_parallel_residual(
x0, x1, residual, weight0, bias0, weight1, bias1,
dropout_p, epsilon, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormParallelResidualFn.apply(
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
True, return_dropout_mask
)
class DropoutAddRMSNorm(torch.nn.Module): class DropoutAddRMSNorm(torch.nn.Module):
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False, def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
device=None, dtype=None): device=None, dtype=None):
......
...@@ -35,7 +35,7 @@ def test_gpt_neox_optimized(model_name): ...@@ -35,7 +35,7 @@ def test_gpt_neox_optimized(model_name):
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast" config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
config.fused_dropout_add_ln = False # We don't support parallel block yet config.fused_dropout_add_ln = True
config.residual_in_fp32 = True config.residual_in_fp32 = True
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
......
...@@ -36,7 +36,7 @@ def test_gptj_optimized(model_name): ...@@ -36,7 +36,7 @@ def test_gptj_optimized(model_name):
config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = True config.fused_mlp = True
config.fused_dropout_add_ln = False # We don't support parallel block yet config.fused_dropout_add_ln = True
config.residual_in_fp32 = True config.residual_in_fp32 = True
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment