transformer.py 17.6 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Shijie's avatar
Shijie committed
2
3
4
5
#
# See LICENSE for license information.
"""Transformer"""

6
from typing import Optional, Tuple, Union
7
import warnings
Shijie's avatar
Shijie committed
8
9

import paddle
10
from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd
Shijie's avatar
Shijie committed
11

12
13
14
from transformer_engine.paddle.layer import LayerNormMLP, LayerNorm, MultiHeadAttention
from transformer_engine.paddle.constants import AttnMaskTypes, LayerTypes, dist_group_type
from transformer_engine.paddle.distributed import get_tp_group_and_world_size, track_rng_state
Shijie's avatar
Shijie committed
15
16


17
class TransformerLayer(paddle.nn.Layer):
Shijie's avatar
Shijie committed
18
19
20
21
22
23
24
25
26
27
28
29
    r"""
    TransformerLayer is made up of an attention block and a feedforward network (MLP).
    This standard layer is based on the paper "Attention Is All You Need".

    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    ffn_hidden_size : int
                     intermediate size to which input samples are projected.
    num_attention_heads : int
                         number of attention heads in the transformer layer.
Shijie's avatar
Shijie committed
30
31
32
33
34
35
36
37
    num_gqa_groups : Optional[int], default = `None`
                    number of GQA groups in the transformer layer.
                    Grouped Query Attention is described in
                    `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
                    This only affects the keys and values, not the queries.
                    GQA-1 is equivalent to Multi-Query Attention
                    (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
                    is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
Shijie's avatar
Shijie committed
38
39
40
41
42
43
44
    layernorm_epsilon : float, default = 1e-5
                       a value added to the denominator of layer normalization
                       for numerical stability.
    hidden_dropout: float, default = 0.1
                   dropout probability for the dropout op after FC2 layer.
    attention_dropout: float, default = 0.1
                      dropout probability for the dropout op during multi-head attention.
45
46
47
48
    weight_attr: Union[paddle.ParamAttr, None], default = None
                optional `paddle.ParamAttr` for weight.
    bias_attr: Union[paddle.ParamAttr, None, bool], default = None
              optional `paddle.ParamAttr` for bias.
Shijie's avatar
Shijie committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    self_attn_mask_type: {'causal', 'padding'}, default = `causal`
                        type of attention mask passed into softmax operation.
    apply_residual_connection_post_layernorm : bool, default = `False`
                                              if set to `True`, residual connections are taken
                                              from the output of layer norm (default is taken
                                              from input of layer norm)
    output_layernorm: bool, default = `False`
                     if set to `True`, layer normalization is applied on the output side,
                     after the final dropout-add. default behavior is to apply layer
                     normalization on the input side, before the QKV transformation.
    layer_type: {'encoder', 'decoder'}, default = `encoder`
               if set to `decoder`, an additional cross-attn block is added after self-attn.
               This can be used for structures like `T5` Transformer in conjunction with the
               `encoder` option.
63
    normalization: {'LayerNorm', 'RMSNorm'}, default = `LayerNorm`
Shijie's avatar
Shijie committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
    activation : str, default = 'gelu'
          Type of activation used in MLP block.
          Options are: 'gelu', 'relu', 'reglu', 'geglu' and 'swiglu'.

    params_dtype : paddle.dtype, default = `paddle.get_default_dtype()`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
79
80
    backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine'
             if set to 'paddle', a framework only no-FP8 path is executed with limited optimization.
81
82
83
84
85
86
87

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, QKV and FC1 layers are used as Column Parallel
                      whereas PROJ and FC2 is used as Row Parallel as described
                      `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
88
89
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
90
91
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
Tian Zheng's avatar
Tian Zheng committed
92
93
94
95
96
97
98
99
100
101
102
    attention_dropout_rng_state_name : str, default = `local_seed`
                   Controls the rng state used for dropout on attention probs. The
                   specified rng should be set different seeds for different TP ranks.
                   It will be ignored if `set_parallel_mode` is False.
    hidden_dropout_rng_state_name : str, default = `global_seed`
                   Controls the rng state used for dropout on hidden states. The
                   specified rng should be given the same seeds for different TP
                   ranks. It will be ignored if `set_parallel_mode` is False. The
                   specified name should be registered through
                   `paddle.distributed.fleet.meta_parallel.get_rng_state_tracker()
                   .add(rng_state_name, seed)`.
Shijie's avatar
Shijie committed
103
104
105
106
107
108
109
110
111
112

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.

Shijie's avatar
Shijie committed
113
114
115
116
117
118
    """

    def __init__(self,
                 hidden_size: int,
                 ffn_hidden_size: int,
                 num_attention_heads: int,
Shijie's avatar
Shijie committed
119
                 num_gqa_groups: Optional[int] = None,
Shijie's avatar
Shijie committed
120
121
122
123
124
                 layernorm_epsilon: float = 1e-5,
                 hidden_dropout: float = 0.1,
                 attention_dropout: float = 0.1,
                 weight_attr: Union[paddle.ParamAttr, None] = None,
                 bias_attr: Union[paddle.ParamAttr, None, bool] = None,
125
                 max_sequence_length: Optional[int] = None,
Shijie's avatar
Shijie committed
126
127
128
129
130
                 self_attn_mask_type: str = "causal",
                 params_dtype: Optional[paddle.dtype] = None,
                 apply_residual_connection_post_layernorm: bool = False,
                 output_layernorm: bool = False,
                 layer_type: str = "encoder",
131
                 normalization: str = "LayerNorm",
Shijie's avatar
Shijie committed
132
133
                 zero_centered_gamma: bool = False,
                 activation: str = 'gelu',
134
                 set_parallel_mode: bool = False,
135
                 sequence_parallel: bool = False,
136
                 tp_group: Optional[dist_group_type] = None,
Shijie's avatar
Shijie committed
137
                 fuse_wgrad_accumulation: bool = False,
138
139
                 attention_dropout_rng_state_name: str = 'local_seed',
                 hidden_dropout_rng_state_name: str = 'global_seed',
Shijie's avatar
Shijie committed
140
141
142
143
144
145
146
147
                 backend: str = 'transformer_engine') -> None:
        super().__init__()

        params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype
        self.output_layernorm = output_layernorm
        self.layer_type = layer_type
        self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
        self.self_attn_mask_type = self_attn_mask_type
148
        self.set_parallel_mode = set_parallel_mode
149
150
151
        self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
                                                                  enable_tp=set_parallel_mode)
        self.tensor_parallel = self.tp_size > 1
152
        self.sequence_parallel = self.tensor_parallel and sequence_parallel
153
        self.hidden_dropout_rng_state_name = hidden_dropout_rng_state_name
154
155
156
157
158
        # SP needs local seed for hidden dropout
        if self.sequence_parallel and self.hidden_dropout_rng_state_name == 'global_seed':
            warnings.warn("RNG state for hidden dropout needs to be different across TP ranks. "
                          "Forcing hidden_dropout_rng_state_name to 'local_seed'")
            self.hidden_dropout_rng_state_name = 'local_seed'
Shijie's avatar
Shijie committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

        assert (self_attn_mask_type
                in AttnMaskTypes), f"self_attn_mask_type {self_attn_mask_type} not supported"
        assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"

        attention_args = (
            hidden_size,
            num_attention_heads,
            attention_dropout,
            layernorm_epsilon,
            weight_attr,
            bias_attr,
        )
        common_attention_kwargs = {
            "params_dtype": params_dtype,
            "return_layernorm_output": apply_residual_connection_post_layernorm,
175
            "normalization": normalization,
Shijie's avatar
Shijie committed
176
            "zero_centered_gamma": zero_centered_gamma,
177
            "set_parallel_mode": set_parallel_mode,
178
            "sequence_parallel": self.sequence_parallel,
179
            'max_sequence_length': max_sequence_length,
180
            "tp_group": tp_group,
Shijie's avatar
Shijie committed
181
            "num_gqa_groups": num_gqa_groups,
Shijie's avatar
Shijie committed
182
            "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
183
            "rng_state_name": attention_dropout_rng_state_name,
Shijie's avatar
Shijie committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
            "backend": backend,
        }

        self.self_attention = MultiHeadAttention(
            *attention_args,
            **common_attention_kwargs,
            attn_mask_type=self_attn_mask_type,
            input_layernorm=not output_layernorm,
            attention_type="self",
        )

        if layer_type == "decoder":
            self.inter_attention = MultiHeadAttention(
                *attention_args,
                **common_attention_kwargs,
                attn_mask_type="padding",
                input_layernorm=True,
                attention_type="cross",
            )

        self.layernorm_mlp = LayerNormMLP(
            hidden_size,
            ffn_hidden_size,
            eps=layernorm_epsilon,
            weight_attr=weight_attr,
            bias_attr=bias_attr,
210
            normalization=normalization,
Shijie's avatar
Shijie committed
211
212
213
            activation=activation,
            return_layernorm_output=apply_residual_connection_post_layernorm,
            zero_centered_gamma=zero_centered_gamma,
214
            set_parallel_mode=set_parallel_mode,
215
            sequence_parallel=self.sequence_parallel,
216
            tp_group=tp_group,
Shijie's avatar
Shijie committed
217
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
Shijie's avatar
Shijie committed
218
219
220
221
222
223
224
225
226
227
228
229
            backend=backend,
        )

        self.hidden_dropout = hidden_dropout

        if self.output_layernorm:
            self.layernorm = LayerNorm(
                hidden_size,
                layernorm_epsilon,
                weight_attr,
                bias_attr,
                zero_centered_gamma=zero_centered_gamma,
230
                sequence_parallel=self.sequence_parallel,
Shijie's avatar
Shijie committed
231
232
233
                backend=backend,
            )

234
235
236
237
238
        self.fused_dropout_add1 = FusedDropoutAdd(self.hidden_dropout, mode="upscale_in_train")
        if self.layer_type == "decoder":
            self.fused_dropout_add2 = FusedDropoutAdd(self.hidden_dropout, mode="upscale_in_train")
        self.fused_dropout_add3 = FusedDropoutAdd(self.hidden_dropout, mode="upscale_in_train")

Shijie's avatar
Shijie committed
239
240
241
242
243
244
    def forward(
        self,
        hidden_states: paddle.Tensor,
        attention_mask: Optional[paddle.Tensor] = None,
        encoder_output: Optional[paddle.Tensor] = None,
        enc_dec_attn_mask: Optional[paddle.Tensor] = None,
245
        rotary_pos_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None,
Shijie's avatar
Shijie committed
246
247
248
        core_attention_bias_type: str = "no_bias",
        core_attention_bias: Optional[paddle.Tensor] = None,
        set_zero: bool = True,
Tian Zheng's avatar
Tian Zheng committed
249
        recompute_core_attention: bool = False,
250
        is_first_microbatch: Optional[bool] = None,
Shijie's avatar
Shijie committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    ) -> paddle.Tensor:
        """
        Transformer Layer: attention block and a feedforward network (MLP)

        .. note::

            Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type`
            is set to `"causal"`.

        Parameters
        ----------
        hidden_states : paddle.Tensor
             Input tensor.
        attention_mask : Optional[paddle.Tensor], default = `None`
             Boolean tensor used to mask out self-attention softmax input.
        encoder_output : Optional[paddle.Tensor], default = `None`
             Output of the encoder block to be fed into the decoder block if using
             `layer_type="decoder"`.
        enc_dec_attn_mask : Optional[paddle.Tensor], default = `None`
             Boolean tensor used to mask out inter-attention softmax input if using
             `layer_type="decoder"`.
272
273
274
        rotary_pos_emb : Optional[Tuple[paddle.Tensor, paddle.Tensor]], default = `None`
             Embeddings for query and key tensors for applying rotary position
             embedding. By default no input embedding is applied
Shijie's avatar
Shijie committed
275
276
277
278
279
        core_attention_bias_type: str, default = `no_bias`
        core_attention_bias: Optional[paddle.Tensor], default = `None`
                    Bias tensor for Q * K.T
        set_zero: bool, default = `True`
                    Whether to set output tensors to 0 or not before use.
Tian Zheng's avatar
Tian Zheng committed
280
281
282
283
284
        recompute_core_attention: bool, default = `False`
                                  If true, forward activations for core attention are recomputed
                                  during the backward pass in order to save memory that would
                                  otherwise be occupied to store the forward activations until
                                  backprop.
285
286
287
288
289
290
291
292
293
294
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
Shijie's avatar
Shijie committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
        """

        if self.self_attn_mask_type != "causal" and attention_mask is not None:
            assert (attention_mask.dtype == paddle.bool), "Attention mask must be a boolean tensor"

        assert core_attention_bias_type in ['no_bias'], f"Only no_bias is supported currently, " \
            f"but receive core_attention_bias_type = {core_attention_bias_type}"

        # Self attention.
        self_attention_outputs = self.self_attention(
            hidden_states,
            attention_mask,
            core_attention_bias_type=core_attention_bias_type,
            core_attention_bias=core_attention_bias,
            set_zero=set_zero,
310
            rotary_pos_emb=rotary_pos_emb,
Tian Zheng's avatar
Tian Zheng committed
311
            recompute_core_attention=recompute_core_attention,
312
            is_first_microbatch=is_first_microbatch,
Shijie's avatar
Shijie committed
313
314
315
316
317
318
319
320
321
        )

        if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
            attention_output, residual = self_attention_outputs
        else:
            attention_output = self_attention_outputs
            residual = hidden_states

        # dropoout add.
322
        with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name):
323
            bda_output = self.fused_dropout_add1(attention_output, residual)
Shijie's avatar
Shijie committed
324
325
326
327
328
329
330
331
332
333

        # Cross attention.
        if self.layer_type == "decoder":
            inter_attention_outputs = self.inter_attention(
                bda_output,
                enc_dec_attn_mask,
                encoder_output=encoder_output,
                core_attention_bias_type=core_attention_bias_type,
                core_attention_bias=core_attention_bias,
                set_zero=set_zero,
Tian Zheng's avatar
Tian Zheng committed
334
                recompute_core_attention=recompute_core_attention,
335
                is_first_microbatch=is_first_microbatch,
Shijie's avatar
Shijie committed
336
337
338
339
340
341
342
            )
            if self.apply_residual_connection_post_layernorm:
                attention_output, residual = inter_attention_outputs
            else:
                attention_output = inter_attention_outputs
                residual = bda_output

343
344
            with track_rng_state(enable=self.tensor_parallel,
                                 name=self.hidden_dropout_rng_state_name):
345
                bda_output = self.fused_dropout_add2(attention_output, residual)
Shijie's avatar
Shijie committed
346
347

        # MLP.
348
        mlp_outputs = self.layernorm_mlp(bda_output, is_first_microbatch=is_first_microbatch)
Shijie's avatar
Shijie committed
349
350
351
352
353
354
355
        if self.apply_residual_connection_post_layernorm:
            mlp_output, residual = mlp_outputs
        else:
            mlp_output = mlp_outputs
            residual = bda_output

        # dropoout add.
356
        with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name):
357
            output = self.fused_dropout_add3(mlp_output, residual)
Shijie's avatar
Shijie committed
358
359
360
361
362
363
364

        # For BERT like architectures.
        if self.output_layernorm:
            output = self.layernorm(output)

        # output: [b, s, hidden]
        return output