transformer.py 35.5 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
#
# See LICENSE for license information.

"""Transformer."""
import os
7
import warnings
Przemek Tredak's avatar
Przemek Tredak committed
8
from contextlib import nullcontext
9
from typing import Callable, List, Optional, Tuple, Union
Przemek Tredak's avatar
Przemek Tredak committed
10
11
12

import torch

13
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
14
15
16
17
18
from transformer_engine.pytorch.attention import (
    InferenceParams,
    MultiheadAttention,
    check_set_window_size,
)
Przemek Tredak's avatar
Przemek Tredak committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from transformer_engine.pytorch.jit import (
    set_jit_fusion_options,
    warmup_jit_bias_dropout_add_all_dtypes,
    get_bias_dropout_add,
    bias_dropout_add_fused_train,
    bias_dropout_add_fused_inference,
)
from transformer_engine.pytorch.utils import (
    cast_if_needed,
    get_default_init_method,
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    LayerTypes,
    dist_group_type,
)
35
36
from transformer_engine.pytorch.distributed import get_distributed_world_size

Przemek Tredak's avatar
Przemek Tredak committed
37

38
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
cyanguwa's avatar
cyanguwa committed
39
40


41
__all__ = ["TransformerLayer"]
cyanguwa's avatar
cyanguwa committed
42

Przemek Tredak's avatar
Przemek Tredak committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

class DropPath(torch.nn.Module):
    """Drop paths (Stochastic Depth) per sample
    (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob: float = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        """DropPath FWD"""

        if self.drop_prob == 0.0 or not self.training:
            return hidden_state
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
        shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
        random_tensor = keep_prob + torch.rand(
            shape, dtype=hidden_state.dtype, device=hidden_state.device
        )
        random_tensor.floor_()  # binarize
        output = hidden_state.div(keep_prob) * random_tensor
        return output


class TransformerLayer(torch.nn.Module):
70
    r"""
Przemek Tredak's avatar
Przemek Tredak committed
71
72
73
    TransformerLayer is made up of an attention block and a feedforward network (MLP).
    This standard layer is based on the paper "Attention Is All You Need".

74
    .. note::
75

76
77
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`self_attn_mask_type` includes `"padding"` or `"arbitrary"`.
78

Przemek Tredak's avatar
Przemek Tredak committed
79
80
81
82
83
84
85
86
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    ffn_hidden_size : int
                     intermediate size to which input samples are projected.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
87
88
89
90
91
92
93
94
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
Przemek Tredak's avatar
Przemek Tredak committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    hidden_dropout: float, default = 0.1
                   dropout probability for the dropout op after FC2 layer.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    apply_residual_connection_post_layernorm : bool, default = `False`
                                              if set to `True`, residual connections are taken
                                              from the output of layer norm (default is taken
                                              from input of layer norm)
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
    output_layernorm: bool, default = `False`
                     if set to `True`, layer normalization is applied on the output side,
                     after the final dropout-add. default behavior is to apply layer
                     normalization on the input side, before the QKV transformation.
121
122
123
124
125
    parallel_attention_mlp: bool, default = `False`
                           if set to `True`, self-attention and feedforward network are computed
                           based on the same input (in parallel) instead of sequentially.
                           Both blocks have an independent normalization.
                           This architecture is used in `Falcon` models.
Przemek Tredak's avatar
Przemek Tredak committed
126
127
128
129
130
    layer_type: {'encoder', 'decoder'}, default = `encoder`
               if set to `decoder`, an additional cross-attn block is added after self-attn.
               This can be used for structures like `T5` Transformer in conjunction with the
               `encoder` option.
    kv_channels: int, default = `None`
131
                number of query-key-value channels per attention head. defaults to
Przemek Tredak's avatar
Przemek Tredak committed
132
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
133
134
    self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                        'padding_causal_bottom_right', 'arbitrary'},
135
                        default = `causal`
