test_layer.py 19.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Test transformer_engine.jax.flax.TransformerLayer"""
5
import os
6
from functools import partial
7
from typing import Dict, Tuple, Optional
8
9
10
11
12
13

import flax
import jax
import jax.numpy as jnp
import pytest

14
15
16
17
18
19
from utils import (
    assert_allclose,
    assert_tree_like_allclose,
    dtype_tols,
    sync_params_values,
)
20
21
22
from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer

23
from transformer_engine.common import recipe
zlsh80826's avatar
zlsh80826 committed
24
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
25
from transformer_engine.jax.quantize import (
26
27
    get_global_quantize_recipe,
    get_quantize_config_with_recipe,
28
29
30
    ScalingMode,
    is_fp8_available,
    update_collections,
31
    TensorSource,
32
    autocast,
33
)
34
from transformer_engine.jax.sharding import MeshResource
35

36

37
@pytest.fixture(autouse=True, scope="function")
38
def enable_fused_attn():
39
    """Enable fused attention"""
40
41
42
43
44
    os.environ["NVTE_FUSED_ATTN"] = "1"
    yield
    del os.environ["NVTE_FUSED_ATTN"]


45
is_fp8_supported, reason = is_fp8_available()
46
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
47
48
49
50
51
52
53
54
55

QUANTIZE_RECIPES = []
""" Find supported scaling modes"""
if is_fp8_supported:
    QUANTIZE_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
if is_mxfp8_supported:
    QUANTIZE_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))


56
57
DATA_SHAPE = [  # (batch, seqlen, emb_dim)
    pytest.param((32, 128, 1024), id="32-128-1024"),
58
]
59
DTYPE = [jnp.bfloat16]
60
61
62
63
64

_KEY_OF_RESIDUAL_POST_LAYERNORM = "apply_residual_connection_post_layernorm"
_KEY_OF_OUTPUT_LAYERNORM = "output_layernorm"
_KEY_OF_DROP_PATH = "drop_path"
_KEY_OF_FUSE_QKV_PARAMS = "fuse_qkv_params"
65
66
67
68
69
_KEY_OF_HIDDEN_DROPOUT = "hidden_dropout"
_KEY_OF_ATTENTION_DROPOUT = "attention_dropout"
_KEY_OF_INTERMEDIATE_DROPOUT = "intermediate_dropout"
_KEY_OF_HIDDEN_DROPOUT_DIMS = "hidden_dropout_dims"
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS = "intermediate_dropout_dims"
70
_KEY_OF_MLP_ACTIVATIONS = "mlp_activations"
71
72
73
74
_KEY_OF_LAYERNORM_TYPE = "layernorm_type"
_KEY_OF_LAYERNORM_EPS = "layernorm_epsilon"
_KEY_OF_ZERO_CENTERED_GAMMA = "zero_centered_gamma"
_KEY_OF_TRANSPOSE_BS = "transpose_batch_sequence"
75
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
76
77
_KEY_OF_NUM_HEADS = "num_attention_heads"
_KEY_OF_NUM_GQA_GROUPS = "num_gqa_groups"
78
_KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
79
_KEY_OF_ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
80
81
82
83
84
_KEY_OF_SELF_ATTN_BIAS_TYPE = "self_attn_bias_type"
_KEY_OF_SELF_ATTN_MASK_TYPE = "self_attn_mask_type"
_KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
85
_KEY_OF_WINDOW_SIZE = "window_size"
86
_KEY_OF_SOFTMAX_TYPE = "softmax_type"
87

zlsh80826's avatar
zlsh80826 committed
88
89
90
BASE_ATTRS = {
    _KEY_OF_TRANSPOSE_BS: True,
    _KEY_OF_NUM_HEADS: 8,
91
    _KEY_OF_HIDDEN_DROPOUT: 0,
92
    _KEY_OF_ATTENTION_DROPOUT: 0.0,
93
94
    _KEY_OF_INTERMEDIATE_DROPOUT: 0,
    _KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
95
    _KEY_OF_LAYERNORM_TYPE: "layernorm",
96
    _KEY_OF_WINDOW_SIZE: (-1, -1),
zlsh80826's avatar
zlsh80826 committed
97
}
98

99
ATTRS = [
100
    # attrs0
101
    {},
102
    # attrs1
103
104
105
    {
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
    },
106
    # attrs2
107
108
109
110
    {
        _KEY_OF_ZERO_CENTERED_GAMMA: True,
        _KEY_OF_LAYERNORM_EPS: 1e-2,
    },
111
    # attrs3
112
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_RESIDUAL_POST_LAYERNORM: True},
113
    # attrs4
114
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_OUTPUT_LAYERNORM: True},
115
    # attrs5
116
117
118
119
120
    {
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_RESIDUAL_POST_LAYERNORM: True,
        _KEY_OF_OUTPUT_LAYERNORM: True,
    },
121
    # attrs6
122
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_DROP_PATH: 0.1},
123
    # attrs7
124
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_FUSE_QKV_PARAMS: False},
125
    # attrs8
126
127
128
129
    {
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
    },
130
    # attrs9
131
132
133
134
135
136
137
138
    {
        _KEY_OF_SCALE_ATTN_LOGITS: True,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_HIDDEN_DROPOUT: 0.8,
        _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
        _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
        _KEY_OF_USE_BIAS: True,
    },
139
    # attrs10
140
141
142
143
144
145
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_SCALE_ATTN_LOGITS: True,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
    },
146
    # attrs11
147
148
149
150
151
152
153
154
    {
        _KEY_OF_NUM_HEADS: 8,
        _KEY_OF_NUM_GQA_GROUPS: 4,
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_SCALE_ATTN_LOGITS: True,
        _KEY_OF_MLP_ACTIVATIONS: ("gelu",),
        _KEY_OF_USE_BIAS: True,
    },
155
    # attrs12
156
157
158
159
160
161
162
163
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_NUM_GQA_GROUPS: 1,
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
164
    # attrs13
165
166
167
168
169
170
    {
        _KEY_OF_TRANSPOSE_BS: True,
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
        _KEY_OF_USE_BIAS: True,
    },
171
    # attrs14
172
173
174
175
176
177
178
179
180
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_LAYERNORM_TYPE: "layernorm",
        _KEY_OF_NUM_GQA_GROUPS: 2,
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "alternate",
        _KEY_OF_USE_BIAS: True,
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
181
    # attrs15
182
183
184
185
186
187
188
    {
        _KEY_OF_TRANSPOSE_BS: True,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "alternate",
        _KEY_OF_USE_BIAS: True,
    },
189
    # attrs16
190
191
192
193
194
195
    {
        _KEY_OF_HIDDEN_DROPOUT: 0.3,
        _KEY_OF_HIDDEN_DROPOUT_DIMS: (0,),
        _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
        _KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,),
    },
196
    # attrs17
197
198
199
200
    {
        _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
        _KEY_OF_USE_BIAS: True,
    },
201
    # attrs18
202
203
204
205
    {
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
    },
206
    # attrs19
207
208
209
    {
        _KEY_OF_ATTENTION_DROPOUT: 0.3,
    },
210
    # attrs20
211
212
213
    {
        _KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")),
    },
214
    # attrs21
215
216
217
218
219
220
221
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
        _KEY_OF_WINDOW_SIZE: (64, 0),  # Left size must < DATA_SHAPE seqlen
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
222
    # attrs22
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
        _KEY_OF_WINDOW_SIZE: None,
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
    # attrs23
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
    # attrs24
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask",
    },
    # attrs25
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask",
        _KEY_OF_WINDOW_SIZE: (2, 2),
    },
    # attrs26
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
        _KEY_OF_WINDOW_SIZE: (2, 2),
    },
    # attrs27
258
259
260
261
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
262
263
264
265
266
267
        _KEY_OF_WINDOW_SIZE: None,
    },
    # attrs28
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
268
269
        _KEY_OF_WINDOW_SIZE: (2, 2),
    },
270
271
272
273
274
275
276
277
278
279
    # attrs29
    {
        _KEY_OF_RELATIVE_EMBEDDING: True,
        _KEY_OF_SELF_ATTN_BIAS_TYPE: "pre_scale_bias",
    },
    # attrs30
    {
        _KEY_OF_RELATIVE_EMBEDDING: True,
        _KEY_OF_SELF_ATTN_BIAS_TYPE: "post_scale_bias",
    },
280
281
282
283
284
285
286
287
    # attrs31
    {
        _KEY_OF_SOFTMAX_TYPE: "off_by_one",
    },
    # attrs31
    {
        _KEY_OF_SOFTMAX_TYPE: "learnable",
    },
288
]
289
290
291
292

ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]


293
294
class BaseRunner:
    """Base runner to define forward and backward tests"""
295

296
297
298
    layer_type: TransformerLayerType = None
    reference_layer: flax.linen.Module = None
    transformations: Dict[str, str] = None
299

300
301
302
303
    def __init__(self, attrs):
        self.attrs = attrs
        self._generate_test_rngs()
        # Disable fused attention for attention dropout because the different dropout impl
304
305
        if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv("NVTE_FUSED_ATTN"):
            os.environ["NVTE_FUSED_ATTN"] = "0"
306

307
308
309
    def _generate_test_rngs(self):
        root_rng = jax.random.PRNGKey(0)
        params_rng, init_dropout_rng, apply_dropout_rng = jax.random.split(root_rng, 3)
310
311
        self.init_rng = {"params": params_rng, "dropout": init_dropout_rng}
        self.apply_rng = {"dropout": apply_dropout_rng}
312

313
314
315
    def _generate_layer(self, layer_cls, diff_inputs, no_diff_inputs):
        layer = layer_cls()
        variables = layer.init(self.init_rng, *diff_inputs, *no_diff_inputs)
316
        others, params = flax.core.pop(variables, "params")
317
318
        del variables
        return layer, params, others
319

320
    def _loss_fn(self, diff_xs, no_diff_xs, params, others, model):
321
        variables = {"params": params, **others}
322
323
324
325
326
327
328
329
        output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng)
        return jnp.mean(output, dtype=jnp.float32).astype(output.dtype)

    def _sync_params(self, ref, target):
        """Copy the reference params to target"""
        target = sync_params_values(target, ref, self.transformations)
        return ref, target

330
331
332
333
334
335
336
    def test_forward(
        self,
        data_shape: Tuple[int],
        dtype: jnp.dtype,
        rtol: Optional[float] = None,
        atol: Optional[float] = None,
    ) -> None:
337
338
339
        """Test only the forward"""
        inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)

340
341
        ref_layer_cls = partial(self.reference_layer, **self.attrs)
        layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
342
343
344
345
346
347
348
349

        ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
        test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
        ref_params, test_params = self._sync_params(ref_params, test_params)

        ref_out = self._loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer)
        test_out = self._loss_fn(inputs, test_masks, test_params, test_others, test_layer)

350
351
        tols = dtype_tols(dtype, rtol=rtol, atol=atol)
        assert_allclose(ref_out, test_out, **tols)
352

353
354
355
356
357
358
359
    def test_backward(
        self,
        data_shape: Tuple[int],
        dtype: jnp.dtype,
        rtol: Optional[float] = None,
        atol: Optional[float] = None,
    ) -> None:
360
361
362
        """Test forward and backward through value_and_grad()"""
        inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)

363
364
        ref_layer_cls = partial(self.reference_layer, **self.attrs)
        layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
365
366
367
368
369

        ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
        test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)

        ref_params, test_params = self._sync_params(ref_params, test_params)
370

371
        if get_quantize_config_with_recipe(get_global_quantize_recipe()).is_fp8_enabled():
372
            for _ in range(4):
373
                _, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
374
375
376
377
378
379
                    inputs,
                    test_masks,
                    test_params,
                    test_others,
                    test_layer,
                )
380
                if (
381
382
383
                    get_quantize_config_with_recipe(get_global_quantize_recipe()).get_scaling_mode(
                        TensorSource.X
                    )
384
385
                    == ScalingMode.DELAYED_TENSOR_SCALING
                ):
386
                    _, updated_quantize_meta = flax.core.pop(
387
388
389
390
                        updated_state[0],
                        get_quantize_config_with_recipe(
                            get_global_quantize_recipe()
                        ).COLLECTION_NAME,
391
392
                    )
                    test_others = update_collections(
393
394
395
396
397
398
                        {
                            get_quantize_config_with_recipe(
                                get_global_quantize_recipe()
                            ).COLLECTION_NAME: updated_quantize_meta
                        },
                        test_others,
399
400
401
                    )
                    del updated_quantize_meta
                del updated_state
402

403
404
        grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False)

405
406
407
408
409
410
        ref_out, (ref_dgrads, ref_wgrads) = grad_fn(
            inputs, ref_masks, ref_params, ref_others, ref_layer
        )
        test_out, (test_dgrads, test_wgrads) = grad_fn(
            inputs, test_masks, test_params, test_others, test_layer
        )
411

412
413
414
        tols = dtype_tols(dtype, rtol=rtol, atol=atol)
        assert_allclose(ref_out, test_out, **tols)
        assert_tree_like_allclose(ref_dgrads, test_dgrads, **tols)
415
416

        _, restructed_ref_wgrads = self._sync_params(ref_wgrads, test_wgrads)
417
        assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, **tols)
418
419
420
421


class EncoderRunner(BaseRunner):
    """Encoder runner implementations"""
422

423
424
425
    layer_type = TransformerLayerType.ENCODER
    reference_layer = RefEncoderLayer
    transformations = {
426
427
428
429
        "attention/qkv/scale": "pre_attention_layer_norm/scale",
        "attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias",
        "attention/query/scale": "pre_attention_layer_norm/scale",
        "attention/query/ln_bias": "pre_attention_layer_norm/ln_bias",
430
431
432
        "attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
            "attention/DotProductAttention_0/softmax_offset"
        ),
433
434
435
436
437
438
        "mlp/wi_kernel": "mlp/wi/kernel",
        "mlp/wi_bias": "mlp/wi/bias",
        "mlp/wo_kernel": "mlp/wo/kernel",
        "mlp/wo_bias": "mlp/wo/bias",
        "mlp/scale": "pre_mlp_layer_norm/scale",
        "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
439
440
441
442
443
444
445
    }

    def generate_inputs(self, data_shape, dtype):
        """
        Return inputs, (ref_masks, test_masks)
        """
        transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
446
447
448
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
449

450
451
        data_rng = jax.random.PRNGKey(2024)
        inputs = (jax.random.normal(data_rng, data_shape, dtype),)
452

453
454
455
        mask_shape = (batch, 1, seqlen, seqlen)
        padded_mask = jnp.zeros(mask_shape, dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones(mask_shape, dtype=jnp.uint8), k=1)
456
        if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
457
458
459
460
            mask = causal_mask
        else:
            mask = padded_mask
        ref_masks = (1 - mask,)
461
        test_masks = (None, mask)  # The second arg of Transformer is encoded tokens.
462
463
464
465
466
467
468
469

        return inputs, (ref_masks, test_masks)


class DecoderRunner(BaseRunner):
    """
    Decoder runner implementations
    """
470

471
472
473
    layer_type = TransformerLayerType.DECODER
    reference_layer = RefDecoderLayer
    transformations = {
474
475
476
477
        "encoder_decoder_attention/qkv/scale": "pre_cross_attention_layer_norm/scale",
        "encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
        "encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale",
        "encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
478
479
480
        "encoder_decoder_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
            "encoder_decoder_attention/DotProductAttention_0/softmax_offset"
        ),
481
482
483
484
        "self_attention/qkv/scale": "pre_self_attention_layer_norm/scale",
        "self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias",
        "self_attention/query/scale": "pre_self_attention_layer_norm/scale",
        "self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias",
485
486
487
        "self_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
            "self_attention/DotProductAttention_0/softmax_offset"
        ),
488
489
490
491
492
493
        "mlp/wi_kernel": "mlp/wi/kernel",
        "mlp/wi_bias": "mlp/wi/bias",
        "mlp/wo_kernel": "mlp/wo/kernel",
        "mlp/wo_bias": "mlp/wo/bias",
        "mlp/scale": "pre_mlp_layer_norm/scale",
        "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
494
495
496
497
498
499
500
    }

    def generate_inputs(self, data_shape, dtype):
        """
        Return inputs, (ref_masks, test_masks)
        """
        transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
501
502
503
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
504

505
506
        data_rng = jax.random.PRNGKey(0)
        data_rng_0, data_rng_1 = jax.random.split(data_rng, 2)
507
508
509
510
        inputs = (
            jax.random.normal(data_rng_0, data_shape, dtype),
            jax.random.normal(data_rng_1, data_shape, dtype),
        )
511
512
513

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
514
        if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
515
516
517
            self_mask = causal_mask
        else:
            self_mask = padded_mask
518

519
520
521
522
523
524
        ref_masks = (1 - self_mask, 1 - padded_mask)
        test_masks = (self_mask, padded_mask)

        return inputs, (ref_masks, test_masks)


525
526
527
528
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", ATTRS)
class BaseTester:
529
530
531
    """
    Pytest interface to invoke the runner
    """
532

533
    runner = BaseRunner
534
535

    def test_forward(self, data_shape, dtype, attrs):
536
        """Test normal datatype forward"""
537
538
        # Ensure FP8 disabled.
        # Empty MeshResource is used as we are running on a single device
539
        with autocast(enabled=False, mesh_resource=MeshResource()):
540
            self.runner(attrs).test_forward(data_shape, dtype)
541
542
543

    def test_backward(self, data_shape, dtype, attrs):
        """Test normal datatype backward"""
544
545
        # Ensure FP8 disabled.
        # Empty MeshResource is used as we are running on a single device
546
        with autocast(enabled=False, mesh_resource=MeshResource()):
547
            self.runner(attrs).test_backward(data_shape, dtype)
548

549
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
550
551
    @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
    def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
552
        """Test forward with fp8 enabled"""
553
        # Empty MeshResource is used as we are running on a single device
554
        with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()):
555
            self.runner(attrs).test_forward(data_shape, dtype)
556

557
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
558
559
    @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
    def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
560
        """Test backward with fp8 enabled"""
561
        # Empty MeshResource is used as we are running on a single device
562
        with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()):
563
            self.runner(attrs).test_backward(data_shape, dtype)
564
565
566
567
568
569


class TestEncoderLayer(BaseTester):
    """
    Test transformer_engine.jax.flax.TransformerLayer(layer_type=Encoder)
    """
570

571
572
573
574
575
576
577
    runner = EncoderRunner


class TestDecoderLayer(BaseTester):
    """
    Test transformer_engine.jax.flax.TransformerLayer(layer_type=Decoder)
    """
578

579
    runner = DecoderRunner