transformer.py 40.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
#
# See LICENSE for license information.

"""Transformer."""
import os
7
import warnings
Przemek Tredak's avatar
Przemek Tredak committed
8
from contextlib import nullcontext
9
from typing import Callable, List, Optional, Tuple, Union
Przemek Tredak's avatar
Przemek Tredak committed
10
11
12

import torch

13
from transformer_engine.pytorch import torch_version
14
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
15
from transformer_engine.debug.pytorch.debug_state import TEDebugState
16
17
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.inference import InferenceParams
Przemek Tredak's avatar
Przemek Tredak committed
18
19
20
21
22
23
24
25
26
27
from transformer_engine.pytorch.jit import (
    set_jit_fusion_options,
    warmup_jit_bias_dropout_add_all_dtypes,
    get_bias_dropout_add,
    bias_dropout_add_fused_train,
    bias_dropout_add_fused_inference,
)
from transformer_engine.pytorch.utils import (
    cast_if_needed,
    get_default_init_method,
28
    torch_get_autocast_gpu_dtype,
Przemek Tredak's avatar
Przemek Tredak committed
29
30
31
32
33
34
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    LayerTypes,
    dist_group_type,
)
35
from transformer_engine.pytorch.distributed import get_distributed_world_size
36
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
37

Przemek Tredak's avatar
Przemek Tredak committed
38

39
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
cyanguwa's avatar
cyanguwa committed
40
41


42
__all__ = ["TransformerLayer"]
cyanguwa's avatar
cyanguwa committed
43

Przemek Tredak's avatar
Przemek Tredak committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

class DropPath(torch.nn.Module):
    """Drop paths (Stochastic Depth) per sample
    (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob: float = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        """DropPath FWD"""

        if self.drop_prob == 0.0 or not self.training:
            return hidden_state
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
        shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
        random_tensor = keep_prob + torch.rand(
            shape, dtype=hidden_state.dtype, device=hidden_state.device
        )
        random_tensor.floor_()  # binarize
        output = hidden_state.div(keep_prob) * random_tensor
        return output


class TransformerLayer(torch.nn.Module):
71
    r"""
Przemek Tredak's avatar
Przemek Tredak committed
72
73
74
    TransformerLayer is made up of an attention block and a feedforward network (MLP).
    This standard layer is based on the paper "Attention Is All You Need".

75
    .. note::
76

77
78
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`self_attn_mask_type` includes `"padding"` or `"arbitrary"`.
79

Przemek Tredak's avatar
Przemek Tredak committed
80
81
82
83
84
85
86
87
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    ffn_hidden_size : int
                     intermediate size to which input samples are projected.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
88
89
90
91
92
93
94
95
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
Przemek Tredak's avatar
Przemek Tredak committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    hidden_dropout: float, default = 0.1
                   dropout probability for the dropout op after FC2 layer.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    apply_residual_connection_post_layernorm : bool, default = `False`
                                              if set to `True`, residual connections are taken
                                              from the output of layer norm (default is taken
                                              from input of layer norm)
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
    output_layernorm: bool, default = `False`
                     if set to `True`, layer normalization is applied on the output side,
                     after the final dropout-add. default behavior is to apply layer
                     normalization on the input side, before the QKV transformation.
122
123
124
125
126
    parallel_attention_mlp: bool, default = `False`
                           if set to `True`, self-attention and feedforward network are computed
                           based on the same input (in parallel) instead of sequentially.
                           Both blocks have an independent normalization.
                           This architecture is used in `Falcon` models.
Przemek Tredak's avatar
Przemek Tredak committed
127
128
129
130
131
    layer_type: {'encoder', 'decoder'}, default = `encoder`
               if set to `decoder`, an additional cross-attn block is added after self-attn.
               This can be used for structures like `T5` Transformer in conjunction with the
               `encoder` option.
    kv_channels: int, default = `None`
132
                number of query-key-value channels per attention head. defaults to
Przemek Tredak's avatar
Przemek Tredak committed
133
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
134
135
    self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
                        'padding_causal_bottom_right', 'arbitrary'},
136
                        default = `causal`
