test_single_gpu_encoder.py 13.5 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
2
3
#
# See LICENSE for license information.
4
"""Encoder training on single GPU"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
5
6
7
8
import argparse
import unittest
from functools import partial

9
import flax
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
10
11
12
13
14
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
15
from datasets import load_dataset
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
16
17
18
from flax import linen as nn
from flax.training import train_state

19
20
21
from common import (
    is_bf16_supported,
    get_quantization_recipe_from_name_string,
22
    unpack_cached_datasets_if_available,
23
)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
24
import transformer_engine.jax as te
25
import transformer_engine.jax.flax as te_flax
26
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
27

28
unpack_cached_datasets_if_available()
29

30
31
PARAMS_KEY = "params"
DROPOUT_KEY = "dropout"
32
SR_KEY = "sr_rng"
33
INPUT_KEY = "input_rng"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
34
35
36
37


class Net(nn.Module):
    """NLP Encoder"""
38

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
39
40
41
42
43
44
    num_embed: int

    @nn.compact
    def __call__(self, x, mask, disable_dropout=False):
        x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)

45
46
47
48
49
50
51
52
53
54
55
56
        te_Encoder = partial(
            te_flax.TransformerLayer,
            hidden_size=256,
            mlp_hidden_size=1024,
            num_attention_heads=8,
            hidden_dropout=0.1,
            attention_dropout=0.1,
            dropout_rng_name=DROPOUT_KEY,
            layer_type=te_flax.TransformerLayerType.ENCODER,
            self_attn_mask_type="padding",
            enable_relative_embedding=False,
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
57
58
59
60
        x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

        x = x.reshape(x.shape[0], -1)

61
        x = te_flax.DenseGeneral(features=256)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
62

63
        x = te_flax.DenseGeneral(features=256)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
64

65
        x = nn.Dense(features=2)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
66
67
68
        return x


69
@jax.jit
70
def train_step(state, inputs, masks, labels, var_collect, rngs):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
71
72
73
74
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
75
        one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
76
77
78
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

79
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
80
81
82
83
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(var_collect)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

84
    var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
85
86
87
88
89
    state = state.apply_gradients(grads=grads)

    return state, loss, accuracy, var_collect


90
def train_epoch(state, train_ds, batch_size, rngs, var_collect):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
91
    """Train for a single epoch."""
92
    train_ds_size = len(train_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
93
94
    steps_per_epoch = train_ds_size // batch_size
    perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
95
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
96
97
98
99
100
    perms = perms.reshape((steps_per_epoch, batch_size))
    epoch_loss = []
    epoch_accuracy = []

    for perm in perms:
101
102
        # Split and reassign to 'rngs' to ensure unique rng for each step
        rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
103
104
105
106
107
108
        batch_inputs = train_ds["sentence"][perm, ...]
        batch_masks = train_ds["mask"][perm, ...]
        batch_labels = train_ds["label"][perm, ...]
        state, loss, accuracy, var_collect = train_step(
            state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
109
110
111
112
113
114
115
116
117
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)

    avg_loss = np.mean(epoch_loss)
    avg_accuracy = np.mean(epoch_accuracy)
    return state, avg_loss, avg_accuracy, var_collect


@jax.jit
118
def eval_step(state, inputs, masks, labels, var_collect, rngs):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
119
120
121
    """Computes loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
122
        logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
123
        one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
124
125
126
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

127
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
128
129
130
131
132
    loss, logits = loss_fn(var_collect, disable_dropout=True)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return loss, accuracy


133
def eval_model(state, test_ds, batch_size, var_collect, rngs):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
134
    """Evaluation loop."""
135
    test_ds_size = len(test_ds["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
136
137
138
139
140
141
    num_steps = test_ds_size // batch_size
    valid_size = num_steps * batch_size
    all_loss = []
    all_accuracy = []

    for batch_start in range(0, valid_size, batch_size):
142
143
        # Split and reassign to 'rngs' to ensure unique rng for each step
        rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
144
        batch_end = batch_start + batch_size
145
146
147
        batch_inputs = test_ds["sentence"][batch_start:batch_end]
        batch_masks = test_ds["mask"][batch_start:batch_end]
        batch_labels = test_ds["label"][batch_start:batch_end]
148
149
150
        loss, accuracy = eval_step(
            state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
151
152
153
154
155
156
157
158
159
160
        all_loss.append(loss)
        all_accuracy.append(accuracy)

    avg_loss = np.mean(all_loss)
    avg_accuracy = np.mean(all_accuracy)
    return avg_loss, avg_accuracy


def data_preprocess(dataset, vocab, word_id, max_seq_len):
    """Convert tokens to numbers."""
161
    nltk.download("punkt_tab")
162
    dataset_size = len(dataset["sentence"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
163
    output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
164
    mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
165

166
    for j, sentence in enumerate(dataset["sentence"]):
167
        tokens = nltk.word_tokenize(sentence)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
168
169
170
171
172
173
174
175
176
177
178
179
180
        tensor = output[j]

        for i, word in enumerate(tokens):
            if i >= max_seq_len:
                break

            if word not in vocab:
                vocab[word] = word_id
                tensor[i] = word_id
                word_id = word_id + 1
            else:
                tensor[i] = vocab[word]

181
        seq_len = min(len(tokens), max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
182
        mask_2d = mask_3d[j]
183
        mask_2d[:seq_len, :seq_len] = 0
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
184

185
    new_dataset = {
186
187
188
        "sentence": output,
        "label": dataset["label"].astype(np.float32),
        "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
189
190
    }
    return new_dataset, vocab, word_id
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
191
192
193
194
195
196


def get_datasets(max_seq_len):
    """Load GLUE train and test datasets into memory."""
    vocab = {}
    word_id = 0
197

198
199
    train_ds = load_dataset("glue", "cola", split="train")
    train_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
200
    train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
201

202
203
    test_ds = load_dataset("glue", "cola", split="validation")
    test_ds.set_format(type="np")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
204
205
206
207
208
209
    test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
    return train_ds, test_ds, word_id


def check_fp8(state, var_collect, inputs, masks, labels):
    "Check if model includes FP8."
210
    rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)}
211
212
    func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
    assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
213
214
215
216
217


def train_and_evaluate(args):
    """Execute model training and evaluation loop."""
    print(args)
218
    train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
219
220
221
222

    rng = jax.random.PRNGKey(args.seed)
    rng, params_rng = jax.random.split(rng)
    rng, dropout_rng = jax.random.split(rng)
223
224
    rng, sr_rng = jax.random.split(rng)
    init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
225
226
227
228
229

    input_shape = [args.batch_size, args.max_seq_len]
    mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
    label_shape = [args.batch_size]

230
    if args.use_fp8:
231
        fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
232
233
234
    else:
        fp8_recipe = None

235
236
    with te.autocast(
        enabled=args.use_fp8, recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
237
    ):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
238
        encoder = Net(num_embed)
239
        # We use nn.Embed, thus inputs need to be in int
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
240
241
242
243
        inputs = jnp.zeros(input_shape, dtype=jnp.int32)
        masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
        var_collect = encoder.init(init_rngs, inputs, masks)
        tx = optax.adamw(args.lr)
244
245
246
        state = train_state.TrainState.create(
            apply_fn=encoder.apply, params=var_collect[PARAMS_KEY], tx=tx
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
247
248
249
250
251
252
253

        if args.use_fp8:
            labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
            check_fp8(state, var_collect, inputs, masks, labels)

        if args.dry_run:
            labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
254
            rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
255
            train_step(state, inputs, masks, labels, var_collect, rngs)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
256
257
258
259
            print("PASSED")
            return None

        for epoch in range(1, args.epochs + 1):
260
            # Split and reassign to 'rng' to ensure unique rng for each step
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
261
262
            rng, input_rng = jax.random.split(rng)
            rng, dropout_rng = jax.random.split(rng)
263
264
            rng, sr_rng = jax.random.split(rng)
            rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
265
266

            state, train_loss, train_accuracy, var_collect = train_epoch(
267
268
                state, train_ds, args.batch_size, rngs, var_collect
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
269

270
271
272
            test_loss, test_accuracy = eval_model(
                state, test_ds, args.test_batch_size, var_collect, rngs
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
273

274
275
276
277
278
279
280
            print(
                f"Epoch: {epoch:>2} "
                f"Train Loss: {train_loss:.6f} "
                f"Train Accuracy: {train_accuracy:.6f} "
                f"Test Loss: {test_loss:.6f} "
                f"Test Accuracy: {test_accuracy:.6f} "
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328

    return [train_loss, train_accuracy, test_loss, test_accuracy]


def encoder_parser(args):
    """Training settings."""
    parser = argparse.ArgumentParser(description="JAX Encoder Example")
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for testing (default: 64)",
    )
    parser.add_argument(
        "--max-seq-len",
        type=int,
        default=32,
        metavar="N",
        help="maximum sequence length (default: 32)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=3,
        metavar="N",
        help="number of epochs to train (default: 3)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.0001,
        metavar="LR",
        help="learning rate (default: 0.0001)",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        default=False,
        help="quickly check a single pass",
    )
329
    parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
330
331
332
333
334
335
    parser.add_argument(
        "--use-fp8",
        action="store_true",
        default=False,
        help="Use FP8 for inference and training without recalibration",
    )
336
337
338
339
340
341
    parser.add_argument(
        "--fp8-recipe",
        action="store_true",
        default="DelayedScaling",
        help="Use FP8 recipe (default: DelayedScaling)",
    )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
342
343
344
345
346
347
348

    return parser.parse_args(args)


class TestEncoder(unittest.TestCase):
    """Encoder unittests"""

349
350
351
    is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
    is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
    is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
352

353
354
355
    def setUp(self):
        """Run 3 epochs for testing"""
        self.args = encoder_parser(["--epochs", "3"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
356

357
    @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
358
359
360
    def test_te_bf16(self):
        """Test Transformer Engine with BF16"""
        actual = train_and_evaluate(self.args)
361
        assert actual[0] < 0.452 and actual[1] > 0.788
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
362

363
364
365
366
367
368
    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_delayed_scaling_fp8(self):
        """Test Transformer Engine with DelayedScaling FP8"""
        self.args.use_fp8 = True
        self.args.fp8_recipe = "DelayedScaling"
        actual = train_and_evaluate(self.args)
369
        assert actual[0] < 0.457 and actual[1] > 0.784
370

371
372
373
374
375
376
    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_current_scaling_fp8(self):
        """Test Transformer Engine with CurrentScaling FP8"""
        self.args.use_fp8 = True
        self.args.fp8_recipe = "Float8CurrentScaling"
        actual = train_and_evaluate(self.args)
377
        assert actual[0] < 0.461 and actual[1] > 0.784
378

379
380
381
    @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
    def test_te_mxfp8(self):
        """Test Transformer Engine with MXFP8"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
382
        self.args.use_fp8 = True
383
        self.args.fp8_recipe = "MXFP8BlockScaling"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
384
        actual = train_and_evaluate(self.args)
385
386
387
388
389
390
391
392
        assert actual[0] < 0.457 and actual[1] > 0.784

    @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
    def test_te_nvfp4(self):
        """Test Transformer Engine with NVFP4"""
        self.args.use_fp8 = True
        self.args.fp8_recipe = "NVFP4BlockScaling"
        actual = train_and_evaluate(self.args)
393
        assert actual[0] < 0.477 and actual[1] > 0.769
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
394
395
396
397


if __name__ == "__main__":
    train_and_evaluate(encoder_parser(None))