test_praxis_layers.py 51.9 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
import os
6
7
8
from functools import partial
from typing import Dict

9
import flax
10
11
12
13
14
15
import jax
import jax.numpy as jnp
from praxis import pax_fiddle
from praxis.base_layer import WeightInit, DEFAULT_INIT_MUTABLE_LIST
import pytest

zlsh80826's avatar
zlsh80826 committed
16
17
from utils import assert_allclose

18
from transformer_engine_jax import get_device_compute_capability
19
20
21
22
23
24
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
from transformer_engine.jax.flax import LayerNorm as flax_LayerNorm
from transformer_engine.jax.flax import LayerNormMLP as flax_LayerNormMLP
from transformer_engine.jax.flax import MultiHeadAttention as flax_MultiHeadAttention
25
from transformer_engine.jax.flax import DotProductAttention as flax_DotProductAttention
26
27
28
29
30
from transformer_engine.jax.flax import RelativePositionBiases as flax_RelativePositionBiases
from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLayer
from transformer_engine.jax.flax.module import Softmax
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from transformer_engine.jax.praxis import LayerNorm
zlsh80826's avatar
zlsh80826 committed
31
from transformer_engine.jax.praxis import FusedSoftmax
32
from transformer_engine.jax.praxis import LayerNormLinear, LayerNormMLP, Linear
33
34
from transformer_engine.jax.praxis import DotProductAttention, MultiHeadAttention
from transformer_engine.jax.praxis import RelativePositionBiases, TransformerEngineBaseLayer
zlsh80826's avatar
zlsh80826 committed
35
from transformer_engine.jax.praxis import TransformerLayer, TransformerLayerType
36
37
38
39
40
41
42
43
44
45
from transformer_engine.jax.softmax import SoftmaxType

is_fp8_supported, reason = is_fp8_available()

DATA_SHAPE = [(128, 32, 512), (512, 32, 512)]
DTYPE = [jnp.float32, jnp.bfloat16]
ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]


46
47
48
49
50
51
52
53
54
55
56
57
58
@pytest.fixture(autouse=True, scope='module')
def enable_fused_attn():
    """
    Enable fused attn for hopper+ arch.
    Fused attn kernels on pre-hopper arch are not deterministic.
    """
    if get_device_compute_capability(0) >= 90:
        os.environ["NVTE_FUSED_ATTN"] = "1"
    yield
    if "NVTE_FUSED_ATTN" in os.environ:
        del os.environ["NVTE_FUSED_ATTN"]


59
60
61
62
63
64
65
66
67
68
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
    """
    Clear all live arrays to keep the resource clean
    """
    yield
    for arr in jax.live_arrays():
        arr.delete()


69
70
71
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
    for key in ref_fd:
        assert key in test_fd, \
72
            f"{key} not found in test dict {test_fd}"
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        assert isinstance(test_fd[key], type(ref_fd[key])), \
            f"The data type is not match between ref and test " \
            f" Dict on {key=}"
        if isinstance(ref_fd[key], Dict):
            compare_dict(ref_fd[key], test_fd[key], rtol, atol)
        else:
            assert_allclose(ref_fd[key],
                            test_fd[key],
                            rtol=rtol,
                            atol=atol,
                            err_msg=f"{key=} is not close")


class TestLayer:

    @staticmethod
    def loss(inner_variables, *inner_inputs, module, mean_out=True):
        outs = module.apply(inner_variables, *inner_inputs)
        out = outs
        if isinstance(outs, tuple):
            # The first place of outs is the real output, others
            # are auxiliary values.
            out = outs[0]
        return jnp.mean(out) if mean_out else out

    @staticmethod
    def loss_and_grads(module, variables, *inputs):
        grad_fn = jax.value_and_grad(TestLayer.loss, argnums=(0, 1))
        loss_val, (wgrads, dgrad) = grad_fn(variables, *inputs, module=module)
        if FP8Helper.is_fp8_enabled():
            wgrads = update_fp8_metas(wgrads)
        return loss_val, wgrads, dgrad

    def input_getter(self, shape, dtype):
        raise NotImplementedError

    def get_layer_name(self):
        raise NotImplementedError

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        raise NotImplementedError

    def sync_variables(self, praxis_variables, flax_variables):
        synced_praxis_variables = praxis_variables

        lyr_name = self.get_layer_name()

120
121
122
        if 'params' in flax_variables:
            synced_praxis_variables['params'][lyr_name]['cld'] = \
                flax.core.unfreeze(flax_variables['params'])
123
124
125
126
127
128
129
130

        return synced_praxis_variables, flax_variables

    def sync_wgrads(self, praxis_wgrads, flax_wgrads):
        synced_praxis_grads = praxis_wgrads

        lyr_name = self.get_layer_name()

131
132
133
        if 'params' in synced_praxis_grads:
            synced_praxis_grads['params'] = \
                synced_praxis_grads['params'][lyr_name]['cld']
134
135
136
137
138

        if FP8Helper.is_fp8_enabled():
            synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = \
                synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME][lyr_name]['cld']

