transformer.py 52.5 KB
Newer Older
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
7
#
# See LICENSE for license information.

"""Transformer."""
import os
import math
8
import warnings
9
from importlib.metadata import version
10
from distutils.version import LooseVersion
Przemek Tredak's avatar
Przemek Tredak committed
11
12
13
14
15
from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union

import torch

cyanguwa's avatar
cyanguwa committed
16
17
from flash_attn.flash_attn_interface import flash_attn_unpadded_func

18
import transformer_engine_extensions as tex
19
from transformer_engine.pytorch.module import LayerNormLinear, Linear, LayerNormMLP, LayerNorm
Przemek Tredak's avatar
Przemek Tredak committed
20
21
22
23
24
25
26
27
28
29
from transformer_engine.pytorch.jit import (
    set_jit_fusion_options,
    warmup_jit_bias_dropout_add_all_dtypes,
    get_bias_dropout_add,
    bias_dropout_add_fused_train,
    bias_dropout_add_fused_inference,
)
from transformer_engine.pytorch.utils import (
    divide,
    attention_mask_func,
30
    split_tensor_along_dim,
Przemek Tredak's avatar
Przemek Tredak committed
31
32
    cast_if_needed,
    get_default_init_method,
33
    get_device_compute_capability,
Przemek Tredak's avatar
Przemek Tredak committed
34
35
36
37
38
39
40
41
42
43
44
45
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    AttnTypes,
    LayerTypes,
    dist_group_type,
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
    get_distributed_world_size,
    checkpoint,
)
46
from transformer_engine.pytorch.export import is_in_onnx_export_mode
Przemek Tredak's avatar
Przemek Tredak committed
47

48
49
_flash_attn_version = LooseVersion(version("flash-attn"))
_flash_attn_version_required = LooseVersion("1.0.2")
50
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
cyanguwa's avatar
cyanguwa committed
51
52
53
54


__all__ = ["DotProductAttention", "TransformerLayer"]

Przemek Tredak's avatar
Przemek Tredak committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

