transformer.py 41.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
#
# See LICENSE for license information.

"""Transformer."""
import os
7
import warnings
Przemek Tredak's avatar
Przemek Tredak committed
8
from contextlib import nullcontext
9
from typing import Callable, List, Optional, Tuple, Union
Przemek Tredak's avatar
Przemek Tredak committed
10
11
12

import torch

13
from transformer_engine.pytorch import torch_version
14
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
15
from transformer_engine.debug.pytorch.debug_state import TEDebugState
16
17
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.inference import InferenceParams
Przemek Tredak's avatar
Przemek Tredak committed
18
19
20
21
22
23
24
25
26
27
from transformer_engine.pytorch.jit import (
    set_jit_fusion_options,
    warmup_jit_bias_dropout_add_all_dtypes,
    get_bias_dropout_add,
    bias_dropout_add_fused_train,
    bias_dropout_add_fused_inference,
)
from transformer_engine.pytorch.utils import (
    cast_if_needed,
    get_default_init_method,
28
    torch_get_autocast_gpu_dtype,
Przemek Tredak's avatar
Przemek Tredak committed
29
30
31
32
33
34
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    LayerTypes,
    dist_group_type,
)
35
from transformer_engine.pytorch.distributed import get_distributed_world_size
36
from transformer_engine.pytorch.export import is_in_onnx_export_mode
37
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
38

Przemek Tredak's avatar
Przemek Tredak committed
39

40
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
cyanguwa's avatar
cyanguwa committed
41
42


43
__all__ = ["TransformerLayer"]
cyanguwa's avatar
cyanguwa committed
44

Przemek Tredak's avatar
Przemek Tredak committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

class DropPath(torch.nn.Module):
    """Drop paths (Stochastic Depth) per sample
    (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob: float = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        """DropPath FWD"""

        if self.drop_prob == 0.0 or not self.training:
            return hidden_state
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
        shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
        random_tensor = keep_prob + torch.rand(
            shape, dtype=hidden_state.dtype, device=hidden_state.device
        )
        random_tensor.floor_()  # binarize
        output = hidden_state.div(keep_prob) * random_tensor
        return output


class TransformerLayer(torch.nn.Module):
72
    r"""
Przemek Tredak's avatar
Przemek Tredak committed
73
74
75
    TransformerLayer is made up of an attention block and a feedforward network (MLP).
    This standard layer is based on the paper "Attention Is All You Need".

76
    .. note::
77

78
79
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`self_attn_mask_type` includes `"padding"` or `"arbitrary"`.
80

Przemek Tredak's avatar
Przemek Tredak committed
81
82
83
84
85
86
87
88
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    ffn_hidden_size : int
                     intermediate size to which input samples are projected.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
89
90
91
92
93
94
95
96
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
Przemek Tredak's avatar
Przemek Tredak committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    hidden_dropout: float, default = 0.1
                   dropout probability for the dropout op after FC2 layer.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    apply_residual_connection_post_layernorm : bool, default = `False`
                                              if set to `True`, residual connections are taken
                                              from the output of layer norm (default is taken
                                              from input of layer norm)
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
    output_layernorm: bool, default = `False`
                     if set to `True`, layer normalization is applied on the output side,
                     after the final dropout-add. default behavior is to apply layer
                     normalization on the input side, before the QKV transformation.
123
124
125
126
127
    parallel_attention_mlp: bool, default = `False`
                           if set to `True`, self-attention and feedforward network are computed
                           based on the same input (in parallel) instead of sequentially.
                           Both blocks have an independent normalization.
                           This architecture is used in `Falcon` models.
Przemek Tredak's avatar
Przemek Tredak committed
128
129
130
131
132
    layer_type: {'encoder', 'decoder'}, default = `encoder`
               if set to `decoder`, an additional cross-attn block is added after self-attn.
               This can be used for structures like `T5` Transformer in conjunction with the
               `encoder` option.
    kv_channels: int, default = `None`
133
                number of query-key-value channels per attention head. defaults to
Przemek Tredak's avatar
Przemek Tredak committed
134
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
135
136
    self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                        'padding_causal_bottom_right', 'arbitrary'},
137
                        default = `causal`
138
139
140
141
142
                        type of attention mask passed into softmax operation for encoder.
                        Overridden by :attr:`self_attn_mask_type` in the `forward` method.
                        The forward arg is useful for dynamically changing mask types, e.g.
                        a different mask for training and inference. The init arg is useful
                        for cases involving compilation/tracing, e.g. ONNX export.
143
    window_size: Optional[Tuple[int, int]], default = `None`
144
145
146
                sliding window size for local attention in encoder, where query at position i
                attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k
                - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean
147
148
149
150
151
                no sliding window and causal mask specifically. Both `causal` and
                `causal_bottom_right` masks map to `window_size = (-1, 0)` and Transformer Engine
                distinguishes them based on `self_attn_mask_type` or `enc_dec_attn_mask_type`.
                Similar to :attr:`self_attn_mask_type`, `window_size` can be overridden by
                :attr:`window_size` in `forward` as well.
152
153
154
155
156
    enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
                           default = `no_mask`
                           type of attention mask passed into softmax operation for decoder.
    enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
                        sliding window size for local attention in decoder.
157
158
159
160
161
162
163
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
164
165
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
166
167
168
169
170
171
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
172
173
    rotary_pos_interleaved : bool, default = `False`
                            whether to use interleaved rotary position embeddings.
ngoyal2707's avatar
ngoyal2707 committed
174
175
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
176
177
    activation : str, default = 'gelu'
          Type of activation used in MLP block.
178
          Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu' and 'srelu'.
179
    device : Union[torch.device, str], default = "cuda"
180
          The device on which the parameters of the model will be allocated. It is the user's
181
182
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
183
    attn_input_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
184
                         This controls whether the dimensions of the
185
186
187
188
                         intermediate hidden states is 'sequence first' ('sbhd'), 'batch first' ('bshd'),
                         or 'token first' ('thd'). `s` stands for the sequence length, `b` batch size,
                         `t` the total number of tokens, `h` the number of heads, `d` head size.
                         Note that these formats are very closely
189
190
                         related to the `qkv_format` in the `MultiHeadAttention`
                         and `DotProductAttention` modules.
191
192
    name: str, default = `None`
        name of the module, currently used for debugging purposes.
ngoyal2707's avatar
ngoyal2707 committed
193

Przemek Tredak's avatar
Przemek Tredak committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
215
216
217
218
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
219
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
Przemek Tredak's avatar
Przemek Tredak committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    seq_length: int
               sequence length of input samples. Needed for JIT Warmup, a technique where jit
               fused functions are warmed up before training to ensure same kernels are used for
               forward propogation and activation recompute phase.
    micro_batch_size: int
                     batch size per training step. Needed for JIT Warmup, a technique where jit
                     fused functions are warmed up before training to ensure same kernels are
                     used for forward propogation and activation recompute phase.
    drop_path_rate: float, default = 0.0
                   when > 0.0, applies stochastic depth per sample in
                   the main path of the residual block.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
239
240
241
242
243
244
245
246
    use_qk_norm: bool, default = 'False'
                    if set to `True`, L2 normalization is applied to query and key tensors
                    after RoPE (if applicable) but before attention computation.
                    This follows the Llama4 approach for QK normalization to improve
                    training stability and model performance.
    qk_norm_eps: float, default = 1e-6
                    epsilon value for L2 normalization of query and key tensors.
                    Only used when `use_qk_norm` is True.
Przemek Tredak's avatar
Przemek Tredak committed
247
248
249
250
251
252
253
    """

    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        num_attention_heads: int,
254
        num_gqa_groups: Optional[int] = None,
Przemek Tredak's avatar
Przemek Tredak committed
255
256
257
258
259
260
261
        layernorm_epsilon: float = 1e-5,
        hidden_dropout: float = 0.1,
        attention_dropout: float = 0.1,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
        layer_number: Optional[int] = None,
        kv_channels: Optional[int] = None,
262
        self_attn_mask_type: str = "causal",
263
        window_size: Optional[Tuple[int, int]] = None,
264
265
        enc_dec_attn_mask_type: str = "no_mask",
        enc_dec_window_size: Optional[Tuple[int, int]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
266
267
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
268
        params_dtype: Optional[torch.dtype] = None,
Przemek Tredak's avatar
Przemek Tredak committed
269
270
271
272
273
274
275
        get_rng_state_tracker: Optional[Callable] = None,
        fuse_wgrad_accumulation: bool = False,
        seq_length: Optional[int] = None,
        micro_batch_size: Optional[int] = None,
        sequence_parallel: bool = False,
        apply_residual_connection_post_layernorm: bool = False,
        output_layernorm: bool = False,
276
        parallel_attention_mlp: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
277
278
279
280
        layer_type: str = "encoder",
        drop_path_rate: float = 0.0,
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
281
        rotary_pos_interleaved: bool = False,
282
        zero_centered_gamma: bool = False,
283
        qkv_weight_interleaved: bool = True,
284
        ub_tp_comm_overlap: bool = False,
285
286
        ub_overlap_ag: bool = True,
        ub_overlap_rs: bool = True,
Jaemin Choi's avatar
Jaemin Choi committed
287
        ub_overlap_rs_dgrad: bool = False,
288
289
        ub_bulk_dgrad: bool = True,
        ub_bulk_wgrad: bool = True,
ngoyal2707's avatar
ngoyal2707 committed
290
        bias: bool = True,
291
        activation: str = "gelu",
292
        normalization: str = "LayerNorm",
293
        device: Union[torch.device, str] = "cuda",
294
        attn_input_format: str = "sbhd",
295
        name: str = None,
296
297
        use_qk_norm: bool = False,
        qk_norm_eps: float = 1e-6,
Przemek Tredak's avatar
Przemek Tredak committed
298
299
300
    ) -> None:
        super().__init__()

301
        self.self_attn_mask_type = self_attn_mask_type
302
        self.window_size = window_size
303
        self.enc_dec_attn_mask_type = enc_dec_attn_mask_type
304
        self.enc_dec_window_size = enc_dec_window_size
305
        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
306
307
        ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
        ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
308
309
        ub_overlap_ag = ub_tp_comm_overlap and ub_overlap_ag
        ub_overlap_rs = ub_tp_comm_overlap and ub_overlap_rs
Jaemin Choi's avatar
Jaemin Choi committed
310
        ub_overlap_rs_dgrad = ub_tp_comm_overlap and ub_overlap_rs_dgrad
311

Przemek Tredak's avatar
Przemek Tredak committed
312
313
314
315
        bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
        self.layer_number = layer_number
        self.output_layernorm = output_layernorm
        self.layer_type = layer_type
316
        self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
317

318
319
        if parallel_attention_mlp:
            assert self.layer_type == "encoder", "parallel_attention requires layer_type='encoder'"
320
321
322
323
            assert not self.apply_residual_connection_post_layernorm, (
                "parallel_attention and apply_residual_connection_post_layernorm "
                "not supported simultaneously."
            )
324
325
326
327
328
329
            assert (
                not self.output_layernorm
            ), "parallel_attention and output_layernorm not supported simultaneously"

        self.parallel_attention_mlp = parallel_attention_mlp

Przemek Tredak's avatar
Przemek Tredak committed
330
331
332
333
334
335
336
        assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"

        if not fuse_qkv_params:
            assert (
                not fuse_wgrad_accumulation
            ), "Gradient accumulation fusion requires single QKV parameter."

337
338
339
        if not fuse_qkv_params:
            qkv_weight_interleaved = False

340
        self.kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)
Przemek Tredak's avatar
Przemek Tredak committed
341
342
343
344
345
346

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()

347
348
349
        self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
        self.seq_length = seq_length
Przemek Tredak's avatar
Przemek Tredak committed
350
351
352

        self.get_rng_state_tracker = get_rng_state_tracker

353
354
        self.attn_input_format = attn_input_format

355
356
        self.name = name

Przemek Tredak's avatar
Przemek Tredak committed
357
358
359
360
361
362
363
364
365
366
367
368
        attention_args = (
            hidden_size,
            num_attention_heads,
            self.kv_channels,
            attention_dropout,
            layernorm_epsilon,
            init_method,
            output_layer_init_method,
        )
        common_attention_kwargs = {
            "layer_number": layer_number,
            "tp_group": tp_group,
369
            "tp_size": self.tp_size,
370
            "num_gqa_groups": num_gqa_groups,
Przemek Tredak's avatar
Przemek Tredak committed
371
372
373
374
375
376
377
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": self.sequence_parallel,
            "params_dtype": params_dtype,
            "return_layernorm_output": apply_residual_connection_post_layernorm,
            "set_parallel_mode": set_parallel_mode,
            "fuse_qkv_params": fuse_qkv_params,
cyanguwa's avatar
cyanguwa committed
378
            "zero_centered_gamma": zero_centered_gamma,
379
            "qkv_weight_interleaved": qkv_weight_interleaved,
380
            "rotary_pos_interleaved": rotary_pos_interleaved,
381
382
383
384
385
386
            "ub_bulk_wgrad": ub_bulk_wgrad,
            "ub_bulk_dgrad": ub_bulk_dgrad,
            "ub_overlap_ag": ub_overlap_ag,
            "ub_overlap_rs": ub_overlap_rs,
            "ub_overlap_rs_dgrad": ub_overlap_rs_dgrad,
            "qkv_format": self.attn_input_format,
387
388
            "seq_length": seq_length,
            "micro_batch_size": micro_batch_size,
Przemek Tredak's avatar
Przemek Tredak committed
389
390
        }

391
        self.self_attention = MultiheadAttention(
Przemek Tredak's avatar
Przemek Tredak committed
392
393
394
395
            *attention_args,
            **common_attention_kwargs,
            input_layernorm=not output_layernorm,
            attention_type="self",
ngoyal2707's avatar
ngoyal2707 committed
396
            bias=bias,
397
            return_bias=not self.parallel_attention_mlp,
398
            normalization=normalization,
399
            device=device,
400
401
            use_qk_norm=use_qk_norm,
            qk_norm_eps=qk_norm_eps,
402
            name=name + ".self_attention" if name is not None else None,
Przemek Tredak's avatar
Przemek Tredak committed
403
404
405
        )

        if layer_type == "decoder":
406
            self.inter_attention = MultiheadAttention(
Przemek Tredak's avatar
Przemek Tredak committed
407
408
                *attention_args,
                **common_attention_kwargs,
409
                attn_mask_type=enc_dec_attn_mask_type,
Przemek Tredak's avatar
Przemek Tredak committed
410
411
                input_layernorm=True,
                attention_type="cross",
ngoyal2707's avatar
ngoyal2707 committed
412
                bias=bias,
413
                return_bias=True,
414
                normalization=normalization,
415
                device=device,
416
417
                use_qk_norm=use_qk_norm,
                qk_norm_eps=qk_norm_eps,
418
                name=name + ".inter_attention" if name is not None else None,
Przemek Tredak's avatar
Przemek Tredak committed
419
420
            )

421
        # LayerNorm -> activation(Linear + Bias) -> Linear
Przemek Tredak's avatar
Przemek Tredak committed
422
423
        # parallel_mode not supported for LayerNormMLP,
        # FC1 is CPL and FC2 is RPL
424
425
        # In the case of GLU activation, FC1 handles both
        # Linear layers before the activation
Przemek Tredak's avatar
Przemek Tredak committed
426
427
428
429
430
431
        self.layernorm_mlp = LayerNormMLP(
            hidden_size,
            ffn_hidden_size,
            eps=layernorm_epsilon,
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            tp_group=tp_group,
432
            tp_size=self.tp_size,
Przemek Tredak's avatar
Przemek Tredak committed
433
434
435
            get_rng_state_tracker=get_rng_state_tracker,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
ngoyal2707's avatar
ngoyal2707 committed
436
            bias=bias,
437
            return_bias=not self.parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
438
439
440
441
442
443
            sequence_parallel=self.sequence_parallel,
            params_dtype=params_dtype,
            return_layernorm_output=apply_residual_connection_post_layernorm,
            seq_length=seq_length,
            micro_batch_size=micro_batch_size,
            set_parallel_mode=set_parallel_mode,
444
            zero_centered_gamma=zero_centered_gamma,
445
446
            ub_bulk_wgrad=ub_bulk_wgrad,
            ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
447
            ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
448
449
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
450
            activation=activation,
451
            normalization=normalization,
452
            device=device,
453
            name=name + ".layernorm_mlp" if name is not None else None,
Przemek Tredak's avatar
Przemek Tredak committed
454
455
456
457
458
459
460
        )

        self.hidden_dropout = hidden_dropout
        self.bias_dropout_fusion = bias_dropout_fusion
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None

        # Set bias+dropout+add fusion grad_enable execution handler.
461
        use_nvfuser = torch_version() >= (1, 10, 0) and torch_version() < (2, 2, 0)
462
        self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
Przemek Tredak's avatar
Przemek Tredak committed
463
464
465
466
467

        if self.bias_dropout_fusion:
            set_jit_fusion_options()
            if seq_length and micro_batch_size:
                if self.sequence_parallel:
468
                    seq_length = seq_length // self.tp_size
469
                warmup_jit_bias_dropout_add_all_dtypes(hidden_size, seq_length, micro_batch_size)
Przemek Tredak's avatar
Przemek Tredak committed
470

471
        norm_module = {
472
473
            "LayerNorm": LayerNorm,
            "RMSNorm": RMSNorm,
474
        }
Przemek Tredak's avatar
Przemek Tredak committed
475
        if self.output_layernorm:
476
            self.layernorm = norm_module[normalization](
Przemek Tredak's avatar
Przemek Tredak committed
477
478
479
480
                hidden_size,
                eps=layernorm_epsilon,
                sequence_parallel=self.sequence_parallel,
                params_dtype=params_dtype,
481
482
                zero_centered_gamma=zero_centered_gamma,
                device=device,
Przemek Tredak's avatar
Przemek Tredak committed
483
484
485
            )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
486
487
488
489
490
491
492
493
494
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
Przemek Tredak's avatar
Przemek Tredak committed
495
496
497
498
499
500
501
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_tensor_parallel_group"):
                child.set_tensor_parallel_group(tp_group)

502
503
504
505
506
507
508
509
510
    def reset_fp8_meta_tensors(self) -> None:
        """Set TP group"""
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "reset_fp8_meta_tensors"):
                child.reset_fp8_meta_tensors()

511
    def set_context_parallel_group(
512
        self,
513
        cp_group: Union[dist_group_type, List[dist_group_type], None],
514
        cp_global_ranks: List[int],
515
        cp_stream: torch.cuda.Stream,
516
        cp_comm_type: str = "p2p",
517
    ) -> None:
518
519
520
521
522
523
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
524
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
525
                  context parallel process group.
526
527
528
                  ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
                  List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
                  and cp_group[1] are for a2a and p2p communications respectively.
529
530
531
532
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
533
        cp_comm_type : str, default = `p2p`
534
                      inter-gpu communication type for context parallelism.
535
                      Can be "p2p" or "all_gather" or "a2a", or "a2a+p2p".
536
537
538
539
540
541
                      "p2p": Exchange KV chunks with P2P communications in ring topology.
                             P2P is async and can be overlapped with attention compute.
                      "all_gather": All-gather to get full sequence of KV before attention.
                                    The all-gather is not async, and cannot be overlapped.
                      "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                             group, and gather to get full sequence of QKV.
542
543
544
                      "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                      across each CP sub-group (e.g., via NVLink), then exchanging KV with
                      p2p between sub-groups (e.g., via IBLink).
545
        """
546
547
548
549
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
550
            if hasattr(child, "set_context_parallel_group"):
551
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type)
552

Przemek Tredak's avatar
Przemek Tredak committed
553
554
555
    def forward(
        self,
        hidden_states: torch.Tensor,
cyanguwa's avatar
cyanguwa committed
556
        attention_mask: Optional[torch.Tensor] = None,
557
        self_attn_mask_type: Optional[str] = None,
558
        window_size: Optional[Tuple[int, int]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
559
        encoder_output: Optional[torch.Tensor] = None,
560
        enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
561
562
        enc_dec_attn_mask_type: Optional[str] = None,
        enc_dec_window_size: Optional[Tuple[int, int]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
563
        is_first_microbatch: Optional[bool] = None,
cyanguwa's avatar
cyanguwa committed
564
        checkpoint_core_attention: bool = False,
565
        inference_params: Optional[InferenceParams] = None,
566
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
567
568
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
569
        alibi_slopes: Optional[torch.Tensor] = None,
570
571
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
572
573
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
574
575
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
576
        fast_zero_fill: bool = True,
577
        pad_between_seqs: Optional[bool] = None,
Przemek Tredak's avatar
Przemek Tredak committed
578
579
580
581
    ) -> torch.Tensor:
        """
        Transformer Layer: attention block and a feedforward network (MLP)

582
583
        .. note::

584
585
            Argument :attr:`attention_mask` is only used when :attr:`self_attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
586

Przemek Tredak's avatar
Przemek Tredak committed
587
588
589
        Parameters
        ----------
        hidden_states : torch.Tensor
590
            Input tensor.
591
        attention_mask : Optional[torch.Tensor], default = `None`
592
593
594
595
596
597
            Boolean tensor used to mask out self-attention softmax input. It should be
            in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable
            to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`"
            mask. It should be `None` for causal masks and "`no_mask`" type.
            A `True` value means the corresponding position is masked out and
            a `False` means that position is allowed to participate in attention.
598
        self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal',
599
600
601
602
603
604
            'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'},
            default = `causal`
            Type of attention mask passed into softmax operation for encoder.
            By default, causal masks are aligned to the top left corner of
            the softmax matrix. When "`bottom_right`" is specified in the mask type,
            causal masks are aligned to the bottom right corner.
605
        window_size: Optional[Tuple[int, int]], default = `None`
606
            Sliding window size for local attention in encoder.
cyanguwa's avatar
cyanguwa committed
607
        encoder_output : Optional[torch.Tensor], default = `None`
608
609
            Output of the encoder block to be fed into the decoder block if using
            `layer_type="decoder"`.
610
        enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
611
612
613
614
615
616
617
            default = `None`. Boolean tensors used to mask out inter-attention softmax input if
            using `layer_type="decoder"`. It should be a tuple of two masks in
            [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks.
            It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
            for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`".
            A `True` value means the corresponding position is masked out and a `False`
            means that position is allowed to participate in attention.
618
        enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
619
620
            default = `None`
            Type of attention mask passed into softmax operation for decoder.
621
        enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
622
            Sliding window size for local attention in decoder.
Przemek Tredak's avatar
Przemek Tredak committed
623
        is_first_microbatch : {True, False, None}, default = None
624
625
626
627
628
629
630
631
632
633
634
635
            During training using either gradient accumulation or
            pipeline parallelism a minibatch of data is further split
            into microbatches. Between the microbatches of the same minibatch
            the model weights are not updated. Setting this parameter indicates
            whether the current microbatch is the first in a minibatch or not.
            When set, this parameter enables additional optimizations:

            * during FP8 training, it allows caching of the FP8 versions of
              the weights
            * it also allows skipping gradient accumulation during the
              first microbatch (since it is the first gradient being
              produced)
cyanguwa's avatar
cyanguwa committed
636
        checkpoint_core_attention: bool, default = `False`
637
638
639
640
            If true, forward activations for core attention are recomputed
            during the backward pass in order to save memory that would
            otherwise be occupied to store the forward activations until
            backprop.
641
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
642
643
            Embeddings for query and key tensors for applying rotary position
            embedding. By default no input embedding is applied.
644
        core_attention_bias_type: str, default = `no_bias`
645
            Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
646
        core_attention_bias: Optional[torch.Tensor], default = `None`
647
            Bias tensor for Q * K.T
648
        alibi_slopes: Optional[torch.Tensor], default = `None`
649
650
651
            ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
            It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
            to the attention score of query i and key j.
652
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
653
654
655
            Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
            with shape [batch_size + 1] and dtype torch.int32.
            Used by encoders, or decoders' self-attention.
656
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
657
658
659
660
661
662
663
664
665
666
667
            Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
            and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
            Used by decoders' cross-attention.
        cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
            Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`,
            with shape [batch_size + 1] and dtype torch.int32. Set to `cu_seqlens_q` if None.
            Used by encoders, or decoders' self-attention.
        cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
            Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
            and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
            Set to `cu_seqlens_kv` if None. Used by decoders' cross-attention.
668
        max_seqlen_q: Optional[int], default = `None`
669
670
            Maximum sequence length in `query_layer`.
            Calculated from `cu_seqlens_q_padded` if not provided.
671
        max_seqlen_kv: Optional[int], default = `None`
672
673
            Maximum sequence length in `key_layer` and `value_layer`.
            Calculated from `cu_seqlens_kv_padded` if not provided.
674
        fast_zero_fill: bool, default = `True`
675
            Whether to set output tensors to 0 or not before use.
676
        inference_params: InferenceParams, default = None
677
678
            Inference parameters that are passed to the main model in order
            to efficiently calculate and store the context during inference.
679
680
        pad_between_seqs: Optional[bool], default = `None`
            If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
681
682
            If true, there are padding tokens between individual sequences in a packed batch,
            i.e. qkv_format = 'thd'.
Przemek Tredak's avatar
Przemek Tredak committed
683
684
        """

685
        if self_attn_mask_type is None:
686
            self_attn_mask_type = self.self_attn_mask_type
687
688
        if window_size is None:
            window_size = self.window_size
689
690
691
692
        if enc_dec_attn_mask_type is None:
            enc_dec_attn_mask_type = self.enc_dec_attn_mask_type
        if enc_dec_window_size is None:
            enc_dec_window_size = self.enc_dec_window_size
693
694
695
696

        assert (
            self_attn_mask_type in AttnMaskTypes
        ), f"self_attn_mask_type {self_attn_mask_type} not supported"
697
698
699
        assert (
            enc_dec_attn_mask_type in AttnMaskTypes
        ), f"enc_dec_attn_mask_type {enc_dec_attn_mask_type} not supported"
700

701
702
        hidden_states = hidden_states.contiguous()

703
704
705
706
707
        if self.sequence_parallel and self.seq_length is not None:
            assert (
                hidden_states.shape[0] == self.seq_length // self.tp_size
            ), "Sequence dimension must be split across TP group when using sequence parallel."

708
709
710
        if (
            "padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary"
        ) and attention_mask is not None:
711
712
713
            assert all(
                attention_mask[i].dtype == torch.bool for i in range(len(attention_mask))
            ), "Attention mask must be a boolean tensor or a list/tuple of two boolean tensors"
714
715
716
717
718
719
        if (
            "padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary"
        ) and enc_dec_attn_mask is not None:
            assert all(
                enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask))
            ), "Encoder-decoder attention mask must be boolean tensor(s)"
720

721
722
723
        if TEDebugState.debug_enabled:
            TransformerEngineBaseModule._validate_name(self)

Przemek Tredak's avatar
Przemek Tredak committed
724
725
        # For AMP
        if torch.is_autocast_enabled():
726
            hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype())
Przemek Tredak's avatar
Przemek Tredak committed
727
728
729
730

        # Self attention.
        self_attention_outputs = self.self_attention(
            hidden_states,
731
732
            attention_mask=attention_mask,
            attn_mask_type=self_attn_mask_type,
733
            window_size=window_size,
Przemek Tredak's avatar
Przemek Tredak committed
734
735
736
            inference_params=inference_params,
            is_first_microbatch=is_first_microbatch,
            checkpoint_core_attention=checkpoint_core_attention,
737
            rotary_pos_emb=rotary_pos_emb,
738
739
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
740
            alibi_slopes=alibi_slopes,
741
            cu_seqlens_q=cu_seqlens_q,
742
743
744
            cu_seqlens_kv=cu_seqlens_q,
            cu_seqlens_q_padded=cu_seqlens_q_padded,
            cu_seqlens_kv_padded=cu_seqlens_q_padded,
745
            max_seqlen_q=max_seqlen_q,
746
            max_seqlen_kv=max_seqlen_q,
747
            fast_zero_fill=fast_zero_fill,
748
            pad_between_seqs=pad_between_seqs,
Przemek Tredak's avatar
Przemek Tredak committed
749
        )
ngoyal2707's avatar
ngoyal2707 committed
750

Przemek Tredak's avatar
Przemek Tredak committed
751
752
        if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
            attention_output, attention_bias, residual = self_attention_outputs
753
754
755
756
            hidden_states = self._bias_dropout_add(
                attention_output, attention_bias, residual, self.drop_path
            )
        elif not self.parallel_attention_mlp:
Przemek Tredak's avatar
Przemek Tredak committed
757
            attention_output, attention_bias = self_attention_outputs
758
759
            hidden_states = self._bias_dropout_add(
                attention_output, attention_bias, hidden_states, self.drop_path
Przemek Tredak's avatar
Przemek Tredak committed
760
761
762
763
764
            )

        # Cross attention.
        if self.layer_type == "decoder":
            inter_attention_outputs = self.inter_attention(
765
                hidden_states,
766
                attention_mask=enc_dec_attn_mask,
767
768
                attn_mask_type=enc_dec_attn_mask_type,
                window_size=enc_dec_window_size,
Przemek Tredak's avatar
Przemek Tredak committed
769
                encoder_output=encoder_output,
770
                inference_params=inference_params,
Przemek Tredak's avatar
Przemek Tredak committed
771
772
                is_first_microbatch=is_first_microbatch,
                checkpoint_core_attention=checkpoint_core_attention,
773
                rotary_pos_emb=rotary_pos_emb,
774
775
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias=core_attention_bias,
776
                alibi_slopes=alibi_slopes,
777
778
779
780
781
782
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_kv=cu_seqlens_kv,
                cu_seqlens_q_padded=cu_seqlens_q_padded,
                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
783
                fast_zero_fill=fast_zero_fill,
784
                pad_between_seqs=pad_between_seqs,
Przemek Tredak's avatar
Przemek Tredak committed
785
786
787
788
789
            )
            if self.apply_residual_connection_post_layernorm:
                attention_output, attention_bias, residual = inter_attention_outputs
            else:
                attention_output, attention_bias = inter_attention_outputs
790
791
792
                residual = hidden_states

            hidden_states = self._bias_dropout_add(attention_output, attention_bias, residual)
Przemek Tredak's avatar
Przemek Tredak committed
793
794
795

        # MLP.
        mlp_outputs = self.layernorm_mlp(
796
797
            hidden_states,
            is_first_microbatch=is_first_microbatch,
Przemek Tredak's avatar
Przemek Tredak committed
798
799
800
        )
        if self.apply_residual_connection_post_layernorm:
            mlp_output, mlp_bias, residual = mlp_outputs
801
802
803
804
805
            output = self._bias_dropout_add(mlp_output, mlp_bias, residual, self.drop_path)
        elif self.parallel_attention_mlp:
            output = self._bias_dropout_add(
                self_attention_outputs, mlp_outputs, hidden_states, self.drop_path
            )
Przemek Tredak's avatar
Przemek Tredak committed
806
807
        else:
            mlp_output, mlp_bias = mlp_outputs
808
809
810
811
812
813
814
815
816
817
            output = self._bias_dropout_add(mlp_output, mlp_bias, hidden_states, self.drop_path)

        # For BERT like architectures.
        if self.output_layernorm:
            output = self.layernorm(output)

        # output: [s, b, h]
        return output

    def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None):
818
819
820
821
822
823
        if (
            drop_path is None
            and bias is not None
            and bias.numel() != 0
            and not is_in_onnx_export_mode()
        ):
824
825
826
827
828
829
830
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
            else:
                bias_dropout_add_func = get_bias_dropout_add(self.training)
Przemek Tredak's avatar
Przemek Tredak committed
831
832

            with self.bias_dropout_add_exec_handler():
833
                output = bias_dropout_add_func(hidden_state, bias, residual, self.hidden_dropout)
Przemek Tredak's avatar
Przemek Tredak committed
834
        else:
835
            if bias is not None and bias.numel() != 0:
836
                hidden_state = hidden_state + bias
Przemek Tredak's avatar
Przemek Tredak committed
837
            out = torch.nn.functional.dropout(
838
                hidden_state, p=self.hidden_dropout, training=self.training
Przemek Tredak's avatar
Przemek Tredak committed
839
            )
840
841
            if drop_path is not None:
                out = drop_path(out)
ngoyal2707's avatar
ngoyal2707 committed
842
            output = residual + out
Przemek Tredak's avatar
Przemek Tredak committed
843
844

        return output