139
        return synced_praxis_grads, flax.core.unfreeze(flax_wgrads)
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

    def forward_backward_runner(self,
                                data_shape,
                                dtype,
                                praxis_p,
                                flax_cls,
                                rtol=1e-05,
                                atol=1e-08):
        init_key = jax.random.PRNGKey(seed=1234)

        test_inputs = self.input_getter(data_shape, dtype)

        praxis_layer = praxis_p.Instantiate()
        # This is a workaround to correctly enable FP8 meta generation for Praxis.
        # TODO (Ming Huang): To come out a better solution.
        mutable_list = DEFAULT_INIT_MUTABLE_LIST + [FP8Helper.FP8_COLLECTION_NAME]
        praxis_variables = praxis_layer.init(init_key, *test_inputs, mutable=mutable_list)

        flax_layer = flax_cls()
        flax_variables = flax_layer.init(init_key, *test_inputs)
        if "params_axes" in flax_variables:
161
            flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
162
        if FP8Helper.is_fp8_enabled():
163
164
            flax_variables, _ = flax.core.pop(flax_variables,
                                              FP8Helper.FP8_COLLECTION_NAME + "_axes")
165
166
167
168
169
170
171
172
173
174
175
176
177

        praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)

        iter_times = 5 if FP8Helper.is_fp8_enabled() else 1

        for _ in range(iter_times):
            praxis_loss, praxis_wgrads, praxis_dgrad = \
                TestLayer.loss_and_grads(praxis_layer, praxis_variables, *test_inputs)
            flax_loss, flax_wgrads, flax_dgrad = \
                TestLayer.loss_and_grads(flax_layer, flax_variables, *test_inputs)
            if FP8Helper.is_fp8_enabled():
                praxis_wgrads.pop('params')
                praxis_variables = update_collections(praxis_wgrads, praxis_variables)
178
                flax_wgrads, _ = flax.core.pop(flax_wgrads, 'params')
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
                flax_variables = update_collections(flax_wgrads, flax_variables)

        praxis_loss, praxis_wgrads, praxis_dgrad = \
                TestLayer.loss_and_grads(praxis_layer, praxis_variables, *test_inputs)
        flax_loss, flax_wgrads, flax_dgrad = \
            TestLayer.loss_and_grads(flax_layer, flax_variables, *test_inputs)

        assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
        assert_allclose(praxis_dgrad, flax_dgrad, rtol=rtol, atol=atol)

        praxis_wgrads, flax_wgrads = self.sync_wgrads(praxis_wgrads, flax_wgrads)
        compare_dict(praxis_wgrads, flax_wgrads, rtol=rtol, atol=atol)


class LayerNormAttr:
    LN_TYPE = 'layernorm_type'
    ZERO_CEN = 'zero_centered_gamma'
    ATTRS = [{
        LN_TYPE: "layernorm",
        ZERO_CEN: False
    }, {
        LN_TYPE: "layernorm",
        ZERO_CEN: True
    }, {
        LN_TYPE: "rmsnorm",
        ZERO_CEN: False
    }]


class TestLayerNorm(TestLayer):

    def input_getter(self, shape, dtype):
        data_key = jax.random.PRNGKey(seed=1234)
        return (jax.random.normal(data_key, shape, dtype),)

    def get_layer_name(self):
        return 'layer_norm'

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        layernorm_type = attrs[LayerNormAttr.LN_TYPE]
        zero_centered_gamma = attrs[LayerNormAttr.ZERO_CEN]
        scale_init = None
        bias_init = WeightInit.Constant(0.0)
        transpose_batch_sequence = False

        praxis_p = pax_fiddle.Config(LayerNorm,
                                     name='layer_norm',
                                     dtype=dtype,
                                     layernorm_type=layernorm_type,
                                     zero_centered_gamma=zero_centered_gamma,
                                     scale_init=scale_init,
                                     bias_init=bias_init,
                                     transpose_batch_sequence=transpose_batch_sequence)
        flax_cls = partial(flax_LayerNorm,
                           layernorm_type=layernorm_type,
                           zero_centered_gamma=zero_centered_gamma,
                           scale_init=scale_init,
                           bias_init=TransformerEngineBaseLayer.generate_params_init(
                               "ln_bias", bias_init),
                           dtype=dtype,
                           transpose_batch_sequence=transpose_batch_sequence)

        return praxis_p, flax_cls

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', LayerNormAttr.ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class FusedSoftmaxAttr:
    SCALE_FACTOR = 'scale_factor'
    ST_TYPE = 'softmax_type'
    ATTRS = [{
        SCALE_FACTOR: 0.0,
        ST_TYPE: SoftmaxType.SCALED
    }, {
        SCALE_FACTOR: 0.0,
        ST_TYPE: SoftmaxType.SCALED_MASKED
    }, {
        SCALE_FACTOR: 0.0,
        ST_TYPE: SoftmaxType.SCALED_UPPER_TRIANG_MASKED
    }]


class TestFusedSoftmax(TestLayer):

    def input_getter(self, shape, dtype):
        data_key = jax.random.PRNGKey(seed=1234)
        return jax.random.normal(data_key, shape, dtype), \
               jnp.ones(shape, dtype=jnp.uint8) # Masks

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        scale_factor = attrs[FusedSoftmaxAttr.SCALE_FACTOR]
        softmax_type = attrs[FusedSoftmaxAttr.ST_TYPE]

        praxis_p = pax_fiddle.Config(FusedSoftmax,
                                     name='fused_softmax',
                                     scale_factor=scale_factor,
                                     softmax_type=softmax_type)
        flax_cls = partial(Softmax, scale_factor=scale_factor, softmax_type=softmax_type)

        return praxis_p, flax_cls

    def sync_variables(self, praxis_variables, flax_variables):
        return praxis_variables, flax_variables

    def sync_wgrads(self, praxis_wgrads, flax_wgrads):
        return praxis_wgrads, flax_wgrads

    @pytest.mark.parametrize('data_shape', [(32, 1, 128, 128), (32, 1, 512, 128)])
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', FusedSoftmaxAttr.ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        if (attrs[FusedSoftmaxAttr.ST_TYPE] == SoftmaxType.SCALED_UPPER_TRIANG_MASKED) and \
            (data_shape[-2] != data_shape[-1]):
            pass    # Skip, due to not support
        else:
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class LinearAttr:
    FEATURE = 'features'
    USE_BIAS = 'use_bias'
    ATTRS = [{
        FEATURE: 512,
        USE_BIAS: False
    }, {
        FEATURE: 512,
        USE_BIAS: True
    }, {
        FEATURE: 1024,
        USE_BIAS: False
    }, {
        FEATURE: 1024,
        USE_BIAS: True
    }]


class TestLinear(TestLayer):

    def input_getter(self, shape, dtype):
        data_key = jax.random.PRNGKey(seed=1234)
        return (jax.random.normal(data_key, shape, dtype),)

    def get_layer_name(self):
        return 'linear'

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        out_features = attrs[LinearAttr.FEATURE]
        kernel_init = WeightInit.Gaussian(1.0)
        use_bias = attrs[LinearAttr.USE_BIAS]
        bias_init = WeightInit.Constant(0.0)
        axis = -1
        transpose_batch_sequence = False

        praxis_p = pax_fiddle.Config(Linear,
                                     name='linear',
                                     dtype=dtype,
                                     out_features=out_features,
                                     params_init=kernel_init,
                                     use_bias=use_bias,
                                     bias_init=bias_init,
                                     axis=axis,
                                     transpose_batch_sequence=transpose_batch_sequence)
        flax_cls = partial(
            DenseGeneral,
            features=out_features,
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
            use_bias=use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
            axis=axis,
            dtype=dtype,
            transpose_batch_sequence=transpose_batch_sequence)

        return praxis_p, flax_cls

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', LinearAttr.ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', LinearAttr.ATTRS)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    def test_forward_backward_fp8(self,
                                  data_shape,
                                  dtype,
                                  attrs,
                                  fp8_format,
                                  rtol=1e-05,
                                  atol=1e-08):

        ds = DelayedScaling(fp8_format=fp8_format)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class LayerNormLinearAttr:
    FEATURE = 'features'
    USE_BIAS = 'use_bias'
    ENABLE_LN = 'enable_layernorm'
    LN_TYPE = 'layernorm_type'
    ZERO_CEN = 'zero_centered_gamma'
    ATTRS = [{
        FEATURE: 512,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False
    }, {
        FEATURE: 512,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False
    }, {
        FEATURE: 512,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True
    }, {
        FEATURE: 512,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True
    }, {
        FEATURE: 512,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False
    }, {
        FEATURE: 512,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False
    }, {
        FEATURE: 512,
        USE_BIAS: True,
        ENABLE_LN: False,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False
    }]


