test_single_gpu_encoder.py 11.3 KB
Newer Older
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
1
2
3
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
4
"""Encoder training on single GPU"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
5
6
7
8
import argparse
import unittest
from functools import partial

9
import flax
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
10
11
12
13
14
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
15
from datasets import load_dataset
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
16
17
18
19
from flax import linen as nn
from flax.training import train_state

import transformer_engine.jax as te
20
import transformer_engine.jax.flax as te_flax
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34

PARAMS_KEY = 'params'
DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng'


class Net(nn.Module):
    """NLP Encoder"""
    num_embed: int

    @nn.compact
    def __call__(self, x, mask, disable_dropout=False):
        x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)

35
        te_Encoder = partial(te_flax.TransformerLayer,
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
36
37
38
39
40
41
                             hidden_size=256,
                             mlp_hidden_size=1024,
                             num_attention_heads=8,
                             hidden_dropout=0.1,
                             attention_dropout=0.1,
                             dropout_rng_name=DROPOUT_KEY,
42
                             layer_type=te_flax.TransformerLayerType.ENCODER,
43
                             self_attn_mask_type='padding',
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
44
45
46
47
48
49
                             enable_relative_embedding=False,
                             dtype=jnp.bfloat16)
        x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

        x = x.reshape(x.shape[0], -1)

50
        x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
51

52
        x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

        x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
        return x


@partial(jax.jit, static_argnums=6)
def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
        one_hot = jax.nn.one_hot(labels, 2)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

68
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
69
70
71
72
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(var_collect)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

73
    var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    state = state.apply_gradients(grads=grads)
    if use_fp8:
        var_collect = te.update_fp8_metas(var_collect)

    return state, loss, accuracy, var_collect


def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8):
    """Train for a single epoch."""
    train_ds_size = len(train_ds['sentence'])
    steps_per_epoch = train_ds_size // batch_size
    perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]    # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    epoch_loss = []
    epoch_accuracy = []

    for perm in perms:
        batch_inputs = train_ds['sentence'][perm, ...]
        batch_masks = train_ds['mask'][perm, ...]
        batch_labels = train_ds['label'][perm, ...]
        state, loss, accuracy, var_collect = train_step(state, batch_inputs, batch_masks,
                                                        batch_labels, var_collect, rngs, use_fp8)
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)

    avg_loss = np.mean(epoch_loss)
    avg_accuracy = np.mean(epoch_accuracy)
    return state, avg_loss, avg_accuracy, var_collect


@jax.jit
def eval_step(state, inputs, masks, labels, var_collect):
    """Computes loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
        one_hot = jax.nn.one_hot(labels, 2)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

115
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    loss, logits = loss_fn(var_collect, disable_dropout=True)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return loss, accuracy


def eval_model(state, test_ds, batch_size, var_collect):
    """Evaluation loop."""
    test_ds_size = len(test_ds['sentence'])
    num_steps = test_ds_size // batch_size
    valid_size = num_steps * batch_size
    all_loss = []
    all_accuracy = []

    for batch_start in range(0, valid_size, batch_size):
        batch_end = batch_start + batch_size
        batch_inputs = test_ds['sentence'][batch_start:batch_end]
        batch_masks = test_ds['mask'][batch_start:batch_end]
        batch_labels = test_ds['label'][batch_start:batch_end]
        loss, accuracy = eval_step(state, batch_inputs, batch_masks, batch_labels, var_collect)
        all_loss.append(loss)
        all_accuracy.append(accuracy)

    avg_loss = np.mean(all_loss)
    avg_accuracy = np.mean(all_accuracy)
    return avg_loss, avg_accuracy


def data_preprocess(dataset, vocab, word_id, max_seq_len):
    """Convert tokens to numbers."""
    nltk.download('punkt')
    dataset_size = len(dataset['sentence'])
    output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
148
    mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
149
150

    for j, sentence in enumerate(dataset['sentence']):
151
        tokens = nltk.word_tokenize(sentence)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
152
153
154
155
156
157
158
159
160
161
162
163
164
        tensor = output[j]

        for i, word in enumerate(tokens):
            if i >= max_seq_len:
                break

            if word not in vocab:
                vocab[word] = word_id
                tensor[i] = word_id
                word_id = word_id + 1
            else:
                tensor[i] = vocab[word]

165
        seq_len = min(len(tokens), max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
166
        mask_2d = mask_3d[j]
167
        mask_2d[:seq_len, :seq_len] = 0
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
168

169
170
171
172
173
174
    new_dataset = {
        'sentence': output,
        'label': dataset['label'].astype(np.float32),
        'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
    }
    return new_dataset, vocab, word_id
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
175
176
177
178
179
180


def get_datasets(max_seq_len):
    """Load GLUE train and test datasets into memory."""
    vocab = {}
    word_id = 0
181
182
183

    train_ds = load_dataset('glue', 'cola', split='train')
    train_ds.set_format(type='np')
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
184
    train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
185
186
187

    test_ds = load_dataset('glue', 'cola', split='validation')
    test_ds.set_format(type='np')
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
    return train_ds, test_ds, word_id


def check_fp8(state, var_collect, inputs, masks, labels):
    "Check if model includes FP8."
    rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
    assert "Float8" in str(
        jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect,
                                                     rngs, True))


def train_and_evaluate(args):
    """Execute model training and evaluation loop."""
    print(args)
203
    train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297

    rng = jax.random.PRNGKey(args.seed)
    rng, params_rng = jax.random.split(rng)
    rng, dropout_rng = jax.random.split(rng)
    init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}

    input_shape = [args.batch_size, args.max_seq_len]
    mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
    label_shape = [args.batch_size]

    with te.fp8_autocast(enabled=args.use_fp8):
        encoder = Net(num_embed)
        inputs = jnp.zeros(input_shape, dtype=jnp.int32)
        masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
        var_collect = encoder.init(init_rngs, inputs, masks)
        tx = optax.adamw(args.lr)
        state = train_state.TrainState.create(apply_fn=encoder.apply,
                                              params=var_collect[PARAMS_KEY],
                                              tx=tx)

        if args.use_fp8:
            labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
            check_fp8(state, var_collect, inputs, masks, labels)

        if args.dry_run:
            labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
            rngs = {DROPOUT_KEY: dropout_rng}
            train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8)
            print("PASSED")
            return None

        for epoch in range(1, args.epochs + 1):
            rng, input_rng = jax.random.split(rng)
            rng, dropout_rng = jax.random.split(rng)
            rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}

            state, train_loss, train_accuracy, var_collect = train_epoch(
                state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8)

            test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)

            print(f"Epoch: {epoch:>2} "
                  f"Train Loss: {train_loss:.6f} "
                  f"Train Accuracy: {train_accuracy:.6f} "
                  f"Test Loss: {test_loss:.6f} "
                  f"Test Accuracy: {test_accuracy:.6f} ")

    return [train_loss, train_accuracy, test_loss, test_accuracy]


def encoder_parser(args):
    """Training settings."""
    parser = argparse.ArgumentParser(description="JAX Encoder Example")
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for testing (default: 64)",
    )
    parser.add_argument(
        "--max-seq-len",
        type=int,
        default=32,
        metavar="N",
        help="maximum sequence length (default: 32)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=3,
        metavar="N",
        help="number of epochs to train (default: 3)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.0001,
        metavar="LR",
        help="learning rate (default: 0.0001)",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        default=False,
        help="quickly check a single pass",
    )
298
    parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
299
300
301
302
303
304
305
306
307
308
309
    parser.add_argument("--use-fp8",
                        action="store_true",
                        default=False,
                        help="Use FP8 for inference and training without recalibration")

    return parser.parse_args(args)


class TestEncoder(unittest.TestCase):
    """Encoder unittests"""

310
311
    gpu_has_fp8, reason = te.fp8.is_fp8_available()

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
312
313
314
315
316
317
318
319
320
321
    @classmethod
    def setUpClass(cls):
        """Run 4 epochs for testing"""
        cls.args = encoder_parser(["--epochs", "3"])

    def test_te_bf16(self):
        """Test Transformer Engine with BF16"""
        actual = train_and_evaluate(self.args)
        assert actual[0] < 0.45 and actual[1] > 0.79

322
    @unittest.skipIf(not gpu_has_fp8, reason)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
323
324
325
326
327
328
329
330
331
    def test_te_fp8(self):
        """Test Transformer Engine with FP8"""
        self.args.use_fp8 = True
        actual = train_and_evaluate(self.args)
        assert actual[0] < 0.45 and actual[1] > 0.79


if __name__ == "__main__":
    train_and_evaluate(encoder_parser(None))