test_layer.py 18.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Test transformer_engine.jax.flax.TransformerLayer"""
5
import os
6
from functools import partial
7
from typing import Dict, Tuple, Optional
8
9
10
11
12
13

import flax
import jax
import jax.numpy as jnp
import pytest

14
15
16
17
18
19
from utils import (
    assert_allclose,
    assert_tree_like_allclose,
    dtype_tols,
    sync_params_values,
)
20
21
22
from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer

23
from transformer_engine.common import recipe
zlsh80826's avatar
zlsh80826 committed
24
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
25
from transformer_engine.jax.quantize import (
26
    get_quantize_config,
27
28
29
    ScalingMode,
    is_fp8_available,
    update_collections,
30
    TensorSource,
31
    autocast,
32
)
33
from transformer_engine.jax.sharding import MeshResource
34

35

36
@pytest.fixture(autouse=True, scope="function")
37
def enable_fused_attn():
38
    """Enable fused attention"""
39
40
41
42
43
    os.environ["NVTE_FUSED_ATTN"] = "1"
    yield
    del os.environ["NVTE_FUSED_ATTN"]


44
is_fp8_supported, reason = is_fp8_available()
45
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
46
47
48
49
50
51
52
53
54

QUANTIZE_RECIPES = []
""" Find supported scaling modes"""
if is_fp8_supported:
    QUANTIZE_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
if is_mxfp8_supported:
    QUANTIZE_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))


55
56
DATA_SHAPE = [  # (batch, seqlen, emb_dim)
    pytest.param((32, 128, 1024), id="32-128-1024"),
57
]
58
DTYPE = [jnp.bfloat16]
59
60
61
62
63

_KEY_OF_RESIDUAL_POST_LAYERNORM = "apply_residual_connection_post_layernorm"
_KEY_OF_OUTPUT_LAYERNORM = "output_layernorm"
_KEY_OF_DROP_PATH = "drop_path"
_KEY_OF_FUSE_QKV_PARAMS = "fuse_qkv_params"
64
65
66
67
68
_KEY_OF_HIDDEN_DROPOUT = "hidden_dropout"
_KEY_OF_ATTENTION_DROPOUT = "attention_dropout"
_KEY_OF_INTERMEDIATE_DROPOUT = "intermediate_dropout"
_KEY_OF_HIDDEN_DROPOUT_DIMS = "hidden_dropout_dims"
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS = "intermediate_dropout_dims"
69
_KEY_OF_MLP_ACTIVATIONS = "mlp_activations"
70
71
72
73
_KEY_OF_LAYERNORM_TYPE = "layernorm_type"
_KEY_OF_LAYERNORM_EPS = "layernorm_epsilon"
_KEY_OF_ZERO_CENTERED_GAMMA = "zero_centered_gamma"
_KEY_OF_TRANSPOSE_BS = "transpose_batch_sequence"
74
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
75
76
_KEY_OF_NUM_HEADS = "num_attention_heads"
_KEY_OF_NUM_GQA_GROUPS = "num_gqa_groups"
77
_KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
78
_KEY_OF_ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
79
80
81
82
83
_KEY_OF_SELF_ATTN_BIAS_TYPE = "self_attn_bias_type"
_KEY_OF_SELF_ATTN_MASK_TYPE = "self_attn_mask_type"
_KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
84
_KEY_OF_WINDOW_SIZE = "window_size"
85

zlsh80826's avatar
zlsh80826 committed
86
87
88
BASE_ATTRS = {
    _KEY_OF_TRANSPOSE_BS: True,
    _KEY_OF_NUM_HEADS: 8,
89
    _KEY_OF_HIDDEN_DROPOUT: 0,
90
    _KEY_OF_ATTENTION_DROPOUT: 0.0,
91
92
    _KEY_OF_INTERMEDIATE_DROPOUT: 0,
    _KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
93
    _KEY_OF_LAYERNORM_TYPE: "layernorm",
94
    _KEY_OF_WINDOW_SIZE: (-1, -1),
zlsh80826's avatar
zlsh80826 committed
95
}
96

97
ATTRS = [
98
    # attrs0
99
    {},
100
    # attrs1
101
102
103
    {
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
    },
104
    # attrs2
105
106
107
108
    {
        _KEY_OF_ZERO_CENTERED_GAMMA: True,
        _KEY_OF_LAYERNORM_EPS: 1e-2,
    },
109
    # attrs3
110
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_RESIDUAL_POST_LAYERNORM: True},
111
    # attrs4
112
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_OUTPUT_LAYERNORM: True},
113
    # attrs5
114
115
116
117
118
    {
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_RESIDUAL_POST_LAYERNORM: True,
        _KEY_OF_OUTPUT_LAYERNORM: True,
    },
119
    # attrs6
120
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_DROP_PATH: 0.1},
121
    # attrs7
122
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_FUSE_QKV_PARAMS: False},
123
    # attrs8
124
125
126
127
    {
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
    },
128
    # attrs9
129
130
131
132
133
134
135
136
    {
        _KEY_OF_SCALE_ATTN_LOGITS: True,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_HIDDEN_DROPOUT: 0.8,
        _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
        _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
        _KEY_OF_USE_BIAS: True,
    },
137
    # attrs10
138
139
140
141
142
143
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_SCALE_ATTN_LOGITS: True,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
    },
144
    # attrs11
145
146
147
148
149
150
151
152
    {
        _KEY_OF_NUM_HEADS: 8,
        _KEY_OF_NUM_GQA_GROUPS: 4,
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_SCALE_ATTN_LOGITS: True,
        _KEY_OF_MLP_ACTIVATIONS: ("gelu",),
        _KEY_OF_USE_BIAS: True,
    },
153
    # attrs12
154
155
156
157
158
159
160
161
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_NUM_GQA_GROUPS: 1,
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
162
    # attrs13
163
164
165
166
167
168
    {
        _KEY_OF_TRANSPOSE_BS: True,
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
        _KEY_OF_USE_BIAS: True,
    },
169
    # attrs14
170
171
172
173
174
175
176
177
178
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_LAYERNORM_TYPE: "layernorm",
        _KEY_OF_NUM_GQA_GROUPS: 2,
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "alternate",
        _KEY_OF_USE_BIAS: True,
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
179
    # attrs15
180
181
182
183
184
185
186
    {
        _KEY_OF_TRANSPOSE_BS: True,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "alternate",
        _KEY_OF_USE_BIAS: True,
    },
187
    # attrs16
188
189
190
191
192
193
    {
        _KEY_OF_HIDDEN_DROPOUT: 0.3,
        _KEY_OF_HIDDEN_DROPOUT_DIMS: (0,),
        _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
        _KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,),
    },
194
    # attrs17
195
196
197
198
    {
        _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
        _KEY_OF_USE_BIAS: True,
    },
199
    # attrs18
200
201
202
203
    {
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
    },
204
    # attrs19
205
206
207
    {
        _KEY_OF_ATTENTION_DROPOUT: 0.3,
    },
208
    # attrs20
209
210
211
    {
        _KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")),
    },
212
    # attrs21
213
214
215
216
217
218
219
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
        _KEY_OF_WINDOW_SIZE: (64, 0),  # Left size must < DATA_SHAPE seqlen
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
220
    # attrs22
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
        _KEY_OF_WINDOW_SIZE: None,
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
    # attrs23
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
    # attrs24
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask",
    },
    # attrs25
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask",
        _KEY_OF_WINDOW_SIZE: (2, 2),
    },
    # attrs26
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
        _KEY_OF_WINDOW_SIZE: (2, 2),
    },
    # attrs27
256
257
258
259
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
260
261
262
263
264
265
        _KEY_OF_WINDOW_SIZE: None,
    },
    # attrs28
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
266
267
        _KEY_OF_WINDOW_SIZE: (2, 2),
    },
268
269
270
271
272
273
274
275
276
277
    # attrs29
    {
        _KEY_OF_RELATIVE_EMBEDDING: True,
        _KEY_OF_SELF_ATTN_BIAS_TYPE: "pre_scale_bias",
    },
    # attrs30
    {
        _KEY_OF_RELATIVE_EMBEDDING: True,
        _KEY_OF_SELF_ATTN_BIAS_TYPE: "post_scale_bias",
    },
278
]
279
280
281
282

ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]


283
284
class BaseRunner:
    """Base runner to define forward and backward tests"""
285

286
287
288
    layer_type: TransformerLayerType = None
    reference_layer: flax.linen.Module = None
    transformations: Dict[str, str] = None
289

290
291
292
293
    def __init__(self, attrs):
        self.attrs = attrs
        self._generate_test_rngs()
        # Disable fused attention for attention dropout because the different dropout impl
294
295
        if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv("NVTE_FUSED_ATTN"):
            os.environ["NVTE_FUSED_ATTN"] = "0"
296

297
298
299
    def _generate_test_rngs(self):
        root_rng = jax.random.PRNGKey(0)
        params_rng, init_dropout_rng, apply_dropout_rng = jax.random.split(root_rng, 3)
300
301
        self.init_rng = {"params": params_rng, "dropout": init_dropout_rng}
        self.apply_rng = {"dropout": apply_dropout_rng}
302

303
304
305
    def _generate_layer(self, layer_cls, diff_inputs, no_diff_inputs):
        layer = layer_cls()
        variables = layer.init(self.init_rng, *diff_inputs, *no_diff_inputs)
306
        others, params = flax.core.pop(variables, "params")
307
308
        del variables
        return layer, params, others
309

310
    def _loss_fn(self, diff_xs, no_diff_xs, params, others, model):
311
        variables = {"params": params, **others}
312
313
314
315
316
317
318
319
        output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng)
        return jnp.mean(output, dtype=jnp.float32).astype(output.dtype)

    def _sync_params(self, ref, target):
        """Copy the reference params to target"""
        target = sync_params_values(target, ref, self.transformations)
        return ref, target

320
321
322
323
324
325
326
    def test_forward(
        self,
        data_shape: Tuple[int],
        dtype: jnp.dtype,
        rtol: Optional[float] = None,
        atol: Optional[float] = None,
    ) -> None:
327
328
329
        """Test only the forward"""
        inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)

330
331
        ref_layer_cls = partial(self.reference_layer, **self.attrs)
        layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
332
333
334
335
336
337
338
339

        ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
        test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
        ref_params, test_params = self._sync_params(ref_params, test_params)

        ref_out = self._loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer)
        test_out = self._loss_fn(inputs, test_masks, test_params, test_others, test_layer)

340
341
        tols = dtype_tols(dtype, rtol=rtol, atol=atol)
        assert_allclose(ref_out, test_out, **tols)
342

343
344
345
346
347
348
349
    def test_backward(
        self,
        data_shape: Tuple[int],
        dtype: jnp.dtype,
        rtol: Optional[float] = None,
        atol: Optional[float] = None,
    ) -> None:
350
351
352
        """Test forward and backward through value_and_grad()"""
        inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)

353
354
        ref_layer_cls = partial(self.reference_layer, **self.attrs)
        layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
355
356
357
358
359

        ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
        test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)

        ref_params, test_params = self._sync_params(ref_params, test_params)
360

361
        if get_quantize_config().is_fp8_enabled():
362
            for _ in range(4):
363
                _, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
364
365
366
367
368
369
                    inputs,
                    test_masks,
                    test_params,
                    test_others,
                    test_layer,
                )
370
371
372
373
                if (
                    get_quantize_config().get_scaling_mode(TensorSource.X)
                    == ScalingMode.DELAYED_TENSOR_SCALING
                ):
374
                    _, updated_quantize_meta = flax.core.pop(
375
                        updated_state[0], get_quantize_config().COLLECTION_NAME
376
377
                    )
                    test_others = update_collections(
378
                        {get_quantize_config().COLLECTION_NAME: updated_quantize_meta}, test_others
379
380
381
                    )
                    del updated_quantize_meta
                del updated_state
382

383
384
        grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False)

385
386
387
388
389
390
        ref_out, (ref_dgrads, ref_wgrads) = grad_fn(
            inputs, ref_masks, ref_params, ref_others, ref_layer
        )
        test_out, (test_dgrads, test_wgrads) = grad_fn(
            inputs, test_masks, test_params, test_others, test_layer
        )
391

392
393
394
        tols = dtype_tols(dtype, rtol=rtol, atol=atol)
        assert_allclose(ref_out, test_out, **tols)
        assert_tree_like_allclose(ref_dgrads, test_dgrads, **tols)
395
396

        _, restructed_ref_wgrads = self._sync_params(ref_wgrads, test_wgrads)
397
        assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, **tols)
398
399
400
401


class EncoderRunner(BaseRunner):
    """Encoder runner implementations"""
402

403
404
405
    layer_type = TransformerLayerType.ENCODER
    reference_layer = RefEncoderLayer
    transformations = {
406
407
408
409
410
411
412
413
414
415
        "attention/qkv/scale": "pre_attention_layer_norm/scale",
        "attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias",
        "attention/query/scale": "pre_attention_layer_norm/scale",
        "attention/query/ln_bias": "pre_attention_layer_norm/ln_bias",
        "mlp/wi_kernel": "mlp/wi/kernel",
        "mlp/wi_bias": "mlp/wi/bias",
        "mlp/wo_kernel": "mlp/wo/kernel",
        "mlp/wo_bias": "mlp/wo/bias",
        "mlp/scale": "pre_mlp_layer_norm/scale",
        "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
416
417
418
419
420
421
422
    }

    def generate_inputs(self, data_shape, dtype):
        """
        Return inputs, (ref_masks, test_masks)
        """
        transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
423
424
425
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
426

427
428
        data_rng = jax.random.PRNGKey(2024)
        inputs = (jax.random.normal(data_rng, data_shape, dtype),)
429

430
431
432
        mask_shape = (batch, 1, seqlen, seqlen)
        padded_mask = jnp.zeros(mask_shape, dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones(mask_shape, dtype=jnp.uint8), k=1)
433
        if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
434
435
436
437
            mask = causal_mask
        else:
            mask = padded_mask
        ref_masks = (1 - mask,)
438
        test_masks = (None, mask)  # The second arg of Transformer is encoded tokens.
439
440
441
442
443
444
445
446

        return inputs, (ref_masks, test_masks)


class DecoderRunner(BaseRunner):
    """
    Decoder runner implementations
    """
447

448
449
450
    layer_type = TransformerLayerType.DECODER
    reference_layer = RefDecoderLayer
    transformations = {
451
452
453
454
455
456
457
458
459
460
461
462
463
464
        "encoder_decoder_attention/qkv/scale": "pre_cross_attention_layer_norm/scale",
        "encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
        "encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale",
        "encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
        "self_attention/qkv/scale": "pre_self_attention_layer_norm/scale",
        "self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias",
        "self_attention/query/scale": "pre_self_attention_layer_norm/scale",
        "self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias",
        "mlp/wi_kernel": "mlp/wi/kernel",
        "mlp/wi_bias": "mlp/wi/bias",
        "mlp/wo_kernel": "mlp/wo/kernel",
        "mlp/wo_bias": "mlp/wo/bias",
        "mlp/scale": "pre_mlp_layer_norm/scale",
        "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
465
466
467
468
469
470
471
    }

    def generate_inputs(self, data_shape, dtype):
        """
        Return inputs, (ref_masks, test_masks)
        """
        transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
472
473
474
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
475

476
477
        data_rng = jax.random.PRNGKey(0)
        data_rng_0, data_rng_1 = jax.random.split(data_rng, 2)
478
479
480
481
        inputs = (
            jax.random.normal(data_rng_0, data_shape, dtype),
            jax.random.normal(data_rng_1, data_shape, dtype),
        )
482
483
484

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
485
        if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
486
487
488
            self_mask = causal_mask
        else:
            self_mask = padded_mask
489

490
491
492
493
494
495
        ref_masks = (1 - self_mask, 1 - padded_mask)
        test_masks = (self_mask, padded_mask)

        return inputs, (ref_masks, test_masks)


496
497
498
499
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", ATTRS)
class BaseTester:
500
501
502
    """
    Pytest interface to invoke the runner
    """
503

504
    runner = BaseRunner
505
506

    def test_forward(self, data_shape, dtype, attrs):
507
        """Test normal datatype forward"""
508
509
        # Ensure FP8 disabled.
        # Empty MeshResource is used as we are running on a single device
510
        with autocast(enabled=False, mesh_resource=MeshResource()):
511
            self.runner(attrs).test_forward(data_shape, dtype)
512
513
514

    def test_backward(self, data_shape, dtype, attrs):
        """Test normal datatype backward"""
515
516
        # Ensure FP8 disabled.
        # Empty MeshResource is used as we are running on a single device
517
        with autocast(enabled=False, mesh_resource=MeshResource()):
518
            self.runner(attrs).test_backward(data_shape, dtype)
519

520
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
521
522
    @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
    def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
523
        """Test forward with fp8 enabled"""
524
        # Empty MeshResource is used as we are running on a single device
525
        with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()):
526
            self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
527

528
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
529
530
    @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
    def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
531
        """Test backward with fp8 enabled"""
532
        # Empty MeshResource is used as we are running on a single device
533
        with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()):
534
            self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
535
536
537
538
539
540


class TestEncoderLayer(BaseTester):
    """
    Test transformer_engine.jax.flax.TransformerLayer(layer_type=Encoder)
    """
541

542
543
544
545
546
547
548
    runner = EncoderRunner


class TestDecoderLayer(BaseTester):
    """
    Test transformer_engine.jax.flax.TransformerLayer(layer_type=Decoder)
    """
549

550
    runner = DecoderRunner