test_mnist.py 8.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import os
import tempfile
import unittest
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.training import train_state

from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import DenseGeneral
from transformer_engine.jax.fp8 import FP8Helper
from utils import is_fp8_supported


class MLPNN(nn.Module):

    use_fp8_dense: bool = True

    @nn.compact
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))    # flatten
        x = nn.Dense(features=512)(x)
        x = nn.relu(x)

        features = [256, 256, 128]
        for feature in features:
            x = DenseGeneral(features=feature, transpose_batch_sequence=False,
                             dtype=jnp.bfloat16, use_bias=True)(x) \
                if self.use_fp8_dense else nn.Dense(features=feature)(x)
            x = nn.relu(x)

        x = nn.Dense(features=10, use_bias=True)(x)
        return x


def cross_entropy_loss(*, logits, labels):
    labels_onehot = jax.nn.one_hot(labels, num_classes=10)
    return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()


def compute_metrics(*, logits, labels):
    loss = cross_entropy_loss(logits=logits, labels=labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
        'loss': loss,
        'accuracy': accuracy,
    }
    return metrics


def get_datasets():
    """Load MNIST train and test datasets into memory."""
    ds_builder = tfds.builder('mnist', data_dir="/tmp/tensorflow-datasets/downloads")
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.
    return train_ds, test_ds


def create_train_state(rng, learning_rate, momentum, use_fp8_dense):
    """Creates initial `TrainState`."""
    cnn = MLPNN(use_fp8_dense=use_fp8_dense)
    variables = cnn.init(rng, jnp.ones([32, 28, 28, 1]))
    tx = optax.sgd(learning_rate, momentum)
    return train_state.TrainState.create(apply_fn=cnn.apply, params=variables['params'],
                                         tx=tx), variables


@partial(jax.jit, static_argnums=(3,))
def train_step(state, others, batch, use_fp8_dense):
    """Train for a single step."""

    def loss_fn(collections):
        logits = MLPNN(use_fp8_dense=use_fp8_dense).apply(collections, batch['image'])
        loss = cross_entropy_loss(logits=logits, labels=batch['label'])
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(others)
    state = state.apply_gradients(grads=grads['params'])
    metrics = compute_metrics(logits=logits, labels=batch['label'])
    return state, metrics, grads


def train_epoch(state, variables, train_ds, batch_size, rng, use_fp8_dense):
    """Train for a single epoch."""
    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size
    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]    # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    batch_metrics = []
    for idx, perm in enumerate(perms):
        idx = idx + 1
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        state, metrics, grads = train_step(state, variables, batch, use_fp8_dense)

        updated_coll = {'params': state.params}
        if use_fp8_dense:
            updated_coll[FP8Helper.FP8_COLLECTION_NAME] \
                = grads[FP8Helper.FP8_COLLECTION_NAME]
        variables = FP8Helper.update_collections(updated_coll, variables)
        batch_metrics.append(metrics)

        if use_fp8_dense:
            variables = FP8Helper.update_fp8_metas(variables)

    return state, variables


@partial(jax.jit, static_argnums=(2,))
def eval_step(variables, batch, use_fp8_dense):
    logits = MLPNN(use_fp8_dense=use_fp8_dense).apply(variables, batch['image'])
    return compute_metrics(logits=logits, labels=batch['label'])


def eval_model(variables, test_ds, batch_size, use_fp8_dense):
    test_ds_size = len(test_ds['image'])
    steps_per_epoch = test_ds_size // batch_size
    perms = np.arange(0, test_ds_size)
    perms = perms[:steps_per_epoch * batch_size]    # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    total_summary = {'correct': 0, 'loss': 0, 'total': 0}
    for _, perm in enumerate(perms):
        batch = {k: v[perm, ...] for k, v in test_ds.items()}
        metrics = eval_step(variables, batch, use_fp8_dense)
        metrics = jax.device_get(metrics)
        summary = jax.tree_map(lambda x: x.item(), metrics)
        total_summary['correct'] += summary['accuracy'] * batch_size
        total_summary['loss'] += summary['loss'] * batch_size
        total_summary['total'] += batch_size
    return total_summary['loss']/total_summary['total'], \
           total_summary['correct']/total_summary['total']


class TestMnist(unittest.TestCase):

    def setUp(self) -> None:
        os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
        self.learning_rate = 0.1
        self.momentum = 0.9

        self.num_epochs = 5
        self.batch_size = 512
        self.train_ds, self.test_ds = get_datasets()

        self.margin = 0.0
        self.num_fp8_layers = 3
        self.fp8_meta_update_interval = 1
        self.temp_file = tempfile.NamedTemporaryFile()    # pylint: disable=consider-using-with
        self.fp8_ckpt_path = self.temp_file.name

        self.seed = 0

        acc_bfp16_ = self._mnist_baseline_runner()
        acc_rtol = 0.005
        self.target_accuracy = acc_bfp16_ * (1. - acc_rtol)

    def tearDown(self):
        self.temp_file.close()

    @unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
    def test_mnist_e4m3(self):
        self._mnist_test_runner(FP8Format.E4M3)

    @unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
    def test_mnist_hybrid(self):
        self._mnist_test_runner(FP8Format.HYBRID)

    # Skip for now due to lack bf16 in TE.Format
    # def test_mnist_bfloa16(self):
    #     self._mnist_test_runner(FP8Format.BFLOAT16)

    def _mnist_baseline_runner(self):
        rng = jax.random.PRNGKey(self.seed)
        rng, init_rng = jax.random.split(rng)

        state, variables = create_train_state(init_rng, self.learning_rate, self.momentum, False)
        del init_rng

        _, accuracy = self._train_model(state, variables, self.num_epochs, rng, False)
        return accuracy

    def _mnist_test_runner(self, fp8_format):
        FP8Helper.initialize(margin=self.margin, fp8_format=fp8_format)

        rng = jax.random.PRNGKey(self.seed)
        rng, init_rng = jax.random.split(rng)

        state, variables = create_train_state(init_rng, self.learning_rate, self.momentum, True)
        del init_rng

        _, test_accuracy = self._train_model(state, variables, self.num_epochs, rng, True)

        self.assertGreater(
            test_accuracy, self.target_accuracy,
            f"Convergence test failed on MNIST with FP8Fomat.{fp8_format.name}. "
            f"Test Accuracy {test_accuracy:.4f} is lower than target {self.target_accuracy:.4f}")

        FP8Helper.finalize()

    def _train_model(self, state, variables, epochs, rng, use_fp8_dense):
        max_test_acc = 0.0
        for _ in range(0, epochs):
            rng, input_rng = jax.random.split(rng)

            state, variables = train_epoch(state, variables, self.train_ds, self.batch_size,
                                           input_rng, use_fp8_dense)

            _, test_accuracy = eval_model(variables, self.test_ds, self.batch_size, use_fp8_dense)
            max_test_acc = test_accuracy if test_accuracy > max_test_acc else max_test_acc
        return state, max_test_acc


if __name__ == '__main__':
    unittest.main()