136
137
138
139
140
                        type of attention mask passed into softmax operation for encoder.
                        Overridden by :attr:`self_attn_mask_type` in the `forward` method.
                        The forward arg is useful for dynamically changing mask types, e.g.
                        a different mask for training and inference. The init arg is useful
                        for cases involving compilation/tracing, e.g. ONNX export.
141
    window_size: Optional[Tuple[int, int]], default = `None`
142
143
144
145
146
147
148
149
150
151
152
                sliding window size for local attention in encoder, where query at position i
                attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k
                - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean
                no sliding window and "`causal`" mask specifically. Similar to
                :attr:`self_attn_mask_type`, it can be overridden by :attr:`window_size`
                in `forward` as well.
    enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
                           default = `no_mask`
                           type of attention mask passed into softmax operation for decoder.
    enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
                        sliding window size for local attention in decoder.
153
154
155
156
157
158
159
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
160
161
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
162
163
164
165
166
167
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
ngoyal2707's avatar
ngoyal2707 committed
168
169
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
170
171
    activation : str, default = 'gelu'
          Type of activation used in MLP block.
172
          Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu' and 'srelu'.
173
174
175
176
    device : Union[torch.device, str], default = "cuda"
          The device on which the parameters of the model will allocated. It is the user's
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
177
178
179
180
181
182
183
184
    attn_input_format: {'sbhd', 'bshd'}, default = 'sbhd'
                         This controls whether the dimensions of the
                         intermediate hidden states is 'batch first' ('bshd') or
                         'sequence first' ('sbhd'). `s` stands for the sequence
                         length, `b` batch size, `h` the number of heads, `d`
                         head size. Note that these formats are very closely
                         related to the `qkv_format` in the `MultiHeadAttention`
                         and `DotProductAttention` modules.
ngoyal2707's avatar
ngoyal2707 committed
185

Przemek Tredak's avatar
Przemek Tredak committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
207
208
209
210
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
211
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
Przemek Tredak's avatar
Przemek Tredak committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    seq_length: int
               sequence length of input samples. Needed for JIT Warmup, a technique where jit
               fused functions are warmed up before training to ensure same kernels are used for
               forward propogation and activation recompute phase.
    micro_batch_size: int
                     batch size per training step. Needed for JIT Warmup, a technique where jit
                     fused functions are warmed up before training to ensure same kernels are
                     used for forward propogation and activation recompute phase.
    drop_path_rate: float, default = 0.0
                   when > 0.0, applies stochastic depth per sample in
                   the main path of the residual block.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
    """

    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        num_attention_heads: int,
238
        num_gqa_groups: Optional[int] = None,
Przemek Tredak's avatar
Przemek Tredak committed
239
240
241
242
243
244
245
        layernorm_epsilon: float = 1e-5,
        hidden_dropout: float = 0.1,
        attention_dropout: float = 0.1,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
        layer_number: Optional[int] = None,
        kv_channels: Optional[int] = None,
246
        self_attn_mask_type: str = "causal",
247
        window_size: Optional[Tuple[int, int]] = None,
248
249
        enc_dec_attn_mask_type: str = "no_mask",
        enc_dec_window_size: Optional[Tuple[int, int]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
250
251
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
252
        params_dtype: Optional[torch.dtype] = None,
Przemek Tredak's avatar
Przemek Tredak committed
253
254
255
256
257
258
259
        get_rng_state_tracker: Optional[Callable] = None,
        fuse_wgrad_accumulation: bool = False,
        seq_length: Optional[int] = None,
        micro_batch_size: Optional[int] = None,
        sequence_parallel: bool = False,
        apply_residual_connection_post_layernorm: bool = False,
        output_layernorm: bool = False,
260
        parallel_attention_mlp: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
261
262
263
264
        layer_type: str = "encoder",
        drop_path_rate: float = 0.0,
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
265
        zero_centered_gamma: bool = False,
266
        qkv_weight_interleaved: bool = True,
267
        ub_tp_comm_overlap: bool = False,
268
269
        ub_bulk_wgrad: bool = True,
        ub_bulk_dgrad: bool = True,
270
271
        ub_overlap_ag: bool = True,
        ub_overlap_rs: bool = True,
Jaemin Choi's avatar
Jaemin Choi committed
272
        ub_overlap_rs_dgrad: bool = False,
ngoyal2707's avatar
ngoyal2707 committed
273
        bias: bool = True,
274
        activation: str = "gelu",
275
        normalization: str = "LayerNorm",
276
        device: Union[torch.device, str] = "cuda",
277
        attn_input_format: str = "sbhd",
Przemek Tredak's avatar
Przemek Tredak committed
278
279
280
    ) -> None:
        super().__init__()

281
        self.self_attn_mask_type = self_attn_mask_type
282
283
284
285
286
        self.window_size = check_set_window_size(self_attn_mask_type, window_size)
        self.enc_dec_attn_mask_type = enc_dec_attn_mask_type
        self.enc_dec_window_size = check_set_window_size(
            enc_dec_attn_mask_type, enc_dec_window_size
        )
287
        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
288
289
        ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
        ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
290
291
        ub_overlap_ag = ub_tp_comm_overlap and ub_overlap_ag
        ub_overlap_rs = ub_tp_comm_overlap and ub_overlap_rs
Jaemin Choi's avatar
Jaemin Choi committed
292
        ub_overlap_rs_dgrad = ub_tp_comm_overlap and ub_overlap_rs_dgrad
293

Przemek Tredak's avatar
Przemek Tredak committed
294
295
296
297
        bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
        self.layer_number = layer_number
        self.output_layernorm = output_layernorm
        self.layer_type = layer_type
298
        self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
299

300
301
        if parallel_attention_mlp:
            assert self.layer_type == "encoder", "parallel_attention requires layer_type='encoder'"
302
303
304
305
            assert not self.apply_residual_connection_post_layernorm, (
                "parallel_attention and apply_residual_connection_post_layernorm "
                "not supported simultaneously."
            )
306
307
308
309
310
311
            assert (
                not self.output_layernorm
            ), "parallel_attention and output_layernorm not supported simultaneously"

        self.parallel_attention_mlp = parallel_attention_mlp

Przemek Tredak's avatar
Przemek Tredak committed
312
313
314
315
316
317
318
        assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"

        if not fuse_qkv_params:
            assert (
                not fuse_wgrad_accumulation
            ), "Gradient accumulation fusion requires single QKV parameter."

319
320
321
        if not fuse_qkv_params:
            qkv_weight_interleaved = False

322
        self.kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)
Przemek Tredak's avatar
Przemek Tredak committed
323
324
325
326
327
328

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()

329
330
331
        self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
        self.seq_length = seq_length
Przemek Tredak's avatar
Przemek Tredak committed
332
333
334

        self.get_rng_state_tracker = get_rng_state_tracker

335
336
        self.attn_input_format = attn_input_format

Przemek Tredak's avatar
Przemek Tredak committed
337
338
339
340
341
342
343
344
345
346
347
348
        attention_args = (
            hidden_size,
            num_attention_heads,
            self.kv_channels,
            attention_dropout,
            layernorm_epsilon,
            init_method,
            output_layer_init_method,
        )
        common_attention_kwargs = {
            "layer_number": layer_number,
            "tp_group": tp_group,
349
            "tp_size": self.tp_size,
350
            "num_gqa_groups": num_gqa_groups,
Przemek Tredak's avatar
Przemek Tredak committed
351
352
353
354
355
356
357
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": self.sequence_parallel,
            "params_dtype": params_dtype,
            "return_layernorm_output": apply_residual_connection_post_layernorm,
            "set_parallel_mode": set_parallel_mode,
            "fuse_qkv_params": fuse_qkv_params,
cyanguwa's avatar
cyanguwa committed
358
            "zero_centered_gamma": zero_centered_gamma,
359
360
361
362
363
364
365
            "qkv_weight_interleaved": qkv_weight_interleaved,
            "ub_bulk_wgrad": ub_bulk_wgrad,
            "ub_bulk_dgrad": ub_bulk_dgrad,
            "ub_overlap_ag": ub_overlap_ag,
            "ub_overlap_rs": ub_overlap_rs,
            "ub_overlap_rs_dgrad": ub_overlap_rs_dgrad,
            "qkv_format": self.attn_input_format,
Przemek Tredak's avatar
Przemek Tredak committed
366
367
        }

368
        self.self_attention = MultiheadAttention(
Przemek Tredak's avatar
Przemek Tredak committed
369
370
371
372
            *attention_args,
            **common_attention_kwargs,
            input_layernorm=not output_layernorm,
            attention_type="self",
ngoyal2707's avatar
ngoyal2707 committed
373
            bias=bias,
374
            return_bias=not self.parallel_attention_mlp,
375
            normalization=normalization,
376
            device=device,
Przemek Tredak's avatar
Przemek Tredak committed
377
378
379
        )

        if layer_type == "decoder":
380
            self.inter_attention = MultiheadAttention(
Przemek Tredak's avatar
Przemek Tredak committed
381
382
                *attention_args,
                **common_attention_kwargs,
383
                attn_mask_type=enc_dec_attn_mask_type,
Przemek Tredak's avatar
Przemek Tredak committed
384
385
                input_layernorm=True,
                attention_type="cross",
ngoyal2707's avatar
ngoyal2707 committed
386
                bias=bias,
387
                return_bias=True,
388
                normalization=normalization,
389
                device=device,
Przemek Tredak's avatar
Przemek Tredak committed
390
391
            )

392
        # LayerNorm -> activation(Linear + Bias) -> Linear
Przemek Tredak's avatar
Przemek Tredak committed
393
394
        # parallel_mode not supported for LayerNormMLP,
        # FC1 is CPL and FC2 is RPL
395
396
        # In the case of GLU activation, FC1 handles both
        # Linear layers before the activation
Przemek Tredak's avatar
Przemek Tredak committed
397
398
399
400
401
402
        self.layernorm_mlp = LayerNormMLP(
            hidden_size,
            ffn_hidden_size,
            eps=layernorm_epsilon,
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            tp_group=tp_group,
403
            tp_size=self.tp_size,
Przemek Tredak's avatar
Przemek Tredak committed
404
405
406
            get_rng_state_tracker=get_rng_state_tracker,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
ngoyal2707's avatar
ngoyal2707 committed
407
            bias=bias,
408
            return_bias=not self.parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
409
410
411
412
413
414
            sequence_parallel=self.sequence_parallel,
            params_dtype=params_dtype,
            return_layernorm_output=apply_residual_connection_post_layernorm,
            seq_length=seq_length,
            micro_batch_size=micro_batch_size,
            set_parallel_mode=set_parallel_mode,
415
            zero_centered_gamma=zero_centered_gamma,
416
417
            ub_bulk_wgrad=ub_bulk_wgrad,
            ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
418
            ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
419
420
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
421
            activation=activation,
422
            normalization=normalization,
423
            device=device,
Przemek Tredak's avatar
Przemek Tredak committed
424
425
426
427
428
429
430
431
432
433
        )

        self.hidden_dropout = hidden_dropout
        self.bias_dropout_fusion = bias_dropout_fusion
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None

        # Set bias+dropout+add fusion grad_enable execution handler.
        TORCH_MAJOR = int(torch.__version__.split(".")[0])
        TORCH_MINOR = int(torch.__version__.split(".")[1])
        use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
434
        self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
Przemek Tredak's avatar
Przemek Tredak committed
435
436
437
438
439

        if self.bias_dropout_fusion:
            set_jit_fusion_options()
            if seq_length and micro_batch_size:
                if self.sequence_parallel:
440
                    seq_length = seq_length // self.tp_size
441
                warmup_jit_bias_dropout_add_all_dtypes(hidden_size, seq_length, micro_batch_size)
Przemek Tredak's avatar
Przemek Tredak committed
442

443
        norm_module = {
444
445
            "LayerNorm": LayerNorm,
            "RMSNorm": RMSNorm,
446
        }
Przemek Tredak's avatar
Przemek Tredak committed
447
        if self.output_layernorm:
448
            self.layernorm = norm_module[normalization](
Przemek Tredak's avatar
Przemek Tredak committed
449
450
451
452
                hidden_size,
                eps=layernorm_epsilon,
                sequence_parallel=self.sequence_parallel,
                params_dtype=params_dtype,
453
454
                zero_centered_gamma=zero_centered_gamma,
                device=device,
Przemek Tredak's avatar
Przemek Tredak committed
455
456
457
            )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
458
459
460
461
462
463
464
465
466
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
Przemek Tredak's avatar
Przemek Tredak committed
467
468
469
470
471
472
473
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_tensor_parallel_group"):
                child.set_tensor_parallel_group(tp_group)

474
475
476
477
478
479
480
481
482
    def reset_fp8_meta_tensors(self) -> None:
        """Set TP group"""
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "reset_fp8_meta_tensors"):
                child.reset_fp8_meta_tensors()

483
    def set_context_parallel_group(
484
485
        self,
        cp_group: Union[dist_group_type, None],
486
        cp_global_ranks: List[int],
487
488
        cp_stream: torch.cuda.Stream,
    ) -> None:
489
490
491
492
493
494
495
496
497
498
499
500
501
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
        """
502
503
504
505
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
506
507
            if hasattr(child, "set_context_parallel_group"):
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
508

Przemek Tredak's avatar
Przemek Tredak committed
509
510
511
    def forward(
        self,
        hidden_states: torch.Tensor,
cyanguwa's avatar
cyanguwa committed
512
        attention_mask: Optional[torch.Tensor] = None,
513
        self_attn_mask_type: Optional[str] = None,
514
        window_size: Optional[Tuple[int, int]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
515
        encoder_output: Optional[torch.Tensor] = None,
516
        enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
517
518
        enc_dec_attn_mask_type: Optional[str] = None,
        enc_dec_window_size: Optional[Tuple[int, int]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
519
        is_first_microbatch: Optional[bool] = None,
cyanguwa's avatar
cyanguwa committed
520
        checkpoint_core_attention: bool = False,
521
        inference_params: Optional[InferenceParams] = None,
522
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
523
524
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
525
        alibi_slopes: Optional[torch.Tensor] = None,
526
        fast_zero_fill: bool = True,
Przemek Tredak's avatar
Przemek Tredak committed
527
528
529
530
    ) -> torch.Tensor:
        """
        Transformer Layer: attention block and a feedforward network (MLP)

531
532
        .. note::

533
534
            Argument :attr:`attention_mask` is only used when :attr:`self_attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
535

Przemek Tredak's avatar
Przemek Tredak committed
536
537
538
539
        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
540
        attention_mask : Optional[torch.Tensor], default = `None`
541
542
543
544
                        Boolean tensor used to mask out self-attention softmax input. It should be
                        in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable
                        to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`"
                        mask. It should be `None` for causal masks and "`no_mask`" type.
545
546
                        A `True` value means the corresponding position is masked out and
                        a `False` means that position is allowed to participate in attention.
547
548
        self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal',
                            'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'},
549
                            default = `causal`
550
551
552
553
                            Type of attention mask passed into softmax operation for encoder.
                            By default, causal masks are aligned to the top left corner of
                            the softmax matrix. When "`bottom_right`" is specified in the mask type,
                            causal masks are aligned to the bottom right corner.
554
        window_size: Optional[Tuple[int, int]], default = `None`
555
                    Sliding window size for local attention in encoder.
cyanguwa's avatar
cyanguwa committed
556
        encoder_output : Optional[torch.Tensor], default = `None`
Przemek Tredak's avatar
Przemek Tredak committed
557
558
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
559
560
561
        enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensors used to mask out inter-attention softmax input if
             using `layer_type="decoder"`. It should be a tuple of two masks in
562
             [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks.
563
             It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
564
565
566
567
568
569
570
571
             for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`".
             A `True` value means the corresponding position is masked out and a `False`
             means that position is allowed to participate in attention.
        enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
                               default = `None`
                               Type of attention mask passed into softmax operation for decoder.
        enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
                            Sliding window size for local attention in decoder.
Przemek Tredak's avatar
Przemek Tredak committed
572
573
574
575
576
577
578
579
580
581
582
583
584
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
cyanguwa's avatar
cyanguwa committed
585
        checkpoint_core_attention: bool, default = `False`
Przemek Tredak's avatar
Przemek Tredak committed
586
587
588
589
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
590
591
592
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
593
        core_attention_bias_type: str, default = `no_bias`
594
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
595
596
        core_attention_bias: Optional[torch.Tensor], default = `None`
                    Bias tensor for Q * K.T
597
598
599
600
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
601
602
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
603
604
605
        inference_params: InferenceParams, default = None
                         Inference parameters that are passed to the main model in order
                         to efficienly calculate and store the context during inference.
Przemek Tredak's avatar
Przemek Tredak committed
606
607
        """

608
        if self_attn_mask_type is None:
609
            self_attn_mask_type = self.self_attn_mask_type
610
611
        if window_size is None:
            window_size = self.window_size
612
613
614
615
616
617
        window_size = check_set_window_size(self_attn_mask_type, window_size)
        if enc_dec_attn_mask_type is None:
            enc_dec_attn_mask_type = self.enc_dec_attn_mask_type
        if enc_dec_window_size is None:
            enc_dec_window_size = self.enc_dec_window_size
        enc_dec_window_size = check_set_window_size(enc_dec_attn_mask_type, enc_dec_window_size)
618
619
620
621

        assert (
            self_attn_mask_type in AttnMaskTypes
        ), f"self_attn_mask_type {self_attn_mask_type} not supported"
622
623
624
        assert (
            enc_dec_attn_mask_type in AttnMaskTypes
        ), f"enc_dec_attn_mask_type {enc_dec_attn_mask_type} not supported"
625

626
627
        hidden_states = hidden_states.contiguous()

628
629
630
631
632
        if self.sequence_parallel and self.seq_length is not None:
            assert (
                hidden_states.shape[0] == self.seq_length // self.tp_size
            ), "Sequence dimension must be split across TP group when using sequence parallel."

633
634
635
636
        if (
            "padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary"
        ) and attention_mask is not None:
            assert attention_mask.dtype == torch.bool, "Attention mask must be a boolean tensor"
637
638
639
640
641
642
        if (
            "padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary"
        ) and enc_dec_attn_mask is not None:
            assert all(
                enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask))
            ), "Encoder-decoder attention mask must be boolean tensor(s)"
643

Przemek Tredak's avatar
Przemek Tredak committed
644
645
        # For AMP
        if torch.is_autocast_enabled():
646
            hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype())
Przemek Tredak's avatar
Przemek Tredak committed
647
648
649
650

        # Self attention.
        self_attention_outputs = self.self_attention(
            hidden_states,
651
652
            attention_mask=attention_mask,
            attn_mask_type=self_attn_mask_type,
653
            window_size=enc_dec_window_size,
Przemek Tredak's avatar
Przemek Tredak committed
654
655
656
            inference_params=inference_params,
            is_first_microbatch=is_first_microbatch,
            checkpoint_core_attention=checkpoint_core_attention,
657
            rotary_pos_emb=rotary_pos_emb,
658
659
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
660
            alibi_slopes=alibi_slopes,
661
            fast_zero_fill=fast_zero_fill,
Przemek Tredak's avatar
Przemek Tredak committed
662
        )
ngoyal2707's avatar
ngoyal2707 committed
663

Przemek Tredak's avatar
Przemek Tredak committed
664
665
        if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
            attention_output, attention_bias, residual = self_attention_outputs
666
667
668
669
            hidden_states = self._bias_dropout_add(
                attention_output, attention_bias, residual, self.drop_path
            )
        elif not self.parallel_attention_mlp:
Przemek Tredak's avatar
Przemek Tredak committed
670
            attention_output, attention_bias = self_attention_outputs
671
672
            hidden_states = self._bias_dropout_add(
                attention_output, attention_bias, hidden_states, self.drop_path
Przemek Tredak's avatar
Przemek Tredak committed
673
674
675
676
677
            )

        # Cross attention.
        if self.layer_type == "decoder":
            inter_attention_outputs = self.inter_attention(
678
                hidden_states,
679
                attention_mask=enc_dec_attn_mask,
Przemek Tredak's avatar
Przemek Tredak committed
680
681
682
                encoder_output=encoder_output,
                is_first_microbatch=is_first_microbatch,
                checkpoint_core_attention=checkpoint_core_attention,
683
684
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias=core_attention_bias,
685
                alibi_slopes=alibi_slopes,
686
                fast_zero_fill=fast_zero_fill,
Przemek Tredak's avatar
Przemek Tredak committed
687
688
689
690
691
            )
            if self.apply_residual_connection_post_layernorm:
                attention_output, attention_bias, residual = inter_attention_outputs
            else:
                attention_output, attention_bias = inter_attention_outputs
692
693
694
                residual = hidden_states

            hidden_states = self._bias_dropout_add(attention_output, attention_bias, residual)
Przemek Tredak's avatar
Przemek Tredak committed
695
696
697

        # MLP.
        mlp_outputs = self.layernorm_mlp(
698
699
            hidden_states,
            is_first_microbatch=is_first_microbatch,
Przemek Tredak's avatar
Przemek Tredak committed
700
701
702
        )
        if self.apply_residual_connection_post_layernorm:
            mlp_output, mlp_bias, residual = mlp_outputs
703
704
705
706
707
            output = self._bias_dropout_add(mlp_output, mlp_bias, residual, self.drop_path)
        elif self.parallel_attention_mlp:
            output = self._bias_dropout_add(
                self_attention_outputs, mlp_outputs, hidden_states, self.drop_path
            )
Przemek Tredak's avatar
Przemek Tredak committed
708
709
        else:
            mlp_output, mlp_bias = mlp_outputs
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
            output = self._bias_dropout_add(mlp_output, mlp_bias, hidden_states, self.drop_path)

        # For BERT like architectures.
        if self.output_layernorm:
            output = self.layernorm(output)

        # output: [s, b, h]
        return output

    def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None):
        if drop_path is None and bias.numel() != 0:
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
            else:
                bias_dropout_add_func = get_bias_dropout_add(self.training)
Przemek Tredak's avatar
Przemek Tredak committed
728
729

            with self.bias_dropout_add_exec_handler():
730
                output = bias_dropout_add_func(hidden_state, bias, residual, self.hidden_dropout)
Przemek Tredak's avatar
Przemek Tredak committed
731
        else:
732
733
            if bias.numel() != 0:
                hidden_state = hidden_state + bias
Przemek Tredak's avatar
Przemek Tredak committed
734
            out = torch.nn.functional.dropout(
735
                hidden_state, p=self.hidden_dropout, training=self.training
Przemek Tredak's avatar
Przemek Tredak committed
736
            )
737
738
            if drop_path is not None:
                out = drop_path(out)
ngoyal2707's avatar
ngoyal2707 committed
739
            output = residual + out
Przemek Tredak's avatar
Przemek Tredak committed
740
741

        return output