test_praxis_layers.py 49.6 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
import os
6
7
8
from functools import partial
from typing import Dict

9
import flax
10
11
12
13
14
15
import jax
import jax.numpy as jnp
from praxis import pax_fiddle
from praxis.base_layer import WeightInit, DEFAULT_INIT_MUTABLE_LIST
import pytest

zlsh80826's avatar
zlsh80826 committed
16
17
from utils import assert_allclose

18
from transformer_engine.common.recipe import DelayedScaling, Format
19
from transformer_engine.jax import fp8_autocast, update_collections
20
21
22
23
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
from transformer_engine.jax.flax import LayerNorm as flax_LayerNorm
from transformer_engine.jax.flax import LayerNormMLP as flax_LayerNormMLP
from transformer_engine.jax.flax import MultiHeadAttention as flax_MultiHeadAttention
24
from transformer_engine.jax.flax import DotProductAttention as flax_DotProductAttention
25
26
27
28
29
from transformer_engine.jax.flax import RelativePositionBiases as flax_RelativePositionBiases
from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLayer
from transformer_engine.jax.flax.module import Softmax
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from transformer_engine.jax.praxis import LayerNorm
zlsh80826's avatar
zlsh80826 committed
30
from transformer_engine.jax.praxis import FusedSoftmax
31
from transformer_engine.jax.praxis import LayerNormLinear, LayerNormMLP, Linear
32
33
from transformer_engine.jax.praxis import DotProductAttention, MultiHeadAttention
from transformer_engine.jax.praxis import RelativePositionBiases, TransformerEngineBaseLayer
zlsh80826's avatar
zlsh80826 committed
34
from transformer_engine.jax.praxis import TransformerLayer, TransformerLayerType
35
36
37
38
from transformer_engine.jax.softmax import SoftmaxType

is_fp8_supported, reason = is_fp8_available()

39
DATA_SHAPE = [(32, 128, 512), (32, 512, 512)]  # (B, S, H)
40
41
42
43
44
45
46
DTYPE = [jnp.float32, jnp.bfloat16]
ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]


def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
    for key in ref_fd:
47
48
49
50
        assert key in test_fd, f"{key} not found in test dict {test_fd}"
        assert isinstance(
            test_fd[key], type(ref_fd[key])
        ), f"The data type is not match between ref and test  Dict on {key=}"
51
52
53
        if isinstance(ref_fd[key], Dict):
            compare_dict(ref_fd[key], test_fd[key], rtol, atol)
        else:
54
55
56
            assert_allclose(
                ref_fd[key], test_fd[key], rtol=rtol, atol=atol, err_msg=f"{key=} is not close"
            )
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90


class TestLayer:

    @staticmethod
    def loss(inner_variables, *inner_inputs, module, mean_out=True):
        outs = module.apply(inner_variables, *inner_inputs)
        out = outs
        if isinstance(outs, tuple):
            # The first place of outs is the real output, others
            # are auxiliary values.
            out = outs[0]
        return jnp.mean(out) if mean_out else out

    @staticmethod
    def loss_and_grads(module, variables, *inputs):
        grad_fn = jax.value_and_grad(TestLayer.loss, argnums=(0, 1))
        loss_val, (wgrads, dgrad) = grad_fn(variables, *inputs, module=module)
        return loss_val, wgrads, dgrad

    def input_getter(self, shape, dtype):
        raise NotImplementedError

    def get_layer_name(self):
        raise NotImplementedError

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        raise NotImplementedError

    def sync_variables(self, praxis_variables, flax_variables):
        synced_praxis_variables = praxis_variables

        lyr_name = self.get_layer_name()

91
92
93
94
        if "params" in flax_variables:
            synced_praxis_variables["params"][lyr_name]["cld"] = flax.core.unfreeze(
                flax_variables["params"]
            )
95
96
97
98
99
100
101
102

        return synced_praxis_variables, flax_variables

    def sync_wgrads(self, praxis_wgrads, flax_wgrads):
        synced_praxis_grads = praxis_wgrads

        lyr_name = self.get_layer_name()

103
104
        if "params" in synced_praxis_grads:
            synced_praxis_grads["params"] = synced_praxis_grads["params"][lyr_name]["cld"]
105
106

        if FP8Helper.is_fp8_enabled():
107
108
109
            synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = synced_praxis_grads[
                FP8Helper.FP8_COLLECTION_NAME
            ][lyr_name]["cld"]
110

111
        return synced_praxis_grads, flax.core.unfreeze(flax_wgrads)
112

113
114
115
    def forward_backward_runner(
        self, data_shape, dtype, praxis_p, flax_cls, rtol=1e-05, atol=1e-08
    ):
116
117
118
119
120
121
122
123
124
125
126
127
128
        init_key = jax.random.PRNGKey(seed=1234)

        test_inputs = self.input_getter(data_shape, dtype)

        praxis_layer = praxis_p.Instantiate()
        # This is a workaround to correctly enable FP8 meta generation for Praxis.
        # TODO (Ming Huang): To come out a better solution.
        mutable_list = DEFAULT_INIT_MUTABLE_LIST + [FP8Helper.FP8_COLLECTION_NAME]
        praxis_variables = praxis_layer.init(init_key, *test_inputs, mutable=mutable_list)

        flax_layer = flax_cls()
        flax_variables = flax_layer.init(init_key, *test_inputs)
        if "params_axes" in flax_variables:
129
            flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
