transformer.py 38.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
#
# See LICENSE for license information.

"""Transformer."""
import os
7
import warnings
Przemek Tredak's avatar
Przemek Tredak committed
8
from contextlib import nullcontext
9
from typing import Callable, List, Optional, Tuple, Union
Przemek Tredak's avatar
Przemek Tredak committed
10
11
12

import torch

13
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
14
15
16
17
18
from transformer_engine.pytorch.attention import (
    InferenceParams,
    MultiheadAttention,
    check_set_window_size,
)
Przemek Tredak's avatar
Przemek Tredak committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from transformer_engine.pytorch.jit import (
    set_jit_fusion_options,
    warmup_jit_bias_dropout_add_all_dtypes,
    get_bias_dropout_add,
    bias_dropout_add_fused_train,
    bias_dropout_add_fused_inference,
)
from transformer_engine.pytorch.utils import (
    cast_if_needed,
    get_default_init_method,
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    LayerTypes,
    dist_group_type,
)
35
36
from transformer_engine.pytorch.distributed import get_distributed_world_size

Przemek Tredak's avatar
Przemek Tredak committed
37

38
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
cyanguwa's avatar
cyanguwa committed
39
40


41
__all__ = ["TransformerLayer"]
cyanguwa's avatar
cyanguwa committed
42

Przemek Tredak's avatar
Przemek Tredak committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

class DropPath(torch.nn.Module):
    """Drop paths (Stochastic Depth) per sample
    (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob: float = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        """DropPath FWD"""

        if self.drop_prob == 0.0 or not self.training:
            return hidden_state
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
        shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
        random_tensor = keep_prob + torch.rand(
            shape, dtype=hidden_state.dtype, device=hidden_state.device
        )
        random_tensor.floor_()  # binarize
        output = hidden_state.div(keep_prob) * random_tensor
        return output


class TransformerLayer(torch.nn.Module):
70
    r"""
Przemek Tredak's avatar
Przemek Tredak committed
71
72
73
    TransformerLayer is made up of an attention block and a feedforward network (MLP).
    This standard layer is based on the paper "Attention Is All You Need".

74
    .. note::
75

76
77
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`self_attn_mask_type` includes `"padding"` or `"arbitrary"`.
78

Przemek Tredak's avatar
Przemek Tredak committed
79
80
81
82
83
84
85
86
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    ffn_hidden_size : int
                     intermediate size to which input samples are projected.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
87
88
89
90
91
92
93
94
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
Przemek Tredak's avatar
Przemek Tredak committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    hidden_dropout: float, default = 0.1
                   dropout probability for the dropout op after FC2 layer.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    apply_residual_connection_post_layernorm : bool, default = `False`
                                              if set to `True`, residual connections are taken
                                              from the output of layer norm (default is taken
                                              from input of layer norm)
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
    output_layernorm: bool, default = `False`
                     if set to `True`, layer normalization is applied on the output side,
                     after the final dropout-add. default behavior is to apply layer
                     normalization on the input side, before the QKV transformation.
121
122
123
124
125
    parallel_attention_mlp: bool, default = `False`
                           if set to `True`, self-attention and feedforward network are computed
                           based on the same input (in parallel) instead of sequentially.
                           Both blocks have an independent normalization.
                           This architecture is used in `Falcon` models.
Przemek Tredak's avatar
Przemek Tredak committed
126
127
128
129
130
    layer_type: {'encoder', 'decoder'}, default = `encoder`
               if set to `decoder`, an additional cross-attn block is added after self-attn.
               This can be used for structures like `T5` Transformer in conjunction with the
               `encoder` option.
    kv_channels: int, default = `None`
131
                number of query-key-value channels per attention head. defaults to
Przemek Tredak's avatar
Przemek Tredak committed
132
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
133
134
    self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                        'padding_causal_bottom_right', 'arbitrary'},
135
                        default = `causal`
136
137
138
139
140
                        type of attention mask passed into softmax operation for encoder.
                        Overridden by :attr:`self_attn_mask_type` in the `forward` method.
                        The forward arg is useful for dynamically changing mask types, e.g.
                        a different mask for training and inference. The init arg is useful
                        for cases involving compilation/tracing, e.g. ONNX export.
141
    window_size: Optional[Tuple[int, int]], default = `None`
142
143
144
                sliding window size for local attention in encoder, where query at position i
                attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k
                - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean
145
146
147
148
149
                no sliding window and causal mask specifically. Both `causal` and
                `causal_bottom_right` masks map to `window_size = (-1, 0)` and Transformer Engine
                distinguishes them based on `self_attn_mask_type` or `enc_dec_attn_mask_type`.
                Similar to :attr:`self_attn_mask_type`, `window_size` can be overridden by
                :attr:`window_size` in `forward` as well.
150
151
152
153
154
    enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
                           default = `no_mask`
                           type of attention mask passed into softmax operation for decoder.
    enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
                        sliding window size for local attention in decoder.
155
156
157
158
159
160
161
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
162
163
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
164
165
166
167
168
169
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
ngoyal2707's avatar
ngoyal2707 committed
170
171
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
172
173
    activation : str, default = 'gelu'
          Type of activation used in MLP block.
174
          Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu' and 'srelu'.
175
    device : Union[torch.device, str], default = "cuda"
176
          The device on which the parameters of the model will be allocated. It is the user's
177
178
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
179
180
181
182
183
184
185
186
    attn_input_format: {'sbhd', 'bshd'}, default = 'sbhd'
                         This controls whether the dimensions of the
                         intermediate hidden states is 'batch first' ('bshd') or
                         'sequence first' ('sbhd'). `s` stands for the sequence
                         length, `b` batch size, `h` the number of heads, `d`
                         head size. Note that these formats are very closely
                         related to the `qkv_format` in the `MultiHeadAttention`
                         and `DotProductAttention` modules.
ngoyal2707's avatar
ngoyal2707 committed
187

