test_single_gpu_mnist.py 11.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
2
3
#
# See LICENSE for license information.
4
"""MNIST training on single GPU"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
5
6
7
import argparse
import unittest
from functools import partial
8
9
import sys
from pathlib import Path
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
10
11
12
13
14

import jax
import jax.numpy as jnp
import numpy as np
import optax
15
from datasets import load_dataset
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
16
17
18
19
from flax import linen as nn
from flax.training import train_state

import transformer_engine.jax as te
20
import transformer_engine.jax.flax as te_flax
21
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
22
23
24

DIR = str(Path(__file__).resolve().parents[1])
sys.path.append(str(DIR))
25
26
27
28
29
30
31
from encoder.common import (
    is_bf16_supported,
    get_quantization_recipe_from_name_string,
    hf_login_if_available,
)

hf_login_if_available()
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
32
33
34
35

IMAGE_H = 28
IMAGE_W = 28
IMAGE_C = 1
36
37
38
PARAMS_KEY = "params"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
39
40
41
42


class Net(nn.Module):
    """CNN model for MNIST."""
43

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
44
45
46
47
48
    use_te: bool = False

    @nn.compact
    def __call__(self, x, disable_dropout=False):
        if self.use_te:
49
            nn_Dense = te_flax.DenseGeneral
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
50
51
        else:
            nn_Dense = nn.Dense
52
        # dtype is used for param init in TE but computation in Linen.nn
53

54
        dtype = jnp.float32 if self.use_te else jnp.bfloat16
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
55
56
57
58
59
60
61
62

        x = nn.Conv(features=32, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x)
        x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Dropout(rate=0.25)(x, deterministic=disable_dropout)
        x = x.reshape(x.shape[0], -1)
63
64
        assert x.dtype == jnp.bfloat16
        x = nn_Dense(features=128, dtype=dtype)(x)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
65
66
        x = nn.relu(x)
        x = nn.Dropout(rate=0.5)(x, deterministic=disable_dropout)
67
68
        x = nn_Dense(features=32, dtype=dtype)(x)
        x = nn_Dense(features=32, dtype=dtype)(x)
69
        assert x.dtype == jnp.bfloat16
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
70
71
72
73
74
75
76
77
78
        return x


@jax.jit
def apply_model(state, images, labels, var_collect, rngs=None):
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, images, disable_dropout, rngs=rngs)
79
        one_hot = jax.nn.one_hot(labels, 32)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
80
81
82
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

83
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
84
85
86
87
88
89
90
91
92
93
94
95

    if rngs is not None:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, logits), grads = grad_fn(var_collect)
    else:
        loss, logits = loss_fn(var_collect, disable_dropout=True)
        grads = None

    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return grads, loss, accuracy


96
@partial(jax.jit)
97
def update_model(state, grads):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
98
99
100
101
102
    """Update model params and FP8 meta."""
    state = state.apply_gradients(grads=grads[PARAMS_KEY])
    return state, grads


103
def train_epoch(state, train_ds, batch_size, rngs, var_collect):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
104
    """Train for a single epoch."""
105
    train_ds_size = len(train_ds["image"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
106
107
    steps_per_epoch = train_ds_size // batch_size
    perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
108
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
109
110
111
112
113
    perms = perms.reshape((steps_per_epoch, batch_size))
    epoch_loss = []
    epoch_accuracy = []

    for perm in perms:
114
115
        batch_images = train_ds["image"][perm, ...]
        batch_labels = train_ds["label"][perm, ...]
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
116
        grads, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect, rngs)
117
        state, var_collect = update_model(state, grads)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
118
119
120
121
122
123
124
125
126
127
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)

    avg_loss = np.mean(epoch_loss)
    avg_accuracy = np.mean(epoch_accuracy)
    return state, avg_loss, avg_accuracy, var_collect


def eval_model(state, test_ds, batch_size, var_collect):
    """Evaluation loop."""
128
    test_ds_size = len(test_ds["image"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
129
130
131
132
133
134
135
    num_steps = test_ds_size // batch_size
    valid_size = num_steps * batch_size
    all_loss = []
    all_accuracy = []

    for batch_start in range(0, valid_size, batch_size):
        batch_end = batch_start + batch_size
136
137
        batch_images = test_ds["image"][batch_start:batch_end]
        batch_labels = test_ds["label"][batch_start:batch_end]
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
138
139
140
141
142
143
144
145
146
147
148
        _, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect)
        all_loss.append(loss)
        all_accuracy.append(accuracy)

    avg_loss = np.mean(all_loss)
    avg_accuracy = np.mean(all_accuracy)
    return avg_loss, avg_accuracy


def get_datasets():
    """Load MNIST train and test datasets into memory."""
149
    train_ds = load_dataset("mnist", split="train", trust_remote_code=True)
150
151
    train_ds.set_format(type="np")
    batch_size = train_ds["image"].shape[0]
152
153
    shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
    new_train_ds = {
154
155
        "image": train_ds["image"].astype(np.float32).reshape(shape) / 255.0,
        "label": train_ds["label"],
156
    }
157
    test_ds = load_dataset("mnist", split="test", trust_remote_code=True)
158
159
    test_ds.set_format(type="np")
    batch_size = test_ds["image"].shape[0]
160
161
    shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
    new_test_ds = {
162
163
        "image": test_ds["image"].astype(np.float32).reshape(shape) / 255.0,
        "label": test_ds["label"],
164
165
    }
    return new_train_ds, new_test_ds
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
166
167
168
169


def check_fp8(state, var_collect, input_shape, label_shape):
    "Check if model includes FP8."
170
    func_jaxpr = str(
171
172
173
174
175
176
177
        jax.make_jaxpr(apply_model)(
            state,
            jnp.empty(input_shape, dtype=jnp.bfloat16),
            jnp.empty(label_shape, dtype=jnp.bfloat16),
            var_collect,
        )
    )
178
    assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196


def train_and_evaluate(args):
    """Execute model training and evaluation loop."""
    print(args)

    if args.use_fp8:
        args.use_te = True

    train_ds, test_ds = get_datasets()
    rng = jax.random.PRNGKey(args.seed)
    rng, params_rng = jax.random.split(rng)
    rng, dropout_rng = jax.random.split(rng)
    init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}

    input_shape = [args.batch_size, IMAGE_H, IMAGE_W, IMAGE_C]
    label_shape = [args.batch_size]

197
    if args.use_fp8:
198
        fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
199
200
201
    else:
        fp8_recipe = None

202
203
    with te.autocast(
        enabled=args.use_fp8, recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
204
    ):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
205
206
207
        cnn = Net(args.use_te)
        var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
        tx = optax.sgd(args.lr, args.momentum)
208
209
210
        state = train_state.TrainState.create(
            apply_fn=cnn.apply, params=var_collect[PARAMS_KEY], tx=tx
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
211
212
213
214
215

        if args.use_fp8:
            check_fp8(state, var_collect, input_shape, label_shape)

        if args.dry_run:
216
217
218
219
220
221
222
            apply_model(
                state,
                jnp.empty(input_shape, dtype=jnp.bfloat16),
                jnp.empty(label_shape, dtype=jnp.bfloat16),
                var_collect,
                {DROPOUT_KEY: dropout_rng},
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
223
224
225
226
227
228
229
230
231
            print("PASSED")
            return None

        for epoch in range(1, args.epochs + 1):
            rng, input_rng = jax.random.split(rng)
            rng, dropout_rng = jax.random.split(rng)
            rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}

            state, train_loss, train_accuracy, var_collect = train_epoch(
232
233
                state, train_ds, args.batch_size, rngs, var_collect
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
234
235
            test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)

236
237
238
239
240
241
242
            print(
                f"Epoch: {epoch:>2} "
                f"Train Loss: {train_loss:.6f} "
                f"Train Accuracy: {train_accuracy:.6f} "
                f"Test Loss: {test_loss:.6f} "
                f"Test Accuracy: {test_accuracy:.6f} "
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

    return [train_loss, train_accuracy, test_loss, test_accuracy]


def mnist_parser(args):
    """Training settings."""
    parser = argparse.ArgumentParser(description="JAX MNIST Example")
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=800,
        metavar="N",
        help="input batch size for testing (default: 800)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=10,
        metavar="N",
        help="number of epochs to train (default: 10)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.01,
        metavar="LR",
        help="learning rate (default: 0.01)",
    )
    parser.add_argument(
        "--momentum",
        type=float,
        default=0.9,
        metavar="M",
        help="Momentum (default: 0.9)",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        default=False,
        help="quickly check a single pass",
    )
    parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
292
293
294
295
296
297
298
299
300
    parser.add_argument(
        "--use-fp8",
        action="store_true",
        default=False,
        help=(
            "Use FP8 for inference and training without recalibration. "
            "It also enables Transformer Engine implicitly."
        ),
    )
301
302
303
304
305
306
    parser.add_argument(
        "--fp8-recipe",
        action="store_true",
        default="DelayedScaling",
        help="Use FP8 recipe (default: DelayedScaling)",
    )
307
308
309
    parser.add_argument(
        "--use-te", action="store_true", default=False, help="Use Transformer Engine"
    )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
310
311
312
313
314
315
316

    return parser.parse_args(args)


class TestMNIST(unittest.TestCase):
    """MNIST unittests"""

317
318
    is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
    is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
319

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
320
321
322
323
324
325
326
327
328
329
    @classmethod
    def setUpClass(cls):
        """Run MNIST without Transformer Engine"""
        cls.args = mnist_parser(["--epochs", "5"])

    @staticmethod
    def verify(actual):
        """Check If loss and accuracy match target"""
        desired_traing_loss = 0.055
        desired_traing_accuracy = 0.98
330
        desired_test_loss = 0.045
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
331
332
333
334
335
336
        desired_test_accuracy = 0.098
        assert actual[0] < desired_traing_loss
        assert actual[1] > desired_traing_accuracy
        assert actual[2] < desired_test_loss
        assert actual[3] > desired_test_accuracy

337
    @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
338
339
340
341
342
343
344
    def test_te_bf16(self):
        """Test Transformer Engine with BF16"""
        self.args.use_te = True
        self.args.use_fp8 = False
        actual = train_and_evaluate(self.args)
        self.verify(actual)

345
346
347
348
349
350
351
352
353
354
355
    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_delayed_scaling_fp8(self):
        """Test Transformer Engine with DelayedScaling FP8"""
        self.args.use_fp8 = True
        self.args.fp8_recipe = "DelayedScaling"
        actual = train_and_evaluate(self.args)
        self.verify(actual)

    @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
    def test_te_mxfp8(self):
        """Test Transformer Engine with MXFP8"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
356
        self.args.use_fp8 = True
357
        self.args.fp8_recipe = "MXFP8BlockScaling"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
358
359
360
        actual = train_and_evaluate(self.args)
        self.verify(actual)

361
362
363
364
365
366
367
368
    @unittest.skipIf(not is_fp8_supported, fp8_reason)
    def test_te_current_scaling_fp8(self):
        """Test Transformer Engine with CurrentScaling FP8"""
        self.args.use_fp8 = True
        self.args.fp8_recipe = "Float8CurrentScaling"
        actual = train_and_evaluate(self.args)
        self.verify(actual)

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
369
370
371

if __name__ == "__main__":
    train_and_evaluate(mnist_parser(None))