137
138
139
140
141
                        type of attention mask passed into softmax operation for encoder.
                        Overridden by :attr:`self_attn_mask_type` in the `forward` method.
                        The forward arg is useful for dynamically changing mask types, e.g.
                        a different mask for training and inference. The init arg is useful
                        for cases involving compilation/tracing, e.g. ONNX export.
142
    window_size: Optional[Tuple[int, int]], default = `None`
143
144
145
                sliding window size for local attention in encoder, where query at position i
                attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k
                - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean
146
147
148
149
150
                no sliding window and causal mask specifically. Both `causal` and
                `causal_bottom_right` masks map to `window_size = (-1, 0)` and Transformer Engine
                distinguishes them based on `self_attn_mask_type` or `enc_dec_attn_mask_type`.
                Similar to :attr:`self_attn_mask_type`, `window_size` can be overridden by
                :attr:`window_size` in `forward` as well.
151
152
153
154
155
    enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
                           default = `no_mask`
                           type of attention mask passed into softmax operation for decoder.
    enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
                        sliding window size for local attention in decoder.
156
157
158
159
160
161
162
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
163
164
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
165
166
167
168
169
170
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
171
172
    rotary_pos_interleaved : bool, default = `False`
                            whether to use interleaved rotary position embeddings.
ngoyal2707's avatar
ngoyal2707 committed
173
174
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
175
176
    activation : str, default = 'gelu'
          Type of activation used in MLP block.
177
          Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu' and 'srelu'.
178
    device : Union[torch.device, str], default = "cuda"
179
          The device on which the parameters of the model will be allocated. It is the user's
180
181
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
182
    attn_input_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
183
                         This controls whether the dimensions of the
184
185
186
187
                         intermediate hidden states is 'sequence first' ('sbhd'), 'batch first' ('bshd'),
                         or 'token first' ('thd'). `s` stands for the sequence length, `b` batch size,
                         `t` the total number of tokens, `h` the number of heads, `d` head size.
                         Note that these formats are very closely
188
189
                         related to the `qkv_format` in the `MultiHeadAttention`
                         and `DotProductAttention` modules.
190
191
    name: str, default = `None`
        name of the module, currently used for debugging purposes.
ngoyal2707's avatar
ngoyal2707 committed
192

