transformer.py 10.5 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.
"""
Praxis Modules related Transformer
"""
from functools import partial
8
from typing import Optional, Sequence, Tuple
9
10
11
12
13
14

from praxis import pax_fiddle
from praxis.base_layer import WeightInit
from praxis.pytypes import JTensor

from .module import TransformerEngineBaseLayer
15
from ..flax.transformer import TransformerLayerType
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention
from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases
from ..flax.transformer import TransformerLayer as flax_TransformerLayer


class RelativePositionBiases(TransformerEngineBaseLayer):
    """RelativePositionBiases"""

    num_buckets: int = 32
    max_distance: int = 128
    num_attention_heads: int = 64
    embedding_init: WeightInit = None
    embedding_axes: Tuple[str, ...] = ()

    @staticmethod
    def generate_embedding_init(init, num_attention_heads, num_buckets):
        """generate_embedding_init"""
        embedding_init = init
        if embedding_init is None:
            rb_stddev = (num_attention_heads * num_buckets)**-0.5
            embedding_init = WeightInit.Gaussian(rb_stddev)
        return embedding_init

    def setup(self) -> None:
        """setup"""
        super().setup()

        embedding_init = RelativePositionBiases.generate_embedding_init(
            self.embedding_init, self.num_attention_heads, self.num_buckets)

        rpb_cls = partial(flax_RelativePositionBiases,
                          num_buckets=self.num_buckets,
                          max_distance=self.max_distance,
                          num_attention_heads=self.num_attention_heads,
                          embedding_init=TransformerEngineBaseLayer.generate_params_init(
                              "rel_embedding", embedding_init),
                          embedding_axes=self.embedding_axes,
                          dtype=self.dtype)

        self.create_layer("relative_position_bias", rpb_cls)

    def __call__(self, q_seqlen: JTensor, k_seqlen: JTensor, bidirectional: bool = True) -> JTensor:
        """__call__"""
        return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional)


class MultiHeadAttention(TransformerEngineBaseLayer):
    """MultiHeadAttention"""

    head_dim: int = 64
    num_heads: int = 16
zlsh80826's avatar
zlsh80826 committed
67
    num_gqa_groups: int | None = None
68
69
70
71
72
73
74
75
76
    dropout_rate: float = 0.
    dropout_rng_name: str = 'dropout'
    layernorm_type: str = "layernorm"
    layernorm_epsilon: float = 1e-6
    zero_centered_gamma: bool = False
    use_bias: bool = False
    bias_init: WeightInit = WeightInit.Constant(0.0)
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
77
    attn_mask_type: str = 'causal'
78
79
    fuse_qkv: bool = True
    transpose_batch_sequence: bool = True
80
    enable_sequence_parallel: bool = False
81
82
83
84
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    float32_logits: bool = False

zlsh80826's avatar
zlsh80826 committed
85
86
87
88
89
    def __post_init__(self):
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_heads
        super().__post_init__()

90
91
92
93
94
95
96
97
98
    def setup(self) -> None:
        """setup"""
        super().setup()

        mha_cls = partial(
            flax_MultiHeadAttention,
            dtype=self.dtype,
            head_dim=self.head_dim,
            num_heads=self.num_heads,
zlsh80826's avatar
zlsh80826 committed
99
            num_gqa_groups=self.num_gqa_groups,
100
101
102
103
104
105
106
107
108
109
            dropout_rate=self.dropout_rate,
            dropout_rng_name=self.dropout_rng_name,
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
            use_bias=self.use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
            apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
            output_layernorm=self.output_layernorm,
110
            attn_mask_type=self.attn_mask_type,
111
112
            fuse_qkv=self.fuse_qkv,
            transpose_batch_sequence=self.transpose_batch_sequence,
113
            enable_sequence_parallel=self.enable_sequence_parallel,
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            float32_logits=self.float32_logits)

        self.create_layer("multi_head_attn", mha_cls)

    def __call__(self,
                 inputs_q: JTensor,
                 inputs_kv: JTensor,
                 mask: Optional[JTensor] = None,
                 bias: Optional[JTensor] = None,
                 *,
                 decode: bool = False,
                 deterministic: bool = False) -> JTensor:
        """__call__"""
        return self.multi_head_attn(inputs_q,
                                    inputs_kv,
                                    mask,
                                    bias,
                                    decode=decode,
                                    deterministic=deterministic)


class TransformerLayer(TransformerEngineBaseLayer):
    """TransformerLayer"""

    hidden_size: int = 512
    mlp_hidden_size: int = 2048
    num_attention_heads: int = 8
