transformer.py 15.9 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""
Praxis Modules related Transformer
"""
from functools import partial
8
from typing import Optional, Sequence, Tuple
9
import warnings
10
11
12
13
14
15

from praxis import pax_fiddle
from praxis.base_layer import WeightInit
from praxis.pytypes import JTensor

from .module import TransformerEngineBaseLayer
16
from ..flax.transformer import TransformerLayerType
17
from ..flax.transformer import DotProductAttention as flax_DotProductAttention
18
19
20
from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention
from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases
from ..flax.transformer import TransformerLayer as flax_TransformerLayer
21
from ..fused_attn import AttnBiasType, AttnMaskType
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64


class RelativePositionBiases(TransformerEngineBaseLayer):
    """RelativePositionBiases"""

    num_buckets: int = 32
    max_distance: int = 128
    num_attention_heads: int = 64
    embedding_init: WeightInit = None
    embedding_axes: Tuple[str, ...] = ()

    @staticmethod
    def generate_embedding_init(init, num_attention_heads, num_buckets):
        """generate_embedding_init"""
        embedding_init = init
        if embedding_init is None:
            rb_stddev = (num_attention_heads * num_buckets)**-0.5
            embedding_init = WeightInit.Gaussian(rb_stddev)
        return embedding_init

    def setup(self) -> None:
        """setup"""
        super().setup()

        embedding_init = RelativePositionBiases.generate_embedding_init(
            self.embedding_init, self.num_attention_heads, self.num_buckets)

        rpb_cls = partial(flax_RelativePositionBiases,
                          num_buckets=self.num_buckets,
                          max_distance=self.max_distance,
                          num_attention_heads=self.num_attention_heads,
                          embedding_init=TransformerEngineBaseLayer.generate_params_init(
                              "rel_embedding", embedding_init),
                          embedding_axes=self.embedding_axes,
                          dtype=self.dtype)

        self.create_layer("relative_position_bias", rpb_cls)

    def __call__(self, q_seqlen: JTensor, k_seqlen: JTensor, bidirectional: bool = True) -> JTensor:
        """__call__"""
        return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional)


65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class DotProductAttention(TransformerEngineBaseLayer):
    """DotProductAttention"""

    head_dim: int = 0
    num_attention_heads: int = 0
    num_gqa_groups: Optional[int] = None
    attention_dropout: float = 0.
    attn_mask_type: AttnMaskType = 'causal'
    attn_bias_type: AttnBiasType = None
    dropout_rng_name: str = 'dropout'
    float32_logits: bool = False
    qkv_layout: str = 'bshd_bshd_bshd'
    scale_factor: Optional[float] = None
    transpose_batch_sequence: bool = True

    def setup(self) -> None:
        """setup"""
        super().setup()

        assert self.head_dim > 0, f'{self.head_dim=}'
        assert self.num_attention_heads > 0, f'{self.num_attention_heads=}'

        dpa_cls = partial(flax_DotProductAttention,
                          head_dim=self.head_dim,
                          num_attention_heads=self.num_attention_heads,
                          num_gqa_groups=self.num_gqa_groups,
                          attn_mask_type=self.attn_mask_type,
                          attn_bias_type=self.attn_bias_type,
                          attention_dropout=self.attention_dropout,
                          dtype=self.dtype,
                          dropout_rng_name=self.dropout_rng_name,
                          float32_logits=self.float32_logits,
                          qkv_layout=self.qkv_layout,
                          scale_factor=self.scale_factor,
                          transpose_batch_sequence=self.transpose_batch_sequence)

        self.create_layer("dot_product_attention", dpa_cls)

    def __call__(self,
                 query: JTensor,
                 key: JTensor,
                 value: JTensor,
                 mask: Optional[JTensor] = None,
                 bias: Optional[JTensor] = None,
                 *,
                 deterministic: bool = False) -> JTensor:
        """__call__"""
        return self.dot_product_attention(query,
                                          key,
                                          value,
                                          mask,
                                          bias,
                                          deterministic=deterministic)


120
121
122
class MultiHeadAttention(TransformerEngineBaseLayer):
    """MultiHeadAttention"""

123
124
125
126
    head_dim: int = 0
    num_attention_heads: int = 0
    num_gqa_groups: Optional[int] = None
    attention_dropout: float = 0.
127
    dropout_rng_name: str = 'dropout'
128
    input_layernorm: bool = True
129
130
131
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
    zero_centered_gamma: bool = False
132
    return_layernorm_output: bool = False
133
134
    use_bias: bool = False
    bias_init: WeightInit = WeightInit.Constant(0.0)
135
    attn_mask_type: str = 'causal'
136
    attn_bias_type: Optional[str] = None
137
138
139
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
    rotary_pos_emb_group_method: str = 'consecutive'
140
141
142
    low_rank_adaptation_scope: str = 'none'
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
143
    fuse_qkv_params: bool = True
144
    transpose_batch_sequence: bool = True
145
    enable_sequence_parallel: bool = False
146
147
148
149
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    float32_logits: bool = False

150
151
152
153
154
155
156
    # Deprecated parameters
    num_heads: Optional[int] = None
    dropout_rate: Optional[float] = None
    output_layernorm: Optional[bool] = None
    apply_residual_connection_post_layernorm: Optional[bool] = None
    fuse_qkv: Optional[bool] = None

zlsh80826's avatar
zlsh80826 committed
157
    def __post_init__(self):
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        # Deal with the deprecated parameters
        if self.num_heads is not None:
            self.num_attention_heads = self.num_heads
            warnings.warn(
                f"{__class__}.num_heads is deprecated. It will be removed recently. "
                f"Please uses {__class__}.num_attention_heads as the new API.", DeprecationWarning)
        if self.dropout_rate is not None:
            self.attention_dropout = self.dropout_rate
            warnings.warn(
                f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
                f"Please use {__class__}.attention_dropout as the new API.", DeprecationWarning)
        if self.apply_residual_connection_post_layernorm is not None:
            warnings.warn(
                f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
                f"It will be removed recently, please use {__class__}.return_layernorm_output.",
                DeprecationWarning)
        if self.fuse_qkv is not None:
            warnings.warn(
                f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
                f"Please use {__class__}.fuse_qkv_params as the new API.", DeprecationWarning)
        assert self.output_layernorm is None, (
            f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
            f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm.")

zlsh80826's avatar
zlsh80826 committed
182
183
184
185
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_heads
        super().__post_init__()

186
187
188
189
    def setup(self) -> None:
        """setup"""
        super().setup()

190
191
192
        assert self.head_dim > 0, f'{self.head_dim=}'
        assert self.num_attention_heads > 0, f'{self.num_attention_heads=}'

193
194
195
196
        mha_cls = partial(
            flax_MultiHeadAttention,
            dtype=self.dtype,
            head_dim=self.head_dim,
197
            num_attention_heads=self.num_attention_heads,
zlsh80826's avatar
zlsh80826 committed
198
            num_gqa_groups=self.num_gqa_groups,
199
            attention_dropout=self.attention_dropout,
200
            dropout_rng_name=self.dropout_rng_name,
201
            input_layernorm=self.input_layernorm,
202
203
204
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
205
            return_layernorm_output=self.return_layernorm_output,
206
207
208
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
            use_bias=self.use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
209
            attn_mask_type=self.attn_mask_type,
210
            attn_bias_type=self.attn_bias_type,
211
212
213
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
214
215
216
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
217
            fuse_qkv_params=self.fuse_qkv_params,
218
            transpose_batch_sequence=self.transpose_batch_sequence,
219
            enable_sequence_parallel=self.enable_sequence_parallel,
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            float32_logits=self.float32_logits)

        self.create_layer("multi_head_attn", mha_cls)

    def __call__(self,
                 inputs_q: JTensor,
                 inputs_kv: JTensor,
                 mask: Optional[JTensor] = None,
                 bias: Optional[JTensor] = None,
                 *,
                 decode: bool = False,
                 deterministic: bool = False) -> JTensor:
        """__call__"""
        return self.multi_head_attn(inputs_q,
                                    inputs_kv,
                                    mask,
                                    bias,
                                    decode=decode,
                                    deterministic=deterministic)


class TransformerLayer(TransformerEngineBaseLayer):
    """TransformerLayer"""

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
249
    num_gqa_groups: Optional[int] = None
250
251
252
253
254
255
    layernorm_type: str = 'layernorm'
    layernorm_epsilon: float = 1e-6
    zero_centered_gamma: bool = False
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
256
257
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
258
259
260
261
262
263
264
265
    dropout_rng_name: str = 'dropout'
    mlp_activations: Sequence[str] = ('relu',)
    use_bias: bool = False
    bias_init: WeightInit = WeightInit.Constant(0.0)
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
266
    self_attn_mask_type: str = 'causal'
267
    self_attn_bias_type: Optional[str] = None
268
269
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
270
    rotary_pos_emb_group_method: str = 'consecutive'
271
272
273
    low_rank_adaptation_scope: str = 'none'
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
274
275
276
277
278
    enable_relative_embedding: bool = True
    relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
    transpose_batch_sequence: bool = False
279
    enable_sequence_parallel: bool = False
280
281
282
    scale_attn_logits: bool = False
    scaled_query_init: bool = True

zlsh80826's avatar
zlsh80826 committed
283
284
285
286
287
    def __post_init__(self):
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
        super().__post_init__()

288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    def setup(self) -> None:
        """setup"""
        super().setup()

        relative_embedding_flax_module = None
        if self.enable_relative_embedding and self.relative_embedding is not None:
            assert self.relative_embedding.num_attention_heads == \
                    self.num_attention_heads, \
                "TransformerLayer.relative_embedding.num_attention_heads shoule be" \
                "the same as TransformerLayer.num_attention_heads."

            embedding_init = RelativePositionBiases.generate_embedding_init(
                self.relative_embedding.embedding_init, self.relative_embedding.num_attention_heads,
                self.relative_embedding.num_buckets)

            relative_embedding_flax_module = flax_RelativePositionBiases(
                num_buckets=self.relative_embedding.num_buckets,
                max_distance=self.relative_embedding.max_distance,
                num_attention_heads=self.relative_embedding.num_attention_heads,
                embedding_init=TransformerEngineBaseLayer.generate_params_init(
                    "rel_embedding", embedding_init),
                embedding_axes=self.relative_embedding.embedding_axes,
                dtype=self.relative_embedding.dtype)

        transformerlayer_cls = partial(
            flax_TransformerLayer,
            dtype=self.dtype,
            hidden_size=self.hidden_size,
            mlp_hidden_size=self.mlp_hidden_size,
            num_attention_heads=self.num_attention_heads,
zlsh80826's avatar
zlsh80826 committed
318
            num_gqa_groups=self.num_gqa_groups,
319
320
321
322
323
324
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            hidden_dropout=self.hidden_dropout,
            hidden_dropout_dims=self.hidden_dropout_dims,
            attention_dropout=self.attention_dropout,
325
326
            intermediate_dropout=self.intermediate_dropout,
            intermediate_dropout_dims=self.intermediate_dropout_dims,
327
328
329
330
331
332
333
334
335
336
337
338
            dropout_rng_name=self.dropout_rng_name,
            mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
                "mha_kernel", self.params_init),
            mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
                "mlp_kernel", self.params_init),
            mlp_activations=self.mlp_activations,
            use_bias=self.use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
            apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
            output_layernorm=self.output_layernorm,
            float32_attention_logits=self.float32_attention_logits,
            layer_type=self.layer_type,
339
            self_attn_mask_type=self.self_attn_mask_type,
340
            self_attn_bias_type=self.self_attn_bias_type,
341
342
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
343
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
344
345
346
            low_rank_adaptation_scope=self.low_rank_adaptation_scope,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
347
348
349
350
351
            enable_relative_embedding=self.enable_relative_embedding,
            relative_embedding=relative_embedding_flax_module,
            drop_path=self.drop_path,
            fuse_qkv_params=self.fuse_qkv_params,
            transpose_batch_sequence=self.transpose_batch_sequence,
352
            enable_sequence_parallel=self.enable_sequence_parallel,
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init)

        self.create_layer("transformerlayer", transformerlayer_cls)

    def __call__(self,
                 inputs: JTensor,
                 encoded: JTensor = None,
                 attention_mask: JTensor = None,
                 encoder_decoder_mask: JTensor = None,
                 deterministic: bool = False,
                 decode: bool = False,
                 max_decode_length: bool = None) -> JTensor:
        """__call__"""
        return self.transformerlayer(inputs, encoded, attention_mask, encoder_decoder_mask,
                                     deterministic, decode, max_decode_length)