test_model_parallel_encoder.py 23.2 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
2
3
#
# See LICENSE for license information.
4
"""Encoder training on multi-GPU with tesnor parallelism"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
5
import argparse
6
import os
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
7
8
9
import unittest
from functools import partial

10
import flax
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
11
12
13
14
15
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
16
from datasets import load_dataset
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
17
from flax import linen as nn
18
from flax.linen import partitioning as nn_partitioning
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
19
20
from flax.training import train_state
from jax.experimental import mesh_utils
21
from jax.sharding import PartitionSpec, NamedSharding
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
22

23
24
from common import (
    is_bf16_supported,
25
    get_quantization_recipe_from_name_string,
26
    assert_params_sufficiently_sharded,
27
    unpack_cached_datasets_if_available,
28
)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
29
import transformer_engine.jax as te
Alp Dener's avatar
Alp Dener committed
30
import transformer_engine.jax.cpp_extensions as tex
31
import transformer_engine.jax.flax as te_flax
32
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
33

34
unpack_cached_datasets_if_available()
35

36
37
38
39
40
41
DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model"
NAMED_BROADCAST_AXIS = "my_broadcast_axis"
NAMED_TP_AXIS = "my_tp_axis"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
42
SR_KEY = "sr_rng"
43
44
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
45
46
47
48


class Net(nn.Module):
    """NLP Encoder"""
49

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
50
    num_embed: int
51
    enable_seq_paral: bool
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
52
53
54
55
56

    @nn.compact
    def __call__(self, x, mask, disable_dropout=False):
        x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)

57
58
59
60
61
62
63
64
65
66
67
68
        te_Encoder = partial(
            te_flax.TransformerLayer,
            hidden_size=256,
            mlp_hidden_size=1024,
            num_attention_heads=8,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            dropout_rng_name=DROPOUT_KEY,
            layer_type=te_flax.TransformerLayerType.ENCODER,
            self_attn_mask_type="padding",
            enable_relative_embedding=False,
            enable_sequence_parallel=self.enable_seq_paral,
69
            mlp_activations=("gelu", "linear"),
70
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
71
72
73
74
        x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

        x = x.reshape(x.shape[0], -1)

75
        if self.enable_seq_paral:
76
            # Trigger all-gather to collect a complete tensor alone sequence on each device.
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
            x = jax.lax.with_sharding_constraint(
                x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
            )

        x = te_flax.DenseGeneral(
            features=256,
            kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
            bias_axes=(NAMED_TP_AXIS,),
        )(x)

        x = te_flax.DenseGeneral(
            features=256,
            kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
            bias_axes=(NAMED_BROADCAST_AXIS,),
        )(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
92

93
        x = nn.Dense(features=2)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
94
95
96
        return x


97
def train_step(state, inputs, masks, labels, var_collect, rngs):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
98
99
100
101
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
102
        one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
103
104
105
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

106
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
107
108
109
110
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(var_collect)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

111
    var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
112
113
114
115
116
    state = state.apply_gradients(grads=grads)

    return state, loss, accuracy, var_collect


117
def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
118
    """Train for a single epoch."""
119
    train_ds_size = len(train_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
120
121
    steps_per_epoch = train_ds_size // batch_size
    perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
122
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
123
124
125
126
127
    perms = perms.reshape((steps_per_epoch, batch_size))
    epoch_loss = []
    epoch_accuracy = []

    for perm in perms:
128
129
        # Split and reassign to 'rngs' to ensure unique rng for each step
        rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
130
131
132
133
134
135
        batch_inputs = train_ds["sentence"][perm, ...]
        batch_masks = train_ds["mask"][perm, ...]
        batch_labels = train_ds["label"][perm, ...]
        state, loss, accuracy, var_collect = train_fn(
            state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
136
137
138
139
140
141
142
143
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)

    avg_loss = np.mean(epoch_loss)
    avg_accuracy = np.mean(epoch_accuracy)
    return state, avg_loss, avg_accuracy, var_collect


144
def eval_step(state, inputs, masks, labels, var_collect, rngs):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
145
146
147
    """Computes loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
148
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
149
        one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
150
151
152
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

153
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
154
155
156
157
158
    loss, logits = loss_fn(var_collect, disable_dropout=True)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return loss, accuracy


