Unverified Commit 5992e03d authored by Jeng Bai-Cheng's avatar Jeng Bai-Cheng Committed by GitHub
Browse files

[JAX] Add TE examples (#108)



* refactor JAX examples
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix doc-string
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add dp example
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix params_axes_pspec
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Add model parallel example and refactor
Update readme
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* align code and readme
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* update verification
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add mask
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* num_gpu is configurable
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* update readme
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* update readme
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* solvepylint issue
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* ignore markdown and txt file from license check
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update README.md
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add flax into requirements.txt
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

---------
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>
parent b7acb6e1
# Transformer Engine Examples #
This folder contains simple examples introducing Transformer Engine and FP8 training usage.
**Examples Outline**
* MNIST training: Training MNIST dataset is a good start point to learn how use Transformer Engine and enable FP8 training
* Encoder training: The encoder examples introduce more about how to scale up training on multiple GPUs with Transformer Engine
\ No newline at end of file
# Basic Transformer Encoder Example with Optional FP8 #
This example uses Transformer Encoder to demonstrate the Transformer Engine usage. And more focus on scaling up training on multiple GPUs. Highly recommend studying the [MNIST example of the Transformer Engine](/examples/jax/mnist) before reading this example. The Transformer Engine is built on top of [Flax](https://github.com/google/flax). Thus, examples use `pjit` to set up multiple GPU training. The basic pjit usage can be referred to [Scale up Flax Modules on multiple devices with pjit](https://flax.readthedocs.io/en/latest/guides/flax_on_pjit.html).
## Single GPU ##
1. Setup dataset: This is done by using the `tfds` library to download the GLUE/CoLA dataset and using `nltk` to tokenize the sentences. This example focuses on Transformer Engine usage. Thus, a simple algorithm is used to convert tokens to INT32 tensors as input to the embedding layer. The `get_datasets` and `data_preprocess` routines are used for this purpose.
2. Define model: The `Net` class is a small Transformer Encoder model for sentence classification. The Transformer Engine provides `te.TransformerLayer` as encoder block and `te.DenseGeneral`. The structure of encoder block can be referred to [Scaling Up Models and Data with t5x and seqio](https://arxiv.org/abs/2203.17189)
3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. Use `fp8_autocast` context manager to enable FP8 training and check `var_collect` if the variable collection contains `Float8`.
4. Training process: In `train_step`, combine the FP8 metadata and latest model parameters into var_collect as a frozen dictionary and fill it to the gradient function. And then, call `te.update_fp8_metas` to update FP8 metadata. The number of training steps to update FP8 metadata can be customized. In this example, it is updated every step.
5. Evaluating process: Same as the training process, the FP8 metadata needs to be in var_collect and fill it into a loss function, if enabling FP8 computing.
### Run ###
```bash
python test_single_gpu_encoder.py
python test_single_gpu_encoder.py --use-fp8
```
## Multiple GPU with Data Parallelism ##
1. The data parallelism (DP) divides a mini-batch for multiple devices, and each device has complete model parameters. In this example, the first dimension of input tensor is `batch_size` which is 64 by default, and uses 8 GPUs to train the model, so each device takes 8 sentences at once. The "dividing" is called "sharding" in the JAX documents.
2. In order to let JAX know how to do sharding, the `device_mesh` needs to be defined and each axis need to be named. A common way to annotate axis names is `data` which means the mesh dimension used for data-parallel sharding of the batch dimension of inputs and activations. And the first argument of `te.ShardingResource` is the name of the device axis which is used for data parallelism.
3. On the model side, the logical axis of each weight tensor of the model can be named. The `te.TransformerLayer` has the default names, which are stored in `abs_var_collect`, a collection of variables returned by `jax.eval_shape(encoder.init, ...)`. The key index is `params_axes`. The `te.DenseGeneral` doesn't have the default named axis because it is generic. Also, data-parallel sharding doesn't need to divide weight tensor, so named axis is not required for this case. But te.DenseGeneral is based on [XLA custom-call](https://www.tensorflow.org/xla/custom_call) and [xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html), the `sharding_type` must be set to map weights and xmap correctly.
4. The next is to create sharding rules, mapping the device axis to the logical axis. The `te.extend_logical_axis_rules` under fp8_autocast will return a list of pairs of the mapping, such as `(('batch', 'data'), ...)`. The first is the logical axis and second is the device axis.
5. Refer structure of `abs_var_collect['params']` and `abs_var_collect['params_axes']` to set up `PartitionSpec` for pjit. All logical axes should be replaced by device axes. If the value of PartitionSpec is None, that means no sharding, broadcasting the data to every device. Note that the `params_axes` attribute is provided by Transformer Engine. The Flax's module doesn't have it, such as `nn.Embed`. For nn.Embed, assigning an empty PartitionSpec is fine because each device has its own embedding layer in DP mode. The `get_params_pspec` routine is used for this purpose. Because each device has a complete model in DP mode, all values of PartitionSpec in params_pspec should be None. This will be different in the model parallelism example.
6. Fill in `params_pspec` and `encoder.init` to pjit to get a compiled function, `pjit_encoder_init `, and use it to initialize the model, so JAX now can know how to do the sharding.
7. The `train_step` and `eval_step` also needs to be compiled by pjit. Thus, every input and output argument has to be set up `PartitionSpec` if the argument contains a tensor. For instance, the `input_pspec` is `PartitionSpec('data', None)` because the input shape is (batch size, sequence length). Then, the rest of the workflow is similar to the previous example.
### Run ###
```bash
python test_multigpu_encoder.py
python test_multigpu_encoder.py --use-fp8
```
## Multiple GPU with Model Parallelism ##
1. The model parallelism as known as tensor parallelism (TP) divides a model for multiple devices, and each device has part of model parameters. This example inherits previous DP example, but divides a model to two devices.
2. To set up device mesh for TP, adding a new named axis called `model`, which is used for sharding parameters of the model across devices. This example divides the model to two parts (`num_gpu_tp = 2`). One device only has half of the model.
3. On the model side, The `te.TransformerLayer` doesn't need additional settings because it has the default axis name already. It will be divided by `DEVICE_TP_AXIS` when model initialization. The first `te.DenseGeneral` is divided by columns and second one is divided by rows for TP. Because `te.DenseGeneral` doesn't have the default named axis, the names must be set manually by passing `kernel_axes` and `bias_axes` arguments. Then, the rest of the workflow is similar to the previous example.
4. The tips for debugging TP:
* Use [inspect_array_sharding](https://jax.readthedocs.io/en/latest/_autosummary/jax.debug.inspect_array_sharding.html) or [visualize_array_sharding](https://jax.readthedocs.io/en/latest/_autosummary/jax.debug.visualize_array_sharding.html) to check the shape of activations and weights.
* Check the shape of device buffer of weight tensor. For instance, `var_collect['params']['DenseGeneral_0']['kernel'].device_buffers[device_id].shape`. The `device_id` is an integer. If a weight tensor's shape is (256, 256) and you intend to divide it for two devices by second dimension, then the shape returned by device_buffers should be (256, 128).
* Dump XLA HLO by setting `XLA_FLAGS` and see whether it contains unexpected `all-gather` operations or not.
```python
import os
os.environ['XLA_FLAGS'] = "--xla_dump_hlo_as_proto --xla_dump_hlo_as_text --xla_dump_hlo_as_html --xla_dump_to=<path to store XLA HLO>"
```
### Run ###
```bash
python test_model_parallel_encoder.py
python test_model_parallel_encoder.py --use-fp8
```
flax
nltk
optax
tensorflow-datasets
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
""" Encoder training on multi-GPU with tesnor parallelism"""
import argparse
import unittest
from functools import partial
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
import tensorflow_datasets as tfds
from cuda import cudart
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state
from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit
import transformer_engine.jax as te
DEVICE_DP_AXIS = 'data'
DEVICE_TP_AXIS = 'model'
NAMED_BROADCAST_AXIS = 'my_broadcast_axis'
NAMED_TP_AXIS = 'my_tp_axis'
PARAMS_KEY = 'params'
PARAMS_AXES_KEY = PARAMS_KEY + '_axes'
DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng'
def check_num_gpu(desired_num_gpu):
"""Check if the number of GPUs are correct."""
actual_num_gpu = len(jax.local_devices())
assert actual_num_gpu == desired_num_gpu, f"Number of GPUs is mismatch. " \
f"{desired_num_gpu} GPUs are assigned, but the actual number of GPUs is {actual_num_gpu}"
def gpu_has_fp8():
"""Check if the GPU has FP8."""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
_, major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor
_, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id)
sm_arch = major * 10 + minor
return sm_arch >= 89
class Net(nn.Module):
"""NLP Encoder"""
num_embed: int
@nn.compact
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.TransformerLayer,
hidden_size=256,
mlp_hidden_size=1024,
num_attention_heads=8,
hidden_dropout=0.1,
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te.TransformerLayerType.ENCODER,
enable_relative_embedding=False,
dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
x = te.DenseGeneral(features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
sharding_type=te.ShardingType.DP_TP_COL,
dtype=jnp.bfloat16)(x)
x = te.DenseGeneral(features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,),
sharding_type=te.ShardingType.DP_TP_ROW,
dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x
def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
"""Computes gradients, loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
var_collect, grads = grads.pop(PARAMS_KEY)
state = state.apply_gradients(grads=grads)
if use_fp8:
var_collect = te.update_fp8_metas(var_collect)
return state, loss, accuracy, var_collect
def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_fn):
"""Train for a single epoch."""
train_ds_size = len(train_ds['sentence'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_inputs = train_ds['sentence'][perm, ...]
batch_masks = train_ds['mask'][perm, ...]
batch_labels = train_ds['label'][perm, ...]
state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks,
batch_labels, var_collect, rngs, use_fp8)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
avg_loss = np.mean(epoch_loss)
avg_accuracy = np.mean(epoch_accuracy)
return state, avg_loss, avg_accuracy, var_collect
def eval_step(state, inputs, masks, labels, var_collect):
"""Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels, 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
loss, logits = loss_fn(var_collect, disable_dropout=True)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return loss, accuracy
def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
"""Evaluation loop."""
test_ds_size = len(test_ds['sentence'])
num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size
all_loss = []
all_accuracy = []
for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size
batch_inputs = test_ds['sentence'][batch_start:batch_end]
batch_masks = test_ds['mask'][batch_start:batch_end]
batch_labels = test_ds['label'][batch_start:batch_end]
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
avg_loss = np.mean(all_loss)
avg_accuracy = np.mean(all_accuracy)
return avg_loss, avg_accuracy
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
nltk.download('punkt')
dataset_size = len(dataset['sentence'])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.empty((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset['sentence']):
tokens = nltk.word_tokenize(sentence.decode("utf-8"))
tensor = output[j]
mask_1d = np.zeros((1, max_seq_len), dtype=np.uint8)
for i, word in enumerate(tokens):
if i >= max_seq_len:
break
if word not in vocab:
vocab[word] = word_id
tensor[i] = word_id
word_id = word_id + 1
else:
tensor[i] = vocab[word]
mask_1d[0, i] = 1
mask_2d = mask_3d[j]
np.dot(mask_1d.T, mask_1d, out=mask_2d)
np.subtract(1, mask_2d, out=mask_2d)
dataset['sentence'] = output
dataset['label'] = dataset['label'].astype(np.float32)
dataset['mask'] = mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
return dataset, vocab, word_id
def get_datasets(max_seq_len):
"""Load GLUE train and test datasets into memory."""
vocab = {}
word_id = 0
dataset = 'glue/cola'
train_ds = tfds.as_numpy(tfds.load(dataset, split='train', batch_size=-1))
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = tfds.as_numpy(tfds.load(dataset, split='validation', batch_size=-1))
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "Float8" in str(
jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect,
rngs, True))
def get_params_pspec(sharding_rules, abs_var_collect):
"""Refer params to create params partition spec"""
rules_dict = {}
for key, value in sharding_rules:
rules_dict[key] = value
def to_device_axis(logical_axis):
partitions = [rules_dict[key] for key in logical_axis]
return jax.sharding.PartitionSpec(*partitions)
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_pspec = jax.tree_map(to_device_axis, nn.partitioning.get_axis_names(params_axes))
params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY])
params_pspec = FrozenDict({**params_pspec, **params_axes_pspec})
return params_pspec
def get_state_pspec(state, params_pspec):
"""Refer params_pspec to create state partition spec"""
def replace_params(x):
return params_pspec if isinstance(x, FrozenDict) else None
state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, FrozenDict))
return state_pspec
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
check_num_gpu(args.num_gpu)
if args.use_fp8:
assert gpu_has_fp8(), "GPU needs to support FP8."
num_gpu_tp = 2
if args.num_gpu % num_gpu_tp == 0:
num_gpu_dp = args.num_gpu // num_gpu_tp
else:
num_gpu_dp = 1
num_gpu_tp = 1
assert args.batch_size % num_gpu_dp == 0, f"Batch size needs to be multiple of {num_gpu_dp}"
assert args.test_batch_size % num_gpu_dp == 0, \
f"Test batch size needs to be multiple of {num_gpu_dp}"
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
with te.fp8_autocast(args.use_fp8,
sharding_resource=te.ShardingResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS)):
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te.extend_logical_axis_rules(tuple()) + customized_rules
params_pspec = get_params_pspec(sharding_rules, abs_var_collect)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
in_shardings = (None, inputs_pspec, masks_pspec)
out_shardings = FrozenDict({key: params_pspec if key is PARAMS_KEY else None \
for key in abs_var_collect})
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
var_collect, params = var_collect.pop(PARAMS_KEY)
state = train_state.TrainState.create(apply_fn=encoder.apply,
params=params,
tx=optimizer)
state_pspec = get_state_pspec(state, params_pspec)
labels_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS,)
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
out_shardings = (state_pspec, None, None, None)
pjit_train_step = pjit(train_step, in_shardings, out_shardings, static_argnums=(6,))
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None)
out_shardings = (None, None)
pjit_eval_step = pjit(eval_step, in_shardings, out_shardings)
if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
check_fp8(state, var_collect, inputs, masks, labels)
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
pjit_train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8)
print("PASSED")
return None
for epoch in range(1, args.epochs + 1):
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8,
pjit_train_step)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size,
var_collect, pjit_eval_step)
print(f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} ")
return [train_loss, train_accuracy, test_loss, test_accuracy]
def encoder_parser(args):
"""Training settings."""
parser = argparse.ArgumentParser(description="JAX Encoder Example")
parser.add_argument(
"--num-gpu",
type=int,
default=8,
metavar="N",
help="number of GPUs (default: 8)",
)
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for testing (default: 64)",
)
parser.add_argument(
"--max-seq-len",
type=int,
default=32,
metavar="N",
help="maximum sequence length (default: 32)",
)
parser.add_argument(
"--epochs",
type=int,
default=3,
metavar="N",
help="number of epochs to train (default: 3)",
)
parser.add_argument(
"--lr",
type=float,
default=0.0001,
metavar="LR",
help="learning rate (default: 0.0001)",
)
parser.add_argument(
"--dry-run",
action="store_true",
default=False,
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument("--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration")
return parser.parse_args(args)
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
@classmethod
def setUpClass(cls):
"""Run 3 epochs for testing"""
num_gpu = len(jax.local_devices())
if num_gpu % 2 != 0:
num_gpu = 1
cls.args = encoder_parser(["--epochs", "3", "--num-gpu", str(num_gpu)])
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not gpu_has_fp8(), reason='GPU capability is not enough to run FP8')
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
""" Encoder training on multi-GPU with data parallelism"""
import argparse
import unittest
from functools import partial
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
import tensorflow_datasets as tfds
from cuda import cudart
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state
from jax.experimental import mesh_utils
from jax.experimental.pjit import pjit
import transformer_engine.jax as te
DEVICE_DP_AXIS = 'data'
PARAMS_KEY = 'params'
PARAMS_AXES_KEY = PARAMS_KEY + '_axes'
DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng'
def check_num_gpu(desired_num_gpu):
"""Check if the number of GPUs are correct."""
actual_num_gpu = len(jax.local_devices())
assert actual_num_gpu == desired_num_gpu, f"Number of GPUs is mismatch. " \
f"{desired_num_gpu} GPUs are assigned, but the actual number of GPUs is {actual_num_gpu}"
def gpu_has_fp8():
"""Check if the GPU has FP8."""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
_, major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor
_, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id)
sm_arch = major * 10 + minor
return sm_arch >= 89
class Net(nn.Module):
"""NLP Encoder"""
num_embed: int
@nn.compact
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.TransformerLayer,
hidden_size=256,
mlp_hidden_size=1024,
num_attention_heads=8,
hidden_dropout=0.1,
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te.TransformerLayerType.ENCODER,
enable_relative_embedding=False,
dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
x = te.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, dtype=jnp.bfloat16)(x)
x = te.DenseGeneral(features=256, sharding_type=te.ShardingType.DP, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x
def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
"""Computes gradients, loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
var_collect, grads = grads.pop(PARAMS_KEY)
state = state.apply_gradients(grads=grads)
if use_fp8:
var_collect = te.update_fp8_metas(var_collect)
return state, loss, accuracy, var_collect
def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8, train_fn):
"""Train for a single epoch."""
train_ds_size = len(train_ds['sentence'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_inputs = train_ds['sentence'][perm, ...]
batch_masks = train_ds['mask'][perm, ...]
batch_labels = train_ds['label'][perm, ...]
state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks,
batch_labels, var_collect, rngs, use_fp8)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
avg_loss = np.mean(epoch_loss)
avg_accuracy = np.mean(epoch_accuracy)
return state, avg_loss, avg_accuracy, var_collect
def eval_step(state, inputs, masks, labels, var_collect):
"""Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels, 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
loss, logits = loss_fn(var_collect, disable_dropout=True)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return loss, accuracy
def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
"""Evaluation loop."""
test_ds_size = len(test_ds['sentence'])
num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size
all_loss = []
all_accuracy = []
for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size
batch_inputs = test_ds['sentence'][batch_start:batch_end]
batch_masks = test_ds['mask'][batch_start:batch_end]
batch_labels = test_ds['label'][batch_start:batch_end]
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
avg_loss = np.mean(all_loss)
avg_accuracy = np.mean(all_accuracy)
return avg_loss, avg_accuracy
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
nltk.download('punkt')
dataset_size = len(dataset['sentence'])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.empty((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset['sentence']):
tokens = nltk.word_tokenize(sentence.decode("utf-8"))
tensor = output[j]
mask_1d = np.zeros((1, max_seq_len), dtype=np.uint8)
for i, word in enumerate(tokens):
if i >= max_seq_len:
break
if word not in vocab:
vocab[word] = word_id
tensor[i] = word_id
word_id = word_id + 1
else:
tensor[i] = vocab[word]
mask_1d[0, i] = 1
mask_2d = mask_3d[j]
np.dot(mask_1d.T, mask_1d, out=mask_2d)
np.subtract(1, mask_2d, out=mask_2d)
dataset['sentence'] = output
dataset['label'] = dataset['label'].astype(np.float32)
dataset['mask'] = mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
return dataset, vocab, word_id
def get_datasets(max_seq_len):
"""Load GLUE train and test datasets into memory."""
vocab = {}
word_id = 0
dataset = 'glue/cola'
train_ds = tfds.as_numpy(tfds.load(dataset, split='train', batch_size=-1))
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = tfds.as_numpy(tfds.load(dataset, split='validation', batch_size=-1))
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "Float8" in str(
jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect,
rngs, True))
def get_params_pspec(sharding_rules, abs_var_collect):
"""Refer params to create params partition spec"""
rules_dict = {}
for key, value in sharding_rules:
rules_dict[key] = value
def to_device_axis(logical_axis):
partitions = [rules_dict[key] for key in logical_axis]
return jax.sharding.PartitionSpec(*partitions)
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_pspec = jax.tree_map(to_device_axis, nn.partitioning.get_axis_names(params_axes))
params_pspec = jax.tree_map(lambda x: jax.sharding.PartitionSpec(), abs_var_collect[PARAMS_KEY])
params_pspec = FrozenDict({**params_pspec, **params_axes_pspec})
return params_pspec
def get_state_pspec(state, params_pspec):
"""Refer params_pspec to create state partition spec"""
def replace_params(x):
return params_pspec if isinstance(x, FrozenDict) else None
state_pspec = jax.tree_map(replace_params, state, is_leaf=lambda x: isinstance(x, FrozenDict))
return state_pspec
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
check_num_gpu(args.num_gpu)
assert args.batch_size % args.num_gpu == 0, f"Batch size needs to be multiple of {args.num_gpu}"
assert args.test_batch_size % args.num_gpu == 0, \
f"Test batch size needs to be multiple of {args.num_gpu}"
if args.use_fp8:
assert gpu_has_fp8(), "GPU needs to support FP8."
device_mesh = mesh_utils.create_device_mesh((args.num_gpu,))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
with te.fp8_autocast(args.use_fp8, sharding_resource=te.ShardingResource(DEVICE_DP_AXIS)):
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
sharding_rules = te.extend_logical_axis_rules(tuple())
params_pspec = get_params_pspec(sharding_rules, abs_var_collect)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
in_shardings = (None, inputs_pspec, masks_pspec)
out_shardings = FrozenDict({key: params_pspec if key is PARAMS_KEY else None \
for key in abs_var_collect})
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
var_collect, params = var_collect.pop(PARAMS_KEY)
state = train_state.TrainState.create(apply_fn=encoder.apply,
params=params,
tx=optimizer)
state_pspec = get_state_pspec(state, params_pspec)
labels_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS,)
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
out_shardings = (state_pspec, None, None, None)
pjit_train_step = pjit(train_step, in_shardings, out_shardings, static_argnums=(6,))
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None)
out_shardings = (None, None)
pjit_eval_step = pjit(eval_step, in_shardings, out_shardings)
if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
check_fp8(state, var_collect, inputs, masks, labels)
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
pjit_train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8)
print("PASSED")
return None
for epoch in range(1, args.epochs + 1):
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8,
pjit_train_step)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size,
var_collect, pjit_eval_step)
print(f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} ")
return [train_loss, train_accuracy, test_loss, test_accuracy]
def encoder_parser(args):
"""Training settings."""
parser = argparse.ArgumentParser(description="JAX Encoder Example")
parser.add_argument(
"--num-gpu",
type=int,
default=8,
metavar="N",
help="number of GPUs (default: 8)",
)
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for testing (default: 64)",
)
parser.add_argument(
"--max-seq-len",
type=int,
default=32,
metavar="N",
help="maximum sequence length (default: 32)",
)
parser.add_argument(
"--epochs",
type=int,
default=3,
metavar="N",
help="number of epochs to train (default: 3)",
)
parser.add_argument(
"--lr",
type=float,
default=0.0001,
metavar="LR",
help="learning rate (default: 0.0001)",
)
parser.add_argument(
"--dry-run",
action="store_true",
default=False,
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument("--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration")
return parser.parse_args(args)
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
@classmethod
def setUpClass(cls):
"""Run 3 epochs for testing"""
num_gpu = len(jax.local_devices())
if num_gpu % 2 != 0:
num_gpu = 1
cls.args = encoder_parser(["--epochs", "3", "--num-gpu", str(num_gpu)])
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not gpu_has_fp8(), reason='GPU capability is not enough to run FP8')
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
""" Encoder with BF16 Training on single GPU"""
import jax
import jax.numpy as jnp
import optax
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state
import transformer_engine.jax as te
PARAMS_KEY = 'params'
BATCH = 32
SEQLEN = 512
HIDDEN = 1024
def network():
"""NLP Encoder"""
encoder = te.TransformerLayer(hidden_size=HIDDEN,
mlp_hidden_size=4 * HIDDEN,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_type='rmsnorm',
mlp_activations=('gelu', 'linear'),
layer_type=te.TransformerLayerType.ENCODER,
transpose_batch_sequence=True,
dtype=jnp.bfloat16)
return encoder
def synthesis_data(data_rng):
"""Dataset generator"""
return jax.random.normal(data_rng, [SEQLEN, BATCH, HIDDEN], jnp.bfloat16)
def train_step(batch, state, others):
"""Training function."""
def loss_fn(collections):
logits = state.apply_fn(collections, batch)
loss = jnp.mean(logits)
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(FrozenDict({PARAMS_KEY: state.params, **others}))
grads, params_grads = grads.pop(PARAMS_KEY)
state = state.apply_gradients(grads=params_grads)
return loss, state, others
def test_encoder():
"""Encoder example"""
rng = jax.random.PRNGKey(0)
rng, init_rng, data_rng = jax.random.split(rng, 3)
inputs = synthesis_data(data_rng)
encoder = network()
variables = jax.jit(encoder.init)(init_rng, inputs)
variables, params = variables.pop(PARAMS_KEY)
optimizer = optax.sgd(0.001, 0.9)
state = train_state.TrainState.create(apply_fn=encoder.apply, params=params, tx=optimizer)
jitted_train_step = jax.jit(train_step)
for i in range(5):
rng, data_rng = jax.random.split(rng)
inputs = synthesis_data(data_rng)
loss, state, variables = jitted_train_step(inputs, state, variables)
print(f"Step {i} - Loss: {loss}")
if __name__ == "__main__":
test_encoder()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
""" Encoder training on single GPU"""
import argparse
import os
import unittest
from functools import partial
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
import tensorflow_datasets as tfds
from cuda import cudart
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state
import transformer_engine.jax as te
PARAMS_KEY = 'params'
DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng'
def gpu_has_fp8():
"""Check if the GPU has FP8."""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
_, major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor
_, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id)
sm_arch = major * 10 + minor
return sm_arch >= 89
class Net(nn.Module):
"""NLP Encoder"""
num_embed: int
@nn.compact
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(te.TransformerLayer,
hidden_size=256,
mlp_hidden_size=1024,
num_attention_heads=8,
hidden_dropout=0.1,
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te.TransformerLayerType.ENCODER,
enable_relative_embedding=False,
dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
x = te.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x
@partial(jax.jit, static_argnums=6)
def train_step(state, inputs, masks, labels, var_collect, rngs, use_fp8):
"""Computes gradients, loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
var_collect, grads = grads.pop(PARAMS_KEY)
state = state.apply_gradients(grads=grads)
if use_fp8:
var_collect = te.update_fp8_metas(var_collect)
return state, loss, accuracy, var_collect
def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8):
"""Train for a single epoch."""
train_ds_size = len(train_ds['sentence'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_inputs = train_ds['sentence'][perm, ...]
batch_masks = train_ds['mask'][perm, ...]
batch_labels = train_ds['label'][perm, ...]
state, loss, accuracy, var_collect = train_step(state, batch_inputs, batch_masks,
batch_labels, var_collect, rngs, use_fp8)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
avg_loss = np.mean(epoch_loss)
avg_accuracy = np.mean(epoch_accuracy)
return state, avg_loss, avg_accuracy, var_collect
@jax.jit
def eval_step(state, inputs, masks, labels, var_collect):
"""Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels, 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
loss, logits = loss_fn(var_collect, disable_dropout=True)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return loss, accuracy
def eval_model(state, test_ds, batch_size, var_collect):
"""Evaluation loop."""
test_ds_size = len(test_ds['sentence'])
num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size
all_loss = []
all_accuracy = []
for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size
batch_inputs = test_ds['sentence'][batch_start:batch_end]
batch_masks = test_ds['mask'][batch_start:batch_end]
batch_labels = test_ds['label'][batch_start:batch_end]
loss, accuracy = eval_step(state, batch_inputs, batch_masks, batch_labels, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
avg_loss = np.mean(all_loss)
avg_accuracy = np.mean(all_accuracy)
return avg_loss, avg_accuracy
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
nltk.download('punkt')
dataset_size = len(dataset['sentence'])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.empty((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset['sentence']):
tokens = nltk.word_tokenize(sentence.decode("utf-8"))
tensor = output[j]
mask_1d = np.zeros((1, max_seq_len), dtype=np.uint8)
for i, word in enumerate(tokens):
if i >= max_seq_len:
break
if word not in vocab:
vocab[word] = word_id
tensor[i] = word_id
word_id = word_id + 1
else:
tensor[i] = vocab[word]
mask_1d[0, i] = 1
mask_2d = mask_3d[j]
np.dot(mask_1d.T, mask_1d, out=mask_2d)
np.subtract(1, mask_2d, out=mask_2d)
dataset['sentence'] = output
dataset['label'] = dataset['label'].astype(np.float32)
dataset['mask'] = mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
return dataset, vocab, word_id
def get_datasets(max_seq_len):
"""Load GLUE train and test datasets into memory."""
vocab = {}
word_id = 0
dataset = 'glue/cola'
train_ds = tfds.as_numpy(tfds.load(dataset, split='train', batch_size=-1))
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = tfds.as_numpy(tfds.load(dataset, split='validation', batch_size=-1))
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "Float8" in str(
jax.make_jaxpr(train_step, static_argnums=6)(state, inputs, masks, labels, var_collect,
rngs, True))
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
print(args)
if args.use_fp8:
assert gpu_has_fp8(), "GPU needs to support FP8."
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
with te.fp8_autocast(enabled=args.use_fp8):
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
var_collect = encoder.init(init_rngs, inputs, masks)
tx = optax.adamw(args.lr)
state = train_state.TrainState.create(apply_fn=encoder.apply,
params=var_collect[PARAMS_KEY],
tx=tx)
if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
check_fp8(state, var_collect, inputs, masks, labels)
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
train_step(state, inputs, masks, labels, var_collect, rngs, args.use_fp8)
print("PASSED")
return None
for epoch in range(1, args.epochs + 1):
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)
print(f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} ")
return [train_loss, train_accuracy, test_loss, test_accuracy]
def encoder_parser(args):
"""Training settings."""
parser = argparse.ArgumentParser(description="JAX Encoder Example")
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for testing (default: 64)",
)
parser.add_argument(
"--max-seq-len",
type=int,
default=32,
metavar="N",
help="maximum sequence length (default: 32)",
)
parser.add_argument(
"--epochs",
type=int,
default=3,
metavar="N",
help="number of epochs to train (default: 3)",
)
parser.add_argument(
"--lr",
type=float,
default=0.0001,
metavar="LR",
help="learning rate (default: 0.0001)",
)
parser.add_argument(
"--dry-run",
action="store_true",
default=False,
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument("--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration")
return parser.parse_args(args)
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
@classmethod
def setUpClass(cls):
"""Run 4 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not gpu_has_fp8(), reason='GPU capability is not enough to run FP8')
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
""" Encoder with FP8 Training on single GPU"""
import jax
import jax.numpy as jnp
import optax
from cuda import cudart
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state
import transformer_engine.jax as te
from transformer_engine.jax.fp8 import FP8Helper
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.common.recipe import DelayedScaling
PARAMS_KEY = 'params'
BATCH = 32
SEQLEN = 512
HIDDEN = 1024
def gpu_has_fp8():
"""GPU arch has to support FP8"""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
_, major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor
_, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id)
sm_arch = major * 10 + minor
return sm_arch >= 89
def network():
"""NLP Encoder"""
encoder = te.TransformerLayer(hidden_size=HIDDEN,
mlp_hidden_size=4 * HIDDEN,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_type='rmsnorm',
mlp_activations=('gelu', 'linear'),
layer_type=te.TransformerLayerType.ENCODER,
transpose_batch_sequence=True,
dtype=jnp.bfloat16)
return encoder
def synthesis_data(data_rng):
"""Dataset generator"""
return jax.random.normal(data_rng, [SEQLEN, BATCH, HIDDEN], jnp.bfloat16)
def train_step(batch, state, others):
"""Training function."""
def loss_fn(collections):
logits = state.apply_fn(collections, batch)
loss = jnp.mean(logits)
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(FrozenDict({PARAMS_KEY: state.params, **others}))
grads, params_grads = grads.pop(PARAMS_KEY)
state = state.apply_gradients(grads=params_grads)
others = FP8Helper.update_fp8_metas(grads)
return loss, state, others
def test_encoder():
"""Encoder example"""
if gpu_has_fp8() is False:
print("GPU doesn't support FP8")
return
rng = jax.random.PRNGKey(0)
rng, init_rng, data_rng = jax.random.split(rng, 3)
inputs = synthesis_data(data_rng)
optimizer = optax.sgd(0.001, 0.9)
with te.fp8_autocast(enabled=True, fp8_recipe=DelayedScaling(fp8_format=FP8Format.HYBRID)):
encoder = network()
variables = jax.jit(encoder.init)(init_rng, inputs)
variables, params = variables.pop(PARAMS_KEY)
state = train_state.TrainState.create(apply_fn=encoder.apply, params=params, tx=optimizer)
jitted_train_step = jax.jit(train_step)
assert "fp8" in str(jax.make_jaxpr(jitted_train_step)(inputs, state, variables))
for i in range(5):
rng, data_rng = jax.random.split(rng)
inputs = synthesis_data(data_rng)
loss, state, variables = jitted_train_step(inputs, state, variables)
print(f"Step {i} - Loss: {loss}")
if __name__ == "__main__":
test_encoder()
# Basic MNIST Example with Optional FP8 #
This example uses MNIST training to demonstrate the Transformer Engine usage. The Transformer Engine is built on top of [Flax](https://github.com/google/flax), a neural network library and ecosystem for JAX. Thus, the Transformer Engine is free to interoperate with other Flax modules. The basic Flax usage can be referred to [Flax Basics](https://flax.readthedocs.io/en/latest/guides/flax_basics.html).
1. Setup dataset: The first step is to prepare the dataset. This is done by using the `tfds` library to download the MNIST dataset and perform image preprocessing. The `get_datasets` routine is used for this purpose.
2. Define model: The `Net` class is a small CNN model for image classification. It has an option to switch between using `nn.Dense` provided by Flax and `te.DenseGeneral` provided by the Transformer Engine. This allows for easy comparison between the two libraries.
3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. For FP8 training, the key is `te.fp8_autocast` context manager. If fp8_autocast is enabled, it will cast all `te.DenseGeneral` to FP8 precision. The `var_collect` is a collection including needed information for model training, such as parameters and FP8 metadata, which is necessary for correct casting of BF16 tensors into FP8 tensors at runtime. If fp8_autocast is turned on and print var_collect, you will see FP8 metadata inside, such as `fp8_meta_collection` section. The training and evaluating with FP8 have to be done under fp8_autocast. If not, then fp8_autocast will deconstruct the FP8 metadata, and the model will fall back to higher floating point precision, such as BF16 in this example. To check if FP8 is enabled, use the `check_fp8` routine. If model initialization with FP8 works fine, the string returned by jax.make_jaxpr should include the `Float8` keyword.
4. Training process: In `apply_model`, the main difference between normal Flax usage and this example is, with FP8 training, the FP8 metadata has to be filled into the gradient function `grad_fn`. Otherwise, the Transformer Engine doesn't know how to cast the BF16 tensor into FP8 tensor at runtime correctly. The FP8 metadata doesn't belong in model parameters (`state.params`), so we need to manually combine the metadata and latest model parameters into var_collect as a frozen dictionary and fill it to the gradient function. After getting loss and gradient, we also need to call `te.update_fp8_metas` to update FP8 metadata in the `update_model` routine. The number of training steps to update FP8 metadata can be customized. In this example, it is updated every step.
5. Evaluating process: The evaluating process is the same as the training process. Need to ensure FP8 metadata is inside var_collect and fill it into loss function.
6. Additional options: The `te.fp8_autocast` context manager has additional options
* FP8 Recipe: control FP8 training behavior. See the [FP8 tutorial](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) for a detailed explanation of FP8 recipes and the supported options. **Noted** that FP8 metadata is now the responsibility of the user to update (i.e., manually calling `te.update_fp8_metas`). The JAX version of Transformer Engine cannot update FP8 metadata on its own.
* Sharding Resource: tell Transformer Engine how to make data parallelism and tensor parallelism. We will introduce it more in Encoder examples.
## Run ##
1. Use Flax to train MNIST with BF16 as usual
```bash
python test_single_gpu_mnist.py
```
2. Use `te.DenseGeneral` provided by Transformer Engine to train MNIST with BF16
```bash
python test_single_gpu_mnist.py --use-te
```
3. Use `te.DenseGeneral` provided by Transformer Engine to train MNIST and enable FP8 training and evaluation.
```bash
python test_single_gpu_mnist.py --use-fp8
```
flax
optax
tensorflow-datasets
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
""" MNIST training on single GPU"""
import argparse
import os
import unittest
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds
from cuda import cudart
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.training import train_state
import transformer_engine.jax as te
IMAGE_H = 28
IMAGE_W = 28
IMAGE_C = 1
PARAMS_KEY = 'params'
DROPOUT_KEY = 'dropout'
INPUT_KEY = 'input_rng'
def gpu_has_fp8():
"""Check if the GPU has FP8."""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
_, major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor
_, minor = cudart.cudaDeviceGetAttribute(flag, gpu_id)
sm_arch = major * 10 + minor
return sm_arch >= 89
class Net(nn.Module):
"""CNN model for MNIST."""
use_te: bool = False
@nn.compact
def __call__(self, x, disable_dropout=False):
if self.use_te:
nn_Dense = te.DenseGeneral
else:
nn_Dense = nn.Dense
x = nn.Conv(features=32, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x)
x = nn.relu(x)
x = nn.Conv(features=64, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Dropout(rate=0.25)(x, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
x = nn_Dense(features=128, dtype=jnp.bfloat16)(x)
x = nn.relu(x)
x = nn.Dropout(rate=0.5)(x, deterministic=disable_dropout)
x = nn_Dense(features=16, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=10, dtype=jnp.bfloat16)(x)
return x
@jax.jit
def apply_model(state, images, labels, var_collect, rngs=None):
"""Computes gradients, loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, images, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 10)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = FrozenDict({**var_collect, PARAMS_KEY: state.params})
if rngs is not None:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect)
else:
loss, logits = loss_fn(var_collect, disable_dropout=True)
grads = None
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return grads, loss, accuracy
@partial(jax.jit, static_argnums=2)
def update_model(state, grads, use_fp8):
"""Update model params and FP8 meta."""
state = state.apply_gradients(grads=grads[PARAMS_KEY])
if use_fp8:
grads = te.update_fp8_metas(grads)
return state, grads
def train_epoch(state, train_ds, batch_size, rngs, var_collect, use_fp8):
"""Train for a single epoch."""
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_images = train_ds['image'][perm, ...]
batch_labels = train_ds['label'][perm, ...]
grads, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect, rngs)
state, var_collect = update_model(state, grads, use_fp8)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
avg_loss = np.mean(epoch_loss)
avg_accuracy = np.mean(epoch_accuracy)
return state, avg_loss, avg_accuracy, var_collect
def eval_model(state, test_ds, batch_size, var_collect):
"""Evaluation loop."""
test_ds_size = len(test_ds['image'])
num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size
all_loss = []
all_accuracy = []
for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size
batch_images = test_ds['image'][batch_start:batch_end]
batch_labels = test_ds['label'][batch_start:batch_end]
_, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
avg_loss = np.mean(all_loss)
avg_accuracy = np.mean(all_accuracy)
return avg_loss, avg_accuracy
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
def check_fp8(state, var_collect, input_shape, label_shape):
"Check if model includes FP8."
assert "Float8" in str(
jax.make_jaxpr(apply_model)(state, jnp.empty(input_shape, dtype=jnp.bfloat16),
jnp.empty(label_shape, dtype=jnp.bfloat16), var_collect))
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
print(args)
if args.use_fp8:
assert gpu_has_fp8(), "GPU needs to support FP8."
args.use_te = True
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
input_shape = [args.batch_size, IMAGE_H, IMAGE_W, IMAGE_C]
label_shape = [args.batch_size]
with te.fp8_autocast(enabled=args.use_fp8):
cnn = Net(args.use_te)
var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
tx = optax.sgd(args.lr, args.momentum)
state = train_state.TrainState.create(apply_fn=cnn.apply,
params=var_collect[PARAMS_KEY],
tx=tx)
if args.use_fp8:
check_fp8(state, var_collect, input_shape, label_shape)
if args.dry_run:
apply_model(state, jnp.empty(input_shape, dtype=jnp.bfloat16),
jnp.empty(label_shape, dtype=jnp.bfloat16), var_collect,
{DROPOUT_KEY: dropout_rng})
print("PASSED")
return None
for epoch in range(1, args.epochs + 1):
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, args.use_fp8)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)
print(f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} ")
return [train_loss, train_accuracy, test_loss, test_accuracy]
def mnist_parser(args):
"""Training settings."""
parser = argparse.ArgumentParser(description="JAX MNIST Example")
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=800,
metavar="N",
help="input batch size for testing (default: 800)",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
metavar="N",
help="number of epochs to train (default: 10)",
)
parser.add_argument(
"--lr",
type=float,
default=0.01,
metavar="LR",
help="learning rate (default: 0.01)",
)
parser.add_argument(
"--momentum",
type=float,
default=0.9,
metavar="M",
help="Momentum (default: 0.9)",
)
parser.add_argument(
"--dry-run",
action="store_true",
default=False,
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument("--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration. " \
"It also enables Transformer Engine implicitly.")
parser.add_argument("--use-te",
action="store_true",
default=False,
help="Use Transformer Engine")
return parser.parse_args(args)
class TestMNIST(unittest.TestCase):
"""MNIST unittests"""
@classmethod
def setUpClass(cls):
"""Run MNIST without Transformer Engine"""
cls.args = mnist_parser(["--epochs", "5"])
@staticmethod
def verify(actual):
"""Check If loss and accuracy match target"""
desired_traing_loss = 0.055
desired_traing_accuracy = 0.98
desired_test_loss = 0.035
desired_test_accuracy = 0.098
assert actual[0] < desired_traing_loss
assert actual[1] > desired_traing_accuracy
assert actual[2] < desired_test_loss
assert actual[3] > desired_test_accuracy
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
self.args.use_te = True
self.args.use_fp8 = False
actual = train_and_evaluate(self.args)
self.verify(actual)
@unittest.skipIf(not gpu_has_fp8(), reason='GPU capability is not enough to run FP8')
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
self.verify(actual)
if __name__ == "__main__":
train_and_evaluate(mnist_parser(None))
......@@ -6,4 +6,7 @@ set -xe
: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax
pip install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
pytest -Wignore -v $TE_PATH/examples/jax
......@@ -17,7 +17,9 @@
"VERSION",
"Doxyfile",
"pylintrc",
".json"
".json",
".md",
".txt"
],
"exclude_copyright": [],
"copyright_only": false
......
......@@ -69,6 +69,7 @@ def get_file_type(path):
"txt": ["txt"],
"cfg": ["cfg"],
"sh": ["sh"],
"md": ["md"],
}
tmp = path.split(".")
for filetype, ext_list in ext.items():
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import tempfile
import unittest
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.training import train_state
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import DenseGeneral
from transformer_engine.jax.fp8 import FP8Helper
from utils import is_fp8_supported
class MLPNN(nn.Module):
use_fp8_dense: bool = True
@nn.compact
def __call__(self, x):
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=512)(x)
x = nn.relu(x)
features = [256, 256, 128]
for feature in features:
x = DenseGeneral(features=feature, transpose_batch_sequence=False,
dtype=jnp.bfloat16, use_bias=True)(x) \
if self.use_fp8_dense else nn.Dense(features=feature)(x)
x = nn.relu(x)
x = nn.Dense(features=10, use_bias=True)(x)
return x
def cross_entropy_loss(*, logits, labels):
labels_onehot = jax.nn.one_hot(labels, num_classes=10)
return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()
def compute_metrics(*, logits, labels):
loss = cross_entropy_loss(logits=logits, labels=labels)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy,
}
return metrics
def get_datasets():
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist', data_dir="/tmp/tensorflow-datasets/downloads")
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
return train_ds, test_ds
def create_train_state(rng, learning_rate, momentum, use_fp8_dense):
"""Creates initial `TrainState`."""
cnn = MLPNN(use_fp8_dense=use_fp8_dense)
variables = cnn.init(rng, jnp.ones([32, 28, 28, 1]))
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(apply_fn=cnn.apply, params=variables['params'],
tx=tx), variables
@partial(jax.jit, static_argnums=(3,))
def train_step(state, others, batch, use_fp8_dense):
"""Train for a single step."""
def loss_fn(collections):
logits = MLPNN(use_fp8_dense=use_fp8_dense).apply(collections, batch['image'])
loss = cross_entropy_loss(logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(others)
state = state.apply_gradients(grads=grads['params'])
metrics = compute_metrics(logits=logits, labels=batch['label'])
return state, metrics, grads
def train_epoch(state, variables, train_ds, batch_size, rng, use_fp8_dense):
"""Train for a single epoch."""
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rng, train_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
batch_metrics = []
for idx, perm in enumerate(perms):
idx = idx + 1
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state, metrics, grads = train_step(state, variables, batch, use_fp8_dense)
updated_coll = {'params': state.params}
if use_fp8_dense:
updated_coll[FP8Helper.FP8_COLLECTION_NAME] \
= grads[FP8Helper.FP8_COLLECTION_NAME]
variables = FP8Helper.update_collections(updated_coll, variables)
batch_metrics.append(metrics)
if use_fp8_dense:
variables = FP8Helper.update_fp8_metas(variables)
return state, variables
@partial(jax.jit, static_argnums=(2,))
def eval_step(variables, batch, use_fp8_dense):
logits = MLPNN(use_fp8_dense=use_fp8_dense).apply(variables, batch['image'])
return compute_metrics(logits=logits, labels=batch['label'])
def eval_model(variables, test_ds, batch_size, use_fp8_dense):
test_ds_size = len(test_ds['image'])
steps_per_epoch = test_ds_size // batch_size
perms = np.arange(0, test_ds_size)
perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
total_summary = {'correct': 0, 'loss': 0, 'total': 0}
for _, perm in enumerate(perms):
batch = {k: v[perm, ...] for k, v in test_ds.items()}
metrics = eval_step(variables, batch, use_fp8_dense)
metrics = jax.device_get(metrics)
summary = jax.tree_map(lambda x: x.item(), metrics)
total_summary['correct'] += summary['accuracy'] * batch_size
total_summary['loss'] += summary['loss'] * batch_size
total_summary['total'] += batch_size
return total_summary['loss']/total_summary['total'], \
total_summary['correct']/total_summary['total']
class TestMnist(unittest.TestCase):
def setUp(self) -> None:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
self.learning_rate = 0.1
self.momentum = 0.9
self.num_epochs = 5
self.batch_size = 512
self.train_ds, self.test_ds = get_datasets()
self.margin = 0.0
self.num_fp8_layers = 3
self.fp8_meta_update_interval = 1
self.temp_file = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with
self.fp8_ckpt_path = self.temp_file.name
self.seed = 0
acc_bfp16_ = self._mnist_baseline_runner()
acc_rtol = 0.005
self.target_accuracy = acc_bfp16_ * (1. - acc_rtol)
def tearDown(self):
self.temp_file.close()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_mnist_e4m3(self):
self._mnist_test_runner(FP8Format.E4M3)
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_mnist_hybrid(self):
self._mnist_test_runner(FP8Format.HYBRID)
# Skip for now due to lack bf16 in TE.Format
# def test_mnist_bfloa16(self):
# self._mnist_test_runner(FP8Format.BFLOAT16)
def _mnist_baseline_runner(self):
rng = jax.random.PRNGKey(self.seed)
rng, init_rng = jax.random.split(rng)
state, variables = create_train_state(init_rng, self.learning_rate, self.momentum, False)
del init_rng
_, accuracy = self._train_model(state, variables, self.num_epochs, rng, False)
return accuracy
def _mnist_test_runner(self, fp8_format):
FP8Helper.initialize(margin=self.margin, fp8_format=fp8_format)
rng = jax.random.PRNGKey(self.seed)
rng, init_rng = jax.random.split(rng)
state, variables = create_train_state(init_rng, self.learning_rate, self.momentum, True)
del init_rng
_, test_accuracy = self._train_model(state, variables, self.num_epochs, rng, True)
self.assertGreater(
test_accuracy, self.target_accuracy,
f"Convergence test failed on MNIST with FP8Fomat.{fp8_format.name}. "
f"Test Accuracy {test_accuracy:.4f} is lower than target {self.target_accuracy:.4f}")
FP8Helper.finalize()
def _train_model(self, state, variables, epochs, rng, use_fp8_dense):
max_test_acc = 0.0
for _ in range(0, epochs):
rng, input_rng = jax.random.split(rng)
state, variables = train_epoch(state, variables, self.train_ds, self.batch_size,
input_rng, use_fp8_dense)
_, test_accuracy = eval_model(variables, self.test_ds, self.batch_size, use_fp8_dense)
max_test_acc = test_accuracy if test_accuracy > max_test_acc else max_test_acc
return state, max_test_acc
if __name__ == '__main__':
unittest.main()
......@@ -219,7 +219,7 @@ class LayerNorm(nn.Module):
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
transpose_batch_sequence : bool, default = False
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
......@@ -233,7 +233,7 @@ class LayerNorm(nn.Module):
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ('embed',)
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
transpose_batch_sequence: bool = False
sharding_type: ShardingType = ShardingType.SINGLE
@nn.compact
......@@ -358,12 +358,12 @@ class DenseGeneral(TransformerEngineBase):
features: Union[Iterable[int], int]
kernel_init: Initializer = None
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = False
use_bias: bool = True
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ()
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
transpose_batch_sequence: bool = False
sharding_type: ShardingType = ShardingType.SINGLE
def __post_init__(self):
......
......@@ -720,7 +720,7 @@ class TransformerLayer(nn.Module):
If set to True, `TransformerLayer` module exposes a single fused
parameter for query-key-value for self-attention and key-value for
cross-attention.
transpose_batch_sequence : bool, default = True
transpose_batch_sequence : bool, default = False
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
......@@ -755,7 +755,7 @@ class TransformerLayer(nn.Module):
dtype: DType = jnp.float32
drop_path: float = 0.0
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True
transpose_batch_sequence: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment