transformer.py 16 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Praxis Modules related Transformer
"""
7
from dataclasses import field
8
from functools import partial
9
from typing import Optional, Sequence, Tuple
10
import warnings
11
12
13
14
15
16

from praxis import pax_fiddle
from praxis.base_layer import WeightInit
from praxis.pytypes import JTensor

from .module import TransformerEngineBaseLayer
17
from ..flax.transformer import TransformerLayerType
18
from ..flax.transformer import DotProductAttention as flax_DotProductAttention
19
20
21
from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention
from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases
from ..flax.transformer import TransformerLayer as flax_TransformerLayer
22
from ..attention import AttnBiasType, AttnMaskType
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38


class RelativePositionBiases(TransformerEngineBaseLayer):
    """RelativePositionBiases"""

    num_buckets: int = 32
    max_distance: int = 128
    num_attention_heads: int = 64
    embedding_init: WeightInit = None
    embedding_axes: Tuple[str, ...] = ()

    @staticmethod
    def generate_embedding_init(init, num_attention_heads, num_buckets):
        """generate_embedding_init"""
        embedding_init = init
        if embedding_init is None:
39
            rb_stddev = (num_attention_heads * num_buckets) ** -0.5
40
41
42
43
44
45
46
47
            embedding_init = WeightInit.Gaussian(rb_stddev)
        return embedding_init

    def setup(self) -> None:
        """setup"""
        super().setup()

        embedding_init = RelativePositionBiases.generate_embedding_init(
48
49
            self.embedding_init, self.num_attention_heads, self.num_buckets
        )
50

51
52
53
54
55
56
57
58
59
60
61
        rpb_cls = partial(
            flax_RelativePositionBiases,
            num_buckets=self.num_buckets,
            max_distance=self.max_distance,
            num_attention_heads=self.num_attention_heads,
            embedding_init=TransformerEngineBaseLayer.generate_params_init(
                "rel_embedding", embedding_init
            ),
            embedding_axes=self.embedding_axes,
            dtype=self.dtype,
        )
62
63
64
65
66
67
68
69

        self.create_layer("relative_position_bias", rpb_cls)

    def __call__(self, q_seqlen: JTensor, k_seqlen: JTensor, bidirectional: bool = True) -> JTensor:
        """__call__"""
        return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional)


70
71
72
73
74
75
class DotProductAttention(TransformerEngineBaseLayer):
    """DotProductAttention"""

    head_dim: int = 0
    num_attention_heads: int = 0
    num_gqa_groups: Optional[int] = None
76
77
    attention_dropout: float = 0.0
    attn_mask_type: AttnMaskType = "causal"
78
    attn_bias_type: AttnBiasType = None
79
    dropout_rng_name: str = "dropout"
80
    float32_logits: bool = False
81
    qkv_layout: str = "bshd_bshd_bshd"
82
83
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True
84
    window_size: Optional[Tuple[int, int]] = None
85
86
87
88
89

    def setup(self) -> None:
        """setup"""
        super().setup()

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        assert self.head_dim > 0, f"{self.head_dim=}"
        assert self.num_attention_heads > 0, f"{self.num_attention_heads=}"

        dpa_cls = partial(
            flax_DotProductAttention,
            head_dim=self.head_dim,
            num_attention_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            attn_mask_type=self.attn_mask_type,
            attn_bias_type=self.attn_bias_type,
            attention_dropout=self.attention_dropout,
            dtype=self.dtype,
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_logits,
            qkv_layout=self.qkv_layout,
            scale_factor=self.scale_factor,
            transpose_batch_sequence=self.transpose_batch_sequence,
107
            window_size=self.window_size,
108
        )
109
110
111

        self.create_layer("dot_product_attention", dpa_cls)

112
113
114
115
116
117
118
119
120
121
    def __call__(
        self,
        query: JTensor,
        key: JTensor,
        value: JTensor,
        mask: Optional[JTensor] = None,
        bias: Optional[JTensor] = None,
        *,
        deterministic: bool = False,
    ) -> JTensor:
122
        """__call__"""
123
124
125
        return self.dot_product_attention(
            query, key, value, mask, bias, deterministic=deterministic
        )
126
127


128
129
130
class MultiHeadAttention(TransformerEngineBaseLayer):
    """MultiHeadAttention"""

131
132
133
    head_dim: int = 0
    num_attention_heads: int = 0
    num_gqa_groups: Optional[int] = None
134
135
    attention_dropout: float = 0.0
    dropout_rng_name: str = "dropout"
136
    input_layernorm: bool = True
137
138
139
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
    zero_centered_gamma: bool = False
140
    return_layernorm_output: bool = False
141
    use_bias: bool = False
142
143
144
    bias_init: WeightInit = field(  # pylint: disable=invalid-field-call
        default_factory=partial(WeightInit.Constant, scale=0.0)
    )
145
    attn_mask_type: str = "causal"
146
    attn_bias_type: Optional[str] = None
147
148
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
149
150
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
151
152
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
153
    fuse_qkv_params: bool = True
154
    transpose_batch_sequence: bool = True
155
    enable_sequence_parallel: bool = False
156
157
158
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    float32_logits: bool = False
159
    window_size: Optional[Tuple[int, int]] = None
160

161
162
163
164
165
166
167
    # Deprecated parameters
    num_heads: Optional[int] = None
    dropout_rate: Optional[float] = None
    output_layernorm: Optional[bool] = None
    apply_residual_connection_post_layernorm: Optional[bool] = None
    fuse_qkv: Optional[bool] = None

zlsh80826's avatar
zlsh80826 committed
168
    def __post_init__(self):
169
170
171
172
173
        # Deal with the deprecated parameters
        if self.num_heads is not None:
            self.num_attention_heads = self.num_heads
            warnings.warn(
                f"{__class__}.num_heads is deprecated. It will be removed recently. "
174
175
176
                f"Please uses {__class__}.num_attention_heads as the new API.",
                DeprecationWarning,
            )
177
178
179
180
        if self.dropout_rate is not None:
            self.attention_dropout = self.dropout_rate
            warnings.warn(
                f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
181
182
183
                f"Please use {__class__}.attention_dropout as the new API.",
                DeprecationWarning,
            )
184
185
186
187
        if self.apply_residual_connection_post_layernorm is not None:
            warnings.warn(
                f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
                f"It will be removed recently, please use {__class__}.return_layernorm_output.",
188
189
                DeprecationWarning,
            )
190
191
192
        if self.fuse_qkv is not None:
            warnings.warn(
                f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
193
194
195
                f"Please use {__class__}.fuse_qkv_params as the new API.",
                DeprecationWarning,
            )
196
197
        assert self.output_layernorm is None, (
            f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
198
199
            f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
        )
200

zlsh80826's avatar
zlsh80826 committed
201
202
203
204
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_heads
        super().__post_init__()

205
206
207
208
    def setup(self) -> None:
        """setup"""
        super().setup()

209
210
        assert self.head_dim > 0, f"{self.head_dim=}"
        assert self.num_attention_heads > 0, f"{self.num_attention_heads=}"
211

212
213
214
215
        mha_cls = partial(
            flax_MultiHeadAttention,
            dtype=self.dtype,
            head_dim=self.head_dim,
216
            num_attention_heads=self.num_attention_heads,
zlsh80826's avatar
zlsh80826 committed
217
            num_gqa_groups=self.num_gqa_groups,
218
            attention_dropout=self.attention_dropout,
219
            dropout_rng_name=self.dropout_rng_name,
220
            input_layernorm=self.input_layernorm,
221
222
223
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
224
            return_layernorm_output=self.return_layernorm_output,
225
226
227
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
            use_bias=self.use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
228
            attn_mask_type=self.attn_mask_type,
229
            attn_bias_type=self.attn_bias_type,
230
231
232
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
233
234
235
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
236
            fuse_qkv_params=self.fuse_qkv_params,
237
            transpose_batch_sequence=self.transpose_batch_sequence,
238
            enable_sequence_parallel=self.enable_sequence_parallel,
239
240
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
241
            float32_logits=self.float32_logits,
242
            window_size=self.window_size,
243
        )
244
245
246

        self.create_layer("multi_head_attn", mha_cls)

247
248
249
250
251
252
253
254
255
256
    def __call__(
        self,
        inputs_q: JTensor,
        inputs_kv: JTensor,
        mask: Optional[JTensor] = None,
        bias: Optional[JTensor] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> JTensor:
257
        """__call__"""
258
259
260
        return self.multi_head_attn(
            inputs_q, inputs_kv, mask, bias, decode=decode, deterministic=deterministic
        )
261
262
263
264
265
266
267
268


class TransformerLayer(TransformerEngineBaseLayer):
    """TransformerLayer"""

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
269
    num_gqa_groups: Optional[int] = None
270
    layernorm_type: str = "layernorm"
271
272
273
274
275
    layernorm_epsilon: float = 1e-6
    zero_centered_gamma: bool = False
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
276
277
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
278
279
    dropout_rng_name: str = "dropout"
    mlp_activations: Sequence[str] = ("relu",)
280
    use_bias: bool = False
281
282
283
    bias_init: WeightInit = field(  # pylint: disable=invalid-field-call
        default_factory=partial(WeightInit.Constant, scale=0.0)
    )
284
285
286
287
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
288
    self_attn_mask_type: str = "causal"
289
    self_attn_bias_type: Optional[str] = None
290
291
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
292
293
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
294
295
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
296
297
298
299
300
    enable_relative_embedding: bool = True
    relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
    transpose_batch_sequence: bool = False
301
    enable_sequence_parallel: bool = False
302
303
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
304
    window_size: Optional[Tuple[int, int]] = None
305

zlsh80826's avatar
zlsh80826 committed
306
307
308
309
310
    def __post_init__(self):
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
        super().__post_init__()

311
312
313
314
315
316
    def setup(self) -> None:
        """setup"""
        super().setup()

        relative_embedding_flax_module = None
        if self.enable_relative_embedding and self.relative_embedding is not None:
317
318
            assert self.relative_embedding.num_attention_heads == self.num_attention_heads, (
                "TransformerLayer.relative_embedding.num_attention_heads shoule be"
319
                "the same as TransformerLayer.num_attention_heads."
320
            )
321
322

            embedding_init = RelativePositionBiases.generate_embedding_init(
323
324
325
326
                self.relative_embedding.embedding_init,
                self.relative_embedding.num_attention_heads,
                self.relative_embedding.num_buckets,
            )
327
328
329
330
331
332

            relative_embedding_flax_module = flax_RelativePositionBiases(
                num_buckets=self.relative_embedding.num_buckets,
                max_distance=self.relative_embedding.max_distance,
                num_attention_heads=self.relative_embedding.num_attention_heads,
                embedding_init=TransformerEngineBaseLayer.generate_params_init(
333
334
                    "rel_embedding", embedding_init
                ),
335
                embedding_axes=self.relative_embedding.embedding_axes,
336
337
                dtype=self.relative_embedding.dtype,
            )
338
339
340
341
342
343
344

        transformerlayer_cls = partial(
            flax_TransformerLayer,
            dtype=self.dtype,
            hidden_size=self.hidden_size,
            mlp_hidden_size=self.mlp_hidden_size,
            num_attention_heads=self.num_attention_heads,
zlsh80826's avatar
zlsh80826 committed
345
            num_gqa_groups=self.num_gqa_groups,
346
347
348
349
350
351
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            hidden_dropout=self.hidden_dropout,
            hidden_dropout_dims=self.hidden_dropout_dims,
            attention_dropout=self.attention_dropout,
352
353
            intermediate_dropout=self.intermediate_dropout,
            intermediate_dropout_dims=self.intermediate_dropout_dims,
354
355
            dropout_rng_name=self.dropout_rng_name,
            mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
356
357
                "mha_kernel", self.params_init
            ),
358
            mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
359
360
                "mlp_kernel", self.params_init
            ),
361
362
363
364
365
366
367
            mlp_activations=self.mlp_activations,
            use_bias=self.use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
            apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
            output_layernorm=self.output_layernorm,
            float32_attention_logits=self.float32_attention_logits,
            layer_type=self.layer_type,
368
            self_attn_mask_type=self.self_attn_mask_type,
369
            self_attn_bias_type=self.self_attn_bias_type,
370
371
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
372
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
373
374
375
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
376
377
378
379
380
            enable_relative_embedding=self.enable_relative_embedding,
            relative_embedding=relative_embedding_flax_module,
            drop_path=self.drop_path,
            fuse_qkv_params=self.fuse_qkv_params,
            transpose_batch_sequence=self.transpose_batch_sequence,
381
            enable_sequence_parallel=self.enable_sequence_parallel,
382
            scale_attn_logits=self.scale_attn_logits,
383
            scaled_query_init=self.scaled_query_init,
384
            window_size=self.window_size,
385
        )
386
387
388

        self.create_layer("transformerlayer", transformerlayer_cls)

389
390
391
392
393
394
395
396
397
398
    def __call__(
        self,
        inputs: JTensor,
        encoded: JTensor = None,
        attention_mask: JTensor = None,
        encoder_decoder_mask: JTensor = None,
        deterministic: bool = False,
        decode: bool = False,
        max_decode_length: bool = None,
    ) -> JTensor:
399
        """__call__"""
400
401
402
403
404
405
406
407
408
        return self.transformerlayer(
            inputs,
            encoded,
            attention_mask,
            encoder_decoder_mask,
            deterministic,
            decode,
            max_decode_length,
        )