test_multigpu_encoder.py 14.4 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
2
3
#
# See LICENSE for license information.
4
"""Encoder training on multi-GPU with data parallelism"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
5
6
7
8
import argparse
import unittest
from functools import partial

9
import flax
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
10
11
12
13
14
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
15
from datasets import load_dataset
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
16
from flax import linen as nn
17
from flax.linen import partitioning as nn_partitioning
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
18
19
from flax.training import train_state
from jax.experimental import mesh_utils
20
from jax.sharding import PartitionSpec, NamedSharding
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
21
22

import transformer_engine.jax as te
23
import transformer_engine.jax.flax as te_flax
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
24

25
26
27
28
29
DEVICE_DP_AXIS = "data"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
30
31
32
33


class Net(nn.Module):
    """NLP Encoder"""
34

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
35
36
37
38
39
40
    num_embed: int

    @nn.compact
    def __call__(self, x, mask, disable_dropout=False):
        x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)

41
42
43
44
45
46
47
48
49
50
51
52
53
        te_Encoder = partial(
            te_flax.TransformerLayer,
            hidden_size=256,
            mlp_hidden_size=1024,
            num_attention_heads=8,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            dropout_rng_name=DROPOUT_KEY,
            layer_type=te_flax.TransformerLayerType.ENCODER,
            self_attn_mask_type="padding",
            enable_relative_embedding=False,
            dtype=jnp.bfloat16,
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
54
55
56
57
        x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

        x = x.reshape(x.shape[0], -1)

58
        x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
59

60
        x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
61
62
63
64
65

        x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
        return x


66
def train_step(state, inputs, masks, labels, var_collect, rngs):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
67
68
69
70
71
72
73
74
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
        one_hot = jax.nn.one_hot(labels, 2)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

75
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
76
77
78
79
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(var_collect)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

80
    var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
81
82
83
84
85
    state = state.apply_gradients(grads=grads)

    return state, loss, accuracy, var_collect


86
def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
87
    """Train for a single epoch."""
88
    train_ds_size = len(train_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
89
90
    steps_per_epoch = train_ds_size // batch_size
    perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
91
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
92
93
94
95
96
    perms = perms.reshape((steps_per_epoch, batch_size))
    epoch_loss = []
    epoch_accuracy = []

    for perm in perms:
97
98
99
100
101
102
        batch_inputs = train_ds["sentence"][perm, ...]
        batch_masks = train_ds["mask"][perm, ...]
        batch_labels = train_ds["label"][perm, ...]
        state, loss, accuracy, var_collect = train_fn(
            state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)

    avg_loss = np.mean(epoch_loss)
    avg_accuracy = np.mean(epoch_accuracy)
    return state, avg_loss, avg_accuracy, var_collect


def eval_step(state, inputs, masks, labels, var_collect):
    """Computes loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
        one_hot = jax.nn.one_hot(labels, 2)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

120
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
121
122
123
124
125
126
127
    loss, logits = loss_fn(var_collect, disable_dropout=True)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return loss, accuracy


def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
    """Evaluation loop."""
128
    test_ds_size = len(test_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
129
130
131
132
133
134
135
    num_steps = test_ds_size // batch_size
    valid_size = num_steps * batch_size
    all_loss = []
    all_accuracy = []

    for batch_start in range(0, valid_size, batch_size):
        batch_end = batch_start + batch_size
136
137
138
        batch_inputs = test_ds["sentence"][batch_start:batch_end]
        batch_masks = test_ds["mask"][batch_start:batch_end]
        batch_labels = test_ds["label"][batch_start:batch_end]
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
139
140
141
142
143
144
145
146
147
148
149
        loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
        all_loss.append(loss)
        all_accuracy.append(accuracy)

    avg_loss = np.mean(all_loss)
    avg_accuracy = np.mean(all_accuracy)
    return avg_loss, avg_accuracy


def data_preprocess(dataset, vocab, word_id, max_seq_len):
    """Convert tokens to numbers."""
150
    nltk.download("punkt_tab")
151
    dataset_size = len(dataset["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
152
    output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
153
    mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
154

155
    for j, sentence in enumerate(dataset["sentence"]):
156
        tokens = nltk.word_tokenize(sentence)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
157
158
159
160
161
162
163
164
165
166
167
168
169
        tensor = output[j]

        for i, word in enumerate(tokens):
            if i >= max_seq_len:
                break

            if word not in vocab:
                vocab[word] = word_id
                tensor[i] = word_id
                word_id = word_id + 1
            else:
                tensor[i] = vocab[word]

170
        seq_len = min(len(tokens), max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
171
        mask_2d = mask_3d[j]
172
        mask_2d[:seq_len, :seq_len] = 0
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
173

174
    new_dataset = {
175
176
177
        "sentence": output,
        "label": dataset["label"].astype(np.float32),
        "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
178
179
    }
    return new_dataset, vocab, word_id
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
180
181
182
183
184
185


def get_datasets(max_seq_len):
    """Load GLUE train and test datasets into memory."""
    vocab = {}
    word_id = 0
186

187
188
    train_ds = load_dataset("glue", "cola", split="train")
    train_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
189
    train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
190

191
192
    test_ds = load_dataset("glue", "cola", split="validation")
    test_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
193
194
195
196
197
198
199
    test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
    return train_ds, test_ds, word_id


def check_fp8(state, var_collect, inputs, masks, labels):
    "Check if model includes FP8."
    rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
200
    assert "fp8_" in str(
201
202
        jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
    )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
203
204


205
206
207
def get_params_sharding(sharding_rules, abs_var_collect, mesh):
    """Refer params to create params sharding"""
    rules_dict = dict(sharding_rules)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
208
209
210

    def to_device_axis(logical_axis):
        partitions = [rules_dict[key] for key in logical_axis]
211
        return NamedSharding(mesh, PartitionSpec(*partitions))
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
212
213

    params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
214
215
216
217
218
219
220
221
222
    params_axes_sharding = jax.tree_util.tree_map(
        to_device_axis, nn_partitioning.get_axis_names(params_axes)
    )
    params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
    params_sharding = jax.tree_util.tree_map(
        lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY]
    )
    params_sharding = {**params_sharding, **params_axes_sharding}
    return params_sharding
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
223
224


225
226
def get_state_sharding(state, params_sharding):
    """Refer params_sharding to create state sharding"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
227
228

    def replace_params(x):
229
        return params_sharding if isinstance(x, dict) else None
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
230

231
232
233
234
    state_sharding = jax.tree_util.tree_map(
        replace_params, state, is_leaf=lambda x: isinstance(x, dict)
    )
    return state_sharding
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
235
236
237
238
239


def train_and_evaluate(args):
    """Execute model training and evaluation loop."""
    print(args)
240
    train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
241

242
243
    num_gpu = jax.local_device_count()
    assert args.batch_size % num_gpu == 0, f"Batch size needs to be multiple of {num_gpu}"
244
    assert args.test_batch_size % num_gpu == 0, f"Test batch size needs to be multiple of {num_gpu}"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
245

246
    device_mesh = mesh_utils.create_device_mesh((num_gpu,))
247
    with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh:
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
248
249
250
251
252
253
254
255
256
257

        rng = jax.random.PRNGKey(args.seed)
        rng, params_rng = jax.random.split(rng)
        rng, dropout_rng = jax.random.split(rng)
        init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}

        input_shape = [args.batch_size, args.max_seq_len]
        mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
        label_shape = [args.batch_size]

258
259
260
        with te.fp8_autocast(
            args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None)
        ):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
261
262
263
264
265
            encoder = Net(num_embed)
            inputs = jnp.zeros(input_shape, dtype=jnp.int32)
            masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
            abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)

266
            sharding_rules = te_flax.extend_logical_axis_rules(tuple())
267
268
269
            params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
            inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
            masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
270

271
            in_shardings = (None, inputs_sharding, masks_sharding)
272
            out_shardings = {
273
                key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
274
            }
275
276
            jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings)
            var_collect = jit_encoder_init(init_rngs, inputs, masks)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
277
278

            optimizer = optax.adamw(args.lr)
279
            var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
280
281
282
            state = train_state.TrainState.create(
                apply_fn=encoder.apply, params=params, tx=optimizer
            )
283
284
285
286
287
288
            state_sharding = get_state_sharding(state, params_sharding)
            labels_sharding = NamedSharding(
                mesh,
                PartitionSpec(
                    DEVICE_DP_AXIS,
                ),
289
            )
290
291
292
293
294
295
296
297
298
299
            in_shardings = (
                state_sharding,
                inputs_sharding,
                masks_sharding,
                labels_sharding,
                None,
                None,
            )
            out_shardings = (state_sharding, None, None, None)
            jit_train_step = jax.jit(train_step, in_shardings, out_shardings)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
300

301
            in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
302
            out_shardings = (None, None)
303
            jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
304
305
306
307
308
309
310
311

            if args.use_fp8:
                labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
                check_fp8(state, var_collect, inputs, masks, labels)

            if args.dry_run:
                labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
                rngs = {DROPOUT_KEY: dropout_rng}
312
                jit_train_step(state, inputs, masks, labels, var_collect, rngs)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
313
314
315
316
317
318
319
320
321
                print("PASSED")
                return None

            for epoch in range(1, args.epochs + 1):
                rng, input_rng = jax.random.split(rng)
                rng, dropout_rng = jax.random.split(rng)
                rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}

                state, train_loss, train_accuracy, var_collect = train_epoch(
322
                    state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
323
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
324

325
                test_loss, test_accuracy = eval_model(
326
                    state, test_ds, args.test_batch_size, var_collect, jit_eval_step
327
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
328

329
330
331
332
333
334
335
                print(
                    f"Epoch: {epoch:>2} "
                    f"Train Loss: {train_loss:.6f} "
                    f"Train Accuracy: {train_accuracy:.6f} "
                    f"Test Loss: {test_loss:.6f} "
                    f"Test Accuracy: {test_accuracy:.6f} "
                )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
336
337
338
339
340
341
342
343
344
345

            return [train_loss, train_accuracy, test_loss, test_accuracy]


def encoder_parser(args):
    """Training settings."""
    parser = argparse.ArgumentParser(description="JAX Encoder Example")
    parser.add_argument(
        "--batch-size",
        type=int,
346
        default=128,
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
347
        metavar="N",
348
        help="input batch size for training (default: 128)",
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
349
350
351
352
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
353
        default=128,
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
354
        metavar="N",
355
        help="input batch size for testing (default: 128)",
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    )
    parser.add_argument(
        "--max-seq-len",
        type=int,
        default=32,
        metavar="N",
        help="maximum sequence length (default: 32)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=3,
        metavar="N",
        help="number of epochs to train (default: 3)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.0001,
        metavar="LR",
        help="learning rate (default: 0.0001)",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        default=False,
        help="quickly check a single pass",
    )
384
    parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
385
386
387
388
389
390
    parser.add_argument(
        "--use-fp8",
        action="store_true",
        default=False,
        help="Use FP8 for inference and training without recalibration",
    )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
391
392
393
394
395
396
397

    return parser.parse_args(args)


class TestEncoder(unittest.TestCase):
    """Encoder unittests"""

398
399
    gpu_has_fp8, reason = te.fp8.is_fp8_available()

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
400
401
402
    @classmethod
    def setUpClass(cls):
        """Run 3 epochs for testing"""
403
        cls.args = encoder_parser(["--epochs", "3"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
404
405
406
407

    def test_te_bf16(self):
        """Test Transformer Engine with BF16"""
        actual = train_and_evaluate(self.args)
408
        assert actual[0] < 0.50 and actual[1] > 0.76
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
409

410
    @unittest.skipIf(not gpu_has_fp8, reason)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
411
412
413
414
    def test_te_fp8(self):
        """Test Transformer Engine with FP8"""
        self.args.use_fp8 = True
        actual = train_and_evaluate(self.args)
415
        assert actual[0] < 0.50 and actual[1] > 0.76
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
416
417
418
419


if __name__ == "__main__":
    train_and_evaluate(encoder_parser(None))