test_single_gpu_mnist.py 10 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
2
3
#
# See LICENSE for license information.
4
"""MNIST training on single GPU"""
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
5
6
7
8
9
10
11
12
import argparse
import unittest
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import optax
13
from datasets import load_dataset
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
14
15
16
17
from flax import linen as nn
from flax.training import train_state

import transformer_engine.jax as te
18
import transformer_engine.jax.flax as te_flax
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
19
20
21
22

IMAGE_H = 28
IMAGE_W = 28
IMAGE_C = 1
23
24
25
PARAMS_KEY = "params"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
26
27
28
29


class Net(nn.Module):
    """CNN model for MNIST."""
30

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
31
32
33
34
35
    use_te: bool = False

    @nn.compact
    def __call__(self, x, disable_dropout=False):
        if self.use_te:
36
            nn_Dense = te_flax.DenseGeneral
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        else:
            nn_Dense = nn.Dense

        x = nn.Conv(features=32, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x)
        x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Dropout(rate=0.25)(x, deterministic=disable_dropout)
        x = x.reshape(x.shape[0], -1)
        x = nn_Dense(features=128, dtype=jnp.bfloat16)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=0.5)(x, deterministic=disable_dropout)
        x = nn_Dense(features=16, dtype=jnp.bfloat16)(x)
        x = nn.Dense(features=10, dtype=jnp.bfloat16)(x)
        return x


@jax.jit
def apply_model(state, images, labels, var_collect, rngs=None):
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(var_collect, disable_dropout=False):
        logits = state.apply_fn(var_collect, images, disable_dropout, rngs=rngs)
        one_hot = jax.nn.one_hot(labels, 10)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

65
    var_collect = {**var_collect, PARAMS_KEY: state.params}
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
66
67
68
69
70
71
72
73
74
75
76
77

    if rngs is not None:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, logits), grads = grad_fn(var_collect)
    else:
        loss, logits = loss_fn(var_collect, disable_dropout=True)
        grads = None

    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return grads, loss, accuracy


78
@partial(jax.jit)
79
def update_model(state, grads):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
80
81
82
83
84
    """Update model params and FP8 meta."""
    state = state.apply_gradients(grads=grads[PARAMS_KEY])
    return state, grads


85
def train_epoch(state, train_ds, batch_size, rngs, var_collect):
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
86
    """Train for a single epoch."""
87
    train_ds_size = len(train_ds["image"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
88
89
    steps_per_epoch = train_ds_size // batch_size
    perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
90
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
91
92
93
94
95
    perms = perms.reshape((steps_per_epoch, batch_size))
    epoch_loss = []
    epoch_accuracy = []

    for perm in perms:
96
97
        batch_images = train_ds["image"][perm, ...]
        batch_labels = train_ds["label"][perm, ...]
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
98
        grads, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect, rngs)
99
        state, var_collect = update_model(state, grads)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
100
101
102
103
104
105
106
107
108
109
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)

    avg_loss = np.mean(epoch_loss)
    avg_accuracy = np.mean(epoch_accuracy)
    return state, avg_loss, avg_accuracy, var_collect


def eval_model(state, test_ds, batch_size, var_collect):
    """Evaluation loop."""
110
    test_ds_size = len(test_ds["image"])
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
111
112
113
114
115
116
117
    num_steps = test_ds_size // batch_size
    valid_size = num_steps * batch_size
    all_loss = []
    all_accuracy = []

    for batch_start in range(0, valid_size, batch_size):
        batch_end = batch_start + batch_size
118
119
        batch_images = test_ds["image"][batch_start:batch_end]
        batch_labels = test_ds["label"][batch_start:batch_end]
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
120
121
122
123
124
125
126
127
128
129
130
        _, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect)
        all_loss.append(loss)
        all_accuracy.append(accuracy)

    avg_loss = np.mean(all_loss)
    avg_accuracy = np.mean(all_accuracy)
    return avg_loss, avg_accuracy


def get_datasets():
    """Load MNIST train and test datasets into memory."""
131
    train_ds = load_dataset("mnist", split="train", trust_remote_code=True)
132
133
    train_ds.set_format(type="np")
    batch_size = train_ds["image"].shape[0]
134
135
    shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
    new_train_ds = {
136
137
        "image": train_ds["image"].astype(np.float32).reshape(shape) / 255.0,
        "label": train_ds["label"],
138
    }
139
    test_ds = load_dataset("mnist", split="test", trust_remote_code=True)
140
141
    test_ds.set_format(type="np")
    batch_size = test_ds["image"].shape[0]
142
143
    shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
    new_test_ds = {
144
145
        "image": test_ds["image"].astype(np.float32).reshape(shape) / 255.0,
        "label": test_ds["label"],
146
147
    }
    return new_train_ds, new_test_ds
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
148
149
150
151


def check_fp8(state, var_collect, input_shape, label_shape):
    "Check if model includes FP8."
152
    assert "f8_" in str(
153
154
155
156
157
158
159
        jax.make_jaxpr(apply_model)(
            state,
            jnp.empty(input_shape, dtype=jnp.bfloat16),
            jnp.empty(label_shape, dtype=jnp.bfloat16),
            var_collect,
        )
    )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181


def train_and_evaluate(args):
    """Execute model training and evaluation loop."""
    print(args)

    if args.use_fp8:
        args.use_te = True

    train_ds, test_ds = get_datasets()
    rng = jax.random.PRNGKey(args.seed)
    rng, params_rng = jax.random.split(rng)
    rng, dropout_rng = jax.random.split(rng)
    init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}

    input_shape = [args.batch_size, IMAGE_H, IMAGE_W, IMAGE_C]
    label_shape = [args.batch_size]

    with te.fp8_autocast(enabled=args.use_fp8):
        cnn = Net(args.use_te)
        var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
        tx = optax.sgd(args.lr, args.momentum)
182
183
184
        state = train_state.TrainState.create(
            apply_fn=cnn.apply, params=var_collect[PARAMS_KEY], tx=tx
        )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
185
186
187
188
189

        if args.use_fp8:
            check_fp8(state, var_collect, input_shape, label_shape)

        if args.dry_run:
190
191
192
193
194
195
196
            apply_model(
                state,
                jnp.empty(input_shape, dtype=jnp.bfloat16),
                jnp.empty(label_shape, dtype=jnp.bfloat16),
                var_collect,
                {DROPOUT_KEY: dropout_rng},
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
197
198
199
200
201
202
203
204
205
            print("PASSED")
            return None

        for epoch in range(1, args.epochs + 1):
            rng, input_rng = jax.random.split(rng)
            rng, dropout_rng = jax.random.split(rng)
            rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}

            state, train_loss, train_accuracy, var_collect = train_epoch(
206
207
                state, train_ds, args.batch_size, rngs, var_collect
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
208
209
            test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)

210
211
212
213
214
215
216
            print(
                f"Epoch: {epoch:>2} "
                f"Train Loss: {train_loss:.6f} "
                f"Train Accuracy: {train_accuracy:.6f} "
                f"Test Loss: {test_loss:.6f} "
                f"Test Accuracy: {test_accuracy:.6f} "
            )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265

    return [train_loss, train_accuracy, test_loss, test_accuracy]


def mnist_parser(args):
    """Training settings."""
    parser = argparse.ArgumentParser(description="JAX MNIST Example")
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=800,
        metavar="N",
        help="input batch size for testing (default: 800)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=10,
        metavar="N",
        help="number of epochs to train (default: 10)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.01,
        metavar="LR",
        help="learning rate (default: 0.01)",
    )
    parser.add_argument(
        "--momentum",
        type=float,
        default=0.9,
        metavar="M",
        help="Momentum (default: 0.9)",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        default=False,
        help="quickly check a single pass",
    )
    parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
266
267
268
269
270
271
272
273
274
275
276
277
    parser.add_argument(
        "--use-fp8",
        action="store_true",
        default=False,
        help=(
            "Use FP8 for inference and training without recalibration. "
            "It also enables Transformer Engine implicitly."
        ),
    )
    parser.add_argument(
        "--use-te", action="store_true", default=False, help="Use Transformer Engine"
    )
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
278
279
280
281
282
283
284

    return parser.parse_args(args)


class TestMNIST(unittest.TestCase):
    """MNIST unittests"""

285
286
    gpu_has_fp8, reason = te.fp8.is_fp8_available()

Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
287
288
289
290
291
292
293
294
295
296
    @classmethod
    def setUpClass(cls):
        """Run MNIST without Transformer Engine"""
        cls.args = mnist_parser(["--epochs", "5"])

    @staticmethod
    def verify(actual):
        """Check If loss and accuracy match target"""
        desired_traing_loss = 0.055
        desired_traing_accuracy = 0.98
297
        desired_test_loss = 0.04
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
298
299
300
301
302
303
304
305
306
307
308
309
310
        desired_test_accuracy = 0.098
        assert actual[0] < desired_traing_loss
        assert actual[1] > desired_traing_accuracy
        assert actual[2] < desired_test_loss
        assert actual[3] > desired_test_accuracy

    def test_te_bf16(self):
        """Test Transformer Engine with BF16"""
        self.args.use_te = True
        self.args.use_fp8 = False
        actual = train_and_evaluate(self.args)
        self.verify(actual)

311
    @unittest.skipIf(not gpu_has_fp8, reason)
Jeng Bai-Cheng's avatar
Jeng Bai-Cheng committed
312
313
314
315
316
317
318
319
320
    def test_te_fp8(self):
        """Test Transformer Engine with FP8"""
        self.args.use_fp8 = True
        actual = train_and_evaluate(self.args)
        self.verify(actual)


if __name__ == "__main__":
    train_and_evaluate(mnist_parser(None))