Przemek Tredak's avatar
Przemek Tredak committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
214
215
216
217
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
218
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
Przemek Tredak's avatar
Przemek Tredak committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    seq_length: int
               sequence length of input samples. Needed for JIT Warmup, a technique where jit
               fused functions are warmed up before training to ensure same kernels are used for
               forward propogation and activation recompute phase.
    micro_batch_size: int
                     batch size per training step. Needed for JIT Warmup, a technique where jit
                     fused functions are warmed up before training to ensure same kernels are
                     used for forward propogation and activation recompute phase.
    drop_path_rate: float, default = 0.0
                   when > 0.0, applies stochastic depth per sample in
                   the main path of the residual block.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
    """

    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        num_attention_heads: int,
245
        num_gqa_groups: Optional[int] = None,
Przemek Tredak's avatar
Przemek Tredak committed
246
247
248
249
250
251
252
        layernorm_epsilon: float = 1e-5,
        hidden_dropout: float = 0.1,
        attention_dropout: float = 0.1,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
        layer_number: Optional[int] = None,
        kv_channels: Optional[int] = None,
253
        self_attn_mask_type: str = "causal",
254
        window_size: Optional[Tuple[int, int]] = None,
255
256
        enc_dec_attn_mask_type: str = "no_mask",
        enc_dec_window_size: Optional[Tuple[int, int]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
257
258
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
259
        params_dtype: Optional[torch.dtype] = None,
Przemek Tredak's avatar
Przemek Tredak committed
260
261
262
263
264
265
266
        get_rng_state_tracker: Optional[Callable] = None,
        fuse_wgrad_accumulation: bool = False,
        seq_length: Optional[int] = None,
        micro_batch_size: Optional[int] = None,
        sequence_parallel: bool = False,
        apply_residual_connection_post_layernorm: bool = False,
        output_layernorm: bool = False,
267
        parallel_attention_mlp: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
268
269
270
271
        layer_type: str = "encoder",
        drop_path_rate: float = 0.0,
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
272
        rotary_pos_interleaved: bool = False,
273
        zero_centered_gamma: bool = False,
274
        qkv_weight_interleaved: bool = True,
275
        ub_tp_comm_overlap: bool = False,
276
277
        ub_overlap_ag: bool = True,
        ub_overlap_rs: bool = True,
Jaemin Choi's avatar
Jaemin Choi committed
278
        ub_overlap_rs_dgrad: bool = False,
279
280
        ub_bulk_dgrad: bool = True,
        ub_bulk_wgrad: bool = True,
ngoyal2707's avatar
ngoyal2707 committed
281
        bias: bool = True,
282
        activation: str = "gelu",
283
        normalization: str = "LayerNorm",
284
        device: Union[torch.device, str] = "cuda",
285
        attn_input_format: str = "sbhd",
286
        name: str = None,
Przemek Tredak's avatar
Przemek Tredak committed
287
288
289
    ) -> None:
        super().__init__()

290
        self.self_attn_mask_type = self_attn_mask_type
291
        self.window_size = window_size
292
        self.enc_dec_attn_mask_type = enc_dec_attn_mask_type
293
        self.enc_dec_window_size = enc_dec_window_size
294
        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
295
296
        ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
        ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
297
298
        ub_overlap_ag = ub_tp_comm_overlap and ub_overlap_ag
        ub_overlap_rs = ub_tp_comm_overlap and ub_overlap_rs
Jaemin Choi's avatar
Jaemin Choi committed
299
        ub_overlap_rs_dgrad = ub_tp_comm_overlap and ub_overlap_rs_dgrad
300

Przemek Tredak's avatar
Przemek Tredak committed
301
302
303
304
        bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
        self.layer_number = layer_number
        self.output_layernorm = output_layernorm
        self.layer_type = layer_type
305
        self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
306

307
308
        if parallel_attention_mlp:
            assert self.layer_type == "encoder", "parallel_attention requires layer_type='encoder'"
309
310
311
312
            assert not self.apply_residual_connection_post_layernorm, (
                "parallel_attention and apply_residual_connection_post_layernorm "
                "not supported simultaneously."
            )
313
314
315
316
317
318
            assert (
                not self.output_layernorm
            ), "parallel_attention and output_layernorm not supported simultaneously"

        self.parallel_attention_mlp = parallel_attention_mlp

Przemek Tredak's avatar
Przemek Tredak committed
319
320
321
322
323
324
325
        assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"

        if not fuse_qkv_params:
            assert (
                not fuse_wgrad_accumulation
            ), "Gradient accumulation fusion requires single QKV parameter."

326
327
328
        if not fuse_qkv_params:
            qkv_weight_interleaved = False

329
        self.kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)
Przemek Tredak's avatar
Przemek Tredak committed
330
331
332
333
334
335

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()

336
337
338
        self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
        self.seq_length = seq_length
Przemek Tredak's avatar
Przemek Tredak committed
339
340
341

        self.get_rng_state_tracker = get_rng_state_tracker

342
343
        self.attn_input_format = attn_input_format

344
345
        self.name = name

Przemek Tredak's avatar
Przemek Tredak committed
346
347
348
349
350
351
352
353
354
355
356
357
        attention_args = (
            hidden_size,
            num_attention_heads,
            self.kv_channels,
            attention_dropout,
            layernorm_epsilon,
            init_method,
            output_layer_init_method,
        )
        common_attention_kwargs = {
            "layer_number": layer_number,
            "tp_group": tp_group,
358
            "tp_size": self.tp_size,
359
            "num_gqa_groups": num_gqa_groups,
Przemek Tredak's avatar
Przemek Tredak committed
360
361
362
363
364
365
366
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": self.sequence_parallel,
            "params_dtype": params_dtype,
            "return_layernorm_output": apply_residual_connection_post_layernorm,
            "set_parallel_mode": set_parallel_mode,
            "fuse_qkv_params": fuse_qkv_params,
cyanguwa's avatar
cyanguwa committed
367
            "zero_centered_gamma": zero_centered_gamma,
368
            "qkv_weight_interleaved": qkv_weight_interleaved,
369
            "rotary_pos_interleaved": rotary_pos_interleaved,
370
371
372
373
374
375
            "ub_bulk_wgrad": ub_bulk_wgrad,
            "ub_bulk_dgrad": ub_bulk_dgrad,
            "ub_overlap_ag": ub_overlap_ag,
            "ub_overlap_rs": ub_overlap_rs,
            "ub_overlap_rs_dgrad": ub_overlap_rs_dgrad,
            "qkv_format": self.attn_input_format,
Przemek Tredak's avatar
Przemek Tredak committed
376
377
        }

378
        self.self_attention = MultiheadAttention(
Przemek Tredak's avatar
Przemek Tredak committed
379
380
381
382
            *attention_args,
            **common_attention_kwargs,
            input_layernorm=not output_layernorm,
            attention_type="self",
ngoyal2707's avatar
ngoyal2707 committed
383
            bias=bias,
384
            return_bias=not self.parallel_attention_mlp,
385
            normalization=normalization,
386
            device=device,
387
            name=name + ".self_attention" if name is not None else None,
Przemek Tredak's avatar
Przemek Tredak committed
388
389
390
        )

        if layer_type == "decoder":
391
            self.inter_attention = MultiheadAttention(
Przemek Tredak's avatar
Przemek Tredak committed
392
393
                *attention_args,
                **common_attention_kwargs,
394
                attn_mask_type=enc_dec_attn_mask_type,
Przemek Tredak's avatar
Przemek Tredak committed
395
396
                input_layernorm=True,
                attention_type="cross",
ngoyal2707's avatar
ngoyal2707 committed
397
                bias=bias,
398
                return_bias=True,
399
                normalization=normalization,
400
                device=device,
401
                name=name + ".inter_attention" if name is not None else None,
Przemek Tredak's avatar
Przemek Tredak committed
402
403
            )

404
        # LayerNorm -> activation(Linear + Bias) -> Linear
Przemek Tredak's avatar
Przemek Tredak committed
405
406
        # parallel_mode not supported for LayerNormMLP,
        # FC1 is CPL and FC2 is RPL
407
408
        # In the case of GLU activation, FC1 handles both
        # Linear layers before the activation
Przemek Tredak's avatar
Przemek Tredak committed
409
410
411
412
413
414
        self.layernorm_mlp = LayerNormMLP(
            hidden_size,
            ffn_hidden_size,
            eps=layernorm_epsilon,
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            tp_group=tp_group,
415
            tp_size=self.tp_size,
Przemek Tredak's avatar
Przemek Tredak committed
416
417
418
            get_rng_state_tracker=get_rng_state_tracker,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
ngoyal2707's avatar
ngoyal2707 committed
419
            bias=bias,
420
            return_bias=not self.parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
421
422
423
424
425
426
            sequence_parallel=self.sequence_parallel,
            params_dtype=params_dtype,
            return_layernorm_output=apply_residual_connection_post_layernorm,
            seq_length=seq_length,
            micro_batch_size=micro_batch_size,
            set_parallel_mode=set_parallel_mode,
427
            zero_centered_gamma=zero_centered_gamma,
428
429
            ub_bulk_wgrad=ub_bulk_wgrad,
            ub_bulk_dgrad=ub_bulk_dgrad,
Jaemin Choi's avatar
Jaemin Choi committed
430
            ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
431
432
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
433
            activation=activation,
434
            normalization=normalization,
435
            device=device,
436
            name=name + ".layernorm_mlp" if name is not None else None,
Przemek Tredak's avatar
Przemek Tredak committed
437
438
439
440
441
442
443
        )

        self.hidden_dropout = hidden_dropout
        self.bias_dropout_fusion = bias_dropout_fusion
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None

        # Set bias+dropout+add fusion grad_enable execution handler.
444
        use_nvfuser = torch_version() >= (1, 10, 0) and torch_version() < (2, 2, 0)
445
        self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
Przemek Tredak's avatar
Przemek Tredak committed
446
447
448
449
450

        if self.bias_dropout_fusion:
            set_jit_fusion_options()
            if seq_length and micro_batch_size:
                if self.sequence_parallel:
451
                    seq_length = seq_length // self.tp_size
452
                warmup_jit_bias_dropout_add_all_dtypes(hidden_size, seq_length, micro_batch_size)
Przemek Tredak's avatar
Przemek Tredak committed
453

454
        norm_module = {
455
456
            "LayerNorm": LayerNorm,
            "RMSNorm": RMSNorm,
457
        }
Przemek Tredak's avatar
Przemek Tredak committed
458
        if self.output_layernorm:
459
            self.layernorm = norm_module[normalization](
Przemek Tredak's avatar
Przemek Tredak committed
460
461
462
463
                hidden_size,
                eps=layernorm_epsilon,
                sequence_parallel=self.sequence_parallel,
                params_dtype=params_dtype,
464
465
                zero_centered_gamma=zero_centered_gamma,
                device=device,
Przemek Tredak's avatar
Przemek Tredak committed
466
467
468
            )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
469
470
471
472
473
474
475
476
477
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
Przemek Tredak's avatar
Przemek Tredak committed
478
479
480
481
482
483
484
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_tensor_parallel_group"):
                child.set_tensor_parallel_group(tp_group)

485
486
487
488
489
490
491
492
493
    def reset_fp8_meta_tensors(self) -> None:
        """Set TP group"""
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "reset_fp8_meta_tensors"):
                child.reset_fp8_meta_tensors()

494
    def set_context_parallel_group(
495
        self,
496
        cp_group: Union[dist_group_type, List[dist_group_type], None],
497
        cp_global_ranks: List[int],
498
        cp_stream: torch.cuda.Stream,
499
        cp_comm_type: str = "p2p",
500
    ) -> None:
501
502
503
504
505
506
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
507
        cp_group : Union[ProcessGroup, List[ProcessGroup]]
508
                  context parallel process group.
509
510
511
                  ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a".
                  List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0]
                  and cp_group[1] are for a2a and p2p communications respectively.
512
513
514
515
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
516
        cp_comm_type : str, default = `p2p`
517
                      inter-gpu communication type for context parallelism.
518
                      Can be "p2p" or "all_gather" or "a2a", or "a2a+p2p".
519
520
521
522
523
524
                      "p2p": Exchange KV chunks with P2P communications in ring topology.
                             P2P is async and can be overlapped with attention compute.
                      "all_gather": All-gather to get full sequence of KV before attention.
                                    The all-gather is not async, and cannot be overlapped.
                      "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP
                             group, and gather to get full sequence of QKV.
525
526
527
                      "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV
                      across each CP sub-group (e.g., via NVLink), then exchanging KV with
                      p2p between sub-groups (e.g., via IBLink).
528
        """
529
530
531
532
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
533
            if hasattr(child, "set_context_parallel_group"):
534
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream, cp_comm_type)
535

Przemek Tredak's avatar
Przemek Tredak committed
536
537
538
    def forward(
        self,
        hidden_states: torch.Tensor,
cyanguwa's avatar
cyanguwa committed
539
        attention_mask: Optional[torch.Tensor] = None,
540
        self_attn_mask_type: Optional[str] = None,
541
        window_size: Optional[Tuple[int, int]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
542
        encoder_output: Optional[torch.Tensor] = None,
543
        enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
544
545
        enc_dec_attn_mask_type: Optional[str] = None,
        enc_dec_window_size: Optional[Tuple[int, int]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
546
        is_first_microbatch: Optional[bool] = None,
cyanguwa's avatar
cyanguwa committed
547
        checkpoint_core_attention: bool = False,
548
        inference_params: Optional[InferenceParams] = None,
549
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
550
551
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
552
        alibi_slopes: Optional[torch.Tensor] = None,
553
554
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
555
556
        cu_seqlens_q_padded: Optional[torch.Tensor] = None,
        cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
557
558
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
559
        fast_zero_fill: bool = True,
560
        pad_between_seqs: Optional[bool] = None,
Przemek Tredak's avatar
Przemek Tredak committed
561
562
563
564
    ) -> torch.Tensor:
        """
        Transformer Layer: attention block and a feedforward network (MLP)

565
566
        .. note::

567
568
            Argument :attr:`attention_mask` is only used when :attr:`self_attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
569

Przemek Tredak's avatar
Przemek Tredak committed
570
571
572
        Parameters
        ----------
        hidden_states : torch.Tensor
573
            Input tensor.
574
        attention_mask : Optional[torch.Tensor], default = `None`
575
576
577
578
579
580
            Boolean tensor used to mask out self-attention softmax input. It should be
            in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable
            to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`"
            mask. It should be `None` for causal masks and "`no_mask`" type.
            A `True` value means the corresponding position is masked out and
            a `False` means that position is allowed to participate in attention.
581
        self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal',
582
583
584
585
586
587
            'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'},
            default = `causal`
            Type of attention mask passed into softmax operation for encoder.
            By default, causal masks are aligned to the top left corner of
            the softmax matrix. When "`bottom_right`" is specified in the mask type,
            causal masks are aligned to the bottom right corner.
588
        window_size: Optional[Tuple[int, int]], default = `None`
589
            Sliding window size for local attention in encoder.
cyanguwa's avatar
cyanguwa committed
590
        encoder_output : Optional[torch.Tensor], default = `None`
591
592
            Output of the encoder block to be fed into the decoder block if using
            `layer_type="decoder"`.
593
        enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
594
595
596
597
598
599
600
            default = `None`. Boolean tensors used to mask out inter-attention softmax input if
            using `layer_type="decoder"`. It should be a tuple of two masks in
            [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks.
            It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
            for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`".
            A `True` value means the corresponding position is masked out and a `False`
            means that position is allowed to participate in attention.
601
        enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
602
603
            default = `None`
            Type of attention mask passed into softmax operation for decoder.
604
        enc_dec_window_size: Optional[Tuple[int, int]], default = `None`
605
            Sliding window size for local attention in decoder.
Przemek Tredak's avatar
Przemek Tredak committed
606
        is_first_microbatch : {True, False, None}, default = None
607
608
609
610
611
612
613
614
615
616
617
618
            During training using either gradient accumulation or
            pipeline parallelism a minibatch of data is further split
            into microbatches. Between the microbatches of the same minibatch
            the model weights are not updated. Setting this parameter indicates
            whether the current microbatch is the first in a minibatch or not.
            When set, this parameter enables additional optimizations:

            * during FP8 training, it allows caching of the FP8 versions of
              the weights
            * it also allows skipping gradient accumulation during the
              first microbatch (since it is the first gradient being
              produced)
cyanguwa's avatar
cyanguwa committed
619
        checkpoint_core_attention: bool, default = `False`
620
621
622
623
            If true, forward activations for core attention are recomputed
            during the backward pass in order to save memory that would
            otherwise be occupied to store the forward activations until
            backprop.
624
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
625
626
            Embeddings for query and key tensors for applying rotary position
            embedding. By default no input embedding is applied.
627
        core_attention_bias_type: str, default = `no_bias`
628
            Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
629
        core_attention_bias: Optional[torch.Tensor], default = `None`
630
            Bias tensor for Q * K.T
631
        alibi_slopes: Optional[torch.Tensor], default = `None`
632
633
634
            ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
            It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
            to the attention score of query i and key j.
635
        cu_seqlens_q: Optional[torch.Tensor], default = `None`
636
637
638
            Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
            with shape [batch_size + 1] and dtype torch.int32.
            Used by encoders, or decoders' self-attention.
639
        cu_seqlens_kv: Optional[torch.Tensor], default = `None`
640
641
642
643
644
645
646
647
648
649
650
            Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
            and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
            Used by decoders' cross-attention.
        cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
            Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`,
            with shape [batch_size + 1] and dtype torch.int32. Set to `cu_seqlens_q` if None.
            Used by encoders, or decoders' self-attention.
        cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
            Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
            and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
            Set to `cu_seqlens_kv` if None. Used by decoders' cross-attention.
651
        max_seqlen_q: Optional[int], default = `None`
652
653
            Maximum sequence length in `query_layer`.
            Calculated from `cu_seqlens_q_padded` if not provided.
654
        max_seqlen_kv: Optional[int], default = `None`
655
656
            Maximum sequence length in `key_layer` and `value_layer`.
            Calculated from `cu_seqlens_kv_padded` if not provided.
657
        fast_zero_fill: bool, default = `True`
658
            Whether to set output tensors to 0 or not before use.
659
        inference_params: InferenceParams, default = None
660
661
            Inference parameters that are passed to the main model in order
            to efficiently calculate and store the context during inference.
662
663
        pad_between_seqs: Optional[bool], default = `None`
            If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
664
665
            If true, there are padding tokens between individual sequences in a packed batch,
            i.e. qkv_format = 'thd'.
Przemek Tredak's avatar
Przemek Tredak committed
666
667
        """

668
        if self_attn_mask_type is None:
669
            self_attn_mask_type = self.self_attn_mask_type
670
671
        if window_size is None:
            window_size = self.window_size
672
673
674
675
        if enc_dec_attn_mask_type is None:
            enc_dec_attn_mask_type = self.enc_dec_attn_mask_type
        if enc_dec_window_size is None:
            enc_dec_window_size = self.enc_dec_window_size
676
677
678
679

        assert (
            self_attn_mask_type in AttnMaskTypes
        ), f"self_attn_mask_type {self_attn_mask_type} not supported"
680
681
682
        assert (
            enc_dec_attn_mask_type in AttnMaskTypes
        ), f"enc_dec_attn_mask_type {enc_dec_attn_mask_type} not supported"
683

684
685
        hidden_states = hidden_states.contiguous()

686
687
688
689
690
        if self.sequence_parallel and self.seq_length is not None:
            assert (
                hidden_states.shape[0] == self.seq_length // self.tp_size
            ), "Sequence dimension must be split across TP group when using sequence parallel."

691
692
693
        if (
            "padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary"
        ) and attention_mask is not None:
694
695
696
            assert all(
                attention_mask[i].dtype == torch.bool for i in range(len(attention_mask))
            ), "Attention mask must be a boolean tensor or a list/tuple of two boolean tensors"
697
698
699
700
701
702
        if (
            "padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary"
        ) and enc_dec_attn_mask is not None:
            assert all(
                enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask))
            ), "Encoder-decoder attention mask must be boolean tensor(s)"
703

704
705
706
        if TEDebugState.debug_enabled:
            TransformerEngineBaseModule._validate_name(self)

Przemek Tredak's avatar
Przemek Tredak committed
707
708
        # For AMP
        if torch.is_autocast_enabled():
709
            hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype())
Przemek Tredak's avatar
Przemek Tredak committed
710
711
712
713

        # Self attention.
        self_attention_outputs = self.self_attention(
            hidden_states,
714
715
            attention_mask=attention_mask,
            attn_mask_type=self_attn_mask_type,
716
            window_size=window_size,
Przemek Tredak's avatar
Przemek Tredak committed
717
718
719
            inference_params=inference_params,
            is_first_microbatch=is_first_microbatch,
            checkpoint_core_attention=checkpoint_core_attention,
720
            rotary_pos_emb=rotary_pos_emb,
721
722
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
723
            alibi_slopes=alibi_slopes,
724
            cu_seqlens_q=cu_seqlens_q,
725
726
727
            cu_seqlens_kv=cu_seqlens_q,
            cu_seqlens_q_padded=cu_seqlens_q_padded,
            cu_seqlens_kv_padded=cu_seqlens_q_padded,
728
            max_seqlen_q=max_seqlen_q,
729
            max_seqlen_kv=max_seqlen_q,
730
            fast_zero_fill=fast_zero_fill,
731
            pad_between_seqs=pad_between_seqs,
Przemek Tredak's avatar
Przemek Tredak committed
732
        )
ngoyal2707's avatar
ngoyal2707 committed
733

Przemek Tredak's avatar
Przemek Tredak committed
734
735
        if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
            attention_output, attention_bias, residual = self_attention_outputs
736
737
738
739
            hidden_states = self._bias_dropout_add(
                attention_output, attention_bias, residual, self.drop_path
            )
        elif not self.parallel_attention_mlp:
Przemek Tredak's avatar
Przemek Tredak committed
740
            attention_output, attention_bias = self_attention_outputs
741
742
            hidden_states = self._bias_dropout_add(
                attention_output, attention_bias, hidden_states, self.drop_path
Przemek Tredak's avatar
Przemek Tredak committed
743
744
745
746
747
            )

        # Cross attention.
        if self.layer_type == "decoder":
            inter_attention_outputs = self.inter_attention(
748
                hidden_states,
749
                attention_mask=enc_dec_attn_mask,
750
751
                attn_mask_type=enc_dec_attn_mask_type,
                window_size=enc_dec_window_size,
Przemek Tredak's avatar
Przemek Tredak committed
752
                encoder_output=encoder_output,
753
                inference_params=inference_params,
Przemek Tredak's avatar
Przemek Tredak committed
754
755
                is_first_microbatch=is_first_microbatch,
                checkpoint_core_attention=checkpoint_core_attention,
756
                rotary_pos_emb=rotary_pos_emb,
757
758
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias=core_attention_bias,
759
                alibi_slopes=alibi_slopes,
760
761
762
763
764
765
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_kv=cu_seqlens_kv,
                cu_seqlens_q_padded=cu_seqlens_q_padded,
                cu_seqlens_kv_padded=cu_seqlens_kv_padded,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
766
                fast_zero_fill=fast_zero_fill,
767
                pad_between_seqs=pad_between_seqs,
Przemek Tredak's avatar
Przemek Tredak committed
768
769
770
771
772
            )
            if self.apply_residual_connection_post_layernorm:
                attention_output, attention_bias, residual = inter_attention_outputs
            else:
                attention_output, attention_bias = inter_attention_outputs
773
774
775
                residual = hidden_states

            hidden_states = self._bias_dropout_add(attention_output, attention_bias, residual)
Przemek Tredak's avatar
Przemek Tredak committed
776
777
778

        # MLP.
        mlp_outputs = self.layernorm_mlp(
779
780
            hidden_states,
            is_first_microbatch=is_first_microbatch,
Przemek Tredak's avatar
Przemek Tredak committed
781
782
783
        )
        if self.apply_residual_connection_post_layernorm:
            mlp_output, mlp_bias, residual = mlp_outputs
784
785
786
787
788
            output = self._bias_dropout_add(mlp_output, mlp_bias, residual, self.drop_path)
        elif self.parallel_attention_mlp:
            output = self._bias_dropout_add(
                self_attention_outputs, mlp_outputs, hidden_states, self.drop_path
            )
Przemek Tredak's avatar
Przemek Tredak committed
789
790
        else:
            mlp_output, mlp_bias = mlp_outputs
791
792
793
794
795
796
797
798
799
800
            output = self._bias_dropout_add(mlp_output, mlp_bias, hidden_states, self.drop_path)

        # For BERT like architectures.
        if self.output_layernorm:
            output = self.layernorm(output)

        # output: [s, b, h]
        return output

    def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None):
801
        if drop_path is None and bias is not None and bias.numel() != 0:
802
803
804
805
806
807
808
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
            else:
                bias_dropout_add_func = get_bias_dropout_add(self.training)
Przemek Tredak's avatar
Przemek Tredak committed
809
810

            with self.bias_dropout_add_exec_handler():
811
                output = bias_dropout_add_func(hidden_state, bias, residual, self.hidden_dropout)
Przemek Tredak's avatar
Przemek Tredak committed
812
        else:
813
            if bias is not None and bias.numel() != 0:
814
                hidden_state = hidden_state + bias
Przemek Tredak's avatar
Przemek Tredak committed
815
            out = torch.nn.functional.dropout(
816
                hidden_state, p=self.hidden_dropout, training=self.training
Przemek Tredak's avatar
Przemek Tredak committed
817
            )
818
819
            if drop_path is not None:
                out = drop_path(out)
ngoyal2707's avatar
ngoyal2707 committed
820
            output = residual + out
Przemek Tredak's avatar
Przemek Tredak committed
821
822

        return output