130
        if FP8Helper.is_fp8_enabled():
131
132
133
            flax_variables, _ = flax.core.pop(
                flax_variables, FP8Helper.FP8_COLLECTION_NAME + "_axes"
            )
134
135
136
137
138
139

        praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)

        iter_times = 5 if FP8Helper.is_fp8_enabled() else 1

        for _ in range(iter_times):
140
141
142
143
144
145
            praxis_loss, praxis_wgrads, praxis_dgrad = TestLayer.loss_and_grads(
                praxis_layer, praxis_variables, *test_inputs
            )
            flax_loss, flax_wgrads, flax_dgrad = TestLayer.loss_and_grads(
                flax_layer, flax_variables, *test_inputs
            )
146
            if FP8Helper.is_fp8_enabled():
147
                praxis_wgrads.pop("params")
148
                praxis_variables = update_collections(praxis_wgrads, praxis_variables)
149
                flax_wgrads, _ = flax.core.pop(flax_wgrads, "params")
150
151
                flax_variables = update_collections(flax_wgrads, flax_variables)

152
153
154
155
156
157
        praxis_loss, praxis_wgrads, praxis_dgrad = TestLayer.loss_and_grads(
            praxis_layer, praxis_variables, *test_inputs
        )
        flax_loss, flax_wgrads, flax_dgrad = TestLayer.loss_and_grads(
            flax_layer, flax_variables, *test_inputs
        )
158
159
160
161
162
163
164
165
166

        assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
        assert_allclose(praxis_dgrad, flax_dgrad, rtol=rtol, atol=atol)

        praxis_wgrads, flax_wgrads = self.sync_wgrads(praxis_wgrads, flax_wgrads)
        compare_dict(praxis_wgrads, flax_wgrads, rtol=rtol, atol=atol)


class LayerNormAttr:
167
168
169
170
171
172
173
    LN_TYPE = "layernorm_type"
    ZERO_CEN = "zero_centered_gamma"
    ATTRS = [
        {LN_TYPE: "layernorm", ZERO_CEN: False},
        {LN_TYPE: "layernorm", ZERO_CEN: True},
        {LN_TYPE: "rmsnorm", ZERO_CEN: False},
    ]
174
175
176
177
178
179
180
181
182


class TestLayerNorm(TestLayer):

    def input_getter(self, shape, dtype):
        data_key = jax.random.PRNGKey(seed=1234)
        return (jax.random.normal(data_key, shape, dtype),)

    def get_layer_name(self):
183
        return "layer_norm"
184
185
186
187
188
189
190
191

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        layernorm_type = attrs[LayerNormAttr.LN_TYPE]
        zero_centered_gamma = attrs[LayerNormAttr.ZERO_CEN]
        scale_init = None
        bias_init = WeightInit.Constant(0.0)
        transpose_batch_sequence = False

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        praxis_p = pax_fiddle.Config(
            LayerNorm,
            name="layer_norm",
            dtype=dtype,
            layernorm_type=layernorm_type,
            zero_centered_gamma=zero_centered_gamma,
            scale_init=scale_init,
            bias_init=bias_init,
            transpose_batch_sequence=transpose_batch_sequence,
        )
        flax_cls = partial(
            flax_LayerNorm,
            layernorm_type=layernorm_type,
            zero_centered_gamma=zero_centered_gamma,
            scale_init=scale_init,
            bias_init=TransformerEngineBaseLayer.generate_params_init("ln_bias", bias_init),
            dtype=dtype,
            transpose_batch_sequence=transpose_batch_sequence,
        )
211
212
213

        return praxis_p, flax_cls

214
215
216
    @pytest.mark.parametrize("data_shape", DATA_SHAPE)
    @pytest.mark.parametrize("dtype", DTYPE)
    @pytest.mark.parametrize("attrs", LayerNormAttr.ATTRS)
217
218
219
220
221
222
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class FusedSoftmaxAttr:
223
224
225
226
227
228
229
    SCALE_FACTOR = "scale_factor"
    ST_TYPE = "softmax_type"
    ATTRS = [
        {SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED},
        {SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED_MASKED},
        {SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED_UPPER_TRIANG_MASKED},
    ]
230
231
232
233
234
235


class TestFusedSoftmax(TestLayer):

    def input_getter(self, shape, dtype):
        data_key = jax.random.PRNGKey(seed=1234)
236
        return jax.random.normal(data_key, shape, dtype), jnp.ones(shape, dtype=jnp.uint8)  # Masks
237
238
239
240
241

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        scale_factor = attrs[FusedSoftmaxAttr.SCALE_FACTOR]
        softmax_type = attrs[FusedSoftmaxAttr.ST_TYPE]

242
243
244
        praxis_p = pax_fiddle.Config(
            FusedSoftmax, name="fused_softmax", scale_factor=scale_factor, softmax_type=softmax_type
        )
245
246
247
248
249
250
251
252
253
254
        flax_cls = partial(Softmax, scale_factor=scale_factor, softmax_type=softmax_type)

        return praxis_p, flax_cls

    def sync_variables(self, praxis_variables, flax_variables):
        return praxis_variables, flax_variables

    def sync_wgrads(self, praxis_wgrads, flax_wgrads):
        return praxis_wgrads, flax_wgrads

255
256
257
    @pytest.mark.parametrize("data_shape", [(32, 1, 128, 128), (32, 1, 512, 128)])
    @pytest.mark.parametrize("dtype", DTYPE)
    @pytest.mark.parametrize("attrs", FusedSoftmaxAttr.ATTRS)
258
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
259
260
261
262
        if (attrs[FusedSoftmaxAttr.ST_TYPE] == SoftmaxType.SCALED_UPPER_TRIANG_MASKED) and (
            data_shape[-2] != data_shape[-1]
        ):
            pass  # Skip, due to not support
263
264
265
266
267
268
        else:
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class LinearAttr:
269
270
271
272
273
274
275
276
    FEATURE = "features"
    USE_BIAS = "use_bias"
    ATTRS = [
        {FEATURE: 512, USE_BIAS: False},
        {FEATURE: 512, USE_BIAS: True},
        {FEATURE: 1024, USE_BIAS: False},
        {FEATURE: 1024, USE_BIAS: True},
    ]
277
278
279
280
281
282
283
284
285


class TestLinear(TestLayer):

    def input_getter(self, shape, dtype):
        data_key = jax.random.PRNGKey(seed=1234)
        return (jax.random.normal(data_key, shape, dtype),)

    def get_layer_name(self):
286
        return "linear"
287
288
289
290
291
292
293
294
295

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        out_features = attrs[LinearAttr.FEATURE]
        kernel_init = WeightInit.Gaussian(1.0)
        use_bias = attrs[LinearAttr.USE_BIAS]
        bias_init = WeightInit.Constant(0.0)
        axis = -1
        transpose_batch_sequence = False

296
297
298
299
300
301
302
303
304
305
306
        praxis_p = pax_fiddle.Config(
            Linear,
            name="linear",
            dtype=dtype,
            out_features=out_features,
            params_init=kernel_init,
            use_bias=use_bias,
            bias_init=bias_init,
            axis=axis,
            transpose_batch_sequence=transpose_batch_sequence,
        )
307
308
309
310
311
312
313
314
        flax_cls = partial(
            DenseGeneral,
            features=out_features,
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
            use_bias=use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
            axis=axis,
            dtype=dtype,
315
316
            transpose_batch_sequence=transpose_batch_sequence,
        )
317
318
319

        return praxis_p, flax_cls

320
321
322
    @pytest.mark.parametrize("data_shape", DATA_SHAPE)
    @pytest.mark.parametrize("dtype", DTYPE)
    @pytest.mark.parametrize("attrs", LinearAttr.ATTRS)
323
324
325
326
327
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
328
329
330
331
332
333
334
    @pytest.mark.parametrize("data_shape", DATA_SHAPE)
    @pytest.mark.parametrize("dtype", DTYPE)
    @pytest.mark.parametrize("attrs", LinearAttr.ATTRS)
    @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
    def test_forward_backward_fp8(
        self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
    ):
335
336
337
338
339
340
341
342

        ds = DelayedScaling(fp8_format=fp8_format)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class LayerNormLinearAttr:
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    FEATURE = "features"
    USE_BIAS = "use_bias"
    ENABLE_LN = "enable_layernorm"
    LN_TYPE = "layernorm_type"
    ZERO_CEN = "zero_centered_gamma"
    ATTRS = [
        {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: False},
        {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: False},
        {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: True},
        {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: True},
        {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "rmsnorm", ZERO_CEN: False},
        {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "rmsnorm", ZERO_CEN: False},
        {FEATURE: 512, USE_BIAS: True, ENABLE_LN: False, LN_TYPE: "layernorm", ZERO_CEN: False},
    ]
357
358
359
360
361
362
363
364
365


class TestLayerNormLinear(TestLayer):

    def input_getter(self, shape, dtype):
        data_key = jax.random.PRNGKey(seed=1234)
        return (jax.random.normal(data_key, shape, dtype),)

    def get_layer_name(self):
366
        return "ln_linear"
367
368
369
370
371
372
373
374
375
376
377
378

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        out_features = attrs[LayerNormLinearAttr.FEATURE]
        enable_layernorm = attrs[LayerNormLinearAttr.ENABLE_LN]
        layernorm_type = attrs[LayerNormLinearAttr.LN_TYPE]
        zero_centered_gamma = attrs[LayerNormLinearAttr.ZERO_CEN]
        kernel_init = WeightInit.Gaussian(1.0)
        use_bias = attrs[LayerNormLinearAttr.USE_BIAS]
        bias_init = WeightInit.Constant(0.0)
        axis = -1
        transpose_batch_sequence = False

379
380
381
382
383
384
385
386
387
388
389
390
391
392
        praxis_p = pax_fiddle.Config(
            LayerNormLinear,
            name="ln_linear",
            dtype=dtype,
            out_features=out_features,
            enable_layernorm=enable_layernorm,
            layernorm_type=layernorm_type,
            zero_centered_gamma=zero_centered_gamma,
            params_init=kernel_init,
            use_bias=use_bias,
            bias_init=bias_init,
            axis=axis,
            transpose_batch_sequence=transpose_batch_sequence,
        )
393
394
395
396
397
398
399
400
401
402
403
        flax_cls = partial(
            LayerNormDenseGeneral,
            features=out_features,
            enable_layernorm=enable_layernorm,
            layernorm_type=layernorm_type,
            zero_centered_gamma=zero_centered_gamma,
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
            use_bias=use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
            axis=axis,
            dtype=dtype,
404
405
            transpose_batch_sequence=transpose_batch_sequence,
        )
406
407
408

        return praxis_p, flax_cls

409
410
411
    @pytest.mark.parametrize("data_shape", DATA_SHAPE)
    @pytest.mark.parametrize("dtype", DTYPE)
    @pytest.mark.parametrize("attrs", LayerNormLinearAttr.ATTRS)
412
413
414
415
416
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
417
418
419
420
421
422
423
    @pytest.mark.parametrize("data_shape", DATA_SHAPE)
    @pytest.mark.parametrize("dtype", DTYPE)
    @pytest.mark.parametrize("attrs", LayerNormLinearAttr.ATTRS)
    @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
    def test_forward_backward_fp8(
        self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
    ):
424
425
426
427
428
429
430
431

        ds = DelayedScaling(fp8_format=fp8_format)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class LayerNormMLPAttr:
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
    INTERMEDIATE_DIM = "intermediate_dim"
    USE_BIAS = "use_bias"
    ENABLE_LN = "enable_layernorm"
    LN_TYPE = "layernorm_type"
    ZERO_CEN = "zero_centered_gamma"
    ACTIVATION = "activations"
    ATTRS = [
        {
            INTERMEDIATE_DIM: 2048,
            USE_BIAS: True,
            ENABLE_LN: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: False,
            ACTIVATION: ("relu",),
        },
        {
            INTERMEDIATE_DIM: 2048,
            USE_BIAS: True,
            ENABLE_LN: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: True,
            ACTIVATION: ("relu",),
        },
        {
            INTERMEDIATE_DIM: 2048,
            USE_BIAS: True,
            ENABLE_LN: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ACTIVATION: ("relu",),
        },
        {
            INTERMEDIATE_DIM: 2048,
            USE_BIAS: True,
            ENABLE_LN: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ACTIVATION: ("gelu", "linear"),
        },
        {
            INTERMEDIATE_DIM: 2048,
            USE_BIAS: False,
            ENABLE_LN: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ACTIVATION: ("gelu", "linear"),
        },
        {
            INTERMEDIATE_DIM: 2048,
            USE_BIAS: True,
            ENABLE_LN: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ACTIVATION: ("silu", "linear"),
        },
        {
            INTERMEDIATE_DIM: 2048,
            USE_BIAS: False,
            ENABLE_LN: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ACTIVATION: ("silu", "linear"),
        },
    ]
496
497
498
499
500
501
502
503
504


class TestLayerNormMLP(TestLayer):

    def input_getter(self, shape, dtype):
        data_key = jax.random.PRNGKey(seed=1234)
        return (jax.random.normal(data_key, shape, dtype),)

    def get_layer_name(self):
505
        return "ln_mlp"
506
507
508
509
510
511
512
513
514
515
516
517
518

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        intermediate_dim = attrs[LayerNormMLPAttr.INTERMEDIATE_DIM]
        enable_layernorm = attrs[LayerNormMLPAttr.ENABLE_LN]
        layernorm_type = attrs[LayerNormMLPAttr.LN_TYPE]
        zero_centered_gamma = attrs[LayerNormMLPAttr.ZERO_CEN]
        kernel_init = WeightInit.Gaussian(1.0)
        use_bias = attrs[LayerNormMLPAttr.USE_BIAS]
        bias_init = WeightInit.Constant(0.0)
        activations = attrs[LayerNormMLPAttr.ACTIVATION]
        axis = -1
        transpose_batch_sequence = False

519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
        praxis_p = pax_fiddle.Config(
            LayerNormMLP,
            name="ln_mlp",
            dtype=dtype,
            intermediate_dim=intermediate_dim,
            enable_layernorm=enable_layernorm,
            layernorm_type=layernorm_type,
            zero_centered_gamma=zero_centered_gamma,
            params_init=kernel_init,
            use_bias=use_bias,
            bias_init=bias_init,
            activations=activations,
            intermediate_dropout_rate=0.0,
            axis=axis,
            transpose_batch_sequence=transpose_batch_sequence,
        )
535
536
537
538
539
540
541
542
543
544
545
546
547
        flax_cls = partial(
            flax_LayerNormMLP,
            intermediate_dim=intermediate_dim,
            enable_layernorm=enable_layernorm,
            layernorm_type=layernorm_type,
            zero_centered_gamma=zero_centered_gamma,
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
            use_bias=use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
            activations=activations,
            intermediate_dropout_rate=0.0,
            axis=axis,
            dtype=dtype,
548
549
            transpose_batch_sequence=transpose_batch_sequence,
        )
550
551
552

        return praxis_p, flax_cls

553
554
555
    @pytest.mark.parametrize("data_shape", DATA_SHAPE)
    @pytest.mark.parametrize("dtype", DTYPE)
    @pytest.mark.parametrize("attrs", LayerNormMLPAttr.ATTRS)
556
557
558
559
560
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
561
562
563
564
565
566
567
    @pytest.mark.parametrize("data_shape", DATA_SHAPE)
    @pytest.mark.parametrize("dtype", DTYPE)
    @pytest.mark.parametrize("attrs", LayerNormMLPAttr.ATTRS)
    @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
    def test_forward_backward_fp8(
        self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
    ):
568
569
570
571
572
573
574
575
576
577

        ds = DelayedScaling(fp8_format=fp8_format)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class TestRelativePositionBias(TestLayer):

    def get_layer_name(self):
578
        return "relative_position_bias"
579
580
581
582
583

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        num_buckets = 32
        max_distance = 128
        num_attention_heads = 64
584
        rb_stddev = (num_attention_heads * num_buckets) ** -0.5
585
586
        embedding_init = WeightInit.Gaussian(rb_stddev)

587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        praxis_p = pax_fiddle.Config(
            RelativePositionBiases,
            name="relative_position_bias",
            dtype=dtype,
            num_buckets=num_buckets,
            max_distance=max_distance,
            num_attention_heads=num_attention_heads,
            embedding_init=embedding_init,
        )
        flax_cls = partial(
            flax_RelativePositionBiases,
            num_buckets=num_buckets,
            max_distance=max_distance,
            num_attention_heads=num_attention_heads,
            embedding_init=TransformerEngineBaseLayer.generate_params_init(
                "rel_embedding", embedding_init
            ),
            dtype=dtype,
        )
606
607
608

        return praxis_p, flax_cls

609
610
611
    @pytest.mark.parametrize("data_shape", DATA_SHAPE)
    @pytest.mark.parametrize("dtype", DTYPE)
    @pytest.mark.parametrize("attrs", [{}])
612
613
614
615
616
617
618
619
620
621
622
623
624
    def test_forward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)

        init_key = jax.random.PRNGKey(seed=1234)

        test_inputs = [(128, 128, True), (128, 128, False)]
        for test_input in test_inputs:
            praxis_layer = praxis_p.Instantiate()
            praxis_variables = praxis_layer.init(init_key, *test_input)

            flax_layer = flax_cls()
            flax_variables = flax_layer.init(init_key, *test_input)
            if "params_axes" in flax_variables:
625
                flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
626
            if FP8Helper.is_fp8_enabled():
627
628
629
                flax_variables, _ = flax.core.pop(
                    flax_variables, FP8Helper.FP8_COLLECTION_NAME + "_axes"
                )
630
631
632

            praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)

633
634
635
636
637
638
            praxis_loss = TestLayer.loss(
                praxis_variables, *test_input, module=praxis_layer, mean_out=False
            )
            flax_loss = TestLayer.loss(
                flax_variables, *test_input, module=flax_layer, mean_out=False
            )
639
640
641
642

            assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)


643
class DotProductAttnAttr:
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
    ATTN_MASK_TYPE = "attn_mask_type"
    NUM_GQA_GROUPS = "num_gqa_groups"
    TRANSPOSE_BS = "transpose_batch_sequence"
    SCALE_FACTOR = "scale_factor"
    ATTRS = [
        {
            ATTN_MASK_TYPE: "padding",
            TRANSPOSE_BS: True,
            SCALE_FACTOR: 0.125,
        },
        {
            ATTN_MASK_TYPE: "padding_causal",
            TRANSPOSE_BS: True,
            SCALE_FACTOR: 0.125,
        },
        {
            ATTN_MASK_TYPE: "causal",
            TRANSPOSE_BS: True,
            SCALE_FACTOR: 0.125,
        },
        {
            ATTN_MASK_TYPE: "padding",
            TRANSPOSE_BS: False,
            SCALE_FACTOR: 0.125,
        },
        {
            ATTN_MASK_TYPE: "padding_causal",
            TRANSPOSE_BS: False,
            SCALE_FACTOR: 2.0,
        },
        {
            ATTN_MASK_TYPE: "causal",
            TRANSPOSE_BS: False,
            SCALE_FACTOR: 1.0,
        },
        {
            ATTN_MASK_TYPE: "no_mask",
            TRANSPOSE_BS: False,
            SCALE_FACTOR: 1.0,
        },
    ]
685
686
687
688
689
690
691


class TestDotProductAttn(TestLayer):

    def input_getter(self, shape, dtype):
        key = jax.random.PRNGKey(seed=1234)
        q_key, k_key, v_key = jax.random.split(key, 3)
692
693
        b, s, *_ = shape
        if self.attrs[DotProductAttnAttr.TRANSPOSE_BS]:
694
            shape = (shape[1], shape[0]) + shape[2:]
695
696
        mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
        return [
697
698
            *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]),
            mask,
699
        ]
700
701

    def get_layer_name(self):
702
        return "dot_product_attn"
703
704
705
706
707
708
709
710

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        head_dim = 64
        num_attention_heads = 16
        num_gqa_groups = num_attention_heads
        attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE]
        transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS]

711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
        praxis_p = pax_fiddle.Config(
            DotProductAttention,
            name="mha",
            dtype=dtype,
            head_dim=head_dim,
            num_attention_heads=num_attention_heads,
            num_gqa_groups=num_gqa_groups,
            attn_mask_type=attn_mask_type,
            transpose_batch_sequence=transpose_batch_sequence,
        )
        flax_cls = partial(
            flax_DotProductAttention,
            dtype=dtype,
            head_dim=head_dim,
            num_attention_heads=num_attention_heads,
            num_gqa_groups=num_gqa_groups,
            attn_mask_type=attn_mask_type,
            transpose_batch_sequence=transpose_batch_sequence,
        )
730
731
732

        return praxis_p, flax_cls

733
734
735
    @pytest.mark.parametrize("data_shape", [(32, 128, 16, 64)])
    @pytest.mark.parametrize("dtype", DTYPE)
    @pytest.mark.parametrize("attrs", DotProductAttnAttr.ATTRS)
736
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
737
        self.attrs = attrs
738
739
740
741
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


742
class MultiHeadAttnAttr:
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
    USE_BIAS = "use_bias"
    LN_TYPE = "layernorm_type"
    ATTN_MASK_TYPE = "attn_mask_type"
    ZERO_CEN = "zero_centered_gamma"
    NUM_ATTN_HEADS = "num_attention_heads"
    NUM_GQA_GROUPS = "num_gqa_groups"
    TRANSPOSE_BS = "transpose_batch_sequence"
    ENABLE_ROPE = "enable_rotary_pos_emb"
    ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
    LORA_SCOPE = "low_rank_adaptation_scope"
    ATTRS = [
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: False,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            ATTN_MASK_TYPE: "padding",
            TRANSPOSE_BS: True,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: True,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            ATTN_MASK_TYPE: "padding",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            ATTN_MASK_TYPE: "padding",
            TRANSPOSE_BS: True,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: False,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            ATTN_MASK_TYPE: "causal",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: True,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            ATTN_MASK_TYPE: "causal",
            TRANSPOSE_BS: True,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            ATTN_MASK_TYPE: "causal",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            NUM_ATTN_HEADS: 8,
            NUM_GQA_GROUPS: 4,
            ATTN_MASK_TYPE: "causal",
            TRANSPOSE_BS: True,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ENABLE_ROPE: True,
            ROPE_GROUP_METHOD: "consecutive",
            NUM_ATTN_HEADS: 8,
            NUM_GQA_GROUPS: 4,
            ATTN_MASK_TYPE: "causal",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ENABLE_ROPE: True,
            ROPE_GROUP_METHOD: "alternate",
            NUM_ATTN_HEADS: 8,
            NUM_GQA_GROUPS: 4,
            ATTN_MASK_TYPE: "causal",
            TRANSPOSE_BS: True,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: False,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            ATTN_MASK_TYPE: "padding",
            LORA_SCOPE: "all",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: False,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            ATTN_MASK_TYPE: "causal",
            LORA_SCOPE: "all",
            TRANSPOSE_BS: True,
        },
    ]
862
863
864
865
866


class TestMultiHeadAttn(TestLayer):

    def input_getter(self, shape, dtype):
867
868
        key = jax.random.PRNGKey(seed=1234)
        q_key, kv_key = jax.random.split(key, 2)
869
870
871
        b, s, *_ = shape
        if self.attrs[MultiHeadAttnAttr.TRANSPOSE_BS]:
            shape = (shape[1], shape[0]) + shape[2:]
872
873
        mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
        return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask]
874
875

    def get_layer_name(self):
876
        return "multi_head_attn"
877
878
879

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        head_dim = 64
880
        num_attention_heads = 16
881
882
883
884
885
        num_gqa_groups = (
            attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS]
            if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs
            else None
        )
886
887
888
889
890
        layernorm_type = attrs[MultiHeadAttnAttr.LN_TYPE]
        zero_centered_gamma = attrs[MultiHeadAttnAttr.ZERO_CEN]
        kernel_init = WeightInit.Gaussian(1.0)
        use_bias = attrs[MultiHeadAttnAttr.USE_BIAS]
        bias_init = WeightInit.Constant(0.0)
891
892
        input_layernorm = False
        return_layernorm_output = False
893
        attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE]
894
895
        enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE]
        rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD]
896
        low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, "none")
897
        fuse_qkv_params = True
898
        transpose_batch_sequence = attrs[MultiHeadAttnAttr.TRANSPOSE_BS]
899
900
901
902
        scale_attn_logits = False
        scaled_query_init = True
        float32_logits = False

903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
        praxis_p = pax_fiddle.Config(
            MultiHeadAttention,
            name="mha",
            dtype=dtype,
            head_dim=head_dim,
            num_attention_heads=num_attention_heads,
            num_gqa_groups=num_gqa_groups,
            layernorm_type=layernorm_type,
            zero_centered_gamma=zero_centered_gamma,
            params_init=kernel_init,
            use_bias=use_bias,
            bias_init=bias_init,
            return_layernorm_output=return_layernorm_output,
            input_layernorm=input_layernorm,
            attn_mask_type=attn_mask_type,
            enable_rotary_pos_emb=enable_rotary_pos_emb,
            rotary_pos_emb_group_method=rotary_pos_emb_group_method,
            low_rank_adaptation_scope=low_rank_adaptation_scope,
            fuse_qkv_params=fuse_qkv_params,
            transpose_batch_sequence=transpose_batch_sequence,
            scale_attn_logits=scale_attn_logits,
            scaled_query_init=scaled_query_init,
            float32_logits=float32_logits,
        )
927
928
929
930
        flax_cls = partial(
            flax_MultiHeadAttention,
            dtype=dtype,
            head_dim=head_dim,
931
932
            num_attention_heads=num_attention_heads,
            num_gqa_groups=num_gqa_groups,
933
934
935
936
937
            layernorm_type=layernorm_type,
            zero_centered_gamma=zero_centered_gamma,
            kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", kernel_init),
            use_bias=use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
938
939
            return_layernorm_output=return_layernorm_output,
            input_layernorm=input_layernorm,
940
            attn_mask_type=attn_mask_type,
941
942
            enable_rotary_pos_emb=enable_rotary_pos_emb,
            rotary_pos_emb_group_method=rotary_pos_emb_group_method,
943
            low_rank_adaptation_scope=low_rank_adaptation_scope,
944
            fuse_qkv_params=fuse_qkv_params,
945
946
947
            transpose_batch_sequence=transpose_batch_sequence,
            scale_attn_logits=scale_attn_logits,
            scaled_query_init=scaled_query_init,
948
949
            float32_logits=float32_logits,
        )
950
951
952

        return praxis_p, flax_cls

953
954
955
    @pytest.mark.parametrize("data_shape", DATA_SHAPE)
    @pytest.mark.parametrize("dtype", DTYPE)
    @pytest.mark.parametrize("attrs", MultiHeadAttnAttr.ATTRS)
956
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
957
        self.attrs = attrs
958
959
960
961
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
962
963
964
965
966
967
968
    @pytest.mark.parametrize("data_shape", DATA_SHAPE)
    @pytest.mark.parametrize("dtype", DTYPE)
    @pytest.mark.parametrize("attrs", MultiHeadAttnAttr.ATTRS)
    @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
    def test_forward_backward_fp8(
        self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
    ):
969
        self.attrs = attrs
970
971
972
973
974
975
976
        ds = DelayedScaling(fp8_format=fp8_format)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)


class TransformerLayerAttr:
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
    USE_BIAS = "use_bias"
    LN_TYPE = "layernorm_type"
    ACTIVATION = "activations"
    LYR_TYPE = "layer_type"
    ZERO_CEN = "zero_centered_gamma"
    TRANSPOSE_BS = "transpose_batch_sequence"
    ENABLE_ROPE = "enable_rotary_pos_emb"
    ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
    LORA_SCOPE = "low_rank_adaptation_scope"
    ATTRS = [
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: False,
            ACTIVATION: ("relu",),
            LYR_TYPE: TransformerLayerType.ENCODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: True,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: False,
            ACTIVATION: ("relu",),
            LYR_TYPE: TransformerLayerType.ENCODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: True,
            ACTIVATION: ("relu",),
            LYR_TYPE: TransformerLayerType.ENCODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: True,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: True,
            ACTIVATION: ("relu",),
            LYR_TYPE: TransformerLayerType.ENCODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ACTIVATION: ("relu",),
            LYR_TYPE: TransformerLayerType.ENCODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: True,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ACTIVATION: ("relu",),
            LYR_TYPE: TransformerLayerType.ENCODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: True,
            ACTIVATION: ("relu",),
            LYR_TYPE: TransformerLayerType.DECODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: True,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: True,
            ACTIVATION: ("relu",),
            LYR_TYPE: TransformerLayerType.DECODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: False,
            ACTIVATION: ("relu",),
            LYR_TYPE: TransformerLayerType.DECODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: True,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: False,
            ACTIVATION: ("relu",),
            LYR_TYPE: TransformerLayerType.DECODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ACTIVATION: ("relu",),
            LYR_TYPE: TransformerLayerType.DECODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: True,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ACTIVATION: ("relu",),
            LYR_TYPE: TransformerLayerType.DECODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: False,
            ACTIVATION: ("gelu", "linear"),
            LYR_TYPE: TransformerLayerType.ENCODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: True,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: False,
            ACTIVATION: ("gelu", "linear"),
            LYR_TYPE: TransformerLayerType.ENCODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ACTIVATION: ("gelu", "linear"),
            LYR_TYPE: TransformerLayerType.ENCODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: True,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ACTIVATION: ("gelu", "linear"),
            LYR_TYPE: TransformerLayerType.ENCODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: False,
            ACTIVATION: ("gelu",),
            LYR_TYPE: TransformerLayerType.ENCODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: False,
            LORA_SCOPE: "all",
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: False,
            ACTIVATION: ("gelu", "linear"),
            LYR_TYPE: TransformerLayerType.DECODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: True,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: False,
            ACTIVATION: ("gelu", "linear"),
            LYR_TYPE: TransformerLayerType.DECODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ACTIVATION: ("gelu", "linear"),
            LYR_TYPE: TransformerLayerType.DECODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: True,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "rmsnorm",
            ZERO_CEN: False,
            ACTIVATION: ("gelu", "linear"),
            LYR_TYPE: TransformerLayerType.DECODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: True,
            ACTIVATION: ("gelu",),
            LYR_TYPE: TransformerLayerType.ENCODER,
            ENABLE_ROPE: True,
            ROPE_GROUP_METHOD: "alternate",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: True,
            ACTIVATION: ("gelu",),
            LYR_TYPE: TransformerLayerType.DECODER,
            ENABLE_ROPE: True,
            ROPE_GROUP_METHOD: "alternate",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: True,
            ACTIVATION: ("gelu",),
            LYR_TYPE: TransformerLayerType.ENCODER,
            ENABLE_ROPE: True,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: True,
            ACTIVATION: ("gelu",),
            LYR_TYPE: TransformerLayerType.DECODER,
            ENABLE_ROPE: True,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: False,
        },
        {
            USE_BIAS: True,
            LN_TYPE: "layernorm",
            ZERO_CEN: False,
            ACTIVATION: ("gelu",),
            LYR_TYPE: TransformerLayerType.DECODER,
            ENABLE_ROPE: False,
            ROPE_GROUP_METHOD: "consecutive",
            TRANSPOSE_BS: False,
            LORA_SCOPE: "all",
        },
    ]
1250
1251
1252
1253
1254


class TestTransformer(TestLayer):

    def input_getter(self, shape, dtype):
1255
1256
1257
1258
        key = jax.random.PRNGKey(seed=1234)
        q_key, kv_key = jax.random.split(key, 2)
        b, s, *_ = shape
        if self.attrs[TransformerLayerAttr.TRANSPOSE_BS]:
1259
            shape = (shape[1], shape[0]) + shape[2:]
1260
1261
        mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
        return [
1262
1263
1264
            *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]),
            mask,
            mask,
1265
        ]
1266
1267

    def get_layer_name(self):
1268
        return "transformerlayer"
1269
1270
1271
1272
1273
1274
1275
1276

    def generate_praxis_p_and_flax_cls(self, dtype, attrs):
        hidden_size = 512
        mlp_hidden_size = 2048
        num_attention_heads = 8
        layernorm_type = attrs[TransformerLayerAttr.LN_TYPE]
        hidden_dropout = 0.0
        attention_dropout = 0.0
1277
        intermediate_dropout = 0.0
1278
1279
1280
1281
1282
        mlp_activations = attrs[TransformerLayerAttr.ACTIVATION]
        kernel_init = WeightInit.Gaussian(1.0)
        use_bias = attrs[TransformerLayerAttr.USE_BIAS]
        bias_init = WeightInit.Constant(0.0)
        layer_type = attrs[TransformerLayerAttr.LYR_TYPE]
1283
        enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
1284
        rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD]
1285
        low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, "none")
1286
        enable_relative_embedding = True
1287
1288
1289
        relative_embedding = pax_fiddle.Config(
            RelativePositionBiases, dtype=dtype, num_attention_heads=num_attention_heads
        )
1290
1291
1292
1293
        drop_path = 0.0
        transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS]

        rel_embedding_init = RelativePositionBiases.generate_embedding_init(
1294
1295
1296
1297
            relative_embedding.embedding_init,
            relative_embedding.num_attention_heads,
            relative_embedding.num_buckets,
        )
1298
1299
1300
1301
1302
1303

        relative_embedding_flax_module = flax_RelativePositionBiases(
            num_buckets=relative_embedding.num_buckets,
            max_distance=relative_embedding.max_distance,
            num_attention_heads=relative_embedding.num_attention_heads,
            embedding_init=TransformerEngineBaseLayer.generate_params_init(
1304
1305
                "rel_embedding", rel_embedding_init
            ),
1306
            embedding_axes=relative_embedding.embedding_axes,
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
            dtype=relative_embedding.dtype,
        )

        praxis_p = pax_fiddle.Config(
            TransformerLayer,
            name="transformer_layer",
            params_init=kernel_init,
            dtype=dtype,
            hidden_size=hidden_size,
            mlp_hidden_size=mlp_hidden_size,
            num_attention_heads=num_attention_heads,
            layernorm_type=layernorm_type,
            hidden_dropout=hidden_dropout,
            attention_dropout=attention_dropout,
            intermediate_dropout=intermediate_dropout,
            mlp_activations=mlp_activations,
            use_bias=use_bias,
            bias_init=bias_init,
            layer_type=layer_type,
            enable_relative_embedding=enable_relative_embedding,
            enable_rotary_pos_emb=enable_rotary_pos_emb,
            rotary_pos_emb_group_method=rotary_pos_emb_group_method,
            low_rank_adaptation_scope=low_rank_adaptation_scope,
            relative_embedding=relative_embedding,
            drop_path=drop_path,
            transpose_batch_sequence=transpose_batch_sequence,
        )
        flax_cls = partial(
            flax_TransformerLayer,
            dtype=dtype,
            hidden_size=hidden_size,
            mlp_hidden_size=mlp_hidden_size,
            num_attention_heads=num_attention_heads,
            layernorm_type=layernorm_type,
            hidden_dropout=hidden_dropout,
            attention_dropout=attention_dropout,
            intermediate_dropout=intermediate_dropout,
            mlp_activations=mlp_activations,
            mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
                "mha_kernel", kernel_init
            ),
            mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
                "mlp_kernel", kernel_init
            ),
            use_bias=use_bias,
            bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
            layer_type=layer_type,
            enable_rotary_pos_emb=enable_rotary_pos_emb,
            rotary_pos_emb_group_method=rotary_pos_emb_group_method,
            enable_relative_embedding=enable_relative_embedding,
            relative_embedding=relative_embedding_flax_module,
            low_rank_adaptation_scope=low_rank_adaptation_scope,
            drop_path=drop_path,
            transpose_batch_sequence=transpose_batch_sequence,
        )
1362
1363
1364

        return praxis_p, flax_cls

1365
1366
1367
    @pytest.mark.parametrize("data_shape", DATA_SHAPE)
    @pytest.mark.parametrize("dtype", DTYPE)
    @pytest.mark.parametrize("attrs", TransformerLayerAttr.ATTRS)
1368
    def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
1369
        self.attrs = attrs
1370
1371
1372
1373
        praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
        self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)

    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
1374
1375
1376
1377
1378
1379
1380
    @pytest.mark.parametrize("data_shape", DATA_SHAPE)
    @pytest.mark.parametrize("dtype", DTYPE)
    @pytest.mark.parametrize("attrs", TransformerLayerAttr.ATTRS)
    @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
    def test_forward_backward_fp8(
        self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
    ):
1381
        self.attrs = attrs
1382
1383
1384
1385
        ds = DelayedScaling(fp8_format=fp8_format)
        with fp8_autocast(enabled=True, fp8_recipe=ds):
            praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
            self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)