"git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "4f6f4967f6af78534f460d75a9391f9a42b564b0"
transformer.py 23.9 KB
Newer Older
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
#
# See LICENSE for license information.

"""Transformer."""
import os
7
import warnings
Przemek Tredak's avatar
Przemek Tredak committed
8
from contextlib import nullcontext
9
from typing import Any, Callable, Optional, Union
Przemek Tredak's avatar
Przemek Tredak committed
10
11
12

import torch

13
import transformer_engine_extensions as tex
14
15
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm
from transformer_engine.pytorch.attention import MultiHeadAttention
Przemek Tredak's avatar
Przemek Tredak committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from transformer_engine.pytorch.jit import (
    set_jit_fusion_options,
    warmup_jit_bias_dropout_add_all_dtypes,
    get_bias_dropout_add,
    bias_dropout_add_fused_train,
    bias_dropout_add_fused_inference,
)
from transformer_engine.pytorch.utils import (
    cast_if_needed,
    get_default_init_method,
)
from transformer_engine.pytorch.constants import (
    AttnMaskTypes,
    LayerTypes,
    dist_group_type,
)
32
33
from transformer_engine.pytorch.distributed import get_distributed_world_size

Przemek Tredak's avatar
Przemek Tredak committed
34

35
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
cyanguwa's avatar
cyanguwa committed
36
37


38
__all__ = ["TransformerLayer"]
cyanguwa's avatar
cyanguwa committed
39

Przemek Tredak's avatar
Przemek Tredak committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

class DropPath(torch.nn.Module):
    """Drop paths (Stochastic Depth) per sample
    (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob: float = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        """DropPath FWD"""

        if self.drop_prob == 0.0 or not self.training:
            return hidden_state
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
        shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
        random_tensor = keep_prob + torch.rand(
            shape, dtype=hidden_state.dtype, device=hidden_state.device
        )
        random_tensor.floor_()  # binarize
        output = hidden_state.div(keep_prob) * random_tensor
        return output




class TransformerLayer(torch.nn.Module):
69
    r"""
Przemek Tredak's avatar
Przemek Tredak committed
70
71
72
    TransformerLayer is made up of an attention block and a feedforward network (MLP).
    This standard layer is based on the paper "Attention Is All You Need".

73
74
75
76
77
    .. warning::

        Arguments :attr:`attention_softmax_in_fp32` and :attr:`apply_query_key_layer_scaling`
        are deprecated and will be fully removed in future releases.

78
79
80
81
82
    .. note::

        Argument :attr:`attention_mask` will be ignored in the `forward` call when
        :attr:`self_attn_mask_type` is set to `"causal"`.

Przemek Tredak's avatar
Przemek Tredak committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    ffn_hidden_size : int
                     intermediate size to which input samples are projected.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    hidden_dropout: float, default = 0.1
                   dropout probability for the dropout op after FC2 layer.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
    init_method : Callable, default = `None`
                 used for initializing weights of QKV and FC1 weights in the following way:
                 `init_method(weight)`. When set to `None`, defaults to
                 `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing weights of PROJ and FC2 in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    apply_residual_connection_post_layernorm : bool, default = `False`
                                              if set to `True`, residual connections are taken
                                              from the output of layer norm (default is taken
                                              from input of layer norm)
    layer_number: int, default = `None`
                 layer number of the current `TransformerLayer` when multiple such modules are
                 concatenated to form a transformer block.
    output_layernorm: bool, default = `False`
                     if set to `True`, layer normalization is applied on the output side,
                     after the final dropout-add. default behavior is to apply layer
                     normalization on the input side, before the QKV transformation.
    layer_type: {'encoder', 'decoder'}, default = `encoder`
               if set to `decoder`, an additional cross-attn block is added after self-attn.
               This can be used for structures like `T5` Transformer in conjunction with the
               `encoder` option.
    kv_channels: int, default = `None`
                number of key-value channels. defaults to
                :attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
    self_attn_mask_type: {'causal', 'padding'}, default = `causal`
                        type of attention mask passed into softmax operation.
126
127
128
129
130
131
132
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
133
134
135
136
137
138
    qkv_weight_interleaved : bool, default = `True`
                            if set to `False`, the QKV weight is interpreted as a concatenation of
                            query, key, and value weights along the `0th` dimension. The default
                            interpretation is that the individual `q`, `k`, and `v` weights for each
                            attention head are interleaved. This parameter is set to `False` when
                            using :attr:`fuse_qkv_params=False`.
ngoyal2707's avatar
ngoyal2707 committed
139
140
141
    bias : bool, default = `True`
          if set to `False`, the transformer layer will not learn any additive biases.

Przemek Tredak's avatar
Przemek Tredak committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient.
    params_dtype : torch.dtype, default = `torch.float32`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    seq_length: int
               sequence length of input samples. Needed for JIT Warmup, a technique where jit
               fused functions are warmed up before training to ensure same kernels are used for
               forward propogation and activation recompute phase.
    micro_batch_size: int
                     batch size per training step. Needed for JIT Warmup, a technique where jit
                     fused functions are warmed up before training to ensure same kernels are
                     used for forward propogation and activation recompute phase.
    drop_path_rate: float, default = 0.0
                   when > 0.0, applies stochastic depth per sample in
                   the main path of the residual block.
    fuse_qkv_params: bool, default = 'False'
                    if set to `True`, `TransformerLayer` module exposes a single fused
                    parameter for query-key-value. This enables optimizations such as QKV
                    fusion without concatentations/splits and also enables the argument
                    `fuse_wgrad_accumulation`.
    """

    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        num_attention_heads: int,
        layernorm_epsilon: float = 1e-5,
        hidden_dropout: float = 0.1,
        attention_dropout: float = 0.1,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
        layer_number: Optional[int] = None,
        kv_channels: Optional[int] = None,
        self_attn_mask_type: str = "causal",
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        params_dtype: torch.dtype = torch.float32,
        get_rng_state_tracker: Optional[Callable] = None,
        fuse_wgrad_accumulation: bool = False,
204
205
        apply_query_key_layer_scaling: bool = False, # pylint: disable=unused-argument
        attention_softmax_in_fp32: bool = True, # pylint: disable=unused-argument
Przemek Tredak's avatar
Przemek Tredak committed
206
207
208
209
210
211
212
213
214
        seq_length: Optional[int] = None,
        micro_batch_size: Optional[int] = None,
        sequence_parallel: bool = False,
        apply_residual_connection_post_layernorm: bool = False,
        output_layernorm: bool = False,
        layer_type: str = "encoder",
        drop_path_rate: float = 0.0,
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
215
        zero_centered_gamma: bool = False,
216
        qkv_weight_interleaved: bool = True,
217
        ub_tp_comm_overlap: bool = False,
ngoyal2707's avatar
ngoyal2707 committed
218
        bias: bool = True,
Przemek Tredak's avatar
Przemek Tredak committed
219
220
221
    ) -> None:
        super().__init__()

222
223
224
225
226
227
        warnings.warn(
            "Arguments `attention_softmax_in_fp32` and `apply_query_key_layer_scaling`"
            "are deprecated and will be fully removed in future releases.",
            category=DeprecationWarning,
        )

228
229
230
231
232
233
234
235
236
237
        if ub_tp_comm_overlap:
            assert (
                tex.userbuf_comm_available()
            ), "Userbuffer communication backend not available."

        ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1")))
        ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1")))
        ub_bulk_dgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_DGRAD", "1")))
        ub_split_ag = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_AG", "1")))
        ub_split_rs = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_RS", "1")))
Przemek Tredak's avatar
Przemek Tredak committed
238
239
240
241
242
243
244
        bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
        self.layer_number = layer_number
        self.output_layernorm = output_layernorm
        self.layer_type = layer_type
        self.apply_residual_connection_post_layernorm = (
            apply_residual_connection_post_layernorm
        )
245
        self.self_attn_mask_type = self_attn_mask_type
Przemek Tredak's avatar
Przemek Tredak committed
246
247
248
249
250
251
252
253
254
255
        assert (
            self_attn_mask_type in AttnMaskTypes
        ), f"self_attn_mask_type {self_attn_mask_type} not supported"
        assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"

        if not fuse_qkv_params:
            assert (
                not fuse_wgrad_accumulation
            ), "Gradient accumulation fusion requires single QKV parameter."

256
257
258
        if not fuse_qkv_params:
            qkv_weight_interleaved = False

Przemek Tredak's avatar
Przemek Tredak committed
259
260
261
262
263
264
265
266
267
        self.kv_channels = (
            kv_channels if kv_channels else (hidden_size // num_attention_heads)
        )

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()

268
269
270
        self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
        self.seq_length = seq_length
Przemek Tredak's avatar
Przemek Tredak committed
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285

        self.get_rng_state_tracker = get_rng_state_tracker

        attention_args = (
            hidden_size,
            num_attention_heads,
            self.kv_channels,
            attention_dropout,
            layernorm_epsilon,
            init_method,
            output_layer_init_method,
        )
        common_attention_kwargs = {
            "layer_number": layer_number,
            "tp_group": tp_group,
286
            "tp_size": self.tp_size,
Przemek Tredak's avatar
Przemek Tredak committed
287
288
289
290
291
292
293
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
            "get_rng_state_tracker": get_rng_state_tracker,
            "sequence_parallel": self.sequence_parallel,
            "params_dtype": params_dtype,
            "return_layernorm_output": apply_residual_connection_post_layernorm,
            "set_parallel_mode": set_parallel_mode,
            "fuse_qkv_params": fuse_qkv_params,
cyanguwa's avatar
cyanguwa committed
294
            "zero_centered_gamma": zero_centered_gamma,
295
            "qkv_weight_interleaved" : qkv_weight_interleaved,
296
297
298
299
            "ub_bulk_wgrad" : ub_bulk_wgrad,
            "ub_bulk_dgrad" : ub_bulk_dgrad,
            "ub_split_ag" : ub_split_ag,
            "ub_split_rs" : ub_split_rs,
Przemek Tredak's avatar
Przemek Tredak committed
300
301
302
303
304
305
306
307
        }

        self.self_attention = MultiHeadAttention(
            *attention_args,
            **common_attention_kwargs,
            attn_mask_type=self_attn_mask_type,
            input_layernorm=not output_layernorm,
            attention_type="self",
ngoyal2707's avatar
ngoyal2707 committed
308
            bias=bias,
Przemek Tredak's avatar
Przemek Tredak committed
309
310
311
312
313
314
315
316
317
        )

        if layer_type == "decoder":
            self.inter_attention = MultiHeadAttention(
                *attention_args,
                **common_attention_kwargs,
                attn_mask_type="padding",
                input_layernorm=True,
                attention_type="cross",
ngoyal2707's avatar
ngoyal2707 committed
318
                bias=bias,
Przemek Tredak's avatar
Przemek Tredak committed
319
320
321
322
323
324
325
326
327
328
329
            )

        # LayerNorm -> gelu(Linear + Bias) -> Linear
        # parallel_mode not supported for LayerNormMLP,
        # FC1 is CPL and FC2 is RPL
        self.layernorm_mlp = LayerNormMLP(
            hidden_size,
            ffn_hidden_size,
            eps=layernorm_epsilon,
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            tp_group=tp_group,
330
            tp_size=self.tp_size,
Przemek Tredak's avatar
Przemek Tredak committed
331
332
333
            get_rng_state_tracker=get_rng_state_tracker,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
ngoyal2707's avatar
ngoyal2707 committed
334
            bias=bias,
Przemek Tredak's avatar
Przemek Tredak committed
335
336
337
338
339
340
341
            return_bias=True,
            sequence_parallel=self.sequence_parallel,
            params_dtype=params_dtype,
            return_layernorm_output=apply_residual_connection_post_layernorm,
            seq_length=seq_length,
            micro_batch_size=micro_batch_size,
            set_parallel_mode=set_parallel_mode,
342
            zero_centered_gamma=zero_centered_gamma,
343
344
345
346
            ub_bulk_wgrad=ub_bulk_wgrad,
            ub_bulk_dgrad=ub_bulk_dgrad,
            ub_split_rs=ub_split_rs,
            ub_split_ag=ub_split_ag,
Przemek Tredak's avatar
Przemek Tredak committed
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        )

        self.hidden_dropout = hidden_dropout
        self.bias_dropout_fusion = bias_dropout_fusion
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None

        # Set bias+dropout+add fusion grad_enable execution handler.
        TORCH_MAJOR = int(torch.__version__.split(".")[0])
        TORCH_MINOR = int(torch.__version__.split(".")[1])
        use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
        self.bias_dropout_add_exec_handler = (
            nullcontext if use_nvfuser else torch.enable_grad
        )

        if self.bias_dropout_fusion:
            set_jit_fusion_options()
            if seq_length and micro_batch_size:
                if self.sequence_parallel:
365
                    seq_length = seq_length // self.tp_size
Przemek Tredak's avatar
Przemek Tredak committed
366
367
368
369
370
371
372
373
374
375
                warmup_jit_bias_dropout_add_all_dtypes(
                    hidden_size, seq_length, micro_batch_size
                )

        if self.output_layernorm:
            self.layernorm = LayerNorm(
                hidden_size,
                eps=layernorm_epsilon,
                sequence_parallel=self.sequence_parallel,
                params_dtype=params_dtype,
376
                zero_centered_gamma=zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
377
378
379
380
381
382
383
384
385
386
387
388
389
390
            )

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
        """Set TP group"""
        # Deep iterate but skip self to avoid infinite recursion.
        for index, child in enumerate(self.modules()):
            if index == 0:
                continue
            if hasattr(child, "set_tensor_parallel_group"):
                child.set_tensor_parallel_group(tp_group)

    def forward(
        self,
        hidden_states: torch.Tensor,
cyanguwa's avatar
cyanguwa committed
391
        attention_mask: Optional[torch.Tensor] = None,
Przemek Tredak's avatar
Przemek Tredak committed
392
393
394
        encoder_output: Optional[torch.Tensor] = None,
        enc_dec_attn_mask: Optional[torch.Tensor] = None,
        is_first_microbatch: Optional[bool] = None,
cyanguwa's avatar
cyanguwa committed
395
        checkpoint_core_attention: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
396
397
398
399
400
        inference_params: Optional[Any] = None,
    ) -> torch.Tensor:
        """
        Transformer Layer: attention block and a feedforward network (MLP)

401
402
403
404
405
        .. note::

            Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type`
            is set to `"causal"`.

Przemek Tredak's avatar
Przemek Tredak committed
406
407
408
409
        Parameters
        ----------
        hidden_states : torch.Tensor
             Input tensor.
cyanguwa's avatar
cyanguwa committed
410
        attention_mask : Optional[torch.Tensor], default = `None`
Przemek Tredak's avatar
Przemek Tredak committed
411
             Boolean tensor used to mask out self-attention softmax input.
cyanguwa's avatar
cyanguwa committed
412
        encoder_output : Optional[torch.Tensor], default = `None`
Przemek Tredak's avatar
Przemek Tredak committed
413
414
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
cyanguwa's avatar
cyanguwa committed
415
        enc_dec_attn_mask : Optional[torch.Tensor], default = `None`
Przemek Tredak's avatar
Przemek Tredak committed
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
             Boolean tensor used to mask out inter-attention softmax input if using
             `layer_type="decoder"`.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
cyanguwa's avatar
cyanguwa committed
431
        checkpoint_core_attention: bool, default = `False`
Przemek Tredak's avatar
Przemek Tredak committed
432
433
434
435
436
437
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
        """

438
439
        hidden_states = hidden_states.contiguous()

440
441
442
443
444
        if self.sequence_parallel and self.seq_length is not None:
            assert (
                hidden_states.shape[0] == self.seq_length // self.tp_size
            ), "Sequence dimension must be split across TP group when using sequence parallel."

445
        if self.self_attn_mask_type != "causal" and attention_mask is not None:
446
447
448
449
            assert (
                attention_mask.dtype == torch.bool
            ), "Attention mask must be a boolean tensor"

Przemek Tredak's avatar
Przemek Tredak committed
450
451
452
453
454
455
456
457
458
459
460
461
462
463
        # For AMP
        if torch.is_autocast_enabled():
            hidden_states = cast_if_needed(
                hidden_states, torch.get_autocast_gpu_dtype()
            )

        # Self attention.
        self_attention_outputs = self.self_attention(
            hidden_states,
            attention_mask,
            inference_params=inference_params,
            is_first_microbatch=is_first_microbatch,
            checkpoint_core_attention=checkpoint_core_attention,
        )
ngoyal2707's avatar
ngoyal2707 committed
464

Przemek Tredak's avatar
Przemek Tredak committed
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
        if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
            attention_output, attention_bias, residual = self_attention_outputs
        else:
            attention_output, attention_bias = self_attention_outputs
            residual = hidden_states

        # Set BDA func.
        if self.bias_dropout_fusion:
            if self.training:
                bias_dropout_add_func = bias_dropout_add_fused_train
            else:
                bias_dropout_add_func = bias_dropout_add_fused_inference
        else:
            bias_dropout_add_func = get_bias_dropout_add(self.training)

        # Bias dropoout add.
ngoyal2707's avatar
ngoyal2707 committed
481
        if self.drop_path is None and attention_bias.numel() != 0:
Przemek Tredak's avatar
Przemek Tredak committed
482
483
484
485
486
            with self.bias_dropout_add_exec_handler():
                bda_output = bias_dropout_add_func(
                    attention_output, attention_bias, residual, self.hidden_dropout
                )
        else:
ngoyal2707's avatar
ngoyal2707 committed
487
488
            if attention_bias.numel() != 0:
                attention_output = attention_output + attention_bias
Przemek Tredak's avatar
Przemek Tredak committed
489
            out = torch.nn.functional.dropout(
ngoyal2707's avatar
ngoyal2707 committed
490
                attention_output,
Przemek Tredak's avatar
Przemek Tredak committed
491
492
493
                p=self.hidden_dropout,
                training=self.training,
            )
ngoyal2707's avatar
ngoyal2707 committed
494
495
496
            if self.drop_path is not None:
                out = self.drop_path(out)
            bda_output = residual + out
Przemek Tredak's avatar
Przemek Tredak committed
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512

        # Cross attention.
        if self.layer_type == "decoder":
            inter_attention_outputs = self.inter_attention(
                bda_output,
                enc_dec_attn_mask,
                encoder_output=encoder_output,
                is_first_microbatch=is_first_microbatch,
                checkpoint_core_attention=checkpoint_core_attention,
            )
            if self.apply_residual_connection_post_layernorm:
                attention_output, attention_bias, residual = inter_attention_outputs
            else:
                attention_output, attention_bias = inter_attention_outputs
                residual = bda_output

ngoyal2707's avatar
ngoyal2707 committed
513
514
515
516
517
518
519
520
521
522
            if attention_bias.numel() != 0:
                with self.bias_dropout_add_exec_handler():
                    bda_output = bias_dropout_add_func(
                        attention_output, attention_bias, residual, self.hidden_dropout
                    )
            else:
                out = torch.nn.functional.dropout(
                    attention_output,
                    p=self.hidden_dropout,
                    training=self.training,
Przemek Tredak's avatar
Przemek Tredak committed
523
                )
ngoyal2707's avatar
ngoyal2707 committed
524
                bda_output = residual + out
Przemek Tredak's avatar
Przemek Tredak committed
525
526
527
528
529
530
531
532
533
534
535
        # MLP.
        mlp_outputs = self.layernorm_mlp(
            bda_output, is_first_microbatch=is_first_microbatch
        )
        if self.apply_residual_connection_post_layernorm:
            mlp_output, mlp_bias, residual = mlp_outputs
        else:
            mlp_output, mlp_bias = mlp_outputs
            residual = bda_output

        # Bias dropoout add.
ngoyal2707's avatar
ngoyal2707 committed
536
        if self.drop_path is None and mlp_bias.numel() != 0:
Przemek Tredak's avatar
Przemek Tredak committed
537
538
539
540
541
            with self.bias_dropout_add_exec_handler():
                output = bias_dropout_add_func(
                    mlp_output, mlp_bias, residual, self.hidden_dropout
                )
        else:
ngoyal2707's avatar
ngoyal2707 committed
542
543
            if mlp_bias.numel() != 0:
                mlp_output = mlp_output + mlp_bias
Przemek Tredak's avatar
Przemek Tredak committed
544
            out = torch.nn.functional.dropout(
ngoyal2707's avatar
ngoyal2707 committed
545
                mlp_output, p=self.hidden_dropout, training=self.training
Przemek Tredak's avatar
Przemek Tredak committed
546
            )
ngoyal2707's avatar
ngoyal2707 committed
547
548
549
            if self.drop_path is not None:
                out = self.drop_path(out)
            output = residual + out
Przemek Tredak's avatar
Przemek Tredak committed
550
551
552
553
554
555
556

        # For BERT like architectures.
        if self.output_layernorm:
            output = self.layernorm(output)

        # output: [b, s, h]
        return output