test_layer.py 20.4 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Test transformer_engine.jax.flax.TransformerLayer"""
5
import os
6
from functools import partial
7
from typing import Dict, Tuple, Optional
8
9
10
11
12
13

import flax
import jax
import jax.numpy as jnp
import pytest

14
15
16
17
18
19
from utils import (
    assert_allclose,
    assert_tree_like_allclose,
    dtype_tols,
    sync_params_values,
)
20
21
22
from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer

23
from transformer_engine.common import recipe
zlsh80826's avatar
zlsh80826 committed
24
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
25
from transformer_engine.jax.quantize import (
26
27
    get_global_quantize_recipe,
    get_quantize_config_with_recipe,
28
29
30
    ScalingMode,
    is_fp8_available,
    update_collections,
31
    TensorSource,
32
    autocast,
33
)
34
from transformer_engine.jax.sharding import MeshResource
35

36

37
@pytest.fixture(autouse=True, scope="function")
38
def enable_fused_attn():
39
    """Enable fused attention"""
40
41
42
43
44
    os.environ["NVTE_FUSED_ATTN"] = "1"
    yield
    del os.environ["NVTE_FUSED_ATTN"]


45
is_fp8_supported, reason = is_fp8_available()
46
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
47
48
49
50
51
52
53
54
55

QUANTIZE_RECIPES = []
""" Find supported scaling modes"""
if is_fp8_supported:
    QUANTIZE_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
if is_mxfp8_supported:
    QUANTIZE_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))


56
57
DATA_SHAPE = [  # (batch, seqlen, emb_dim)
    pytest.param((32, 128, 1024), id="32-128-1024"),
58
]
59
DTYPE = [jnp.bfloat16]
60
61
62
63
64

_KEY_OF_RESIDUAL_POST_LAYERNORM = "apply_residual_connection_post_layernorm"
_KEY_OF_OUTPUT_LAYERNORM = "output_layernorm"
_KEY_OF_DROP_PATH = "drop_path"
_KEY_OF_FUSE_QKV_PARAMS = "fuse_qkv_params"
65
66
67
68
69
_KEY_OF_HIDDEN_DROPOUT = "hidden_dropout"
_KEY_OF_ATTENTION_DROPOUT = "attention_dropout"
_KEY_OF_INTERMEDIATE_DROPOUT = "intermediate_dropout"
_KEY_OF_HIDDEN_DROPOUT_DIMS = "hidden_dropout_dims"
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS = "intermediate_dropout_dims"
70
_KEY_OF_MLP_ACTIVATIONS = "mlp_activations"
71
72
73
74
_KEY_OF_LAYERNORM_TYPE = "layernorm_type"
_KEY_OF_LAYERNORM_EPS = "layernorm_epsilon"
_KEY_OF_ZERO_CENTERED_GAMMA = "zero_centered_gamma"
_KEY_OF_TRANSPOSE_BS = "transpose_batch_sequence"
75
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
76
77
_KEY_OF_NUM_HEADS = "num_attention_heads"
_KEY_OF_NUM_GQA_GROUPS = "num_gqa_groups"
78
_KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
79
_KEY_OF_ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
80
81
82
83
84
_KEY_OF_SELF_ATTN_BIAS_TYPE = "self_attn_bias_type"
_KEY_OF_SELF_ATTN_MASK_TYPE = "self_attn_mask_type"
_KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
85
_KEY_OF_WINDOW_SIZE = "window_size"
86
_KEY_OF_SOFTMAX_TYPE = "softmax_type"
87

zlsh80826's avatar
zlsh80826 committed
88
89
90
BASE_ATTRS = {
    _KEY_OF_TRANSPOSE_BS: True,
    _KEY_OF_NUM_HEADS: 8,
91
    _KEY_OF_HIDDEN_DROPOUT: 0,
92
    _KEY_OF_ATTENTION_DROPOUT: 0.0,
93
94
    _KEY_OF_INTERMEDIATE_DROPOUT: 0,
    _KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
95
    _KEY_OF_LAYERNORM_TYPE: "layernorm",
96
    _KEY_OF_WINDOW_SIZE: (-1, -1),
zlsh80826's avatar
zlsh80826 committed
97
}
98

99
ATTRS = [
100
    # attrs0
101
    {},
102
    # attrs1
103
104
105
    {
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
    },
106
    # attrs2
107
108
109
110
    {
        _KEY_OF_ZERO_CENTERED_GAMMA: True,
        _KEY_OF_LAYERNORM_EPS: 1e-2,
    },
111
    # attrs3
112
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_RESIDUAL_POST_LAYERNORM: True},
113
    # attrs4
114
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_OUTPUT_LAYERNORM: True},
115
    # attrs5
116
117
118
119
120
    {
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_RESIDUAL_POST_LAYERNORM: True,
        _KEY_OF_OUTPUT_LAYERNORM: True,
    },
121
    # attrs6
122
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_DROP_PATH: 0.1},
123
    # attrs7
124
    {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_FUSE_QKV_PARAMS: False},
125
    # attrs8
126
127
128
129
    {
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
    },
130
    # attrs9
131
132
133
134
135
136
137
138
    {
        _KEY_OF_SCALE_ATTN_LOGITS: True,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_HIDDEN_DROPOUT: 0.8,
        _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
        _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
        _KEY_OF_USE_BIAS: True,
    },
139
    # attrs10
140
141
142
143
144
145
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_SCALE_ATTN_LOGITS: True,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
    },
146
    # attrs11
147
148
149
150
151
152
153
154
    {
        _KEY_OF_NUM_HEADS: 8,
        _KEY_OF_NUM_GQA_GROUPS: 4,
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_SCALE_ATTN_LOGITS: True,
        _KEY_OF_MLP_ACTIVATIONS: ("gelu",),
        _KEY_OF_USE_BIAS: True,
    },
155
    # attrs12
156
157
158
159
160
161
162
163
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_NUM_GQA_GROUPS: 1,
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
164
    # attrs13
165
166
167
168
169
170
    {
        _KEY_OF_TRANSPOSE_BS: True,
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
        _KEY_OF_USE_BIAS: True,
    },
171
    # attrs14
172
173
174
175
176
177
178
179
180
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_LAYERNORM_TYPE: "layernorm",
        _KEY_OF_NUM_GQA_GROUPS: 2,
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "alternate",
        _KEY_OF_USE_BIAS: True,
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
181
    # attrs15
182
183
184
185
186
187
188
    {
        _KEY_OF_TRANSPOSE_BS: True,
        _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
        _KEY_OF_ENABLE_ROPE: True,
        _KEY_OF_ROPE_GROUP_METHOD: "alternate",
        _KEY_OF_USE_BIAS: True,
    },
189
    # attrs16
190
191
192
193
194
195
    {
        _KEY_OF_HIDDEN_DROPOUT: 0.3,
        _KEY_OF_HIDDEN_DROPOUT_DIMS: (0,),
        _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
        _KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,),
    },
196
    # attrs17
197
198
199
200
    {
        _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
        _KEY_OF_USE_BIAS: True,
    },
201
    # attrs18
202
203
204
205
    {
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
    },
206
    # attrs19
207
208
209
    {
        _KEY_OF_ATTENTION_DROPOUT: 0.3,
    },
210
    # attrs20
211
212
213
    {
        _KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")),
    },
214
    # attrs21
215
216
217
218
219
220
221
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
        _KEY_OF_WINDOW_SIZE: (64, 0),  # Left size must < DATA_SHAPE seqlen
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
222
    # attrs22
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
        _KEY_OF_WINDOW_SIZE: None,
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
    # attrs23
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
        _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
    },
    # attrs24
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask",
    },
    # attrs25
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "no_mask",
        _KEY_OF_WINDOW_SIZE: (2, 2),
    },
    # attrs26
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
        _KEY_OF_WINDOW_SIZE: (2, 2),
    },
    # attrs27
258
259
260
261
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
        _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
262
263
264
265
266
267
        _KEY_OF_WINDOW_SIZE: None,
    },
    # attrs28
    {
        _KEY_OF_TRANSPOSE_BS: False,
        _KEY_OF_RELATIVE_EMBEDDING: False,
268
269
        _KEY_OF_WINDOW_SIZE: (2, 2),
    },
270
271
272
273
274
275
276
277
278
279
    # attrs29
    {
        _KEY_OF_RELATIVE_EMBEDDING: True,
        _KEY_OF_SELF_ATTN_BIAS_TYPE: "pre_scale_bias",
    },
    # attrs30
    {
        _KEY_OF_RELATIVE_EMBEDDING: True,
        _KEY_OF_SELF_ATTN_BIAS_TYPE: "post_scale_bias",
    },
280
281
282
283
284
285
286
287
    # attrs31
    {
        _KEY_OF_SOFTMAX_TYPE: "off_by_one",
    },
    # attrs31
    {
        _KEY_OF_SOFTMAX_TYPE: "learnable",
    },
288
]
289
290
291
292

ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]


293
294
class BaseRunner:
    """Base runner to define forward and backward tests"""
295

296
297
298
    layer_type: TransformerLayerType = None
    reference_layer: flax.linen.Module = None
    transformations: Dict[str, str] = None
299

300
301
302
303
    def __init__(self, attrs):
        self.attrs = attrs
        self._generate_test_rngs()
        # Disable fused attention for attention dropout because the different dropout impl
304
305
        if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv("NVTE_FUSED_ATTN"):
            os.environ["NVTE_FUSED_ATTN"] = "0"
306

307
308
309
    def _generate_test_rngs(self):
        root_rng = jax.random.PRNGKey(0)
        params_rng, init_dropout_rng, apply_dropout_rng = jax.random.split(root_rng, 3)
310
311
        self.init_rng = {"params": params_rng, "dropout": init_dropout_rng}
        self.apply_rng = {"dropout": apply_dropout_rng}
312

313
314
315
    def _generate_layer(self, layer_cls, diff_inputs, no_diff_inputs):
        layer = layer_cls()
        variables = layer.init(self.init_rng, *diff_inputs, *no_diff_inputs)
316
        others, params = flax.core.pop(variables, "params")
317
318
        del variables
        return layer, params, others
319

320
    def _loss_fn(self, diff_xs, no_diff_xs, params, others, model):
321
        variables = {"params": params, **others}
322
323
324
325
326
327
328
329
        output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng)
        return jnp.mean(output, dtype=jnp.float32).astype(output.dtype)

    def _sync_params(self, ref, target):
        """Copy the reference params to target"""
        target = sync_params_values(target, ref, self.transformations)
        return ref, target

330
331
332
333
334
335
336
    def test_forward(
        self,
        data_shape: Tuple[int],
        dtype: jnp.dtype,
        rtol: Optional[float] = None,
        atol: Optional[float] = None,
    ) -> None:
337
338
339
        """Test only the forward"""
        inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)

340
341
        ref_layer_cls = partial(self.reference_layer, **self.attrs)
        layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
342
343
344
345
346
347
348
349

        ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
        test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
        ref_params, test_params = self._sync_params(ref_params, test_params)

        ref_out = self._loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer)
        test_out = self._loss_fn(inputs, test_masks, test_params, test_others, test_layer)

350
351
        tols = dtype_tols(dtype, rtol=rtol, atol=atol)
        assert_allclose(ref_out, test_out, **tols)
352

353
354
355
356
357
358
359
    def test_backward(
        self,
        data_shape: Tuple[int],
        dtype: jnp.dtype,
        rtol: Optional[float] = None,
        atol: Optional[float] = None,
    ) -> None:
360
361
362
        """Test forward and backward through value_and_grad()"""
        inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)

363
364
        ref_layer_cls = partial(self.reference_layer, **self.attrs)
        layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
365
366
367
368
369

        ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
        test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)

        ref_params, test_params = self._sync_params(ref_params, test_params)
370

371
        if get_quantize_config_with_recipe(get_global_quantize_recipe()).is_fp8_enabled():
372
            for _ in range(4):
373
                _, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
374
375
376
377
378
379
                    inputs,
                    test_masks,
                    test_params,
                    test_others,
                    test_layer,
                )
380
                if (
381
382
383
                    get_quantize_config_with_recipe(get_global_quantize_recipe()).get_scaling_mode(
                        TensorSource.X
                    )
384
385
                    == ScalingMode.DELAYED_TENSOR_SCALING
                ):
386
                    _, updated_quantize_meta = flax.core.pop(
387
388
389
390
                        updated_state[0],
                        get_quantize_config_with_recipe(
                            get_global_quantize_recipe()
                        ).COLLECTION_NAME,
391
392
                    )
                    test_others = update_collections(
393
394
395
396
397
398
                        {
                            get_quantize_config_with_recipe(
                                get_global_quantize_recipe()
                            ).COLLECTION_NAME: updated_quantize_meta
                        },
                        test_others,
399
400
401
                    )
                    del updated_quantize_meta
                del updated_state
402

403
404
        grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False)

405
406
407
408
409
410
        ref_out, (ref_dgrads, ref_wgrads) = grad_fn(
            inputs, ref_masks, ref_params, ref_others, ref_layer
        )
        test_out, (test_dgrads, test_wgrads) = grad_fn(
            inputs, test_masks, test_params, test_others, test_layer
        )
411

412
413
414
        tols = dtype_tols(dtype, rtol=rtol, atol=atol)
        assert_allclose(ref_out, test_out, **tols)
        assert_tree_like_allclose(ref_dgrads, test_dgrads, **tols)
415
416

        _, restructed_ref_wgrads = self._sync_params(ref_wgrads, test_wgrads)
417
        assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, **tols)
418
419
420
421


class EncoderRunner(BaseRunner):
    """Encoder runner implementations"""
422

423
424
425
    layer_type = TransformerLayerType.ENCODER
    reference_layer = RefEncoderLayer
    transformations = {
426
427
428
429
        "attention/qkv/scale": "pre_attention_layer_norm/scale",
        "attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias",
        "attention/query/scale": "pre_attention_layer_norm/scale",
        "attention/query/ln_bias": "pre_attention_layer_norm/ln_bias",
430
431
432
        "attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
            "attention/DotProductAttention_0/softmax_offset"
        ),
433
434
435
        "attention/DotProductAttention_0/_FusedDotProductAttention_0/softmax_offset": (
            "attention/DotProductAttention_0/softmax_offset"
        ),
436
437
438
439
440
441
        "mlp/wi_kernel": "mlp/wi/kernel",
        "mlp/wi_bias": "mlp/wi/bias",
        "mlp/wo_kernel": "mlp/wo/kernel",
        "mlp/wo_bias": "mlp/wo/bias",
        "mlp/scale": "pre_mlp_layer_norm/scale",
        "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
442
443
444
445
446
447
448
    }

    def generate_inputs(self, data_shape, dtype):
        """
        Return inputs, (ref_masks, test_masks)
        """
        transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
449
450
451
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
452

453
454
        data_rng = jax.random.PRNGKey(2024)
        inputs = (jax.random.normal(data_rng, data_shape, dtype),)
455

456
457
458
        mask_shape = (batch, 1, seqlen, seqlen)
        padded_mask = jnp.zeros(mask_shape, dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones(mask_shape, dtype=jnp.uint8), k=1)
459
        if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
460
461
462
463
            mask = causal_mask
        else:
            mask = padded_mask
        ref_masks = (1 - mask,)
464
        test_masks = (None, mask)  # The second arg of Transformer is encoded tokens.
465
466
467
468
469
470
471
472

        return inputs, (ref_masks, test_masks)


class DecoderRunner(BaseRunner):
    """
    Decoder runner implementations
    """
473

474
475
476
    layer_type = TransformerLayerType.DECODER
    reference_layer = RefDecoderLayer
    transformations = {
477
478
479
480
        "encoder_decoder_attention/qkv/scale": "pre_cross_attention_layer_norm/scale",
        "encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
        "encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale",
        "encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
481
482
483
        "encoder_decoder_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
            "encoder_decoder_attention/DotProductAttention_0/softmax_offset"
        ),
484
485
486
        "encoder_decoder_attention/DotProductAttention_0/_FusedDotProductAttention_0/softmax_offset": (
            "encoder_decoder_attention/DotProductAttention_0/softmax_offset"
        ),
487
488
489
490
        "self_attention/qkv/scale": "pre_self_attention_layer_norm/scale",
        "self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias",
        "self_attention/query/scale": "pre_self_attention_layer_norm/scale",
        "self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias",
491
492
493
        "self_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
            "self_attention/DotProductAttention_0/softmax_offset"
        ),
494
495
496
        "self_attention/DotProductAttention_0/_FusedDotProductAttention_0/softmax_offset": (
            "self_attention/DotProductAttention_0/softmax_offset"
        ),
497
498
499
500
501
502
        "mlp/wi_kernel": "mlp/wi/kernel",
        "mlp/wi_bias": "mlp/wi/bias",
        "mlp/wo_kernel": "mlp/wo/kernel",
        "mlp/wo_bias": "mlp/wo/bias",
        "mlp/scale": "pre_mlp_layer_norm/scale",
        "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
503
504
505
506
507
508
509
    }

    def generate_inputs(self, data_shape, dtype):
        """
        Return inputs, (ref_masks, test_masks)
        """
        transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
510
511
512
        batch, seqlen = data_shape[:2]
        if transpose_batch_sequence:
            data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
513

514
515
        data_rng = jax.random.PRNGKey(0)
        data_rng_0, data_rng_1 = jax.random.split(data_rng, 2)
516
517
518
519
        inputs = (
            jax.random.normal(data_rng_0, data_shape, dtype),
            jax.random.normal(data_rng_1, data_shape, dtype),
        )
520
521
522

        padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
        causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
523
        if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
524
525
526
            self_mask = causal_mask
        else:
            self_mask = padded_mask
527

528
529
530
531
532
533
        ref_masks = (1 - self_mask, 1 - padded_mask)
        test_masks = (self_mask, padded_mask)

        return inputs, (ref_masks, test_masks)


534
535
536
537
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", ATTRS)
class BaseTester:
538
539
540
    """
    Pytest interface to invoke the runner
    """
541

542
    runner = BaseRunner
543
544

    def test_forward(self, data_shape, dtype, attrs):
545
        """Test normal datatype forward"""
546
547
        # Ensure FP8 disabled.
        # Empty MeshResource is used as we are running on a single device
548
        with autocast(enabled=False, mesh_resource=MeshResource()):
549
            self.runner(attrs).test_forward(data_shape, dtype)
550
551
552

    def test_backward(self, data_shape, dtype, attrs):
        """Test normal datatype backward"""
553
554
        # Ensure FP8 disabled.
        # Empty MeshResource is used as we are running on a single device
555
        with autocast(enabled=False, mesh_resource=MeshResource()):
556
            self.runner(attrs).test_backward(data_shape, dtype)
557

558
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
559
560
    @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
    def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
561
        """Test forward with fp8 enabled"""
562
        # Empty MeshResource is used as we are running on a single device
563
        with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()):
564
            self.runner(attrs).test_forward(data_shape, dtype)
565

566
    @pytest.mark.skipif(not is_fp8_supported, reason=reason)
567
568
    @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
    def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
569
        """Test backward with fp8 enabled"""
570
        # Empty MeshResource is used as we are running on a single device
571
        with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()):
572
            self.runner(attrs).test_backward(data_shape, dtype)
573
574
575
576
577
578


class TestEncoderLayer(BaseTester):
    """
    Test transformer_engine.jax.flax.TransformerLayer(layer_type=Encoder)
    """
579

580
581
582
583
584
585
586
    runner = EncoderRunner


class TestDecoderLayer(BaseTester):
    """
    Test transformer_engine.jax.flax.TransformerLayer(layer_type=Decoder)
    """
587

588
    runner = DecoderRunner