transformer.py 15.8 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""
Praxis Modules related Transformer
"""
from functools import partial
8
from typing import Optional, Sequence, Tuple
9
import warnings
10
11
12
13
14
15

from praxis import pax_fiddle
from praxis.base_layer import WeightInit
from praxis.pytypes import JTensor

from .module import TransformerEngineBaseLayer
16
from ..flax.transformer import TransformerLayerType
17
from ..flax.transformer import DotProductAttention as flax_DotProductAttention
18
19
20
from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention
from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases
from ..flax.transformer import TransformerLayer as flax_TransformerLayer
21
from ..attention import AttnBiasType, AttnMaskType
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37


class RelativePositionBiases(TransformerEngineBaseLayer):
    """RelativePositionBiases"""

    num_buckets: int = 32
    max_distance: int = 128
    num_attention_heads: int = 64
    embedding_init: WeightInit = None
    embedding_axes: Tuple[str, ...] = ()

    @staticmethod
    def generate_embedding_init(init, num_attention_heads, num_buckets):
        """generate_embedding_init"""
        embedding_init = init
        if embedding_init is None:
38
            rb_stddev = (num_attention_heads * num_buckets) ** -0.5
39
40
41
42
43
44
45
46
            embedding_init = WeightInit.Gaussian(rb_stddev)
        return embedding_init

    def setup(self) -> None:
        """setup"""
        super().setup()

        embedding_init = RelativePositionBiases.generate_embedding_init(
47
48
            self.embedding_init, self.num_attention_heads, self.num_buckets
        )
49

50
51
52
53
54
55
56
57
58
59
60
        rpb_cls = partial(
            flax_RelativePositionBiases,
            num_buckets=self.num_buckets,
            max_distance=self.max_distance,
            num_attention_heads=self.num_attention_heads,
            embedding_init=TransformerEngineBaseLayer.generate_params_init(
                "rel_embedding", embedding_init
            ),
            embedding_axes=self.embedding_axes,
            dtype=self.dtype,
        )
61
62
63
64
65
66
67
68

        self.create_layer("relative_position_bias", rpb_cls)

    def __call__(self, q_seqlen: JTensor, k_seqlen: JTensor, bidirectional: bool = True) -> JTensor:
        """__call__"""
        return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional)


69
70
71
72
73
74
class DotProductAttention(TransformerEngineBaseLayer):
    """DotProductAttention"""

    head_dim: int = 0
    num_attention_heads: int = 0
    num_gqa_groups: Optional[int] = None
75
76
    attention_dropout: float = 0.0
    attn_mask_type: AttnMaskType = "causal"
77
    attn_bias_type: AttnBiasType = None
78
    dropout_rng_name: str = "dropout"
79
    float32_logits: bool = False
80
    qkv_layout: str = "bshd_bshd_bshd"
81
82
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True
83
    window_size: Optional[Tuple[int, int]] = None
84
85
86
87
88

    def setup(self) -> None:
        """setup"""
        super().setup()

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        assert self.head_dim > 0, f"{self.head_dim=}"
        assert self.num_attention_heads > 0, f"{self.num_attention_heads=}"

        dpa_cls = partial(
            flax_DotProductAttention,
            head_dim=self.head_dim,
            num_attention_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            attn_mask_type=self.attn_mask_type,
            attn_bias_type=self.attn_bias_type,
            attention_dropout=self.attention_dropout,
            dtype=self.dtype,
            dropout_rng_name=self.dropout_rng_name,
            float32_logits=self.float32_logits,
            qkv_layout=self.qkv_layout,
            scale_factor=self.scale_factor,
            transpose_batch_sequence=self.transpose_batch_sequence,
106
            window_size=self.window_size,
107
        )
108
109
110

        self.create_layer("dot_product_attention", dpa_cls)

111
112
113
114
115
116
117
118
119
120
    def __call__(
        self,
        query: JTensor,
        key: JTensor,
        value: JTensor,
        mask: Optional[JTensor] = None,
        bias: Optional[JTensor] = None,
        *,
        deterministic: bool = False,
    ) -> JTensor:
121
        """__call__"""
122
123
124
        return self.dot_product_attention(
            query, key, value, mask, bias, deterministic=deterministic
        )
125
126


127
128
129
class MultiHeadAttention(TransformerEngineBaseLayer):
    """MultiHeadAttention"""

130
131
132
    head_dim: int = 0
    num_attention_heads: int = 0
    num_gqa_groups: Optional[int] = None
133
134
    attention_dropout: float = 0.0
    dropout_rng_name: str = "dropout"
135
    input_layernorm: bool = True
136
137
138
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
    zero_centered_gamma: bool = False
139
    return_layernorm_output: bool = False
140
141
    use_bias: bool = False
    bias_init: WeightInit = WeightInit.Constant(0.0)
142
    attn_mask_type: str = "causal"
143
    attn_bias_type: Optional[str] = None
144
145
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
146
147
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
148
149
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
150
    fuse_qkv_params: bool = True
151
    transpose_batch_sequence: bool = True
152
    enable_sequence_parallel: bool = False
153
154
155
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    float32_logits: bool = False
156
    window_size: Optional[Tuple[int, int]] = None
157

158
159
160
161
162
163
164
    # Deprecated parameters
    num_heads: Optional[int] = None
    dropout_rate: Optional[float] = None
    output_layernorm: Optional[bool] = None
    apply_residual_connection_post_layernorm: Optional[bool] = None
    fuse_qkv: Optional[bool] = None

zlsh80826's avatar
zlsh80826 committed
165
    def __post_init__(self):
166
167
168
169
170
        # Deal with the deprecated parameters
        if self.num_heads is not None:
            self.num_attention_heads = self.num_heads
            warnings.warn(
                f"{__class__}.num_heads is deprecated. It will be removed recently. "
171
172
173
                f"Please uses {__class__}.num_attention_heads as the new API.",
                DeprecationWarning,
            )
174
175
176
177
        if self.dropout_rate is not None:
            self.attention_dropout = self.dropout_rate
            warnings.warn(
                f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
178
179
180
                f"Please use {__class__}.attention_dropout as the new API.",
                DeprecationWarning,
            )
181
182
183
184
        if self.apply_residual_connection_post_layernorm is not None:
            warnings.warn(
                f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
                f"It will be removed recently, please use {__class__}.return_layernorm_output.",
185
186
                DeprecationWarning,
            )
187
188
189
        if self.fuse_qkv is not None:
            warnings.warn(
                f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
190
191
192
                f"Please use {__class__}.fuse_qkv_params as the new API.",
                DeprecationWarning,
            )
193
194
        assert self.output_layernorm is None, (
            f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
195
196
            f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
        )
197

zlsh80826's avatar
zlsh80826 committed
198
199
200
201
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_heads
        super().__post_init__()

202
203
204
205
    def setup(self) -> None:
        """setup"""
        super().setup()

206
207
        assert self.head_dim > 0, f"{self.head_dim=}"
        assert self.num_attention_heads > 0, f"{self.num_attention_heads=}"
208

209
210
211
212
        mha_cls = partial(
            flax_MultiHeadAttention,
            dtype=self.dtype,
            head_dim=self.head_dim,
213
            num_attention_heads=self.num_attention_heads,
zlsh80826's avatar
zlsh80826 committed
214
            num_gqa_groups=self.num_gqa_groups,
215
            attention_dropout=self.attention_dropout,
216
            dropout_rng_name=self.dropout_rng_name,
217
            input_layernorm=self.input_layernorm,
218
219
220
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
221
            return_layernorm_output=self.return_layernorm_output,
222
223
224
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
            use_bias=self.use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
225
            attn_mask_type=self.attn_mask_type,
226
            attn_bias_type=self.attn_bias_type,
227
228
229
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
230
231
232
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
233
            fuse_qkv_params=self.fuse_qkv_params,
234
            transpose_batch_sequence=self.transpose_batch_sequence,
235
            enable_sequence_parallel=self.enable_sequence_parallel,
236
237
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
238
            float32_logits=self.float32_logits,
239
            window_size=self.window_size,
240
        )
241
242
243

        self.create_layer("multi_head_attn", mha_cls)

244
245
246
247
248
249
250
251
252
253
    def __call__(
        self,
        inputs_q: JTensor,
        inputs_kv: JTensor,
        mask: Optional[JTensor] = None,
        bias: Optional[JTensor] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> JTensor:
254
        """__call__"""
255
256
257
        return self.multi_head_attn(
            inputs_q, inputs_kv, mask, bias, decode=decode, deterministic=deterministic
        )
258
259
260
261
262
263
264
265


class TransformerLayer(TransformerEngineBaseLayer):
    """TransformerLayer"""

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
266
    num_gqa_groups: Optional[int] = None
267
    layernorm_type: str = "layernorm"
268
269
270
271
272
    layernorm_epsilon: float = 1e-6
    zero_centered_gamma: bool = False
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
273
274
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
275
276
    dropout_rng_name: str = "dropout"
    mlp_activations: Sequence[str] = ("relu",)
277
278
279
280
281
282
    use_bias: bool = False
    bias_init: WeightInit = WeightInit.Constant(0.0)
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
283
    self_attn_mask_type: str = "causal"
284
    self_attn_bias_type: Optional[str] = None
285
286
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
287
288
    rotary_pos_emb_group_method: str = "consecutive"
    low_rank_adaptation_scope: str = "none"
289
290
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
291
292
293
294
295
    enable_relative_embedding: bool = True
    relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
    transpose_batch_sequence: bool = False
296
    enable_sequence_parallel: bool = False
297
298
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
299
    window_size: Optional[Tuple[int, int]] = None
300

zlsh80826's avatar
zlsh80826 committed
301
302
303
304
305
    def __post_init__(self):
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
        super().__post_init__()

306
307
308
309
310
311
    def setup(self) -> None:
        """setup"""
        super().setup()

        relative_embedding_flax_module = None
        if self.enable_relative_embedding and self.relative_embedding is not None:
312
313
            assert self.relative_embedding.num_attention_heads == self.num_attention_heads, (
                "TransformerLayer.relative_embedding.num_attention_heads shoule be"
314
                "the same as TransformerLayer.num_attention_heads."
315
            )
316
317

            embedding_init = RelativePositionBiases.generate_embedding_init(
318
319
320
321
                self.relative_embedding.embedding_init,
                self.relative_embedding.num_attention_heads,
                self.relative_embedding.num_buckets,
            )
322
323
324
325
326
327

            relative_embedding_flax_module = flax_RelativePositionBiases(
                num_buckets=self.relative_embedding.num_buckets,
                max_distance=self.relative_embedding.max_distance,
                num_attention_heads=self.relative_embedding.num_attention_heads,
                embedding_init=TransformerEngineBaseLayer.generate_params_init(
328
329
                    "rel_embedding", embedding_init
                ),
330
                embedding_axes=self.relative_embedding.embedding_axes,
331
332
                dtype=self.relative_embedding.dtype,
            )
333
334
335
336
337
338
339

        transformerlayer_cls = partial(
            flax_TransformerLayer,
            dtype=self.dtype,
            hidden_size=self.hidden_size,
            mlp_hidden_size=self.mlp_hidden_size,
            num_attention_heads=self.num_attention_heads,
zlsh80826's avatar
zlsh80826 committed
340
            num_gqa_groups=self.num_gqa_groups,
341
342
343
344
345
346
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            hidden_dropout=self.hidden_dropout,
            hidden_dropout_dims=self.hidden_dropout_dims,
            attention_dropout=self.attention_dropout,
347
348
            intermediate_dropout=self.intermediate_dropout,
            intermediate_dropout_dims=self.intermediate_dropout_dims,
349
350
            dropout_rng_name=self.dropout_rng_name,
            mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
351
352
                "mha_kernel", self.params_init
            ),
353
            mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
354
355
                "mlp_kernel", self.params_init
            ),
356
357
358
359
360
361
362
            mlp_activations=self.mlp_activations,
            use_bias=self.use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
            apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
            output_layernorm=self.output_layernorm,
            float32_attention_logits=self.float32_attention_logits,
            layer_type=self.layer_type,
363
            self_attn_mask_type=self.self_attn_mask_type,
364
            self_attn_bias_type=self.self_attn_bias_type,
365
366
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
367
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
368
369
370
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
371
372
373
374
375
            enable_relative_embedding=self.enable_relative_embedding,
            relative_embedding=relative_embedding_flax_module,
            drop_path=self.drop_path,
            fuse_qkv_params=self.fuse_qkv_params,
            transpose_batch_sequence=self.transpose_batch_sequence,
376
            enable_sequence_parallel=self.enable_sequence_parallel,
377
            scale_attn_logits=self.scale_attn_logits,
378
            scaled_query_init=self.scaled_query_init,
379
            window_size=self.window_size,
380
        )
381
382
383

        self.create_layer("transformerlayer", transformerlayer_cls)

384
385
386
387
388
389
390
391
392
393
    def __call__(
        self,
        inputs: JTensor,
        encoded: JTensor = None,
        attention_mask: JTensor = None,
        encoder_decoder_mask: JTensor = None,
        deterministic: bool = False,
        decode: bool = False,
        max_decode_length: bool = None,
    ) -> JTensor:
394
        """__call__"""
395
396
397
398
399
400
401
402
403
        return self.transformerlayer(
            inputs,
            encoded,
            attention_mask,
            encoder_decoder_mask,
            deterministic,
            decode,
            max_decode_length,
        )