test_single_gpu_encoder.py 10.8 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
2
3
#
# See LICENSE for license information.
4
"""Encoder training on single GPU"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
5
6
7
8
import argparse
import unittest
from functools import partial

9
import flax
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
10
11
12
13
14
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
15
from datasets import load_dataset
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
16
17
18
19
from flax import linen as nn
from flax.training import train_state

import transformer_engine.jax as te
20
import transformer_engine.jax.flax as te_flax
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
21

22
23
24
PARAMS_KEY = "params"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
25
26
27
28


class Net(nn.Module):
    """NLP Encoder"""
29

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
30
31
32
33
34
35
    num_embed: int

    @nn.compact
    def __call__(self, x, mask, disable_dropout=False):
        x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)

36
37
38
39
40
41
42
43
44
45
46
47
48
        te_Encoder = partial(
            te_flax.TransformerLayer,
            hidden_size=256,
            mlp_hidden_size=1024,
            num_attention_heads=8,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            dropout_rng_name=DROPOUT_KEY,
            layer_type=te_flax.TransformerLayerType.ENCODER,
            self_attn_mask_type="padding",
            enable_relative_embedding=False,
            dtype=jnp.bfloat16,
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
49
50
51
52
        x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

        x = x.reshape(x.shape[0], -1)

53
        x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
54

55
        x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
56
57
58
59
60

        x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
        return x


61
@partial(jax.jit)
62
def train_step(state, inputs, masks, labels, var_collect, rngs):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
63
64
65
66
67
68
69
70
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
        one_hot = jax.nn.one_hot(labels, 2)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

71
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
72
73
74
75
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(var_collect)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

76
    var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
77
78
79
80
81
    state = state.apply_gradients(grads=grads)

    return state, loss, accuracy, var_collect


82
def train_epoch(state, train_ds, batch_size, rngs, var_collect):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
83
    """Train for a single epoch."""
84
    train_ds_size = len(train_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
85
86
    steps_per_epoch = train_ds_size // batch_size
    perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
87
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
88
89
90
91
92
    perms = perms.reshape((steps_per_epoch, batch_size))
    epoch_loss = []
    epoch_accuracy = []

    for perm in perms:
93
94
95
96
97
98
        batch_inputs = train_ds["sentence"][perm, ...]
        batch_masks = train_ds["mask"][perm, ...]
        batch_labels = train_ds["label"][perm, ...]
        state, loss, accuracy, var_collect = train_step(
            state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)

    avg_loss = np.mean(epoch_loss)
    avg_accuracy = np.mean(epoch_accuracy)
    return state, avg_loss, avg_accuracy, var_collect


@jax.jit
def eval_step(state, inputs, masks, labels, var_collect):
    """Computes loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
        one_hot = jax.nn.one_hot(labels, 2)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

117
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
118
119
120
121
122
123
124
    loss, logits = loss_fn(var_collect, disable_dropout=True)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return loss, accuracy


def eval_model(state, test_ds, batch_size, var_collect):
    """Evaluation loop."""
125
    test_ds_size = len(test_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
126
127
128
129
130
131
132
    num_steps = test_ds_size // batch_size
    valid_size = num_steps * batch_size
    all_loss = []
    all_accuracy = []

    for batch_start in range(0, valid_size, batch_size):
        batch_end = batch_start + batch_size
133
134
135
        batch_inputs = test_ds["sentence"][batch_start:batch_end]
        batch_masks = test_ds["mask"][batch_start:batch_end]
        batch_labels = test_ds["label"][batch_start:batch_end]
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
136
137
138
139
140
141
142
143
144
145
146
        loss, accuracy = eval_step(state, batch_inputs, batch_masks, batch_labels, var_collect)
        all_loss.append(loss)
        all_accuracy.append(accuracy)

    avg_loss = np.mean(all_loss)
    avg_accuracy = np.mean(all_accuracy)
    return avg_loss, avg_accuracy


def data_preprocess(dataset, vocab, word_id, max_seq_len):
    """Convert tokens to numbers."""
147
148
    nltk.download("punkt")
    dataset_size = len(dataset["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
149
    output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
150
    mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
151

152
    for j, sentence in enumerate(dataset["sentence"]):
153
        tokens = nltk.word_tokenize(sentence)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
154
155
156
157
158
159
160
161
162
163
164
165
166
        tensor = output[j]

        for i, word in enumerate(tokens):
            if i >= max_seq_len:
                break

            if word not in vocab:
                vocab[word] = word_id
                tensor[i] = word_id
                word_id = word_id + 1
            else:
                tensor[i] = vocab[word]

167
        seq_len = min(len(tokens), max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
168
        mask_2d = mask_3d[j]
169
        mask_2d[:seq_len, :seq_len] = 0
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
170

171
    new_dataset = {
172
173
174
        "sentence": output,
        "label": dataset["label"].astype(np.float32),
        "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
175
176
    }
    return new_dataset, vocab, word_id
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
177
178
179
180
181
182


def get_datasets(max_seq_len):
    """Load GLUE train and test datasets into memory."""
    vocab = {}
    word_id = 0
183

184
185
    train_ds = load_dataset("glue", "cola", split="train")
    train_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
186
    train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
187

188
189
    test_ds = load_dataset("glue", "cola", split="validation")
    test_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
190
191
192
193
194
195
196
    test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
    return train_ds, test_ds, word_id


def check_fp8(state, var_collect, inputs, masks, labels):
    "Check if model includes FP8."
    rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
197
    assert "fp8_" in str(
198
199
        jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
    )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
200
201
202
203
204


def train_and_evaluate(args):
    """Execute model training and evaluation loop."""
    print(args)
205
    train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221

    rng = jax.random.PRNGKey(args.seed)
    rng, params_rng = jax.random.split(rng)
    rng, dropout_rng = jax.random.split(rng)
    init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}

    input_shape = [args.batch_size, args.max_seq_len]
    mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
    label_shape = [args.batch_size]

    with te.fp8_autocast(enabled=args.use_fp8):
        encoder = Net(num_embed)
        inputs = jnp.zeros(input_shape, dtype=jnp.int32)
        masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
        var_collect = encoder.init(init_rngs, inputs, masks)
        tx = optax.adamw(args.lr)
222
223
224
        state = train_state.TrainState.create(
            apply_fn=encoder.apply, params=var_collect[PARAMS_KEY], tx=tx
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
225
226
227
228
229
230
231
232

        if args.use_fp8:
            labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
            check_fp8(state, var_collect, inputs, masks, labels)

        if args.dry_run:
            labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
            rngs = {DROPOUT_KEY: dropout_rng}
233
            train_step(state, inputs, masks, labels, var_collect, rngs)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
234
235
236
237
238
239
240
241
242
            print("PASSED")
            return None

        for epoch in range(1, args.epochs + 1):
            rng, input_rng = jax.random.split(rng)
            rng, dropout_rng = jax.random.split(rng)
            rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}

            state, train_loss, train_accuracy, var_collect = train_epoch(
243
244
                state, train_ds, args.batch_size, rngs, var_collect
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
245
246
247

            test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)

248
249
250
251
252
253
254
            print(
                f"Epoch: {epoch:>2} "
                f"Train Loss: {train_loss:.6f} "
                f"Train Accuracy: {train_accuracy:.6f} "
                f"Test Loss: {test_loss:.6f} "
                f"Test Accuracy: {test_accuracy:.6f} "
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

    return [train_loss, train_accuracy, test_loss, test_accuracy]


def encoder_parser(args):
    """Training settings."""
    parser = argparse.ArgumentParser(description="JAX Encoder Example")
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for testing (default: 64)",
    )
    parser.add_argument(
        "--max-seq-len",
        type=int,
        default=32,
        metavar="N",
        help="maximum sequence length (default: 32)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=3,
        metavar="N",
        help="number of epochs to train (default: 3)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.0001,
        metavar="LR",
        help="learning rate (default: 0.0001)",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        default=False,
        help="quickly check a single pass",
    )
303
    parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
304
305
306
307
308
309
    parser.add_argument(
        "--use-fp8",
        action="store_true",
        default=False,
        help="Use FP8 for inference and training without recalibration",
    )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
310
311
312
313
314
315
316

    return parser.parse_args(args)


class TestEncoder(unittest.TestCase):
    """Encoder unittests"""

317
318
    gpu_has_fp8, reason = te.fp8.is_fp8_available()

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
319
320
321
322
323
324
325
326
327
328
    @classmethod
    def setUpClass(cls):
        """Run 4 epochs for testing"""
        cls.args = encoder_parser(["--epochs", "3"])

    def test_te_bf16(self):
        """Test Transformer Engine with BF16"""
        actual = train_and_evaluate(self.args)
        assert actual[0] < 0.45 and actual[1] > 0.79

329
    @unittest.skipIf(not gpu_has_fp8, reason)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
330
331
332
333
334
335
336
337
338
    def test_te_fp8(self):
        """Test Transformer Engine with FP8"""
        self.args.use_fp8 = True
        actual = train_and_evaluate(self.args)
        assert actual[0] < 0.45 and actual[1] > 0.79


if __name__ == "__main__":
    train_and_evaluate(encoder_parser(None))