test_model_parallel_encoder.py 20.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
2
3
#
# See LICENSE for license information.
4
"""Encoder training on multi-GPU with tesnor parallelism"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
5
6
7
8
import argparse
import unittest
from functools import partial

9
import flax
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
10
11
12
13
14
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
15
from datasets import load_dataset
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
16
from flax import linen as nn
17
from flax.linen import partitioning as nn_partitioning
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
18
19
from flax.training import train_state
from jax.experimental import mesh_utils
20
from jax.sharding import PartitionSpec, NamedSharding
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
21

22
23
24
25
26
from common import (
    is_bf16_supported,
    get_fp8_recipe_from_name_string,
    assert_params_sufficiently_sharded,
)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
27
import transformer_engine.jax as te
Alp Dener's avatar
Alp Dener committed
28
import transformer_engine.jax.cpp_extensions as tex
29
import transformer_engine.jax.flax as te_flax
30
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
31

32

33
34
35
36
37
38
39
40
DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model"
NAMED_BROADCAST_AXIS = "my_broadcast_axis"
NAMED_TP_AXIS = "my_tp_axis"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
41
42
43
44


class Net(nn.Module):
    """NLP Encoder"""
45

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
46
    num_embed: int
47
    enable_seq_paral: bool
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
48
49
50
51
52

    @nn.compact
    def __call__(self, x, mask, disable_dropout=False):
        x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)

53
54
55
56
57
58
59
60
61
62
63
64
        te_Encoder = partial(
            te_flax.TransformerLayer,
            hidden_size=256,
            mlp_hidden_size=1024,
            num_attention_heads=8,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            dropout_rng_name=DROPOUT_KEY,
            layer_type=te_flax.TransformerLayerType.ENCODER,
            self_attn_mask_type="padding",
            enable_relative_embedding=False,
            enable_sequence_parallel=self.enable_seq_paral,
65
            mlp_activations=("gelu", "linear"),
66
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
67
68
69
70
        x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

        x = x.reshape(x.shape[0], -1)

71
        if self.enable_seq_paral:
72
            # Trigger all-gather to collect a complete tensor alone sequence on each device.
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
            x = jax.lax.with_sharding_constraint(
                x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
            )

        x = te_flax.DenseGeneral(
            features=256,
            kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
            bias_axes=(NAMED_TP_AXIS,),
        )(x)

        x = te_flax.DenseGeneral(
            features=256,
            kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
            bias_axes=(NAMED_BROADCAST_AXIS,),
        )(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
88

89
        x = nn.Dense(features=2)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
90
91
92
        return x


93
def train_step(state, inputs, masks, labels, var_collect, rngs):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
94
95
96
97
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
98
        one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
99
100
101
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

102
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
103
104
105
106
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(var_collect)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

107
    var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
108
109
110
111
112
    state = state.apply_gradients(grads=grads)

    return state, loss, accuracy, var_collect


113
def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
114
    """Train for a single epoch."""
115
    train_ds_size = len(train_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
116
117
    steps_per_epoch = train_ds_size // batch_size
    perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
118
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
119
120
121
122
123
    perms = perms.reshape((steps_per_epoch, batch_size))
    epoch_loss = []
    epoch_accuracy = []

    for perm in perms:
124
125
126
127
128
129
        batch_inputs = train_ds["sentence"][perm, ...]
        batch_masks = train_ds["mask"][perm, ...]
        batch_labels = train_ds["label"][perm, ...]
        state, loss, accuracy, var_collect = train_fn(
            state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
130
131
132
133
134
135
136
137
138
139
140
141
142
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)

    avg_loss = np.mean(epoch_loss)
    avg_accuracy = np.mean(epoch_accuracy)
    return state, avg_loss, avg_accuracy, var_collect


def eval_step(state, inputs, masks, labels, var_collect):
    """Computes loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
143
        one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
144
145
146
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

147
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
148
149
150
151
152
153
154
    loss, logits = loss_fn(var_collect, disable_dropout=True)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return loss, accuracy


def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
    """Evaluation loop."""
155
    test_ds_size = len(test_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
156
157
158
159
160
161
162
    num_steps = test_ds_size // batch_size
    valid_size = num_steps * batch_size
    all_loss = []
    all_accuracy = []

    for batch_start in range(0, valid_size, batch_size):
        batch_end = batch_start + batch_size
163
164
165
        batch_inputs = test_ds["sentence"][batch_start:batch_end]
        batch_masks = test_ds["mask"][batch_start:batch_end]
        batch_labels = test_ds["label"][batch_start:batch_end]
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
166
167
168
169
170
171
172
173
174
175
176
        loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
        all_loss.append(loss)
        all_accuracy.append(accuracy)

    avg_loss = np.mean(all_loss)
    avg_accuracy = np.mean(all_accuracy)
    return avg_loss, avg_accuracy


def data_preprocess(dataset, vocab, word_id, max_seq_len):
    """Convert tokens to numbers."""
177
    nltk.download("punkt_tab")
178
    dataset_size = len(dataset["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
179
    output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
180
    mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
181

182
    for j, sentence in enumerate(dataset["sentence"]):
183
        tokens = nltk.word_tokenize(sentence)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
184
185
186
187
188
189
190
191
192
193
194
195
196
        tensor = output[j]

        for i, word in enumerate(tokens):
            if i >= max_seq_len:
                break

            if word not in vocab:
                vocab[word] = word_id
                tensor[i] = word_id
                word_id = word_id + 1
            else:
                tensor[i] = vocab[word]

197
        seq_len = min(len(tokens), max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
198
        mask_2d = mask_3d[j]
199
        mask_2d[:seq_len, :seq_len] = 0
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
200

201
    new_dataset = {
202
203
204
        "sentence": output,
        "label": dataset["label"].astype(np.float32),
        "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
205
206
    }
    return new_dataset, vocab, word_id
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
207
208
209
210
211
212


def get_datasets(max_seq_len):
    """Load GLUE train and test datasets into memory."""
    vocab = {}
    word_id = 0
213

214
215
    train_ds = load_dataset("glue", "cola", split="train")
    train_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
216
    train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
217

218
219
    test_ds = load_dataset("glue", "cola", split="validation")
    test_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
220
221
222
223
224
225
226
    test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
    return train_ds, test_ds, word_id


def check_fp8(state, var_collect, inputs, masks, labels):
    "Check if model includes FP8."
    rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
227
228
    func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
    assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
229
230
231
232
233


def train_and_evaluate(args):
    """Execute model training and evaluation loop."""
    print(args)
234
235
    jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)

236
    train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
237

238
    num_gpu = jax.local_device_count()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
239
    num_gpu_tp = 2
240
241
    if num_gpu % num_gpu_tp == 0:
        num_gpu_dp = num_gpu // num_gpu_tp
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
242
243
244
245
246
    else:
        num_gpu_dp = 1
        num_gpu_tp = 1

    assert args.batch_size % num_gpu_dp == 0, f"Batch size needs to be multiple of {num_gpu_dp}"
247
248
249
    assert (
        args.test_batch_size % num_gpu_dp == 0
    ), f"Test batch size needs to be multiple of {num_gpu_dp}"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
250

251
252
253
254
255
256
257
258
259
260
261
262
263
    if args.fp8_recipe == "MXFP8BlockScaling":
        assert (
            args.batch_size / num_gpu_dp % 32 == 0
        ), "Batch size needs to be multiple of 32 for MXFP8"
        assert (
            args.test_batch_size / num_gpu_dp % 32 == 0
        ), "Test batch size needs to be multiple of 32 for MXFP8"

    if args.use_fp8:
        fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
    else:
        fp8_recipe = None

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
264
    device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
265
266
    with jax.sharding.Mesh(
        devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
267
268
269
270
    ) as mesh, te.fp8_autocast(
        enabled=args.use_fp8,
        fp8_recipe=fp8_recipe,
        mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
271
    ):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
272
273
274
275
276
277
278
279
280
        rng = jax.random.PRNGKey(args.seed)
        rng, params_rng = jax.random.split(rng)
        rng, dropout_rng = jax.random.split(rng)
        init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}

        input_shape = [args.batch_size, args.max_seq_len]
        mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
        label_shape = [args.batch_size]

281
282
283
284
285
286
287
288
289
290
        # Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
        axis_rules = flax.linen.get_logical_axis_rules()
        axis_rules += ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
        te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)

        with flax.linen.logical_axis_rules(te_extended_axis_rules):

            print(f"Device mesh: {mesh}")
            print(f"Axis rules: {te_extended_axis_rules}")

291
            encoder = Net(num_embed, args.enable_sp)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
292
293
294
295
            inputs = jnp.zeros(input_shape, dtype=jnp.int32)
            masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
            abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)

296
297
298
299
300
301
302
303
            logical_partition_spec = nn.get_partition_spec(abs_var_collect)

            # Note that `nn.logical_to_mesh_sharding` returns a dict with an extra
            # "params" key that contains the sharding for the parameters.
            params_sharding = nn.logical_to_mesh_sharding(
                logical_partition_spec, mesh, te_extended_axis_rules
            )

304
305
            inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
            masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
306

307
            in_shardings = (None, inputs_sharding, masks_sharding)
308
            out_shardings = {
309
310
                key: params_sharding[PARAMS_KEY] if key is PARAMS_KEY else None
                for key in abs_var_collect
311
            }
312
313
314
            jit_encoder_init = jax.jit(
                encoder.init, in_shardings=in_shardings, out_shardings=out_shardings
            )
315
            var_collect = jit_encoder_init(init_rngs, inputs, masks)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
316

317
318
319
            # Check if params are sufficiently sharded after initialization
            assert_params_sufficiently_sharded(var_collect, mesh, print_info=False)

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
320
            optimizer = optax.adamw(args.lr)
321
            var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
322
323
324
            state = train_state.TrainState.create(
                apply_fn=encoder.apply, params=params, tx=optimizer
            )
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339

            abs_state = jax.eval_shape(
                lambda: train_state.TrainState.create(
                    apply_fn=encoder.apply, params=params, tx=optimizer
                )
            )
            logical_state_partition_spec = nn.get_partition_spec(abs_state)
            state_sharding = nn.logical_to_mesh_sharding(
                logical_state_partition_spec, mesh, te_extended_axis_rules
            )

            # Check if params are sufficiently sharded after jitting the state creation
            assert_params_sufficiently_sharded(state.params, mesh, print_info=False)

            # state_sharding = get_state_sharding(state, params_sharding)
340
341
342
343
344
345
346
347
348
            labels_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS))

            in_shardings = (
                state_sharding,
                inputs_sharding,
                masks_sharding,
                labels_sharding,
                None,
                None,
349
            )
350
            out_shardings = (state_sharding, None, None, None)
351
352
353
            jit_train_step = jax.jit(
                train_step, in_shardings=in_shardings, out_shardings=out_shardings
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
354

355
            in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
356
            out_shardings = (None, None)
357
358
359
            jit_eval_step = jax.jit(
                eval_step, in_shardings=in_shardings, out_shardings=out_shardings
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
360
361
362
363
364
365
366
367

            if args.use_fp8:
                labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
                check_fp8(state, var_collect, inputs, masks, labels)

            if args.dry_run:
                labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
                rngs = {DROPOUT_KEY: dropout_rng}
368
                jit_train_step(state, inputs, masks, labels, var_collect, rngs)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
369
370
371
372
373
374
375
376
377
                print("PASSED")
                return None

            for epoch in range(1, args.epochs + 1):
                rng, input_rng = jax.random.split(rng)
                rng, dropout_rng = jax.random.split(rng)
                rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}

                state, train_loss, train_accuracy, var_collect = train_epoch(
378
                    state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
379
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
380

381
                test_loss, test_accuracy = eval_model(
382
                    state, test_ds, args.test_batch_size, var_collect, jit_eval_step
383
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
384

385
386
387
388
389
390
391
                print(
                    f"Epoch: {epoch:>2} "
                    f"Train Loss: {train_loss:.6f} "
                    f"Train Accuracy: {train_accuracy:.6f} "
                    f"Test Loss: {test_loss:.6f} "
                    f"Test Accuracy: {test_accuracy:.6f} "
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
392
393
394
395
396
397
398
399
400
401

            return [train_loss, train_accuracy, test_loss, test_accuracy]


def encoder_parser(args):
    """Training settings."""
    parser = argparse.ArgumentParser(description="JAX Encoder Example")
    parser.add_argument(
        "--batch-size",
        type=int,
402
        default=128,
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
403
        metavar="N",
404
        help="input batch size for training (default: 128)",
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
405
406
407
408
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
409
        default=128,
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
410
        metavar="N",
411
        help="input batch size for testing (default: 128)",
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
412
413
414
415
    )
    parser.add_argument(
        "--max-seq-len",
        type=int,
416
        default=64,
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
        metavar="N",
        help="maximum sequence length (default: 32)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=3,
        metavar="N",
        help="number of epochs to train (default: 3)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.0001,
        metavar="LR",
        help="learning rate (default: 0.0001)",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        default=False,
        help="quickly check a single pass",
    )
440
    parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
441
442
443
444
445
446
    parser.add_argument(
        "--use-fp8",
        action="store_true",
        default=False,
        help="Use FP8 for inference and training without recalibration",
    )
447
448
449
450
451
452
    parser.add_argument(
        "--fp8-recipe",
        action="store_true",
        default="DelayedScaling",
        help="Use FP8 recipe (default: DelayedScaling)",
    )
453
454
455
    parser.add_argument(
        "--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
    )
456
457
458
    parser.add_argument(
        "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
    )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
459
460
461
462
463
464
465

    return parser.parse_args(args)


class TestEncoder(unittest.TestCase):
    """Encoder unittests"""

466
467
    is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
    is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
468

469
    def setUp(self):
Alp Dener's avatar
Alp Dener committed
470
471
        """Run 5 epochs for testing"""
        self.args = encoder_parser(["--epochs", "5"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
472

473
    @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
474
475
476
    def test_te_bf16(self):
        """Test Transformer Engine with BF16"""
        actual = train_and_evaluate(self.args)
Alp Dener's avatar
Alp Dener committed
477
        assert actual[0] < 0.43 and actual[1] > 0.80
478
479
480
481
482
483
484

    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_delayed_scaling_fp8(self):
        """Test Transformer Engine with DelayedScaling FP8"""
        self.args.use_fp8 = True
        self.args.fp8_recipe = "DelayedScaling"
        actual = train_and_evaluate(self.args)
Alp Dener's avatar
Alp Dener committed
485
        assert actual[0] < 0.43 and actual[1] > 0.80
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
486

487
488
489
    @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
    def test_te_mxfp8(self):
        """Test Transformer Engine with MXFP8"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
490
        self.args.use_fp8 = True
491
        self.args.fp8_recipe = "MXFP8BlockScaling"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
492
        actual = train_and_evaluate(self.args)
Alp Dener's avatar
Alp Dener committed
493
        assert actual[0] < 0.43 and actual[1] > 0.80
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
494

495
    @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
496
    def test_te_bf16_with_sp(self):
497
498
499
        """Test Transformer Engine with BF16 + SP"""
        self.args.enable_sp = True
        actual = train_and_evaluate(self.args)
Alp Dener's avatar
Alp Dener committed
500
        assert actual[0] < 0.43 and actual[1] > 0.80
501
502
503
504
505
506
507
508

    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_delayed_scaling_fp8_with_sp(self):
        """Test Transformer Engine with DelayedScaling FP8 + SP"""
        self.args.enable_sp = True
        self.args.use_fp8 = True
        self.args.fp8_recipe = "DelayedScaling"
        actual = train_and_evaluate(self.args)
Alp Dener's avatar
Alp Dener committed
509
        assert actual[0] < 0.43 and actual[1] > 0.80
510

511
512
513
    @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
    def test_te_mxfp8_with_sp(self):
        """Test Transformer Engine with MXFP8 + SP"""
514
515
        self.args.enable_sp = True
        self.args.use_fp8 = True
516
        self.args.fp8_recipe = "MXFP8BlockScaling"
517
        actual = train_and_evaluate(self.args)
Alp Dener's avatar
Alp Dener committed
518
        assert actual[0] < 0.43 and actual[1] > 0.80
519

520
521
522
523
524
    @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
    def test_te_bf16_shardy(self):
        """Test Transformer Engine with BF16"""
        self.args.enable_shardy = True
        actual = train_and_evaluate(self.args)
Alp Dener's avatar
Alp Dener committed
525
        assert actual[0] < 0.43 and actual[1] > 0.80
526
527
528
529
530
531
532
533

    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_delayed_scaling_fp8_shardy(self):
        """Test Transformer Engine with DelayedScaling FP8"""
        self.args.enable_shardy = True
        self.args.use_fp8 = True
        self.args.fp8_recipe = "DelayedScaling"
        actual = train_and_evaluate(self.args)
Alp Dener's avatar
Alp Dener committed
534
        assert actual[0] < 0.43 and actual[1] > 0.80
535
536
537
538
539
540
541
542
543

    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_delayed_scaling_fp8_with_sp_shardy(self):
        """Test Transformer Engine with DelayedScaling FP8 + SP"""
        self.args.enable_shardy = True
        self.args.enable_sp = True
        self.args.use_fp8 = True
        self.args.fp8_recipe = "DelayedScaling"
        actual = train_and_evaluate(self.args)
Alp Dener's avatar
Alp Dener committed
544
        assert actual[0] < 0.43 and actual[1] > 0.80
545

Alp Dener's avatar
Alp Dener committed
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
    @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
    @unittest.skipIf(
        tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
    )
    def test_te_mxfp8_shardy(self):
        """Test Transformer Engine with MXFP8"""
        self.args.enable_shardy = True
        self.args.use_fp8 = True
        self.args.fp8_recipe = "MXFP8BlockScaling"
        actual = train_and_evaluate(self.args)
        assert actual[0] < 0.43 and actual[1] > 0.80

    @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
    @unittest.skipIf(
        tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
    )
    def test_te_mxfp8_with_sp_shardy(self):
        """Test Transformer Engine with MXFP8 + SP"""
        self.args.enable_shardy = True
        self.args.enable_sp = True
        self.args.use_fp8 = True
        self.args.fp8_recipe = "MXFP8BlockScaling"
        actual = train_and_evaluate(self.args)
        assert actual[0] < 0.43 and actual[1] > 0.80
570

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
571
572
573

if __name__ == "__main__":
    train_and_evaluate(encoder_parser(None))