class TestLayerNormLinear(TestLayer):

    def input_getter(self, shape, dtype):
        data_key = jax.random.PRNGKey(seed=1234)
        return (jax.random.normal(data_key, shape, dtype),)

    def get_layer_name(self):
        return 'ln_linear'

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        out_features = attrs[LayerNormLinearAttr.FEATURE]
        enable_layernorm = attrs[LayerNormLinearAttr.ENABLE_LN]
        layernorm_type = attrs[LayerNormLinearAttr.LN_TYPE]
        zero_centered_gamma = attrs[LayerNormLinearAttr.ZERO_CEN]
        kernel_init = WeightInit.Gaussian(1.0)
        use_bias = attrs[LayerNormLinearAttr.USE_BIAS]
        bias_init = WeightInit.Constant(0.0)
        axis = -1
        transpose_batch_sequence = False

        praxis_p = pax_fiddle.Config(LayerNormLinear,
                                     name='ln_linear',
                                     dtype=dtype,
                                     out_features=out_features,
                                     enable_layernorm=enable_layernorm,
                                     layernorm_type=layernorm_type,
                                     zero_centered_gamma=zero_centered_gamma,
                                     params_init=kernel_init,
                                     use_bias=use_bias,
                                     bias_init=bias_init,
                                     axis=axis,
                                     transpose_batch_sequence=transpose_batch_sequence)
        flax_cls = partial(
            LayerNormDenseGeneral,
            features=out_features,
            enable_layernorm=enable_layernorm,
            layernorm_type=layernorm_type,
            zero_centered_gamma=zero_centered_gamma,
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
            use_bias=use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
            axis=axis,
            dtype=dtype,
            transpose_batch_sequence=transpose_batch_sequence)

        return praxis_p, flax_cls

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', LayerNormLinearAttr.ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', LayerNormLinearAttr.ATTRS)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    def test_forward_backward_fp8(self,
                                  data_shape,
                                  dtype,
                                  attrs,
                                  fp8_format,
                                  rtol=1e-05,
                                  atol=1e-08):

        ds = DelayedScaling(fp8_format=fp8_format)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class LayerNormMLPAttr:
    INTERMEDIATE_DIM = 'intermediate_dim'
    USE_BIAS = 'use_bias'
    ENABLE_LN = 'enable_layernorm'
    LN_TYPE = 'layernorm_type'
    ZERO_CEN = 'zero_centered_gamma'
    ACTIVATION = 'activations'
    ATTRS = [{
        INTERMEDIATE_DIM: 2048,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',)
    }, {
        INTERMEDIATE_DIM: 2048,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('relu',)
    }, {
        INTERMEDIATE_DIM: 2048,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',)
    }, {
        INTERMEDIATE_DIM: 2048,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear')
    }, {
        INTERMEDIATE_DIM: 2048,
546
        USE_BIAS: False,
547
548
549
550
        ENABLE_LN: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear')
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    }, {
        INTERMEDIATE_DIM: 2048,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('silu', 'linear')
    }, {
        INTERMEDIATE_DIM: 2048,
        USE_BIAS: False,
        ENABLE_LN: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('silu', 'linear')
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
    }]


class TestLayerNormMLP(TestLayer):

    def input_getter(self, shape, dtype):
        data_key = jax.random.PRNGKey(seed=1234)
        return (jax.random.normal(data_key, shape, dtype),)

    def get_layer_name(self):
        return 'ln_mlp'

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        intermediate_dim = attrs[LayerNormMLPAttr.INTERMEDIATE_DIM]
        enable_layernorm = attrs[LayerNormMLPAttr.ENABLE_LN]
        layernorm_type = attrs[LayerNormMLPAttr.LN_TYPE]
        zero_centered_gamma = attrs[LayerNormMLPAttr.ZERO_CEN]
        kernel_init = WeightInit.Gaussian(1.0)
        use_bias = attrs[LayerNormMLPAttr.USE_BIAS]
        bias_init = WeightInit.Constant(0.0)
        activations = attrs[LayerNormMLPAttr.ACTIVATION]
        axis = -1
        transpose_batch_sequence = False

        praxis_p = pax_fiddle.Config(LayerNormMLP,
                                     name='ln_mlp',
                                     dtype=dtype,
                                     intermediate_dim=intermediate_dim,
                                     enable_layernorm=enable_layernorm,
                                     layernorm_type=layernorm_type,
                                     zero_centered_gamma=zero_centered_gamma,
                                     params_init=kernel_init,
                                     use_bias=use_bias,
                                     bias_init=bias_init,
                                     activations=activations,
                                     intermediate_dropout_rate=0.0,
                                     axis=axis,
                                     transpose_batch_sequence=transpose_batch_sequence)
        flax_cls = partial(
            flax_LayerNormMLP,
            intermediate_dim=intermediate_dim,
            enable_layernorm=enable_layernorm,
            layernorm_type=layernorm_type,
            zero_centered_gamma=zero_centered_gamma,
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
            use_bias=use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
            activations=activations,
            intermediate_dropout_rate=0.0,
            axis=axis,
            dtype=dtype,
            transpose_batch_sequence=transpose_batch_sequence)

        return praxis_p, flax_cls

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', LayerNormMLPAttr.ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', LayerNormMLPAttr.ATTRS)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    def test_forward_backward_fp8(self,
                                  data_shape,
                                  dtype,
                                  attrs,
                                  fp8_format,
                                  rtol=1e-05,
                                  atol=1e-08):

        ds = DelayedScaling(fp8_format=fp8_format)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class TestRelativePositionBias(TestLayer):

    def get_layer_name(self):
        return 'relative_position_bias'

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        num_buckets = 32
        max_distance = 128
        num_attention_heads = 64
        rb_stddev = (num_attention_heads * num_buckets)**-0.5
        embedding_init = WeightInit.Gaussian(rb_stddev)

        praxis_p = pax_fiddle.Config(RelativePositionBiases,
                                     name='relative_position_bias',
                                     dtype=dtype,
                                     num_buckets=num_buckets,
                                     max_distance=max_distance,
                                     num_attention_heads=num_attention_heads,
                                     embedding_init=embedding_init)
        flax_cls = partial(flax_RelativePositionBiases,
                           num_buckets=num_buckets,
                           max_distance=max_distance,
                           num_attention_heads=num_attention_heads,
                           embedding_init=TransformerEngineBaseLayer.generate_params_init(
                               "rel_embedding", embedding_init),
                           dtype=dtype)

        return praxis_p, flax_cls

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', [{}])
    def test_forward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)

        init_key = jax.random.PRNGKey(seed=1234)

        test_inputs = [(128, 128, True), (128, 128, False)]
        for test_input in test_inputs:
            praxis_layer = praxis_p.Instantiate()
            praxis_variables = praxis_layer.init(init_key, *test_input)

            flax_layer = flax_cls()
            flax_variables = flax_layer.init(init_key, *test_input)
            if "params_axes" in flax_variables:
691
                flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
692
            if FP8Helper.is_fp8_enabled():
693
694
                flax_variables, _ = flax.core.pop(flax_variables,
                                                  FP8Helper.FP8_COLLECTION_NAME + "_axes")
695
696
697
698

            praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)

            praxis_loss= \
zlsh80826's avatar
zlsh80826 committed
699
                TestLayer.loss(praxis_variables, *test_input, module=praxis_layer, mean_out=False)
700
701
702
703
704
705
            flax_loss = \
                TestLayer.loss(flax_variables, *test_input, module=flax_layer, mean_out=False)

            assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)


706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
class DotProductAttnAttr:
    ATTN_MASK_TYPE = 'attn_mask_type'
    NUM_GQA_GROUPS = 'num_gqa_groups'
    TRANSPOSE_BS = 'transpose_batch_sequence'
    SCALE_FACTOR = 'scale_factor'
    ATTRS = [{
        ATTN_MASK_TYPE: 'padding',
        TRANSPOSE_BS: True,
        SCALE_FACTOR: 0.125,
    }, {
        ATTN_MASK_TYPE: 'padding_causal',
        TRANSPOSE_BS: True,
        SCALE_FACTOR: 0.125,
    }, {
        ATTN_MASK_TYPE: 'causal',
        TRANSPOSE_BS: True,
        SCALE_FACTOR: 0.125,
    }, {
        ATTN_MASK_TYPE: 'padding',
        TRANSPOSE_BS: False,
        SCALE_FACTOR: 0.125,
    }, {
        ATTN_MASK_TYPE: 'padding_causal',
        TRANSPOSE_BS: False,
        SCALE_FACTOR: 2.,
    }, {
        ATTN_MASK_TYPE: 'causal',
        TRANSPOSE_BS: False,
        SCALE_FACTOR: 1.,
    }, {
        ATTN_MASK_TYPE: 'no_mask',
        TRANSPOSE_BS: False,
        SCALE_FACTOR: 1.,
    }]


class TestDotProductAttn(TestLayer):

    def input_getter(self, shape, dtype):
        key = jax.random.PRNGKey(seed=1234)
        q_key, k_key, v_key = jax.random.split(key, 3)
747
748
749
750
751
752
753
        b, s, *_ = shape
        if self.attrs[DotProductAttnAttr.TRANSPOSE_BS]:
            b, s = s, b
        mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
        return [
            *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask
        ]
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786

    def get_layer_name(self):
        return 'dot_product_attn'

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        head_dim = 64
        num_attention_heads = 16
        num_gqa_groups = num_attention_heads
        attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE]
        transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS]

        praxis_p = pax_fiddle.Config(DotProductAttention,
                                     name='mha',
                                     dtype=dtype,
                                     head_dim=head_dim,
                                     num_attention_heads=num_attention_heads,
                                     num_gqa_groups=num_gqa_groups,
                                     attn_mask_type=attn_mask_type,
                                     transpose_batch_sequence=transpose_batch_sequence)
        flax_cls = partial(flax_DotProductAttention,
                           dtype=dtype,
                           head_dim=head_dim,
                           num_attention_heads=num_attention_heads,
                           num_gqa_groups=num_gqa_groups,
                           attn_mask_type=attn_mask_type,
                           transpose_batch_sequence=transpose_batch_sequence)

        return praxis_p, flax_cls

    @pytest.mark.parametrize('data_shape', [(32, 128, 16, 64)])
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
787
        self.attrs = attrs
788
789
790
791
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


792
793
794
class MultiHeadAttnAttr:
    USE_BIAS = 'use_bias'
    LN_TYPE = 'layernorm_type'
795
    ATTN_MASK_TYPE = 'attn_mask_type'
796
    ZERO_CEN = 'zero_centered_gamma'
zlsh80826's avatar
zlsh80826 committed
797
798
    NUM_ATTN_HEADS = 'num_attention_heads'
    NUM_GQA_GROUPS = 'num_gqa_groups'
799
800
    ENABLE_ROPE = 'enable_rotary_pos_emb'
    ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
801
    LORA_SCOPE = 'low_rank_adaptation_scope'
802
803
804
805
    ATTRS = [{
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
806
807
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
808
        ATTN_MASK_TYPE: 'padding'
809
810
811
812
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
813
814
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
815
        ATTN_MASK_TYPE: 'padding'
816
817
818
819
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
820
821
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
822
        ATTN_MASK_TYPE: 'padding'
823
824
825
826
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
827
828
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
829
        ATTN_MASK_TYPE: 'causal'
830
831
832
833
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
834
835
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
836
        ATTN_MASK_TYPE: 'causal'
837
838
839
840
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
841
842
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
843
        ATTN_MASK_TYPE: 'causal'
zlsh80826's avatar
zlsh80826 committed
844
845
846
847
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
        NUM_ATTN_HEADS: 8,
        NUM_GQA_GROUPS: 4,
        ATTN_MASK_TYPE: 'causal'
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ENABLE_ROPE: True,
        ROPE_GROUP_METHOD: 'consecutive',
        NUM_ATTN_HEADS: 8,
        NUM_GQA_GROUPS: 4,
        ATTN_MASK_TYPE: 'causal'
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ENABLE_ROPE: True,
        ROPE_GROUP_METHOD: 'alternate',
zlsh80826's avatar
zlsh80826 committed
868
869
870
        NUM_ATTN_HEADS: 8,
        NUM_GQA_GROUPS: 4,
        ATTN_MASK_TYPE: 'causal'
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
        ATTN_MASK_TYPE: 'padding',
        LORA_SCOPE: 'all'
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
        ATTN_MASK_TYPE: 'causal',
        LORA_SCOPE: 'all'
887
888
889
890
891
892
    }]


class TestMultiHeadAttn(TestLayer):

    def input_getter(self, shape, dtype):
893
894
895
896
897
        key = jax.random.PRNGKey(seed=1234)
        q_key, kv_key = jax.random.split(key, 2)
        s, b, *_ = shape
        mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
        return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask]
898
899
900
901
902
903

    def get_layer_name(self):
        return 'multi_head_attn'

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        head_dim = 64
904
905
906
        num_attention_heads = 16
        num_gqa_groups = attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS] \
            if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs else None
907
908
909
910
911
        layernorm_type = attrs[MultiHeadAttnAttr.LN_TYPE]
        zero_centered_gamma = attrs[MultiHeadAttnAttr.ZERO_CEN]
        kernel_init = WeightInit.Gaussian(1.0)
        use_bias = attrs[MultiHeadAttnAttr.USE_BIAS]
        bias_init = WeightInit.Constant(0.0)
912
913
        input_layernorm = False
        return_layernorm_output = False
914
        attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE]
915
916
        enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE]
        rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD]
917
        low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, 'none')
918
        fuse_qkv_params = True
919
920
921
922
923
        transpose_batch_sequence = True
        scale_attn_logits = False
        scaled_query_init = True
        float32_logits = False

924
925
926
927
928
929
930
931
932
933
934
935
936
937
        praxis_p = pax_fiddle.Config(MultiHeadAttention,
                                     name='mha',
                                     dtype=dtype,
                                     head_dim=head_dim,
                                     num_attention_heads=num_attention_heads,
                                     num_gqa_groups=num_gqa_groups,
                                     layernorm_type=layernorm_type,
                                     zero_centered_gamma=zero_centered_gamma,
                                     params_init=kernel_init,
                                     use_bias=use_bias,
                                     bias_init=bias_init,
                                     return_layernorm_output=return_layernorm_output,
                                     input_layernorm=input_layernorm,
                                     attn_mask_type=attn_mask_type,
938
939
                                     enable_rotary_pos_emb=enable_rotary_pos_emb,
                                     rotary_pos_emb_group_method=rotary_pos_emb_group_method,
940
                                     low_rank_adaptation_scope=low_rank_adaptation_scope,
941
942
943
944
945
                                     fuse_qkv_params=fuse_qkv_params,
                                     transpose_batch_sequence=transpose_batch_sequence,
                                     scale_attn_logits=scale_attn_logits,
                                     scaled_query_init=scaled_query_init,
                                     float32_logits=float32_logits)
946
947
948
949
        flax_cls = partial(
            flax_MultiHeadAttention,
            dtype=dtype,
            head_dim=head_dim,
950
951
            num_attention_heads=num_attention_heads,
            num_gqa_groups=num_gqa_groups,
952
953
954
955
956
            layernorm_type=layernorm_type,
            zero_centered_gamma=zero_centered_gamma,
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
            use_bias=use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
957
958
            return_layernorm_output=return_layernorm_output,
            input_layernorm=input_layernorm,
959
            attn_mask_type=attn_mask_type,
960
961
            enable_rotary_pos_emb=enable_rotary_pos_emb,
            rotary_pos_emb_group_method=rotary_pos_emb_group_method,
962
            low_rank_adaptation_scope=low_rank_adaptation_scope,
963
            fuse_qkv_params=fuse_qkv_params,
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
            transpose_batch_sequence=transpose_batch_sequence,
            scale_attn_logits=scale_attn_logits,
            scaled_query_init=scaled_query_init,
            float32_logits=float32_logits)

        return praxis_p, flax_cls

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    def test_forward_backward_fp8(self,
                                  data_shape,
                                  dtype,
                                  attrs,
                                  fp8_format,
                                  rtol=1e-05,
                                  atol=1e-08):

        ds = DelayedScaling(fp8_format=fp8_format)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class TransformerLayerAttr:
    USE_BIAS = 'use_bias'
    LN_TYPE = 'layernorm_type'
    ACTIVATION = 'activations'
    LYR_TYPE = 'layer_type'
    ZERO_CEN = 'zero_centered_gamma'
    TRANSPOSE_BS = 'transpose_batch_sequence'
1004
    ENABLE_ROPE = 'enable_rotary_pos_emb'
1005
    ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
1006
    LORA_SCOPE = 'low_rank_adaptation_scope'
1007
1008
1009
1010
1011
1012
    ATTRS = [{
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
1013
        ENABLE_ROPE: False,
1014
        ROPE_GROUP_METHOD: 'consecutive',
1015
1016
1017
1018
1019
1020
1021
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
1022
        ENABLE_ROPE: False,
1023
        ROPE_GROUP_METHOD: 'consecutive',
1024
1025
1026
1027
1028
1029
1030
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
1031
        ENABLE_ROPE: False,
1032
        ROPE_GROUP_METHOD: 'consecutive',
1033
1034
1035
1036
1037
1038
1039
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
1040
        ENABLE_ROPE: False,
1041
        ROPE_GROUP_METHOD: 'consecutive',
1042
1043
1044
1045
1046
1047
1048
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
1049
        ENABLE_ROPE: False,
1050
        ROPE_GROUP_METHOD: 'consecutive',
1051
1052
1053
1054
1055
1056
1057
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
1058
        ENABLE_ROPE: False,
1059
        ROPE_GROUP_METHOD: 'consecutive',
1060
1061
1062
1063
1064
1065
1066
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.DECODER,
1067
        ENABLE_ROPE: False,
1068
        ROPE_GROUP_METHOD: 'consecutive',
1069
1070
1071
1072
1073
1074
1075
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.DECODER,
1076
        ENABLE_ROPE: False,
1077
        ROPE_GROUP_METHOD: 'consecutive',
1078
1079
1080
1081
1082
1083
1084
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.DECODER,
1085
        ENABLE_ROPE: False,
1086
        ROPE_GROUP_METHOD: 'consecutive',
1087
1088
1089
1090
1091
1092
1093
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.DECODER,
1094
        ENABLE_ROPE: False,
1095
        ROPE_GROUP_METHOD: 'consecutive',
1096
1097
1098
1099
1100
1101
1102
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.DECODER,
1103
        ENABLE_ROPE: False,
1104
        ROPE_GROUP_METHOD: 'consecutive',
1105
1106
1107
1108
1109
1110
1111
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.DECODER,
1112
        ENABLE_ROPE: False,
1113
        ROPE_GROUP_METHOD: 'consecutive',
1114
1115
1116
1117
1118
1119
1120
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear'),
        LYR_TYPE: TransformerLayerType.ENCODER,
1121
        ENABLE_ROPE: False,
1122
        ROPE_GROUP_METHOD: 'consecutive',
1123
1124
1125
1126
1127
1128
1129
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear'),
        LYR_TYPE: TransformerLayerType.ENCODER,
1130
        ENABLE_ROPE: False,
1131
        ROPE_GROUP_METHOD: 'consecutive',
1132
1133
1134
1135
1136
1137
1138
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear'),
        LYR_TYPE: TransformerLayerType.ENCODER,
1139
        ENABLE_ROPE: False,
1140
        ROPE_GROUP_METHOD: 'consecutive',
1141
1142
1143
1144
1145
1146
1147
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear'),
        LYR_TYPE: TransformerLayerType.ENCODER,
1148
        ENABLE_ROPE: False,
1149
        ROPE_GROUP_METHOD: 'consecutive',
1150
        TRANSPOSE_BS: False
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
        TRANSPOSE_BS: False,
        LORA_SCOPE: 'all'
1161
1162
1163
1164
1165
1166
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear'),
        LYR_TYPE: TransformerLayerType.DECODER,
1167
        ENABLE_ROPE: False,
1168
        ROPE_GROUP_METHOD: 'consecutive',
1169
1170
1171
1172
1173
1174
1175
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear'),
        LYR_TYPE: TransformerLayerType.DECODER,
1176
        ENABLE_ROPE: False,
1177
        ROPE_GROUP_METHOD: 'consecutive',
1178
1179
1180
1181
1182
1183
1184
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear'),
        LYR_TYPE: TransformerLayerType.DECODER,
1185
        ENABLE_ROPE: False,
1186
        ROPE_GROUP_METHOD: 'consecutive',
1187
1188
1189
1190
1191
1192
1193
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear'),
        LYR_TYPE: TransformerLayerType.DECODER,
1194
        ENABLE_ROPE: False,
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
        ROPE_GROUP_METHOD: 'consecutive',
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('gelu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
        ENABLE_ROPE: True,
        ROPE_GROUP_METHOD: 'alternate',
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('gelu',),
        LYR_TYPE: TransformerLayerType.DECODER,
        ENABLE_ROPE: True,
        ROPE_GROUP_METHOD: 'alternate',
1214
1215
1216
1217
1218
1219
1220
1221
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('gelu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
        ENABLE_ROPE: True,
1222
        ROPE_GROUP_METHOD: 'consecutive',
1223
1224
1225
1226
1227
1228
1229
1230
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('gelu',),
        LYR_TYPE: TransformerLayerType.DECODER,
        ENABLE_ROPE: True,
1231
        ROPE_GROUP_METHOD: 'consecutive',
1232
        TRANSPOSE_BS: False
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu',),
        LYR_TYPE: TransformerLayerType.DECODER,
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
        TRANSPOSE_BS: False,
        LORA_SCOPE: 'all'
1243
1244
1245
1246
1247
1248
    }]


class TestTransformer(TestLayer):

    def input_getter(self, shape, dtype):
1249
1250
1251
1252
1253
1254
1255
1256
1257
        key = jax.random.PRNGKey(seed=1234)
        q_key, kv_key = jax.random.split(key, 2)
        b, s, *_ = shape
        if self.attrs[TransformerLayerAttr.TRANSPOSE_BS]:
            b, s = s, b
        mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
        return [
            *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask
        ]
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268

    def get_layer_name(self):
        return 'transformerlayer'

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        hidden_size = 512
        mlp_hidden_size = 2048
        num_attention_heads = 8
        layernorm_type = attrs[TransformerLayerAttr.LN_TYPE]
        hidden_dropout = 0.0
        attention_dropout = 0.0
1269
        intermediate_dropout = 0.0
1270
1271
1272
1273
1274
        mlp_activations = attrs[TransformerLayerAttr.ACTIVATION]
        kernel_init = WeightInit.Gaussian(1.0)
        use_bias = attrs[TransformerLayerAttr.USE_BIAS]
        bias_init = WeightInit.Constant(0.0)
        layer_type = attrs[TransformerLayerAttr.LYR_TYPE]
1275
        enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
1276
        rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD]
1277
        low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, 'none')
1278
1279
        enable_relative_embedding = True
        relative_embedding = pax_fiddle.Config(RelativePositionBiases,
1280
                                               dtype=dtype,
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
                                               num_attention_heads=num_attention_heads)
        drop_path = 0.0
        transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS]

        rel_embedding_init = RelativePositionBiases.generate_embedding_init(
            relative_embedding.embedding_init, relative_embedding.num_attention_heads,
            relative_embedding.num_buckets)

        relative_embedding_flax_module = flax_RelativePositionBiases(
            num_buckets=relative_embedding.num_buckets,
            max_distance=relative_embedding.max_distance,
            num_attention_heads=relative_embedding.num_attention_heads,
            embedding_init=TransformerEngineBaseLayer.generate_params_init(
                "rel_embedding", rel_embedding_init),
            embedding_axes=relative_embedding.embedding_axes,
            dtype=relative_embedding.dtype)

        praxis_p = pax_fiddle.Config(TransformerLayer,
                                     name='transformer_layer',
                                     params_init=kernel_init,
                                     dtype=dtype,
                                     hidden_size=hidden_size,
                                     mlp_hidden_size=mlp_hidden_size,
                                     num_attention_heads=num_attention_heads,
                                     layernorm_type=layernorm_type,
                                     hidden_dropout=hidden_dropout,
                                     attention_dropout=attention_dropout,
1308
                                     intermediate_dropout=intermediate_dropout,
1309
1310
1311
1312
1313
                                     mlp_activations=mlp_activations,
                                     use_bias=use_bias,
                                     bias_init=bias_init,
                                     layer_type=layer_type,
                                     enable_relative_embedding=enable_relative_embedding,
1314
                                     enable_rotary_pos_emb=enable_rotary_pos_emb,
1315
                                     rotary_pos_emb_group_method=rotary_pos_emb_group_method,
1316
                                     low_rank_adaptation_scope=low_rank_adaptation_scope,
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
                                     relative_embedding=relative_embedding,
                                     drop_path=drop_path,
                                     transpose_batch_sequence=transpose_batch_sequence)
        flax_cls = partial(flax_TransformerLayer,
                           dtype=dtype,
                           hidden_size=hidden_size,
                           mlp_hidden_size=mlp_hidden_size,
                           num_attention_heads=num_attention_heads,
                           layernorm_type=layernorm_type,
                           hidden_dropout=hidden_dropout,
                           attention_dropout=attention_dropout,
1328
                           intermediate_dropout=intermediate_dropout,
1329
1330
1331
1332
1333
1334
1335
1336
1337
                           mlp_activations=mlp_activations,
                           mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
                               "mha_kernel", kernel_init),
                           mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
                               "mlp_kernel", kernel_init),
                           use_bias=use_bias,
                           bias_init=TransformerEngineBaseLayer.generate_params_init(
                               "bias", bias_init),
                           layer_type=layer_type,
1338
                           enable_rotary_pos_emb=enable_rotary_pos_emb,
1339
                           rotary_pos_emb_group_method=rotary_pos_emb_group_method,
1340
1341
                           enable_relative_embedding=enable_relative_embedding,
                           relative_embedding=relative_embedding_flax_module,
1342
                           low_rank_adaptation_scope=low_rank_adaptation_scope,
1343
1344
1345
1346
1347
1348
1349
1350
1351
                           drop_path=drop_path,
                           transpose_batch_sequence=transpose_batch_sequence)

        return praxis_p, flax_cls

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
1352
        self.attrs = attrs
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    def test_forward_backward_fp8(self,
                                  data_shape,
                                  dtype,
                                  attrs,
                                  fp8_format,
                                  rtol=1e-05,
                                  atol=1e-08):
1368
        self.attrs = attrs
1369
1370
1371
1372
        ds = DelayedScaling(fp8_format=fp8_format)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)