159
def eval_model(state, test_ds, batch_size, var_collect, eval_fn, rngs):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
160
    """Evaluation loop."""
161
    test_ds_size = len(test_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
162
163
164
165
166
167
    num_steps = test_ds_size // batch_size
    valid_size = num_steps * batch_size
    all_loss = []
    all_accuracy = []

    for batch_start in range(0, valid_size, batch_size):
168
169
        # Split and reassign to 'rngs' to ensure unique rng for each step
        rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
170
        batch_end = batch_start + batch_size
171
172
173
        batch_inputs = test_ds["sentence"][batch_start:batch_end]
        batch_masks = test_ds["mask"][batch_start:batch_end]
        batch_labels = test_ds["label"][batch_start:batch_end]
174
        loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect, rngs)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
175
176
177
178
179
180
181
182
183
184
        all_loss.append(loss)
        all_accuracy.append(accuracy)

    avg_loss = np.mean(all_loss)
    avg_accuracy = np.mean(all_accuracy)
    return avg_loss, avg_accuracy


def data_preprocess(dataset, vocab, word_id, max_seq_len):
    """Convert tokens to numbers."""
185
    nltk.download("punkt_tab")
186
    dataset_size = len(dataset["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
187
    output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
188
    mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
189

190
    for j, sentence in enumerate(dataset["sentence"]):
191
        tokens = nltk.word_tokenize(sentence)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
192
193
194
195
196
197
198
199
200
201
202
203
204
        tensor = output[j]

        for i, word in enumerate(tokens):
            if i >= max_seq_len:
                break

            if word not in vocab:
                vocab[word] = word_id
                tensor[i] = word_id
                word_id = word_id + 1
            else:
                tensor[i] = vocab[word]

205
        seq_len = min(len(tokens), max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
206
        mask_2d = mask_3d[j]
207
        mask_2d[:seq_len, :seq_len] = 0
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
208

209
    new_dataset = {
210
211
212
        "sentence": output,
        "label": dataset["label"].astype(np.float32),
        "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
213
214
    }
    return new_dataset, vocab, word_id
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
215
216
217
218
219
220


def get_datasets(max_seq_len):
    """Load GLUE train and test datasets into memory."""
    vocab = {}
    word_id = 0
221

222
223
    train_ds = load_dataset("glue", "cola", split="train")
    train_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
224
    train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
225

226
227
    test_ds = load_dataset("glue", "cola", split="validation")
    test_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
228
229
230
231
232
233
    test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
    return train_ds, test_ds, word_id


def check_fp8(state, var_collect, inputs, masks, labels):
    "Check if model includes FP8."
234
    rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)}
235
236
    func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
    assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
237
238
239
240
241


def train_and_evaluate(args):
    """Execute model training and evaluation loop."""
    print(args)
242
243
    jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)

244
    train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
245

246
    num_gpu = jax.local_device_count()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
247
    num_gpu_tp = 2
248
249
    if num_gpu % num_gpu_tp == 0:
        num_gpu_dp = num_gpu // num_gpu_tp
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
250
251
252
253
254
    else:
        num_gpu_dp = 1
        num_gpu_tp = 1

    assert args.batch_size % num_gpu_dp == 0, f"Batch size needs to be multiple of {num_gpu_dp}"
255
256
257
    assert (
        args.test_batch_size % num_gpu_dp == 0
    ), f"Test batch size needs to be multiple of {num_gpu_dp}"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
258

259
260
261
262
263
264
265
266
267
    if args.fp8_recipe == "MXFP8BlockScaling":
        assert (
            args.batch_size / num_gpu_dp % 32 == 0
        ), "Batch size needs to be multiple of 32 for MXFP8"
        assert (
            args.test_batch_size / num_gpu_dp % 32 == 0
        ), "Test batch size needs to be multiple of 32 for MXFP8"

    if args.use_fp8:
268
        fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
269
270
271
    else:
        fp8_recipe = None

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
272
    device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
273
274
    with jax.sharding.Mesh(
        devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
275
    ) as mesh, te.autocast(
276
        enabled=args.use_fp8,
277
        recipe=fp8_recipe,
278
279
280
281
        mesh_resource=te.MeshResource(
            dp_resource=DEVICE_DP_AXIS,
            tpsp_resource=DEVICE_TP_AXIS,
        ),
282
    ):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
283
284
285
        rng = jax.random.PRNGKey(args.seed)
        rng, params_rng = jax.random.split(rng)
        rng, dropout_rng = jax.random.split(rng)
286
287
        rng, sr_rng = jax.random.split(rng)
        init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
288
289
290
291
292

        input_shape = [args.batch_size, args.max_seq_len]
        mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
        label_shape = [args.batch_size]

293
        # Get the base axis rules and extend them with TE's rules. This must be done inside autocast
294
295
296
297
298
299
300
301
302
        axis_rules = flax.linen.get_logical_axis_rules()
        axis_rules += ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
        te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)

        with flax.linen.logical_axis_rules(te_extended_axis_rules):

            print(f"Device mesh: {mesh}")
            print(f"Axis rules: {te_extended_axis_rules}")

303
            encoder = Net(num_embed, args.enable_sp)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
304
305
306
307
            inputs = jnp.zeros(input_shape, dtype=jnp.int32)
            masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
            abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)

308
309
310
311
312
313
314
315
            logical_partition_spec = nn.get_partition_spec(abs_var_collect)

            # Note that `nn.logical_to_mesh_sharding` returns a dict with an extra
            # "params" key that contains the sharding for the parameters.
            params_sharding = nn.logical_to_mesh_sharding(
                logical_partition_spec, mesh, te_extended_axis_rules
            )

316
317
            inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
            masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
318

319
            in_shardings = (None, inputs_sharding, masks_sharding)
320
            out_shardings = {
321
322
                key: params_sharding[PARAMS_KEY] if key is PARAMS_KEY else None
                for key in abs_var_collect
323
            }
324
325
326
            jit_encoder_init = jax.jit(
                encoder.init, in_shardings=in_shardings, out_shardings=out_shardings
            )
327
            var_collect = jit_encoder_init(init_rngs, inputs, masks)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
328

329
330
331
            # Check if params are sufficiently sharded after initialization
            assert_params_sufficiently_sharded(var_collect, mesh, print_info=False)

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
332
            optimizer = optax.adamw(args.lr)
333
            var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
334
335
336
            state = train_state.TrainState.create(
                apply_fn=encoder.apply, params=params, tx=optimizer
            )
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351

            abs_state = jax.eval_shape(
                lambda: train_state.TrainState.create(
                    apply_fn=encoder.apply, params=params, tx=optimizer
                )
            )
            logical_state_partition_spec = nn.get_partition_spec(abs_state)
            state_sharding = nn.logical_to_mesh_sharding(
                logical_state_partition_spec, mesh, te_extended_axis_rules
            )

            # Check if params are sufficiently sharded after jitting the state creation
            assert_params_sufficiently_sharded(state.params, mesh, print_info=False)

            # state_sharding = get_state_sharding(state, params_sharding)
352
353
354
355
356
357
358
359
360
            labels_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS))

            in_shardings = (
                state_sharding,
                inputs_sharding,
                masks_sharding,
                labels_sharding,
                None,
                None,
361
            )
362
            out_shardings = (state_sharding, None, None, None)
363
364
365
            jit_train_step = jax.jit(
                train_step, in_shardings=in_shardings, out_shardings=out_shardings
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
366

367
368
369
370
371
372
373
374
            in_shardings = (
                state_sharding,
                inputs_sharding,
                masks_sharding,
                labels_sharding,
                None,
                None,
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
375
            out_shardings = (None, None)
376
377
378
            jit_eval_step = jax.jit(
                eval_step, in_shardings=in_shardings, out_shardings=out_shardings
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
379
380
381
382
383
384
385

            if args.use_fp8:
                labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
                check_fp8(state, var_collect, inputs, masks, labels)

            if args.dry_run:
                labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
386
                rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng_state}
387
                jit_train_step(state, inputs, masks, labels, var_collect, rngs)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
388
389
390
391
                print("PASSED")
                return None

            for epoch in range(1, args.epochs + 1):
392
                # Split and reassign to 'rng' to ensure unique rng for each step
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
393
394
                rng, input_rng = jax.random.split(rng)
                rng, dropout_rng = jax.random.split(rng)
395
396
                rng, sr_rng = jax.random.split(rng)
                rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
397
398

                state, train_loss, train_accuracy, var_collect = train_epoch(
399
                    state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
400
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
401

402
                test_loss, test_accuracy = eval_model(
403
                    state, test_ds, args.test_batch_size, var_collect, jit_eval_step, rngs
404
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
405

406
407
408
409
410
411
412
                print(
                    f"Epoch: {epoch:>2} "
                    f"Train Loss: {train_loss:.6f} "
                    f"Train Accuracy: {train_accuracy:.6f} "
                    f"Test Loss: {test_loss:.6f} "
                    f"Test Accuracy: {test_accuracy:.6f} "
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
413
414
415
416
417
418
419
420
421
422

            return [train_loss, train_accuracy, test_loss, test_accuracy]


def encoder_parser(args):
    """Training settings."""
    parser = argparse.ArgumentParser(description="JAX Encoder Example")
    parser.add_argument(
        "--batch-size",
        type=int,
423
        default=256,
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
424
        metavar="N",
425
        help="input batch size for training (default: 256)",
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
426
427
428
429
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
430
        default=256,
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
431
        metavar="N",
432
        help="input batch size for testing (default: 256)",
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
433
434
435
436
    )
    parser.add_argument(
        "--max-seq-len",
        type=int,
437
        default=64,
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
        metavar="N",
        help="maximum sequence length (default: 32)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=3,
        metavar="N",
        help="number of epochs to train (default: 3)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.0001,
        metavar="LR",
        help="learning rate (default: 0.0001)",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        default=False,
        help="quickly check a single pass",
    )
461
    parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
462
463
464
465
466
467
    parser.add_argument(
        "--use-fp8",
        action="store_true",
        default=False,
        help="Use FP8 for inference and training without recalibration",
    )
468
469
470
471
472
473
    parser.add_argument(
        "--fp8-recipe",
        action="store_true",
        default="DelayedScaling",
        help="Use FP8 recipe (default: DelayedScaling)",
    )
474
475
476
    parser.add_argument(
        "--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
    )
477
478
479
    parser.add_argument(
        "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
    )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
480
481
482
483
484
485
486

    return parser.parse_args(args)


class TestEncoder(unittest.TestCase):
    """Encoder unittests"""

487
488
489
    is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
    is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
    is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
490

491
    def setUp(self):
Alp Dener's avatar
Alp Dener committed
492
        """Run 5 epochs for testing"""
493
494
495
        # TODO(jberchtold): Remove once fused attention from cuDNN supports determinism on Blackwell
        if "NVTE_FUSED_ATTN" not in os.environ:
            os.environ["NVTE_FUSED_ATTN"] = "0"
Alp Dener's avatar
Alp Dener committed
496
        self.args = encoder_parser(["--epochs", "5"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
497

498
    @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
499
500
501
    def test_te_bf16(self):
        """Test Transformer Engine with BF16"""
        actual = train_and_evaluate(self.args)
502
        assert actual[0] < 0.36 and actual[1] > 0.84
503
504
505
506
507
508
509

    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_delayed_scaling_fp8(self):
        """Test Transformer Engine with DelayedScaling FP8"""
        self.args.use_fp8 = True
        self.args.fp8_recipe = "DelayedScaling"
        actual = train_and_evaluate(self.args)
510
        assert actual[0] < 0.362 and actual[1] > 0.84
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
511

512
513
514
    @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
    def test_te_mxfp8(self):
        """Test Transformer Engine with MXFP8"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
515
        self.args.use_fp8 = True
516
        self.args.fp8_recipe = "MXFP8BlockScaling"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
517
        actual = train_and_evaluate(self.args)
518
519
520
521
522
523
524
525
526
        assert actual[0] < 0.36 and actual[1] > 0.84

    @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
    def test_te_nvfp4(self):
        """Test Transformer Engine with NVFP4"""
        self.args.use_fp8 = True
        self.args.fp8_recipe = "NVFP4BlockScaling"
        actual = train_and_evaluate(self.args)
        assert actual[0] < 0.40 and actual[1] > 0.82
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
527

528
    @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
529
    def test_te_bf16_with_sp(self):
530
531
532
        """Test Transformer Engine with BF16 + SP"""
        self.args.enable_sp = True
        actual = train_and_evaluate(self.args)
533
        assert actual[0] < 0.36 and actual[1] > 0.84
534
535
536
537
538
539
540
541

    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_delayed_scaling_fp8_with_sp(self):
        """Test Transformer Engine with DelayedScaling FP8 + SP"""
        self.args.enable_sp = True
        self.args.use_fp8 = True
        self.args.fp8_recipe = "DelayedScaling"
        actual = train_and_evaluate(self.args)
542
        assert actual[0] < 0.362 and actual[1] > 0.84
543

544
545
546
    @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
    def test_te_mxfp8_with_sp(self):
        """Test Transformer Engine with MXFP8 + SP"""
547
548
        self.args.enable_sp = True
        self.args.use_fp8 = True
549
        self.args.fp8_recipe = "MXFP8BlockScaling"
550
        actual = train_and_evaluate(self.args)
551
552
553
554
555
556
557
558
559
560
        assert actual[0] < 0.36 and actual[1] > 0.84

    @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
    def test_te_nvfp4_with_sp(self):
        """Test Transformer Engine with NVFP4"""
        self.args.enable_sp = True
        self.args.use_fp8 = True
        self.args.fp8_recipe = "NVFP4BlockScaling"
        actual = train_and_evaluate(self.args)
        assert actual[0] < 0.40 and actual[1] > 0.82
561

562
563
564
565
566
    @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
    def test_te_bf16_shardy(self):
        """Test Transformer Engine with BF16"""
        self.args.enable_shardy = True
        actual = train_and_evaluate(self.args)
567
        assert actual[0] < 0.36 and actual[1] > 0.84
568
569
570
571
572
573
574
575

    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_delayed_scaling_fp8_shardy(self):
        """Test Transformer Engine with DelayedScaling FP8"""
        self.args.enable_shardy = True
        self.args.use_fp8 = True
        self.args.fp8_recipe = "DelayedScaling"
        actual = train_and_evaluate(self.args)
576
        assert actual[0] < 0.362 and actual[1] > 0.84
577
578
579
580
581
582
583
584
585

    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_delayed_scaling_fp8_with_sp_shardy(self):
        """Test Transformer Engine with DelayedScaling FP8 + SP"""
        self.args.enable_shardy = True
        self.args.enable_sp = True
        self.args.use_fp8 = True
        self.args.fp8_recipe = "DelayedScaling"
        actual = train_and_evaluate(self.args)
586
        assert actual[0] < 0.362 and actual[1] > 0.84
587

Alp Dener's avatar
Alp Dener committed
588
589
590
591
592
593
594
    @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
    def test_te_mxfp8_shardy(self):
        """Test Transformer Engine with MXFP8"""
        self.args.enable_shardy = True
        self.args.use_fp8 = True
        self.args.fp8_recipe = "MXFP8BlockScaling"
        actual = train_and_evaluate(self.args)
595
596
597
598
599
600
601
602
603
604
        assert actual[0] < 0.36 and actual[1] > 0.84

    @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
    def test_te_nvfp4_shardy(self):
        """Test Transformer Engine with NVFP4"""
        self.args.enable_shardy = True
        self.args.use_fp8 = True
        self.args.fp8_recipe = "NVFP4BlockScaling"
        actual = train_and_evaluate(self.args)
        assert actual[0] < 0.40 and actual[1] > 0.82
Alp Dener's avatar
Alp Dener committed
605
606
607
608
609
610
611
612
613

    @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
    def test_te_mxfp8_with_sp_shardy(self):
        """Test Transformer Engine with MXFP8 + SP"""
        self.args.enable_shardy = True
        self.args.enable_sp = True
        self.args.use_fp8 = True
        self.args.fp8_recipe = "MXFP8BlockScaling"
        actual = train_and_evaluate(self.args)
614
615
616
617
618
619
620
621
622
623
624
        assert actual[0] < 0.36 and actual[1] > 0.84

    @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
    def test_te_nvfp4_with_sp_shardy(self):
        """Test Transformer Engine with NVFP4"""
        self.args.enable_shardy = True
        self.args.enable_sp = True
        self.args.use_fp8 = True
        self.args.fp8_recipe = "NVFP4BlockScaling"
        actual = train_and_evaluate(self.args)
        assert actual[0] < 0.40 and actual[1] > 0.82
625

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
626
627
628

if __name__ == "__main__":
    train_and_evaluate(encoder_parser(None))