transformer.py 15.5 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""
Praxis Modules related Transformer
"""
from functools import partial
8
from typing import Optional, Sequence, Tuple
9
import warnings
10
11
12
13
14
15

from praxis import pax_fiddle
from praxis.base_layer import WeightInit
from praxis.pytypes import JTensor

from .module import TransformerEngineBaseLayer
16
from ..flax.transformer import TransformerLayerType
17
from ..flax.transformer import DotProductAttention as flax_DotProductAttention
18
19
20
from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention
from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases
from ..flax.transformer import TransformerLayer as flax_TransformerLayer
21
from ..attention import AttnBiasType, AttnMaskType
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37


class RelativePositionBiases(TransformerEngineBaseLayer):
    """RelativePositionBiases"""

    num_buckets: int = 32
    max_distance: int = 128
    num_attention_heads: int = 64
    embedding_init: WeightInit = None
    embedding_axes: Tuple[str, ...] = ()

    @staticmethod
    def generate_embedding_init(init, num_attention_heads, num_buckets):
        """generate_embedding_init"""
        embedding_init = init
        if embedding_init is None:
38
            rb_stddev = (num_attention_heads * num_buckets) ** -0.5
39
40
41
42
43
44
45
46
            embedding_init = WeightInit.Gaussian(rb_stddev)
        return embedding_init

    def setup(self) -> None:
        """setup"""
        super().setup()

        embedding_init = RelativePositionBiases.generate_embedding_init(
47
48
            self.embedding_init, self.num_attention_heads, self.num_buckets
        )
49

50
51
52
53
54
55
56
57
58
59
60
        rpb_cls = partial(
            flax_RelativePositionBiases,
            num_buckets=self.num_buckets,
            max_distance=self.max_distance,
            num_attention_heads=self.num_attention_heads,
            embedding_init=TransformerEngineBaseLayer.generate_params_init(
                "rel_embedding", embedding_init
            ),
            embedding_axes=self.embedding_axes,
            dtype=self.dtype,
        )
61
62
63
64
65
66
67
68

        self.create_layer("relative_position_bias", rpb_cls)

    def __call__(self, q_seqlen: JTensor, k_seqlen: JTensor, bidirectional: bool = True) -> JTensor:
        """__call__"""
        return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional)


69
70
71
72
73
74
class DotProductAttention(TransformerEngineBaseLayer):
    """DotProductAttention"""

    head_dim: int = 0
    num_attention_heads: int = 0
    num_gqa_groups: Optional[int] = None
75
76
    attention_dropout: float = 0.0
    attn_mask_type: AttnMaskType = "causal"
77
    attn_bias_type: AttnBiasType = None
78
    dropout_rng_name: str = "dropout"
79
    float32_logits: bool = False
80
    qkv_layout: str = "bshd_bshd_bshd"
81
82
83
84
85
86
87
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True

    def setup(self) -> None:
        """setup"""
        super().setup()

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        assert self.head_dim > 0, f"{self.head_dim=}"
        assert self.num_attention_heads > 0, f"{self.num_attention_heads=}"

        dpa_cls = partial(
            flax_DotProductAttention,
            head_dim=self.head_dim,
            num_attention_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            attn_mask_type=self.attn_mask_type,
            attn_bias_type=self.attn_bias_type,
            attention_dropout=self.attention_dropout,
            dtype=self.dtype,
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_logits,
            qkv_layout=self.qkv_layout,
            scale_factor=self.scale_factor,
            transpose_batch_sequence=self.transpose_batch_sequence,
        )
106
107
108

        self.create_layer("dot_product_attention", dpa_cls)

109
110
111
112
113
114
115
116
117
118
    def __call__(
        self,
        query: JTensor,
        key: JTensor,
        value: JTensor,
        mask: Optional[JTensor] = None,
        bias: Optional[JTensor] = None,
        *,
        deterministic: bool = False,
    ) -> JTensor:
119
        """__call__"""
120
121
122
        return self.dot_product_attention(
            query, key, value, mask, bias, deterministic=deterministic
        )
123
124


125
126
127
class MultiHeadAttention(TransformerEngineBaseLayer):
    """MultiHeadAttention"""

128
129
130
    head_dim: int = 0
    num_attention_heads: int = 0
    num_gqa_groups: Optional[int] = None
131
132
    attention_dropout: float = 0.0
    dropout_rng_name: str = "dropout"
133
    input_layernorm: bool = True
134
135
136
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
    zero_centered_gamma: bool = False
137
    return_layernorm_output: bool = False
138
139
    use_bias: bool = False
    bias_init: WeightInit = WeightInit.Constant(0.0)
140
    attn_mask_type: str = "causal"
141
    attn_bias_type: Optional[str] = None
142
143
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
144
145
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
146
147
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
148
    fuse_qkv_params: bool = True
149
    transpose_batch_sequence: bool = True
150
    enable_sequence_parallel: bool = False
151
152
153
154
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    float32_logits: bool = False

155
156
157
158
159
160
161
    # Deprecated parameters
    num_heads: Optional[int] = None
    dropout_rate: Optional[float] = None
    output_layernorm: Optional[bool] = None
    apply_residual_connection_post_layernorm: Optional[bool] = None
    fuse_qkv: Optional[bool] = None

zlsh80826's avatar
zlsh80826 committed
162
    def __post_init__(self):
163
164
165
166
167
        # Deal with the deprecated parameters
        if self.num_heads is not None:
            self.num_attention_heads = self.num_heads
            warnings.warn(
                f"{__class__}.num_heads is deprecated. It will be removed recently. "
168
169
170
                f"Please uses {__class__}.num_attention_heads as the new API.",
                DeprecationWarning,
            )
171
172
173
174
        if self.dropout_rate is not None:
            self.attention_dropout = self.dropout_rate
            warnings.warn(
                f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
175
176
177
                f"Please use {__class__}.attention_dropout as the new API.",
                DeprecationWarning,
            )
178
179
180
181
        if self.apply_residual_connection_post_layernorm is not None:
            warnings.warn(
                f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
                f"It will be removed recently, please use {__class__}.return_layernorm_output.",
182
183
                DeprecationWarning,
            )
184
185
186
        if self.fuse_qkv is not None:
            warnings.warn(
                f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
187
188
189
                f"Please use {__class__}.fuse_qkv_params as the new API.",
                DeprecationWarning,
            )
190
191
        assert self.output_layernorm is None, (
            f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
192
193
            f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
        )
194

zlsh80826's avatar
zlsh80826 committed
195
196
197
198
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_heads
        super().__post_init__()

199
200
201
202
    def setup(self) -> None:
        """setup"""
        super().setup()

203
204
        assert self.head_dim > 0, f"{self.head_dim=}"
        assert self.num_attention_heads > 0, f"{self.num_attention_heads=}"
205

206
207
208
209
        mha_cls = partial(
            flax_MultiHeadAttention,
            dtype=self.dtype,
            head_dim=self.head_dim,
210
            num_attention_heads=self.num_attention_heads,
zlsh80826's avatar
zlsh80826 committed
211
            num_gqa_groups=self.num_gqa_groups,
212
            attention_dropout=self.attention_dropout,
213
            dropout_rng_name=self.dropout_rng_name,
214
            input_layernorm=self.input_layernorm,
215
216
217
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
218
            return_layernorm_output=self.return_layernorm_output,
219
220
221
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
            use_bias=self.use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
222
            attn_mask_type=self.attn_mask_type,
223
            attn_bias_type=self.attn_bias_type,
224
225
226
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
227
228
229
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
230
            fuse_qkv_params=self.fuse_qkv_params,
231
            transpose_batch_sequence=self.transpose_batch_sequence,
232
            enable_sequence_parallel=self.enable_sequence_parallel,
233
234
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
235
236
            float32_logits=self.float32_logits,
        )
237
238
239

        self.create_layer("multi_head_attn", mha_cls)

240
241
242
243
244
245
246
247
248
249
    def __call__(
        self,
        inputs_q: JTensor,
        inputs_kv: JTensor,
        mask: Optional[JTensor] = None,
        bias: Optional[JTensor] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> JTensor:
250
        """__call__"""
251
252
253
        return self.multi_head_attn(
            inputs_q, inputs_kv, mask, bias, decode=decode, deterministic=deterministic
        )
254
255
256
257
258
259
260
261


class TransformerLayer(TransformerEngineBaseLayer):
    """TransformerLayer"""

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
262
    num_gqa_groups: Optional[int] = None
263
    layernorm_type: str = "layernorm"
264
265
266
267
268
    layernorm_epsilon: float = 1e-6
    zero_centered_gamma: bool = False
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
269
270
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
271
272
    dropout_rng_name: str = "dropout"
    mlp_activations: Sequence[str] = ("relu",)
273
274
275
276
277
278
    use_bias: bool = False
    bias_init: WeightInit = WeightInit.Constant(0.0)
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
279
    self_attn_mask_type: str = "causal"
280
    self_attn_bias_type: Optional[str] = None
281
282
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
283
284
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
285
286
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
287
288
289
290
291
    enable_relative_embedding: bool = True
    relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
    transpose_batch_sequence: bool = False
292
    enable_sequence_parallel: bool = False
293
294
295
    scale_attn_logits: bool = False
    scaled_query_init: bool = True

zlsh80826's avatar
zlsh80826 committed
296
297
298
299
300
    def __post_init__(self):
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
        super().__post_init__()

301
302
303
304
305
306
    def setup(self) -> None:
        """setup"""
        super().setup()

        relative_embedding_flax_module = None
        if self.enable_relative_embedding and self.relative_embedding is not None:
307
308
            assert self.relative_embedding.num_attention_heads == self.num_attention_heads, (
                "TransformerLayer.relative_embedding.num_attention_heads shoule be"
309
                "the same as TransformerLayer.num_attention_heads."
310
            )
311
312

            embedding_init = RelativePositionBiases.generate_embedding_init(
313
314
315
316
                self.relative_embedding.embedding_init,
                self.relative_embedding.num_attention_heads,
                self.relative_embedding.num_buckets,
            )
317
318
319
320
321
322

            relative_embedding_flax_module = flax_RelativePositionBiases(
                num_buckets=self.relative_embedding.num_buckets,
                max_distance=self.relative_embedding.max_distance,
                num_attention_heads=self.relative_embedding.num_attention_heads,
                embedding_init=TransformerEngineBaseLayer.generate_params_init(
323
324
                    "rel_embedding", embedding_init
                ),
325
                embedding_axes=self.relative_embedding.embedding_axes,
326
327
                dtype=self.relative_embedding.dtype,
            )
328
329
330
331
332
333
334

        transformerlayer_cls = partial(
            flax_TransformerLayer,
            dtype=self.dtype,
            hidden_size=self.hidden_size,
            mlp_hidden_size=self.mlp_hidden_size,
            num_attention_heads=self.num_attention_heads,
zlsh80826's avatar
zlsh80826 committed
335
            num_gqa_groups=self.num_gqa_groups,
336
337
338
339
340
341
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            hidden_dropout=self.hidden_dropout,
            hidden_dropout_dims=self.hidden_dropout_dims,
            attention_dropout=self.attention_dropout,
342
343
            intermediate_dropout=self.intermediate_dropout,
            intermediate_dropout_dims=self.intermediate_dropout_dims,
344
345
            dropout_rng_name=self.dropout_rng_name,
            mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
346
347
                "mha_kernel", self.params_init
            ),
348
            mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
349
350
                "mlp_kernel", self.params_init
            ),
351
352
353
354
355
356
357
            mlp_activations=self.mlp_activations,
            use_bias=self.use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
            apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
            output_layernorm=self.output_layernorm,
            float32_attention_logits=self.float32_attention_logits,
            layer_type=self.layer_type,
358
            self_attn_mask_type=self.self_attn_mask_type,
359
            self_attn_bias_type=self.self_attn_bias_type,
360
361
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
362
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
363
364
365
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
366
367
368
369
370
            enable_relative_embedding=self.enable_relative_embedding,
            relative_embedding=relative_embedding_flax_module,
            drop_path=self.drop_path,
            fuse_qkv_params=self.fuse_qkv_params,
            transpose_batch_sequence=self.transpose_batch_sequence,
371
            enable_sequence_parallel=self.enable_sequence_parallel,
372
            scale_attn_logits=self.scale_attn_logits,
373
374
            scaled_query_init=self.scaled_query_init,
        )
375
376
377

        self.create_layer("transformerlayer", transformerlayer_cls)

378
379
380
381
382
383
384
385
386
387
    def __call__(
        self,
        inputs: JTensor,
        encoded: JTensor = None,
        attention_mask: JTensor = None,
        encoder_decoder_mask: JTensor = None,
        deterministic: bool = False,
        decode: bool = False,
        max_decode_length: bool = None,
    ) -> JTensor:
388
        """__call__"""
389
390
391
392
393
394
395
396
397
        return self.transformerlayer(
            inputs,
            encoded,
            attention_mask,
            encoder_decoder_mask,
            deterministic,
            decode,
            max_decode_length,
        )