Commit 5b6ef054 authored by yuguo's avatar yuguo
Browse files
parents 76060570 a7eeb28b
datasets
flax>=0.7.1
nltk>=3.8.2
optax
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_bf16 --num-process=$NUM_GPUS --process-id=$i &
done
wait
for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_fp8 --num-process=$NUM_GPUS --process-id=$i &
done
wait
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Encoder training on multi-GPU with tesnor parallelism"""
import argparse
import unittest
from functools import partial
import flax
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
from datasets import load_dataset
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from common import is_bf16_supported
DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model"
NAMED_BROADCAST_AXIS = "my_broadcast_axis"
NAMED_TP_AXIS = "my_tp_axis"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
class Net(nn.Module):
"""NLP Encoder"""
num_embed: int
enable_seq_paral: bool
@nn.compact
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(
te_flax.TransformerLayer,
hidden_size=256,
mlp_hidden_size=1024,
num_attention_heads=8,
hidden_dropout=0.1,
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding",
enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device.
x = jax.lax.with_sharding_constraint(
x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
)
x = te_flax.DenseGeneral(
features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
)(x)
x = te_flax.DenseGeneral(
features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,),
)(x)
x = nn.Dense(features=2)(x)
return x
def train_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes gradients, loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = {**var_collect, PARAMS_KEY: state.params}
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads)
return state, loss, accuracy, var_collect
def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
"""Train for a single epoch."""
train_ds_size = len(train_ds["sentence"])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds["label"][perm, ...]
state, loss, accuracy, var_collect = train_fn(
state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
avg_loss = np.mean(epoch_loss)
avg_accuracy = np.mean(epoch_accuracy)
return state, avg_loss, avg_accuracy, var_collect
def eval_step(state, inputs, masks, labels, var_collect):
"""Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = {**var_collect, PARAMS_KEY: state.params}
loss, logits = loss_fn(var_collect, disable_dropout=True)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return loss, accuracy
def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
"""Evaluation loop."""
test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size
all_loss = []
all_accuracy = []
for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size
batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
avg_loss = np.mean(all_loss)
avg_accuracy = np.mean(all_accuracy)
return avg_loss, avg_accuracy
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
nltk.download("punkt_tab")
dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset["sentence"]):
tokens = nltk.word_tokenize(sentence)
tensor = output[j]
for i, word in enumerate(tokens):
if i >= max_seq_len:
break
if word not in vocab:
vocab[word] = word_id
tensor[i] = word_id
word_id = word_id + 1
else:
tensor[i] = vocab[word]
seq_len = min(len(tokens), max_seq_len)
mask_2d = mask_3d[j]
mask_2d[:seq_len, :seq_len] = 0
new_dataset = {
"sentence": output,
"label": dataset["label"].astype(np.float32),
"mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
}
return new_dataset, vocab, word_id
def get_datasets(max_seq_len):
"""Load GLUE train and test datasets into memory."""
vocab = {}
word_id = 0
train_ds = load_dataset("glue", "cola", split="train")
train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset("glue", "cola", split="validation")
test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
)
def get_params_sharding(sharding_rules, abs_var_collect, mesh):
"""Refer params to create params sharding"""
rules_dict = dict(sharding_rules)
def to_device_axis(logical_axis):
partitions = [rules_dict[key] for key in logical_axis]
return NamedSharding(mesh, PartitionSpec(*partitions))
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_sharding = jax.tree_util.tree_map(
to_device_axis, nn_partitioning.get_axis_names(params_axes)
)
params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
params_sharding = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY]
)
params_sharding = {**params_sharding, **params_axes_sharding}
return params_sharding
def get_state_sharding(state, params_sharding):
"""Refer params_sharding to create state sharding"""
def replace_params(x):
return params_sharding if isinstance(x, dict) else None
state_sharding = jax.tree_util.tree_map(
replace_params, state, is_leaf=lambda x: isinstance(x, dict)
)
return state_sharding
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
num_gpu = jax.local_device_count()
num_gpu_tp = 2
if num_gpu % num_gpu_tp == 0:
num_gpu_dp = num_gpu // num_gpu_tp
else:
num_gpu_dp = 1
num_gpu_tp = 1
assert args.batch_size % num_gpu_dp == 0, f"Batch size needs to be multiple of {num_gpu_dp}"
assert (
args.test_batch_size % num_gpu_dp == 0
), f"Test batch size needs to be multiple of {num_gpu_dp}"
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh:
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
with te.fp8_autocast(
args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None)
):
encoder = Net(num_embed, args.enable_sp)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))
in_shardings = (None, inputs_sharding, masks_sharding)
out_shardings = {
key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
}
jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings)
var_collect = jit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(
apply_fn=encoder.apply, params=params, tx=optimizer
)
state_sharding = get_state_sharding(state, params_sharding)
labels_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS))
in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
)
out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(train_step, in_shardings, out_shardings)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (None, None)
jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings)
if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
check_fp8(state, var_collect, inputs, masks, labels)
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
return None
for epoch in range(1, args.epochs + 1):
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
)
test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, jit_eval_step
)
print(
f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} "
)
return [train_loss, train_accuracy, test_loss, test_accuracy]
def encoder_parser(args):
"""Training settings."""
parser = argparse.ArgumentParser(description="JAX Encoder Example")
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for testing (default: 64)",
)
parser.add_argument(
"--max-seq-len",
type=int,
default=32,
metavar="N",
help="maximum sequence length (default: 32)",
)
parser.add_argument(
"--epochs",
type=int,
default=3,
metavar="N",
help="number of epochs to train (default: 3)",
)
parser.add_argument(
"--lr",
type=float,
default=0.0001,
metavar="LR",
help="learning rate (default: 0.0001)",
)
parser.add_argument(
"--dry-run",
action="store_true",
default=False,
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument(
"--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration",
)
parser.add_argument(
"--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
)
return parser.parse_args(args)
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
@classmethod
def setUpClass(cls):
"""Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8_sp(self):
"""Test Transformer Engine with FP8 + SP"""
self.args.enable_sp = True
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Encoder training on multi-GPU with data parallelism"""
import argparse
import unittest
from functools import partial
import flax
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
from datasets import load_dataset
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from common import is_bf16_supported
DEVICE_DP_AXIS = "data"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
class Net(nn.Module):
"""NLP Encoder"""
num_embed: int
@nn.compact
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(
te_flax.TransformerLayer,
hidden_size=256,
mlp_hidden_size=1024,
num_attention_heads=8,
hidden_dropout=0.1,
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding",
enable_relative_embedding=False,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
x = te_flax.DenseGeneral(features=256)(x)
x = te_flax.DenseGeneral(features=256)(x)
x = nn.Dense(features=2)(x)
return x
def train_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes gradients, loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = {**var_collect, PARAMS_KEY: state.params}
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads)
return state, loss, accuracy, var_collect
def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
"""Train for a single epoch."""
train_ds_size = len(train_ds["sentence"])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds["label"][perm, ...]
state, loss, accuracy, var_collect = train_fn(
state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
avg_loss = np.mean(epoch_loss)
avg_accuracy = np.mean(epoch_accuracy)
return state, avg_loss, avg_accuracy, var_collect
def eval_step(state, inputs, masks, labels, var_collect):
"""Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = {**var_collect, PARAMS_KEY: state.params}
loss, logits = loss_fn(var_collect, disable_dropout=True)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return loss, accuracy
def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
"""Evaluation loop."""
test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size
all_loss = []
all_accuracy = []
for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size
batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
avg_loss = np.mean(all_loss)
avg_accuracy = np.mean(all_accuracy)
return avg_loss, avg_accuracy
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
nltk.download("punkt_tab")
dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset["sentence"]):
tokens = nltk.word_tokenize(sentence)
tensor = output[j]
for i, word in enumerate(tokens):
if i >= max_seq_len:
break
if word not in vocab:
vocab[word] = word_id
tensor[i] = word_id
word_id = word_id + 1
else:
tensor[i] = vocab[word]
seq_len = min(len(tokens), max_seq_len)
mask_2d = mask_3d[j]
mask_2d[:seq_len, :seq_len] = 0
new_dataset = {
"sentence": output,
"label": dataset["label"].astype(np.float32),
"mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
}
return new_dataset, vocab, word_id
def get_datasets(max_seq_len):
"""Load GLUE train and test datasets into memory."""
vocab = {}
word_id = 0
train_ds = load_dataset("glue", "cola", split="train")
train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset("glue", "cola", split="validation")
test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
)
def get_params_sharding(sharding_rules, abs_var_collect, mesh):
"""Refer params to create params sharding"""
rules_dict = dict(sharding_rules)
def to_device_axis(logical_axis):
partitions = [rules_dict[key] for key in logical_axis]
return NamedSharding(mesh, PartitionSpec(*partitions))
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_sharding = jax.tree_util.tree_map(
to_device_axis, nn_partitioning.get_axis_names(params_axes)
)
params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
params_sharding = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY]
)
params_sharding = {**params_sharding, **params_axes_sharding}
return params_sharding
def get_state_sharding(state, params_sharding):
"""Refer params_sharding to create state sharding"""
def replace_params(x):
return params_sharding if isinstance(x, dict) else None
state_sharding = jax.tree_util.tree_map(
replace_params, state, is_leaf=lambda x: isinstance(x, dict)
)
return state_sharding
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
num_gpu = jax.local_device_count()
assert args.batch_size % num_gpu == 0, f"Batch size needs to be multiple of {num_gpu}"
assert args.test_batch_size % num_gpu == 0, f"Test batch size needs to be multiple of {num_gpu}"
device_mesh = mesh_utils.create_device_mesh((num_gpu,))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh:
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
with te.fp8_autocast(
args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None)
):
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
sharding_rules = te_flax.extend_logical_axis_rules(tuple())
params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))
in_shardings = (None, inputs_sharding, masks_sharding)
out_shardings = {
key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
}
jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings)
var_collect = jit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(
apply_fn=encoder.apply, params=params, tx=optimizer
)
state_sharding = get_state_sharding(state, params_sharding)
labels_sharding = NamedSharding(
mesh,
PartitionSpec(
DEVICE_DP_AXIS,
),
)
in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
)
out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(train_step, in_shardings, out_shardings)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (None, None)
jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings)
if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
check_fp8(state, var_collect, inputs, masks, labels)
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
return None
for epoch in range(1, args.epochs + 1):
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
)
test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, jit_eval_step
)
print(
f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} "
)
return [train_loss, train_accuracy, test_loss, test_accuracy]
def encoder_parser(args):
"""Training settings."""
parser = argparse.ArgumentParser(description="JAX Encoder Example")
parser.add_argument(
"--batch-size",
type=int,
default=128,
metavar="N",
help="input batch size for training (default: 128)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=128,
metavar="N",
help="input batch size for testing (default: 128)",
)
parser.add_argument(
"--max-seq-len",
type=int,
default=32,
metavar="N",
help="maximum sequence length (default: 32)",
)
parser.add_argument(
"--epochs",
type=int,
default=3,
metavar="N",
help="number of epochs to train (default: 3)",
)
parser.add_argument(
"--lr",
type=float,
default=0.0001,
metavar="LR",
help="learning rate (default: 0.0001)",
)
parser.add_argument(
"--dry-run",
action="store_true",
default=False,
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument(
"--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration",
)
return parser.parse_args(args)
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
@classmethod
def setUpClass(cls):
"""Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Encoder training with multi-GPU, multiprocessing, and tensor parallelism"""
import argparse
import os
import unittest
from functools import partial
import pytest
import flax
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
from datasets import load_dataset
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.training import train_state
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding
from common import is_bf16_supported, is_fp8_supported
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model"
NAMED_BROADCAST_AXIS = "my_broadcast_axis"
NAMED_TP_AXIS = "my_tp_axis"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
class Net(nn.Module):
"""NLP Encoder"""
num_embed: int
@nn.compact
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(
te_flax.TransformerLayer,
hidden_size=256,
mlp_hidden_size=1024,
num_attention_heads=8,
hidden_dropout=0.1,
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding",
enable_relative_embedding=False,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
x = te_flax.DenseGeneral(
features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
)(x)
x = te_flax.DenseGeneral(
features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,),
)(x)
x = nn.Dense(features=2)(x)
return x
def valid_shard_size(total_size, batch_size, dp_size, tp_size):
"""Get sharded input shape"""
global_batch_size = dp_size * batch_size
num_steps = total_size // global_batch_size
valid_size = num_steps * global_batch_size
gpu_id = jax.local_devices()[0].id
tp_group_id = gpu_id // tp_size
return valid_size, global_batch_size, num_steps, tp_group_id
def shard_array_wrapper(dataset, batch_size, mesh, pspec, enable_partition=False):
"""Generate needed args for jax.make_array_from_single_device_arrays"""
inputs = jnp.asarray(dataset)
total_input_size = len(inputs)
(dp_size, tp_size) = mesh.device_ids.shape
valid_input_size, global_batch_size, num_steps, tp_group_id = valid_shard_size(
total_input_size, batch_size, dp_size, tp_size
)
inputs = inputs[:valid_input_size] # skip incomplete batch
single_input_shape = inputs.shape[1:]
global_input_shape = (global_batch_size, *single_input_shape)
named_sharding = jax.sharding.NamedSharding(mesh, pspec)
if enable_partition:
inputs = inputs.reshape(dp_size, num_steps, batch_size, *single_input_shape)
inputs = inputs[tp_group_id]
return global_input_shape, named_sharding, inputs
def train_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes gradients, loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = {**var_collect, PARAMS_KEY: state.params}
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads)
return state, loss, accuracy, var_collect
def train_epoch(
state,
train_ds,
batch_size,
rngs,
var_collect,
train_fn,
mesh,
inputs_pspec,
masks_pspec,
labels_pspec,
):
"""Train for a single epoch."""
total_batch_size = len(train_ds["sentence"])
(dp_size, tp_size) = mesh.device_ids.shape
valid_size, _, num_steps, tp_group_id = valid_shard_size(
total_batch_size, batch_size, dp_size, tp_size
)
perms = jax.random.permutation(rngs[INPUT_KEY], valid_size)
perms = perms.reshape(dp_size, num_steps, batch_size)
perms = perms[tp_group_id]
global_input_shape, input_named_sharding, sentence = shard_array_wrapper(
train_ds["sentence"], batch_size, mesh, inputs_pspec
)
global_mask_shape, mask_named_sharding, mask = shard_array_wrapper(
train_ds["mask"], batch_size, mesh, masks_pspec
)
global_label_shape, label_named_sharding, label = shard_array_wrapper(
train_ds["label"], batch_size, mesh, labels_pspec
)
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_input = sentence[perm, ...]
batch_mask = mask[perm, ...]
batch_label = label[perm, ...]
shard_input = jax.make_array_from_single_device_arrays(
global_input_shape, input_named_sharding, [batch_input]
)
shard_mask = jax.make_array_from_single_device_arrays(
global_mask_shape, mask_named_sharding, [batch_mask]
)
shard_label = jax.make_array_from_single_device_arrays(
global_label_shape, label_named_sharding, [batch_label]
)
state, loss, accuracy, var_collect = train_fn(
state, shard_input, shard_mask, shard_label, var_collect, rngs
)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
avg_loss = np.mean(epoch_loss)
avg_accuracy = np.mean(epoch_accuracy)
return state, avg_loss, avg_accuracy, var_collect
def eval_step(state, inputs, masks, labels, var_collect):
"""Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels, 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = {**var_collect, PARAMS_KEY: state.params}
loss, logits = loss_fn(var_collect, disable_dropout=True)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return loss, accuracy
def eval_model(
state, test_ds, batch_size, var_collect, eval_fn, mesh, inputs_pspec, masks_pspec, labels_pspec
):
"""Evaluation loop."""
global_input_shape, input_named_sharding, sentence = shard_array_wrapper(
test_ds["sentence"], batch_size, mesh, inputs_pspec, enable_partition=True
)
global_mask_shape, mask_named_sharding, mask = shard_array_wrapper(
test_ds["mask"], batch_size, mesh, masks_pspec, enable_partition=True
)
global_label_shape, label_named_sharding, label = shard_array_wrapper(
test_ds["label"], batch_size, mesh, labels_pspec, enable_partition=True
)
all_loss = []
all_accuracy = []
for batch_input, batch_mask, batch_label in zip(sentence, mask, label):
shard_input = jax.make_array_from_single_device_arrays(
global_input_shape, input_named_sharding, [batch_input]
)
shard_mask = jax.make_array_from_single_device_arrays(
global_mask_shape, mask_named_sharding, [batch_mask]
)
shard_label = jax.make_array_from_single_device_arrays(
global_label_shape, label_named_sharding, [batch_label]
)
loss, accuracy = eval_fn(state, shard_input, shard_mask, shard_label, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
avg_loss = np.mean(all_loss)
avg_accuracy = np.mean(all_accuracy)
return avg_loss, avg_accuracy
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset["sentence"]):
tokens = nltk.word_tokenize(sentence)
tensor = output[j]
for i, word in enumerate(tokens):
if i >= max_seq_len:
break
if word not in vocab:
vocab[word] = word_id
tensor[i] = word_id
word_id = word_id + 1
else:
tensor[i] = vocab[word]
seq_len = min(len(tokens), max_seq_len)
mask_2d = mask_3d[j]
mask_2d[:seq_len, :seq_len] = 0
new_dataset = {
"sentence": output,
"label": dataset["label"].astype(np.float32),
"mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
}
return new_dataset, vocab, word_id
def get_datasets(max_seq_len):
"""Load GLUE train and test datasets into memory."""
vocab = {}
word_id = 0
train_ds = load_dataset("glue", "cola", split="train")
train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset("glue", "cola", split="validation")
test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
)
def get_params_sharding(sharding_rules, abs_var_collect, mesh):
"""Refer params to create params sharding"""
rules_dict = dict(sharding_rules)
def to_device_axis(logical_axis):
partitions = [rules_dict[key] for key in logical_axis]
return NamedSharding(mesh, jax.sharding.PartitionSpec(*partitions))
params_axes = abs_var_collect.get(PARAMS_AXES_KEY, {})
params_axes_sharding = jax.tree_util.tree_map(
to_device_axis, nn_partitioning.get_axis_names(params_axes)
)
params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
params_sharding = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY]
)
params_sharding = {**params_sharding, **params_axes_sharding}
return params_sharding
def get_state_sharding(state, params_sharding):
"""Refer params_sharding to create state sharding"""
def replace_params(x):
return params_sharding if isinstance(x, dict) else None
state_sharding = jax.tree_util.tree_map(
replace_params, state, is_leaf=lambda x: isinstance(x, dict)
)
return state_sharding
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
if args.process_id == 0:
nltk.download("punkt_tab")
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
jax.distributed.initialize(
coordinator_address=args.coordinator_address,
num_processes=args.num_process,
process_id=args.process_id,
local_device_ids=args.process_id,
)
assert jax.local_device_count() == 1, "1 GPU per process"
num_gpu_tp = 2
if args.num_process % num_gpu_tp == 0:
num_gpu_dp = args.num_process // num_gpu_tp
else:
assert args.num_process == 1, "number of processes should be multiple of 2, or 1"
num_gpu_dp = 1
num_gpu_tp = 1
assert args.batch_size % num_gpu_dp == 0, f"Batch size needs to be multiple of {num_gpu_dp}"
assert (
args.test_batch_size % num_gpu_dp == 0
), f"Test batch size needs to be multiple of {num_gpu_dp}"
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh:
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
with te.fp8_autocast(
args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None)
):
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
inputs_sharding = NamedSharding(mesh, inputs_pspec)
masks_sharding = NamedSharding(mesh, masks_pspec)
in_shardings = (None, inputs_sharding, masks_sharding)
out_shardings = {
key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
}
jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings)
var_collect = jit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
state = train_state.TrainState.create(
apply_fn=encoder.apply, params=params, tx=optimizer
)
state_sharding = get_state_sharding(state, params_sharding)
labels_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS))
in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
)
out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(train_step, in_shardings, out_shardings)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (None, None)
jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings)
if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
check_fp8(state, var_collect, inputs, masks, labels)
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
else:
for epoch in range(1, args.epochs + 1):
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state,
train_ds,
args.batch_size,
rngs,
var_collect,
jit_train_step,
mesh,
inputs_pspec,
masks_pspec,
labels_sharding.spec,
)
test_loss, test_accuracy = eval_model(
state,
test_ds,
args.test_batch_size,
var_collect,
jit_eval_step,
mesh,
inputs_pspec,
masks_pspec,
labels_sharding.spec,
)
if args.process_id == 0:
print(
f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} "
)
jax.distributed.shutdown()
return [train_loss, train_accuracy, test_loss, test_accuracy]
def encoder_parser(args):
"""Training settings."""
parser = argparse.ArgumentParser(description="JAX Encoder Example")
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for testing (default: 64)",
)
parser.add_argument(
"--max-seq-len",
type=int,
default=32,
metavar="N",
help="maximum sequence length (default: 32)",
)
parser.add_argument(
"--epochs",
type=int,
default=3,
metavar="N",
help="number of epochs to train (default: 3)",
)
parser.add_argument(
"--lr",
type=float,
default=0.0001,
metavar="LR",
help="learning rate (default: 0.0001)",
)
parser.add_argument(
"--dry-run",
action="store_true",
default=False,
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument(
"--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration",
)
parser.add_argument(
"--coordinator-address",
type=str,
default="127.0.0.1:1234",
help=(
"the IP address of process 0 and a port on which that"
" process should launch a coordinator service (default:"
" 127.0.0.1:1234)"
),
)
parser.add_argument(
"--num-process", type=int, default=1, help="number of processes (default: 1)"
)
parser.add_argument(
"--process-id",
type=int,
default=0,
help="the ID number of the current process (default: 0)",
)
return parser.parse_args(args)
@pytest.mark.usefixtures("multiprocessing_parses")
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
gpu_has_fp8 = is_fp8_supported()
gpu_has_bf16 = is_bf16_supported()
def exec(self, use_fp8):
"""Run 3 epochs for testing"""
args = encoder_parser([])
num_gpu = self.num_process
tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1
dp_size = num_gpu // tp_size
batch_size = 64 // dp_size
args.use_fp8 = use_fp8
args.batch_size = batch_size
args.test_batch_size = batch_size
args.num_process = num_gpu
args.process_id = self.process_id
return train_and_evaluate(args)
@unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
result = self.exec(False)
assert result[0] < 0.45 and result[1] > 0.79
@unittest.skipIf(not gpu_has_fp8, "Device compute capability 9.0+ is required for FP8")
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
result = self.exec(True)
assert result[0] < 0.455 and result[1] > 0.79
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Encoder training on single GPU"""
import argparse
import unittest
from functools import partial
import flax
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
from datasets import load_dataset
from flax import linen as nn
from flax.training import train_state
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from common import is_bf16_supported
PARAMS_KEY = "params"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
class Net(nn.Module):
"""NLP Encoder"""
num_embed: int
@nn.compact
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
te_Encoder = partial(
te_flax.TransformerLayer,
hidden_size=256,
mlp_hidden_size=1024,
num_attention_heads=8,
hidden_dropout=0.1,
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding",
enable_relative_embedding=False,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
x = te_flax.DenseGeneral(features=256)(x)
x = te_flax.DenseGeneral(features=256)(x)
x = nn.Dense(features=2)(x)
return x
@partial(jax.jit)
def train_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes gradients, loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = {**var_collect, PARAMS_KEY: state.params}
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
var_collect, grads = flax.core.pop(grads, PARAMS_KEY)
state = state.apply_gradients(grads=grads)
return state, loss, accuracy, var_collect
def train_epoch(state, train_ds, batch_size, rngs, var_collect):
"""Train for a single epoch."""
train_ds_size = len(train_ds["sentence"])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds["label"][perm, ...]
state, loss, accuracy, var_collect = train_step(
state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
avg_loss = np.mean(epoch_loss)
avg_accuracy = np.mean(epoch_accuracy)
return state, avg_loss, avg_accuracy, var_collect
@jax.jit
def eval_step(state, inputs, masks, labels, var_collect):
"""Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = {**var_collect, PARAMS_KEY: state.params}
loss, logits = loss_fn(var_collect, disable_dropout=True)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return loss, accuracy
def eval_model(state, test_ds, batch_size, var_collect):
"""Evaluation loop."""
test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size
all_loss = []
all_accuracy = []
for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size
batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_step(state, batch_inputs, batch_masks, batch_labels, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
avg_loss = np.mean(all_loss)
avg_accuracy = np.mean(all_accuracy)
return avg_loss, avg_accuracy
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
nltk.download("punkt_tab")
dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
for j, sentence in enumerate(dataset["sentence"]):
tokens = nltk.word_tokenize(sentence)
tensor = output[j]
for i, word in enumerate(tokens):
if i >= max_seq_len:
break
if word not in vocab:
vocab[word] = word_id
tensor[i] = word_id
word_id = word_id + 1
else:
tensor[i] = vocab[word]
seq_len = min(len(tokens), max_seq_len)
mask_2d = mask_3d[j]
mask_2d[:seq_len, :seq_len] = 0
new_dataset = {
"sentence": output,
"label": dataset["label"].astype(np.float32),
"mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
}
return new_dataset, vocab, word_id
def get_datasets(max_seq_len):
"""Load GLUE train and test datasets into memory."""
vocab = {}
word_id = 0
train_ds = load_dataset("glue", "cola", split="train")
train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset("glue", "cola", split="validation")
test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
)
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
with te.fp8_autocast(enabled=args.use_fp8):
encoder = Net(num_embed)
# We use nn.Embed, thus inputs need to be in int
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
var_collect = encoder.init(init_rngs, inputs, masks)
tx = optax.adamw(args.lr)
state = train_state.TrainState.create(
apply_fn=encoder.apply, params=var_collect[PARAMS_KEY], tx=tx
)
if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
check_fp8(state, var_collect, inputs, masks, labels)
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
return None
for epoch in range(1, args.epochs + 1):
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect
)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)
print(
f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} "
)
return [train_loss, train_accuracy, test_loss, test_accuracy]
def encoder_parser(args):
"""Training settings."""
parser = argparse.ArgumentParser(description="JAX Encoder Example")
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for testing (default: 64)",
)
parser.add_argument(
"--max-seq-len",
type=int,
default=32,
metavar="N",
help="maximum sequence length (default: 32)",
)
parser.add_argument(
"--epochs",
type=int,
default=3,
metavar="N",
help="number of epochs to train (default: 3)",
)
parser.add_argument(
"--lr",
type=float,
default=0.0001,
metavar="LR",
help="learning rate (default: 0.0001)",
)
parser.add_argument(
"--dry-run",
action="store_true",
default=False,
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument(
"--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration",
)
return parser.parse_args(args)
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
@classmethod
def setUpClass(cls):
"""Run 4 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
# Basic MNIST Example with Optional FP8 #
This example uses MNIST training to demonstrate the Transformer Engine usage. The Transformer Engine is built on top of [Flax](https://github.com/google/flax), a neural network library and ecosystem for JAX. Thus, the Transformer Engine is free to interoperate with other Flax modules. The basic Flax usage can be referred to [Flax Basics](https://flax.readthedocs.io/en/latest/guides/flax_basics.html).
1. Setup dataset: The first step is to prepare the dataset. This is done by using the `tfds` library to download the MNIST dataset and perform image preprocessing. The `get_datasets` routine is used for this purpose.
2. Define model: The `Net` class is a small CNN model for image classification. It has an option to switch between using `nn.Dense` provided by Flax and `te.DenseGeneral` provided by the Transformer Engine. This allows for easy comparison between the two libraries.
3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. For FP8 training, the key is `te.fp8_autocast` context manager. If fp8_autocast is enabled, it will cast all `te.DenseGeneral` to FP8 precision. The `var_collect` is a collection including needed information for model training, such as parameters and FP8 metadata, which is necessary for correct casting of BF16 tensors into FP8 tensors at runtime. If fp8_autocast is turned on and print var_collect, you will see FP8 metadata inside, such as `fp8_meta_collection` section. The training and evaluating with FP8 have to be done under fp8_autocast. If not, then fp8_autocast will deconstruct the FP8 metadata, and the model will fall back to higher floating point precision, such as BF16 in this example. To check if FP8 is enabled, use the `check_fp8` routine. If model initialization with FP8 works fine, the string returned by jax.make_jaxpr should include the `Float8` keyword.
4. Training process: In `apply_model`, the main difference between normal Flax usage and this example is, with FP8 training, the FP8 metadata has to be filled into the gradient function `grad_fn`. Otherwise, the Transformer Engine doesn't know how to cast the BF16 tensor into FP8 tensor at runtime correctly. The FP8 metadata doesn't belong in model parameters (`state.params`), so we need to manually combine the metadata and latest model parameters into var_collect as a frozen dictionary and fill it to the gradient function.
5. Evaluating process: The evaluating process is the same as the training process. Need to ensure FP8 metadata is inside var_collect and fill it into loss function.
6. Additional options: The `te.fp8_autocast` context manager has additional options
* FP8 Recipe: control FP8 training behavior. See the [FP8 tutorial](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) for a detailed explanation of FP8 recipes and the supported options.
## Run ##
1. Use Flax to train MNIST with BF16 as usual
```bash
python test_single_gpu_mnist.py
```
2. Use `te.DenseGeneral` provided by Transformer Engine to train MNIST with BF16
```bash
python test_single_gpu_mnist.py --use-te
```
3. Use `te.DenseGeneral` provided by Transformer Engine to train MNIST and enable FP8 training and evaluation.
```bash
python test_single_gpu_mnist.py --use-fp8
```
datasets
flax>=0.7.1
optax
Pillow
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""MNIST training on single GPU"""
import argparse
import unittest
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
import optax
from datasets import load_dataset
from flax import linen as nn
from flax.training import train_state
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
IMAGE_H = 28
IMAGE_W = 28
IMAGE_C = 1
PARAMS_KEY = "params"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
class Net(nn.Module):
"""CNN model for MNIST."""
use_te: bool = False
@nn.compact
def __call__(self, x, disable_dropout=False):
if self.use_te:
nn_Dense = te_flax.DenseGeneral
else:
nn_Dense = nn.Dense
# dtype is used for param init in TE but computation in Linen.nn
dtype = jnp.float32 if self.use_te else jnp.bfloat16
x = nn.Conv(features=32, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x)
x = nn.relu(x)
x = nn.Conv(features=64, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Dropout(rate=0.25)(x, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
assert x.dtype == jnp.bfloat16
x = nn_Dense(features=128, dtype=dtype)(x)
x = nn.relu(x)
x = nn.Dropout(rate=0.5)(x, deterministic=disable_dropout)
x = nn_Dense(features=16, dtype=dtype)(x)
x = nn_Dense(features=10, dtype=dtype)(x)
assert x.dtype == jnp.bfloat16
return x
@jax.jit
def apply_model(state, images, labels, var_collect, rngs=None):
"""Computes gradients, loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, images, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 10)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
var_collect = {**var_collect, PARAMS_KEY: state.params}
if rngs is not None:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(var_collect)
else:
loss, logits = loss_fn(var_collect, disable_dropout=True)
grads = None
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return grads, loss, accuracy
@partial(jax.jit)
def update_model(state, grads):
"""Update model params and FP8 meta."""
state = state.apply_gradients(grads=grads[PARAMS_KEY])
return state, grads
def train_epoch(state, train_ds, batch_size, rngs, var_collect):
"""Train for a single epoch."""
train_ds_size = len(train_ds["image"])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_images = train_ds["image"][perm, ...]
batch_labels = train_ds["label"][perm, ...]
grads, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect, rngs)
state, var_collect = update_model(state, grads)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
avg_loss = np.mean(epoch_loss)
avg_accuracy = np.mean(epoch_accuracy)
return state, avg_loss, avg_accuracy, var_collect
def eval_model(state, test_ds, batch_size, var_collect):
"""Evaluation loop."""
test_ds_size = len(test_ds["image"])
num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size
all_loss = []
all_accuracy = []
for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size
batch_images = test_ds["image"][batch_start:batch_end]
batch_labels = test_ds["label"][batch_start:batch_end]
_, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
avg_loss = np.mean(all_loss)
avg_accuracy = np.mean(all_accuracy)
return avg_loss, avg_accuracy
def get_datasets():
"""Load MNIST train and test datasets into memory."""
train_ds = load_dataset("mnist", split="train", trust_remote_code=True)
train_ds.set_format(type="np")
batch_size = train_ds["image"].shape[0]
shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
new_train_ds = {
"image": train_ds["image"].astype(np.float32).reshape(shape) / 255.0,
"label": train_ds["label"],
}
test_ds = load_dataset("mnist", split="test", trust_remote_code=True)
test_ds.set_format(type="np")
batch_size = test_ds["image"].shape[0]
shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
new_test_ds = {
"image": test_ds["image"].astype(np.float32).reshape(shape) / 255.0,
"label": test_ds["label"],
}
return new_train_ds, new_test_ds
def check_fp8(state, var_collect, input_shape, label_shape):
"Check if model includes FP8."
assert "f8_" in str(
jax.make_jaxpr(apply_model)(
state,
jnp.empty(input_shape, dtype=jnp.bfloat16),
jnp.empty(label_shape, dtype=jnp.bfloat16),
var_collect,
)
)
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
if args.use_fp8:
args.use_te = True
train_ds, test_ds = get_datasets()
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
input_shape = [args.batch_size, IMAGE_H, IMAGE_W, IMAGE_C]
label_shape = [args.batch_size]
with te.fp8_autocast(enabled=args.use_fp8):
cnn = Net(args.use_te)
var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
tx = optax.sgd(args.lr, args.momentum)
state = train_state.TrainState.create(
apply_fn=cnn.apply, params=var_collect[PARAMS_KEY], tx=tx
)
if args.use_fp8:
check_fp8(state, var_collect, input_shape, label_shape)
if args.dry_run:
apply_model(
state,
jnp.empty(input_shape, dtype=jnp.bfloat16),
jnp.empty(label_shape, dtype=jnp.bfloat16),
var_collect,
{DROPOUT_KEY: dropout_rng},
)
print("PASSED")
return None
for epoch in range(1, args.epochs + 1):
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect
)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)
print(
f"Epoch: {epoch:>2} "
f"Train Loss: {train_loss:.6f} "
f"Train Accuracy: {train_accuracy:.6f} "
f"Test Loss: {test_loss:.6f} "
f"Test Accuracy: {test_accuracy:.6f} "
)
return [train_loss, train_accuracy, test_loss, test_accuracy]
def mnist_parser(args):
"""Training settings."""
parser = argparse.ArgumentParser(description="JAX MNIST Example")
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=800,
metavar="N",
help="input batch size for testing (default: 800)",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
metavar="N",
help="number of epochs to train (default: 10)",
)
parser.add_argument(
"--lr",
type=float,
default=0.01,
metavar="LR",
help="learning rate (default: 0.01)",
)
parser.add_argument(
"--momentum",
type=float,
default=0.9,
metavar="M",
help="Momentum (default: 0.9)",
)
parser.add_argument(
"--dry-run",
action="store_true",
default=False,
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--use-fp8",
action="store_true",
default=False,
help=(
"Use FP8 for inference and training without recalibration. "
"It also enables Transformer Engine implicitly."
),
)
parser.add_argument(
"--use-te", action="store_true", default=False, help="Use Transformer Engine"
)
return parser.parse_args(args)
class TestMNIST(unittest.TestCase):
"""MNIST unittests"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
@classmethod
def setUpClass(cls):
"""Run MNIST without Transformer Engine"""
cls.args = mnist_parser(["--epochs", "5"])
@staticmethod
def verify(actual):
"""Check If loss and accuracy match target"""
desired_traing_loss = 0.055
desired_traing_accuracy = 0.98
desired_test_loss = 0.04
desired_test_accuracy = 0.098
assert actual[0] < desired_traing_loss
assert actual[1] > desired_traing_accuracy
assert actual[2] < desired_test_loss
assert actual[3] > desired_test_accuracy
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
self.args.use_te = True
self.args.use_fp8 = False
actual = train_and_evaluate(self.args)
self.verify(actual)
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
self.verify(actual)
if __name__ == "__main__":
train_and_evaluate(mnist_parser(None))
# Overlapping Communication with GEMM in TransformerEngine Modules
## Requirements
- Tensor-parallel GPUs must be on a single node, and connected over NVLink/NVSwitch.
- `CUDA_DEVICE_MAX_CONNECTIONS=1` must be enabled in the environment.
- For best performance, point-to-point communication via _CUDA Multicast_ needs CUDA Toolkit 12.0+
and CUDA driver 535+ on devices with compute capability 9.0 or newer.
- Devices older than compute capability 9.0 require `UB_SKIPMC=1` in the environment in order fall
back on a less performant implementation based on CUDA Inter-Process Communication (IPC) handles.
## Examples
### Single node, tensor-parallel LayerNormMLP:
Forward and backward passes with layer weights distributed over all GPUs in a single node.
```bash
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py
# Sample output on 8x H100s:
# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3, 4, 5, 6, 7]
# !!! [UB] Create UbufP2PCommOverlap Communicator
# UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz
# MC initialized succesfully, window size = 549755813888
# !!! [UBP2P] Register UBuf 1
# !!! [UBP2P] Register UBuf 2
# !!! [UBP2P] Register UBuf 3
# !!! [UBP2P] Register UBuf 4
# !!! [UB] Register UBuf 5
# !!! [UBP2P] Register UBuf 6
# !!! [UB] Register UBuf 7
# !!! [UB] Register UBuf 8
# !!! [UBP2P] Register UBuf 9
# !!! [UB] Register UBuf 10
# [rank0:node0] Iter 1
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank0:node0] Iter 2
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank0:node0] Iter 3
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank0:node0] Iter 4
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank0:node0] Iter 5
# [rank0:node0] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
```
### Single node, mixed data- and tensor-parallel LayerNormMLP:
Uses `torch.nn.parallel.DistributedDataParallel` for replicatin the model across 2 tensor-parallel
groups in a single node.
```bash
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py --num-replicas 2
# Sample output on 8x H100s:
# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3]
# [rank4:node1] |-- Created tensor-parallel group: [4, 5, 6, 7]
# [rank0:node0] |-- Created data-parallel group: [0, 4]
# [rank3:node1] |-- Created data-parallel group: [3, 7]
# [rank1:node1] |-- Created data-parallel group: [1, 5]
# [rank2:node0] |-- Created data-parallel group: [2, 6]
# !!! [UB] Create UbufP2PCommOverlap Communicator
# UB_TIMEOUT is set to 110 sec, 217800000000 cycles, freq: 1980000khz
# MC initialized succesfully, window size = 549755813888
# !!! [UBP2P] Register UBuf 1
# !!! [UBP2P] Register UBuf 2
# !!! [UBP2P] Register UBuf 3
# !!! [UBP2P] Register UBuf 4
# !!! [UB] Register UBuf 5
# !!! [UBP2P] Register UBuf 6
# !!! [UB] Register UBuf 7
# !!! [UB] Register UBuf 8
# !!! [UBP2P] Register UBuf 9
# !!! [UB] Register UBuf 10
# [rank4:node1] Iter 1
# [rank0:node0] Iter 1
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Forward pass
# [rank4:node1] |-- Compute loss
# [rank0:node0] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank4:node1] |-- Backward pass
# [rank4:node1] |-- Optimizer step
# [rank0:node0] |-- Optimizer step
# [rank4:node1] Iter 2
# [rank0:node0] Iter 2
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank4:node1] |-- Forward pass
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Compute loss
# [rank0:node0] |-- Compute loss
# [rank4:node1] |-- Backward pass
# [rank0:node0] |-- Backward pass
# [rank4:node1] |-- Optimizer step
# [rank0:node0] |-- Optimizer step
# [rank4:node1] Iter 3
# [rank0:node0] Iter 3
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Forward pass
# [rank4:node1] |-- Compute loss
# [rank0:node0] |-- Compute loss
# [rank4:node1] |-- Backward pass
# [rank0:node0] |-- Backward pass
# [rank0:node0] |-- Optimizer step
# [rank4:node1] |-- Optimizer step
# [rank0:node0] Iter 4
# [rank4:node1] Iter 4
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank4:node1] |-- Compute loss
# [rank4:node1] |-- Backward pass
# [rank0:node0] |-- Backward pass
# [rank4:node1] |-- Optimizer step
# [rank0:node0] |-- Optimizer step
# [rank4:node1] Iter 5
# [rank0:node0] Iter 5
# [rank0:node0] |-- Generate random input batch
# [rank4:node1] |-- Generate random input batch
# [rank0:node0] |-- Forward pass
# [rank4:node1] |-- Forward pass
# [rank0:node0] |-- Compute loss
# [rank4:node1] |-- Compute loss
# [rank0:node0] |-- Backward pass
# [rank4:node1] |-- Backward pass
# [rank4:node1] |-- Optimizer step
# [rank0:node0] |-- Optimizer step
```
**NOTE:** To run with Fp8 compute on supporting hardware, add the `--fp8` flag to the commands
shown above.
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import sys
import socket
import fcntl
import struct
import argparse
import warnings
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.common.recipe import Format, DelayedScaling
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
if not tex.device_supports_multicast():
os.environ["UB_SKIPMC"] = "1"
def _te_layer_argtype(name):
te_layers = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers))
if name.lower() not in layer_map.keys():
raise argparse.ArgumentTypeError(
f"Invalid TE layer name! Please choose from: {layer_map.keys()}"
)
return layer_map[name.lower()]
def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(
description="Train a Transformer Engine module with GEMM+comm overlap via Userbuffers."
)
parser.add_argument(
"-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations."
)
parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.")
parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.")
parser.add_argument(
"-n", "--num-heads", type=int, default=64, help="Number of attention heads."
)
parser.add_argument(
"-d", "--head-dim", type=int, default=128, help="Dimension of each attention head."
)
parser.add_argument(
"--layer-type",
type=_te_layer_argtype,
default=te.TransformerLayer,
help="Transformer Engine layer to train with comm+GEMM overlap.",
)
parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
)
parser.add_argument(
"--no-comm-overlap",
action="store_true",
default=False,
help="Disable the comm+GEMM overlap.",
)
parser.add_argument(
"--num-replicas",
type=int,
default=1,
help="Number of data-parallel model replicas per node.",
)
parser.add_argument(
"--use-global-replica-count",
action="store_true",
default=False,
help="Treat '--num-replicas' as the total number of replicas.",
)
parser.add_argument(
"--tcp-init",
action="store_true",
default=False,
help="Initialize torch.distributed with TcpStore.",
)
parser.add_argument(
"--bind-to-device",
action="store_true",
default=False,
help="Initialize torch.distributed with `device_id` to bind each rank to a single device.",
)
parser.add_argument(
"--bootstrap-backend",
type=str.lower,
default="nccl",
choices=["gloo", "mpi", "nccl"],
help="Communications backend for host tensor collectives during Userbuffers bootstrapping.",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
default=False,
help="Print out from every rank instead of just the root rank of relevant process groups.",
)
parser.add_argument(
"--debug",
action="store_true",
default=False,
help="Print out additional debug information.",
)
args = parser.parse_args(argv, namespace)
if args.bootstrap_backend == "nccl":
args.bind_to_device = True
return args
def _get_layer_args(config, tp_group, tp_size, reference=False):
hidden_size = config.num_heads * config.head_dim
input_shape = [config.seq_length, config.batch_size, hidden_size]
args = [hidden_size]
kwargs = {
"params_dtype": torch.float32,
"device": "cuda",
"tp_group": tp_group,
"tp_size": tp_size,
"sequence_parallel": True,
}
kwargs["ub_overlap_ag"] = not config.no_comm_overlap
if config.layer_type is te.Linear:
input_shape[2] = hidden_size // tp_size
args.append(hidden_size)
kwargs["parallel_mode"] = "row"
kwargs["ub_overlap_rs"] = not config.no_comm_overlap
kwargs["ub_name"] = "proj"
else:
input_shape[0] = config.seq_length // tp_size
kwargs["ub_bulk_wgrad"] = not config.no_comm_overlap
kwargs["ub_bulk_dgrad"] = not config.no_comm_overlap
if config.layer_type is te.LayerNormLinear:
args.append(3 * hidden_size)
kwargs["parallel_mode"] = "column"
kwargs["ub_name"] = "qkv"
else:
kwargs["set_parallel_mode"] = True
kwargs["ub_overlap_rs"] = not config.no_comm_overlap
if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]:
args.append(4 * hidden_size)
kwargs["seq_length"] = config.seq_length
if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
args.append(config.num_heads)
kwargs["attention_dropout"] = 0.0
kwargs["fuse_qkv_params"] = True
if config.layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True
else:
kwargs["ub_tp_comm_overlap"] = not config.no_comm_overlap
kwargs["hidden_dropout"] = 0.0
return args, kwargs, input_shape
def _train(opts):
if "OMPI_COMM_WORLD_SIZE" in os.environ:
# Execution with `mpirun -np N`
WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0"))
WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1"))
opts.tcp_init = True
opts.bind_to_device = True
opts.bootstrap_backend = "mpi"
else: # TORCHELASTIC, SLURM, etc...
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count())))
NUM_NODES = WORLD_SIZE // LOCAL_SIZE
# Initialize torch.distributed global process group and get DP/TP groups
torch.cuda.set_device(LOCAL_RANK)
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
if opts.tcp_init or NUM_NODES > 1:
if NUM_NODES > 1:
assert (
"MASTER_ADDR" in os.environ
), "Multi-node run requires MASTER_ADDR to be set in the environment."
MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname()))
MASTER_PORT = os.getenv("MASTER_PORT", "1234")
dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}"
if opts.bind_to_device or opts.bootstrap_backend == "nccl":
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl")
def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False):
if debug and not opts.debug:
return
group_rank = dist.get_rank(group)
stream = sys.stderr if error else sys.stdout
if group_rank == src:
stream.write(f"[rank{WORLD_RANK}] {msg}{end}")
dist.barrier(group)
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")
total_replicas = (
opts.num_replicas if opts.use_global_replica_count else opts.num_replicas * NUM_NODES
)
tp_size = WORLD_SIZE // total_replicas
if total_replicas > 1:
ranks_per_replica_list = [
[i * tp_size + t for t in range(tp_size)] for i in range(total_replicas)
]
tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl")
ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32)
dp_group, _ = dist.new_subgroups_by_enumeration(
ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl"
)
else:
dp_group = None
tp_group = nccl_world
tp_rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group)
dist_print(
f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}",
group=tp_group,
)
if dp_group is not None:
dp_rank = dist.get_rank(dp_group)
dist_print(
f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}",
group=dp_group,
)
else:
dp_rank = 0
# Intialize userbuffers
hidden_size = opts.num_heads * opts.head_dim
batched_size = opts.seq_length * opts.batch_size
if not opts.no_comm_overlap:
te.module.base.initialize_ub(
[batched_size, hidden_size],
tp_size,
use_fp8=opts.fp8,
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
)
# Initialize the fused LayerNorm + Multi-layer Perceptron module
torch.manual_seed(opts.seed + dp_rank)
torch.cuda.manual_seed(opts.seed + tp_rank)
layer_args, layer_kwargs, input_size = _get_layer_args(opts, tp_group, tp_size)
model = opts.layer_type(*layer_args, **layer_kwargs)
if dp_group is not None:
model = DistributedDataParallel(model, dim=1, process_group=dp_group)
# Initialize optimizer with model parameters
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
# Fp8 recipe setup
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
# Start dummy "training" iterations
dist_print("Starting training iterations...")
for i in range(opts.num_iters):
dist_print(f" Iter {i+1}", group=tp_group, debug=True)
dist_print(" |-- Generate random input batch", group=tp_group, debug=True)
x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True)
dist_print(" |-- Forward pass", group=tp_group, debug=True)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
y = model(x)
if isinstance(y, tuple):
out, *_ = y
else:
out = y
dist_print(" |-- Compute loss", group=tp_group, debug=True)
loss = out.sum()
dist_print(" |-- Backward pass", group=tp_group, debug=True)
loss.backward()
dist_print(" |-- Optimizer step", group=tp_group, debug=True)
optim.step()
torch.cuda.synchronize()
dist_print("Finished training!")
te.module.base.destroy_ub()
dist_print("Destroying all process groups...", debug=True)
dist.destroy_process_group()
if opts.debug and WORLD_RANK == 0:
print("Exiting...\n", end="", flush=True)
return 0
if __name__ == "__main__":
sys.exit(_train(_parse_args()))
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Basic Example for Using PyTorch Fully Sharded Data Parallel mode with Transformer Engine
```bash
# FSDP without deferred initialization:
# Duplicate modules initialized on each device. Load on device memory reduced only after
# torch.distributed.fsdp.FullyShardedDataParallel mode shards model parameters.
$ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py
# Sample output on 8xL40S:
# [GPU-0] WORLD_SIZE = 8
# [GPU-0] TransformerEngine Model:
# TransformerLayer(
# (self_attention): MultiheadAttention(
# (layernorm_qkv): LayerNormLinear()
# (core_attention): DotProductAttention(
# (flash_attention): FlashAttention()
# (fused_attention): FusedAttention()
# (unfused_attention): UnfusedDotProductAttention(
# (scale_mask_softmax): FusedScaleMaskSoftmax()
# (attention_dropout): Dropout(p=0.1, inplace=False)
# )
# )
# (proj): Linear()
# )
# (layernorm_mlp): LayerNormMLP()
# )
# [GPU-0] Pre-FSDP memory use = 83.935232MiB
# [GPU-0] Post-FSDP memory use = 10.491904MiB
# [GPU-0] Iter. 1
# [GPU-0] Iter. 2
# [GPU-0] Iter. 3
# [GPU-0] Training Time: 6.647654296875s
# [GPU-0] Avg. Iter. Time: 2.2158847656250003s
# [GPU-0] Peak memory use = 3000MiB
# FSDP with deferred initialization:
# Modules initialized with empty parameters via `device='meta'` option. Zero load on device
# memory until torch.distributed.fsdp.FullyShardedDataParallel mode triggers a reset on
# on already sharded model parameters.
$ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py --defer-init
# Sample output on 8xL40S:
# [GPU-0] WORLD_SIZE = 8
# ...
# [GPU-0] Pre-FSDP memory use = 0.0MiB
# [GPU-0] Post-FSDP memory use = 10.491904MiB
# ...
```
**NOTE:** This example has `fp8_autocast()` enabled by default. To run on GPUs without Fp8 support
(e.g.: A100), add the `--no-fp8` option to the commands shown above.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import argparse
from functools import partial
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
)
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
# RNG state tracker for checkpointing
rng_seed = 1234
torch.manual_seed(rng_seed)
torch.cuda.manual_seed(rng_seed)
CUDA_RNG_STATES_TRACKER = te.distributed.CudaRNGStatesTracker()
CUDA_RNG_STATES_TRACKER.add("model-parallel-rng", rng_seed)
def get_cuda_rng_tracker():
return CUDA_RNG_STATES_TRACKER
def apply_fsdp_checkpointing(model, blocks):
"""apply activation checkpointing to model
returns None as model is updated directly
"""
wrapper = lambda m: checkpoint_wrapper(
m,
checkpoint_fn=te.distributed.checkpoint,
use_reentrant=False,
get_rng_state_tracker=get_cuda_rng_tracker,
)
check_fn = lambda submodule: isinstance(submodule, blocks)
apply_activation_checkpointing(model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn)
def lowercase(s):
return str(s).lower()
def torch_dtype(d):
typemap = {
"fp32": torch.float32,
"float32": torch.float32,
"fp16": torch.float16,
"float16": torch.float16,
"bf16": torch.bfloat16,
"bfloat16": torch.bfloat16,
}
if lowercase(d) not in typemap.keys():
raise TypeError
return typemap[lowercase(d)]
te_layer_map = {
"linear": te.Linear,
"layernorm": te.LayerNorm,
"rmsnorm": te.RMSNorm,
"layernormlinear": te.LayerNormLinear,
"layernormmlp": te.LayerNormMLP,
"multiheadattention": te.MultiheadAttention,
"transformerlayer": te.TransformerLayer,
}
def te_layer(l):
if l is not None:
if lowercase(l) not in te_layer_map.keys():
raise TypeError
return te_layer_map[lowercase(l)]
return None
def get_layer_args(opts):
hidden_size = opts.num_heads * opts.head_dim
layer_args = (hidden_size,)
layer_kwargs = {
"params_dtype": opts.dtype,
"device": "cuda" if opts.no_defer_init else "meta",
"get_rng_state_tracker": get_cuda_rng_tracker,
}
if opts.layer_type in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
ffn_hidden_size = 3 * hidden_size if opts.num_layers == 1 else hidden_size
layer_args += (ffn_hidden_size,)
layer_kwargs["bias"] = True
if opts.layer_type == te.LayerNormMLP:
layer_kwargs["seq_length"] = opts.seq_length
elif opts.layer_type == te.MultiheadAttention:
layer_args += (opts.num_heads,)
layer_kwargs["fuse_qkv_params"] = True
layer_kwargs["input_layernorm"] = True
elif opts.layer_type == te.TransformerLayer:
layer_args += (3 * hidden_size, opts.num_heads)
layer_kwargs["fuse_qkv_params"] = True
layer_kwargs["seq_length"] = opts.seq_length
return layer_args, layer_kwargs
def parse_fsdp_args():
parser = argparse.ArgumentParser(
description="Run Transformer Engine modules with the "
+ "torch.distributed.fsdp.FullyShardedDataParallel strategy."
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
default=False,
help="Print out information from all GPUs instead of only the root GPU-0.",
)
parser.add_argument("-b", "--batch-size", type=int, default=32, help="Input batch size.")
parser.add_argument("-s", "--seq-length", type=int, default=1048, help="Input sequence length.")
parser.add_argument(
"-n", "--num-heads", type=int, default=16, help="Number of attention heads."
)
parser.add_argument(
"-d",
"--head-dim",
type=int,
default=128,
help="Dimension of each attention head (number of KV channels).",
)
parser.add_argument(
"-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations."
)
parser.add_argument(
"-k",
"--num-layers",
type=int,
default=3,
help="Number of modules chained together with nn.Sequential.",
)
parser.add_argument(
"--layer-type",
type=te_layer,
default=te.TransformerLayer,
choices=list(te_layer_map.values()),
help="TE module type used to construct the test model.",
)
parser.add_argument("--seed", type=int, default=1234, help="PyTorch RNG seed.")
parser.add_argument(
"--profile-memory",
action="store_true",
help="Enable memory profiling via torch.profiler.profile().",
)
parser.add_argument(
"--profile-name", type=str, default=None, help="File path for memory profiling."
)
parser.add_argument(
"--checkpoint-layer",
type=te_layer,
default=None,
help="Recompute activations of the selected layer during the backward "
+ "pass instead of saving.",
)
parser.add_argument(
"--no-fp8",
action="store_true",
default=False,
help="Disables the te.fp8_autocast() context.",
)
parser.add_argument(
"--no-defer-init",
action="store_true",
help="Defer module parameter initialization until after FSDP sharding.",
)
parser.add_argument(
"--no-te-fsdp",
action="store_true",
help="Disable sharding of intermediate/activation tensors in TE modules.",
)
parser.add_argument(
"--dtype",
type=torch_dtype,
default=torch.bfloat16,
help="Data type for input tensor and Transformer Engine module parameters.",
)
return parser.parse_args()
def dist_print(text, all_ranks=False, no_new_line=False):
if LOCAL_RANK == 0 or all_ranks:
end = "" if no_new_line else "\n"
print(f"[GPU-{LOCAL_RANK}] " + text, end=end)
def train(opts):
# Initialize torch.distributed global process group
dist.init_process_group(backend="nccl")
torch.cuda.set_device(LOCAL_RANK)
dist_print(f"WORLD_SIZE = {WORLD_SIZE}")
torch.manual_seed(opts.seed)
# Construct a simple homogeneous model (only one layer type) with NO PARALLELISM
layer_args, layer_kwargs = get_layer_args(opts)
if opts.num_layers > 1:
te_layer_list = []
for i in range(opts.num_layers):
if opts.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
layer_kwargs["layer_number"] = i + 1
te_layer_list.append(opts.layer_type(*layer_args, **layer_kwargs))
te_model = nn.Sequential(*te_layer_list)
else:
# Single layer model
te_model = opts.layer_type(*layer_args, **layer_kwargs)
# Print out allocated device memory before the model parameters are sharded by FSDP
pre_mem_use = torch.cuda.memory_allocated(device=f"cuda:{LOCAL_RANK}") * 1e-6
dist_print(f"Pre-FSDP memory use = {pre_mem_use}MiB")
# Wrap the model with FSDP
# NOTE: The TE model itself has no inherent parallelism. FSDP shards model parameters and
# controls all communication.
all_gpus = dist.new_group(backend="nccl")
fsdp_wrap_policy = always_wrap_policy
if opts.layer_type == te.TransformerLayer:
# NOTE: FSDP causes illegal memory access without this special policy for Transformers
fsdp_wrap_policy = partial(
transformer_auto_wrap_policy, transformer_layer_cls={te.TransformerLayer}
)
te_model = FullyShardedDataParallel(
te_model,
process_group=all_gpus,
use_orig_params=True,
mixed_precision=MixedPrecision(
param_dtype=opts.dtype,
reduce_dtype=torch.float32,
),
auto_wrap_policy=fsdp_wrap_policy,
)
if opts.checkpoint_layer is not None:
# Recompute the activations of the selected layer during the backward pass instead of
# saving them during the forward pass
apply_fsdp_checkpointing(te_model, blocks=opts.checkpoint_layer)
elif not opts.no_te_fsdp:
# Prepare TE modules to shard internal buffers that FSDP cannot shard on its own
prepare_te_modules_for_fsdp(te_model)
# Print out allocated device memory after the model parameters are sharded
post_mem_use = torch.cuda.memory_allocated(device=f"cuda:{LOCAL_RANK}") * 1e-6
dist_print(f"Post-FSDP memory use = {post_mem_use}MiB")
dist_print(f"FSDP-Wrapped + Checkpointed TE Model:\n{te_model}")
# Fp8 setup for TE
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
# Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded
optim = torch.optim.Adam(te_model.parameters(), lr=0.0001)
# Profile memory use
if opts.profile_memory:
torch.cuda.memory._record_memory_history(max_entries=100000)
else:
torch.cuda.reset_peak_memory_stats()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
for i in range(opts.num_iters):
# Generate a random input batch
x = torch.rand(
opts.seq_length,
opts.batch_size,
opts.num_heads * opts.head_dim,
dtype=opts.dtype,
device="cuda",
)
# fp8_autocast needs to be given the FSDP process group for amax reductions
with te.fp8_autocast(enabled=not opts.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
y = te_model(x)
loss = y.sum()
# calculate gradient and take training step outside the fp8_autocast context
loss.backward()
optim.step()
optim.zero_grad(set_to_none=True)
del x
if opts.profile_memory:
torch.cuda.memory._dump_snapshot(f"gpu{LOCAL_RANK}_{opts.profile_name}.pickle")
torch.cuda.memory._record_memory_history(enabled=None)
else:
end.record()
torch.cuda.synchronize()
peak_mem = torch.cuda.max_memory_allocated()
train_time = start.elapsed_time(end) / 1000.0
dist_print(f"Training Time: {train_time}s")
dist_print(f"Avg. Iter. Time: {train_time / opts.num_iters}s")
dist_print(f"Peak Memory Use: {peak_mem * 1e-6}MBs")
# Run with:
# torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) test_fsdp.py --defer-init
if __name__ == "__main__":
args = parse_fsdp_args()
train(args)
# Basic MNIST Example with optional FP8
```bash
python main.py
python main.py --use-te # Linear layers from TransformerEngine
python main.py --use-fp8 # FP8 + TransformerEngine for Linear layers
```
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from transformer_engine import pytorch as te
class Net(nn.Module):
def __init__(self, use_te=False):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
if use_te:
self.fc1 = te.Linear(9216, 128)
self.fc2 = te.Linear(128, 16)
else:
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 16)
self.fc3 = nn.Linear(16, 10)
def forward(self, x):
"""FWD"""
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
x = self.fc3(x)
output = F.log_softmax(x, dim=1)
return output
def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
"""Training function."""
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
with te.fp8_autocast(enabled=use_fp8):
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print(
f"Train Epoch: {epoch} "
f"[{batch_idx * len(data)}/{len(train_loader.dataset)} "
f"({100. * batch_idx / len(train_loader):.0f}%)]\t"
f"Loss: {loss.item():.6f}"
)
if args.dry_run:
break
def calibrate(model, device, test_loader, fp8):
"""Calibration function."""
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=fp8, calibrating=True):
output = model(data)
def test(model, device, test_loader, use_fp8):
"""Testing function."""
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=use_fp8):
output = model(data)
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print(
f"\nTest set: Average loss: {test_loss:.4f}, "
f"Accuracy: {correct}/{len(test_loader.dataset)} "
f"({100. * correct / len(test_loader.dataset):.0f}%)\n"
)
def main():
# Training settings
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=1000,
metavar="N",
help="input batch size for testing (default: 1000)",
)
parser.add_argument(
"--epochs",
type=int,
default=14,
metavar="N",
help="number of epochs to train (default: 14)",
)
parser.add_argument(
"--lr",
type=float,
default=1.0,
metavar="LR",
help="learning rate (default: 1.0)",
)
parser.add_argument(
"--gamma",
type=float,
default=0.7,
metavar="M",
help="Learning rate step gamma (default: 0.7)",
)
parser.add_argument(
"--dry-run",
action="store_true",
default=False,
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--log-interval",
type=int,
default=10,
metavar="N",
help="how many batches to wait before logging training status",
)
parser.add_argument(
"--save-model",
action="store_true",
default=False,
help="For Saving the current Model",
)
parser.add_argument(
"--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration",
)
parser.add_argument(
"--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only"
)
parser.add_argument(
"--use-te", action="store_true", default=False, help="Use Transformer Engine"
)
args = parser.parse_args()
use_cuda = torch.cuda.is_available()
if args.use_fp8 or args.use_fp8_infer:
assert use_cuda, "CUDA needed for FP8 execution."
args.use_te = True
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
train_kwargs = {"batch_size": args.batch_size}
test_kwargs = {"batch_size": args.test_batch_size}
if use_cuda:
cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST("../data", train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
model = Net(use_te=args.use_te).to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch, args.use_fp8)
test(model, device, test_loader, args.use_fp8)
scheduler.step()
if args.use_fp8_infer and not args.use_fp8:
calibrate(model, device, test_loader, args.use_fp8)
if args.save_model or args.use_fp8_infer:
torch.save(model.state_dict(), "mnist_cnn.pt")
print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8_infer))
weights = torch.load("mnist_cnn.pt")
model.load_state_dict(weights)
test(model, device, test_loader, args.use_fp8_infer)
if __name__ == "__main__":
main()
[MASTER]
extension-pkg-whitelist=flash_attn_2_cuda,
torch,
transformer_engine_torch,
transformer_engine_jax
disable=too-many-locals,
too-few-public-methods,
too-many-public-methods,
too-many-positional-arguments,
invalid-name,
too-many-arguments,
abstract-method,
arguments-differ,
too-many-instance-attributes,
unsubscriptable-object,
import-outside-toplevel,
too-many-statements,
import-error,
too-many-lines,
use-maxsplit-arg,
protected-access,
pointless-string-statement,
cyclic-import,
duplicate-code,
no-member,
attribute-defined-outside-init,
global-statement,
too-many-branches,
global-variable-not-assigned,
redefined-argument-from-local,
line-too-long,
too-many-return-statements,
too-many-nested-blocks
[TYPECHECK]
ignored-modules=torch
ignored-classes=torch
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
# Find TE
: ${TE_PATH:=/opt/transformerengine}
TE_LIB_PATH=`pip3 show transformer-engine | grep Location | cut -d ' ' -f 2`
export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH
cd $TE_PATH/tests/cpp
cmake -GNinja -Bbuild .
cmake --build build
ctest --test-dir build -j4
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -xe
: ${TE_PATH:=/opt/transformerengine}
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
pip3 install cpplint==1.6.0 pylint==3.3.1
if [ -z "${PYTHON_ONLY}" ]
then
cd $TE_PATH
echo "Checking common API headers"
python3 -m cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include
echo "Checking C++ files"
python3 -m cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common
python3 -m cpplint --recursive transformer_engine/jax
fi
if [ -z "${CPP_ONLY}" ]
then
cd $TE_PATH
echo "Checking Python files"
python3 -m pylint --recursive=y transformer_engine/common transformer_engine/jax
fi
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -xe
pip3 install "nltk>=3.8.2"
pip3 install pytest==8.2.1
: ${TE_PATH:=/opt/transformerengine}
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py
# Test without custom calls
NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment