transformer.py 32.7 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
#
# See LICENSE for license information.

"""Transformer."""
import os
7
import warnings
Przemek Tredak's avatar
Przemek Tredak committed
8
from contextlib import nullcontext
9
from typing import Callable, List, Optional, Tuple, Union
Przemek Tredak's avatar
Przemek Tredak committed
10
11
12

import torch

13
import transformer_engine_extensions as tex
14
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
15
16
17
18
19
from transformer_engine.pytorch.attention import (
    InferenceParams,
    MultiheadAttention,
    check_set_window_size,
)
Przemek Tredak's avatar
Przemek Tredak committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from transformer_engine.pytorch.jit import (
    set_jit_fusion_options,
    warmup_jit_bias_dropout_add_all_dtypes,
    get_bias_dropout_add,
    bias_dropout_add_fused_train,
    bias_dropout_add_fused_inference,
)
from transformer_engine.pytorch.utils import (
    cast_if_needed,
    get_default_init_method,
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    LayerTypes,
    dist_group_type,
)
36
37
from transformer_engine.pytorch.distributed import get_distributed_world_size

Przemek Tredak's avatar
Przemek Tredak committed
38

39
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
cyanguwa's avatar
cyanguwa committed
40
41


42
__all__ = ["TransformerLayer"]
cyanguwa's avatar
cyanguwa committed
43

Przemek Tredak's avatar
Przemek Tredak committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

class DropPath(torch.nn.Module):
    """Drop paths (Stochastic Depth) per sample
    (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob: float = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        """DropPath FWD"""

        if self.drop_prob == 0.0 or not self.training:
            return hidden_state
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
        shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
        random_tensor = keep_prob + torch.rand(
            shape, dtype=hidden_state.dtype, device=hidden_state.device
        )
        random_tensor.floor_()  # binarize
        output = hidden_state.div(keep_prob) * random_tensor
        return output


class TransformerLayer(torch.nn.Module):
71
    r"""
Przemek Tredak's avatar
Przemek Tredak committed
72
73
74
    TransformerLayer is made up of an attention block and a feedforward network (MLP).
    This standard layer is based on the paper "Attention Is All You Need".

75
    .. note::
76

77
78
        Argument :attr:`attention_mask` in the `forward` call is only used when
        :attr:`self_attn_mask_type` includes `"padding"` or `"arbitrary"`.
79

Przemek Tredak's avatar
Przemek Tredak committed
80
81
82
83
84
85
86
87
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    ffn_hidden_size : int
                     intermediate size to which input samples are projected.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
88
89
90
91
92
93
94
95
    num_gqa_groups : int, default = `None`
                         number of GQA groups in the transformer layer.
                         Grouped Query Attention is described in
                         `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                         This only affects the keys and values, not the querys.
                         GQA-1 is equivalent to Multi-Query Attention
                         (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                         is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
Przemek Tredak's avatar
Przemek Tredak committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    hidden_dropout: float, default = 0.1
                   dropout probability for the dropout op after FC2 layer.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    apply_residual_connection_post_layernorm : bool, default = `False`
                                              if set to `True`, residual connections are taken
                                              from the output of layer norm (default is taken
                                              from input of layer norm)
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
    output_layernorm: bool, default = `False`
                     if set to `True`, layer normalization is applied on the output side,
                     after the final dropout-add. default behavior is to apply layer
                     normalization on the input side, before the QKV transformation.
122
123
124
125
126
    parallel_attention_mlp: bool, default = `False`
                           if set to `True`, self-attention and feedforward network are computed
                           based on the same input (in parallel) instead of sequentially.
                           Both blocks have an independent normalization.
                           This architecture is used in `Falcon` models.
Przemek Tredak's avatar
Przemek Tredak committed
127
128
129
130
131
132
133
    layer_type: {'encoder', 'decoder'}, default = `encoder`
               if set to `decoder`, an additional cross-attn block is added after self-attn.
               This can be used for structures like `T5` Transformer in conjunction with the
               `encoder` option.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
134
135
    self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'},
                        default = `causal`
136
137
138
139
140
                        type of attention mask passed into softmax operation. Overridden by
                        :attr:`self_attn_mask_type` in the `forward` method. The forward
                        arg is useful for dynamically changing mask types, e.g. a different
                        mask for training and inference. The init arg is useful for cases
                        involving compilation/tracing, e.g. ONNX export.
141
142
143
144
145
146
    window_size: Optional[Tuple[int, int]], default = `None`
                sliding window size for local attention, where query at position i attends to keys
                in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
                + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
                window and causal mask specifically. Similar to :attr:`self_attn_mask_type`, it can
                be overridden by :attr:`window_size` in `forward` as well.
147
148
149
150
151
152
153
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
154
155
    normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
                   type of normalization applied.
156
157
158
159
160
161
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
ngoyal2707's avatar
ngoyal2707 committed
162
163
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.
164
165
    activation : str, default = 'gelu'
          Type of activation used in MLP block.
166
          Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu' and 'qgelu'.
167
168
169
170
    device : Union[torch.device, str], default = "cuda"
          The device on which the parameters of the model will allocated. It is the user's
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
171
172
173
174
175
176
177
178
    attn_input_format: {'sbhd', 'bshd'}, default = 'sbhd'
                         This controls whether the dimensions of the
                         intermediate hidden states is 'batch first' ('bshd') or
                         'sequence first' ('sbhd'). `s` stands for the sequence
                         length, `b` batch size, `h` the number of heads, `d`
                         head size. Note that these formats are very closely
                         related to the `qkv_format` in the `MultiHeadAttention`
                         and `DotProductAttention` modules.
ngoyal2707's avatar
ngoyal2707 committed
179

Przemek Tredak's avatar
Przemek Tredak committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
201
202
203
204
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
205
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
Przemek Tredak's avatar
Przemek Tredak committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    seq_length: int
               sequence length of input samples. Needed for JIT Warmup, a technique where jit
               fused functions are warmed up before training to ensure same kernels are used for
               forward propogation and activation recompute phase.
    micro_batch_size: int
                     batch size per training step. Needed for JIT Warmup, a technique where jit
                     fused functions are warmed up before training to ensure same kernels are
                     used for forward propogation and activation recompute phase.
    drop_path_rate: float, default = 0.0
                   when > 0.0, applies stochastic depth per sample in
                   the main path of the residual block.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
    """

    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        num_attention_heads: int,
232
        num_gqa_groups: Optional[int] = None,
Przemek Tredak's avatar
Przemek Tredak committed
233
234
235
236
237
238
239
        layernorm_epsilon: float = 1e-5,
        hidden_dropout: float = 0.1,
        attention_dropout: float = 0.1,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
        layer_number: Optional[int] = None,
        kv_channels: Optional[int] = None,
240
        self_attn_mask_type: str = "causal",
241
        window_size: Optional[Tuple[int, int]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
242
243
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
244
        params_dtype: Optional[torch.dtype] = None,
Przemek Tredak's avatar
Przemek Tredak committed
245
246
247
248
249
250
251
        get_rng_state_tracker: Optional[Callable] = None,
        fuse_wgrad_accumulation: bool = False,
        seq_length: Optional[int] = None,
        micro_batch_size: Optional[int] = None,
        sequence_parallel: bool = False,
        apply_residual_connection_post_layernorm: bool = False,
        output_layernorm: bool = False,
252
        parallel_attention_mlp: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
253
254
255
256
        layer_type: str = "encoder",
        drop_path_rate: float = 0.0,
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
257
        zero_centered_gamma: bool = False,
258
        qkv_weight_interleaved: bool = True,
259
        ub_tp_comm_overlap: bool = False,
260
261
        ub_bulk_wgrad: bool = True,
        ub_bulk_dgrad: bool = True,
262
263
        ub_overlap_ag: bool = True,
        ub_overlap_rs: bool = True,
ngoyal2707's avatar
ngoyal2707 committed
264
        bias: bool = True,
265
266
        activation: str = 'gelu',
        normalization: str = "LayerNorm",
267
        device: Union[torch.device, str] = "cuda",
268
        attn_input_format: str = "sbhd",
Przemek Tredak's avatar
Przemek Tredak committed
269
270
271
    ) -> None:
        super().__init__()

272
273
274
275
276
        if ub_tp_comm_overlap:
            assert (
                tex.userbuf_comm_available()
            ), "Userbuffer communication backend not available."

277
        self.self_attn_mask_type = self_attn_mask_type
278
279
        self.window_size = window_size
        self.window_size = check_set_window_size(self_attn_mask_type, self.window_size)
280
        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
281
282
        ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
        ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
283
284
        ub_overlap_ag = ub_tp_comm_overlap and ub_overlap_ag
        ub_overlap_rs = ub_tp_comm_overlap and ub_overlap_rs
285

Przemek Tredak's avatar
Przemek Tredak committed
286
287
288
289
290
291
292
        bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
        self.layer_number = layer_number
        self.output_layernorm = output_layernorm
        self.layer_type = layer_type
        self.apply_residual_connection_post_layernorm = (
            apply_residual_connection_post_layernorm
        )
293

294
295
296
297
298
299
300
301
302
303
304
305
        if parallel_attention_mlp:
            assert self.layer_type == "encoder", "parallel_attention requires layer_type='encoder'"
            assert (
                not self.apply_residual_connection_post_layernorm
            ), "parallel_attention and apply_residual_connection_post_layernorm "\
               "not supported simultaneously."
            assert (
                not self.output_layernorm
            ), "parallel_attention and output_layernorm not supported simultaneously"

        self.parallel_attention_mlp = parallel_attention_mlp

Przemek Tredak's avatar
Przemek Tredak committed
306
307
308
309
310
311
312
        assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"

        if not fuse_qkv_params:
            assert (
                not fuse_wgrad_accumulation
            ), "Gradient accumulation fusion requires single QKV parameter."

313
314
315
        if not fuse_qkv_params:
            qkv_weight_interleaved = False

Przemek Tredak's avatar
Przemek Tredak committed
316
317
318
319
320
321
322
323
324
        self.kv_channels = (
            kv_channels if kv_channels else (hidden_size // num_attention_heads)
        )

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()

325
326
327
        self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
        self.seq_length = seq_length
Przemek Tredak's avatar
Przemek Tredak committed
328
329
330

        self.get_rng_state_tracker = get_rng_state_tracker

331
332
        self.attn_input_format = attn_input_format

Przemek Tredak's avatar
Przemek Tredak committed
333
334
335
336
337
338
339
340
341
342
343
344
        attention_args = (
            hidden_size,
            num_attention_heads,
            self.kv_channels,
            attention_dropout,
            layernorm_epsilon,
            init_method,
            output_layer_init_method,
        )
        common_attention_kwargs = {
            "layer_number": layer_number,
            "tp_group": tp_group,
345
            "tp_size": self.tp_size,
346
            "num_gqa_groups": num_gqa_groups,
Przemek Tredak's avatar
Przemek Tredak committed
347
348
349
350
351
352
353
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": self.sequence_parallel,
            "params_dtype": params_dtype,
            "return_layernorm_output": apply_residual_connection_post_layernorm,
            "set_parallel_mode": set_parallel_mode,
            "fuse_qkv_params": fuse_qkv_params,
cyanguwa's avatar
cyanguwa committed
354
            "zero_centered_gamma": zero_centered_gamma,
355
            "qkv_weight_interleaved" : qkv_weight_interleaved,
356
357
            "ub_bulk_wgrad" : ub_bulk_wgrad,
            "ub_bulk_dgrad" : ub_bulk_dgrad,
358
359
            "ub_overlap_ag" : ub_overlap_ag,
            "ub_overlap_rs" : ub_overlap_rs,
360
            "qkv_format" : self.attn_input_format,
Przemek Tredak's avatar
Przemek Tredak committed
361
362
        }

363
        self.self_attention = MultiheadAttention(
Przemek Tredak's avatar
Przemek Tredak committed
364
365
366
367
            *attention_args,
            **common_attention_kwargs,
            input_layernorm=not output_layernorm,
            attention_type="self",
ngoyal2707's avatar
ngoyal2707 committed
368
            bias=bias,
369
            return_bias=not self.parallel_attention_mlp,
370
            normalization=normalization,
371
            device=device,
Przemek Tredak's avatar
Przemek Tredak committed
372
373
374
        )

        if layer_type == "decoder":
375
            self.inter_attention = MultiheadAttention(
Przemek Tredak's avatar
Przemek Tredak committed
376
377
378
379
380
                *attention_args,
                **common_attention_kwargs,
                attn_mask_type="padding",
                input_layernorm=True,
                attention_type="cross",
ngoyal2707's avatar
ngoyal2707 committed
381
                bias=bias,
382
                return_bias=True,
383
                normalization=normalization,
384
                device=device,
Przemek Tredak's avatar
Przemek Tredak committed
385
386
            )

387
        # LayerNorm -> activation(Linear + Bias) -> Linear
Przemek Tredak's avatar
Przemek Tredak committed
388
389
        # parallel_mode not supported for LayerNormMLP,
        # FC1 is CPL and FC2 is RPL
390
391
        # In the case of GLU activation, FC1 handles both
        # Linear layers before the activation
Przemek Tredak's avatar
Przemek Tredak committed
392
393
394
395
396
397
        self.layernorm_mlp = LayerNormMLP(
            hidden_size,
            ffn_hidden_size,
            eps=layernorm_epsilon,
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            tp_group=tp_group,
398
            tp_size=self.tp_size,
Przemek Tredak's avatar
Przemek Tredak committed
399
400
401
            get_rng_state_tracker=get_rng_state_tracker,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
ngoyal2707's avatar
ngoyal2707 committed
402
            bias=bias,
403
            return_bias=not self.parallel_attention_mlp,
Przemek Tredak's avatar
Przemek Tredak committed
404
405
406
407
408
409
            sequence_parallel=self.sequence_parallel,
            params_dtype=params_dtype,
            return_layernorm_output=apply_residual_connection_post_layernorm,
            seq_length=seq_length,
            micro_batch_size=micro_batch_size,
            set_parallel_mode=set_parallel_mode,
410
            zero_centered_gamma=zero_centered_gamma,
411
412
            ub_bulk_wgrad=ub_bulk_wgrad,
            ub_bulk_dgrad=ub_bulk_dgrad,
413
414
            ub_overlap_rs=ub_overlap_rs,
            ub_overlap_ag=ub_overlap_ag,
415
            activation=activation,
416
            normalization=normalization,
417
            device=device,
Przemek Tredak's avatar
Przemek Tredak committed
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        )

        self.hidden_dropout = hidden_dropout
        self.bias_dropout_fusion = bias_dropout_fusion
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None

        # Set bias+dropout+add fusion grad_enable execution handler.
        TORCH_MAJOR = int(torch.__version__.split(".")[0])
        TORCH_MINOR = int(torch.__version__.split(".")[1])
        use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
        self.bias_dropout_add_exec_handler = (
            nullcontext if use_nvfuser else torch.enable_grad
        )

        if self.bias_dropout_fusion:
            set_jit_fusion_options()
            if seq_length and micro_batch_size:
                if self.sequence_parallel:
436
                    seq_length = seq_length // self.tp_size
Przemek Tredak's avatar
Przemek Tredak committed
437
438
439
440
                warmup_jit_bias_dropout_add_all_dtypes(
                    hidden_size, seq_length, micro_batch_size
                )

441
442
443
444
        norm_module = {
                "LayerNorm": LayerNorm,
                "RMSNorm": RMSNorm,
        }
Przemek Tredak's avatar
Przemek Tredak committed
445
        if self.output_layernorm:
446
            self.layernorm = norm_module[normalization](
Przemek Tredak's avatar
Przemek Tredak committed
447
448
449
450
                hidden_size,
                eps=layernorm_epsilon,
                sequence_parallel=self.sequence_parallel,
                params_dtype=params_dtype,
451
452
                zero_centered_gamma=zero_centered_gamma,
                device=device,
Przemek Tredak's avatar
Przemek Tredak committed
453
454
455
            )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
456
457
458
459
460
461
462
463
464
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
Przemek Tredak's avatar
Przemek Tredak committed
465
466
467
468
469
470
471
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_tensor_parallel_group"):
                child.set_tensor_parallel_group(tp_group)

472
    def set_context_parallel_group(
473
474
        self,
        cp_group: Union[dist_group_type, None],
475
        cp_global_ranks: List[int],
476
477
        cp_stream: torch.cuda.Stream,
    ) -> None:
478
479
480
481
482
483
484
485
486
487
488
489
490
        """
        Set the context parallel attributes for the given
        module before executing the forward pass.

        Parameters
        ----------
        cp_group : ProcessGroup
                  context parallel process group.
        cp_global_ranks : List[int]
                         list of global ranks in the context group.
        cp_stream : torch.cuda.Stream
                   cuda stream for context parallel execution.
        """
491
492
493
494
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
495
496
            if hasattr(child, "set_context_parallel_group"):
                child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
497

Przemek Tredak's avatar
Przemek Tredak committed
498
499
500
    def forward(
        self,
        hidden_states: torch.Tensor,
cyanguwa's avatar
cyanguwa committed
501
        attention_mask: Optional[torch.Tensor] = None,
502
        self_attn_mask_type: Optional[str] = None,
503
        window_size: Optional[Tuple[int, int]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
504
        encoder_output: Optional[torch.Tensor] = None,
505
        enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
506
        is_first_microbatch: Optional[bool] = None,
cyanguwa's avatar
cyanguwa committed
507
        checkpoint_core_attention: bool = False,
508
        inference_params: Optional[InferenceParams] = None,
509
        rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
510
511
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[torch.Tensor] = None,
512
        alibi_slopes: Optional[torch.Tensor] = None,
513
        fast_zero_fill: bool = True,
Przemek Tredak's avatar
Przemek Tredak committed
514
515
516
517
    ) -> torch.Tensor:
        """
        Transformer Layer: attention block and a feedforward network (MLP)

518
519
        .. note::

520
521
            Argument :attr:`attention_mask` is only used when :attr:`self_attn_mask_type`
            includes `"padding"` or `"arbitrary"`.
522

Przemek Tredak's avatar
Przemek Tredak committed
523
524
525
526
        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
527
        attention_mask : Optional[torch.Tensor], default = `None`
528
                        Boolean tensor used to mask out self-attention softmax input.
529
530
531
532
533
534
                        It should be in [batch_size, 1, 1, seqlen_q] for 'padding' mask,
                        and broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
                        for 'arbitrary'. It should be 'None' for 'causal' and 'no_mask'.
        self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
                            default = `causal`
                            Type of attention mask passed into softmax operation.
535
536
        window_size: Optional[Tuple[int, int]], default = `None`
                    sliding window size for local attention.
cyanguwa's avatar
cyanguwa committed
537
        encoder_output : Optional[torch.Tensor], default = `None`
Przemek Tredak's avatar
Przemek Tredak committed
538
539
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
540
541
542
543
544
545
        enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
             default = `None`. Boolean tensors used to mask out inter-attention softmax input if
             using `layer_type="decoder"`. It should be a tuple of two masks in
             [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for 'padding' mask.
             It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]
             for 'arbitrary' mask. It should be 'None' for 'causal' and 'no_mask'.
Przemek Tredak's avatar
Przemek Tredak committed
546
547
548
549
550
551
552
553
554
555
556
557
558
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
cyanguwa's avatar
cyanguwa committed
559
        checkpoint_core_attention: bool, default = `False`
Przemek Tredak's avatar
Przemek Tredak committed
560
561
562
563
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
564
565
566
        rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
                       Embeddings for query and key tensors for applying rotary position
                       embedding. By default no input embedding is applied.
567
        core_attention_bias_type: str, default = `no_bias`
568
                    Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
569
570
        core_attention_bias: Optional[torch.Tensor], default = `None`
                    Bias tensor for Q * K.T
571
572
573
574
        alibi_slopes: Optional[torch.Tensor], default = `None`
                     ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
                     It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
                     to the attention score of query i and key j.
575
576
        fast_zero_fill: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
577
578
579
        inference_params: InferenceParams, default = None
                         Inference parameters that are passed to the main model in order
                         to efficienly calculate and store the context during inference.
Przemek Tredak's avatar
Przemek Tredak committed
580
581
        """

582
583
        if self_attn_mask_type is not None:
            window_size = check_set_window_size(self_attn_mask_type, window_size)
584
        if self_attn_mask_type is None:
585
            self_attn_mask_type = self.self_attn_mask_type
586
587
        if window_size is None:
            window_size = self.window_size
588
589
590
591
592

        assert (
            self_attn_mask_type in AttnMaskTypes
        ), f"self_attn_mask_type {self_attn_mask_type} not supported"

593
594
        hidden_states = hidden_states.contiguous()

595
596
597
598
599
        if self.sequence_parallel and self.seq_length is not None:
            assert (
                hidden_states.shape[0] == self.seq_length // self.tp_size
            ), "Sequence dimension must be split across TP group when using sequence parallel."

600
601
602
        if (("padding" in self_attn_mask_type
            or self_attn_mask_type == "arbitrary")
            and attention_mask is not None):
603
604
605
606
            assert (
                attention_mask.dtype == torch.bool
            ), "Attention mask must be a boolean tensor"

Przemek Tredak's avatar
Przemek Tredak committed
607
608
609
610
611
612
613
614
615
        # For AMP
        if torch.is_autocast_enabled():
            hidden_states = cast_if_needed(
                hidden_states, torch.get_autocast_gpu_dtype()
            )

        # Self attention.
        self_attention_outputs = self.self_attention(
            hidden_states,
616
617
            attention_mask=attention_mask,
            attn_mask_type=self_attn_mask_type,
618
            window_size=window_size,
Przemek Tredak's avatar
Przemek Tredak committed
619
620
621
            inference_params=inference_params,
            is_first_microbatch=is_first_microbatch,
            checkpoint_core_attention=checkpoint_core_attention,
622
            rotary_pos_emb=rotary_pos_emb,
623
624
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
625
            alibi_slopes=alibi_slopes,
626
            fast_zero_fill=fast_zero_fill,
Przemek Tredak's avatar
Przemek Tredak committed
627
        )
ngoyal2707's avatar
ngoyal2707 committed
628

Przemek Tredak's avatar
Przemek Tredak committed
629
630
        if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
            attention_output, attention_bias, residual = self_attention_outputs
631
632
633
634
            hidden_states = self._bias_dropout_add(
                attention_output, attention_bias, residual, self.drop_path
            )
        elif not self.parallel_attention_mlp:
Przemek Tredak's avatar
Przemek Tredak committed
635
            attention_output, attention_bias = self_attention_outputs
636
637
            hidden_states = self._bias_dropout_add(
                attention_output, attention_bias, hidden_states, self.drop_path
Przemek Tredak's avatar
Przemek Tredak committed
638
639
640
641
642
            )

        # Cross attention.
        if self.layer_type == "decoder":
            inter_attention_outputs = self.inter_attention(
643
                hidden_states,
644
                attention_mask=enc_dec_attn_mask,
645
                window_size=window_size,
Przemek Tredak's avatar
Przemek Tredak committed
646
647
648
                encoder_output=encoder_output,
                is_first_microbatch=is_first_microbatch,
                checkpoint_core_attention=checkpoint_core_attention,
649
650
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias=core_attention_bias,
651
                alibi_slopes=alibi_slopes,
652
                fast_zero_fill=fast_zero_fill,
Przemek Tredak's avatar
Przemek Tredak committed
653
654
655
656
657
            )
            if self.apply_residual_connection_post_layernorm:
                attention_output, attention_bias, residual = inter_attention_outputs
            else:
                attention_output, attention_bias = inter_attention_outputs
658
659
660
                residual = hidden_states

            hidden_states = self._bias_dropout_add(attention_output, attention_bias, residual)
Przemek Tredak's avatar
Przemek Tredak committed
661
662
663

        # MLP.
        mlp_outputs = self.layernorm_mlp(
664
            hidden_states, is_first_microbatch=is_first_microbatch
Przemek Tredak's avatar
Przemek Tredak committed
665
666
667
        )
        if self.apply_residual_connection_post_layernorm:
            mlp_output, mlp_bias, residual = mlp_outputs
668
669
670
671
672
            output = self._bias_dropout_add(mlp_output, mlp_bias, residual, self.drop_path)
        elif self.parallel_attention_mlp:
            output = self._bias_dropout_add(
                self_attention_outputs, mlp_outputs, hidden_states, self.drop_path
            )
Przemek Tredak's avatar
Przemek Tredak committed
673
674
        else:
            mlp_output, mlp_bias = mlp_outputs
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
            output = self._bias_dropout_add(mlp_output, mlp_bias, hidden_states, self.drop_path)

        # For BERT like architectures.
        if self.output_layernorm:
            output = self.layernorm(output)

        # output: [s, b, h]
        return output

    def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None):
        if drop_path is None and bias.numel() != 0:
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
            else:
                bias_dropout_add_func = get_bias_dropout_add(self.training)
Przemek Tredak's avatar
Przemek Tredak committed
693
694
695

            with self.bias_dropout_add_exec_handler():
                output = bias_dropout_add_func(
696
                    hidden_state, bias, residual, self.hidden_dropout
Przemek Tredak's avatar
Przemek Tredak committed
697
698
                )
        else:
699
700
            if bias.numel() != 0:
                hidden_state = hidden_state + bias
Przemek Tredak's avatar
Przemek Tredak committed
701
            out = torch.nn.functional.dropout(
702
                hidden_state, p=self.hidden_dropout, training=self.training
Przemek Tredak's avatar
Przemek Tredak committed
703
            )
704
705
            if drop_path is not None:
                out = drop_path(out)
ngoyal2707's avatar
ngoyal2707 committed
706
            output = residual + out
Przemek Tredak's avatar
Przemek Tredak committed
707
708

        return output