test_layer.py 18 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Test transformer_engine.jax.flax.TransformerLayer"""
5
import os
6
from functools import partial
7
from typing import Dict, Tuple, Optional
8
9
10
11
12
13

import flax
import jax
import jax.numpy as jnp
import pytest

14
15
16
17
18
19
from utils import (
    assert_allclose,
    assert_tree_like_allclose,
    dtype_tols,
    sync_params_values,
)
20
21
22
from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer

23
from transformer_engine.common import recipe
zlsh80826's avatar
zlsh80826 committed
24
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
25
26
27
28
29
30
from transformer_engine.jax.quantize import (
    QuantizeConfig,
    ScalingMode,
    is_fp8_available,
    update_collections,
)
31

32

33
@pytest.fixture(autouse=True, scope="function")
34
def enable_fused_attn():
35
    """Enable fused attention"""
36
37
38
39
40
    os.environ["NVTE_FUSED_ATTN"] = "1"
    yield
    del os.environ["NVTE_FUSED_ATTN"]


41
is_fp8_supported, reason = is_fp8_available()
42
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
43
44
45
46
47
48
49
50
51

QUANTIZE_RECIPES = []
""" Find supported scaling modes"""
if is_fp8_supported:
    QUANTIZE_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
if is_mxfp8_supported:
    QUANTIZE_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))


52
53
DATA_SHAPE = [  # (batch, seqlen, emb_dim)
    pytest.param((32, 128, 1024), id="32-128-1024"),
54
]
55
DTYPE = [jnp.bfloat16]
56
57
58
59
60

_KEY_OF_RESIDUAL_POST_LAYERNORM = "apply_residual_connection_post_layernorm"
_KEY_OF_OUTPUT_LAYERNORM = "output_layernorm"
_KEY_OF_DROP_PATH = "drop_path"
_KEY_OF_FUSE_QKV_PARAMS = "fuse_qkv_params"
61
62
63
64
65
_KEY_OF_HIDDEN_DROPOUT = "hidden_dropout"
_KEY_OF_ATTENTION_DROPOUT = "attention_dropout"
_KEY_OF_INTERMEDIATE_DROPOUT = "intermediate_dropout"
_KEY_OF_HIDDEN_DROPOUT_DIMS = "hidden_dropout_dims"
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS = "intermediate_dropout_dims"
66
_KEY_OF_MLP_ACTIVATIONS = "mlp_activations"
67
68
69
70
_KEY_OF_LAYERNORM_TYPE = "layernorm_type"
_KEY_OF_LAYERNORM_EPS = "layernorm_epsilon"
_KEY_OF_ZERO_CENTERED_GAMMA = "zero_centered_gamma"
_KEY_OF_TRANSPOSE_BS = "transpose_batch_sequence"
71
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
72
73
_KEY_OF_NUM_HEADS = "num_attention_heads"
_KEY_OF_NUM_GQA_GROUPS = "num_gqa_groups"
74
_KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
75
_KEY_OF_ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
76
77
78
79
80
_KEY_OF_SELF_ATTN_BIAS_TYPE = "self_attn_bias_type"
_KEY_OF_SELF_ATTN_MASK_TYPE = "self_attn_mask_type"
_KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
81
_KEY_OF_WINDOW_SIZE = "window_size"
82

zlsh80826's avatar
zlsh80826 committed
83
84
85
BASE_ATTRS = {
    _KEY_OF_TRANSPOSE_BS: True,
    _KEY_OF_NUM_HEADS: 8,
86
    _KEY_OF_HIDDEN_DROPOUT: 0,
87
    _KEY_OF_ATTENTION_DROPOUT: 0.0,
88
89
    _KEY_OF_INTERMEDIATE_DROPOUT: 0,
    _KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
90
    _KEY_OF_LAYERNORM_TYPE: "layernorm",
91
    _KEY_OF_WINDOW_SIZE: (-1, -1),
zlsh80826's avatar
zlsh80826 committed
92
}
93

94
ATTRS = [
95
    # attrs0
96
    {},
97
    # attrs1
98
99
100
    {
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
    },
101
    # attrs2
102
103
104
105
    {
        _KEY_OF_ZERO_CENTERED_GAMMA: True,
        _KEY_OF_LAYERNORM_EPS: 1e-2,
    },
106
    # attrs3
107
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_RESIDUAL_POST_LAYERNORM: True},
108
    # attrs4
109
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_OUTPUT_LAYERNORM: True},
110
    # attrs5
111
112
113
114
115
    {
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_RESIDUAL_POST_LAYERNORM: True,
        _KEY_OF_OUTPUT_LAYERNORM: True,
    },
116
    # attrs6
117
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_DROP_PATH: 0.1},
118
    # attrs7
119
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_FUSE_QKV_PARAMS: False},
120
    # attrs8
121
122
123
124
    {
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
    },
125
    # attrs9
126
127
128
129
130
131
132
133
    {
        _KEY_OF_SCALE_ATTN_LOGITS: True,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_HIDDEN_DROPOUT: 0.8,
        _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
        _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
        _KEY_OF_USE_BIAS: True,
    },
134
    # attrs10
135
136
137
138
139
140
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_SCALE_ATTN_LOGITS: True,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
    },
141
    # attrs11
142
143
144
145
146
147
148
149
    {
        _KEY_OF_NUM_HEADS: 8,
        _KEY_OF_NUM_GQA_GROUPS: 4,
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_SCALE_ATTN_LOGITS: True,
        _KEY_OF_MLP_ACTIVATIONS: ("gelu",),
        _KEY_OF_USE_BIAS: True,
    },
150
    # attrs12
151
152
153
154
155
156
157
158
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_NUM_GQA_GROUPS: 1,
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
159
    # attrs13
160
161
162
163
164
165
    {
        _KEY_OF_TRANSPOSE_BS: True,
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
        _KEY_OF_USE_BIAS: True,
    },
166
    # attrs14
167
168
169
170
171
172
173
174
175
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_LAYERNORM_TYPE: "layernorm",
        _KEY_OF_NUM_GQA_GROUPS: 2,
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "alternate",
        _KEY_OF_USE_BIAS: True,
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
176
    # attrs15
177
178
179
180
181
182
183
    {
        _KEY_OF_TRANSPOSE_BS: True,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "alternate",
        _KEY_OF_USE_BIAS: True,
    },
184
    # attrs16
185
186
187
188
189
190
    {
        _KEY_OF_HIDDEN_DROPOUT: 0.3,
        _KEY_OF_HIDDEN_DROPOUT_DIMS: (0,),
        _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
        _KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,),
    },
191
    # attrs17
192
193
194
195
    {
        _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
        _KEY_OF_USE_BIAS: True,
    },
196
    # attrs18
197
198
199
200
    {
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
    },
201
    # attrs19
202
203
204
    {
        _KEY_OF_ATTENTION_DROPOUT: 0.3,
    },
205
    # attrs20
206
207
208
    {
        _KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")),
    },
209
    # attrs21
210
211
212
213
214
215
216
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
        _KEY_OF_WINDOW_SIZE: (64, 0),  # Left size must < DATA_SHAPE seqlen
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
217
    # attrs22
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
        _KEY_OF_WINDOW_SIZE: None,
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
    # attrs23
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
    # attrs24
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask",
    },
    # attrs25
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask",
        _KEY_OF_WINDOW_SIZE: (2, 2),
    },
    # attrs26
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
        _KEY_OF_WINDOW_SIZE: (2, 2),
    },
    # attrs27
253
254
255
256
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
257
258
259
260
261
262
        _KEY_OF_WINDOW_SIZE: None,
    },
    # attrs28
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
263
264
        _KEY_OF_WINDOW_SIZE: (2, 2),
    },
265
]
266
267
268
269

ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]


270
271
class BaseRunner:
    """Base runner to define forward and backward tests"""
272

273
274
275
    layer_type: TransformerLayerType = None
    reference_layer: flax.linen.Module = None
    transformations: Dict[str, str] = None
276

277
278
279
280
    def __init__(self, attrs):
        self.attrs = attrs
        self._generate_test_rngs()
        # Disable fused attention for attention dropout because the different dropout impl
281
282
        if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv("NVTE_FUSED_ATTN"):
            os.environ["NVTE_FUSED_ATTN"] = "0"
283

284
285
286
    def _generate_test_rngs(self):
        root_rng = jax.random.PRNGKey(0)
        params_rng, init_dropout_rng, apply_dropout_rng = jax.random.split(root_rng, 3)
287
288
        self.init_rng = {"params": params_rng, "dropout": init_dropout_rng}
        self.apply_rng = {"dropout": apply_dropout_rng}
289

290
291
292
    def _generate_layer(self, layer_cls, diff_inputs, no_diff_inputs):
        layer = layer_cls()
        variables = layer.init(self.init_rng, *diff_inputs, *no_diff_inputs)
293
        others, params = flax.core.pop(variables, "params")
294
295
        del variables
        return layer, params, others
296

297
    def _loss_fn(self, diff_xs, no_diff_xs, params, others, model):
298
        variables = {"params": params, **others}
299
300
301
302
303
304
305
306
        output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng)
        return jnp.mean(output, dtype=jnp.float32).astype(output.dtype)

    def _sync_params(self, ref, target):
        """Copy the reference params to target"""
        target = sync_params_values(target, ref, self.transformations)
        return ref, target

307
308
309
310
311
312
313
    def test_forward(
        self,
        data_shape: Tuple[int],
        dtype: jnp.dtype,
        rtol: Optional[float] = None,
        atol: Optional[float] = None,
    ) -> None:
314
315
316
        """Test only the forward"""
        inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)

317
318
        ref_layer_cls = partial(self.reference_layer, **self.attrs)
        layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
319
320
321
322
323
324
325
326

        ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
        test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
        ref_params, test_params = self._sync_params(ref_params, test_params)

        ref_out = self._loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer)
        test_out = self._loss_fn(inputs, test_masks, test_params, test_others, test_layer)

327
328
        tols = dtype_tols(dtype, rtol=rtol, atol=atol)
        assert_allclose(ref_out, test_out, **tols)
329

330
331
332
333
334
335
336
    def test_backward(
        self,
        data_shape: Tuple[int],
        dtype: jnp.dtype,
        rtol: Optional[float] = None,
        atol: Optional[float] = None,
    ) -> None:
337
338
339
        """Test forward and backward through value_and_grad()"""
        inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)

340
341
        ref_layer_cls = partial(self.reference_layer, **self.attrs)
        layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
342
343
344
345
346

        ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
        test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)

        ref_params, test_params = self._sync_params(ref_params, test_params)
347

348
        if QuantizeConfig.is_fp8_enabled():
349
            for _ in range(4):
350
                _, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
351
352
353
354
355
356
                    inputs,
                    test_masks,
                    test_params,
                    test_others,
                    test_layer,
                )
357
                if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
358
359
360
361
362
363
364
365
                    _, updated_quantize_meta = flax.core.pop(
                        updated_state[0], QuantizeConfig.COLLECTION_NAME
                    )
                    test_others = update_collections(
                        {QuantizeConfig.COLLECTION_NAME: updated_quantize_meta}, test_others
                    )
                    del updated_quantize_meta
                del updated_state
366

367
368
        grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False)

369
370
371
372
373
374
        ref_out, (ref_dgrads, ref_wgrads) = grad_fn(
            inputs, ref_masks, ref_params, ref_others, ref_layer
        )
        test_out, (test_dgrads, test_wgrads) = grad_fn(
            inputs, test_masks, test_params, test_others, test_layer
        )
375

376
377
378
        tols = dtype_tols(dtype, rtol=rtol, atol=atol)
        assert_allclose(ref_out, test_out, **tols)
        assert_tree_like_allclose(ref_dgrads, test_dgrads, **tols)
379
380

        _, restructed_ref_wgrads = self._sync_params(ref_wgrads, test_wgrads)
381
        assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, **tols)
382
383
384
385


class EncoderRunner(BaseRunner):
    """Encoder runner implementations"""
386

387
388
389
    layer_type = TransformerLayerType.ENCODER
    reference_layer = RefEncoderLayer
    transformations = {
390
391
392
393
394
395
396
397
398
399
        "attention/qkv/scale": "pre_attention_layer_norm/scale",
        "attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias",
        "attention/query/scale": "pre_attention_layer_norm/scale",
        "attention/query/ln_bias": "pre_attention_layer_norm/ln_bias",
        "mlp/wi_kernel": "mlp/wi/kernel",
        "mlp/wi_bias": "mlp/wi/bias",
        "mlp/wo_kernel": "mlp/wo/kernel",
        "mlp/wo_bias": "mlp/wo/bias",
        "mlp/scale": "pre_mlp_layer_norm/scale",
        "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
400
401
402
403
404
405
406
    }

    def generate_inputs(self, data_shape, dtype):
        """
        Return inputs, (ref_masks, test_masks)
        """
        transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
407
408
409
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
410

411
412
        data_rng = jax.random.PRNGKey(2024)
        inputs = (jax.random.normal(data_rng, data_shape, dtype),)
413

414
415
416
        mask_shape = (batch, 1, seqlen, seqlen)
        padded_mask = jnp.zeros(mask_shape, dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones(mask_shape, dtype=jnp.uint8), k=1)
417
        if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
418
419
420
421
            mask = causal_mask
        else:
            mask = padded_mask
        ref_masks = (1 - mask,)
422
        test_masks = (None, mask)  # The second arg of Transformer is encoded tokens.
423
424
425
426
427
428
429
430

        return inputs, (ref_masks, test_masks)


class DecoderRunner(BaseRunner):
    """
    Decoder runner implementations
    """
431

432
433
434
    layer_type = TransformerLayerType.DECODER
    reference_layer = RefDecoderLayer
    transformations = {
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        "encoder_decoder_attention/qkv/scale": "pre_cross_attention_layer_norm/scale",
        "encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
        "encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale",
        "encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
        "self_attention/qkv/scale": "pre_self_attention_layer_norm/scale",
        "self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias",
        "self_attention/query/scale": "pre_self_attention_layer_norm/scale",
        "self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias",
        "mlp/wi_kernel": "mlp/wi/kernel",
        "mlp/wi_bias": "mlp/wi/bias",
        "mlp/wo_kernel": "mlp/wo/kernel",
        "mlp/wo_bias": "mlp/wo/bias",
        "mlp/scale": "pre_mlp_layer_norm/scale",
        "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
449
450
451
452
453
454
455
    }

    def generate_inputs(self, data_shape, dtype):
        """
        Return inputs, (ref_masks, test_masks)
        """
        transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
456
457
458
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
459

460
461
        data_rng = jax.random.PRNGKey(0)
        data_rng_0, data_rng_1 = jax.random.split(data_rng, 2)
462
463
464
465
        inputs = (
            jax.random.normal(data_rng_0, data_shape, dtype),
            jax.random.normal(data_rng_1, data_shape, dtype),
        )
466
467
468

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
469
        if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
470
471
472
            self_mask = causal_mask
        else:
            self_mask = padded_mask
473

474
475
476
477
478
479
        ref_masks = (1 - self_mask, 1 - padded_mask)
        test_masks = (self_mask, padded_mask)

        return inputs, (ref_masks, test_masks)


480
481
482
483
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", ATTRS)
class BaseTester:
484
485
486
    """
    Pytest interface to invoke the runner
    """
487

488
    runner = BaseRunner
489
490

    def test_forward(self, data_shape, dtype, attrs):
491
        """Test normal datatype forward"""
492
        QuantizeConfig.finalize()  # Ensure FP8 disabled.
493
        self.runner(attrs).test_forward(data_shape, dtype)
494
495
496

    def test_backward(self, data_shape, dtype, attrs):
        """Test normal datatype backward"""
497
        QuantizeConfig.finalize()  # Ensure FP8 disabled.
498
        self.runner(attrs).test_backward(data_shape, dtype)
499

500
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
501
502
    @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
    def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
503
        """Test forward with fp8 enabled"""
504
        QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
505
        self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
506
        QuantizeConfig.finalize()
507

508
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
509
510
    @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
    def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
511
        """Test backward with fp8 enabled"""
512
        QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
513
        self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
514
        QuantizeConfig.finalize()
515
516
517
518
519
520


class TestEncoderLayer(BaseTester):
    """
    Test transformer_engine.jax.flax.TransformerLayer(layer_type=Encoder)
    """
521

522
523
524
525
526
527
528
    runner = EncoderRunner


class TestDecoderLayer(BaseTester):
    """
    Test transformer_engine.jax.flax.TransformerLayer(layer_type=Decoder)
    """
529

530
    runner = DecoderRunner