Unverified Commit 8aa2da17 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

PyTorch MultiheadAttention API (#387)



* PyTorch MultiheadAttention API
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix ONNX export tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Expose MultiheadAttention for import
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Expand mask type and add no mask numerical test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f29efb77
...@@ -22,6 +22,9 @@ pyTorch ...@@ -22,6 +22,9 @@ pyTorch
.. autoapiclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs) .. autoapiclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
:members: forward :members: forward
.. autoapiclass:: transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs)
:members: forward
.. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs) .. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
:members: forward :members: forward
......
...@@ -21,7 +21,8 @@ from transformer_engine.pytorch.utils import ( ...@@ -21,7 +21,8 @@ from transformer_engine.pytorch.utils import (
attention_mask_func, attention_mask_func,
) )
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
DotProductAttention, Linear, LayerNormLinear, LayerNormMLP, TransformerLayer, RMSNorm DotProductAttention, LayerNormLinear, LayerNormMLP, Linear,
MultiheadAttention, RMSNorm, TransformerLayer
) )
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
...@@ -60,6 +61,9 @@ all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] ...@@ -60,6 +61,9 @@ all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
all_normalizations = ["LayerNorm", "RMSNorm"] all_normalizations = ["LayerNorm", "RMSNorm"]
mask_types = ["causal", "no_mask"]
def get_causal_attn_mask(sq: int) -> torch.Tensor: def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
...@@ -320,6 +324,7 @@ class TorchDotProductAttention(torch.nn.Module): ...@@ -320,6 +324,7 @@ class TorchDotProductAttention(torch.nn.Module):
return context_layer return context_layer
# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py # Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py
class TorchRMSNorm(nn.Module): class TorchRMSNorm(nn.Module):
def __init__(self, in_features, eps=1e-5): def __init__(self, in_features, eps=1e-5):
...@@ -341,6 +346,7 @@ class TorchRMSNorm(nn.Module): ...@@ -341,6 +346,7 @@ class TorchRMSNorm(nn.Module):
return (self.weight.float() * x_normed).to(x.dtype) return (self.weight.float() * x_normed).to(x.dtype)
class TorchLayerNormLinear(nn.Module): class TorchLayerNormLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, def __init__(self, in_features: int, out_features: int,
eps: float, bias: bool = True, eps: float, bias: bool = True,
...@@ -371,7 +377,11 @@ class TorchMHA(nn.Module): ...@@ -371,7 +377,11 @@ class TorchMHA(nn.Module):
) )
def forward(self, x, attn_mask=None): def forward(self, x, attn_mask=None):
return self.mhsa(x, x, x, attn_mask=attn_mask, need_weights=False) output = self.mhsa(x, x, x, attn_mask=attn_mask, need_weights=False)
if isinstance(output, tuple):
output = output[0]
return output
_supported_act = {'geglu' : nn.GELU(approximate="tanh"), _supported_act = {'geglu' : nn.GELU(approximate="tanh"),
'gelu' : nn.GELU(approximate="tanh"), 'gelu' : nn.GELU(approximate="tanh"),
...@@ -379,6 +389,7 @@ _supported_act = {'geglu' : nn.GELU(approximate="tanh"), ...@@ -379,6 +389,7 @@ _supported_act = {'geglu' : nn.GELU(approximate="tanh"),
'relu' : nn.ReLU(), 'relu' : nn.ReLU(),
'swiglu' : nn.SiLU()} 'swiglu' : nn.SiLU()}
class TorchGLU(nn.Module): class TorchGLU(nn.Module):
def __init__(self, activation: str): def __init__(self, activation: str):
super().__init__() super().__init__()
...@@ -391,6 +402,7 @@ class TorchGLU(nn.Module): ...@@ -391,6 +402,7 @@ class TorchGLU(nn.Module):
a = self.act(a) a = self.act(a)
return a * b return a * b
class TorchLayerNormMLP(nn.Module): class TorchLayerNormMLP(nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int, def __init__(self, hidden_size: int, ffn_hidden_size: int,
eps: float = 1e-5, activation = 'gelu', eps: float = 1e-5, activation = 'gelu',
...@@ -431,7 +443,7 @@ class TorchGPT(nn.Module): ...@@ -431,7 +443,7 @@ class TorchGPT(nn.Module):
attn_mask: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
a = self.ln(x) a = self.ln(x)
b, _ = self.causal_attn(a, attn_mask) b = self.causal_attn(a, attn_mask)
x = x + self.resid_attn_dropout(b) x = x + self.resid_attn_dropout(b)
n = self.ln_mlp(x) n = self.ln_mlp(x)
x = x + self.resid_mlp_dropout(n) x = x + self.resid_mlp_dropout(n)
...@@ -754,6 +766,75 @@ def test_gpt_accuracy(dtype, bs, model): ...@@ -754,6 +766,75 @@ def test_gpt_accuracy(dtype, bs, model):
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
def _test_mha_accuracy(block, bs, dtype, config, mask_type):
reset_rng_states()
inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None
out = block(inp_hidden_states, inp_attn_mask)
loss = out.sum()
loss.backward()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model]
te_mha = (
MultiheadAttention(
config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
qkv_weight_interleaved=False,
input_layernorm=False,
attn_mask_type=mask_type,
)
.to(dtype=dtype)
.cuda()
.eval()
)
torch_mha = (
TorchMHA(
config.hidden_size,
config.num_attention_heads,
)
.to(dtype=dtype)
.cuda()
.eval()
)
# Share params
with torch.no_grad():
torch_mha.mhsa.in_proj_weight = Parameter(te_mha.qkv.weight.clone())
torch_mha.mhsa.in_proj_bias = Parameter(te_mha.qkv.bias.clone())
torch_mha.mhsa.out_proj.weight = Parameter(te_mha.proj.weight.clone())
torch_mha.mhsa.out_proj.bias = Parameter(te_mha.proj.bias.clone())
te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type)
torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type)
# Check output.
if dtype == torch.float32:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
def _test_granular_accuracy(block, bs, dtype, config): def _test_granular_accuracy(block, bs, dtype, config):
reset_rng_states() reset_rng_states()
......
...@@ -1267,7 +1267,7 @@ def test_export_multihead_attention( ...@@ -1267,7 +1267,7 @@ def test_export_multihead_attention(
input_ln_str = "_input-ln" if input_layernorm else "" input_ln_str = "_input-ln" if input_layernorm else ""
fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx" fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx"
model = te.attention.MultiHeadAttention( model = te.MultiheadAttention(
*attention_args, *attention_args,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
params_dtype=precision, params_dtype=precision,
...@@ -1275,6 +1275,7 @@ def test_export_multihead_attention( ...@@ -1275,6 +1275,7 @@ def test_export_multihead_attention(
input_layernorm=input_layernorm, input_layernorm=input_layernorm,
attention_type=attention_type, attention_type=attention_type,
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
return_bias=True,
).to(device='cuda') ).to(device='cuda')
inp_context = (hidden_states_context, attention_mask, encoder_output) inp_context = (hidden_states_context, attention_mask, encoder_output)
......
...@@ -9,6 +9,7 @@ from .module import LayerNormMLP ...@@ -9,6 +9,7 @@ from .module import LayerNormMLP
from .module import LayerNorm from .module import LayerNorm
from .module import RMSNorm from .module import RMSNorm
from .attention import DotProductAttention from .attention import DotProductAttention
from .attention import MultiheadAttention
from .transformer import TransformerLayer from .transformer import TransformerLayer
from .fp8 import fp8_autocast from .fp8 import fp8_autocast
from .export import onnx_export from .export import onnx_export
......
...@@ -30,6 +30,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -30,6 +30,7 @@ from transformer_engine.pytorch.utils import (
attention_mask_func, attention_mask_func,
split_tensor_along_dim, split_tensor_along_dim,
get_device_compute_capability, get_device_compute_capability,
get_default_init_method,
) )
from transformer_engine.pytorch.constants import ( from transformer_engine.pytorch.constants import (
AttnMaskTypes, AttnMaskTypes,
...@@ -56,7 +57,7 @@ else: ...@@ -56,7 +57,7 @@ else:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_forward_func # pylint: disable=no-name-in-module,ungrouped-imports from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_forward_func # pylint: disable=no-name-in-module,ungrouped-imports
__all__ = ["DotProductAttention"] __all__ = ["DotProductAttention", "MultiheadAttention"]
def _rotate_half(x: torch.Tensor) -> torch.Tensor: def _rotate_half(x: torch.Tensor) -> torch.Tensor:
...@@ -1181,20 +1182,132 @@ class DotProductAttention(torch.nn.Module): ...@@ -1181,20 +1182,132 @@ class DotProductAttention(torch.nn.Module):
) )
class MultiHeadAttention(torch.nn.Module): class MultiheadAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms r"""
BMM1 -> softmax + dropout -> BMM2 Multi-head Attention (MHA), including Query,
Key, Value and Output projection.
.. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`self_attn_mask_type` is set to `"causal"`.
Parameters
----------
hidden_size : int
size of each input sample.
num_attention_heads : int
number of attention heads in the transformer layer.
kv_channels: int, default = `None`
number of key-value channels. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
attention_dropout: float, default = 0.1
dropout probability for the dropout op during multi-head attention.
layernorm_epsilon : float, default = 1e-5
a value added to the denominator of layer normalization
for numerical stability.
init_method : Callable, default = `None`
used for initializing weights of QKV and FC1 weights in the following way:
`init_method(weight)`. When set to `None`, defaults to
`torch.nn.init.normal_(mean=0.0, std=0.023)`.
output_layer_init_method : Callable, default = `None`
used for initializing weights of PROJ and FC2 in the following way:
`output_layer_init_method(weight)`. When set to `None`, defaults to
`torch.nn.init.normal_(mean=0.0, std=0.023)`.
layer_number: int, default = `None`
layer number of the current `TransformerLayer` when multiple such modules are
concatenated to form a transformer block.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation.
num_gqa_groups : int, default = `None`
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
return_layernorm_output : bool, default = `False`
if set to `True`, output of layernorm is returned from the forward
together with the output of the linear transformation.
Example use case: residual connection for transformer module is
taken post layernorm.
input_layernorm: bool, default = `True`
if set to `False`, layer normalization to the input is not applied.
attention_type: { 'self', 'cross' }, default = 'self'
type of attention applied.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
qkv_weight_interleaved : bool, default = `True`
if set to `False`, the QKV weight is interpreted as a concatenation of
query, key, and value weights along the `0th` dimension. The default
interpretation is that the individual `q`, `k`, and `v` weights for each
attention head are interleaved. This parameter is set to `False` when
using :attr:`fuse_qkv_params=False`.
bias : bool, default = `True`
if set to `False`, the transformer layer will not learn any additive biases.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Parallelism parameters
----------------------
set_parallel_mode : bool, default = `False`
if set to `True`, QKV and FC1 layers are used as Column Parallel
whereas PROJ and FC2 is used as Row Parallel as described
`here <https://arxiv.org/pdf/1909.08053.pdf>`_.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
fuse_qkv_params: bool, default = 'False'
if set to `True`, `TransformerLayer` module exposes a single fused
parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`.
""" """
def __init__( def __init__(
self, self,
hidden_size: int, hidden_size: int,
num_attention_heads: int, num_attention_heads: int,
kv_channels: int, kv_channels: Optional[int] = None,
attention_dropout: float, attention_dropout: float = 0.1,
layernorm_epsilon: float, layernorm_epsilon: float = 1e-5,
init_method: Callable, init_method: Optional[Callable] = None,
output_layer_init_method: Callable, output_layer_init_method: Optional[Callable] = None,
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
...@@ -1204,6 +1317,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1204,6 +1317,7 @@ class MultiHeadAttention(torch.nn.Module):
get_rng_state_tracker: Optional[Callable] = None, get_rng_state_tracker: Optional[Callable] = None,
sequence_parallel: bool = False, sequence_parallel: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
return_bias: bool = False,
return_layernorm_output: bool = False, return_layernorm_output: bool = False,
input_layernorm: bool = False, input_layernorm: bool = False,
attention_type: str = "self", attention_type: str = "self",
...@@ -1227,9 +1341,16 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1227,9 +1341,16 @@ class MultiHeadAttention(torch.nn.Module):
self.tp_group = tp_group self.tp_group = tp_group
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.init_method = init_method
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.return_bias = return_bias
kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)
if init_method is None:
init_method = get_default_init_method()
if output_layer_init_method is None:
output_layer_init_method = get_default_init_method()
if not fuse_qkv_params: if not fuse_qkv_params:
qkv_weight_interleaved = False qkv_weight_interleaved = False
...@@ -1358,7 +1479,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1358,7 +1479,7 @@ class MultiHeadAttention(torch.nn.Module):
hidden_size, hidden_size,
init_method=output_layer_init_method, init_method=output_layer_init_method,
bias=bias, bias=bias,
return_bias=True, return_bias=return_bias,
parallel_mode="row" if set_parallel_mode else None, parallel_mode="row" if set_parallel_mode else None,
ub_split_rs=ub_split_rs, ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag, ub_split_ag=ub_split_ag,
...@@ -1395,10 +1516,54 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1395,10 +1516,54 @@ class MultiHeadAttention(torch.nn.Module):
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""MultiHeadAttention FWD""" """
Forward propagation for MultiheadAttention layer.
.. note::
Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type`
is set to `"causal"`.
Parameters
----------
hidden_states : torch.Tensor
Input tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input.
encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
checkpoint_core_attention: bool, default = `False`
If true, forward activations for core attention are recomputed
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied.
core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`}
core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T
fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use.
"""
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
if self.attn_mask_type != "causal" and attention_mask is not None: if self.attn_mask_type == "padding" and attention_mask is not None:
assert ( assert (
attention_mask.dtype == torch.bool attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor" ), "Attention mask must be a boolean tensor"
...@@ -1604,20 +1769,28 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -1604,20 +1769,28 @@ class MultiHeadAttention(torch.nn.Module):
key_layer, key_layer,
value_layer, value_layer,
attention_mask, attention_mask,
checkpoint_core_attention = checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
core_attention_bias_type = core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias = core_attention_bias, core_attention_bias=core_attention_bias,
fast_zero_fill = fast_zero_fill, fast_zero_fill=fast_zero_fill,
) )
# ================= # =================
# Output. [sq, b, h] # Output. [sq, b, h]
# ================= # =================
attention_output, attention_bias = self.proj( projection_output = self.proj(
context_layer, is_first_microbatch=is_first_microbatch context_layer, is_first_microbatch=is_first_microbatch
) )
if self.return_bias:
attention_output, attention_bias = projection_output
else:
attention_output, attention_bias = projection_output, None
outputs = (attention_output,)
if self.return_bias:
outputs += (attention_bias,)
if self.input_layernorm and self.return_layernorm_output: if self.input_layernorm and self.return_layernorm_output:
return attention_output, attention_bias, layernorm_output outputs += (layernorm_output,)
return attention_output, attention_bias return outputs if len(outputs) > 1 else outputs[0]
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.pytorch.attention import MultiHeadAttention from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.jit import ( from transformer_engine.pytorch.jit import (
set_jit_fusion_options, set_jit_fusion_options,
warmup_jit_bias_dropout_add_all_dtypes, warmup_jit_bias_dropout_add_all_dtypes,
...@@ -323,25 +323,27 @@ class TransformerLayer(torch.nn.Module): ...@@ -323,25 +323,27 @@ class TransformerLayer(torch.nn.Module):
"ub_split_rs" : ub_split_rs, "ub_split_rs" : ub_split_rs,
} }
self.self_attention = MultiHeadAttention( self.self_attention = MultiheadAttention(
*attention_args, *attention_args,
**common_attention_kwargs, **common_attention_kwargs,
attn_mask_type=self_attn_mask_type, attn_mask_type=self_attn_mask_type,
input_layernorm=not output_layernorm, input_layernorm=not output_layernorm,
attention_type="self", attention_type="self",
bias=bias, bias=bias,
return_bias=True,
normalization=normalization, normalization=normalization,
device=device, device=device,
) )
if layer_type == "decoder": if layer_type == "decoder":
self.inter_attention = MultiHeadAttention( self.inter_attention = MultiheadAttention(
*attention_args, *attention_args,
**common_attention_kwargs, **common_attention_kwargs,
attn_mask_type="padding", attn_mask_type="padding",
input_layernorm=True, input_layernorm=True,
attention_type="cross", attention_type="cross",
bias=bias, bias=bias,
return_bias=True,
normalization=normalization, normalization=normalization,
device=device, device=device,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment