transformer.py 44.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
#
# See LICENSE for license information.

"""Transformer."""
import os
7
import warnings
Przemek Tredak's avatar
Przemek Tredak committed
8
from contextlib import nullcontext
9
from typing import Callable, List, Optional, Tuple, Union
Przemek Tredak's avatar
Przemek Tredak committed
10
11
12

import torch

Paweł Gadziński's avatar
Paweł Gadziński committed
13
from transformer_engine.pytorch.torch_version import torch_version
14
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
15
from transformer_engine.debug.pytorch.debug_state import TEDebugState
16
17
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.inference import InferenceParams
Przemek Tredak's avatar
Przemek Tredak committed
18
19
20
21
22
23
24
25
26
27
from transformer_engine.pytorch.jit import (
    set_jit_fusion_options,
    warmup_jit_bias_dropout_add_all_dtypes,
    get_bias_dropout_add,
    bias_dropout_add_fused_train,
    bias_dropout_add_fused_inference,
)
from transformer_engine.pytorch.utils import (
    cast_if_needed,
    get_default_init_method,
28
    torch_get_autocast_gpu_dtype,
Przemek Tredak's avatar
Przemek Tredak committed
29
30
31
32
33
34
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    LayerTypes,
    dist_group_type,
)
35
from transformer_engine.pytorch.distributed import get_distributed_world_size
36
from transformer_engine.pytorch.export import is_in_onnx_export_mode
37
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
38

Przemek Tredak's avatar
Przemek Tredak committed
39

40
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
cyanguwa's avatar
cyanguwa committed
41
42


43
__all__ = ["TransformerLayer"]
cyanguwa's avatar
cyanguwa committed
44

Przemek Tredak's avatar
Przemek Tredak committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

class DropPath(torch.nn.Module):
    """Drop paths (Stochastic Depth) per sample
    (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob: float = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        """DropPath FWD"""

        if self.drop_prob == 0.0 or not self.training:
            return hidden_state
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
        shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
        random_tensor = keep_prob + torch.rand(
            shape, dtype=hidden_state.dtype, device=hidden_state.device
        )
        random_tensor.floor_()  # binarize
        output = hidden_state.div(keep_prob) * random_tensor
        return output


class TransformerLayer(torch.nn.Module):
72
    r"""
Przemek Tredak's avatar
Przemek Tredak committed
73
74
75
    TransformerLayer is made up of an attention block and a feedforward network (MLP).
    This standard layer is based on the paper "Attention Is All You Need".

76
    .. note::
77

Paweł Gadziński's avatar
Paweł Gadziński committed
78
79
        Argument :attr:`attention_mask` in the :meth:`forward` call is only used when
        :attr:`self_attn_mask_type` includes ``"padding"`` or ``"arbitrary"``.
80

Przemek Tredak's avatar
Przemek Tredak committed
81
82
83
84
85
86
87
88
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    ffn_hidden_size : int
                     intermediate size to which input samples are projected.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
Paweł Gadziński's avatar
Paweł Gadziński committed
89
    num_gqa_groups : int, default = None
90
91
92
93
94
95
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
Paweł Gadziński's avatar
Paweł Gadziński committed
96
                         is equivalent to MHA, i.e. ``num_gqa_groups = num_attention_heads``.
Przemek Tredak's avatar
Przemek Tredak committed
97
98
99
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
Paweł Gadziński's avatar
Paweł Gadziński committed
100
    hidden_dropout : float, default = 0.1
Przemek Tredak's avatar
Przemek Tredak committed
101
                   dropout probability for the dropout op after FC2 layer.
Paweł Gadziński's avatar
Paweł Gadziński committed
102
    attention_dropout : float, default = 0.1
Przemek Tredak's avatar
Przemek Tredak committed
103
                      dropout probability for the dropout op during multi-head attention.
Paweł Gadziński's avatar
Paweł Gadziński committed
104
    init_method : Callable, default = None
Przemek Tredak's avatar
Przemek Tredak committed
105
                 used for initializing weights of QKV and FC1 weights in the following way:
Paweł Gadziński's avatar
Paweł Gadziński committed
106
107
108
                 ``init_method(weight)``. When set to ``None``, defaults to
                 ``torch.nn.init.normal_(mean=0.0, std=0.023)``.
    output_layer_init_method : Callable, default = None
Przemek Tredak's avatar
Przemek Tredak committed
109
                              used for initializing weights of PROJ and FC2 in the following way:
Paweł Gadziński's avatar
Paweł Gadziński committed
110
111
112
113
                              ``output_layer_init_method(weight)``. When set to ``None``, defaults to
                              ``torch.nn.init.normal_(mean=0.0, std=0.023)``.
    apply_residual_connection_post_layernorm : bool, default = False
                                              if set to ``True``, residual connections are taken
Przemek Tredak's avatar
Przemek Tredak committed
114
115
                                              from the output of layer norm (default is taken
                                              from input of layer norm)
Paweł Gadziński's avatar
Paweł Gadziński committed
116
117
    layer_number : int, default = None
                 layer number of the current :class:`TransformerLayer` when multiple such modules are
Przemek Tredak's avatar
Przemek Tredak committed
118
                 concatenated to form a transformer block.
Paweł Gadziński's avatar
Paweł Gadziński committed
119
120
    output_layernorm : bool, default = False
                     if set to ``True``, layer normalization is applied on the output side,
Przemek Tredak's avatar
Przemek Tredak committed
121
122
                     after the final dropout-add. default behavior is to apply layer
                     normalization on the input side, before the QKV transformation.
Paweł Gadziński's avatar
Paweł Gadziński committed
123
124
    parallel_attention_mlp : bool, default = False
                           if set to ``True``, self-attention and feedforward network are computed
125
126
127
                           based on the same input (in parallel) instead of sequentially.
                           Both blocks have an independent normalization.
                           This architecture is used in `Falcon` models.
Paweł Gadziński's avatar
Paweł Gadziński committed
128
129
    layer_type : {'encoder', 'decoder'}, default = "encoder"
               if set to ``"decoder"``, an additional cross-attn block is added after self-attn.
Przemek Tredak's avatar
Przemek Tredak committed
130
               This can be used for structures like `T5` Transformer in conjunction with the
Paweł Gadziński's avatar
Paweł Gadziński committed
131
132
               ``"encoder"`` option.
    kv_channels : int, default = None
133
                number of query-key-value channels per attention head. defaults to
Paweł Gadziński's avatar
Paweł Gadziński committed
134
135
                :attr:`hidden_size` / :attr:`num_attention_heads` if ``None``.
    self_attn_mask_type : {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
136
                        'padding_causal_bottom_right', 'arbitrary'},
Paweł Gadziński's avatar
Paweł Gadziński committed
137
                        default = "causal"
138
                        type of attention mask passed into softmax operation for encoder.
Paweł Gadziński's avatar
Paweł Gadziński committed
139
140
141
                        Overridden by :attr:`self_attn_mask_type` in the :meth:`forward` method.
                        The :meth:`forward` arg is useful for dynamically changing mask types, e.g.
                        a different mask for training and inference. The :meth:`__init__` arg is useful
142
                        for cases involving compilation/tracing, e.g. ONNX export.
Paweł Gadziński's avatar
Paweł Gadziński committed
143
    window_size : Optional[Tuple[int, int]], default = None
144
                sliding window size for local attention in encoder, where query at position i
Paweł Gadziński's avatar
Paweł Gadziński committed
145
146
147
148
149
150
151
152
153
                attends to keys in ``[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k
                - seqlen_q + window_size[1]]`` inclusive. Special cases ``(-1, -1)`` and ``(-1, 0)`` mean
                no sliding window and causal mask specifically. Both ``"causal"`` and
                ``"causal_bottom_right"`` masks map to :attr:`window_size` = ``(-1, 0)`` and Transformer Engine
                distinguishes them based on :attr:`self_attn_mask_type` or :attr:`enc_dec_attn_mask_type`.
                Similar to :attr:`self_attn_mask_type`, :attr:`window_size` can be overridden by
                :attr:`window_size` in :meth:`forward` as well.
    enc_dec_attn_mask_type : {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
                           default = "no_mask"
154
                           type of attention mask passed into softmax operation for decoder.
Paweł Gadziński's avatar
Paweł Gadziński committed
155
    enc_dec_window_size : Optional[Tuple[int, int]], default = None
156
                        sliding window size for local attention in decoder.
Paweł Gadziński's avatar
Paweł Gadziński committed
157
158
    zero_centered_gamma : bool, default = False
                         if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and
159
160
161
162
163
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
164
165
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
Paweł Gadziński's avatar
Paweł Gadziński committed
166
167
168
169
170
    qkv_weight_interleaved : bool, default = True
                            if set to ``False``, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the ``0th`` dimension. The default
                            interpretation is that the individual ``q``, ``k``, and ``v`` weights for each
                            attention head are interleaved. This parameter is set to ``False`` when
171
                            using :attr:`fuse_qkv_params=False`.
Paweł Gadziński's avatar
Paweł Gadziński committed
172
    rotary_pos_interleaved : bool, default = False
173
                            whether to use interleaved rotary position embeddings.
Paweł Gadziński's avatar
Paweł Gadziński committed
174
175
    bias : bool, default = True
          if set to ``False``, the transformer layer will not learn any additive biases.
176
177
    activation : str, default = 'gelu'
          Type of activation used in MLP block.
Paweł Gadziński's avatar
Paweł Gadziński committed
178
179
180
          Options are: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
          ``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``.
    activation_params : Optional[dict], default = None
181
                        Additional parameters for the activation function.
Paweł Gadziński's avatar
Paweł Gadziński committed
182
183
184
                        At the moment, only used for ``'clamped_swiglu'`` activation which
                        supports ``'limit'`` and ``'alpha'`` parameters. You can set these as
                        ``activation_params={'limit': 7.0, 'alpha': 1.702}``.
185
    device : Union[torch.device, str], default = "cuda"
186
          The device on which the parameters of the model will be allocated. It is the user's
187
188
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
Paweł Gadziński's avatar
Paweł Gadziński committed
189
190
191
192
193
194
195
196
197
    attn_input_format : {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
            This controls whether the dimensions of the
            intermediate hidden states is 'sequence first' (``'sbhd'``), 'batch first' (``'bshd'``),
            or 'token first' (``'thd'``). ``s`` stands for the sequence length, ``b`` batch size,
            ``t`` the total number of tokens, ``h`` the number of heads, ``d`` head size.
            Note that these formats are very closely
            related to the :attr:`qkv_format` parameter in the :class:`MultiHeadAttention`
            and :class:`DotProductAttention` modules.
    name : str, default = None
198
        name of the module, currently used for debugging purposes.
Paweł Gadziński's avatar
Paweł Gadziński committed
199
200
    softmax_type : str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
                 Softmax type as described in the paper
201
202
                 `Efficient Streaming Language Models with Attention Sinks
                 <https://arxiv.org/pdf/2309.17453v3>`_.
Paweł Gadziński's avatar
Paweł Gadziński committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

                 For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:

                 * ``'vanilla'``:

                   .. math::
                      Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}

                 * ``'off-by-one'``:

                   .. math::
                      Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}

                 * ``'learnable'``:

                   .. math::
                      Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}

                   where :math:`\\alpha` is a learnable parameter of shape ``[h]``.

                 ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
                 (``'zero sink'`` and ``'learnable sink'``).
ngoyal2707's avatar
ngoyal2707 committed
225

Przemek Tredak's avatar
Przemek Tredak committed
226
227
    Parallelism parameters
    ----------------------
Paweł Gadziński's avatar
Paweł Gadziński committed
228
229
    set_parallel_mode : bool, default = False
                      if set to ``True``, QKV and FC1 layers are used as Column Parallel
Przemek Tredak's avatar
Przemek Tredak committed
230
231
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
Paweł Gadziński's avatar
Paweł Gadziński committed
232
233
234
    sequence_parallel : bool, default = False
                       if set to ``True``, uses sequence parallelism.
    tp_group : ProcessGroup, default = None
Przemek Tredak's avatar
Przemek Tredak committed
235
236
237
238
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
Paweł Gadziński's avatar
Paweł Gadziński committed
239
             :meth:`set_tensor_parallel_group` method on the initialized module before the
Przemek Tredak's avatar
Przemek Tredak committed
240
241
242
243
244
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
Paweł Gadziński's avatar
Paweł Gadziński committed
245
246
    fuse_wgrad_accumulation : bool, default = False
                             if set to ``True``, enables fusing of creation and accumulation of
247
                             the weight gradient. When enabled, it is assumed that the weights
Paweł Gadziński's avatar
Paweł Gadziński committed
248
249
                             have an additional :attr:`main_grad` attribute (used instead of the
                             regular :attr:`grad`) which is a pre-allocated buffer of the correct
250
                             size to accumulate gradients in.
Paweł Gadziński's avatar
Paweł Gadziński committed
251
    params_dtype : torch.dtype, default = torch.get_default_dtype()
Przemek Tredak's avatar
Przemek Tredak committed
252
253
254
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
Paweł Gadziński's avatar
Paweł Gadziński committed
255
    seq_length : int
Przemek Tredak's avatar
Przemek Tredak committed
256
257
258
               sequence length of input samples. Needed for JIT Warmup, a technique where jit
               fused functions are warmed up before training to ensure same kernels are used for
               forward propogation and activation recompute phase.
Paweł Gadziński's avatar
Paweł Gadziński committed
259
    micro_batch_size : int
Przemek Tredak's avatar
Przemek Tredak committed
260
261
262
                     batch size per training step. Needed for JIT Warmup, a technique where jit
                     fused functions are warmed up before training to ensure same kernels are
                     used for forward propogation and activation recompute phase.
Paweł Gadziński's avatar
Paweł Gadziński committed
263
    drop_path_rate : float, default = 0.0
Przemek Tredak's avatar
Przemek Tredak committed
264
265
                   when > 0.0, applies stochastic depth per sample in
                   the main path of the residual block.
Paweł Gadziński's avatar
Paweł Gadziński committed
266
267
    fuse_qkv_params : bool, default = False
                    if set to ``True``, :class:`TransformerLayer` module exposes a single fused
Przemek Tredak's avatar
Przemek Tredak committed
268
269
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
Paweł Gadziński's avatar
Paweł Gadziński committed
270
271
                    :attr:`fuse_wgrad_accumulation`.
    qk_norm_type : Optional[str], default = None
272
                    type of normalization to apply to query and key tensors.
Paweł Gadziński's avatar
Paweł Gadziński committed
273
274
275
276
                    Options: ``None``, ``'L2Normalization'``, ``'RMSNorm'``, ``'LayerNorm'``. When ``None``, no normalization is applied.
                    When ``'L2Normalization'``, L2 normalization is applied to query and key tensors.
                    When ``'RMSNorm'``, RMS normalization is applied to query and key tensors.
                    When ``'LayerNorm'``, layer normalization is applied to query and key tensors.
277
                    Normalization is applied after RoPE (if applicable) but before attention computation
Paweł Gadziński's avatar
Paweł Gadziński committed
278
                    when ``qk_norm_before_rope`` is ``False``. This follows the e.g. Llama4 approach for
279
                    QK normalization to improve training stability and model performance.
Paweł Gadziński's avatar
Paweł Gadziński committed
280
    qk_norm_eps : float, default = 1e-6
281
                    epsilon value for normalization of query and key tensors.
Paweł Gadziński's avatar
Paweł Gadziński committed
282
283
284
285
                    Only used when ``qk_norm_type`` is not ``None``.
    qk_norm_before_rope : bool, default = False
                    if set to ``True``, query and key normalization is applied before rotary position
                    embedding. When ``False`` (default), normalization is applied after RoPE.
286
287
                    This parameter allows supporting different architectural variants that apply
                    QK normalization at different points.
Przemek Tredak's avatar
Przemek Tredak committed
288
289
290
291
292
293
294
    """

    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        num_attention_heads: int,
295
        num_gqa_groups: Optional[int] = None,
Przemek Tredak's avatar
Przemek Tredak committed
296
297
298
299
300
301
302
        layernorm_epsilon: float = 1e-5,
        hidden_dropout: float = 0.1,
        attention_dropout: float = 0.1,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
        layer_number: Optional[int] = None,
        kv_channels: Optional[int] = None,
303
        self_attn_mask_type: str = "causal",
304
        window_size: Optional[Tuple[int, int]] = None,
305
306
        enc_dec_attn_mask_type: str = "no_mask",
        enc_dec_window_size: Optional[Tuple[int, int]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
307
308
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
309
        params_dtype: Optional[torch.dtype] = None,
Przemek Tredak's avatar
Przemek Tredak committed
310
311
312
313
314
315
316
        get_rng_state_tracker: Optional[Callable] = None,
        fuse_wgrad_accumulation: bool = False,
        seq_length: Optional[int] = None,
        micro_batch_size: Optional[int] = None,
        sequence_parallel: bool = False,
        apply_residual_connection_post_layernorm: bool = False,
        output_layernorm: bool = False,
317
        parallel_attention_mlp: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
318
319
320
321
        layer_type: str = "encoder",
        drop_path_rate: float = 0.0,
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
322
        rotary_pos_interleaved: bool = False,
323
        zero_centered_gamma: bool = False,
324
        qkv_weight_interleaved: bool = True,
325
        ub_tp_comm_overlap: bool = False,
326
327
        ub_overlap_ag: bool = True,
        ub_overlap_rs: bool = True,
Jaemin Choi's avatar
Jaemin Choi committed
328
        ub_overlap_rs_dgrad: bool = False,
329
330
        ub_bulk_dgrad: bool = True,
        ub_bulk_wgrad: bool = True,
ngoyal2707's avatar
ngoyal2707 committed
331
        bias: bool = True,
332
        activation: str = "gelu",
333
        activation_params: Optional[dict] = None,
334
        normalization: str = "LayerNorm",
335
        device: Union[torch.device, str] = "cuda",
336
        attn_input_format: str = "sbhd",
337
        name: str = None,
338
        qk_norm_type: Optional[str] = None,
339
        qk_norm_eps: float = 1e-6,
340
        qk_norm_before_rope: bool = False,
341
        softmax_type: str = "vanilla",
Przemek Tredak's avatar
Przemek Tredak committed
342
343
344
    ) -> None:
        super().__init__()

345
        self.self_attn_mask_type = self_attn_mask_type
346
        self.window_size = window_size
347
        self.enc_dec_attn_mask_type = enc_dec_attn_mask_type
348
        self.enc_dec_window_size = enc_dec_window_size
349
        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
350
351
        ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
        ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
352
353
        ub_overlap_ag = ub_tp_comm_overlap and ub_overlap_ag
        ub_overlap_rs = ub_tp_comm_overlap and ub_overlap_rs
Jaemin Choi's avatar
Jaemin Choi committed
354
        ub_overlap_rs_dgrad = ub_tp_comm_overlap and ub_overlap_rs_dgrad
355

Przemek Tredak's avatar
Przemek Tredak committed
356
357
358
359
        bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
        self.layer_number = layer_number
        self.output_layernorm = output_layernorm
        self.layer_type = layer_type
360
        self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
361

362
363
        if parallel_attention_mlp:
            assert self.layer_type == "encoder", "parallel_attention requires layer_type='encoder'"
364
365
366
367
            assert not self.apply_residual_connection_post_layernorm, (
                "parallel_attention and apply_residual_connection_post_layernorm "
                "not supported simultaneously."
            )
368
369
370
371
372
373
            assert (
                not self.output_layernorm
            ), "parallel_attention and output_layernorm not supported simultaneously"

        self.parallel_attention_mlp = parallel_attention_mlp

Przemek Tredak's avatar
Przemek Tredak committed
374
375
376
377
378
379
380
        assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"

        if not fuse_qkv_params:
            assert (
                not fuse_wgrad_accumulation
            ), "Gradient accumulation fusion requires single QKV parameter."

381
382
383
        if not fuse_qkv_params:
            qkv_weight_interleaved = False

384
        self.kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)
Przemek Tredak's avatar
Przemek Tredak committed
385
386
387
388
389
390

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()

391
392
393
        self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
        self.seq_length = seq_length
Przemek Tredak's avatar
Przemek Tredak committed
394
395
396

        self.get_rng_state_tracker = get_rng_state_tracker

397
        self.attn_input_format = attn_input_format
398
        self.softmax_type = softmax_type
399

400
401
        self.name = name

Przemek Tredak's avatar
Przemek Tredak committed
402
403
404
405
406
407
408
409
410
411
412
413
        attention_args = (
            hidden_size,
            num_attention_heads,
            self.kv_channels,
            attention_dropout,
            layernorm_epsilon,
            init_method,
            output_layer_init_method,
        )
        common_attention_kwargs = {
            "layer_number": layer_number,
            "tp_group": tp_group,
414
            "tp_size": self.tp_size,
415
            "num_gqa_groups": num_gqa_groups,
Przemek Tredak's avatar
Przemek Tredak committed
416
417
418
419
420
421
422
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": self.sequence_parallel,
            "params_dtype": params_dtype,
            "return_layernorm_output": apply_residual_connection_post_layernorm,
            "set_parallel_mode": set_parallel_mode,
            "fuse_qkv_params": fuse_qkv_params,
cyanguwa's avatar
cyanguwa committed
423
            "zero_centered_gamma": zero_centered_gamma,
424
            "qkv_weight_interleaved": qkv_weight_interleaved,
425
            "rotary_pos_interleaved": rotary_pos_interleaved,
426
427
428
429
430
431
            "ub_bulk_wgrad": ub_bulk_wgrad,
            "ub_bulk_dgrad": ub_bulk_dgrad,
            "ub_overlap_ag": ub_overlap_ag,
            "ub_overlap_rs": ub_overlap_rs,
            "ub_overlap_rs_dgrad": ub_overlap_rs_dgrad,
            "qkv_format": self.attn_input_format,
432
433
            "seq_length": seq_length,
            "micro_batch_size": micro_batch_size,
434
            "softmax_type": self.softmax_type,
Przemek Tredak's avatar
Przemek Tredak committed
435
436
        }

437
        self.self_attention = MultiheadAttention(
Przemek Tredak's avatar
Przemek Tredak committed
438
439
440
441
            *attention_args,
            **common_attention_kwargs,
            input_layernorm=not output_layernorm,
            attention_type="self",
ngoyal2707's avatar
ngoyal2707 committed
442
            bias=bias,
443
            return_bias=not self.parallel_attention_mlp,
444
            normalization=normalization,
445
            device=device,
446
            qk_norm_type=qk_norm_type,
447
            qk_norm_eps=qk_norm_eps,
448
            qk_norm_before_rope=qk_norm_before_rope,
449
            name=name + ".self_attention" if name is not None else None,
Przemek Tredak's avatar
Przemek Tredak committed
450
451
452
        )

        if layer_type == "decoder":
453
            self.inter_attention = MultiheadAttention(
Przemek Tredak's avatar
Przemek Tredak committed
454
455
                *attention_args,
                **common_attention_kwargs,
456
                attn_mask_type=enc_dec_attn_mask_type,
Przemek Tredak's avatar
Przemek Tredak committed
457
458
                input_layernorm=True,
                attention_type="cross",
ngoyal2707's avatar
ngoyal2707 committed
459
                bias=bias,
460
                return_bias=True,
461
                normalization=normalization,
462
                device=device,
463
                qk_norm_type=qk_norm_type,
464
                qk_norm_eps=qk_norm_eps,
465
                qk_norm_before_rope=qk_norm_before_rope,
466
                name=name + ".inter_attention" if name is not None else None,
Przemek Tredak's avatar
Przemek Tredak committed
467
468
            )

469
        # LayerNorm -> activation(Linear + Bias) -> Linear
Przemek Tredak's avatar
Przemek Tredak committed
470
471
        # parallel_mode not supported for LayerNormMLP,
        # FC1 is CPL and FC2 is RPL
472
473
        # In the case of GLU activation, FC1 handles both
        # Linear layers before the activation
Przemek Tredak's avatar
Przemek Tredak committed
474
475
476
477
478
479
        self.layernorm_mlp = LayerNormMLP(
            hidden_size,
            ffn_hidden_size,
            eps=layernorm_epsilon,
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            tp_group=tp_group,
480
            tp_size=self.tp_size,
Przemek Tredak's avatar
Przemek Tredak committed
481
482
483
            get_rng_state_tracker=get_rng_state_tracker,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
ngoyal2707's avatar
ngoyal2707 committed
484
            bias=bias,
485
            return_bias=not self.parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
486
487
488
489
490
491
            sequence_parallel=self.sequence_parallel,
            params_dtype=params_dtype,
            return_layernorm_output=apply_residual_connection_post_layernorm,
            seq_length=seq_length,
            micro_batch_size=micro_batch_size,
            set_parallel_mode=set_parallel_mode,
492
            zero_centered_gamma=zero_centered_gamma,
493
494
            ub_bulk_wgrad=ub_bulk_wgrad,
            ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
495
            ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
496
497
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
498
            activation=activation,
499
            activation_params=activation_params,
500
            normalization=normalization,
501
            device=device,
502
            name=name + ".layernorm_mlp" if name is not None else None,
Przemek Tredak's avatar
Przemek Tredak committed
503
504
505
506
507
508
509
        )

        self.hidden_dropout = hidden_dropout
        self.bias_dropout_fusion = bias_dropout_fusion
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None

        # Set bias+dropout+add fusion grad_enable execution handler.
510
        use_nvfuser = torch_version() >= (1, 10, 0) and torch_version() < (2, 2, 0)
511
        self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
Przemek Tredak's avatar
Przemek Tredak committed
512
513
514
515
516

        if self.bias_dropout_fusion:
            set_jit_fusion_options()
            if seq_length and micro_batch_size:
                if self.sequence_parallel:
517
                    seq_length = seq_length // self.tp_size
518
                warmup_jit_bias_dropout_add_all_dtypes(hidden_size, seq_length, micro_batch_size)
Przemek Tredak's avatar
Przemek Tredak committed
519

520
        norm_module = {
521
522
            "LayerNorm": LayerNorm,
            "RMSNorm": RMSNorm,
523
        }
Przemek Tredak's avatar
Przemek Tredak committed
524
        if self.output_layernorm:
525
            self.layernorm = norm_module[normalization](
Przemek Tredak's avatar
Przemek Tredak committed
526
527
528
529
                hidden_size,
                eps=layernorm_epsilon,
                sequence_parallel=self.sequence_parallel,
                params_dtype=params_dtype,
530
531
                zero_centered_gamma=zero_centered_gamma,
                device=device,
Przemek Tredak's avatar
Przemek Tredak committed
532
533
534
            )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
535
536
537
538
539
540
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
541
        tp_group : ProcessGroup, default = None
542
543
                  tensor parallel process group.
        """
Przemek Tredak's avatar
Przemek Tredak committed
544
545
546
547
548
549
550
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_tensor_parallel_group"):
                child.set_tensor_parallel_group(tp_group)

551
552
553
554
555
556
557
558
559
    def reset_fp8_meta_tensors(self) -> None:
        """Set TP group"""
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "reset_fp8_meta_tensors"):
                child.reset_fp8_meta_tensors()

560
    def set_context_parallel_group(
561
        self,
562
        cp_group: Union[dist_group_type, List[dist_group_type], None],
563
        cp_global_ranks: List[int],
564
        cp_stream: torch.cuda.Stream,
565
        cp_comm_type: str = "p2p",
566
    ) -> None:
Paweł Gadziński's avatar
Paweł Gadziński committed
567
        r"""
568
569
570
571
572
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
573
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
574
                  context parallel process group.
Paweł Gadziński's avatar
Paweł Gadziński committed
575
576
577
                  ProcessGroup is for cp_comm_type of ``"p2p"``, ``"all_gather"``, and ``"a2a"``.
                  List[ProcessGroup] is for cp_comm_type of ``"a2a+p2p"``, where ``cp_group[0]``
                  and ``cp_group[1]`` are for a2a and p2p communications respectively.
578
579
580
581
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
Paweł Gadziński's avatar
Paweł Gadziński committed
582
        cp_comm_type : str, default = "p2p"
583
                      inter-gpu communication type for context parallelism.
Paweł Gadziński's avatar
Paweł Gadziński committed
584
585
586
587
588
589
590
591
592
593
594
                      Can be ``"p2p"`` or ``"all_gather"`` or ``"a2a"`` or ``"a2a+p2p"``.

                      - ``"p2p"``: Exchange KV chunks with P2P communications in ring topology.
                        P2P is async and can be overlapped with attention compute.
                      - ``"all_gather"``: All-gather to get full sequence of KV before attention.
                        The all-gather is not async, and cannot be overlapped.
                      - ``"a2a"``: Like DeepSpeed Ulysses, scatter attention heads across the CP
                        group, and gather to get full sequence of QKV.
                      - ``"a2a+p2p"``: hierarchical CP implementation. First applying a2a to QKV
                        across each CP sub-group (e.g., via NVLink), then exchanging KV with
                        p2p between sub-groups (e.g., via IBLink).
595
        """
596
597
598
599
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
600
            if hasattr(child, "set_context_parallel_group"):
601
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type)
602

Przemek Tredak's avatar
Przemek Tredak committed
603
604
605
    def forward(
        self,
        hidden_states: torch.Tensor,
cyanguwa's avatar
cyanguwa committed
606
        attention_mask: Optional[torch.Tensor] = None,
607
        self_attn_mask_type: Optional[str] = None,
608
        window_size: Optional[Tuple[int, int]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
609
        encoder_output: Optional[torch.Tensor] = None,
610
        enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
611
612
        enc_dec_attn_mask_type: Optional[str] = None,
        enc_dec_window_size: Optional[Tuple[int, int]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
613
        is_first_microbatch: Optional[bool] = None,
cyanguwa's avatar
cyanguwa committed
614
        checkpoint_core_attention: bool = False,
615
        inference_params: Optional[InferenceParams] = None,
616
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
617
618
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
619
        alibi_slopes: Optional[torch.Tensor] = None,
620
621
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
622
623
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
624
625
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
626
        fast_zero_fill: bool = True,
627
        pad_between_seqs: Optional[bool] = None,
Przemek Tredak's avatar
Przemek Tredak committed
628
    ) -> torch.Tensor:
Paweł Gadziński's avatar
Paweł Gadziński committed
629
        r"""
Przemek Tredak's avatar
Przemek Tredak committed
630
631
        Transformer Layer: attention block and a feedforward network (MLP)

632
633
        .. note::

634
            Argument :attr:`attention_mask` is only used when :attr:`self_attn_mask_type`
Paweł Gadziński's avatar
Paweł Gadziński committed
635
            includes ``"padding"`` or ``"arbitrary"``.
636

Przemek Tredak's avatar
Przemek Tredak committed
637
638
639
        Parameters
        ----------
        hidden_states : torch.Tensor
640
            Input tensor.
Paweł Gadziński's avatar
Paweł Gadziński committed
641
        attention_mask : Optional[torch.Tensor], default = None
642
            Boolean tensor used to mask out self-attention softmax input. It should be
Paweł Gadziński's avatar
Paweł Gadziński committed
643
644
645
646
647
            in ``[batch_size, 1, 1, seqlen_q]`` for padding masks, and broadcastable
            to ``[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]`` for ``"arbitrary"``
            mask. It should be ``None`` for causal masks and ``"no_mask"`` type.
            A ``True`` value means the corresponding position is masked out and
            a ``False`` means that position is allowed to participate in attention.
648
        self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal',
649
            'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'},
Paweł Gadziński's avatar
Paweł Gadziński committed
650
            default = "causal"
651
652
            Type of attention mask passed into softmax operation for encoder.
            By default, causal masks are aligned to the top left corner of
Paweł Gadziński's avatar
Paweł Gadziński committed
653
            the softmax matrix. When ``"bottom_right"`` is specified in the mask type,
654
            causal masks are aligned to the bottom right corner.
Paweł Gadziński's avatar
Paweł Gadziński committed
655
        window_size: Optional[Tuple[int, int]], default = None
656
            Sliding window size for local attention in encoder.
Paweł Gadziński's avatar
Paweł Gadziński committed
657
        encoder_output : Optional[torch.Tensor], default = None
658
            Output of the encoder block to be fed into the decoder block if using
Paweł Gadziński's avatar
Paweł Gadziński committed
659
            :attr:`layer_type` = ``"decoder"``.
660
        enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
Paweł Gadziński's avatar
Paweł Gadziński committed
661
662
663
664
665
666
            default = None. Boolean tensors used to mask out inter-attention softmax input if
            using :attr:`layer_type` = ``"decoder"``. It should be a tuple of two masks in
            ``[batch_size, 1, 1, seqlen_q]`` and ``[batch_size, 1, 1, seqlen_kv]`` for padding masks.
            It should be broadcastable to ``[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]``
            for ``"arbitrary"`` mask. It should be ``None`` for causal masks and ``"no_mask"``.
            A ``True`` value means the corresponding position is masked out and a ``False``
667
            means that position is allowed to participate in attention.
668
        enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
Paweł Gadziński's avatar
Paweł Gadziński committed
669
            default = None
670
            Type of attention mask passed into softmax operation for decoder.
Paweł Gadziński's avatar
Paweł Gadziński committed
671
        enc_dec_window_size: Optional[Tuple[int, int]], default = None
672
            Sliding window size for local attention in decoder.
Przemek Tredak's avatar
Przemek Tredak committed
673
        is_first_microbatch : {True, False, None}, default = None
674
675
676
677
678
679
680
681
682
683
684
685
            During training using either gradient accumulation or
            pipeline parallelism a minibatch of data is further split
            into microbatches. Between the microbatches of the same minibatch
            the model weights are not updated. Setting this parameter indicates
            whether the current microbatch is the first in a minibatch or not.
            When set, this parameter enables additional optimizations:

            * during FP8 training, it allows caching of the FP8 versions of
              the weights
            * it also allows skipping gradient accumulation during the
              first microbatch (since it is the first gradient being
              produced)
Paweł Gadziński's avatar
Paweł Gadziński committed
686
687
        checkpoint_core_attention: bool, default = False
            If ``True``, forward activations for core attention are recomputed
688
689
690
            during the backward pass in order to save memory that would
            otherwise be occupied to store the forward activations until
            backprop.
Paweł Gadziński's avatar
Paweł Gadziński committed
691
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = None
692
693
            Embeddings for query and key tensors for applying rotary position
            embedding. By default no input embedding is applied.
Paweł Gadziński's avatar
Paweł Gadziński committed
694
695
696
697
698
699
700
        core_attention_bias_type: str, default = "no_bias"
            Bias type, {``"no_bias"``, ``"pre_scale_bias"``, ``"post_scale_bias"``, ``"alibi"``}
        core_attention_bias: Optional[torch.Tensor], default = None
            Bias tensor for :math:`Q \cdot K^T`
        alibi_slopes: Optional[torch.Tensor], default = None
            ALiBi slopes in FP32 and shape ``[nheads]`` or ``[batch_size, nheads]``.
            It adds a bias of :math:`(-\text{alibi_slope} \cdot (i + \text{seqlen_k} - \text{seqlen_q} - j))`
701
            to the attention score of query i and key j.
Paweł Gadziński's avatar
Paweł Gadziński committed
702
703
704
        cu_seqlens_q: Optional[torch.Tensor], default = None
            Cumulative sum of sequence lengths (without offset) in a batch for query layer,
            with shape ``[batch_size + 1]`` and dtype torch.int32.
705
            Used by encoders, or decoders' self-attention.
Paweł Gadziński's avatar
Paweł Gadziński committed
706
707
708
        cu_seqlens_kv: Optional[torch.Tensor], default = None
            Cumulative sum of sequence lengths (without offset) in a batch for key layer
            and value layer, with shape ``[batch_size + 1]`` and dtype torch.int32.
709
            Used by decoders' cross-attention.
Paweł Gadziński's avatar
Paweł Gadziński committed
710
711
712
        cu_seqlens_q_padded: Optional[torch.Tensor], default = None
            Cumulative sum of sequence lengths (with offset) in a batch for query layer,
            with shape ``[batch_size + 1]`` and dtype torch.int32. Set to :attr:`cu_seqlens_q` if ``None``.
713
            Used by encoders, or decoders' self-attention.
Paweł Gadziński's avatar
Paweł Gadziński committed
714
715
716
717
718
719
720
721
722
723
724
        cu_seqlens_kv_padded: Optional[torch.Tensor], default = None
            Cumulative sum of sequence lengths (with offset) in a batch for key layer
            and value layer, with shape ``[batch_size + 1]`` and dtype torch.int32.
            Set to :attr:`cu_seqlens_kv` if ``None``. Used by decoders' cross-attention.
        max_seqlen_q: Optional[int], default = None
            Maximum sequence length in query layer.
            Calculated from :attr:`cu_seqlens_q_padded` if not provided.
        max_seqlen_kv: Optional[int], default = None
            Maximum sequence length in key layer and value layer.
            Calculated from :attr:`cu_seqlens_kv_padded` if not provided.
        fast_zero_fill: bool, default = True
725
            Whether to set output tensors to 0 or not before use.
726
        inference_params: InferenceParams, default = None
727
728
            Inference parameters that are passed to the main model in order
            to efficiently calculate and store the context during inference.
Paweł Gadziński's avatar
Paweł Gadziński committed
729
730
731
732
        pad_between_seqs: Optional[bool], default = None
            If ``None``, inferred from :attr:`qkv_format`, cu_seqlens and cu_seqlens_padded.
            If ``True``, there are padding tokens between individual sequences in a packed batch,
            i.e. :attr:`qkv_format` = ``'thd'``.
Przemek Tredak's avatar
Przemek Tredak committed
733
734
        """

735
        if self_attn_mask_type is None:
736
            self_attn_mask_type = self.self_attn_mask_type
737
738
        if window_size is None:
            window_size = self.window_size
739
740
741
742
        if enc_dec_attn_mask_type is None:
            enc_dec_attn_mask_type = self.enc_dec_attn_mask_type
        if enc_dec_window_size is None:
            enc_dec_window_size = self.enc_dec_window_size
743
744
745
746

        assert (
            self_attn_mask_type in AttnMaskTypes
        ), f"self_attn_mask_type {self_attn_mask_type} not supported"
747
748
749
        assert (
            enc_dec_attn_mask_type in AttnMaskTypes
        ), f"enc_dec_attn_mask_type {enc_dec_attn_mask_type} not supported"
750

751
752
        hidden_states = hidden_states.contiguous()

753
754
755
756
757
        if self.sequence_parallel and self.seq_length is not None:
            assert (
                hidden_states.shape[0] == self.seq_length // self.tp_size
            ), "Sequence dimension must be split across TP group when using sequence parallel."

758
759
760
        if (
            "padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary"
        ) and attention_mask is not None:
761
762
763
            assert all(
                attention_mask[i].dtype == torch.bool for i in range(len(attention_mask))
            ), "Attention mask must be a boolean tensor or a list/tuple of two boolean tensors"
764
765
766
767
768
769
        if (
            "padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary"
        ) and enc_dec_attn_mask is not None:
            assert all(
                enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask))
            ), "Encoder-decoder attention mask must be boolean tensor(s)"
770

771
772
773
        if TEDebugState.debug_enabled:
            TransformerEngineBaseModule._validate_name(self)

Przemek Tredak's avatar
Przemek Tredak committed
774
775
        # For AMP
        if torch.is_autocast_enabled():
776
            hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype())
Przemek Tredak's avatar
Przemek Tredak committed
777
778
779
780

        # Self attention.
        self_attention_outputs = self.self_attention(
            hidden_states,
781
782
            attention_mask=attention_mask,
            attn_mask_type=self_attn_mask_type,
783
            window_size=window_size,
Przemek Tredak's avatar
Przemek Tredak committed
784
785
786
            inference_params=inference_params,
            is_first_microbatch=is_first_microbatch,
            checkpoint_core_attention=checkpoint_core_attention,
787
            rotary_pos_emb=rotary_pos_emb,
788
789
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
790
            alibi_slopes=alibi_slopes,
791
            cu_seqlens_q=cu_seqlens_q,
792
793
794
            cu_seqlens_kv=cu_seqlens_q,
            cu_seqlens_q_padded=cu_seqlens_q_padded,
            cu_seqlens_kv_padded=cu_seqlens_q_padded,
795
            max_seqlen_q=max_seqlen_q,
796
            max_seqlen_kv=max_seqlen_q,
797
            fast_zero_fill=fast_zero_fill,
798
            pad_between_seqs=pad_between_seqs,
Przemek Tredak's avatar
Przemek Tredak committed
799
        )
ngoyal2707's avatar
ngoyal2707 committed
800

Przemek Tredak's avatar
Przemek Tredak committed
801
802
        if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
            attention_output, attention_bias, residual = self_attention_outputs
803
804
805
806
            hidden_states = self._bias_dropout_add(
                attention_output, attention_bias, residual, self.drop_path
            )
        elif not self.parallel_attention_mlp:
Przemek Tredak's avatar
Przemek Tredak committed
807
            attention_output, attention_bias = self_attention_outputs
808
809
            hidden_states = self._bias_dropout_add(
                attention_output, attention_bias, hidden_states, self.drop_path
Przemek Tredak's avatar
Przemek Tredak committed
810
811
812
813
814
            )

        # Cross attention.
        if self.layer_type == "decoder":
            inter_attention_outputs = self.inter_attention(
815
                hidden_states,
816
                attention_mask=enc_dec_attn_mask,
817
818
                attn_mask_type=enc_dec_attn_mask_type,
                window_size=enc_dec_window_size,
Przemek Tredak's avatar
Przemek Tredak committed
819
                encoder_output=encoder_output,
820
                inference_params=inference_params,
Przemek Tredak's avatar
Przemek Tredak committed
821
822
                is_first_microbatch=is_first_microbatch,
                checkpoint_core_attention=checkpoint_core_attention,
823
                rotary_pos_emb=rotary_pos_emb,
824
825
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias=core_attention_bias,
826
                alibi_slopes=alibi_slopes,
827
828
829
830
831
832
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_kv=cu_seqlens_kv,
                cu_seqlens_q_padded=cu_seqlens_q_padded,
                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
833
                fast_zero_fill=fast_zero_fill,
834
                pad_between_seqs=pad_between_seqs,
Przemek Tredak's avatar
Przemek Tredak committed
835
836
837
838
839
            )
            if self.apply_residual_connection_post_layernorm:
                attention_output, attention_bias, residual = inter_attention_outputs
            else:
                attention_output, attention_bias = inter_attention_outputs
840
841
842
                residual = hidden_states

            hidden_states = self._bias_dropout_add(attention_output, attention_bias, residual)
Przemek Tredak's avatar
Przemek Tredak committed
843
844
845

        # MLP.
        mlp_outputs = self.layernorm_mlp(
846
847
            hidden_states,
            is_first_microbatch=is_first_microbatch,
Przemek Tredak's avatar
Przemek Tredak committed
848
849
850
        )
        if self.apply_residual_connection_post_layernorm:
            mlp_output, mlp_bias, residual = mlp_outputs
851
852
853
854
855
            output = self._bias_dropout_add(mlp_output, mlp_bias, residual, self.drop_path)
        elif self.parallel_attention_mlp:
            output = self._bias_dropout_add(
                self_attention_outputs, mlp_outputs, hidden_states, self.drop_path
            )
Przemek Tredak's avatar
Przemek Tredak committed
856
857
        else:
            mlp_output, mlp_bias = mlp_outputs
858
859
860
861
862
863
864
865
866
867
            output = self._bias_dropout_add(mlp_output, mlp_bias, hidden_states, self.drop_path)

        # For BERT like architectures.
        if self.output_layernorm:
            output = self.layernorm(output)

        # output: [s, b, h]
        return output

    def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None):
868
869
870
871
872
873
        if (
            drop_path is None
            and bias is not None
            and bias.numel() != 0
            and not is_in_onnx_export_mode()
        ):
874
875
876
877
878
879
880
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
            else:
                bias_dropout_add_func = get_bias_dropout_add(self.training)
Przemek Tredak's avatar
Przemek Tredak committed
881
882

            with self.bias_dropout_add_exec_handler():
883
                output = bias_dropout_add_func(hidden_state, bias, residual, self.hidden_dropout)
Przemek Tredak's avatar
Przemek Tredak committed
884
        else:
885
            if bias is not None and bias.numel() != 0:
886
                hidden_state = hidden_state + bias
Przemek Tredak's avatar
Przemek Tredak committed
887
            out = torch.nn.functional.dropout(
888
                hidden_state, p=self.hidden_dropout, training=self.training
Przemek Tredak's avatar
Przemek Tredak committed
889
            )
890
891
            if drop_path is not None:
                out = drop_path(out)
ngoyal2707's avatar
ngoyal2707 committed
892
            output = residual + out
Przemek Tredak's avatar
Przemek Tredak committed
893
894

        return output