test_praxis_layers.py 52.3 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
import os
6
7
8
from functools import partial
from typing import Dict

9
import flax
10
11
12
13
14
15
import jax
import jax.numpy as jnp
from praxis import pax_fiddle
from praxis.base_layer import WeightInit, DEFAULT_INIT_MUTABLE_LIST
import pytest

zlsh80826's avatar
zlsh80826 committed
16
17
from utils import assert_allclose

18
from transformer_engine_jax import get_device_compute_capability
19
20
21
22
23
24
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
from transformer_engine.jax.flax import LayerNorm as flax_LayerNorm
from transformer_engine.jax.flax import LayerNormMLP as flax_LayerNormMLP
from transformer_engine.jax.flax import MultiHeadAttention as flax_MultiHeadAttention
25
from transformer_engine.jax.flax import DotProductAttention as flax_DotProductAttention
26
27
28
29
30
from transformer_engine.jax.flax import RelativePositionBiases as flax_RelativePositionBiases
from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLayer
from transformer_engine.jax.flax.module import Softmax
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from transformer_engine.jax.praxis import LayerNorm
zlsh80826's avatar
zlsh80826 committed
31
from transformer_engine.jax.praxis import FusedSoftmax
32
from transformer_engine.jax.praxis import LayerNormLinear, LayerNormMLP, Linear
33
34
from transformer_engine.jax.praxis import DotProductAttention, MultiHeadAttention
from transformer_engine.jax.praxis import RelativePositionBiases, TransformerEngineBaseLayer
zlsh80826's avatar
zlsh80826 committed
35
from transformer_engine.jax.praxis import TransformerLayer, TransformerLayerType
36
37
38
39
from transformer_engine.jax.softmax import SoftmaxType

is_fp8_supported, reason = is_fp8_available()

40
DATA_SHAPE = [(32, 128, 512), (32, 512, 512)]    # (B, S, H)
41
42
43
44
45
DTYPE = [jnp.float32, jnp.bfloat16]
ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]


46
47
48
49
50
51
52
53
54
55
56
57
58
@pytest.fixture(autouse=True, scope='module')
def enable_fused_attn():
    """
    Enable fused attn for hopper+ arch.
    Fused attn kernels on pre-hopper arch are not deterministic.
    """
    if get_device_compute_capability(0) >= 90:
        os.environ["NVTE_FUSED_ATTN"] = "1"
    yield
    if "NVTE_FUSED_ATTN" in os.environ:
        del os.environ["NVTE_FUSED_ATTN"]


59
60
61
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
    for key in ref_fd:
        assert key in test_fd, \
62
            f"{key} not found in test dict {test_fd}"
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        assert isinstance(test_fd[key], type(ref_fd[key])), \
            f"The data type is not match between ref and test " \
            f" Dict on {key=}"
        if isinstance(ref_fd[key], Dict):
            compare_dict(ref_fd[key], test_fd[key], rtol, atol)
        else:
            assert_allclose(ref_fd[key],
                            test_fd[key],
                            rtol=rtol,
                            atol=atol,
                            err_msg=f"{key=} is not close")


class TestLayer:

    @staticmethod
    def loss(inner_variables, *inner_inputs, module, mean_out=True):
        outs = module.apply(inner_variables, *inner_inputs)
        out = outs
        if isinstance(outs, tuple):
            # The first place of outs is the real output, others
            # are auxiliary values.
            out = outs[0]
        return jnp.mean(out) if mean_out else out

    @staticmethod
    def loss_and_grads(module, variables, *inputs):
        grad_fn = jax.value_and_grad(TestLayer.loss, argnums=(0, 1))
        loss_val, (wgrads, dgrad) = grad_fn(variables, *inputs, module=module)
        if FP8Helper.is_fp8_enabled():
            wgrads = update_fp8_metas(wgrads)
        return loss_val, wgrads, dgrad

    def input_getter(self, shape, dtype):
        raise NotImplementedError

    def get_layer_name(self):
        raise NotImplementedError

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        raise NotImplementedError

    def sync_variables(self, praxis_variables, flax_variables):
        synced_praxis_variables = praxis_variables

        lyr_name = self.get_layer_name()

110
111
112
        if 'params' in flax_variables:
            synced_praxis_variables['params'][lyr_name]['cld'] = \
                flax.core.unfreeze(flax_variables['params'])
113
114
115
116
117
118
119
120

        return synced_praxis_variables, flax_variables

    def sync_wgrads(self, praxis_wgrads, flax_wgrads):
        synced_praxis_grads = praxis_wgrads

        lyr_name = self.get_layer_name()

121
122
123
        if 'params' in synced_praxis_grads:
            synced_praxis_grads['params'] = \
                synced_praxis_grads['params'][lyr_name]['cld']
124
125
126
127
128

        if FP8Helper.is_fp8_enabled():
            synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = \
                synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME][lyr_name]['cld']

129
        return synced_praxis_grads, flax.core.unfreeze(flax_wgrads)
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

    def forward_backward_runner(self,
                                data_shape,
                                dtype,
                                praxis_p,
                                flax_cls,
                                rtol=1e-05,
                                atol=1e-08):
        init_key = jax.random.PRNGKey(seed=1234)

        test_inputs = self.input_getter(data_shape, dtype)

        praxis_layer = praxis_p.Instantiate()
        # This is a workaround to correctly enable FP8 meta generation for Praxis.
        # TODO (Ming Huang): To come out a better solution.
        mutable_list = DEFAULT_INIT_MUTABLE_LIST + [FP8Helper.FP8_COLLECTION_NAME]
        praxis_variables = praxis_layer.init(init_key, *test_inputs, mutable=mutable_list)

        flax_layer = flax_cls()
        flax_variables = flax_layer.init(init_key, *test_inputs)
        if "params_axes" in flax_variables:
151
            flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
152
        if FP8Helper.is_fp8_enabled():
153
154
            flax_variables, _ = flax.core.pop(flax_variables,
                                              FP8Helper.FP8_COLLECTION_NAME + "_axes")
155
156
157
158
159
160
161
162
163
164
165
166
167

        praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)

        iter_times = 5 if FP8Helper.is_fp8_enabled() else 1

        for _ in range(iter_times):
            praxis_loss, praxis_wgrads, praxis_dgrad = \
                TestLayer.loss_and_grads(praxis_layer, praxis_variables, *test_inputs)
            flax_loss, flax_wgrads, flax_dgrad = \
                TestLayer.loss_and_grads(flax_layer, flax_variables, *test_inputs)
            if FP8Helper.is_fp8_enabled():
                praxis_wgrads.pop('params')
                praxis_variables = update_collections(praxis_wgrads, praxis_variables)
168
                flax_wgrads, _ = flax.core.pop(flax_wgrads, 'params')
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
                flax_variables = update_collections(flax_wgrads, flax_variables)

        praxis_loss, praxis_wgrads, praxis_dgrad = \
                TestLayer.loss_and_grads(praxis_layer, praxis_variables, *test_inputs)
        flax_loss, flax_wgrads, flax_dgrad = \
            TestLayer.loss_and_grads(flax_layer, flax_variables, *test_inputs)

        assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
        assert_allclose(praxis_dgrad, flax_dgrad, rtol=rtol, atol=atol)

        praxis_wgrads, flax_wgrads = self.sync_wgrads(praxis_wgrads, flax_wgrads)
        compare_dict(praxis_wgrads, flax_wgrads, rtol=rtol, atol=atol)


class LayerNormAttr:
    LN_TYPE = 'layernorm_type'
    ZERO_CEN = 'zero_centered_gamma'
    ATTRS = [{
        LN_TYPE: "layernorm",
        ZERO_CEN: False
    }, {
        LN_TYPE: "layernorm",
        ZERO_CEN: True
    }, {
        LN_TYPE: "rmsnorm",
        ZERO_CEN: False
    }]


class TestLayerNorm(TestLayer):

    def input_getter(self, shape, dtype):
        data_key = jax.random.PRNGKey(seed=1234)
        return (jax.random.normal(data_key, shape, dtype),)

    def get_layer_name(self):
        return 'layer_norm'

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        layernorm_type = attrs[LayerNormAttr.LN_TYPE]
        zero_centered_gamma = attrs[LayerNormAttr.ZERO_CEN]
        scale_init = None
        bias_init = WeightInit.Constant(0.0)
        transpose_batch_sequence = False

        praxis_p = pax_fiddle.Config(LayerNorm,
                                     name='layer_norm',
                                     dtype=dtype,
                                     layernorm_type=layernorm_type,
                                     zero_centered_gamma=zero_centered_gamma,
                                     scale_init=scale_init,
                                     bias_init=bias_init,
                                     transpose_batch_sequence=transpose_batch_sequence)
        flax_cls = partial(flax_LayerNorm,
                           layernorm_type=layernorm_type,
                           zero_centered_gamma=zero_centered_gamma,
                           scale_init=scale_init,
                           bias_init=TransformerEngineBaseLayer.generate_params_init(
                               "ln_bias", bias_init),
                           dtype=dtype,
                           transpose_batch_sequence=transpose_batch_sequence)

        return praxis_p, flax_cls

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', LayerNormAttr.ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class FusedSoftmaxAttr:
    SCALE_FACTOR = 'scale_factor'
    ST_TYPE = 'softmax_type'
    ATTRS = [{
        SCALE_FACTOR: 0.0,
        ST_TYPE: SoftmaxType.SCALED
    }, {
        SCALE_FACTOR: 0.0,
        ST_TYPE: SoftmaxType.SCALED_MASKED
    }, {
        SCALE_FACTOR: 0.0,
        ST_TYPE: SoftmaxType.SCALED_UPPER_TRIANG_MASKED
    }]


class TestFusedSoftmax(TestLayer):

    def input_getter(self, shape, dtype):
        data_key = jax.random.PRNGKey(seed=1234)
        return jax.random.normal(data_key, shape, dtype), \
               jnp.ones(shape, dtype=jnp.uint8) # Masks

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        scale_factor = attrs[FusedSoftmaxAttr.SCALE_FACTOR]
        softmax_type = attrs[FusedSoftmaxAttr.ST_TYPE]

        praxis_p = pax_fiddle.Config(FusedSoftmax,
                                     name='fused_softmax',
                                     scale_factor=scale_factor,
                                     softmax_type=softmax_type)
        flax_cls = partial(Softmax, scale_factor=scale_factor, softmax_type=softmax_type)

        return praxis_p, flax_cls

    def sync_variables(self, praxis_variables, flax_variables):
        return praxis_variables, flax_variables

    def sync_wgrads(self, praxis_wgrads, flax_wgrads):
        return praxis_wgrads, flax_wgrads

    @pytest.mark.parametrize('data_shape', [(32, 1, 128, 128), (32, 1, 512, 128)])
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', FusedSoftmaxAttr.ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        if (attrs[FusedSoftmaxAttr.ST_TYPE] == SoftmaxType.SCALED_UPPER_TRIANG_MASKED) and \
            (data_shape[-2] != data_shape[-1]):
            pass    # Skip, due to not support
        else:
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class LinearAttr:
    FEATURE = 'features'
    USE_BIAS = 'use_bias'
    ATTRS = [{
        FEATURE: 512,
        USE_BIAS: False
    }, {
        FEATURE: 512,
        USE_BIAS: True
    }, {
        FEATURE: 1024,
        USE_BIAS: False
    }, {
        FEATURE: 1024,
        USE_BIAS: True
    }]


class TestLinear(TestLayer):

    def input_getter(self, shape, dtype):
        data_key = jax.random.PRNGKey(seed=1234)
        return (jax.random.normal(data_key, shape, dtype),)

    def get_layer_name(self):
        return 'linear'

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        out_features = attrs[LinearAttr.FEATURE]
        kernel_init = WeightInit.Gaussian(1.0)
        use_bias = attrs[LinearAttr.USE_BIAS]
        bias_init = WeightInit.Constant(0.0)
        axis = -1
        transpose_batch_sequence = False

        praxis_p = pax_fiddle.Config(Linear,
                                     name='linear',
                                     dtype=dtype,
                                     out_features=out_features,
                                     params_init=kernel_init,
                                     use_bias=use_bias,
                                     bias_init=bias_init,
                                     axis=axis,
                                     transpose_batch_sequence=transpose_batch_sequence)
        flax_cls = partial(
            DenseGeneral,
            features=out_features,
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
            use_bias=use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
            axis=axis,
            dtype=dtype,
            transpose_batch_sequence=transpose_batch_sequence)

        return praxis_p, flax_cls

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', LinearAttr.ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', LinearAttr.ATTRS)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    def test_forward_backward_fp8(self,
                                  data_shape,
                                  dtype,
                                  attrs,
                                  fp8_format,
                                  rtol=1e-05,
                                  atol=1e-08):

        ds = DelayedScaling(fp8_format=fp8_format)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class LayerNormLinearAttr:
    FEATURE = 'features'
    USE_BIAS = 'use_bias'
    ENABLE_LN = 'enable_layernorm'
    LN_TYPE = 'layernorm_type'
    ZERO_CEN = 'zero_centered_gamma'
    ATTRS = [{
        FEATURE: 512,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False
    }, {
        FEATURE: 512,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False
    }, {
        FEATURE: 512,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True
    }, {
        FEATURE: 512,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True
    }, {
        FEATURE: 512,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False
    }, {
        FEATURE: 512,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False
    }, {
        FEATURE: 512,
        USE_BIAS: True,
        ENABLE_LN: False,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False
    }]


class TestLayerNormLinear(TestLayer):

    def input_getter(self, shape, dtype):
        data_key = jax.random.PRNGKey(seed=1234)
        return (jax.random.normal(data_key, shape, dtype),)

    def get_layer_name(self):
        return 'ln_linear'

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        out_features = attrs[LayerNormLinearAttr.FEATURE]
        enable_layernorm = attrs[LayerNormLinearAttr.ENABLE_LN]
        layernorm_type = attrs[LayerNormLinearAttr.LN_TYPE]
        zero_centered_gamma = attrs[LayerNormLinearAttr.ZERO_CEN]
        kernel_init = WeightInit.Gaussian(1.0)
        use_bias = attrs[LayerNormLinearAttr.USE_BIAS]
        bias_init = WeightInit.Constant(0.0)
        axis = -1
        transpose_batch_sequence = False

        praxis_p = pax_fiddle.Config(LayerNormLinear,
                                     name='ln_linear',
                                     dtype=dtype,
                                     out_features=out_features,
                                     enable_layernorm=enable_layernorm,
                                     layernorm_type=layernorm_type,
                                     zero_centered_gamma=zero_centered_gamma,
                                     params_init=kernel_init,
                                     use_bias=use_bias,
                                     bias_init=bias_init,
                                     axis=axis,
                                     transpose_batch_sequence=transpose_batch_sequence)
        flax_cls = partial(
            LayerNormDenseGeneral,
            features=out_features,
            enable_layernorm=enable_layernorm,
            layernorm_type=layernorm_type,
            zero_centered_gamma=zero_centered_gamma,
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
            use_bias=use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
            axis=axis,
            dtype=dtype,
            transpose_batch_sequence=transpose_batch_sequence)

        return praxis_p, flax_cls

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', LayerNormLinearAttr.ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', LayerNormLinearAttr.ATTRS)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    def test_forward_backward_fp8(self,
                                  data_shape,
                                  dtype,
                                  attrs,
                                  fp8_format,
                                  rtol=1e-05,
                                  atol=1e-08):

        ds = DelayedScaling(fp8_format=fp8_format)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class LayerNormMLPAttr:
    INTERMEDIATE_DIM = 'intermediate_dim'
    USE_BIAS = 'use_bias'
    ENABLE_LN = 'enable_layernorm'
    LN_TYPE = 'layernorm_type'
    ZERO_CEN = 'zero_centered_gamma'
    ACTIVATION = 'activations'
    ATTRS = [{
        INTERMEDIATE_DIM: 2048,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',)
    }, {
        INTERMEDIATE_DIM: 2048,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('relu',)
    }, {
        INTERMEDIATE_DIM: 2048,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',)
    }, {
        INTERMEDIATE_DIM: 2048,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear')
    }, {
        INTERMEDIATE_DIM: 2048,
536
        USE_BIAS: False,
537
538
539
540
        ENABLE_LN: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear')
541
542
543
544
545
546
547
548
549
550
551
552
553
554
    }, {
        INTERMEDIATE_DIM: 2048,
        USE_BIAS: True,
        ENABLE_LN: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('silu', 'linear')
    }, {
        INTERMEDIATE_DIM: 2048,
        USE_BIAS: False,
        ENABLE_LN: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('silu', 'linear')
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
    }]


class TestLayerNormMLP(TestLayer):

    def input_getter(self, shape, dtype):
        data_key = jax.random.PRNGKey(seed=1234)
        return (jax.random.normal(data_key, shape, dtype),)

    def get_layer_name(self):
        return 'ln_mlp'

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        intermediate_dim = attrs[LayerNormMLPAttr.INTERMEDIATE_DIM]
        enable_layernorm = attrs[LayerNormMLPAttr.ENABLE_LN]
        layernorm_type = attrs[LayerNormMLPAttr.LN_TYPE]
        zero_centered_gamma = attrs[LayerNormMLPAttr.ZERO_CEN]
        kernel_init = WeightInit.Gaussian(1.0)
        use_bias = attrs[LayerNormMLPAttr.USE_BIAS]
        bias_init = WeightInit.Constant(0.0)
        activations = attrs[LayerNormMLPAttr.ACTIVATION]
        axis = -1
        transpose_batch_sequence = False

        praxis_p = pax_fiddle.Config(LayerNormMLP,
                                     name='ln_mlp',
                                     dtype=dtype,
                                     intermediate_dim=intermediate_dim,
                                     enable_layernorm=enable_layernorm,
                                     layernorm_type=layernorm_type,
                                     zero_centered_gamma=zero_centered_gamma,
                                     params_init=kernel_init,
                                     use_bias=use_bias,
                                     bias_init=bias_init,
                                     activations=activations,
                                     intermediate_dropout_rate=0.0,
                                     axis=axis,
                                     transpose_batch_sequence=transpose_batch_sequence)
        flax_cls = partial(
            flax_LayerNormMLP,
            intermediate_dim=intermediate_dim,
            enable_layernorm=enable_layernorm,
            layernorm_type=layernorm_type,
            zero_centered_gamma=zero_centered_gamma,
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
            use_bias=use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
            activations=activations,
            intermediate_dropout_rate=0.0,
            axis=axis,
            dtype=dtype,
            transpose_batch_sequence=transpose_batch_sequence)

        return praxis_p, flax_cls

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', LayerNormMLPAttr.ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', LayerNormMLPAttr.ATTRS)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    def test_forward_backward_fp8(self,
                                  data_shape,
                                  dtype,
                                  attrs,
                                  fp8_format,
                                  rtol=1e-05,
                                  atol=1e-08):

        ds = DelayedScaling(fp8_format=fp8_format)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class TestRelativePositionBias(TestLayer):

    def get_layer_name(self):
        return 'relative_position_bias'

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        num_buckets = 32
        max_distance = 128
        num_attention_heads = 64
        rb_stddev = (num_attention_heads * num_buckets)**-0.5
        embedding_init = WeightInit.Gaussian(rb_stddev)

        praxis_p = pax_fiddle.Config(RelativePositionBiases,
                                     name='relative_position_bias',
                                     dtype=dtype,
                                     num_buckets=num_buckets,
                                     max_distance=max_distance,
                                     num_attention_heads=num_attention_heads,
                                     embedding_init=embedding_init)
        flax_cls = partial(flax_RelativePositionBiases,
                           num_buckets=num_buckets,
                           max_distance=max_distance,
                           num_attention_heads=num_attention_heads,
                           embedding_init=TransformerEngineBaseLayer.generate_params_init(
                               "rel_embedding", embedding_init),
                           dtype=dtype)

        return praxis_p, flax_cls

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', [{}])
    def test_forward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)

        init_key = jax.random.PRNGKey(seed=1234)

        test_inputs = [(128, 128, True), (128, 128, False)]
        for test_input in test_inputs:
            praxis_layer = praxis_p.Instantiate()
            praxis_variables = praxis_layer.init(init_key, *test_input)

            flax_layer = flax_cls()
            flax_variables = flax_layer.init(init_key, *test_input)
            if "params_axes" in flax_variables:
681
                flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
682
            if FP8Helper.is_fp8_enabled():
683
684
                flax_variables, _ = flax.core.pop(flax_variables,
                                                  FP8Helper.FP8_COLLECTION_NAME + "_axes")
685
686
687
688

            praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)

            praxis_loss= \
zlsh80826's avatar
zlsh80826 committed
689
                TestLayer.loss(praxis_variables, *test_input, module=praxis_layer, mean_out=False)
690
691
692
693
694
695
            flax_loss = \
                TestLayer.loss(flax_variables, *test_input, module=flax_layer, mean_out=False)

            assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)


696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
class DotProductAttnAttr:
    ATTN_MASK_TYPE = 'attn_mask_type'
    NUM_GQA_GROUPS = 'num_gqa_groups'
    TRANSPOSE_BS = 'transpose_batch_sequence'
    SCALE_FACTOR = 'scale_factor'
    ATTRS = [{
        ATTN_MASK_TYPE: 'padding',
        TRANSPOSE_BS: True,
        SCALE_FACTOR: 0.125,
    }, {
        ATTN_MASK_TYPE: 'padding_causal',
        TRANSPOSE_BS: True,
        SCALE_FACTOR: 0.125,
    }, {
        ATTN_MASK_TYPE: 'causal',
        TRANSPOSE_BS: True,
        SCALE_FACTOR: 0.125,
    }, {
        ATTN_MASK_TYPE: 'padding',
        TRANSPOSE_BS: False,
        SCALE_FACTOR: 0.125,
    }, {
        ATTN_MASK_TYPE: 'padding_causal',
        TRANSPOSE_BS: False,
        SCALE_FACTOR: 2.,
    }, {
        ATTN_MASK_TYPE: 'causal',
        TRANSPOSE_BS: False,
        SCALE_FACTOR: 1.,
    }, {
        ATTN_MASK_TYPE: 'no_mask',
        TRANSPOSE_BS: False,
        SCALE_FACTOR: 1.,
    }]


class TestDotProductAttn(TestLayer):

    def input_getter(self, shape, dtype):
        key = jax.random.PRNGKey(seed=1234)
        q_key, k_key, v_key = jax.random.split(key, 3)
737
738
        b, s, *_ = shape
        if self.attrs[DotProductAttnAttr.TRANSPOSE_BS]:
739
            shape = (shape[1], shape[0]) + shape[2:]
740
741
742
743
        mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
        return [
            *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask
        ]
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776

    def get_layer_name(self):
        return 'dot_product_attn'

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        head_dim = 64
        num_attention_heads = 16
        num_gqa_groups = num_attention_heads
        attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE]
        transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS]

        praxis_p = pax_fiddle.Config(DotProductAttention,
                                     name='mha',
                                     dtype=dtype,
                                     head_dim=head_dim,
                                     num_attention_heads=num_attention_heads,
                                     num_gqa_groups=num_gqa_groups,
                                     attn_mask_type=attn_mask_type,
                                     transpose_batch_sequence=transpose_batch_sequence)
        flax_cls = partial(flax_DotProductAttention,
                           dtype=dtype,
                           head_dim=head_dim,
                           num_attention_heads=num_attention_heads,
                           num_gqa_groups=num_gqa_groups,
                           attn_mask_type=attn_mask_type,
                           transpose_batch_sequence=transpose_batch_sequence)

        return praxis_p, flax_cls

    @pytest.mark.parametrize('data_shape', [(32, 128, 16, 64)])
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
777
        self.attrs = attrs
778
779
780
781
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


782
783
784
class MultiHeadAttnAttr:
    USE_BIAS = 'use_bias'
    LN_TYPE = 'layernorm_type'
785
    ATTN_MASK_TYPE = 'attn_mask_type'
786
    ZERO_CEN = 'zero_centered_gamma'
zlsh80826's avatar
zlsh80826 committed
787
788
    NUM_ATTN_HEADS = 'num_attention_heads'
    NUM_GQA_GROUPS = 'num_gqa_groups'
789
    TRANSPOSE_BS = 'transpose_batch_sequence'
790
791
    ENABLE_ROPE = 'enable_rotary_pos_emb'
    ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
792
    LORA_SCOPE = 'low_rank_adaptation_scope'
793
794
795
796
    ATTRS = [{
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
797
798
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
799
800
        ATTN_MASK_TYPE: 'padding',
        TRANSPOSE_BS: True,
801
802
803
804
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
805
806
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
807
808
        ATTN_MASK_TYPE: 'padding',
        TRANSPOSE_BS: False,
809
810
811
812
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
813
814
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
815
816
        ATTN_MASK_TYPE: 'padding',
        TRANSPOSE_BS: True,
817
818
819
820
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
821
822
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
823
824
        ATTN_MASK_TYPE: 'causal',
        TRANSPOSE_BS: False,
825
826
827
828
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
829
830
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
831
832
        ATTN_MASK_TYPE: 'causal',
        TRANSPOSE_BS: True,
833
834
835
836
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
837
838
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
839
840
        ATTN_MASK_TYPE: 'causal',
        TRANSPOSE_BS: False,
zlsh80826's avatar
zlsh80826 committed
841
842
843
844
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
845
846
847
848
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
        NUM_ATTN_HEADS: 8,
        NUM_GQA_GROUPS: 4,
849
850
        ATTN_MASK_TYPE: 'causal',
        TRANSPOSE_BS: True,
851
852
853
854
855
856
857
858
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ENABLE_ROPE: True,
        ROPE_GROUP_METHOD: 'consecutive',
        NUM_ATTN_HEADS: 8,
        NUM_GQA_GROUPS: 4,
859
860
        ATTN_MASK_TYPE: 'causal',
        TRANSPOSE_BS: False,
861
862
863
864
865
866
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ENABLE_ROPE: True,
        ROPE_GROUP_METHOD: 'alternate',
zlsh80826's avatar
zlsh80826 committed
867
868
        NUM_ATTN_HEADS: 8,
        NUM_GQA_GROUPS: 4,
869
870
        ATTN_MASK_TYPE: 'causal',
        TRANSPOSE_BS: True,
871
872
873
874
875
876
877
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
        ATTN_MASK_TYPE: 'padding',
878
879
        LORA_SCOPE: 'all',
        TRANSPOSE_BS: False,
880
881
882
883
884
885
886
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
        ATTN_MASK_TYPE: 'causal',
887
888
        LORA_SCOPE: 'all',
        TRANSPOSE_BS: True,
889
890
891
892
893
894
    }]


class TestMultiHeadAttn(TestLayer):

    def input_getter(self, shape, dtype):
895
896
        key = jax.random.PRNGKey(seed=1234)
        q_key, kv_key = jax.random.split(key, 2)
897
898
899
        b, s, *_ = shape
        if self.attrs[MultiHeadAttnAttr.TRANSPOSE_BS]:
            shape = (shape[1], shape[0]) + shape[2:]
900
901
        mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
        return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask]
902
903
904
905
906
907

    def get_layer_name(self):
        return 'multi_head_attn'

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        head_dim = 64
908
909
910
        num_attention_heads = 16
        num_gqa_groups = attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS] \
            if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs else None
911
912
913
914
915
        layernorm_type = attrs[MultiHeadAttnAttr.LN_TYPE]
        zero_centered_gamma = attrs[MultiHeadAttnAttr.ZERO_CEN]
        kernel_init = WeightInit.Gaussian(1.0)
        use_bias = attrs[MultiHeadAttnAttr.USE_BIAS]
        bias_init = WeightInit.Constant(0.0)
916
917
        input_layernorm = False
        return_layernorm_output = False
918
        attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE]
919
920
        enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE]
        rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD]
921
        low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, 'none')
922
        fuse_qkv_params = True
923
        transpose_batch_sequence = attrs[MultiHeadAttnAttr.TRANSPOSE_BS]
924
925
926
927
        scale_attn_logits = False
        scaled_query_init = True
        float32_logits = False

928
929
930
931
932
933
934
935
936
937
938
939
940
941
        praxis_p = pax_fiddle.Config(MultiHeadAttention,
                                     name='mha',
                                     dtype=dtype,
                                     head_dim=head_dim,
                                     num_attention_heads=num_attention_heads,
                                     num_gqa_groups=num_gqa_groups,
                                     layernorm_type=layernorm_type,
                                     zero_centered_gamma=zero_centered_gamma,
                                     params_init=kernel_init,
                                     use_bias=use_bias,
                                     bias_init=bias_init,
                                     return_layernorm_output=return_layernorm_output,
                                     input_layernorm=input_layernorm,
                                     attn_mask_type=attn_mask_type,
942
943
                                     enable_rotary_pos_emb=enable_rotary_pos_emb,
                                     rotary_pos_emb_group_method=rotary_pos_emb_group_method,
944
                                     low_rank_adaptation_scope=low_rank_adaptation_scope,
945
946
947
948
949
                                     fuse_qkv_params=fuse_qkv_params,
                                     transpose_batch_sequence=transpose_batch_sequence,
                                     scale_attn_logits=scale_attn_logits,
                                     scaled_query_init=scaled_query_init,
                                     float32_logits=float32_logits)
950
951
952
953
        flax_cls = partial(
            flax_MultiHeadAttention,
            dtype=dtype,
            head_dim=head_dim,
954
955
            num_attention_heads=num_attention_heads,
            num_gqa_groups=num_gqa_groups,
956
957
958
959
960
            layernorm_type=layernorm_type,
            zero_centered_gamma=zero_centered_gamma,
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
            use_bias=use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
961
962
            return_layernorm_output=return_layernorm_output,
            input_layernorm=input_layernorm,
963
            attn_mask_type=attn_mask_type,
964
965
            enable_rotary_pos_emb=enable_rotary_pos_emb,
            rotary_pos_emb_group_method=rotary_pos_emb_group_method,
966
            low_rank_adaptation_scope=low_rank_adaptation_scope,
967
            fuse_qkv_params=fuse_qkv_params,
968
969
970
971
972
973
974
975
976
977
978
            transpose_batch_sequence=transpose_batch_sequence,
            scale_attn_logits=scale_attn_logits,
            scaled_query_init=scaled_query_init,
            float32_logits=float32_logits)

        return praxis_p, flax_cls

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
979
        self.attrs = attrs
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    def test_forward_backward_fp8(self,
                                  data_shape,
                                  dtype,
                                  attrs,
                                  fp8_format,
                                  rtol=1e-05,
                                  atol=1e-08):
995
        self.attrs = attrs
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
        ds = DelayedScaling(fp8_format=fp8_format)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class TransformerLayerAttr:
    USE_BIAS = 'use_bias'
    LN_TYPE = 'layernorm_type'
    ACTIVATION = 'activations'
    LYR_TYPE = 'layer_type'
    ZERO_CEN = 'zero_centered_gamma'
    TRANSPOSE_BS = 'transpose_batch_sequence'
1009
    ENABLE_ROPE = 'enable_rotary_pos_emb'
1010
    ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
1011
    LORA_SCOPE = 'low_rank_adaptation_scope'
1012
1013
1014
1015
1016
1017
    ATTRS = [{
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
1018
        ENABLE_ROPE: False,
1019
        ROPE_GROUP_METHOD: 'consecutive',
1020
1021
1022
1023
1024
1025
1026
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
1027
        ENABLE_ROPE: False,
1028
        ROPE_GROUP_METHOD: 'consecutive',
1029
1030
1031
1032
1033
1034
1035
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
1036
        ENABLE_ROPE: False,
1037
        ROPE_GROUP_METHOD: 'consecutive',
1038
1039
1040
1041
1042
1043
1044
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
1045
        ENABLE_ROPE: False,
1046
        ROPE_GROUP_METHOD: 'consecutive',
1047
1048
1049
1050
1051
1052
1053
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
1054
        ENABLE_ROPE: False,
1055
        ROPE_GROUP_METHOD: 'consecutive',
1056
1057
1058
1059
1060
1061
1062
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
1063
        ENABLE_ROPE: False,
1064
        ROPE_GROUP_METHOD: 'consecutive',
1065
1066
1067
1068
1069
1070
1071
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.DECODER,
1072
        ENABLE_ROPE: False,
1073
        ROPE_GROUP_METHOD: 'consecutive',
1074
1075
1076
1077
1078
1079
1080
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.DECODER,
1081
        ENABLE_ROPE: False,
1082
        ROPE_GROUP_METHOD: 'consecutive',
1083
1084
1085
1086
1087
1088
1089
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.DECODER,
1090
        ENABLE_ROPE: False,
1091
        ROPE_GROUP_METHOD: 'consecutive',
1092
1093
1094
1095
1096
1097
1098
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.DECODER,
1099
        ENABLE_ROPE: False,
1100
        ROPE_GROUP_METHOD: 'consecutive',
1101
1102
1103
1104
1105
1106
1107
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.DECODER,
1108
        ENABLE_ROPE: False,
1109
        ROPE_GROUP_METHOD: 'consecutive',
1110
1111
1112
1113
1114
1115
1116
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('relu',),
        LYR_TYPE: TransformerLayerType.DECODER,
1117
        ENABLE_ROPE: False,
1118
        ROPE_GROUP_METHOD: 'consecutive',
1119
1120
1121
1122
1123
1124
1125
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear'),
        LYR_TYPE: TransformerLayerType.ENCODER,
1126
        ENABLE_ROPE: False,
1127
        ROPE_GROUP_METHOD: 'consecutive',
1128
1129
1130
1131
1132
1133
1134
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear'),
        LYR_TYPE: TransformerLayerType.ENCODER,
1135
        ENABLE_ROPE: False,
1136
        ROPE_GROUP_METHOD: 'consecutive',
1137
1138
1139
1140
1141
1142
1143
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear'),
        LYR_TYPE: TransformerLayerType.ENCODER,
1144
        ENABLE_ROPE: False,
1145
        ROPE_GROUP_METHOD: 'consecutive',
1146
1147
1148
1149
1150
1151
1152
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear'),
        LYR_TYPE: TransformerLayerType.ENCODER,
1153
        ENABLE_ROPE: False,
1154
        ROPE_GROUP_METHOD: 'consecutive',
1155
        TRANSPOSE_BS: False
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
        TRANSPOSE_BS: False,
        LORA_SCOPE: 'all'
1166
1167
1168
1169
1170
1171
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear'),
        LYR_TYPE: TransformerLayerType.DECODER,
1172
        ENABLE_ROPE: False,
1173
        ROPE_GROUP_METHOD: 'consecutive',
1174
1175
1176
1177
1178
1179
1180
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear'),
        LYR_TYPE: TransformerLayerType.DECODER,
1181
        ENABLE_ROPE: False,
1182
        ROPE_GROUP_METHOD: 'consecutive',
1183
1184
1185
1186
1187
1188
1189
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear'),
        LYR_TYPE: TransformerLayerType.DECODER,
1190
        ENABLE_ROPE: False,
1191
        ROPE_GROUP_METHOD: 'consecutive',
1192
1193
1194
1195
1196
1197
1198
        TRANSPOSE_BS: True
    }, {
        USE_BIAS: True,
        LN_TYPE: 'rmsnorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu', 'linear'),
        LYR_TYPE: TransformerLayerType.DECODER,
1199
        ENABLE_ROPE: False,
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
        ROPE_GROUP_METHOD: 'consecutive',
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('gelu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
        ENABLE_ROPE: True,
        ROPE_GROUP_METHOD: 'alternate',
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('gelu',),
        LYR_TYPE: TransformerLayerType.DECODER,
        ENABLE_ROPE: True,
        ROPE_GROUP_METHOD: 'alternate',
1219
1220
1221
1222
1223
1224
1225
1226
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('gelu',),
        LYR_TYPE: TransformerLayerType.ENCODER,
        ENABLE_ROPE: True,
1227
        ROPE_GROUP_METHOD: 'consecutive',
1228
1229
1230
1231
1232
1233
1234
1235
        TRANSPOSE_BS: False
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: True,
        ACTIVATION: ('gelu',),
        LYR_TYPE: TransformerLayerType.DECODER,
        ENABLE_ROPE: True,
1236
        ROPE_GROUP_METHOD: 'consecutive',
1237
        TRANSPOSE_BS: False
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
    }, {
        USE_BIAS: True,
        LN_TYPE: 'layernorm',
        ZERO_CEN: False,
        ACTIVATION: ('gelu',),
        LYR_TYPE: TransformerLayerType.DECODER,
        ENABLE_ROPE: False,
        ROPE_GROUP_METHOD: 'consecutive',
        TRANSPOSE_BS: False,
        LORA_SCOPE: 'all'
1248
1249
1250
1251
1252
1253
    }]


class TestTransformer(TestLayer):

    def input_getter(self, shape, dtype):
1254
1255
1256
1257
        key = jax.random.PRNGKey(seed=1234)
        q_key, kv_key = jax.random.split(key, 2)
        b, s, *_ = shape
        if self.attrs[TransformerLayerAttr.TRANSPOSE_BS]:
1258
            shape = (shape[1], shape[0]) + shape[2:]
1259
1260
1261
1262
        mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
        return [
            *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask
        ]
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273

    def get_layer_name(self):
        return 'transformerlayer'

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        hidden_size = 512
        mlp_hidden_size = 2048
        num_attention_heads = 8
        layernorm_type = attrs[TransformerLayerAttr.LN_TYPE]
        hidden_dropout = 0.0
        attention_dropout = 0.0
1274
        intermediate_dropout = 0.0
1275
1276
1277
1278
1279
        mlp_activations = attrs[TransformerLayerAttr.ACTIVATION]
        kernel_init = WeightInit.Gaussian(1.0)
        use_bias = attrs[TransformerLayerAttr.USE_BIAS]
        bias_init = WeightInit.Constant(0.0)
        layer_type = attrs[TransformerLayerAttr.LYR_TYPE]
1280
        enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
1281
        rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD]
1282
        low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, 'none')
1283
1284
        enable_relative_embedding = True
        relative_embedding = pax_fiddle.Config(RelativePositionBiases,
1285
                                               dtype=dtype,
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
                                               num_attention_heads=num_attention_heads)
        drop_path = 0.0
        transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS]

        rel_embedding_init = RelativePositionBiases.generate_embedding_init(
            relative_embedding.embedding_init, relative_embedding.num_attention_heads,
            relative_embedding.num_buckets)

        relative_embedding_flax_module = flax_RelativePositionBiases(
            num_buckets=relative_embedding.num_buckets,
            max_distance=relative_embedding.max_distance,
            num_attention_heads=relative_embedding.num_attention_heads,
            embedding_init=TransformerEngineBaseLayer.generate_params_init(
                "rel_embedding", rel_embedding_init),
            embedding_axes=relative_embedding.embedding_axes,
            dtype=relative_embedding.dtype)

        praxis_p = pax_fiddle.Config(TransformerLayer,
                                     name='transformer_layer',
                                     params_init=kernel_init,
                                     dtype=dtype,
                                     hidden_size=hidden_size,
                                     mlp_hidden_size=mlp_hidden_size,
                                     num_attention_heads=num_attention_heads,
                                     layernorm_type=layernorm_type,
                                     hidden_dropout=hidden_dropout,
                                     attention_dropout=attention_dropout,
1313
                                     intermediate_dropout=intermediate_dropout,
1314
1315
1316
1317
1318
                                     mlp_activations=mlp_activations,
                                     use_bias=use_bias,
                                     bias_init=bias_init,
                                     layer_type=layer_type,
                                     enable_relative_embedding=enable_relative_embedding,
1319
                                     enable_rotary_pos_emb=enable_rotary_pos_emb,
1320
                                     rotary_pos_emb_group_method=rotary_pos_emb_group_method,
1321
                                     low_rank_adaptation_scope=low_rank_adaptation_scope,
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
                                     relative_embedding=relative_embedding,
                                     drop_path=drop_path,
                                     transpose_batch_sequence=transpose_batch_sequence)
        flax_cls = partial(flax_TransformerLayer,
                           dtype=dtype,
                           hidden_size=hidden_size,
                           mlp_hidden_size=mlp_hidden_size,
                           num_attention_heads=num_attention_heads,
                           layernorm_type=layernorm_type,
                           hidden_dropout=hidden_dropout,
                           attention_dropout=attention_dropout,
1333
                           intermediate_dropout=intermediate_dropout,
1334
1335
1336
1337
1338
1339
1340
1341
1342
                           mlp_activations=mlp_activations,
                           mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
                               "mha_kernel", kernel_init),
                           mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
                               "mlp_kernel", kernel_init),
                           use_bias=use_bias,
                           bias_init=TransformerEngineBaseLayer.generate_params_init(
                               "bias", bias_init),
                           layer_type=layer_type,
1343
                           enable_rotary_pos_emb=enable_rotary_pos_emb,
1344
                           rotary_pos_emb_group_method=rotary_pos_emb_group_method,
1345
1346
                           enable_relative_embedding=enable_relative_embedding,
                           relative_embedding=relative_embedding_flax_module,
1347
                           low_rank_adaptation_scope=low_rank_adaptation_scope,
1348
1349
1350
1351
1352
1353
1354
1355
1356
                           drop_path=drop_path,
                           transpose_batch_sequence=transpose_batch_sequence)

        return praxis_p, flax_cls

    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS)
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
1357
        self.attrs = attrs
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
    @pytest.mark.parametrize('data_shape', DATA_SHAPE)
    @pytest.mark.parametrize('dtype', DTYPE)
    @pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS)
    @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
    def test_forward_backward_fp8(self,
                                  data_shape,
                                  dtype,
                                  attrs,
                                  fp8_format,
                                  rtol=1e-05,
                                  atol=1e-08):
1373
        self.attrs = attrs
1374
1375
1376
1377
        ds = DelayedScaling(fp8_format=fp8_format)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)