class DropPath(torch.nn.Module):
    """Drop paths (Stochastic Depth) per sample
    (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob: float = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        """DropPath FWD"""

        if self.drop_prob == 0.0 or not self.training:
            return hidden_state
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
        shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
        random_tensor = keep_prob + torch.rand(
            shape, dtype=hidden_state.dtype, device=hidden_state.device
        )
        random_tensor.floor_()  # binarize
        output = hidden_state.div(keep_prob) * random_tensor
        return output


cyanguwa's avatar
cyanguwa committed
81
class UnfusedDotProductAttention(torch.nn.Module):
Przemek Tredak's avatar
Przemek Tredak committed
82
83
84
85
86
87
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
cyanguwa's avatar
cyanguwa committed
88
89
90
        norm_factor: float,
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
Przemek Tredak's avatar
Przemek Tredak committed
91
        attn_mask_type: str = "causal",
92
        layer_number: Optional[int] = None,
Przemek Tredak's avatar
Przemek Tredak committed
93
94
95
96
97
98
99
    ) -> None:
        super().__init__()

        assert (
            attn_mask_type in AttnMaskTypes
        ), f"attn_mask_type {attn_mask_type} not supported"

cyanguwa's avatar
cyanguwa committed
100
101
        self.norm_factor = norm_factor
        self.attention_dropout_ctx = attention_dropout_ctx
102
        self.layer_number = layer_number
Przemek Tredak's avatar
Przemek Tredak committed
103
104

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
cyanguwa's avatar
cyanguwa committed
105
            attn_mask_type,
Przemek Tredak's avatar
Przemek Tredak committed
106
107
108
109
110
111
112
113
114
115
116
117
118
            attention_mask_func,
        )

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
cyanguwa's avatar
cyanguwa committed
119
        attention_mask: Optional[torch.Tensor] = None,
Przemek Tredak's avatar
Przemek Tredak committed
120
121
    ) -> torch.Tensor:
        """core attention fprop"""
cyanguwa's avatar
cyanguwa committed
122
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
123
        apply_qk_layer_scaling = self.layer_number is not None and key_layer.dtype == torch.float16
cyanguwa's avatar
cyanguwa committed
124

Przemek Tredak's avatar
Przemek Tredak committed
125
126
127
128
129
130
131
132
133
        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

        # [sq, b, np, hn] -> [sq, b * np, hn]
134
        query_layer = query_layer.reshape(
Przemek Tredak's avatar
Przemek Tredak committed
135
136
137
            output_size[2], output_size[0] * output_size[1], -1
        )
        # [sk, b, np, hn] -> [sk, b * np, hn]
138
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)
Przemek Tredak's avatar
Przemek Tredak committed
139
140
141
142
143
144
145
146
147
148

        # preallocting result tensor: [b * np, sq, sk]
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
            dtype=query_layer.dtype,
            device=torch.cuda.current_device(),
        )

149
150
151
152
        scale = self.norm_factor
        if apply_qk_layer_scaling:
            scale *= self.layer_number

Przemek Tredak's avatar
Przemek Tredak committed
153
154
155
156
157
158
        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
            matmul_result,
            query_layer.transpose(0, 1),  # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0,
159
            alpha=(1.0 / scale),
Przemek Tredak's avatar
Przemek Tredak committed
160
161
162
163
164
165
        )

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # attention scores and attention mask [b, np, sq, sk]
166
167
        softmax_scale = self.layer_number if apply_qk_layer_scaling else None
        attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, softmax_scale)
Przemek Tredak's avatar
Przemek Tredak committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with self.attention_dropout_ctx():
            attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
184
        value_layer = value_layer.reshape(
Przemek Tredak's avatar
Przemek Tredak committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
            value_layer.size(0), output_size[0] * output_size[1], -1
        )

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(
            output_size[0] * output_size[1], output_size[2], -1
        )

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
cyanguwa's avatar
cyanguwa committed
203
        context_layer = context_layer.view(seqlen, batch_size, -1)
Przemek Tredak's avatar
Przemek Tredak committed
204
205
206
207

        return context_layer


cyanguwa's avatar
cyanguwa committed
208
209
210
211
212
213
214
215
216
217
218
219
220
class FlashAttention(torch.nn.Module):
    """Dot product attention implementation by using the flash-attn package.
    """

    def __init__(
        self,
        norm_factor: float,
        attention_dropout: float = 0.0,
        attention_dropout_ctx: Optional[Callable] = nullcontext,
        attn_mask_type: str = "causal",
    ) -> None:
        super().__init__()

221
222
223
        assert (
            _flash_attn_version >= _flash_attn_version_required
        ), f"FlashAttention minimum version {_flash_attn_version_required} is required."
cyanguwa's avatar
cyanguwa committed
224
225
        assert (
            attn_mask_type == "causal"
226
        ), 'FlashAttention currently only supports causal attention mask.'
cyanguwa's avatar
cyanguwa committed
227
228
229
230
231

        self.attn_causal_mask = attn_mask_type == "causal"
        self.norm_factor = norm_factor
        self.attention_dropout_ctx = attention_dropout_ctx
        self.attention_dropout = attention_dropout
232
        self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
cyanguwa's avatar
cyanguwa committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """flash-attn fprop"""

        assert (
            (query_layer.dtype in [torch.float16, torch.bfloat16])
            and (key_layer.dtype in [torch.float16, torch.bfloat16])
            and (value_layer.dtype in [torch.float16, torch.bfloat16])
            ), 'FlashAttention currently only supports FP16 and BF16.'
        assert (
            query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
            ), 'FlashAttention currently only supports CUDA tensors.'
        assert (
            attention_mask is None
        ), 'FlashAttention currently does not support external attention mask.'

        query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
                       for x in (query_layer, key_layer, value_layer)]

        batch_size, seqlen = query_layer.shape[0], query_layer.shape[1]

        # [b, sq, np, hn]
        query_layer, key_layer, value_layer = [
            x.view(x.shape[0] * x.shape[1], *x.shape[2:])
            for x in [query_layer, key_layer, value_layer]
        ]

        max_seqlen = seqlen
        cu_seqlens = torch.arange(
            0,
            (batch_size + 1) * seqlen,
            step=seqlen,
            dtype=torch.int32,
            device=query_layer.device)

        with self.attention_dropout_ctx():
            output = flash_attn_unpadded_func(
                query_layer, key_layer, value_layer, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
                self.attention_dropout if self.training else 0.0,
278
279
                softmax_scale=1.0/self.norm_factor, causal=self.attn_causal_mask,
                deterministic=self.deterministic,
cyanguwa's avatar
cyanguwa committed
280
281
282
283
284
285
286
287
288
289
290
            )

        # [(b sq), np, hn] -> [sq, b, (np hn)]
        return output.view(batch_size, seqlen, -1).transpose(0, 1).contiguous()


class DotProductAttention(torch.nn.Module):
    """Allows the model to jointly attend to information from different
    representation subspaces as described in the paper:
    `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.

291
292
293
294
295
    .. note::

        Argument :attr:`attention_mask` will be ignored in the `forward` call when
        :attr:`attn_mask_type` is set to `"causal"`.

cyanguwa's avatar
cyanguwa committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    .. warning::

        For the default attention mechanism, this module executes a non-deterministic version of
        `flash-attn <https://github.com/ksivaman/flash-attention>`_ whenever possible in order to
        achieve optimal performance. To observe deterministic behavior, set the environment
        variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order to disable
        `flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.

    Parameters
    ----------
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    kv_channels : int
                number of key-value channels.
    attention_dropout: float, default = 0.0
                      dropout probability for the dropout op during multi-head attention.
    attn_mask_type: {'causal', 'padding'}, default = `causal`
                   type of attention mask passed into softmax operation.
314
315
316
    layer_number: int, default = `None`
                 layer number of the current `DotProductAttention` when multiple such modules
                 are concatenated, for instance in consecutive transformer blocks.
cyanguwa's avatar
cyanguwa committed
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_size : int, default = 1
             tensor parallel world size.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    """

    def __init__(
        self,
        num_attention_heads: int,
        kv_channels: int,
        attention_dropout: float = 0.0,
        attn_mask_type: str = "causal",
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
338
        layer_number: Optional[int] = None,
cyanguwa's avatar
cyanguwa committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    ) -> None:
        super().__init__()

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_group = tp_group
        self.get_rng_state_tracker = get_rng_state_tracker

        projection_size = kv_channels * num_attention_heads
        self.hidden_size_per_partition = divide(projection_size, tp_size)
        self.hidden_size_per_attention_head = divide(
            projection_size, num_attention_heads
        )

        if sequence_parallel or get_rng_state_tracker is None:
            attention_dropout_ctx = nullcontext
        else:
            attention_dropout_ctx = get_rng_state_tracker().fork

        norm_factor = math.sqrt(self.hidden_size_per_attention_head)

359
        self.device_compute_capability = get_device_compute_capability()
cyanguwa's avatar
cyanguwa committed
360
361
362
        self.use_flash_attention = (
            int(os.getenv("NVTE_FLASH_ATTN", "1"))
            and attn_mask_type == "causal"
363
            and self.device_compute_capability >= 8.0
cyanguwa's avatar
cyanguwa committed
364
365
366
367
368
369
370
371
372
        )

        attn_kwargs = {
            "attention_dropout": attention_dropout,
            "attention_dropout_ctx": attention_dropout_ctx,
            "attn_mask_type": attn_mask_type,
        }

        if self.use_flash_attention:
373
            self.flash_attention = FlashAttention(norm_factor, **attn_kwargs)
cyanguwa's avatar
cyanguwa committed
374
375
        # Instantiating both types since use of flash-attn
        # might be ruled out due to forward inputs.
376
377
        self.unfused_attention = UnfusedDotProductAttention(
            norm_factor, **attn_kwargs, layer_number=layer_number)
cyanguwa's avatar
cyanguwa committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409

    def _checkpointed_attention_forward(
        self,
        attention_func: Callable,
        *forward_args: Tuple[torch.Tensor, ...],
    ) -> torch.Tensor:
        """Forward method with activation checkpointing."""

        def custom_forward(*inputs):
            return attention_func(*inputs)

        hidden_states = checkpoint(
            custom_forward,
            False,
            self.get_rng_state_tracker,
            self.tp_group,
            *forward_args,
        )

        return hidden_states

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        checkpoint_core_attention: bool = False,
    ) -> torch.Tensor:
        """
        Dot Product Attention Layer.

410
411
412
413
414
        .. note::

            Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
            is set to `"causal"`.

cyanguwa's avatar
cyanguwa committed
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
        .. note::

            Input tensors :attr:`query_layer`, :attr:`key_layer`, and :attr:`value_layer`
            must each be of shape (:attr:`sequence_length`, :attr:`batch_size`,
            :attr:`num_attention_heads`, :attr:`kv_channels`). Output of shape
            (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`
            * :attr:`kv_channels`) is returned.

        Parameters
        ----------
        query_layer : torch.Tensor
                     Query tensor.
        key_layer : torch.Tensor
                   Key tensor.
        value_layer : torch.Tensor
                     Value tensor.
        attention_mask : Optional[torch.Tensor], default = `None`
                        Boolean tensor used to mask out softmax input when not using flash-attn.
        checkpoint_core_attention : bool, default = `False`
                                   If true, forward activations for attention are recomputed
                                   during the backward pass in order to save memory that would
                                   otherwise be occupied to store the forward activations until
                                   backprop.
        """

        use_flash_attention = self.use_flash_attention
441
        if (query_layer.dtype not in [torch.bfloat16, torch.float16]
cyanguwa's avatar
cyanguwa committed
442
443
            or key_layer.dtype not in [torch.bfloat16, torch.float16]
            or value_layer.dtype not in [torch.bfloat16, torch.float16]
444
            or (self.device_compute_capability == 8.6 and key_layer.shape[-1] > 64)
cyanguwa's avatar
cyanguwa committed
445
446
447
        ):
            use_flash_attention = False

448
449
450
        if is_in_onnx_export_mode():
            use_flash_attention = False

cyanguwa's avatar
cyanguwa committed
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
        if use_flash_attention:
            if checkpoint_core_attention:
                return self._checkpointed_attention_forward(self.flash_attention,
                                                            query_layer,
                                                            key_layer,
                                                            value_layer)
            return self.flash_attention(query_layer, key_layer, value_layer)

        if checkpoint_core_attention:
            return self._checkpointed_attention_forward(
                self.unfused_attention,
                query_layer,
                key_layer,
                value_layer,
                attention_mask,
            )
        return self.unfused_attention(query_layer, key_layer, value_layer, attention_mask)


Przemek Tredak's avatar
Przemek Tredak committed
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
class MultiHeadAttention(torch.nn.Module):
    """Parallel attention w/o QKV and Proj Gemms
    BMM1 -> softmax + dropout -> BMM2
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        kv_channels: int,
        attention_dropout: float,
        layernorm_epsilon: float,
        init_method: Callable,
        output_layer_init_method: Callable,
        layer_number: Optional[int] = None,
        attn_mask_type: str = "causal",
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        fuse_wgrad_accumulation: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        sequence_parallel: bool = False,
        params_dtype: torch.dtype = torch.float32,
        return_layernorm_output: bool = False,
        input_layernorm: bool = False,
        attention_type: str = "self",
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
cyanguwa's avatar
cyanguwa committed
497
        zero_centered_gamma: bool = False,
498
        qkv_weight_interleaved: bool = True,
499
500
501
502
        ub_bulk_wgrad: bool = False,
        ub_bulk_dgrad: bool = False,
        ub_split_rs: bool = False,
        ub_split_ag: bool = False,
ngoyal2707's avatar
ngoyal2707 committed
503
        bias: bool = True,
Przemek Tredak's avatar
Przemek Tredak committed
504
505
506
507
508
509
510
511
512
513
    ) -> None:
        super().__init__()
        self.layer_number = (layer_number,)
        self.input_layernorm = input_layernorm
        self.attention_type = attention_type
        self.get_rng_state_tracker = get_rng_state_tracker
        self.tp_group = tp_group
        self.return_layernorm_output = return_layernorm_output
        self.params_dtype = params_dtype
        self.init_method = init_method
514
        self.attn_mask_type = attn_mask_type
Przemek Tredak's avatar
Przemek Tredak committed
515

516
517
518
519
        if not fuse_qkv_params:
            qkv_weight_interleaved = False
        self.qkv_weight_interleaved = qkv_weight_interleaved

Przemek Tredak's avatar
Przemek Tredak committed
520
521
522
523
524
525
526
527
        assert (
            attention_type in AttnTypes
        ), f"attention_type {attention_type} not supported"

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.tp_size = tp_size
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

528
        self.hidden_size_per_attention_head = kv_channels
Przemek Tredak's avatar
Przemek Tredak committed
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
        self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)

        common_gemm_kwargs = {
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": sequence_parallel,
            "params_dtype": params_dtype,
        }

        qkv_parallel_mode = "column" if set_parallel_mode else None

        if self.attention_type == "self":
            if self.input_layernorm:
                self.layernorm_qkv = LayerNormLinear(
                    hidden_size,
                    3 * hidden_size,
                    eps=layernorm_epsilon,
                    init_method=init_method,
ngoyal2707's avatar
ngoyal2707 committed
549
                    bias=bias,
Przemek Tredak's avatar
Przemek Tredak committed
550
551
552
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
553
                    parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
554
                    zero_centered_gamma=zero_centered_gamma,
555
556
557
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
                    ub_split_ag=ub_split_ag,
Przemek Tredak's avatar
Przemek Tredak committed
558
559
560
561
562
563
564
                    **common_gemm_kwargs,
                )
            else:
                self.qkv = Linear(
                    hidden_size,
                    3 * hidden_size,
                    init_method=init_method,
ngoyal2707's avatar
ngoyal2707 committed
565
                    bias=bias,
Przemek Tredak's avatar
Przemek Tredak committed
566
567
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
568
                    parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
Przemek Tredak's avatar
Przemek Tredak committed
569
570
571
572
573
574
575
576
577
                    **common_gemm_kwargs,
                )
        else:
            if self.input_layernorm:
                self.layernorm_query = LayerNormLinear(
                    hidden_size,
                    hidden_size,
                    eps=layernorm_epsilon,
                    init_method=init_method,
ngoyal2707's avatar
ngoyal2707 committed
578
                    bias=bias,
Przemek Tredak's avatar
Przemek Tredak committed
579
580
581
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    return_layernorm_output=return_layernorm_output,
582
                    zero_centered_gamma=zero_centered_gamma,
583
584
585
                    ub_bulk_wgrad=ub_bulk_wgrad,
                    ub_bulk_dgrad=ub_bulk_dgrad,
                    ub_split_ag=ub_split_ag,
Przemek Tredak's avatar
Przemek Tredak committed
586
587
588
                    **common_gemm_kwargs,
                )
            else:
589
                self.query_layer = Linear(
Przemek Tredak's avatar
Przemek Tredak committed
590
591
592
                    hidden_size,
                    hidden_size,
                    init_method=init_method,
ngoyal2707's avatar
ngoyal2707 committed
593
                    bias=bias,
Przemek Tredak's avatar
Przemek Tredak committed
594
595
596
597
598
599
600
601
                    return_bias=False,
                    parallel_mode=qkv_parallel_mode,
                    **common_gemm_kwargs,
                )
            self.key_value = Linear(
                hidden_size,
                2 * hidden_size,
                init_method=init_method,
ngoyal2707's avatar
ngoyal2707 committed
602
                bias=bias,
Przemek Tredak's avatar
Przemek Tredak committed
603
604
                return_bias=False,
                parallel_mode=qkv_parallel_mode,
605
                parameters_split=("key_", "value_") if not fuse_qkv_params else None,
Przemek Tredak's avatar
Przemek Tredak committed
606
607
608
                **common_gemm_kwargs,
            )

cyanguwa's avatar
cyanguwa committed
609
610
        # Attention.
        self.core_attention = DotProductAttention(
Przemek Tredak's avatar
Przemek Tredak committed
611
612
613
614
615
616
617
            num_attention_heads,
            kv_channels,
            attention_dropout,
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            attn_mask_type=attn_mask_type,
            sequence_parallel=sequence_parallel,
cyanguwa's avatar
cyanguwa committed
618
            tp_group=tp_group,
619
            layer_number=layer_number,
Przemek Tredak's avatar
Przemek Tredak committed
620
621
622
623
624
625
626
        )

        # Linear
        self.proj = Linear(
            hidden_size,
            hidden_size,
            init_method=output_layer_init_method,
ngoyal2707's avatar
ngoyal2707 committed
627
            bias=bias,
Przemek Tredak's avatar
Przemek Tredak committed
628
629
            return_bias=True,
            parallel_mode="row" if set_parallel_mode else None,
630
631
            ub_split_rs=ub_split_rs,
            ub_split_ag=ub_split_ag,
Przemek Tredak's avatar
Przemek Tredak committed
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
            **common_gemm_kwargs,
        )


    def _allocate_memory(
        self, inference_max_sequence_len: int, batch_size: int
    ) -> torch.Tensor:
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device(),
        )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
        """Set TP group"""
        self.tp_group = tp_group

    def forward(
        self,
        hidden_states: torch.Tensor,
cyanguwa's avatar
cyanguwa committed
655
        attention_mask: Optional[torch.Tensor] = None,
Przemek Tredak's avatar
Przemek Tredak committed
656
657
        encoder_output: Optional[torch.Tensor] = None,
        is_first_microbatch: Optional[bool] = None,
cyanguwa's avatar
cyanguwa committed
658
        checkpoint_core_attention: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
659
660
661
662
663
        inference_params: Optional[Any] = None,
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """MultiHeadAttention FWD"""
        # hidden_states: [sq, b, h]

664
        if self.attn_mask_type != "causal" and attention_mask is not None:
665
666
667
668
            assert (
                attention_mask.dtype == torch.bool
            ), "Attention mask must be a boolean tensor"

Przemek Tredak's avatar
Przemek Tredak committed
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================

        if inference_params and self.layer_number is not None:
            if self.layer_number not in inference_params.key_value_memory_dict:
                inf_max_seq_len = inference_params.max_sequence_len
                inf_max_batch_size = inference_params.max_batch_size
                inference_key_memory = self._allocate_memory(
                    inf_max_seq_len, inf_max_batch_size
                )
                inference_value_memory = self._allocate_memory(
                    inf_max_seq_len, inf_max_batch_size
                )
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory,
                    inference_value_memory,
                )
            else:
                (
                    inference_key_memory,
                    inference_value_memory,
                ) = inference_params.key_value_memory_dict[self.layer_number]

        # =====================
        # Query, Key, and Value
        # =====================

        if self.attention_type == "self":
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            if self.input_layernorm:
                layernorm_qkv_outputs = self.layernorm_qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    mixed_x_layer, layernorm_output = layernorm_qkv_outputs
                else:
                    mixed_x_layer = layernorm_qkv_outputs
            else:
                mixed_x_layer = self.qkv(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )

714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
            if self.qkv_weight_interleaved:
                # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    self.num_attention_heads_per_partition,
                    3 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
                # [sq, b, (np * 3 * hn)] --> [sq, b, 3 * np, hn]
                new_tensor_shape = mixed_x_layer.size()[:-1] + (
                    3 * self.num_attention_heads_per_partition,
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

Przemek Tredak's avatar
Przemek Tredak committed
731
732
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

733
734
735
            # mixed_x_layer --> 3 [sq, b, np, hn]
            query_layer, key_layer, value_layer = split_tensor_along_dim(
                mixed_x_layer, split_dim, 3
Przemek Tredak's avatar
Przemek Tredak committed
736
737
738
739
740
741
742
743
            )
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer = self.key_value(
                encoder_output,
                is_first_microbatch=is_first_microbatch,
            )

744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
            if self.qkv_weight_interleaved:
                # [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn]
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
                    self.num_attention_heads_per_partition,
                    2 * self.hidden_size_per_attention_head,
                )
                # split along last dimension
                split_dim = -1
            else:
                # [sq, b, (np * 2 * hn)] --> [sq, b, 2 * np, hn]
                new_tensor_shape = mixed_kv_layer.size()[:-1] + (
                    2 * self.num_attention_heads_per_partition,
                    self.hidden_size_per_attention_head,
                )
                # split along second last dimension
                split_dim = -2

Przemek Tredak's avatar
Przemek Tredak committed
761
762
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

763
764
            # mixed_kv_layer --> 2 [sk, b, np, hn]
            key_layer, value_layer = split_tensor_along_dim(mixed_kv_layer, split_dim, 2)
Przemek Tredak's avatar
Przemek Tredak committed
765
766
767
768
769
770
771
772
773
774
775
776

            # Attention head [sq, b, h] --> [sq, b, hp]
            if self.input_layernorm:
                layernorm_query_outputs = self.layernorm_query(
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )
                if self.return_layernorm_output:
                    query_layer, layernorm_output = layernorm_query_outputs
                else:
                    query_layer = layernorm_query_outputs
            else:
777
                query_layer = self.query_layer(
Przemek Tredak's avatar
Przemek Tredak committed
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
                    hidden_states,
                    is_first_microbatch=is_first_microbatch,
                )

            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + (
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
            query_layer = query_layer.view(*new_tensor_shape)

        # ==================================
        # Adjust key and value for inference
        # ==================================

        if inference_params and self.layer_number is not None:
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
            assert batch_end <= inference_key_memory.size(1)
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
            assert sequence_end <= inference_key_memory.size(0)
            # Copy key and values.
            inference_key_memory[
                sequence_start:sequence_end, batch_start:batch_end, ...
            ] = key_layer
            inference_value_memory[
                sequence_start:sequence_end, batch_start:batch_end, ...
            ] = value_layer
            key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
            value_layer = inference_value_memory[
                :sequence_end, batch_start:batch_end, ...
            ]

        # ==================================
        # core attention computation
        # ==================================

cyanguwa's avatar
cyanguwa committed
816
817
818
819
820
821
822
        context_layer = self.core_attention(
            query_layer,
            key_layer,
            value_layer,
            attention_mask,
            checkpoint_core_attention=checkpoint_core_attention,
        )
Przemek Tredak's avatar
Przemek Tredak committed
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837

        # =================
        # Output. [sq, b, h]
        # =================

        attention_output, attention_bias = self.proj(
            context_layer, is_first_microbatch=is_first_microbatch
        )

        if self.input_layernorm and self.return_layernorm_output:
            return attention_output, attention_bias, layernorm_output
        return attention_output, attention_bias


class TransformerLayer(torch.nn.Module):
838
    r"""
Przemek Tredak's avatar
Przemek Tredak committed
839
840
841
    TransformerLayer is made up of an attention block and a feedforward network (MLP).
    This standard layer is based on the paper "Attention Is All You Need".

842
843
844
845
846
    .. warning::

        Arguments :attr:`attention_softmax_in_fp32` and :attr:`apply_query_key_layer_scaling`
        are deprecated and will be fully removed in future releases.

847
848
849
850
851
    .. note::

        Argument :attr:`attention_mask` will be ignored in the `forward` call when
        :attr:`self_attn_mask_type` is set to `"causal"`.

Przemek Tredak's avatar
Przemek Tredak committed
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    ffn_hidden_size : int
                     intermediate size to which input samples are projected.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    hidden_dropout: float, default = 0.1
                   dropout probability for the dropout op after FC2 layer.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    apply_residual_connection_post_layernorm : bool, default = `False`
                                              if set to `True`, residual connections are taken
                                              from the output of layer norm (default is taken
                                              from input of layer norm)
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
    output_layernorm: bool, default = `False`
                     if set to `True`, layer normalization is applied on the output side,
                     after the final dropout-add. default behavior is to apply layer
                     normalization on the input side, before the QKV transformation.
    layer_type: {'encoder', 'decoder'}, default = `encoder`
               if set to `decoder`, an additional cross-attn block is added after self-attn.
               This can be used for structures like `T5` Transformer in conjunction with the
               `encoder` option.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    self_attn_mask_type: {'causal', 'padding'}, default = `causal`
                        type of attention mask passed into softmax operation.
895
896
897
898
899
900
901
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
902
903
904
905
906
907
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
ngoyal2707's avatar
ngoyal2707 committed
908
909
910
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.

Przemek Tredak's avatar
Przemek Tredak committed
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient.
    params_dtype : torch.dtype, default = `torch.float32`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    seq_length: int
               sequence length of input samples. Needed for JIT Warmup, a technique where jit
               fused functions are warmed up before training to ensure same kernels are used for
               forward propogation and activation recompute phase.
    micro_batch_size: int
                     batch size per training step. Needed for JIT Warmup, a technique where jit
                     fused functions are warmed up before training to ensure same kernels are
                     used for forward propogation and activation recompute phase.
    drop_path_rate: float, default = 0.0
                   when > 0.0, applies stochastic depth per sample in
                   the main path of the residual block.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
    """

    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        num_attention_heads: int,
        layernorm_epsilon: float = 1e-5,
        hidden_dropout: float = 0.1,
        attention_dropout: float = 0.1,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
        layer_number: Optional[int] = None,
        kv_channels: Optional[int] = None,
        self_attn_mask_type: str = "causal",
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        params_dtype: torch.dtype = torch.float32,
        get_rng_state_tracker: Optional[Callable] = None,
        fuse_wgrad_accumulation: bool = False,
973
974
        apply_query_key_layer_scaling: bool = False, # pylint: disable=unused-argument
        attention_softmax_in_fp32: bool = True, # pylint: disable=unused-argument
Przemek Tredak's avatar
Przemek Tredak committed
975
976
977
978
979
980
981
982
983
        seq_length: Optional[int] = None,
        micro_batch_size: Optional[int] = None,
        sequence_parallel: bool = False,
        apply_residual_connection_post_layernorm: bool = False,
        output_layernorm: bool = False,
        layer_type: str = "encoder",
        drop_path_rate: float = 0.0,
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
984
        zero_centered_gamma: bool = False,
985
        qkv_weight_interleaved: bool = True,
986
        ub_tp_comm_overlap: bool = False,
ngoyal2707's avatar
ngoyal2707 committed
987
        bias: bool = True,
Przemek Tredak's avatar
Przemek Tredak committed
988
989
990
    ) -> None:
        super().__init__()

991
992
993
994
995
996
        warnings.warn(
            "Arguments `attention_softmax_in_fp32` and `apply_query_key_layer_scaling`"
            "are deprecated and will be fully removed in future releases.",
            category=DeprecationWarning,
        )

997
998
999
1000
1001
1002
1003
1004
1005
1006
        if ub_tp_comm_overlap:
            assert (
                tex.userbuf_comm_available()
            ), "Userbuffer communication backend not available."

        ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1")))
        ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1")))
        ub_bulk_dgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_DGRAD", "1")))
        ub_split_ag = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_AG", "1")))
        ub_split_rs = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_RS", "1")))
Przemek Tredak's avatar
Przemek Tredak committed
1007
1008
1009
1010
1011
1012
1013
        bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
        self.layer_number = layer_number
        self.output_layernorm = output_layernorm
        self.layer_type = layer_type
        self.apply_residual_connection_post_layernorm = (
            apply_residual_connection_post_layernorm
        )
1014
        self.self_attn_mask_type = self_attn_mask_type
Przemek Tredak's avatar
Przemek Tredak committed
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
        assert (
            self_attn_mask_type in AttnMaskTypes
        ), f"self_attn_mask_type {self_attn_mask_type} not supported"
        assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"

        if not fuse_qkv_params:
            assert (
                not fuse_wgrad_accumulation
            ), "Gradient accumulation fusion requires single QKV parameter."

1025
1026
1027
        if not fuse_qkv_params:
            qkv_weight_interleaved = False

Przemek Tredak's avatar
Przemek Tredak committed
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
        self.kv_channels = (
            kv_channels if kv_channels else (hidden_size // num_attention_heads)
        )

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()

        tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.sequence_parallel = (tp_size > 1) and sequence_parallel

        self.get_rng_state_tracker = get_rng_state_tracker

        attention_args = (
            hidden_size,
            num_attention_heads,
            self.kv_channels,
            attention_dropout,
            layernorm_epsilon,
            init_method,
            output_layer_init_method,
        )
        common_attention_kwargs = {
            "layer_number": layer_number,
            "tp_group": tp_group,
            "tp_size": tp_size,
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": self.sequence_parallel,
            "params_dtype": params_dtype,
            "return_layernorm_output": apply_residual_connection_post_layernorm,
            "set_parallel_mode": set_parallel_mode,
            "fuse_qkv_params": fuse_qkv_params,
cyanguwa's avatar
cyanguwa committed
1062
            "zero_centered_gamma": zero_centered_gamma,
1063
            "qkv_weight_interleaved" : qkv_weight_interleaved,
1064
1065
1066
1067
            "ub_bulk_wgrad" : ub_bulk_wgrad,
            "ub_bulk_dgrad" : ub_bulk_dgrad,
            "ub_split_ag" : ub_split_ag,
            "ub_split_rs" : ub_split_rs,
Przemek Tredak's avatar
Przemek Tredak committed
1068
1069
1070
1071
1072
1073
1074
1075
        }

        self.self_attention = MultiHeadAttention(
            *attention_args,
            **common_attention_kwargs,
            attn_mask_type=self_attn_mask_type,
            input_layernorm=not output_layernorm,
            attention_type="self",
ngoyal2707's avatar
ngoyal2707 committed
1076
            bias=bias,
Przemek Tredak's avatar
Przemek Tredak committed
1077
1078
1079
1080
1081
1082
1083
1084
1085
        )

        if layer_type == "decoder":
            self.inter_attention = MultiHeadAttention(
                *attention_args,
                **common_attention_kwargs,
                attn_mask_type="padding",
                input_layernorm=True,
                attention_type="cross",
ngoyal2707's avatar
ngoyal2707 committed
1086
                bias=bias,
Przemek Tredak's avatar
Przemek Tredak committed
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
            )

        # LayerNorm -> gelu(Linear + Bias) -> Linear
        # parallel_mode not supported for LayerNormMLP,
        # FC1 is CPL and FC2 is RPL
        self.layernorm_mlp = LayerNormMLP(
            hidden_size,
            ffn_hidden_size,
            eps=layernorm_epsilon,
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            tp_group=tp_group,
            tp_size=tp_size,
            get_rng_state_tracker=get_rng_state_tracker,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
ngoyal2707's avatar
ngoyal2707 committed
1102
            bias=bias,
Przemek Tredak's avatar
Przemek Tredak committed
1103
1104
1105
1106
1107
1108
1109
            return_bias=True,
            sequence_parallel=self.sequence_parallel,
            params_dtype=params_dtype,
            return_layernorm_output=apply_residual_connection_post_layernorm,
            seq_length=seq_length,
            micro_batch_size=micro_batch_size,
            set_parallel_mode=set_parallel_mode,
1110
            zero_centered_gamma=zero_centered_gamma,
1111
1112
1113
1114
            ub_bulk_wgrad=ub_bulk_wgrad,
            ub_bulk_dgrad=ub_bulk_dgrad,
            ub_split_rs=ub_split_rs,
            ub_split_ag=ub_split_ag,
Przemek Tredak's avatar
Przemek Tredak committed
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
        )

        self.hidden_dropout = hidden_dropout
        self.bias_dropout_fusion = bias_dropout_fusion
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None

        # Set bias+dropout+add fusion grad_enable execution handler.
        TORCH_MAJOR = int(torch.__version__.split(".")[0])
        TORCH_MINOR = int(torch.__version__.split(".")[1])
        use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
        self.bias_dropout_add_exec_handler = (
            nullcontext if use_nvfuser else torch.enable_grad
        )

        if self.bias_dropout_fusion:
            set_jit_fusion_options()
            if seq_length and micro_batch_size:
                if self.sequence_parallel:
                    seq_length = seq_length // tp_size
                warmup_jit_bias_dropout_add_all_dtypes(
                    hidden_size, seq_length, micro_batch_size
                )

        if self.output_layernorm:
            self.layernorm = LayerNorm(
                hidden_size,
                eps=layernorm_epsilon,
                sequence_parallel=self.sequence_parallel,
                params_dtype=params_dtype,
1144
                zero_centered_gamma=zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
            )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
        """Set TP group"""
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_tensor_parallel_group"):
                child.set_tensor_parallel_group(tp_group)

    def forward(
        self,
        hidden_states: torch.Tensor,
cyanguwa's avatar
cyanguwa committed
1159
        attention_mask: Optional[torch.Tensor] = None,
Przemek Tredak's avatar
Przemek Tredak committed
1160
1161
1162
        encoder_output: Optional[torch.Tensor] = None,
        enc_dec_attn_mask: Optional[torch.Tensor] = None,
        is_first_microbatch: Optional[bool] = None,
cyanguwa's avatar
cyanguwa committed
1163
        checkpoint_core_attention: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
1164
1165
1166
1167
1168
        inference_params: Optional[Any] = None,
    ) -> torch.Tensor:
        """
        Transformer Layer: attention block and a feedforward network (MLP)

1169
1170
1171
1172
1173
        .. note::

            Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type`
            is set to `"causal"`.

Przemek Tredak's avatar
Przemek Tredak committed
1174
1175
1176
1177
        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
cyanguwa's avatar
cyanguwa committed
1178
        attention_mask : Optional[torch.Tensor], default = `None`
Przemek Tredak's avatar
Przemek Tredak committed
1179
             Boolean tensor used to mask out self-attention softmax input.
cyanguwa's avatar
cyanguwa committed
1180
        encoder_output : Optional[torch.Tensor], default = `None`
Przemek Tredak's avatar
Przemek Tredak committed
1181
1182
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
cyanguwa's avatar
cyanguwa committed
1183
        enc_dec_attn_mask : Optional[torch.Tensor], default = `None`
Przemek Tredak's avatar
Przemek Tredak committed
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
             Boolean tensor used to mask out inter-attention softmax input if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
cyanguwa's avatar
cyanguwa committed
1199
        checkpoint_core_attention: bool, default = `False`
Przemek Tredak's avatar
Przemek Tredak committed
1200
1201
1202
1203
1204
1205
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        """

1206
1207
        hidden_states = hidden_states.contiguous()

1208
        if self.self_attn_mask_type != "causal" and attention_mask is not None:
1209
1210
1211
1212
            assert (
                attention_mask.dtype == torch.bool
            ), "Attention mask must be a boolean tensor"

Przemek Tredak's avatar
Przemek Tredak committed
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
        # For AMP
        if torch.is_autocast_enabled():
            hidden_states = cast_if_needed(
                hidden_states, torch.get_autocast_gpu_dtype()
            )

        # Self attention.
        self_attention_outputs = self.self_attention(
            hidden_states,
            attention_mask,
            inference_params=inference_params,
            is_first_microbatch=is_first_microbatch,
            checkpoint_core_attention=checkpoint_core_attention,
        )
ngoyal2707's avatar
ngoyal2707 committed
1227

Przemek Tredak's avatar
Przemek Tredak committed
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
        if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
            attention_output, attention_bias, residual = self_attention_outputs
        else:
            attention_output, attention_bias = self_attention_outputs
            residual = hidden_states

        # Set BDA func.
        if self.bias_dropout_fusion:
            if self.training:
                bias_dropout_add_func = bias_dropout_add_fused_train
            else:
                bias_dropout_add_func = bias_dropout_add_fused_inference
        else:
            bias_dropout_add_func = get_bias_dropout_add(self.training)

        # Bias dropoout add.
ngoyal2707's avatar
ngoyal2707 committed
1244
        if self.drop_path is None and attention_bias.numel() != 0:
Przemek Tredak's avatar
Przemek Tredak committed
1245
1246
1247
1248
1249
            with self.bias_dropout_add_exec_handler():
                bda_output = bias_dropout_add_func(
                    attention_output, attention_bias, residual, self.hidden_dropout
                )
        else:
ngoyal2707's avatar
ngoyal2707 committed
1250
1251
            if attention_bias.numel() != 0:
                attention_output = attention_output + attention_bias
Przemek Tredak's avatar
Przemek Tredak committed
1252
            out = torch.nn.functional.dropout(
ngoyal2707's avatar
ngoyal2707 committed
1253
                attention_output,
Przemek Tredak's avatar
Przemek Tredak committed
1254
1255
1256
                p=self.hidden_dropout,
                training=self.training,
            )
ngoyal2707's avatar
ngoyal2707 committed
1257
1258
1259
            if self.drop_path is not None:
                out = self.drop_path(out)
            bda_output = residual + out
Przemek Tredak's avatar
Przemek Tredak committed
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275

        # Cross attention.
        if self.layer_type == "decoder":
            inter_attention_outputs = self.inter_attention(
                bda_output,
                enc_dec_attn_mask,
                encoder_output=encoder_output,
                is_first_microbatch=is_first_microbatch,
                checkpoint_core_attention=checkpoint_core_attention,
            )
            if self.apply_residual_connection_post_layernorm:
                attention_output, attention_bias, residual = inter_attention_outputs
            else:
                attention_output, attention_bias = inter_attention_outputs
                residual = bda_output

ngoyal2707's avatar
ngoyal2707 committed
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
            if attention_bias.numel() != 0:
                with self.bias_dropout_add_exec_handler():
                    bda_output = bias_dropout_add_func(
                        attention_output, attention_bias, residual, self.hidden_dropout
                    )
            else:
                out = torch.nn.functional.dropout(
                    attention_output,
                    p=self.hidden_dropout,
                    training=self.training,
Przemek Tredak's avatar
Przemek Tredak committed
1286
                )
ngoyal2707's avatar
ngoyal2707 committed
1287
                bda_output = residual + out
Przemek Tredak's avatar
Przemek Tredak committed
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
        # MLP.
        mlp_outputs = self.layernorm_mlp(
            bda_output, is_first_microbatch=is_first_microbatch
        )
        if self.apply_residual_connection_post_layernorm:
            mlp_output, mlp_bias, residual = mlp_outputs
        else:
            mlp_output, mlp_bias = mlp_outputs
            residual = bda_output

        # Bias dropoout add.
ngoyal2707's avatar
ngoyal2707 committed
1299
        if self.drop_path is None and mlp_bias.numel() != 0:
Przemek Tredak's avatar
Przemek Tredak committed
1300
1301
1302
1303
1304
            with self.bias_dropout_add_exec_handler():
                output = bias_dropout_add_func(
                    mlp_output, mlp_bias, residual, self.hidden_dropout
                )
        else:
ngoyal2707's avatar
ngoyal2707 committed
1305
1306
            if mlp_bias.numel() != 0:
                mlp_output = mlp_output + mlp_bias
Przemek Tredak's avatar
Przemek Tredak committed
1307
            out = torch.nn.functional.dropout(
ngoyal2707's avatar
ngoyal2707 committed
1308
                mlp_output, p=self.hidden_dropout, training=self.training
Przemek Tredak's avatar
Przemek Tredak committed
1309
            )
ngoyal2707's avatar
ngoyal2707 committed
1310
1311
1312
            if self.drop_path is not None:
                out = self.drop_path(out)
            output = residual + out
Przemek Tredak's avatar
Przemek Tredak committed
1313
1314
1315
1316
1317
1318
1319

        # For BERT like architectures.
        if self.output_layernorm:
            output = self.layernorm(output)

        # output: [b, s, h]
        return output