module.py 11.1 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""
Praxis Modules
"""
7
from dataclasses import field
8
9
10
11
12
from functools import partial
from typing import Callable, Iterable, Sequence, Tuple, Union

from praxis import pax_fiddle
from praxis.base_layer import init_var
13
from praxis.base_layer import BaseLayer, WeightInit, WeightHParams, WeightHParamsCollection
14
15
16
from praxis.layers import flax_adapter
from praxis.pytypes import JTensor

17
from ..fp8 import FP8Helper
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from ..flax.module import DenseGeneral, LayerNormDenseGeneral
from ..flax.module import LayerNorm as flax_LayerNorm
from ..flax.module import LayerNormMLP as flax_LayerNormMLP
from ..flax.module import Softmax
from ..softmax import SoftmaxType


def _generate_ln_scale_init(scale_init):
    if scale_init is not None:
        return TransformerEngineBaseLayer.generate_params_init("scale", scale_init)
    return scale_init


class TransformerEngineBaseLayer(BaseLayer):
    """TransformerEngineBaseLayer"""

    logical_axes_rules: Tuple[Tuple, ...] = None

    @staticmethod
    def generate_params_init(name: str, initializer: WeightInit):
        """generate_params_init"""

        def kernel_init(key, shape, dtype):
            wp = WeightHParams(shape=shape, init=initializer, dtype=dtype)
            return init_var(wp, key, name)

        return kernel_init

    def create_layer(self, name, flax_module_cls):
        """create_layer"""

49
50
51
        fp8_collection_map = {
            FP8Helper.FP8_COLLECTION_NAME: [
                WeightHParamsCollection.SKIP_LP_REGULARIZATION,
52
                WeightHParamsCollection.OVERWRITE_WITH_GRADIENT,
53
                WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION,
54
55
56
            ]
        }

57
58
59
60
61
62
63
64
65
        flax_module_p = pax_fiddle.Config(
            flax_adapter.FlaxModuleAdapter,
            module_factory_method=flax_module_cls,
            logical_axes_rules=self.logical_axes_rules,
            var_collection_map=fp8_collection_map,
            ici_mesh_shape=self.ici_mesh_shape,
            dcn_mesh_shape=self.dcn_mesh_shape,
            mesh_axis_names=self.mesh_axis_names,
        )
66
67
68
69
70
71
72
73

        self.create_child(name, flax_module_p.clone())


class LayerNorm(TransformerEngineBaseLayer):
    """LayerNorm"""

    epsilon: float = 1e-6
74
    layernorm_type: str = "layernorm"
75
76
77
    zero_centered_gamma: bool = False
    scale_init: WeightInit = None
    scale_axes: Tuple[str, ...] = ()
78
79
80
    bias_init: WeightInit = field(  # pylint: disable=invalid-field-call
        default_factory=partial(WeightInit.Constant, scale=0.0)
    )
81
82
83
84
85
86
87
    bias_axes: Tuple[str, ...] = ()
    transpose_batch_sequence: bool = False

    def setup(self) -> None:
        """setup"""
        super().setup()

88
89
90
91
92
93
94
95
96
97
98
99
        ln_cls = partial(
            flax_LayerNorm,
            epsilon=self.epsilon,
            layernorm_type=self.layernorm_type,
            zero_centered_gamma=self.zero_centered_gamma,
            scale_init=_generate_ln_scale_init(self.scale_init),
            scale_axes=self.scale_axes,
            bias_init=TransformerEngineBaseLayer.generate_params_init("ln_bias", self.bias_init),
            bias_axes=self.bias_axes,
            dtype=self.dtype,
            transpose_batch_sequence=self.transpose_batch_sequence,
        )
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

        self.create_layer("layer_norm", ln_cls)

    def __call__(self, x: JTensor) -> JTensor:
        """__call__"""
        return self.layer_norm(x)


class FusedSoftmax(TransformerEngineBaseLayer):
    """FusedSoftmax"""

    scale_factor: float = 1.0
    softmax_type: SoftmaxType = SoftmaxType.SCALED

    def setup(self) -> None:
        """setup"""
        super().setup()

118
119
120
        fused_softmax_cls = partial(
            Softmax, scale_factor=self.scale_factor, softmax_type=self.softmax_type
        )
121
122
123
124
125
126
127
128
129
130
131
132
133
134

        self.create_layer("fused_softmax", fused_softmax_cls)

    def __call__(self, x: JTensor, mask: JTensor = None, bias: JTensor = None) -> JTensor:
        """__call__"""
        return self.fused_softmax(x, mask, bias)


class Linear(TransformerEngineBaseLayer):
    """Linear"""

    out_features: int = 512
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = True
135
136
137
    bias_init: WeightInit = field(  # pylint: disable=invalid-field-call
        default_factory=partial(WeightInit.Constant, scale=0.0)
    )
138
    bias_axes: Tuple[str, ...] = ()
139
140
141
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    axis: Union[Iterable[int], int] = -1
    transpose_batch_sequence: bool = False

    def setup(self) -> None:
        """setup"""
        super().setup()

        dense_general_cls = partial(
            DenseGeneral,
            features=self.out_features,
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
            kernel_axes=self.kernel_axes,
            use_bias=self.use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
            bias_axes=self.bias_axes,
157
158
159
            enable_low_rank_adaptation=self.enable_low_rank_adaptation,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
160
161
            axis=self.axis,
            dtype=self.dtype,
162
163
            transpose_batch_sequence=self.transpose_batch_sequence,
        )
164
165
166
167
168
169
170
171
172
173
174
175
176

        self.create_layer("linear", dense_general_cls)

    def __call__(self, x: JTensor) -> JTensor:
        """__call__"""
        return self.linear(x)


class LayerNormLinear(TransformerEngineBaseLayer):
    """LayerNormLinear"""

    out_features: int = 512
    enable_layernorm: bool = True
177
    layernorm_type: str = "layernorm"
178
179
180
181
    epsilon: float = 1e-6
    zero_centered_gamma: bool = False
    scale_init: WeightInit = None
    scale_axes: Tuple[str, ...] = ()
182
183
184
    ln_bias_init: WeightInit = field(  # pylint: disable=invalid-field-call
        default_factory=partial(WeightInit.Constant, scale=1.0)
    )
185
186
187
    ln_bias_axes: Tuple[str, ...] = ()
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
188
189
190
    bias_init: WeightInit = field(  # pylint: disable=invalid-field-call
        default_factory=partial(WeightInit.Constant, scale=0.0)
    )
191
    bias_axes: Tuple[str, ...] = ()
192
193
194
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    return_layernorm_output: bool = True
    axis: Union[Iterable[int], int] = -1
    transpose_batch_sequence: bool = False
    depth_scaling: float = None

    def setup(self) -> None:
        """setup"""
        super().setup()

        ln_dense_general_cls = partial(
            LayerNormDenseGeneral,
            features=self.out_features,
            enable_layernorm=self.enable_layernorm,
            layernorm_type=self.layernorm_type,
            epsilon=self.epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            scale_init=_generate_ln_scale_init(self.scale_init),
            scale_axes=self.scale_axes,
            ln_bias_init=TransformerEngineBaseLayer.generate_params_init(
214
215
                "ln_bias", self.ln_bias_init
            ),
216
217
218
219
220
221
            ln_bias_axes=self.ln_bias_axes,
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
            kernel_axes=self.kernel_axes,
            use_bias=self.use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
            bias_axes=self.bias_axes,
222
223
224
            enable_low_rank_adaptation=self.enable_low_rank_adaptation,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
225
226
227
228
            return_layernorm_output=self.return_layernorm_output,
            axis=self.axis,
            dtype=self.dtype,
            transpose_batch_sequence=self.transpose_batch_sequence,
229
230
            depth_scaling=self.depth_scaling,
        )
231
232
233
234
235
236
237
238
239
240
241
242
243

        self.create_layer("ln_linear", ln_dense_general_cls)

    def __call__(self, x: JTensor) -> JTensor:
        """__call__"""
        return self.ln_linear(x)


class LayerNormMLP(TransformerEngineBaseLayer):
    """LayerNormMLP"""

    intermediate_dim: int = 2048
    enable_layernorm: bool = True
244
    layernorm_type: str = "layernorm"
245
246
247
248
    epsilon: float = 1e-6
    zero_centered_gamma: bool = False
    scale_init: WeightInit = None
    scale_axes: Tuple[str, ...] = ()
249
250
251
    ln_bias_init: WeightInit = field(  # pylint: disable=invalid-field-call
        default_factory=partial(WeightInit.Constant, scale=1.0)
    )
252
253
254
255
    ln_bias_axes: Tuple[str, ...] = ()
    kernel_axes_1: Tuple[str, ...] = ()
    kernel_axes_2: Tuple[str, ...] = ()
    use_bias: bool = False
256
257
258
    bias_init: WeightInit = field(  # pylint: disable=invalid-field-call
        default_factory=partial(WeightInit.Constant, scale=0.0)
    )
259
260
    bias_axes_1: Tuple[str, ...] = ()
    bias_axes_2: Tuple[str, ...] = ()
261
262
263
    enable_low_rank_adaptation: bool = False
    low_rank_adaptation_dim: int = 32
    low_rank_adaptation_alpha: float = None
264
    return_layernorm_output: bool = True
265
    activations: Sequence[Union[str, Callable]] = ("relu",)
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    intermediate_dropout_rate: float = 0.1
    intermediate_hidden_dropout_dims: Sequence[int] = ()
    axis: Union[Iterable[int], int] = -1
    transpose_batch_sequence: bool = False

    def setup(self) -> None:
        """setup"""
        super().setup()

        ln_mlp_cls = partial(
            flax_LayerNormMLP,
            intermediate_dim=self.intermediate_dim,
            enable_layernorm=self.enable_layernorm,
            layernorm_type=self.layernorm_type,
            epsilon=self.epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            scale_init=_generate_ln_scale_init(self.scale_init),
            scale_axes=self.scale_axes,
            ln_bias_init=TransformerEngineBaseLayer.generate_params_init(
285
286
                "ln_bias", self.ln_bias_init
            ),
287
288
289
290
291
292
293
294
            ln_bias_axes=self.ln_bias_axes,
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
            kernel_axes_1=self.kernel_axes_1,
            kernel_axes_2=self.kernel_axes_2,
            use_bias=self.use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
            bias_axes_1=self.bias_axes_1,
            bias_axes_2=self.bias_axes_2,
295
296
297
            enable_low_rank_adaptation=self.enable_low_rank_adaptation,
            low_rank_adaptation_dim=self.low_rank_adaptation_dim,
            low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
298
299
300
301
302
303
            return_layernorm_output=self.return_layernorm_output,
            activations=self.activations,
            intermediate_dropout_rate=self.intermediate_dropout_rate,
            intermediate_hidden_dropout_dims=self.intermediate_hidden_dropout_dims,
            axis=self.axis,
            dtype=self.dtype,
304
305
            transpose_batch_sequence=self.transpose_batch_sequence,
        )
306
307
308
309
310
311

        self.create_layer("ln_mlp", ln_mlp_cls)

    def __call__(self, x: JTensor, deterministic: bool = False) -> JTensor:
        """__call__"""
        return self.ln_mlp(x, deterministic)