test_multigpu_encoder.py 19.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
2
3
#
# See LICENSE for license information.
4
"""Encoder training on multi-GPU with data parallelism"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
5
6
7
8
import argparse
import unittest
from functools import partial

9
import flax
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
10
11
12
13
14
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
15
from datasets import load_dataset
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
16
from flax import linen as nn
17
from flax.linen import partitioning as nn_partitioning
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
18
19
from flax.training import train_state
from jax.experimental import mesh_utils
20
from jax.sharding import PartitionSpec, NamedSharding
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
21

22
from common import is_bf16_supported, get_quantization_recipe_from_name_string
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
23
import transformer_engine.jax as te
Alp Dener's avatar
Alp Dener committed
24
import transformer_engine.jax.cpp_extensions as tex
25
import transformer_engine.jax.flax as te_flax
26
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
27

28

29
30
31
32
DEVICE_DP_AXIS = "data"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
DROPOUT_KEY = "dropout"
33
SR_KEY = "sr_rng"
34
INPUT_KEY = "input_rng"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
35
36
37
38


class Net(nn.Module):
    """NLP Encoder"""
39

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
40
41
42
43
44
45
    num_embed: int

    @nn.compact
    def __call__(self, x, mask, disable_dropout=False):
        x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)

46
47
48
49
50
51
52
53
54
55
56
57
        te_Encoder = partial(
            te_flax.TransformerLayer,
            hidden_size=256,
            mlp_hidden_size=1024,
            num_attention_heads=8,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            dropout_rng_name=DROPOUT_KEY,
            layer_type=te_flax.TransformerLayerType.ENCODER,
            self_attn_mask_type="padding",
            enable_relative_embedding=False,
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
58
59
60
61
        x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

        x = x.reshape(x.shape[0], -1)

62
        x = te_flax.DenseGeneral(features=256)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
63

64
        x = te_flax.DenseGeneral(features=256)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
65

66
        x = nn.Dense(features=2)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
67
68
69
        return x


70
def train_step(state, inputs, masks, labels, var_collect, rngs):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
71
72
73
74
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
75
        one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
76
77
78
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

79
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
80
81
82
83
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(var_collect)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

84
    var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
85
86
87
88
89
    state = state.apply_gradients(grads=grads)

    return state, loss, accuracy, var_collect


90
def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
91
    """Train for a single epoch."""
92
    train_ds_size = len(train_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
93
94
    steps_per_epoch = train_ds_size // batch_size
    perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
95
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
96
97
98
99
100
    perms = perms.reshape((steps_per_epoch, batch_size))
    epoch_loss = []
    epoch_accuracy = []

    for perm in perms:
101
102
        # Split and reassign to 'rngs' to ensure unique rng for each step
        rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
103
104
105
106
107
108
        batch_inputs = train_ds["sentence"][perm, ...]
        batch_masks = train_ds["mask"][perm, ...]
        batch_labels = train_ds["label"][perm, ...]
        state, loss, accuracy, var_collect = train_fn(
            state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
109
110
111
112
113
114
115
116
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)

    avg_loss = np.mean(epoch_loss)
    avg_accuracy = np.mean(epoch_accuracy)
    return state, avg_loss, avg_accuracy, var_collect


117
def eval_step(state, inputs, masks, labels, var_collect, rngs):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
118
119
120
    """Computes loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
121
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
122
        one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
123
124
125
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

126
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
127
128
129
130
131
    loss, logits = loss_fn(var_collect, disable_dropout=True)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return loss, accuracy


132
def eval_model(state, test_ds, batch_size, var_collect, eval_fn, rngs):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
133
    """Evaluation loop."""
134
    test_ds_size = len(test_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
135
136
137
138
139
140
    num_steps = test_ds_size // batch_size
    valid_size = num_steps * batch_size
    all_loss = []
    all_accuracy = []

    for batch_start in range(0, valid_size, batch_size):
141
142
        # Split and reassign to 'rngs' to ensure unique rng for each step
        rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
143
        batch_end = batch_start + batch_size
144
145
146
        batch_inputs = test_ds["sentence"][batch_start:batch_end]
        batch_masks = test_ds["mask"][batch_start:batch_end]
        batch_labels = test_ds["label"][batch_start:batch_end]
147
        loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect, rngs)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
148
149
150
151
152
153
154
155
156
157
        all_loss.append(loss)
        all_accuracy.append(accuracy)

    avg_loss = np.mean(all_loss)
    avg_accuracy = np.mean(all_accuracy)
    return avg_loss, avg_accuracy


def data_preprocess(dataset, vocab, word_id, max_seq_len):
    """Convert tokens to numbers."""
158
    nltk.download("punkt_tab")
159
    dataset_size = len(dataset["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
160
    output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
161
    mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
162

163
    for j, sentence in enumerate(dataset["sentence"]):
164
        tokens = nltk.word_tokenize(sentence)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
165
166
167
168
169
170
171
172
173
174
175
176
177
        tensor = output[j]

        for i, word in enumerate(tokens):
            if i >= max_seq_len:
                break

            if word not in vocab:
                vocab[word] = word_id
                tensor[i] = word_id
                word_id = word_id + 1
            else:
                tensor[i] = vocab[word]

178
        seq_len = min(len(tokens), max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
179
        mask_2d = mask_3d[j]
180
        mask_2d[:seq_len, :seq_len] = 0
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
181

182
    new_dataset = {
183
184
185
        "sentence": output,
        "label": dataset["label"].astype(np.float32),
        "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
186
187
    }
    return new_dataset, vocab, word_id
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
188
189
190
191
192
193


def get_datasets(max_seq_len):
    """Load GLUE train and test datasets into memory."""
    vocab = {}
    word_id = 0
194

195
196
    train_ds = load_dataset("glue", "cola", split="train")
    train_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
197
    train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
198

199
200
    test_ds = load_dataset("glue", "cola", split="validation")
    test_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
201
202
203
204
205
206
    test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
    return train_ds, test_ds, word_id


def check_fp8(state, var_collect, inputs, masks, labels):
    "Check if model includes FP8."
207
    rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)}
208
209
    func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
    assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
210
211


212
213
214
def get_params_sharding(sharding_rules, abs_var_collect, mesh):
    """Refer params to create params sharding"""
    rules_dict = dict(sharding_rules)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
215
216
217

    def to_device_axis(logical_axis):
        partitions = [rules_dict[key] for key in logical_axis]
218
        return NamedSharding(mesh, PartitionSpec(*partitions))
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
219
220

    params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
221
222
223
224
225
    params_axes_sharding = jax.tree_util.tree_map(
        to_device_axis, nn_partitioning.get_axis_names(params_axes)
    )
    params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
    params_sharding = jax.tree_util.tree_map(
226
        lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY]
227
228
229
    )
    params_sharding = {**params_sharding, **params_axes_sharding}
    return params_sharding
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
230
231


232
233
def get_state_sharding(state, params_sharding):
    """Refer params_sharding to create state sharding"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
234
235

    def replace_params(x):
236
        return params_sharding if isinstance(x, dict) else None
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
237

238
239
240
241
    state_sharding = jax.tree_util.tree_map(
        replace_params, state, is_leaf=lambda x: isinstance(x, dict)
    )
    return state_sharding
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
242
243
244
245
246


def train_and_evaluate(args):
    """Execute model training and evaluation loop."""
    print(args)
247
    jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
248
    train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
249

250
251
    num_gpu = jax.local_device_count()
    assert args.batch_size % num_gpu == 0, f"Batch size needs to be multiple of {num_gpu}"
252
    assert args.test_batch_size % num_gpu == 0, f"Test batch size needs to be multiple of {num_gpu}"
253
254
255
256
257
258
259
260
261
    if args.fp8_recipe == "MXFP8BlockScaling":
        assert (
            args.batch_size / num_gpu % 32 == 0
        ), "Batch size needs to be multiple of 32 for MXFP8"
        assert (
            args.test_batch_size / num_gpu % 32 == 0
        ), "Test batch size needs to be multiple of 32 for MXFP8"

    if args.use_fp8:
262
        fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
263
264
    else:
        fp8_recipe = None
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
265

266
    device_mesh = mesh_utils.create_device_mesh((num_gpu,))
267
    with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh, te.autocast(
268
        enabled=args.use_fp8,
269
        recipe=fp8_recipe,
270
        mesh_resource=te.MeshResource(dp_resource=DEVICE_DP_AXIS),
271
    ):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
272
273
274
275

        rng = jax.random.PRNGKey(args.seed)
        rng, params_rng = jax.random.split(rng)
        rng, dropout_rng = jax.random.split(rng)
276
        rng, sr_rng = jax.random.split(rng)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
277
278
279
280
281
282
        init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}

        input_shape = [args.batch_size, args.max_seq_len]
        mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
        label_shape = [args.batch_size]

283
        # Add TE logical axis rules to our Flax logical axis rule context. This must be done inside autocast
284
285
        sharding_rules = te_flax.extend_logical_axis_rules(tuple())
        with flax.linen.logical_axis_rules(sharding_rules):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
286
287
288
289
290
            encoder = Net(num_embed)
            inputs = jnp.zeros(input_shape, dtype=jnp.int32)
            masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
            abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)

291
292
293
            params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
            inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
            masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
294

295
            in_shardings = (None, inputs_sharding, masks_sharding)
296
            out_shardings = {
297
                key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
298
            }
299
300
301
            jit_encoder_init = jax.jit(
                encoder.init, in_shardings=in_shardings, out_shardings=out_shardings
            )
302
            var_collect = jit_encoder_init(init_rngs, inputs, masks)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
303
304

            optimizer = optax.adamw(args.lr)
305
            var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
306
307
308
            state = train_state.TrainState.create(
                apply_fn=encoder.apply, params=params, tx=optimizer
            )
309
310
311
312
313
314
            state_sharding = get_state_sharding(state, params_sharding)
            labels_sharding = NamedSharding(
                mesh,
                PartitionSpec(
                    DEVICE_DP_AXIS,
                ),
315
            )
316
317
318
319
320
321
322
323
324
            in_shardings = (
                state_sharding,
                inputs_sharding,
                masks_sharding,
                labels_sharding,
                None,
                None,
            )
            out_shardings = (state_sharding, None, None, None)
325
326
327
            jit_train_step = jax.jit(
                train_step, in_shardings=in_shardings, out_shardings=out_shardings
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
328

329
330
331
332
333
334
335
336
            in_shardings = (
                state_sharding,
                inputs_sharding,
                masks_sharding,
                labels_sharding,
                None,
                None,
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
337
            out_shardings = (None, None)
338
339
340
            jit_eval_step = jax.jit(
                eval_step, in_shardings=in_shardings, out_shardings=out_shardings
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
341
342
343
344
345
346
347

            if args.use_fp8:
                labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
                check_fp8(state, var_collect, inputs, masks, labels)

            if args.dry_run:
                labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
348
                rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
349
                jit_train_step(state, inputs, masks, labels, var_collect, rngs)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
350
351
352
353
                print("PASSED")
                return None

            for epoch in range(1, args.epochs + 1):
354
                # Split and reassign to 'rng' to ensure unique rng for each step
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
355
356
                rng, input_rng = jax.random.split(rng)
                rng, dropout_rng = jax.random.split(rng)
357
358
                rng, sr_rng = jax.random.split(rng)
                rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
359
360

                state, train_loss, train_accuracy, var_collect = train_epoch(
361
                    state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
362
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
363

364
                test_loss, test_accuracy = eval_model(
365
                    state, test_ds, args.test_batch_size, var_collect, jit_eval_step, rngs
366
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
367

368
369
370
371
372
373
374
                print(
                    f"Epoch: {epoch:>2} "
                    f"Train Loss: {train_loss:.6f} "
                    f"Train Accuracy: {train_accuracy:.6f} "
                    f"Test Loss: {test_loss:.6f} "
                    f"Test Accuracy: {test_accuracy:.6f} "
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
375
376
377
378
379
380
381
382
383
384

            return [train_loss, train_accuracy, test_loss, test_accuracy]


def encoder_parser(args):
    """Training settings."""
    parser = argparse.ArgumentParser(description="JAX Encoder Example")
    parser.add_argument(
        "--batch-size",
        type=int,
385
        default=512,
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
386
        metavar="N",
387
        help="input batch size for training (default: 512)",
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
388
389
390
391
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
392
        default=512,
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
393
        metavar="N",
394
        help="input batch size for testing (default: 512)",
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
    )
    parser.add_argument(
        "--max-seq-len",
        type=int,
        default=32,
        metavar="N",
        help="maximum sequence length (default: 32)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=3,
        metavar="N",
        help="number of epochs to train (default: 3)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.0001,
        metavar="LR",
        help="learning rate (default: 0.0001)",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        default=False,
        help="quickly check a single pass",
    )
423
    parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
424
425
426
427
428
429
    parser.add_argument(
        "--use-fp8",
        action="store_true",
        default=False,
        help="Use FP8 for inference and training without recalibration",
    )
430
431
432
433
434
435
    parser.add_argument(
        "--fp8-recipe",
        action="store_true",
        default="DelayedScaling",
        help="Use FP8 recipe (default: DelayedScaling)",
    )
436
437
438
    parser.add_argument(
        "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
    )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
439
440
441
442
443
444
445

    return parser.parse_args(args)


class TestEncoder(unittest.TestCase):
    """Encoder unittests"""

446
447
448
    is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
    is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
    is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
449

450
    def setUp(self):
Alp Dener's avatar
Alp Dener committed
451
        """Run 5 epochs for testing"""
452
        self.args = encoder_parser(["--epochs", "5"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
453

454
    @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
455
456
457
    def test_te_bf16(self):
        """Test Transformer Engine with BF16"""
        actual = train_and_evaluate(self.args)
458
        assert actual[0] < 0.51 and actual[1] > 0.75
459
460
461
462
463
464
465

    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_delayed_scaling_fp8(self):
        """Test Transformer Engine with DelayedScaling FP8"""
        self.args.use_fp8 = True
        self.args.fp8_recipe = "DelayedScaling"
        actual = train_and_evaluate(self.args)
466
        assert actual[0] < 0.51 and actual[1] > 0.75
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
467

468
469
470
471
472
473
    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_current_scaling_fp8(self):
        """Test Transformer Engine with CurrentScaling FP8"""
        self.args.use_fp8 = True
        self.args.fp8_recipe = "Float8CurrentScaling"
        actual = train_and_evaluate(self.args)
474
        assert actual[0] < 0.51 and actual[1] > 0.749
475

476
477
478
    @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
    def test_te_mxfp8(self):
        """Test Transformer Engine with MXFP8"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
479
        self.args.use_fp8 = True
480
        self.args.fp8_recipe = "MXFP8BlockScaling"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
481
        actual = train_and_evaluate(self.args)
482
483
484
485
486
487
488
489
        assert actual[0] < 0.51 and actual[1] > 0.75

    @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
    def test_te_nvfp4(self):
        """Test Transformer Engine with NVFP4"""
        self.args.use_fp8 = True
        self.args.fp8_recipe = "NVFP4BlockScaling"
        actual = train_and_evaluate(self.args)
490
        assert actual[0] < 0.52 and actual[1] > 0.74
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
491

492
493
494
495
496
    @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
    def test_te_bf16_shardy(self):
        """Test Transformer Engine with BF16"""
        self.args.enable_shardy = True
        actual = train_and_evaluate(self.args)
497
        assert actual[0] < 0.51 and actual[1] > 0.75
498
499
500
501
502
503
504
505

    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_delayed_scaling_fp8_shardy(self):
        """Test Transformer Engine with DelayedScaling FP8"""
        self.args.enable_shardy = True
        self.args.use_fp8 = True
        self.args.fp8_recipe = "DelayedScaling"
        actual = train_and_evaluate(self.args)
506
        assert actual[0] < 0.51 and actual[1] > 0.75
507

508
509
510
511
512
513
514
    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_current_scaling_fp8_shardy(self):
        """Test Transformer Engine with CurrentScaling FP8"""
        self.args.enable_shardy = True
        self.args.use_fp8 = True
        self.args.fp8_recipe = "Float8CurrentScaling"
        actual = train_and_evaluate(self.args)
515
        assert actual[0] < 0.51 and actual[1] > 0.749
Alp Dener's avatar
Alp Dener committed
516
517
518
519
520
521
522
523

    @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
    def test_te_mxfp8_shardy(self):
        """Test Transformer Engine with MXFP8"""
        self.args.enable_shardy = True
        self.args.use_fp8 = True
        self.args.fp8_recipe = "MXFP8BlockScaling"
        actual = train_and_evaluate(self.args)
524
525
526
527
528
529
530
531
532
        assert actual[0] < 0.51 and actual[1] > 0.75

    @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
    def test_te_nvfp4_shardy(self):
        """Test Transformer Engine with NVFP4"""
        self.args.enable_shardy = True
        self.args.use_fp8 = True
        self.args.fp8_recipe = "NVFP4BlockScaling"
        actual = train_and_evaluate(self.args)
533
        assert actual[0] < 0.52 and actual[1] > 0.74
534

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
535
536
537

if __name__ == "__main__":
    train_and_evaluate(encoder_parser(None))