test_single_gpu_encoder.py 13.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
2
3
#
# See LICENSE for license information.
4
"""Encoder training on single GPU"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
5
6
7
8
import argparse
import unittest
from functools import partial

9
import flax
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
10
11
12
13
14
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
15
from datasets import load_dataset
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
16
17
18
from flax import linen as nn
from flax.training import train_state

19
from common import is_bf16_supported, get_quantization_recipe_from_name_string
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
20
import transformer_engine.jax as te
21
import transformer_engine.jax.flax as te_flax
22
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
23

24

25
26
PARAMS_KEY = "params"
DROPOUT_KEY = "dropout"
27
SR_KEY = "sr_rng"
28
INPUT_KEY = "input_rng"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
29
30
31
32


class Net(nn.Module):
    """NLP Encoder"""
33

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
34
35
36
37
38
39
    num_embed: int

    @nn.compact
    def __call__(self, x, mask, disable_dropout=False):
        x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)

40
41
42
43
44
45
46
47
48
49
50
51
        te_Encoder = partial(
            te_flax.TransformerLayer,
            hidden_size=256,
            mlp_hidden_size=1024,
            num_attention_heads=8,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            dropout_rng_name=DROPOUT_KEY,
            layer_type=te_flax.TransformerLayerType.ENCODER,
            self_attn_mask_type="padding",
            enable_relative_embedding=False,
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
52
53
54
55
        x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

        x = x.reshape(x.shape[0], -1)

56
        x = te_flax.DenseGeneral(features=256)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
57

58
        x = te_flax.DenseGeneral(features=256)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
59

60
        x = nn.Dense(features=2)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
61
62
63
        return x


64
@jax.jit
65
def train_step(state, inputs, masks, labels, var_collect, rngs):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
66
67
68
69
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
70
        one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
71
72
73
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

74
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
75
76
77
78
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(var_collect)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

79
    var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
80
81
82
83
84
    state = state.apply_gradients(grads=grads)

    return state, loss, accuracy, var_collect


85
def train_epoch(state, train_ds, batch_size, rngs, var_collect):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
86
    """Train for a single epoch."""
87
    train_ds_size = len(train_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
88
89
    steps_per_epoch = train_ds_size // batch_size
    perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
90
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
91
92
93
94
95
    perms = perms.reshape((steps_per_epoch, batch_size))
    epoch_loss = []
    epoch_accuracy = []

    for perm in perms:
96
97
        # Split and reassign to 'rngs' to ensure unique rng for each step
        rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
98
99
100
101
102
103
        batch_inputs = train_ds["sentence"][perm, ...]
        batch_masks = train_ds["mask"][perm, ...]
        batch_labels = train_ds["label"][perm, ...]
        state, loss, accuracy, var_collect = train_step(
            state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
104
105
106
107
108
109
110
111
112
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)

    avg_loss = np.mean(epoch_loss)
    avg_accuracy = np.mean(epoch_accuracy)
    return state, avg_loss, avg_accuracy, var_collect


@jax.jit
113
def eval_step(state, inputs, masks, labels, var_collect, rngs):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
114
115
116
    """Computes loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
117
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
118
        one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
119
120
121
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

122
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
123
124
125
126
127
    loss, logits = loss_fn(var_collect, disable_dropout=True)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return loss, accuracy


128
def eval_model(state, test_ds, batch_size, var_collect, rngs):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
129
    """Evaluation loop."""
130
    test_ds_size = len(test_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
131
132
133
134
135
136
    num_steps = test_ds_size // batch_size
    valid_size = num_steps * batch_size
    all_loss = []
    all_accuracy = []

    for batch_start in range(0, valid_size, batch_size):
137
138
        # Split and reassign to 'rngs' to ensure unique rng for each step
        rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
139
        batch_end = batch_start + batch_size
140
141
142
        batch_inputs = test_ds["sentence"][batch_start:batch_end]
        batch_masks = test_ds["mask"][batch_start:batch_end]
        batch_labels = test_ds["label"][batch_start:batch_end]
143
144
145
        loss, accuracy = eval_step(
            state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
146
147
148
149
150
151
152
153
154
155
        all_loss.append(loss)
        all_accuracy.append(accuracy)

    avg_loss = np.mean(all_loss)
    avg_accuracy = np.mean(all_accuracy)
    return avg_loss, avg_accuracy


def data_preprocess(dataset, vocab, word_id, max_seq_len):
    """Convert tokens to numbers."""
156
    nltk.download("punkt_tab")
157
    dataset_size = len(dataset["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
158
    output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
159
    mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
160

161
    for j, sentence in enumerate(dataset["sentence"]):
162
        tokens = nltk.word_tokenize(sentence)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
163
164
165
166
167
168
169
170
171
172
173
174
175
        tensor = output[j]

        for i, word in enumerate(tokens):
            if i >= max_seq_len:
                break

            if word not in vocab:
                vocab[word] = word_id
                tensor[i] = word_id
                word_id = word_id + 1
            else:
                tensor[i] = vocab[word]

176
        seq_len = min(len(tokens), max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
177
        mask_2d = mask_3d[j]
178
        mask_2d[:seq_len, :seq_len] = 0
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
179

180
    new_dataset = {
181
182
183
        "sentence": output,
        "label": dataset["label"].astype(np.float32),
        "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
184
185
    }
    return new_dataset, vocab, word_id
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
186
187
188
189
190
191


def get_datasets(max_seq_len):
    """Load GLUE train and test datasets into memory."""
    vocab = {}
    word_id = 0
192

193
194
    train_ds = load_dataset("glue", "cola", split="train")
    train_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
195
    train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
196

197
198
    test_ds = load_dataset("glue", "cola", split="validation")
    test_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
199
200
201
202
203
204
    test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
    return train_ds, test_ds, word_id


def check_fp8(state, var_collect, inputs, masks, labels):
    "Check if model includes FP8."
205
    rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)}
206
207
    func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
    assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
208
209
210
211
212


def train_and_evaluate(args):
    """Execute model training and evaluation loop."""
    print(args)
213
    train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
214
215
216
217

    rng = jax.random.PRNGKey(args.seed)
    rng, params_rng = jax.random.split(rng)
    rng, dropout_rng = jax.random.split(rng)
218
219
    rng, sr_rng = jax.random.split(rng)
    init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
220
221
222
223
224

    input_shape = [args.batch_size, args.max_seq_len]
    mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
    label_shape = [args.batch_size]

225
    if args.use_fp8:
226
        fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
227
228
229
    else:
        fp8_recipe = None

230
231
232
    with te.fp8_autocast(
        enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
    ):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
233
        encoder = Net(num_embed)
234
        # We use nn.Embed, thus inputs need to be in int
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
235
236
237
238
        inputs = jnp.zeros(input_shape, dtype=jnp.int32)
        masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
        var_collect = encoder.init(init_rngs, inputs, masks)
        tx = optax.adamw(args.lr)
239
240
241
        state = train_state.TrainState.create(
            apply_fn=encoder.apply, params=var_collect[PARAMS_KEY], tx=tx
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
242
243
244
245
246
247
248

        if args.use_fp8:
            labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
            check_fp8(state, var_collect, inputs, masks, labels)

        if args.dry_run:
            labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
249
            rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
250
            train_step(state, inputs, masks, labels, var_collect, rngs)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
251
252
253
254
            print("PASSED")
            return None

        for epoch in range(1, args.epochs + 1):
255
            # Split and reassign to 'rng' to ensure unique rng for each step
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
256
257
            rng, input_rng = jax.random.split(rng)
            rng, dropout_rng = jax.random.split(rng)
258
259
            rng, sr_rng = jax.random.split(rng)
            rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
260
261

            state, train_loss, train_accuracy, var_collect = train_epoch(
262
263
                state, train_ds, args.batch_size, rngs, var_collect
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
264

265
266
267
            test_loss, test_accuracy = eval_model(
                state, test_ds, args.test_batch_size, var_collect, rngs
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
268

269
270
271
272
273
274
275
            print(
                f"Epoch: {epoch:>2} "
                f"Train Loss: {train_loss:.6f} "
                f"Train Accuracy: {train_accuracy:.6f} "
                f"Test Loss: {test_loss:.6f} "
                f"Test Accuracy: {test_accuracy:.6f} "
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

    return [train_loss, train_accuracy, test_loss, test_accuracy]


def encoder_parser(args):
    """Training settings."""
    parser = argparse.ArgumentParser(description="JAX Encoder Example")
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for testing (default: 64)",
    )
    parser.add_argument(
        "--max-seq-len",
        type=int,
        default=32,
        metavar="N",
        help="maximum sequence length (default: 32)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=3,
        metavar="N",
        help="number of epochs to train (default: 3)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.0001,
        metavar="LR",
        help="learning rate (default: 0.0001)",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        default=False,
        help="quickly check a single pass",
    )
324
    parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
325
326
327
328
329
330
    parser.add_argument(
        "--use-fp8",
        action="store_true",
        default=False,
        help="Use FP8 for inference and training without recalibration",
    )
331
332
333
334
335
336
    parser.add_argument(
        "--fp8-recipe",
        action="store_true",
        default="DelayedScaling",
        help="Use FP8 recipe (default: DelayedScaling)",
    )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
337
338
339
340
341
342
343

    return parser.parse_args(args)


class TestEncoder(unittest.TestCase):
    """Encoder unittests"""

344
345
346
    is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
    is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
    is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
347

348
349
350
    def setUp(self):
        """Run 3 epochs for testing"""
        self.args = encoder_parser(["--epochs", "3"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
351

352
    @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
353
354
355
    def test_te_bf16(self):
        """Test Transformer Engine with BF16"""
        actual = train_and_evaluate(self.args)
356
        assert actual[0] < 0.452 and actual[1] > 0.788
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
357

358
359
360
361
362
363
    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_delayed_scaling_fp8(self):
        """Test Transformer Engine with DelayedScaling FP8"""
        self.args.use_fp8 = True
        self.args.fp8_recipe = "DelayedScaling"
        actual = train_and_evaluate(self.args)
364
        assert actual[0] < 0.457 and actual[1] > 0.784
365

366
367
368
369
370
371
    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_current_scaling_fp8(self):
        """Test Transformer Engine with CurrentScaling FP8"""
        self.args.use_fp8 = True
        self.args.fp8_recipe = "Float8CurrentScaling"
        actual = train_and_evaluate(self.args)
372
        assert actual[0] < 0.461 and actual[1] > 0.784
373

374
375
376
    @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
    def test_te_mxfp8(self):
        """Test Transformer Engine with MXFP8"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
377
        self.args.use_fp8 = True
378
        self.args.fp8_recipe = "MXFP8BlockScaling"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
379
        actual = train_and_evaluate(self.args)
380
381
382
383
384
385
386
387
388
        assert actual[0] < 0.457 and actual[1] > 0.784

    @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
    def test_te_nvfp4(self):
        """Test Transformer Engine with NVFP4"""
        self.args.use_fp8 = True
        self.args.fp8_recipe = "NVFP4BlockScaling"
        actual = train_and_evaluate(self.args)
        assert actual[0] < 0.476 and actual[1] > 0.775
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
389
390
391
392


if __name__ == "__main__":
    train_and_evaluate(encoder_parser(None))