"vscode:/vscode.git/clone" did not exist on "9d173c93e67213bb87c7c4286a5543867bd22bdf"
Unverified Commit 4e33a69e authored by Marks101's avatar Marks101 Committed by GitHub
Browse files

[PyTorch] TransformerLayer: add support for Falcon architecture (#513)



* [PyTorch] TransformerLayer: add parallel_attention_mlp to support Falcon models
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>

* [PyTorch] add test for parallel_attention_mlp to test_numerics
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>

* [PyTorch] TorchGPT: fix dropout for parallel_attention_mlp

Now uses nn.functional.dropout because depending on the path there are one or two dropouts.
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>

* Apply suggestions from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* [PyTorch] test_gpt_accuracy: fix spelling in construction of TorchGPT
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>

---------
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent c898ab1b
...@@ -318,13 +318,12 @@ class TorchLayerNormMLP(nn.Module): ...@@ -318,13 +318,12 @@ class TorchLayerNormMLP(nn.Module):
class TorchGPT(nn.Module): class TorchGPT(nn.Module):
def __init__(self, hidden_size: int, eps: float, num_attention_heads: int): def __init__(self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool):
super().__init__() super().__init__()
self.ln = nn.LayerNorm(hidden_size, eps=eps) self.ln = nn.LayerNorm(hidden_size, eps=eps)
self.causal_attn = TorchMHA(hidden_size, num_attention_heads) self.causal_attn = TorchMHA(hidden_size, num_attention_heads)
self.ln_mlp = TorchLayerNormMLP(hidden_size, 4 * hidden_size, eps) self.ln_mlp = TorchLayerNormMLP(hidden_size, 4 * hidden_size, eps)
self.resid_attn_dropout = nn.Dropout(0.1) self.parallel_attention_mlp = parallel_attention_mlp
self.resid_mlp_dropout = nn.Dropout(0.1)
def forward( def forward(
self, self,
...@@ -333,12 +332,17 @@ class TorchGPT(nn.Module): ...@@ -333,12 +332,17 @@ class TorchGPT(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
a = self.ln(x) a = self.ln(x)
b = self.causal_attn(a, attn_mask) b = self.causal_attn(a, attn_mask)
x = x + self.resid_attn_dropout(b) if self.parallel_attention_mlp:
n = self.ln_mlp(x) n = self.ln_mlp(x)
x = x + self.resid_mlp_dropout(n) x = x + nn.functional.dropout(b + n, p=0.1, training=self.training)
else:
x = x + nn.functional.dropout(b, p=0.1, training=self.training)
n = self.ln_mlp(x)
x = x + nn.functional.dropout(n, p=0.1, training=self.training)
return x return x
def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False): def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
reset_rng_states() reset_rng_states()
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
...@@ -619,7 +623,8 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): ...@@ -619,7 +623,8 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_accuracy(dtype, bs, model): @pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model] config = model_configs[model]
te_gpt = ( te_gpt = (
...@@ -632,6 +637,7 @@ def test_gpt_accuracy(dtype, bs, model): ...@@ -632,6 +637,7 @@ def test_gpt_accuracy(dtype, bs, model):
hidden_dropout=0.1, hidden_dropout=0.1,
fuse_qkv_params=True, fuse_qkv_params=True,
qkv_weight_interleaved=False, qkv_weight_interleaved=False,
parallel_attention_mlp=parallel_attention_mlp,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -643,6 +649,7 @@ def test_gpt_accuracy(dtype, bs, model): ...@@ -643,6 +649,7 @@ def test_gpt_accuracy(dtype, bs, model):
config.hidden_size, config.hidden_size,
config.eps, config.eps,
config.num_attention_heads, config.num_attention_heads,
parallel_attention_mlp=parallel_attention_mlp,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
......
...@@ -441,9 +441,10 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, ...@@ -441,9 +441,10 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad,
@pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, bias, activation, zero_centered_gamma, bias, activation,
normalization): normalization, parallel_attention_mlp):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
...@@ -473,6 +474,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, ...@@ -473,6 +474,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad,
bias=bias, bias=bias,
activation=activation, activation=activation,
normalization=normalization, normalization=normalization,
parallel_attention_mlp=parallel_attention_mlp,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
......
...@@ -115,6 +115,11 @@ class TransformerLayer(torch.nn.Module): ...@@ -115,6 +115,11 @@ class TransformerLayer(torch.nn.Module):
if set to `True`, layer normalization is applied on the output side, if set to `True`, layer normalization is applied on the output side,
after the final dropout-add. default behavior is to apply layer after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation. normalization on the input side, before the QKV transformation.
parallel_attention_mlp: bool, default = `False`
if set to `True`, self-attention and feedforward network are computed
based on the same input (in parallel) instead of sequentially.
Both blocks have an independent normalization.
This architecture is used in `Falcon` models.
layer_type: {'encoder', 'decoder'}, default = `encoder` layer_type: {'encoder', 'decoder'}, default = `encoder`
if set to `decoder`, an additional cross-attn block is added after self-attn. if set to `decoder`, an additional cross-attn block is added after self-attn.
This can be used for structures like `T5` Transformer in conjunction with the This can be used for structures like `T5` Transformer in conjunction with the
...@@ -224,6 +229,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -224,6 +229,7 @@ class TransformerLayer(torch.nn.Module):
sequence_parallel: bool = False, sequence_parallel: bool = False,
apply_residual_connection_post_layernorm: bool = False, apply_residual_connection_post_layernorm: bool = False,
output_layernorm: bool = False, output_layernorm: bool = False,
parallel_attention_mlp: bool = False,
layer_type: str = "encoder", layer_type: str = "encoder",
drop_path_rate: float = 0.0, drop_path_rate: float = 0.0,
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
...@@ -274,6 +280,18 @@ class TransformerLayer(torch.nn.Module): ...@@ -274,6 +280,18 @@ class TransformerLayer(torch.nn.Module):
apply_residual_connection_post_layernorm apply_residual_connection_post_layernorm
) )
if parallel_attention_mlp:
assert self.layer_type == "encoder", "parallel_attention requires layer_type='encoder'"
assert (
not self.apply_residual_connection_post_layernorm
), "parallel_attention and apply_residual_connection_post_layernorm "\
"not supported simultaneously."
assert (
not self.output_layernorm
), "parallel_attention and output_layernorm not supported simultaneously"
self.parallel_attention_mlp = parallel_attention_mlp
assert layer_type in LayerTypes, f"layer_type {layer_type} not supported" assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"
if not fuse_qkv_params: if not fuse_qkv_params:
...@@ -336,7 +354,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -336,7 +354,7 @@ class TransformerLayer(torch.nn.Module):
input_layernorm=not output_layernorm, input_layernorm=not output_layernorm,
attention_type="self", attention_type="self",
bias=bias, bias=bias,
return_bias=True, return_bias=not self.parallel_attention_mlp,
normalization=normalization, normalization=normalization,
device=device, device=device,
) )
...@@ -370,7 +388,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -370,7 +388,7 @@ class TransformerLayer(torch.nn.Module):
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
bias=bias, bias=bias,
return_bias=True, return_bias=not self.parallel_attention_mlp,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
params_dtype=params_dtype, params_dtype=params_dtype,
return_layernorm_output=apply_residual_connection_post_layernorm, return_layernorm_output=apply_residual_connection_post_layernorm,
...@@ -578,41 +596,19 @@ class TransformerLayer(torch.nn.Module): ...@@ -578,41 +596,19 @@ class TransformerLayer(torch.nn.Module):
if self.apply_residual_connection_post_layernorm and not self.output_layernorm: if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
attention_output, attention_bias, residual = self_attention_outputs attention_output, attention_bias, residual = self_attention_outputs
else: hidden_states = self._bias_dropout_add(
attention_output, attention_bias, residual, self.drop_path
)
elif not self.parallel_attention_mlp:
attention_output, attention_bias = self_attention_outputs attention_output, attention_bias = self_attention_outputs
residual = hidden_states hidden_states = self._bias_dropout_add(
attention_output, attention_bias, hidden_states, self.drop_path
# Set BDA func.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# Bias dropoout add.
if self.drop_path is None and attention_bias.numel() != 0:
with self.bias_dropout_add_exec_handler():
bda_output = bias_dropout_add_func(
attention_output, attention_bias, residual, self.hidden_dropout
)
else:
if attention_bias.numel() != 0:
attention_output = attention_output + attention_bias
out = torch.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=self.training,
) )
if self.drop_path is not None:
out = self.drop_path(out)
bda_output = residual + out
# Cross attention. # Cross attention.
if self.layer_type == "decoder": if self.layer_type == "decoder":
inter_attention_outputs = self.inter_attention( inter_attention_outputs = self.inter_attention(
bda_output, hidden_states,
attention_mask=enc_dec_attn_mask, attention_mask=enc_dec_attn_mask,
attn_mask_type=self_attn_mask_type, attn_mask_type=self_attn_mask_type,
encoder_output=encoder_output, encoder_output=encoder_output,
...@@ -626,49 +622,54 @@ class TransformerLayer(torch.nn.Module): ...@@ -626,49 +622,54 @@ class TransformerLayer(torch.nn.Module):
attention_output, attention_bias, residual = inter_attention_outputs attention_output, attention_bias, residual = inter_attention_outputs
else: else:
attention_output, attention_bias = inter_attention_outputs attention_output, attention_bias = inter_attention_outputs
residual = bda_output residual = hidden_states
hidden_states = self._bias_dropout_add(attention_output, attention_bias, residual)
if attention_bias.numel() != 0:
with self.bias_dropout_add_exec_handler():
bda_output = bias_dropout_add_func(
attention_output, attention_bias, residual, self.hidden_dropout
)
else:
out = torch.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=self.training,
)
bda_output = residual + out
# MLP. # MLP.
mlp_outputs = self.layernorm_mlp( mlp_outputs = self.layernorm_mlp(
bda_output, is_first_microbatch=is_first_microbatch hidden_states, is_first_microbatch=is_first_microbatch
) )
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
mlp_output, mlp_bias, residual = mlp_outputs mlp_output, mlp_bias, residual = mlp_outputs
output = self._bias_dropout_add(mlp_output, mlp_bias, residual, self.drop_path)
elif self.parallel_attention_mlp:
output = self._bias_dropout_add(
self_attention_outputs, mlp_outputs, hidden_states, self.drop_path
)
else: else:
mlp_output, mlp_bias = mlp_outputs mlp_output, mlp_bias = mlp_outputs
residual = bda_output output = self._bias_dropout_add(mlp_output, mlp_bias, hidden_states, self.drop_path)
# For BERT like architectures.
if self.output_layernorm:
output = self.layernorm(output)
# output: [s, b, h]
return output
def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None):
if drop_path is None and bias.numel() != 0:
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# Bias dropoout add.
if self.drop_path is None and mlp_bias.numel() != 0:
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func( output = bias_dropout_add_func(
mlp_output, mlp_bias, residual, self.hidden_dropout hidden_state, bias, residual, self.hidden_dropout
) )
else: else:
if mlp_bias.numel() != 0: if bias.numel() != 0:
mlp_output = mlp_output + mlp_bias hidden_state = hidden_state + bias
out = torch.nn.functional.dropout( out = torch.nn.functional.dropout(
mlp_output, p=self.hidden_dropout, training=self.training hidden_state, p=self.hidden_dropout, training=self.training
) )
if self.drop_path is not None: if drop_path is not None:
out = self.drop_path(out) out = drop_path(out)
output = residual + out output = residual + out
# For BERT like architectures.
if self.output_layernorm:
output = self.layernorm(output)
# output: [s, b, h]
return output return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment