test_model_parallel_encoder.py 16.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
2
3
#
# See LICENSE for license information.
4
"""Encoder training on multi-GPU with tesnor parallelism"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
5
6
7
8
import argparse
import unittest
from functools import partial

9
import flax
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
10
11
12
13
14
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
15
from datasets import load_dataset
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
16
from flax import linen as nn
17
from flax.linen import partitioning as nn_partitioning
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
18
19
from flax.training import train_state
from jax.experimental import mesh_utils
20
from jax.sharding import PartitionSpec, NamedSharding
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
21
22

import transformer_engine.jax as te
23
import transformer_engine.jax.flax as te_flax
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
24

25
26
from common import is_bf16_supported

27
28
29
30
31
32
33
34
DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model"
NAMED_BROADCAST_AXIS = "my_broadcast_axis"
NAMED_TP_AXIS = "my_tp_axis"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
35
36
37
38


class Net(nn.Module):
    """NLP Encoder"""
39

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
40
    num_embed: int
41
    enable_seq_paral: bool
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
42
43
44
45
46

    @nn.compact
    def __call__(self, x, mask, disable_dropout=False):
        x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)

47
48
49
50
51
52
53
54
55
56
57
58
59
        te_Encoder = partial(
            te_flax.TransformerLayer,
            hidden_size=256,
            mlp_hidden_size=1024,
            num_attention_heads=8,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            dropout_rng_name=DROPOUT_KEY,
            layer_type=te_flax.TransformerLayerType.ENCODER,
            self_attn_mask_type="padding",
            enable_relative_embedding=False,
            enable_sequence_parallel=self.enable_seq_paral,
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
60
61
62
63
        x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

        x = x.reshape(x.shape[0], -1)

64
65
        if self.enable_seq_paral:
            # Trigger all-gather to collect a complete tensor alone seqence on each device.
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
            x = jax.lax.with_sharding_constraint(
                x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
            )

        x = te_flax.DenseGeneral(
            features=256,
            kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
            bias_axes=(NAMED_TP_AXIS,),
        )(x)

        x = te_flax.DenseGeneral(
            features=256,
            kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
            bias_axes=(NAMED_BROADCAST_AXIS,),
        )(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
81

82
        x = nn.Dense(features=2)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
83
84
85
        return x


86
def train_step(state, inputs, masks, labels, var_collect, rngs):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
87
88
89
90
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
91
        one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
92
93
94
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

95
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
96
97
98
99
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(var_collect)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

100
    var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
101
102
103
104
105
    state = state.apply_gradients(grads=grads)

    return state, loss, accuracy, var_collect


106
def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
107
    """Train for a single epoch."""
108
    train_ds_size = len(train_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
109
110
    steps_per_epoch = train_ds_size // batch_size
    perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
111
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
112
113
114
115
116
    perms = perms.reshape((steps_per_epoch, batch_size))
    epoch_loss = []
    epoch_accuracy = []

    for perm in perms:
117
118
119
120
121
122
        batch_inputs = train_ds["sentence"][perm, ...]
        batch_masks = train_ds["mask"][perm, ...]
        batch_labels = train_ds["label"][perm, ...]
        state, loss, accuracy, var_collect = train_fn(
            state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
123
124
125
126
127
128
129
130
131
132
133
134
135
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)

    avg_loss = np.mean(epoch_loss)
    avg_accuracy = np.mean(epoch_accuracy)
    return state, avg_loss, avg_accuracy, var_collect


def eval_step(state, inputs, masks, labels, var_collect):
    """Computes loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
136
        one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
137
138
139
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

140
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
141
142
143
144
145
146
147
    loss, logits = loss_fn(var_collect, disable_dropout=True)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return loss, accuracy


def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
    """Evaluation loop."""
148
    test_ds_size = len(test_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
149
150
151
152
153
154
155
    num_steps = test_ds_size // batch_size
    valid_size = num_steps * batch_size
    all_loss = []
    all_accuracy = []

    for batch_start in range(0, valid_size, batch_size):
        batch_end = batch_start + batch_size
156
157
158
        batch_inputs = test_ds["sentence"][batch_start:batch_end]
        batch_masks = test_ds["mask"][batch_start:batch_end]
        batch_labels = test_ds["label"][batch_start:batch_end]
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
159
160
161
162
163
164
165
166
167
168
169
        loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
        all_loss.append(loss)
        all_accuracy.append(accuracy)

    avg_loss = np.mean(all_loss)
    avg_accuracy = np.mean(all_accuracy)
    return avg_loss, avg_accuracy


def data_preprocess(dataset, vocab, word_id, max_seq_len):
    """Convert tokens to numbers."""
170
    nltk.download("punkt_tab")
171
    dataset_size = len(dataset["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
172
    output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
173
    mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
174

175
    for j, sentence in enumerate(dataset["sentence"]):
176
        tokens = nltk.word_tokenize(sentence)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
177
178
179
180
181
182
183
184
185
186
187
188
189
        tensor = output[j]

        for i, word in enumerate(tokens):
            if i >= max_seq_len:
                break

            if word not in vocab:
                vocab[word] = word_id
                tensor[i] = word_id
                word_id = word_id + 1
            else:
                tensor[i] = vocab[word]

190
        seq_len = min(len(tokens), max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
191
        mask_2d = mask_3d[j]
192
        mask_2d[:seq_len, :seq_len] = 0
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
193

194
    new_dataset = {
195
196
197
        "sentence": output,
        "label": dataset["label"].astype(np.float32),
        "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
198
199
    }
    return new_dataset, vocab, word_id
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
200
201
202
203
204
205


def get_datasets(max_seq_len):
    """Load GLUE train and test datasets into memory."""
    vocab = {}
    word_id = 0
206

207
208
    train_ds = load_dataset("glue", "cola", split="train")
    train_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
209
    train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
210

211
212
    test_ds = load_dataset("glue", "cola", split="validation")
    test_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
213
214
215
216
217
218
219
    test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
    return train_ds, test_ds, word_id


def check_fp8(state, var_collect, inputs, masks, labels):
    "Check if model includes FP8."
    rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
220
    assert "fp8_" in str(
221
222
        jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
    )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
223
224


225
226
227
def get_params_sharding(sharding_rules, abs_var_collect, mesh):
    """Refer params to create params sharding"""
    rules_dict = dict(sharding_rules)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
228
229
230

    def to_device_axis(logical_axis):
        partitions = [rules_dict[key] for key in logical_axis]
231
        return NamedSharding(mesh, PartitionSpec(*partitions))
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
232
233

    params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
234
235
236
237
238
    params_axes_sharding = jax.tree_util.tree_map(
        to_device_axis, nn_partitioning.get_axis_names(params_axes)
    )
    params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
    params_sharding = jax.tree_util.tree_map(
239
        lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY]
240
241
242
    )
    params_sharding = {**params_sharding, **params_axes_sharding}
    return params_sharding
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
243
244


245
246
def get_state_sharding(state, params_sharding):
    """Refer params_sharding to create state sharding"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
247
248

    def replace_params(x):
249
        return params_sharding if isinstance(x, dict) else None
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
250

251
252
253
254
    state_sharding = jax.tree_util.tree_map(
        replace_params, state, is_leaf=lambda x: isinstance(x, dict)
    )
    return state_sharding
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
255
256
257
258
259


def train_and_evaluate(args):
    """Execute model training and evaluation loop."""
    print(args)
260
    train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
261

262
    num_gpu = jax.local_device_count()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
263
    num_gpu_tp = 2
264
265
    if num_gpu % num_gpu_tp == 0:
        num_gpu_dp = num_gpu // num_gpu_tp
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
266
267
268
269
270
    else:
        num_gpu_dp = 1
        num_gpu_tp = 1

    assert args.batch_size % num_gpu_dp == 0, f"Batch size needs to be multiple of {num_gpu_dp}"
271
272
273
    assert (
        args.test_batch_size % num_gpu_dp == 0
    ), f"Test batch size needs to be multiple of {num_gpu_dp}"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
274
275

    device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
276
277
278
    with jax.sharding.Mesh(
        devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
    ) as mesh:
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
279
280
281
282
283
284
285
286
287
288

        rng = jax.random.PRNGKey(args.seed)
        rng, params_rng = jax.random.split(rng)
        rng, dropout_rng = jax.random.split(rng)
        init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}

        input_shape = [args.batch_size, args.max_seq_len]
        mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
        label_shape = [args.batch_size]

289
290
291
        with te.fp8_autocast(
            args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None)
        ):
292
            encoder = Net(num_embed, args.enable_sp)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
293
294
295
296
297
            inputs = jnp.zeros(input_shape, dtype=jnp.int32)
            masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
            abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)

            customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
298
            sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
299
300
301
            params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
            inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
            masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
302

303
            in_shardings = (None, inputs_sharding, masks_sharding)
304
            out_shardings = {
305
                key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
306
            }
307
308
            jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings)
            var_collect = jit_encoder_init(init_rngs, inputs, masks)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
309
310

            optimizer = optax.adamw(args.lr)
311
            var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
312
313
314
            state = train_state.TrainState.create(
                apply_fn=encoder.apply, params=params, tx=optimizer
            )
315
316
317
318
319
320
321
322
323
324
            state_sharding = get_state_sharding(state, params_sharding)
            labels_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS))

            in_shardings = (
                state_sharding,
                inputs_sharding,
                masks_sharding,
                labels_sharding,
                None,
                None,
325
            )
326
327
            out_shardings = (state_sharding, None, None, None)
            jit_train_step = jax.jit(train_step, in_shardings, out_shardings)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
328

329
            in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
330
            out_shardings = (None, None)
331
            jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
332
333
334
335
336
337
338
339

            if args.use_fp8:
                labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
                check_fp8(state, var_collect, inputs, masks, labels)

            if args.dry_run:
                labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
                rngs = {DROPOUT_KEY: dropout_rng}
340
                jit_train_step(state, inputs, masks, labels, var_collect, rngs)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
341
342
343
344
345
346
347
348
349
                print("PASSED")
                return None

            for epoch in range(1, args.epochs + 1):
                rng, input_rng = jax.random.split(rng)
                rng, dropout_rng = jax.random.split(rng)
                rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}

                state, train_loss, train_accuracy, var_collect = train_epoch(
350
                    state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
351
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
352

353
                test_loss, test_accuracy = eval_model(
354
                    state, test_ds, args.test_batch_size, var_collect, jit_eval_step
355
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
356

357
358
359
360
361
362
363
                print(
                    f"Epoch: {epoch:>2} "
                    f"Train Loss: {train_loss:.6f} "
                    f"Train Accuracy: {train_accuracy:.6f} "
                    f"Test Loss: {test_loss:.6f} "
                    f"Test Accuracy: {test_accuracy:.6f} "
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411

            return [train_loss, train_accuracy, test_loss, test_accuracy]


def encoder_parser(args):
    """Training settings."""
    parser = argparse.ArgumentParser(description="JAX Encoder Example")
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for testing (default: 64)",
    )
    parser.add_argument(
        "--max-seq-len",
        type=int,
        default=32,
        metavar="N",
        help="maximum sequence length (default: 32)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=3,
        metavar="N",
        help="number of epochs to train (default: 3)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.0001,
        metavar="LR",
        help="learning rate (default: 0.0001)",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        default=False,
        help="quickly check a single pass",
    )
412
    parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
413
414
415
416
417
418
419
420
421
    parser.add_argument(
        "--use-fp8",
        action="store_true",
        default=False,
        help="Use FP8 for inference and training without recalibration",
    )
    parser.add_argument(
        "--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
    )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
422
423
424
425
426
427
428

    return parser.parse_args(args)


class TestEncoder(unittest.TestCase):
    """Encoder unittests"""

429
430
    gpu_has_fp8, reason = te.fp8.is_fp8_available()

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
431
432
433
    @classmethod
    def setUpClass(cls):
        """Run 3 epochs for testing"""
434
        cls.args = encoder_parser(["--epochs", "3"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
435

436
    @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
437
438
439
440
441
    def test_te_bf16(self):
        """Test Transformer Engine with BF16"""
        actual = train_and_evaluate(self.args)
        assert actual[0] < 0.45 and actual[1] > 0.79

442
    @unittest.skipIf(not gpu_has_fp8, reason)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
443
444
445
446
    def test_te_fp8(self):
        """Test Transformer Engine with FP8"""
        self.args.use_fp8 = True
        actual = train_and_evaluate(self.args)
447
        assert actual[0] < 0.455 and actual[1] > 0.785
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
448

449
    @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
450
451
452
453
454
455
456
457
458
459
460
461
    def test_te_bf16_sp(self):
        """Test Transformer Engine with BF16 + SP"""
        self.args.enable_sp = True
        actual = train_and_evaluate(self.args)
        assert actual[0] < 0.45 and actual[1] > 0.79

    @unittest.skipIf(not gpu_has_fp8, reason)
    def test_te_fp8_sp(self):
        """Test Transformer Engine with FP8 + SP"""
        self.args.enable_sp = True
        self.args.use_fp8 = True
        actual = train_and_evaluate(self.args)
462
        assert actual[0] < 0.455 and actual[1] > 0.785
463

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
464
465
466

if __name__ == "__main__":
    train_and_evaluate(encoder_parser(None))