Przemek Tredak's avatar
Przemek Tredak committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
209
210
211
212
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
213
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
Przemek Tredak's avatar
Przemek Tredak committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    seq_length: int
               sequence length of input samples. Needed for JIT Warmup, a technique where jit
               fused functions are warmed up before training to ensure same kernels are used for
               forward propogation and activation recompute phase.
    micro_batch_size: int
                     batch size per training step. Needed for JIT Warmup, a technique where jit
                     fused functions are warmed up before training to ensure same kernels are
                     used for forward propogation and activation recompute phase.
    drop_path_rate: float, default = 0.0
                   when > 0.0, applies stochastic depth per sample in
                   the main path of the residual block.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
    """

    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        num_attention_heads: int,
240
        num_gqa_groups: Optional[int] = None,
Przemek Tredak's avatar
Przemek Tredak committed
241
242
243
244
245
246
247
        layernorm_epsilon: float = 1e-5,
        hidden_dropout: float = 0.1,
        attention_dropout: float = 0.1,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
        layer_number: Optional[int] = None,
        kv_channels: Optional[int] = None,
248
        self_attn_mask_type: str = "causal",
249
        window_size: Optional[Tuple[int, int]] = None,
250
251
        enc_dec_attn_mask_type: str = "no_mask",
        enc_dec_window_size: Optional[Tuple[int, int]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
252
253
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
254
        params_dtype: Optional[torch.dtype] = None,
Przemek Tredak's avatar
Przemek Tredak committed
255
256
257
258
259
260
261
        get_rng_state_tracker: Optional[Callable] = None,
        fuse_wgrad_accumulation: bool = False,
        seq_length: Optional[int] = None,
        micro_batch_size: Optional[int] = None,
        sequence_parallel: bool = False,
        apply_residual_connection_post_layernorm: bool = False,
        output_layernorm: bool = False,
262
        parallel_attention_mlp: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
263
264
265
266
        layer_type: str = "encoder",
        drop_path_rate: float = 0.0,
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
267
        zero_centered_gamma: bool = False,
268
        qkv_weight_interleaved: bool = True,
269
        ub_tp_comm_overlap: bool = False,
270
271
        ub_overlap_ag: bool = True,
        ub_overlap_rs: bool = True,
Jaemin Choi's avatar
Jaemin Choi committed
272
        ub_overlap_rs_dgrad: bool = False,
273
274
        ub_bulk_dgrad: bool = True,
        ub_bulk_wgrad: bool = True,
ngoyal2707's avatar
ngoyal2707 committed
275
        bias: bool = True,
276
        activation: str = "gelu",
277
        normalization: str = "LayerNorm",
278
        device: Union[torch.device, str] = "cuda",
279
        attn_input_format: str = "sbhd",
Przemek Tredak's avatar
Przemek Tredak committed
280
281
282
    ) -> None:
        super().__init__()

283
        self.self_attn_mask_type = self_attn_mask_type
284
285
286
287
288
        self.window_size = check_set_window_size(self_attn_mask_type, window_size)
        self.enc_dec_attn_mask_type = enc_dec_attn_mask_type
        self.enc_dec_window_size = check_set_window_size(
            enc_dec_attn_mask_type, enc_dec_window_size
        )
289
        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
290
291
        ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
        ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
292
293
        ub_overlap_ag = ub_tp_comm_overlap and ub_overlap_ag
        ub_overlap_rs = ub_tp_comm_overlap and ub_overlap_rs
Jaemin Choi's avatar
Jaemin Choi committed
294
        ub_overlap_rs_dgrad = ub_tp_comm_overlap and ub_overlap_rs_dgrad
295

Przemek Tredak's avatar
Przemek Tredak committed
296
297
298
299
        bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
        self.layer_number = layer_number
        self.output_layernorm = output_layernorm
        self.layer_type = layer_type
300
        self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
301

302
303
        if parallel_attention_mlp:
            assert self.layer_type == "encoder", "parallel_attention requires layer_type='encoder'"
304
305
306
307
            assert not self.apply_residual_connection_post_layernorm, (
                "parallel_attention and apply_residual_connection_post_layernorm "
                "not supported simultaneously."
            )
308
309
310
311
312
313
            assert (
                not self.output_layernorm
            ), "parallel_attention and output_layernorm not supported simultaneously"

        self.parallel_attention_mlp = parallel_attention_mlp

Przemek Tredak's avatar
Przemek Tredak committed
314
315
316
317
318
319
320
        assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"

        if not fuse_qkv_params:
            assert (
                not fuse_wgrad_accumulation
            ), "Gradient accumulation fusion requires single QKV parameter."

321
322
323
        if not fuse_qkv_params:
            qkv_weight_interleaved = False

324
        self.kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)
Przemek Tredak's avatar
Przemek Tredak committed
325
326
327
328
329
330

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()

331
332
333
        self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
        self.seq_length = seq_length
Przemek Tredak's avatar
Przemek Tredak committed
334
335
336

        self.get_rng_state_tracker = get_rng_state_tracker

337
338
        self.attn_input_format = attn_input_format

Przemek Tredak's avatar
Przemek Tredak committed
339
340
341
342
343
344
345
346
347
348
349
350
        attention_args = (
            hidden_size,
            num_attention_heads,
            self.kv_channels,
            attention_dropout,
            layernorm_epsilon,
            init_method,
            output_layer_init_method,
        )
        common_attention_kwargs = {
            "layer_number": layer_number,
            "tp_group": tp_group,
351
            "tp_size": self.tp_size,
352
            "num_gqa_groups": num_gqa_groups,
Przemek Tredak's avatar
Przemek Tredak committed
353
354
355
356
357
358
359
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": self.sequence_parallel,
            "params_dtype": params_dtype,
            "return_layernorm_output": apply_residual_connection_post_layernorm,
            "set_parallel_mode": set_parallel_mode,
            "fuse_qkv_params": fuse_qkv_params,
cyanguwa's avatar
cyanguwa committed
360
            "zero_centered_gamma": zero_centered_gamma,
361
362
363
364
365
366
367
            "qkv_weight_interleaved": qkv_weight_interleaved,
            "ub_bulk_wgrad": ub_bulk_wgrad,
            "ub_bulk_dgrad": ub_bulk_dgrad,
            "ub_overlap_ag": ub_overlap_ag,
            "ub_overlap_rs": ub_overlap_rs,
            "ub_overlap_rs_dgrad": ub_overlap_rs_dgrad,
            "qkv_format": self.attn_input_format,
Przemek Tredak's avatar
Przemek Tredak committed
368
369
        }

370
        self.self_attention = MultiheadAttention(
Przemek Tredak's avatar
Przemek Tredak committed
371
372
373
374
            *attention_args,
            **common_attention_kwargs,
            input_layernorm=not output_layernorm,
            attention_type="self",
ngoyal2707's avatar
ngoyal2707 committed
375
            bias=bias,
376
            return_bias=not self.parallel_attention_mlp,
377
            normalization=normalization,
378
            device=device,
Przemek Tredak's avatar
Przemek Tredak committed
379
380
381
        )

        if layer_type == "decoder":
382
            self.inter_attention = MultiheadAttention(
Przemek Tredak's avatar
Przemek Tredak committed
383
384
                *attention_args,
                **common_attention_kwargs,
385
                attn_mask_type=enc_dec_attn_mask_type,
Przemek Tredak's avatar
Przemek Tredak committed
386
387
                input_layernorm=True,
                attention_type="cross",
ngoyal2707's avatar
ngoyal2707 committed
388
                bias=bias,
389
                return_bias=True,
390
                normalization=normalization,
391
                device=device,
Przemek Tredak's avatar
Przemek Tredak committed
392
393
            )

394
        # LayerNorm -> activation(Linear + Bias) -> Linear
Przemek Tredak's avatar
Przemek Tredak committed
395
396
        # parallel_mode not supported for LayerNormMLP,
        # FC1 is CPL and FC2 is RPL
397
398
        # In the case of GLU activation, FC1 handles both
        # Linear layers before the activation
Przemek Tredak's avatar
Przemek Tredak committed
399
400
401
402
403
404
        self.layernorm_mlp = LayerNormMLP(
            hidden_size,
            ffn_hidden_size,
            eps=layernorm_epsilon,
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            tp_group=tp_group,
405
            tp_size=self.tp_size,
Przemek Tredak's avatar
Przemek Tredak committed
406
407
408
            get_rng_state_tracker=get_rng_state_tracker,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
ngoyal2707's avatar
ngoyal2707 committed
409
            bias=bias,
410
            return_bias=not self.parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
411
412
413
414
415
416
            sequence_parallel=self.sequence_parallel,
            params_dtype=params_dtype,
            return_layernorm_output=apply_residual_connection_post_layernorm,
            seq_length=seq_length,
            micro_batch_size=micro_batch_size,
            set_parallel_mode=set_parallel_mode,
417
            zero_centered_gamma=zero_centered_gamma,
418
419
            ub_bulk_wgrad=ub_bulk_wgrad,
            ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
420
            ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
421
422
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
423
            activation=activation,
424
            normalization=normalization,
425
            device=device,
Przemek Tredak's avatar
Przemek Tredak committed
426
427
428
429
430
431
432
433
434
435
        )

        self.hidden_dropout = hidden_dropout
        self.bias_dropout_fusion = bias_dropout_fusion
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None

        # Set bias+dropout+add fusion grad_enable execution handler.
        TORCH_MAJOR = int(torch.__version__.split(".")[0])
        TORCH_MINOR = int(torch.__version__.split(".")[1])
        use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
436
        self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
Przemek Tredak's avatar
Przemek Tredak committed
437
438
439
440
441

        if self.bias_dropout_fusion:
            set_jit_fusion_options()
            if seq_length and micro_batch_size:
                if self.sequence_parallel:
442
                    seq_length = seq_length // self.tp_size
443
                warmup_jit_bias_dropout_add_all_dtypes(hidden_size, seq_length, micro_batch_size)
Przemek Tredak's avatar
Przemek Tredak committed
444

445
        norm_module = {
446
447
            "LayerNorm": LayerNorm,
            "RMSNorm": RMSNorm,
448
        }
Przemek Tredak's avatar
Przemek Tredak committed
449
        if self.output_layernorm:
450
            self.layernorm = norm_module[normalization](
Przemek Tredak's avatar
Przemek Tredak committed
451
452
453
454
                hidden_size,
                eps=layernorm_epsilon,
                sequence_parallel=self.sequence_parallel,
                params_dtype=params_dtype,
455
456
                zero_centered_gamma=zero_centered_gamma,
                device=device,
Przemek Tredak's avatar
Przemek Tredak committed
457
458
459
            )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
460
461
462
463
464
465
466
467
468
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
Przemek Tredak's avatar
Przemek Tredak committed
469
470
471
472
473
474
475
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_tensor_parallel_group"):
                child.set_tensor_parallel_group(tp_group)

476
477
478
479
480
481
482
483
484
    def reset_fp8_meta_tensors(self) -> None:
        """Set TP group"""
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "reset_fp8_meta_tensors"):
                child.reset_fp8_meta_tensors()

485
    def set_context_parallel_group(
486
        self,
487
        cp_group: Union[dist_group_type, List[dist_group_type], None],
488
        cp_global_ranks: List[int],
489
        cp_stream: torch.cuda.Stream,
490
        cp_comm_type: str = "p2p",
491
    ) -> None:
492
493
494
495
496
497
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
498
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
499
                  context parallel process group.
500
501
502
                  ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
                  List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
                  and cp_group[1] are for a2a and p2p communications respectively.
503
504
505
506
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
507
        cp_comm_type : str, default = `p2p`
508
                      inter-gpu communication type for context parallelism.
509
                      Can be "p2p" or "all_gather" or "a2a", or "a2a+p2p".
510
511
512
513
514
515
                      "p2p": Exchange KV chunks with P2P communications in ring topology.
                             P2P is async and can be overlapped with attention compute.
                      "all_gather": All-gather to get full sequence of KV before attention.
                                    The all-gather is not async, and cannot be overlapped.
                      "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                             group, and gather to get full sequence of QKV.
516
517
518
                      "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                      across each CP sub-group (e.g., via NVLink), then exchanging KV with
                      p2p between sub-groups (e.g., via IBLink).
519
        """
520
521
522
523
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
524
            if hasattr(child, "set_context_parallel_group"):
525
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type)
526

Przemek Tredak's avatar
Przemek Tredak committed
527
528
529
    def forward(
        self,
        hidden_states: torch.Tensor,
cyanguwa's avatar
cyanguwa committed
530
        attention_mask: Optional[torch.Tensor] = None,
531
        self_attn_mask_type: Optional[str] = None,
532
        window_size: Optional[Tuple[int, int]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
533
        encoder_output: Optional[torch.Tensor] = None,
534
        enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
535
536
        enc_dec_attn_mask_type: Optional[str] = None,
        enc_dec_window_size: Optional[Tuple[int, int]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
537
        is_first_microbatch: Optional[bool] = None,
cyanguwa's avatar
cyanguwa committed
538
        checkpoint_core_attention: bool = False,
539
        inference_params: Optional[InferenceParams] = None,
540
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
541
542
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
543
        alibi_slopes: Optional[torch.Tensor] = None,
544
545
546
547
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
548
        fast_zero_fill: bool = True,
Przemek Tredak's avatar
Przemek Tredak committed
549
550
551
552
    ) -> torch.Tensor:
        """
        Transformer Layer: attention block and a feedforward network (MLP)

553
554
        .. note::

555
556
            Argument :attr:`attention_mask` is only used when :attr:`self_attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
557

Przemek Tredak's avatar
Przemek Tredak committed
558
559
560
561
        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
562
        attention_mask : Optional[torch.Tensor], default = `None`
563
564
565
566
                        Boolean tensor used to mask out self-attention softmax input. It should be
                        in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable
                        to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`"
                        mask. It should be `None` for causal masks and "`no_mask`" type.
567
568
                        A `True` value means the corresponding position is masked out and
                        a `False` means that position is allowed to participate in attention.
569
570
        self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal',
                            'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'},
571
                            default = `causal`
572
573
574
575
                            Type of attention mask passed into softmax operation for encoder.
                            By default, causal masks are aligned to the top left corner of
                            the softmax matrix. When "`bottom_right`" is specified in the mask type,
                            causal masks are aligned to the bottom right corner.
576
        window_size: Optional[Tuple[int, int]], default = `None`
577
                    Sliding window size for local attention in encoder.
cyanguwa's avatar
cyanguwa committed
578
        encoder_output : Optional[torch.Tensor], default = `None`
Przemek Tredak's avatar
Przemek Tredak committed
579
580
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
581
582
583
        enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensors used to mask out inter-attention softmax input if
             using `layer_type="decoder"`. It should be a tuple of two masks in
584
             [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks.
585
             It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
586
587
588
589
590
591
592
593
             for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`".
             A `True` value means the corresponding position is masked out and a `False`
             means that position is allowed to participate in attention.
        enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
                               default = `None`
                               Type of attention mask passed into softmax operation for decoder.
        enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
                            Sliding window size for local attention in decoder.
Przemek Tredak's avatar
Przemek Tredak committed
594
595
596
597
598
599
600
601
602
603
604
605
606
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
cyanguwa's avatar
cyanguwa committed
607
        checkpoint_core_attention: bool, default = `False`
Przemek Tredak's avatar
Przemek Tredak committed
608
609
610
611
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
612
613
614
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
615
        core_attention_bias_type: str, default = `no_bias`
616
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
617
618
        core_attention_bias: Optional[torch.Tensor], default = `None`
                    Bias tensor for Q * K.T
619
620
621
622
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
623
624
625
626
627
628
629
630
631
632
633
634
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
                   with shape [batch_size + 1] and dtype torch.int32.
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
                   Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
                   and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
        max_seqlen_q: Optional[int], default = `None`
                      Maximum sequence length in `query_layer`.
                      Calculated from `cu_seqlens_q` if not provided.
        max_seqlen_kv: Optional[int], default = `None`
                       Maximum sequence length in `key_layer` and `value_layer`.
                       Calculated from `cu_seqlens_kv` if not provided.
635
636
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
637
638
        inference_params: InferenceParams, default = None
                         Inference parameters that are passed to the main model in order
639
                         to efficiently calculate and store the context during inference.
Przemek Tredak's avatar
Przemek Tredak committed
640
641
        """

642
        if self_attn_mask_type is None:
643
            self_attn_mask_type = self.self_attn_mask_type
644
645
        if window_size is None:
            window_size = self.window_size
646
647
648
649
650
651
        window_size = check_set_window_size(self_attn_mask_type, window_size)
        if enc_dec_attn_mask_type is None:
            enc_dec_attn_mask_type = self.enc_dec_attn_mask_type
        if enc_dec_window_size is None:
            enc_dec_window_size = self.enc_dec_window_size
        enc_dec_window_size = check_set_window_size(enc_dec_attn_mask_type, enc_dec_window_size)
652
653
654
655

        assert (
            self_attn_mask_type in AttnMaskTypes
        ), f"self_attn_mask_type {self_attn_mask_type} not supported"
656
657
658
        assert (
            enc_dec_attn_mask_type in AttnMaskTypes
        ), f"enc_dec_attn_mask_type {enc_dec_attn_mask_type} not supported"
659

660
661
        hidden_states = hidden_states.contiguous()

662
663
664
665
666
        if self.sequence_parallel and self.seq_length is not None:
            assert (
                hidden_states.shape[0] == self.seq_length // self.tp_size
            ), "Sequence dimension must be split across TP group when using sequence parallel."

667
668
669
670
        if (
            "padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary"
        ) and attention_mask is not None:
            assert attention_mask.dtype == torch.bool, "Attention mask must be a boolean tensor"
671
672
673
674
675
676
        if (
            "padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary"
        ) and enc_dec_attn_mask is not None:
            assert all(
                enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask))
            ), "Encoder-decoder attention mask must be boolean tensor(s)"
677

Przemek Tredak's avatar
Przemek Tredak committed
678
679
        # For AMP
        if torch.is_autocast_enabled():
680
            hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype())
Przemek Tredak's avatar
Przemek Tredak committed
681
682
683
684

        # Self attention.
        self_attention_outputs = self.self_attention(
            hidden_states,
685
686
            attention_mask=attention_mask,
            attn_mask_type=self_attn_mask_type,
687
            window_size=window_size,
Przemek Tredak's avatar
Przemek Tredak committed
688
689
690
            inference_params=inference_params,
            is_first_microbatch=is_first_microbatch,
            checkpoint_core_attention=checkpoint_core_attention,
691
            rotary_pos_emb=rotary_pos_emb,
692
693
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
694
            alibi_slopes=alibi_slopes,
695
696
697
698
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_kv=max_seqlen_kv,
699
            fast_zero_fill=fast_zero_fill,
Przemek Tredak's avatar
Przemek Tredak committed
700
        )
ngoyal2707's avatar
ngoyal2707 committed
701

Przemek Tredak's avatar
Przemek Tredak committed
702
703
        if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
            attention_output, attention_bias, residual = self_attention_outputs
704
705
706
707
            hidden_states = self._bias_dropout_add(
                attention_output, attention_bias, residual, self.drop_path
            )
        elif not self.parallel_attention_mlp:
Przemek Tredak's avatar
Przemek Tredak committed
708
            attention_output, attention_bias = self_attention_outputs
709
710
            hidden_states = self._bias_dropout_add(
                attention_output, attention_bias, hidden_states, self.drop_path
Przemek Tredak's avatar
Przemek Tredak committed
711
712
713
714
715
            )

        # Cross attention.
        if self.layer_type == "decoder":
            inter_attention_outputs = self.inter_attention(
716
                hidden_states,
717
                attention_mask=enc_dec_attn_mask,
718
719
                attn_mask_type=enc_dec_attn_mask_type,
                window_size=enc_dec_window_size,
Przemek Tredak's avatar
Przemek Tredak committed
720
721
722
                encoder_output=encoder_output,
                is_first_microbatch=is_first_microbatch,
                checkpoint_core_attention=checkpoint_core_attention,
723
724
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias=core_attention_bias,
725
                alibi_slopes=alibi_slopes,
726
                fast_zero_fill=fast_zero_fill,
Przemek Tredak's avatar
Przemek Tredak committed
727
728
729
730
731
            )
            if self.apply_residual_connection_post_layernorm:
                attention_output, attention_bias, residual = inter_attention_outputs
            else:
                attention_output, attention_bias = inter_attention_outputs
732
733
734
                residual = hidden_states

            hidden_states = self._bias_dropout_add(attention_output, attention_bias, residual)
Przemek Tredak's avatar
Przemek Tredak committed
735
736
737

        # MLP.
        mlp_outputs = self.layernorm_mlp(
738
739
            hidden_states,
            is_first_microbatch=is_first_microbatch,
Przemek Tredak's avatar
Przemek Tredak committed
740
741
742
        )
        if self.apply_residual_connection_post_layernorm:
            mlp_output, mlp_bias, residual = mlp_outputs
743
744
745
746
747
            output = self._bias_dropout_add(mlp_output, mlp_bias, residual, self.drop_path)
        elif self.parallel_attention_mlp:
            output = self._bias_dropout_add(
                self_attention_outputs, mlp_outputs, hidden_states, self.drop_path
            )
Przemek Tredak's avatar
Przemek Tredak committed
748
749
        else:
            mlp_output, mlp_bias = mlp_outputs
750
751
752
753
754
755
756
757
758
759
            output = self._bias_dropout_add(mlp_output, mlp_bias, hidden_states, self.drop_path)

        # For BERT like architectures.
        if self.output_layernorm:
            output = self.layernorm(output)

        # output: [s, b, h]
        return output

    def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None):
760
        if drop_path is None and bias is not None and bias.numel() != 0:
761
762
763
764
765
766
767
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
            else:
                bias_dropout_add_func = get_bias_dropout_add(self.training)
Przemek Tredak's avatar
Przemek Tredak committed
768
769

            with self.bias_dropout_add_exec_handler():
770
                output = bias_dropout_add_func(hidden_state, bias, residual, self.hidden_dropout)
Przemek Tredak's avatar
Przemek Tredak committed
771
        else:
772
            if bias is not None and bias.numel() != 0:
773
                hidden_state = hidden_state + bias
Przemek Tredak's avatar
Przemek Tredak committed
774
            out = torch.nn.functional.dropout(
775
                hidden_state, p=self.hidden_dropout, training=self.training
Przemek Tredak's avatar
Przemek Tredak committed
776
            )
777
778
            if drop_path is not None:
                out = drop_path(out)
ngoyal2707's avatar
ngoyal2707 committed
779
            output = residual + out
Przemek Tredak's avatar
Przemek Tredak committed
780
781

        return output