test_model_parallel_encoder.py 16.3 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
2
3
#
# See LICENSE for license information.
4
"""Encoder training on multi-GPU with tesnor parallelism"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
5
6
7
8
import argparse
import unittest
from functools import partial

9
import flax
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
10
11
12
13
14
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
15
from datasets import load_dataset
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
16
from flax import linen as nn
17
from flax.linen import partitioning as nn_partitioning
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
18
19
from flax.training import train_state
from jax.experimental import mesh_utils
20
from jax.sharding import PartitionSpec, NamedSharding
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
21
22

import transformer_engine.jax as te
23
import transformer_engine.jax.flax as te_flax
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
24

25
26
from common import is_bf16_supported

27
28
29
30
31
32
33
34
DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model"
NAMED_BROADCAST_AXIS = "my_broadcast_axis"
NAMED_TP_AXIS = "my_tp_axis"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
35
36
37
38


class Net(nn.Module):
    """NLP Encoder"""
39

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
40
    num_embed: int
41
    enable_seq_paral: bool
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
42
43
44
45
46

    @nn.compact
    def __call__(self, x, mask, disable_dropout=False):
        x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)

47
48
49
50
51
52
53
54
55
56
57
58
59
60
        te_Encoder = partial(
            te_flax.TransformerLayer,
            hidden_size=256,
            mlp_hidden_size=1024,
            num_attention_heads=8,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            dropout_rng_name=DROPOUT_KEY,
            layer_type=te_flax.TransformerLayerType.ENCODER,
            self_attn_mask_type="padding",
            enable_relative_embedding=False,
            enable_sequence_parallel=self.enable_seq_paral,
            dtype=jnp.bfloat16,
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
61
62
63
64
        x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

        x = x.reshape(x.shape[0], -1)

65
66
        if self.enable_seq_paral:
            # Trigger all-gather to collect a complete tensor alone seqence on each device.
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
            x = jax.lax.with_sharding_constraint(
                x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
            )

        x = te_flax.DenseGeneral(
            features=256,
            kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
            bias_axes=(NAMED_TP_AXIS,),
            dtype=jnp.bfloat16,
        )(x)

        x = te_flax.DenseGeneral(
            features=256,
            kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
            bias_axes=(NAMED_BROADCAST_AXIS,),
            dtype=jnp.bfloat16,
        )(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
84
85
86
87
88

        x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
        return x


89
def train_step(state, inputs, masks, labels, var_collect, rngs):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
90
91
92
93
94
95
96
97
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
        one_hot = jax.nn.one_hot(labels, 2)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

98
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
99
100
101
102
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(var_collect)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

103
    var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
104
105
106
107
108
    state = state.apply_gradients(grads=grads)

    return state, loss, accuracy, var_collect


109
def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
110
    """Train for a single epoch."""
111
    train_ds_size = len(train_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
112
113
    steps_per_epoch = train_ds_size // batch_size
    perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
114
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
115
116
117
118
119
    perms = perms.reshape((steps_per_epoch, batch_size))
    epoch_loss = []
    epoch_accuracy = []

    for perm in perms:
120
121
122
123
124
125
        batch_inputs = train_ds["sentence"][perm, ...]
        batch_masks = train_ds["mask"][perm, ...]
        batch_labels = train_ds["label"][perm, ...]
        state, loss, accuracy, var_collect = train_fn(
            state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)

    avg_loss = np.mean(epoch_loss)
    avg_accuracy = np.mean(epoch_accuracy)
    return state, avg_loss, avg_accuracy, var_collect


def eval_step(state, inputs, masks, labels, var_collect):
    """Computes loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
        one_hot = jax.nn.one_hot(labels, 2)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

143
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
144
145
146
147
148
149
150
    loss, logits = loss_fn(var_collect, disable_dropout=True)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return loss, accuracy


def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
    """Evaluation loop."""
151
    test_ds_size = len(test_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
152
153
154
155
156
157
158
    num_steps = test_ds_size // batch_size
    valid_size = num_steps * batch_size
    all_loss = []
    all_accuracy = []

    for batch_start in range(0, valid_size, batch_size):
        batch_end = batch_start + batch_size
159
160
161
        batch_inputs = test_ds["sentence"][batch_start:batch_end]
        batch_masks = test_ds["mask"][batch_start:batch_end]
        batch_labels = test_ds["label"][batch_start:batch_end]
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
162
163
164
165
166
167
168
169
170
171
172
        loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
        all_loss.append(loss)
        all_accuracy.append(accuracy)

    avg_loss = np.mean(all_loss)
    avg_accuracy = np.mean(all_accuracy)
    return avg_loss, avg_accuracy


def data_preprocess(dataset, vocab, word_id, max_seq_len):
    """Convert tokens to numbers."""
173
    nltk.download("punkt_tab")
174
    dataset_size = len(dataset["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
175
    output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
176
    mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
177

178
    for j, sentence in enumerate(dataset["sentence"]):
179
        tokens = nltk.word_tokenize(sentence)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
180
181
182
183
184
185
186
187
188
189
190
191
192
        tensor = output[j]

        for i, word in enumerate(tokens):
            if i >= max_seq_len:
                break

            if word not in vocab:
                vocab[word] = word_id
                tensor[i] = word_id
                word_id = word_id + 1
            else:
                tensor[i] = vocab[word]

193
        seq_len = min(len(tokens), max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
194
        mask_2d = mask_3d[j]
195
        mask_2d[:seq_len, :seq_len] = 0
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
196

197
    new_dataset = {
198
199
200
        "sentence": output,
        "label": dataset["label"].astype(np.float32),
        "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
201
202
    }
    return new_dataset, vocab, word_id
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
203
204
205
206
207
208


def get_datasets(max_seq_len):
    """Load GLUE train and test datasets into memory."""
    vocab = {}
    word_id = 0
209

210
211
    train_ds = load_dataset("glue", "cola", split="train")
    train_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
212
    train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
213

214
215
    test_ds = load_dataset("glue", "cola", split="validation")
    test_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
216
217
218
219
220
221
222
    test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
    return train_ds, test_ds, word_id


def check_fp8(state, var_collect, inputs, masks, labels):
    "Check if model includes FP8."
    rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
223
    assert "fp8_" in str(
224
225
        jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
    )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
226
227


228
229
230
def get_params_sharding(sharding_rules, abs_var_collect, mesh):
    """Refer params to create params sharding"""
    rules_dict = dict(sharding_rules)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
231
232
233

    def to_device_axis(logical_axis):
        partitions = [rules_dict[key] for key in logical_axis]
234
        return NamedSharding(mesh, PartitionSpec(*partitions))
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
235
236

    params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
237
238
239
240
241
242
243
244
245
    params_axes_sharding = jax.tree_util.tree_map(
        to_device_axis, nn_partitioning.get_axis_names(params_axes)
    )
    params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
    params_sharding = jax.tree_util.tree_map(
        lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY]
    )
    params_sharding = {**params_sharding, **params_axes_sharding}
    return params_sharding
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
246
247


248
249
def get_state_sharding(state, params_sharding):
    """Refer params_sharding to create state sharding"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
250
251

    def replace_params(x):
252
        return params_sharding if isinstance(x, dict) else None
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
253

254
255
256
257
    state_sharding = jax.tree_util.tree_map(
        replace_params, state, is_leaf=lambda x: isinstance(x, dict)
    )
    return state_sharding
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
258
259
260
261
262


def train_and_evaluate(args):
    """Execute model training and evaluation loop."""
    print(args)
263
    train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
264

265
    num_gpu = jax.local_device_count()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
266
    num_gpu_tp = 2
267
268
    if num_gpu % num_gpu_tp == 0:
        num_gpu_dp = num_gpu // num_gpu_tp
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
269
270
271
272
273
    else:
        num_gpu_dp = 1
        num_gpu_tp = 1

    assert args.batch_size % num_gpu_dp == 0, f"Batch size needs to be multiple of {num_gpu_dp}"
274
275
276
    assert (
        args.test_batch_size % num_gpu_dp == 0
    ), f"Test batch size needs to be multiple of {num_gpu_dp}"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
277
278

    device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
279
280
281
    with jax.sharding.Mesh(
        devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
    ) as mesh:
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
282
283
284
285
286
287
288
289
290
291

        rng = jax.random.PRNGKey(args.seed)
        rng, params_rng = jax.random.split(rng)
        rng, dropout_rng = jax.random.split(rng)
        init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}

        input_shape = [args.batch_size, args.max_seq_len]
        mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
        label_shape = [args.batch_size]

292
293
294
        with te.fp8_autocast(
            args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None)
        ):
295
            encoder = Net(num_embed, args.enable_sp)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
296
297
298
299
300
            inputs = jnp.zeros(input_shape, dtype=jnp.int32)
            masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
            abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)

            customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
301
            sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
302
303
304
            params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
            inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
            masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
305

306
            in_shardings = (None, inputs_sharding, masks_sharding)
307
            out_shardings = {
308
                key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
309
            }
310
311
            jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings)
            var_collect = jit_encoder_init(init_rngs, inputs, masks)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
312
313

            optimizer = optax.adamw(args.lr)
314
            var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
315
316
317
            state = train_state.TrainState.create(
                apply_fn=encoder.apply, params=params, tx=optimizer
            )
318
319
320
321
322
323
324
325
326
327
            state_sharding = get_state_sharding(state, params_sharding)
            labels_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS))

            in_shardings = (
                state_sharding,
                inputs_sharding,
                masks_sharding,
                labels_sharding,
                None,
                None,
328
            )
329
330
            out_shardings = (state_sharding, None, None, None)
            jit_train_step = jax.jit(train_step, in_shardings, out_shardings)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
331

332
            in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
333
            out_shardings = (None, None)
334
            jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
335
336
337
338
339
340
341
342

            if args.use_fp8:
                labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
                check_fp8(state, var_collect, inputs, masks, labels)

            if args.dry_run:
                labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
                rngs = {DROPOUT_KEY: dropout_rng}
343
                jit_train_step(state, inputs, masks, labels, var_collect, rngs)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
344
345
346
347
348
349
350
351
352
                print("PASSED")
                return None

            for epoch in range(1, args.epochs + 1):
                rng, input_rng = jax.random.split(rng)
                rng, dropout_rng = jax.random.split(rng)
                rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}

                state, train_loss, train_accuracy, var_collect = train_epoch(
353
                    state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
354
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
355

356
                test_loss, test_accuracy = eval_model(
357
                    state, test_ds, args.test_batch_size, var_collect, jit_eval_step
358
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
359

360
361
362
363
364
365
366
                print(
                    f"Epoch: {epoch:>2} "
                    f"Train Loss: {train_loss:.6f} "
                    f"Train Accuracy: {train_accuracy:.6f} "
                    f"Test Loss: {test_loss:.6f} "
                    f"Test Accuracy: {test_accuracy:.6f} "
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414

            return [train_loss, train_accuracy, test_loss, test_accuracy]


def encoder_parser(args):
    """Training settings."""
    parser = argparse.ArgumentParser(description="JAX Encoder Example")
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for testing (default: 64)",
    )
    parser.add_argument(
        "--max-seq-len",
        type=int,
        default=32,
        metavar="N",
        help="maximum sequence length (default: 32)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=3,
        metavar="N",
        help="number of epochs to train (default: 3)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.0001,
        metavar="LR",
        help="learning rate (default: 0.0001)",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        default=False,
        help="quickly check a single pass",
    )
415
    parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
416
417
418
419
420
421
422
423
424
    parser.add_argument(
        "--use-fp8",
        action="store_true",
        default=False,
        help="Use FP8 for inference and training without recalibration",
    )
    parser.add_argument(
        "--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
    )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
425
426
427
428
429
430
431

    return parser.parse_args(args)


class TestEncoder(unittest.TestCase):
    """Encoder unittests"""

432
433
    gpu_has_fp8, reason = te.fp8.is_fp8_available()

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
434
435
436
    @classmethod
    def setUpClass(cls):
        """Run 3 epochs for testing"""
437
        cls.args = encoder_parser(["--epochs", "3"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
438

439
    @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
440
441
442
443
444
    def test_te_bf16(self):
        """Test Transformer Engine with BF16"""
        actual = train_and_evaluate(self.args)
        assert actual[0] < 0.45 and actual[1] > 0.79

445
    @unittest.skipIf(not gpu_has_fp8, reason)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
446
447
448
449
450
451
    def test_te_fp8(self):
        """Test Transformer Engine with FP8"""
        self.args.use_fp8 = True
        actual = train_and_evaluate(self.args)
        assert actual[0] < 0.45 and actual[1] > 0.79

452
    @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
453
454
455
456
457
458
459
460
461
462
463
464
465
466
    def test_te_bf16_sp(self):
        """Test Transformer Engine with BF16 + SP"""
        self.args.enable_sp = True
        actual = train_and_evaluate(self.args)
        assert actual[0] < 0.45 and actual[1] > 0.79

    @unittest.skipIf(not gpu_has_fp8, reason)
    def test_te_fp8_sp(self):
        """Test Transformer Engine with FP8 + SP"""
        self.args.enable_sp = True
        self.args.use_fp8 = True
        actual = train_and_evaluate(self.args)
        assert actual[0] < 0.45 and actual[1] > 0.79

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
467
468
469

if __name__ == "__main__":
    train_and_evaluate(encoder_parser(None))