zlsh80826's avatar
zlsh80826 committed
143
    num_gqa_groups: int | None = None
144
145
146
147
148
149
    layernorm_type: str = 'layernorm'
    layernorm_epsilon: float = 1e-6
    zero_centered_gamma: bool = False
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
150
151
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
152
153
154
155
156
157
158
159
    dropout_rng_name: str = 'dropout'
    mlp_activations: Sequence[str] = ('relu',)
    use_bias: bool = False
    bias_init: WeightInit = WeightInit.Constant(0.0)
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    float32_attention_logits: bool = False
    layer_type: TransformerLayerType = TransformerLayerType.ENCODER
160
    self_attn_mask_type: str = 'causal'
161
162
    enable_rotary_pos_emb: bool = False
    rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
163
164
165
166
167
    enable_relative_embedding: bool = True
    relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
    transpose_batch_sequence: bool = False
168
    enable_sequence_parallel: bool = False
169
170
171
    scale_attn_logits: bool = False
    scaled_query_init: bool = True

zlsh80826's avatar
zlsh80826 committed
172
173
174
175
176
    def __post_init__(self):
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
        super().__post_init__()

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    def setup(self) -> None:
        """setup"""
        super().setup()

        relative_embedding_flax_module = None
        if self.enable_relative_embedding and self.relative_embedding is not None:
            assert self.relative_embedding.num_attention_heads == \
                    self.num_attention_heads, \
                "TransformerLayer.relative_embedding.num_attention_heads shoule be" \
                "the same as TransformerLayer.num_attention_heads."

            embedding_init = RelativePositionBiases.generate_embedding_init(
                self.relative_embedding.embedding_init, self.relative_embedding.num_attention_heads,
                self.relative_embedding.num_buckets)

            relative_embedding_flax_module = flax_RelativePositionBiases(
                num_buckets=self.relative_embedding.num_buckets,
                max_distance=self.relative_embedding.max_distance,
                num_attention_heads=self.relative_embedding.num_attention_heads,
                embedding_init=TransformerEngineBaseLayer.generate_params_init(
                    "rel_embedding", embedding_init),
                embedding_axes=self.relative_embedding.embedding_axes,
                dtype=self.relative_embedding.dtype)

        transformerlayer_cls = partial(
            flax_TransformerLayer,
            dtype=self.dtype,
            hidden_size=self.hidden_size,
            mlp_hidden_size=self.mlp_hidden_size,
            num_attention_heads=self.num_attention_heads,
zlsh80826's avatar
zlsh80826 committed
207
            num_gqa_groups=self.num_gqa_groups,
208
209
210
211
212
213
            layernorm_type=self.layernorm_type,
            layernorm_epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            hidden_dropout=self.hidden_dropout,
            hidden_dropout_dims=self.hidden_dropout_dims,
            attention_dropout=self.attention_dropout,
214
215
            intermediate_dropout=self.intermediate_dropout,
            intermediate_dropout_dims=self.intermediate_dropout_dims,
216
217
218
219
220
221
222
223
224
225
226
227
            dropout_rng_name=self.dropout_rng_name,
            mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
                "mha_kernel", self.params_init),
            mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
                "mlp_kernel", self.params_init),
            mlp_activations=self.mlp_activations,
            use_bias=self.use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
            apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
            output_layernorm=self.output_layernorm,
            float32_attention_logits=self.float32_attention_logits,
            layer_type=self.layer_type,
228
            self_attn_mask_type=self.self_attn_mask_type,
229
230
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_windows=self.rotary_pos_emb_windows,
231
232
233
234
235
            enable_relative_embedding=self.enable_relative_embedding,
            relative_embedding=relative_embedding_flax_module,
            drop_path=self.drop_path,
            fuse_qkv_params=self.fuse_qkv_params,
            transpose_batch_sequence=self.transpose_batch_sequence,
236
            enable_sequence_parallel=self.enable_sequence_parallel,
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init)

        self.create_layer("transformerlayer", transformerlayer_cls)

    def __call__(self,
                 inputs: JTensor,
                 encoded: JTensor = None,
                 attention_mask: JTensor = None,
                 encoder_decoder_mask: JTensor = None,
                 deterministic: bool = False,
                 decode: bool = False,
                 max_decode_length: bool = None) -> JTensor:
        """__call__"""
        return self.transformerlayer(inputs, encoded, attention_mask, encoder_decoder_mask,
                                     deterministic, decode, max_decode_length)