Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
......@@ -13,7 +13,7 @@ def speedometer(
input: torch.Tensor,
output_grad: torch.Tensor,
forward_kwargs: dict = {},
fp8_autocast_kwargs: Optional[dict] = None,
autocast_kwargs: Optional[dict] = None,
timing_iters: int = 50,
warmup_iters: int = 50,
) -> None:
......@@ -23,20 +23,20 @@ def speedometer(
"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
if fp8_autocast_kwargs is None:
fp8_autocast_kwargs = {"enabled": False}
if autocast_kwargs is None:
autocast_kwargs = {"enabled": False}
# Warmup runs
torch.cuda.synchronize()
for _ in range(warmup_iters):
with te.fp8_autocast(**fp8_autocast_kwargs):
with te.autocast(**autocast_kwargs):
output = module(input, **forward_kwargs)
output.backward(output_grad)
# Timing runs
start.record()
for _ in range(timing_iters):
with te.fp8_autocast(**fp8_autocast_kwargs):
with te.autocast(**autocast_kwargs):
output = module(input, **forward_kwargs)
output.backward(output_grad)
end.record()
......
......@@ -14,7 +14,7 @@ from torch.amp import autocast
import transformer_engine as te
from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding
from transformer_engine.common.recipe import Format, DelayedScaling
from transformer_engine.pytorch.fp8 import get_default_fp8_recipe
from transformer_engine.pytorch.quantization import get_default_fp8_recipe
import transformers
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig, GemmaModel
......@@ -461,8 +461,8 @@ class TEGemmaForCausalLM(GemmaForCausalLM):
# Both autocasts are needed: FP8 for operations that can run in lower
# precision and BF16 for those that cannot.
with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False), te.pytorch.fp8_autocast(
enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None
with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False), te.pytorch.autocast(
enabled=self.config.fp8, recipe=self.fp8_recipe if self.config.fp8 else None
):
lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze()
# If padding is at the beginning, then shift it to the end
......@@ -694,8 +694,8 @@ class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM):
graphed_function = te.pytorch.make_graphed_callables(
function,
(input_tensor,),
fp8_enabled=self.config.fp8,
fp8_recipe=fp8_recipe,
enabled=self.config.fp8,
recipe=fp8_recipe,
allow_unused_input=True,
num_warmup_iters=5,
sample_kwargs=sample_kwargs,
......
......@@ -9,7 +9,7 @@ import torch
from typing import List
from transformer_engine.pytorch.fp8 import fp8_model_init
from transformer_engine.pytorch.quantization import quantized_model_init
from transformers.modeling_utils import load_state_dict
from transformers.utils.hub import get_checkpoint_shard_files
......@@ -88,10 +88,10 @@ def load_te_model(cls, config):
config.use_cache = False # To make TransformerLayer compatible with GemmaModel
# Loading model with FP8 only weights needs both the following context managers.
# 1. fp8_model_init(config.fp8_model_init) to tell TE to use FP8 only weights.
# 1. quantized_model_init(config.quantized_model_init) to tell TE to use FP8 only weights.
# 2. torch.no_grad() during TE modules' initilization so that they respect
# the `fp8_model_init` context manager.
with torch.no_grad(), fp8_model_init(config.fp8_model_init):
# the `quantized_model_init` context manager.
with torch.no_grad(), quantized_model_init(config.quantized_model_init):
# Just create a model with random weights.
vanilla_model = cls(config).cuda()
......
......@@ -77,7 +77,7 @@
"\n",
"This tutorial uses the `DelayedScaling` recipe for FP8 precision, which relies on the correct calculation of \"scaling factors\".\n",
"\n",
"If a model is trained in BF16/FP32, obtaining correct FP8 scaling factors becomes important when it is then run under `fp8_autocast()` context manager. The value of these scaling factors defaults to their initial values, which do not capture the distribution of higher precision weights and input tensors and can cause numerical errors upon usage. Calibration involves capturing an appropriate distribution of higher precision weights and input tensor values and, in turn, calculating appropriate FP8 scaling factors from those. Once these factors are computed, the model becomes numerically stable.\n",
"If a model is trained in BF16/FP32, obtaining correct FP8 scaling factors becomes important when it is then run under `autocast()` context manager. The value of these scaling factors defaults to their initial values, which do not capture the distribution of higher precision weights and input tensors and can cause numerical errors upon usage. Calibration involves capturing an appropriate distribution of higher precision weights and input tensor values and, in turn, calculating appropriate FP8 scaling factors from those. Once these factors are computed, the model becomes numerically stable.\n",
"\n",
"It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n",
"\n",
......@@ -94,12 +94,12 @@
"\n",
"The typical approach is to store weights in higher precision and then cast them to FP8 before operations. This may prevent accuracy drops in training. However, for inference, this level of precision is not necessary.\n",
"\n",
"The Transformer Engine includes a wrapper `fp8_model_init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast model weights from higher precision to FP8 every time, thus saving time in the forward pass during token generation. \n",
"The Transformer Engine includes a wrapper `quantized_model_init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast model weights from higher precision to FP8 every time, thus saving time in the forward pass during token generation. \n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/fp8_model_init.svg\" style=\"border: 1px solid #000; border-radius: 0;\" alt=\"\">\n",
"<figcaption>\n",
"Figure 3: Model under <b>fp8_autocast()</b> stores weights in high precision by default, and casts them if needed. If used without consideration, it could potentially not provide the expected speedup and also end up unnecessarily increasing overall GPU memory usage. Using <b>fp8_model_init()</b> results in storing model weights in FP8 by default, which can help with these potential issues.\n",
"Figure 3: Model under <b>autocast()</b> stores weights in high precision by default, and casts them if needed. If used without consideration, it could potentially not provide the expected speedup and also end up unnecessarily increasing overall GPU memory usage. Using <b>quantized_model_init()</b> results in storing model weights in FP8 by default, which can help with these potential issues.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
......@@ -405,8 +405,8 @@
" graphed_function = te.pytorch.make_graphed_callables(\n",
" function,\n",
" (input_tensor,),\n",
" fp8_enabled=self.config.fp8,\n",
" fp8_recipe=fp8_recipe,\n",
" enabled=self.config.fp8,\n",
" recipe=fp8_recipe,\n",
" allow_unused_input=True,\n",
" num_warmup_iters=5,\n",
" sample_kwargs=sample_kwargs,\n",
......@@ -540,14 +540,14 @@
"source": [
"### Calibrating FP8 scaling factors for correctness\n",
"\n",
"Implementing token generation in FP8 precision with the Gemma model is not straightforward because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing when used with `fp8_autocast` context manager. As Figure 5 shows, scaling factors are needed for two types of tensors for this tutorial:\n",
"Implementing token generation in FP8 precision with the Gemma model is not straightforward because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing when used with `autocast` context manager. As Figure 5 shows, scaling factors are needed for two types of tensors for this tutorial:\n",
"\n",
"1. Model weight tensors\n",
"2. Input tensors\n",
"\n",
"If the model is run in FP8 precision with incorrect scaling factors, the resulting FP8-cast model weights and FP8-cast inputs (both converted from BF16 precision) will be significantly misaligned, potentially leading to large errors and inaccurate results.\n",
"\n",
"To address this issue, \"calibration\" is used. This involves running several forward iterations in BF16 precision within the context `te.fp8_autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the \"scaling factors\" that are then used to cast higher precision tensors to FP8 precision more accurately. Calibration in the forward passes calculates the scaling factors for weight and input tensors.\n",
"To address this issue, \"calibration\" is used. This involves running several forward iterations in BF16 precision within the context `te.autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the \"scaling factors\" that are then used to cast higher precision tensors to FP8 precision more accurately. Calibration in the forward passes calculates the scaling factors for weight and input tensors.\n",
"\n",
"*Note that other tensors might need calibration in specific use-cases, but for the generation process in this tutorial, calibrating only the input and weight tensors is needed, and so only the forward pass is considered.*\n",
" \n",
......@@ -590,14 +590,14 @@
"model = init_te_gemma_model(run_config)\n",
"\n",
"# Calibration\n",
"with te.fp8_autocast(enabled=False, calibrating=True), torch.autocast(\n",
"with te.autocast(enabled=False, calibrating=True), torch.autocast(\n",
" device_type=\"cuda\", dtype=torch.bfloat16\n",
"):\n",
" model.train()\n",
" run_forward_pass(model, run_config, num_iters=64)\n",
"\n",
"# Compute scale_fwd with enabled fp8 autocast\n",
"with te.fp8_autocast(enabled=True), torch.autocast(\n",
"with te.autocast(enabled=True), torch.autocast(\n",
" device_type=\"cuda\", dtype=torch.bfloat16\n",
"):\n",
" run_forward_pass(model, run_config, 1)\n",
......@@ -734,7 +734,7 @@
"2. Faster forward pass - since there is no cast kernel to cast higher precision weights to FP8 every time before a GEMM operation. (Unless the inputs are in FP8 precision already, there's still one cast kernel to cast inputs to FP8 precision.) \n",
"\n",
"\n",
"Transformer Engine supports maintaining FP8-only weights with the `fp8_model_init` context manager. Let's see a small example:"
"Transformer Engine supports maintaining FP8-only weights with the `quantized_model_init` context manager. Let's see a small example:"
]
},
{
......@@ -778,7 +778,7 @@
"del linear_bf16\n",
"\n",
"# Initialize model weights in FP8 precision\n",
"with torch.no_grad(), te.fp8_model_init(enabled=True):\n",
"with torch.no_grad(), te.quantized_model_init(enabled=True):\n",
" linear_fp8 = te.Linear(H, D)\n",
"print(f\"Actual GPU memory usage with a TE FP8 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n",
"del linear_fp8"
......@@ -793,11 +793,11 @@
"<figure align=\"center\">\n",
"<img src=\"./media/fp8_model_init_2_half.svg\" style=\"border: 1px solid #000; border-radius: 0;\">\n",
"<figcaption>\n",
" Figure 8: Using fp8_model_init stores the weights directly in FP8 format, which reduces both time and memory usage. Note that the inputs still need a cast kernel.\n",
" Figure 8: Using quantized_model_init stores the weights directly in FP8 format, which reduces both time and memory usage. Note that the inputs still need a cast kernel.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
"Let's run the code with `fp8_model_init`:"
"Let's run the code with `quantized_model_init`:"
]
},
{
......@@ -862,7 +862,7 @@
"\n",
"# Enable FP8 math and FP8 model weights\n",
"run_config.fp8 = True\n",
"run_config.fp8_model_init = True # This will result in storing only fp8 weights.\n",
"run_config.quantized_model_init = True # This will result in storing only fp8 weights.\n",
"run_config.fp8_model_weights_filename = (\n",
" \"calibrated_weights.pth\" # <-- Add calibrated weights location here.\n",
")\n",
......@@ -885,7 +885,7 @@
"| HF (baseline) | 46.6 s | - | - | - |\n",
"| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n",
"| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |\n",
"| TE (te.TransformerLayer) + CUDA Graphs + FP8 (with `fp8_model_init`) | 4.99 s | 9.3x | 5.05 s | 9.2x |"
"| TE (te.TransformerLayer) + CUDA Graphs + FP8 (with `quantized_model_init`) | 4.99 s | 9.3x | 5.05 s | 9.2x |"
]
},
{
......@@ -911,7 +911,7 @@
"It's worth noting that these features in TE are also readily applicable to other use-cases which haven't been extensively talked about in the tutorial: \n",
"\n",
"1. Longer context lengths (with paged KV cache) \n",
"2. Using less memory during generation (by storing weights in FP8 precision using `fp8_model_init`)\n",
"2. Using less memory during generation (by storing weights in FP8 precision using `quantized_model_init`)\n",
"\n",
"Readers are encouraged to explore these use cases by playing around with this tutorial, especially with larger models."
]
......
......@@ -34,7 +34,7 @@ class RunConfiguration:
# FP8 precision settings
self.fp8 = False
self.fp8_model_weights_filename = None
self.fp8_model_init = False
self.quantized_model_init = False
# Cuda graphs
self.generation_cuda_graphs = False
......
......@@ -15,8 +15,8 @@ Here, we take the `MultiheadAttention` module as an example. Its FP8 attention m
.. code-block:: python
>>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init
>>> with fp8_model_init(enabled=True):
>>> from transformer_engine.pytorch import MultiheadAttention, quantized_model_init
>>> with quantized_model_init(enabled=True):
... mha = MultiheadAttention(
... hidden_size=1024,
... num_attention_heads=16,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Shared functions for the comm_overlap tests"""
import jax.numpy as jnp
import numpy as np
# Add this after your existing imports
def dtype_tols(dtype, rtol=None, atol=None):
"""Expected numerical tolerance for a data type."""
# Return immediately if tolerances are fully specified
if rtol is not None and atol is not None:
return {"rtol": rtol, "atol": atol}
# Default tolerances for common dtypes
if dtype in [jnp.float32, "float32"]:
return {"rtol": 1e-5, "atol": 1e-8}
elif dtype in [jnp.float16, "float16"]:
return {"rtol": 1e-3, "atol": 1e-6}
elif dtype in [jnp.bfloat16, "bfloat16"]:
return {"rtol": 1e-2, "atol": 1e-5}
else:
return {"rtol": 1e-5, "atol": 1e-8}
def assert_allclose(
actual,
desired,
rtol=None,
atol=None,
dtype=None,
**kwargs,
):
"""Check if two tensors are close."""
# Infer data type if needed
if dtype is None:
if isinstance(actual, float):
dtype = "float32"
else:
dtype = actual.dtype
# Determine tolerances
tols = {}
if rtol is None or atol is None:
tols = dtype_tols(dtype)
if rtol is not None:
tols["rtol"] = rtol
if atol is not None:
tols["atol"] = atol
# Cast tensors to fp32
if not isinstance(actual, float):
actual = actual.astype(jnp.float32)
if not isinstance(desired, float):
desired = desired.astype(jnp.float32)
# Check if tensors are close
np.testing.assert_allclose(actual, desired, **tols, **kwargs)
def assert_allclose_print_index(ref_output, gathered_output, rtol=1e-5, atol=1e-8):
if not jnp.allclose(ref_output, gathered_output, rtol=rtol, atol=atol):
diff = jnp.abs(ref_output - gathered_output)
mask = diff > (atol + rtol * jnp.abs(gathered_output))
print(mask.astype(int))
print(jnp.where(mask, diff, 0))
# Shared constants for all tests
DP_AXIS = "data"
TPSP_AXIS = "tensor_sequence"
PARAMS_KEY = "params"
# Shared functions for distributed testing
import argparse
import jax
from jax.experimental import mesh_utils
from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap
# Global flag to track if distributed has been initialized
_distributed_initialized = False
def _is_distributed_initialized():
"""Check if JAX distributed has been initialized."""
return _distributed_initialized
def _initialize_distributed(args):
"""Initialize JAX distributed with custom arguments."""
global _distributed_initialized
# Check if already initialized
if _distributed_initialized:
return
if args.coordinator_address is None or args.num_processes is None or args.process_id is None:
raise ValueError(
"All distributed initialization arguments are required: "
"--coordinator-address, --num-processes, --process-id"
)
if args.local_device_ids is None:
assert (
args.num_devices_per_process is not None
), "Either local_device_ids or num_devices_per_process must be provided"
# Calculate device range for this process
# Single process single device: each process gets one unique device
# Single process multiple devices: each process gets a unique range of devices
start_device = args.process_id * args.num_devices_per_process
device_range = range(start_device, start_device + args.num_devices_per_process)
global_device_ids_for_this_process = ",".join(map(str, device_range))
else:
# Use explicitly provided global device IDs
global_device_ids_for_this_process = args.local_device_ids
args.num_devices_per_process = len(args.local_device_ids.split(","))
assert args.num_devices_per_process == 1, "Only single process single GPU is supported!"
print(
f"Initializing JAX distributed with coordinator={args.coordinator_address}, "
f"num_processes={args.num_processes}, process_id={args.process_id}"
)
# Note: "local_device_ids" is a JAX term meaning "global CUDA devices managed by this process"
jax.distributed.initialize(
coordinator_address=args.coordinator_address,
num_processes=args.num_processes,
process_id=args.process_id,
local_device_ids=global_device_ids_for_this_process,
)
_distributed_initialized = True
jax.clear_caches()
jax.config.update(
"jax_use_shardy_partitioner", False
) # CollectiveGEMM does not work with Shardy yet
assert jax.local_device_count() == 1, (
f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found"
f" {jax.local_device_count()}"
)
devices_per_process = 1
num_total_devices = args.num_processes
print(
f"Initializing CGEMM communicator with num_total_devices={num_total_devices},"
f" devices_per_process={devices_per_process}, process_id={args.process_id}"
)
collective_gemm_bootstrap(
num_total_devices=num_total_devices,
num_devices_per_process=devices_per_process,
process_id=args.process_id,
tensor_parallel_size=args.tensor_parallel_size,
)
def _get_dp_and_tp_sizes(args):
num_gpu = args.num_processes * args.num_devices_per_process
if args.tensor_parallel_size is None:
num_gpu_dp = 2 if args.enable_data_parallel else 1
assert (
num_gpu > 1 and num_gpu % num_gpu_dp == 0
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_tp = num_gpu // num_gpu_dp
else:
num_gpu_tp = args.tensor_parallel_size
assert (
num_gpu > 1 and num_gpu % num_gpu_tp == 0
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_dp = num_gpu // num_gpu_tp
return num_gpu_dp, num_gpu_tp
def _create_mesh(args):
"""Create mesh configuration with proper validation."""
num_gpu = args.num_processes * args.num_devices_per_process
assert num_gpu == len(jax.devices()), "Number of GPUs must be equal to number of devices"
num_gpu_dp, num_gpu_tp = _get_dp_and_tp_sizes(args)
print(f"Using {num_gpu_dp}x{num_gpu_tp} mesh ({num_gpu_dp * num_gpu_tp} total GPUs)")
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=(DP_AXIS, TPSP_AXIS))
return mesh
def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor parallelism"):
"""Create common argument parser for all collective GEMM tests."""
parser = argparse.ArgumentParser(description=description)
# Distributed initialization arguments
parser.add_argument(
"--coordinator-address",
type=str,
default=None,
help="Coordinator address for distributed initialization",
)
parser.add_argument(
"--num-processes",
type=int,
default=None,
help="Number of processes for distributed initialization",
)
parser.add_argument(
"--process-id", type=int, default=None, help="Process ID for distributed initialization"
)
parser.add_argument(
"--local-device-ids",
type=str,
default=None,
help="Local device IDs for distributed initialization (comma-separated)",
)
parser.add_argument(
"--num-devices-per-process", type=int, default=1, help="Number of devices per process"
)
# Test configuration arguments
parser.add_argument(
"--tensor-parallel-size", type=int, default=None, help="Tensor parallel size"
)
parser.add_argument("--batch-size", type=int, default=4, help="Batch size for testing")
parser.add_argument("--seq-len", type=int, default=8192, help="Sequence length for testing")
parser.add_argument("--hidden-in", type=int, default=4096, help="Input hidden dimension")
parser.add_argument("--hidden-out", type=int, default=8192, help="Output hidden dimension")
parser.add_argument(
"--collective-type",
type=str,
default="all_gather",
choices=["all_gather", "reduce_scatter"],
help="Type of collective operation",
)
parser.add_argument(
"--fp8-recipe", type=str, default="DelayedScaling", help="FP8 recipe to use"
)
parser.add_argument(
"--enable-data-parallel", action="store_true", help="Enable data parallelism"
)
parser.add_argument(
"--enable-result-check", action="store_true", default=True, help="Enable result checking"
)
return parser
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""config for collective_gemm tests"""
import pytest
def pytest_addoption(parser):
"""Pytest hook for collective_gemm tests"""
parser.addoption("--coordinator-address", action="store", default="localhost:12345")
parser.addoption("--num-processes", action="store", default=1)
parser.addoption("--process-id", action="store", default=0)
parser.addoption("--local-device-ids", action="store", default=None)
@pytest.fixture(autouse=True)
def distributed_args(request):
"""Fixture for querying distributed initialization arguments"""
if request.cls:
request.cls.coordinator_address = request.config.getoption("--coordinator-address")
request.cls.num_processes = int(request.config.getoption("--num-processes"))
request.cls.process_id = int(request.config.getoption("--process-id"))
request.cls.local_device_ids = request.config.getoption("--local-device-ids")
request.cls.num_devices_per_process = (
1
if request.cls.local_device_ids is None
else len(request.cls.local_device_ids.split(","))
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
# Check if NVLINK is supported before running tests
echo "*** Checking NVLINK support***"
NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1)
NVLINK_EXIT_CODE=$?
# Check if command failed OR output indicates no NVLINK
if [ $NVLINK_EXIT_CODE -ne 0 ] || [[ "$NVLINK_OUTPUT" == *"not supported"* ]] || [[ "$NVLINK_OUTPUT" == *"No devices"* ]] || [ -z "$NVLINK_OUTPUT" ]; then
echo "NVLINK is not supported on this platform"
echo "Collective GEMM tests require NVLINK connectivity"
echo "SKIPPING all tests"
exit 0
else
echo "NVLINK support detected"
fi
# Define the test files to run
TEST_FILES=(
"test_gemm.py"
"test_dense_grad.py"
"test_layernorm_mlp_grad.py"
)
echo
echo "*** Executing tests in examples/jax/collective_gemm/ ***"
HAS_FAILURE=0 # Global failure flag
PIDS=() # Array to store all process PIDs
# Cleanup function to kill all processes
cleanup() {
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Killing process $pid"
kill -TERM "$pid" 2>/dev/null || true
fi
done
# Wait a bit and force kill if needed
sleep 2
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Force killing process $pid"
kill -KILL "$pid" 2>/dev/null || true
fi
done
}
# Set up signal handlers to cleanup on exit
trap cleanup EXIT INT TERM
# Run each test file across all GPUs
for TEST_FILE in "${TEST_FILES[@]}"; do
echo
echo "=== Starting test file: $TEST_FILE ..."
# Clear PIDs array for this test file
PIDS=()
for i in $(seq 0 $(($NUM_GPUS - 1))); do
# Define output file for logs
LOG_FILE="${TEST_FILE}_gpu_${i}.log"
if [ $i -eq 0 ]; then
# For process 0: show live output AND save to log file using tee
echo "=== Live output from process 0 ==="
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \
"$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
--num-processes=$NUM_GPUS \
--process-id=$i 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
else
# For other processes: redirect to log files only
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
--num-processes=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
PID=$!
PIDS+=($PID)
fi
done
# Wait for all processes to finish
wait
# Check and print the log content from process 0 (now has log file thanks to tee)
if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE SKIPPED"
elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE FAILED"
HAS_FAILURE=1
elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE PASSED"
else
echo "... $TEST_FILE INVALID"
HAS_FAILURE=1
fi
# Remove the log files after processing them
wait
rm ${TEST_FILE}_gpu_*.log
done
wait
# Final cleanup (trap will also call cleanup on exit)
cleanup
exit $HAS_FAILURE
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Collective Dense Gradient test on multi-GPU with tensor parallelism"""
import argparse
import unittest
import os
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding
import flax
from common import (
assert_allclose,
_initialize_distributed,
_get_dp_and_tp_sizes,
_create_mesh,
DP_AXIS,
TPSP_AXIS,
PARAMS_KEY,
cgemm_parser,
)
from transformer_engine.jax.dense import dense
from transformer_engine.jax.quantize import autocast
from transformer_engine.jax.cpp_extensions.gemm import (
CollectiveOp,
CollectiveOpSet,
noop_collective_op_set,
)
from transformer_engine.jax.sharding import MeshResource
import transformer_engine.jax.flax as te_flax
def _get_logical_axes(collective_op):
if collective_op.is_all_gather:
input_axes = (DP_AXIS, TPSP_AXIS, None)
weight_axes = (None, TPSP_AXIS)
bias_axes = (TPSP_AXIS,)
output_axes = (DP_AXIS, None, TPSP_AXIS)
else: # RS
input_axes = (DP_AXIS, None, TPSP_AXIS)
weight_axes = (TPSP_AXIS, None)
bias_axes = (None,)
output_axes = (DP_AXIS, TPSP_AXIS, None)
return input_axes, weight_axes, bias_axes, output_axes
def _get_operand_sharding(mesh, collective_op):
input_axes, weight_axes, bias_axes, _ = _get_logical_axes(collective_op)
x_sharding = NamedSharding(mesh, PartitionSpec(*input_axes))
weight_sharding = NamedSharding(mesh, PartitionSpec(*weight_axes))
bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes))
return x_sharding, weight_sharding, bias_sharding
def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set):
output = dense(
x,
weight,
bias,
contracting_dims=((2,), (0,)),
input_axes=input_axes,
kernel_axes=weight_axes,
output_axes=output_axes,
collective_op_set=collective_op_set,
)
return jnp.mean(output.astype(jnp.float32))
def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set):
return jax.jit(jax.value_and_grad(_mean_dense, (0, 1, 2)), static_argnums=(3, 4, 5, 6))(
x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set
)
def run_dense_grad_tests(args, mesh=None):
"""Execute Dense Gradient tests."""
print(args)
_initialize_distributed(args)
mesh = mesh or _create_mesh(args)
# Create test data
rng = jax.random.PRNGKey(0)
rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4)
x = jax.random.normal(
x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16
)
weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16)
bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16)
collective_op = (
CollectiveOp.ALL_GATHER
if args.collective_type == "all_gather"
else CollectiveOp.REDUCE_SCATTER
)
collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op)
with mesh, autocast(
enabled=False,
recipe=None,
mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
):
# Get the base axis rules and extend them with TE's rules. This must be done inside autocast
axis_rules = flax.linen.get_logical_axis_rules()
axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS))
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
with flax.linen.logical_axis_rules(te_extended_axis_rules):
x_sharding, weight_sharding, bias_sharding = _get_operand_sharding(mesh, collective_op)
x_sharded = jax.device_put(x, x_sharding)
weight_sharded = jax.device_put(weight, weight_sharding)
bias_sharded = jax.device_put(bias, bias_sharding)
input_axes, weight_axes, _, output_axes = _get_logical_axes(collective_op)
ref_output, ref_grads = _value_and_grad_dense(
x_sharded,
weight_sharded,
bias_sharded,
input_axes,
weight_axes,
output_axes,
noop_collective_op_set,
)
output, sharded_grads = _value_and_grad_dense(
x_sharded,
weight_sharded,
bias_sharded,
input_axes,
weight_axes,
output_axes,
collective_op_set,
)
jax.block_until_ready(ref_output)
jax.block_until_ready(output)
gathered_grads = []
gathered_ref_grads = []
for ref_grad, grad in zip(ref_grads, sharded_grads):
gathered_grads.append(
jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None)))
)
gathered_ref_grads.append(
jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None)))
)
jax.block_until_ready(gathered_grads)
jax.block_until_ready(gathered_ref_grads)
if args.enable_result_check and args.process_id == 0:
assert_allclose(ref_output, output, dtype=jnp.bfloat16)
for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads):
assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16)
class TestCollectiveDenseGradient(unittest.TestCase):
"""Collective Dense Gradient unittests"""
def setUp(self):
self.args = cgemm_parser(
"Collective Dense Gradient test on multi-GPU with tensor parallelism"
).parse_args([])
self.args.coordinator_address = self.coordinator_address
self.args.num_processes = self.num_processes
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
# Create mesh once for all tests
self.mesh = _create_mesh(self.args)
jax.sharding.set_mesh(self.mesh)
self.args.enable_result_check = True
os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1"
def tearDown(self):
os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None)
def test_te_bf16_all_gather(self):
"""Test Collective Dense Gradient with AllGather"""
self.args.collective_type = "all_gather"
run_dense_grad_tests(self.args, self.mesh)
def test_te_bf16_reduce_scatter(self):
"""Test Collective Dense Gradient with ReduceScatter"""
self.args.collective_type = "reduce_scatter"
run_dense_grad_tests(self.args, self.mesh)
if __name__ == "__main__":
import sys
if len(sys.argv) < 7: # Need at least the 3 required distributed args
print("Error: This script requires distributed initialization arguments.")
print(
"Usage: python test_dense_grad.py --coordinator-address <address> --num-processes <num>"
" --process-id <id> [--local-device-ids <ids>] [other args]"
)
print(
"Example: python test_dense_grad.py --coordinator-address localhost:1234"
" --num-processes 4 --process-id 0"
)
print(
"Example: python test_dense_grad.py --coordinator-address localhost:1234"
" --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3"
)
sys.exit(1)
args = cgemm_parser(
"Collective Dense Gradient test on multi-GPU with tensor parallelism"
).parse_args([])
_initialize_distributed(args)
run_dense_grad_tests(args, mesh=None)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Collective GEMM test on multi-GPU with tensor parallelism
This script uses custom distributed initialization with the following arguments:
- --coordinator-address: Coordinator address for distributed initialization
- --num-processes: Number of processes for distributed initialization
- --process-id: Process ID for distributed initialization
- --local-device-ids: Local device IDs for distributed initialization
Example:
python test_gemm.py --coordinator-address localhost:1234 --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3
"""
import unittest
import os
from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding
from common import (
assert_allclose,
_initialize_distributed,
_get_dp_and_tp_sizes,
_create_mesh,
DP_AXIS,
TPSP_AXIS,
PARAMS_KEY,
cgemm_parser,
)
import transformer_engine.jax.cpp_extensions as tex
from transformer_engine.jax.quantize import autocast
from transformer_engine.jax.cpp_extensions.gemm import CollectiveOp
from transformer_engine.jax.sharding import MeshResource
def _get_operand_sharding(mesh, collective_op, is_with_dp):
dp_axis = DP_AXIS if is_with_dp else None
if collective_op == CollectiveOp.ALL_GATHER:
x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None))
weight_sharding = NamedSharding(mesh, PartitionSpec(None, TPSP_AXIS))
bias_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS))
output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS))
else: # RS
x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS))
weight_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS, None))
bias_sharding = NamedSharding(mesh, PartitionSpec(None))
output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None))
return x_sharding, weight_sharding, bias_sharding, output_sharding
def _get_dp_and_tp_sizes(args):
num_gpu = args.num_processes * args.num_devices_per_process
if args.tensor_parallel_size is None:
num_gpu_dp = 2 if args.enable_data_parallel else 1
assert (
num_gpu > 1 and num_gpu % num_gpu_dp == 0
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_tp = num_gpu // num_gpu_dp
else:
num_gpu_tp = args.tensor_parallel_size
assert (
num_gpu > 1 and num_gpu % num_gpu_tp == 0
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_dp = num_gpu // num_gpu_tp
return num_gpu_dp, num_gpu_tp
@partial(jax.jit, static_argnames=("contracting_dims", "collective_op", "output_sharding"))
def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_sharding):
output = tex.gemm(
x,
weight,
bias=bias,
contracting_dims=contracting_dims,
collective_op=collective_op,
)
if output_sharding is not None:
output = jax.lax.with_sharding_constraint(output, output_sharding)
return output
def run_gemm_tests(args, mesh=None):
"""Execute GEMM tests."""
print(args)
# Collective GEMM requires Shardy partitioner to be disabled
jax.config.update("jax_use_shardy_partitioner", False)
# Initialize distributed with provided arguments
_initialize_distributed(args)
mesh = mesh or _create_mesh(args)
# Create test data
rng = jax.random.PRNGKey(0)
rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4)
x = jax.random.normal(
x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16
)
weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16)
bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16)
collective_op = (
CollectiveOp.ALL_GATHER
if args.collective_type == "all_gather"
else CollectiveOp.REDUCE_SCATTER
)
with mesh, autocast(
enabled=False,
recipe=None,
mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
):
print(f"Device mesh: {mesh}")
x_sharding, weight_sharding, bias_sharding, output_sharding = _get_operand_sharding(
mesh, collective_op, args.enable_data_parallel
)
x_sharded = jax.device_put(x, x_sharding)
weight_sharded = jax.device_put(weight, weight_sharding)
bias_sharded = jax.device_put(bias, bias_sharding)
ref_output = _jitted_cgemm(
x_sharded,
weight_sharded,
bias_sharded,
contracting_dims=((2,), (0,)),
collective_op=CollectiveOp.NONE,
output_sharding=output_sharding,
)
output = _jitted_cgemm(
x_sharded,
weight_sharded,
bias_sharded,
contracting_dims=((2,), (0,)),
collective_op=collective_op,
# CollectiveGEMM output should have a correct sharding without applying sharding constraint
output_sharding=None,
)
assert (
ref_output.sharding == output.sharding
), f"ref_output.sharding={ref_output.sharding}, output.sharding={output.sharding}"
gathered_ref_output = jax.lax.with_sharding_constraint(
ref_output, NamedSharding(mesh, PartitionSpec(None))
)
gathered_output = jax.lax.with_sharding_constraint(
output, NamedSharding(mesh, PartitionSpec(None))
)
jax.block_until_ready(gathered_ref_output)
jax.block_until_ready(gathered_output)
if args.enable_result_check and args.process_id == 0:
assert_allclose(gathered_ref_output, gathered_output)
class TestCollectiveGemmWithDP(unittest.TestCase):
"""Collective GEMM with DP unittests"""
def setUp(self):
self.args = cgemm_parser(
"Collective GEMM test on multi-GPU with tensor parallelism"
).parse_args([])
self.args.coordinator_address = self.coordinator_address
self.args.num_processes = self.num_processes
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
self.mesh = _create_mesh(self.args)
jax.sharding.set_mesh(self.mesh)
self.args.enable_result_check = True
os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1"
def tearDown(self):
os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None)
def test_te_bf16_all_gather_with_dp(self):
"""Test Collective GEMM with AllGather"""
self.args.collective_type = "all_gather"
run_gemm_tests(self.args, self.mesh)
def test_te_bf16_reduce_scatter_with_dp(self):
"""Test Collective GEMM with ReduceScatter"""
self.args.collective_type = "reduce_scatter"
run_gemm_tests(self.args, self.mesh)
if __name__ == "__main__":
import sys
if len(sys.argv) < 5: # Need at least the 3 required distributed args
print("Error: This script requires distributed initialization arguments.")
print(
"Usage: python test_gemm.py --coordinator-address <address> --num-processes <num>"
" --process-id <id> [--local-device-ids <ids>] [other args]"
)
sys.exit(1)
args = cgemm_parser("Collective GEMM test on multi-GPU with tensor parallelism").parse_args()
_initialize_distributed(args)
run_gemm_tests(args, mesh=None)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Collective Dense Gradient test on multi-GPU with tensor parallelism"""
import argparse
import unittest
import os
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding
import flax
from common import (
assert_allclose,
_initialize_distributed,
_get_dp_and_tp_sizes,
_create_mesh,
DP_AXIS,
TPSP_AXIS,
PARAMS_KEY,
cgemm_parser,
)
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
from transformer_engine.jax.quantize import autocast
from transformer_engine.jax.cpp_extensions.gemm import (
CollectiveOpSet,
CollectiveOp,
noop_collective_op_set,
)
from transformer_engine.jax.sharding import MeshResource
import transformer_engine.jax.flax as te_flax
def _get_logical_axes():
input_1_axes = (DP_AXIS, TPSP_AXIS, None)
weight_1_axes = (None, None, TPSP_AXIS)
bias_axes_1 = (None, TPSP_AXIS)
input_2_axes = (DP_AXIS, None, TPSP_AXIS)
weight_2_axes = (TPSP_AXIS, None)
bias_axes_2 = (None,)
return input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2
def _get_operand_sharding(mesh):
input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2 = (
_get_logical_axes()
)
x_sharding = NamedSharding(mesh, PartitionSpec(*input_1_axes))
weight_1_sharding = NamedSharding(mesh, PartitionSpec(*weight_1_axes))
bias_1_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_1))
weight_2_sharding = NamedSharding(mesh, PartitionSpec(*weight_2_axes))
bias_2_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_2))
return x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding
def _mean_layernorm_mlp(
x,
weight_1,
bias_1,
weight_2,
bias_2,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
collective_op_sets,
):
output = layernorm_mlp(
x,
gamma,
beta=None,
kernels=[weight_1, weight_2],
biases=[bias_1, bias_2],
norm_type="rmsnorm",
dot_1_input_axes=input_1_axes,
dot_2_input_axes=input_2_axes,
kernel_1_axes=weight_1_axes,
kernel_2_axes=weight_2_axes,
activation_type=("gelu",),
collective_op_sets=collective_op_sets,
)
return jnp.mean(output)
def _value_and_grad_layernorm_mlp(
x,
weight_1,
bias_1,
weight_2,
bias_2,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
collective_op_sets,
):
return jax.jit(
jax.value_and_grad(_mean_layernorm_mlp, (0, 1, 2, 3, 4, 5)), static_argnums=(6, 7, 8, 9, 10)
)(
x,
weight_1,
bias_1,
weight_2,
bias_2,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
collective_op_sets,
)
def run_layernorm_mlp_grad_tests(args, mesh=None):
"""Execute Dense Gradient tests."""
print(args)
# Collective GEMM requires Shardy partitioner to be disabled
jax.config.update("jax_use_shardy_partitioner", False)
# Initialize distributed with provided arguments
_initialize_distributed(args)
mesh = mesh or _create_mesh(args)
# Create test data
rng = jax.random.PRNGKey(0)
rng, x_rng, weight_1_rng, bias_1_rng, weight_2_rng, bias_2_rng, gamma_rng = jax.random.split(
rng, 7
)
x = jax.random.normal(
x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16
)
weight_1 = jax.random.normal(
weight_1_rng, (args.hidden_in, 1, args.hidden_out), dtype=jnp.bfloat16
) / jnp.sqrt(args.hidden_in)
bias_1 = jax.random.normal(bias_1_rng, (1, args.hidden_out), dtype=jnp.bfloat16)
weight_2 = jax.random.normal(
weight_2_rng, (args.hidden_out, args.hidden_in), dtype=jnp.bfloat16
) / jnp.sqrt(args.hidden_out)
bias_2 = jax.random.normal(bias_2_rng, (args.hidden_in,), dtype=jnp.bfloat16)
gamma = jax.random.normal(gamma_rng, (args.hidden_in,), dtype=jnp.bfloat16) / jnp.sqrt(
args.hidden_in
)
collective_op_set_1 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.ALL_GATHER)
collective_op_set_2 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.REDUCE_SCATTER)
collective_op_sets = (collective_op_set_1, collective_op_set_2)
noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set)
with mesh, autocast(
enabled=False,
recipe=None,
mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
):
# Get the base axis rules and extend them with TE's rules. This must be done inside autocast
axis_rules = flax.linen.get_logical_axis_rules()
axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS))
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
with flax.linen.logical_axis_rules(te_extended_axis_rules):
x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding = (
_get_operand_sharding(mesh)
)
x_sharded = jax.device_put(x, x_sharding)
weight_1_sharded = jax.device_put(weight_1, weight_1_sharding)
bias_1_sharded = jax.device_put(bias_1, bias_1_sharding)
weight_2_sharded = jax.device_put(weight_2, weight_2_sharding)
bias_2_sharded = jax.device_put(bias_2, bias_2_sharding)
input_1_axes, weight_1_axes, _, input_2_axes, weight_2_axes, _ = _get_logical_axes()
ref_output, ref_grads = _value_and_grad_layernorm_mlp(
x_sharded,
weight_1_sharded,
bias_1_sharded,
weight_2_sharded,
bias_2_sharded,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
noop_collective_op_sets,
)
output, sharded_grads = _value_and_grad_layernorm_mlp(
x_sharded,
weight_1_sharded,
bias_1_sharded,
weight_2_sharded,
bias_2_sharded,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
collective_op_sets,
)
jax.block_until_ready(ref_output)
jax.block_until_ready(output)
gathered_grads = []
gathered_ref_grads = []
for ref_grad, grad in zip(ref_grads, sharded_grads):
gathered_grads.append(
jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None)))
)
gathered_ref_grads.append(
jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None)))
)
jax.block_until_ready(gathered_grads)
jax.block_until_ready(gathered_ref_grads)
if args.enable_result_check and args.process_id == 0:
assert_allclose(ref_output, output, dtype=jnp.bfloat16)
for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads):
assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16)
class TestCollectiveLayerNormMLPGradient(unittest.TestCase):
"""Collective Dense Gradient unittests"""
def setUp(self):
self.args = cgemm_parser(
"Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism"
).parse_args([])
self.args.coordinator_address = self.coordinator_address
self.args.num_processes = self.num_processes
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
# Create mesh once for all tests
self.mesh = _create_mesh(self.args)
jax.sharding.set_mesh(self.mesh)
self.args.enable_result_check = True
os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1"
def tearDown(self):
os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None)
def test_te_bf16_layernorm_mlp_grad(self):
"""Test Collective Dense Gradient with AllGather"""
run_layernorm_mlp_grad_tests(self.args, self.mesh)
if __name__ == "__main__":
import sys
if len(sys.argv) < 7: # Need at least the 3 required distributed args
print("Error: This script requires distributed initialization arguments.")
print(
"Usage: python test_layernorm_mlp_grad.py --coordinator-address <address>"
" --num-processes <num> --process-id <id> [--local-device-ids <ids>] [other args]"
)
print(
"Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234"
" --num-processes 4 --process-id 0"
)
print(
"Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234"
" --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3"
)
sys.exit(1)
args = cgemm_parser(
"Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism"
).parse_args([])
_initialize_distributed(args)
run_layernorm_mlp_grad_tests(args, mesh=None)
......@@ -8,7 +8,7 @@ This example uses Transformer Encoder to demonstrate the Transformer Engine usag
2. Define model: The `Net` class is a small Transformer Encoder model for sentence classification. The Transformer Engine provides `te.TransformerLayer` as encoder block and `te.DenseGeneral`. The structure of encoder block can be referred to [Scaling Up Models and Data with t5x and seqio](https://arxiv.org/abs/2203.17189)
3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. Use `fp8_autocast` context manager to enable FP8 training and check `var_collect` if the variable collection contains `Float8`.
3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. Use `autocast` context manager to enable FP8 training and check `var_collect` if the variable collection contains `Float8`.
4. Training process: In `train_step`, combine the FP8 metadata and latest model parameters into var_collect as a frozen dictionary and fill it to the gradient function.
......@@ -29,7 +29,7 @@ python test_single_gpu_encoder.py --use-fp8
3. On the model side, the logical axis of each weight tensor of the model can be named. The `te.TransformerLayer` has the default names, which are stored in `abs_var_collect`, a collection of variables returned by `jax.eval_shape(encoder.init, ...)`. The key index is `params_axes`. The `te.DenseGeneral` doesn't have the default named axis because it is generic. Also, data-parallel sharding doesn't need to divide weight tensor, so named axis is not required for this case.
4. The next is to create sharding rules, mapping the device axis to the logical axis. The `te.extend_logical_axis_rules` under fp8_autocast will return a list of pairs of the mapping, such as `(('batch', 'data'), ...)`. The first is the logical axis and second is the device axis.
4. The next is to create sharding rules, mapping the device axis to the logical axis. The `te.extend_logical_axis_rules` under autocast will return a list of pairs of the mapping, such as `(('batch', 'data'), ...)`. The first is the logical axis and second is the device axis.
5. Refer structure of `abs_var_collect['params']` and `abs_var_collect['params_axes']` to set up `PartitionSpec` for parallel jit. All logical axes should be replaced by device axes. If the value of PartitionSpec is None, that means no sharding, broadcasting the data to every device. Note that the `params_axes` attribute is provided by Transformer Engine. The Flax's module doesn't have it, such as `nn.Embed`. For nn.Embed, assigning an empty PartitionSpec is fine because each device has its own embedding layer in DP mode. The `get_params_pspec` routine is used for this purpose. Because each device has a complete model in DP mode, all values of PartitionSpec in params_pspec should be None. This will be different in the model parallelism example.
......@@ -136,4 +136,4 @@ numactl --cpunodebind=112 --membind=7 python test_multiprocessing_encoder.py --n
numactl --cpunodebind=113 --membind=7 python test_multiprocessing_encoder.py --num-process 8 --process-id 5 &
numactl --cpunodebind=80 --membind=5 python test_multiprocessing_encoder.py --num-process 8 --process-id 6 &
numactl --cpunodebind=81 --membind=5 python test_multiprocessing_encoder.py --num-process 8 --process-id 7 &
```
\ No newline at end of file
```
......@@ -33,6 +33,13 @@ def is_mxfp8_supported():
return gpu_arch >= 100
@lru_cache
def is_nvfp4_supported():
"""Return if FP8 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 100
def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01, print_info=False):
"""Checks whether most params are sharded across sharding axis.
......@@ -98,7 +105,7 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01, print_info=
)
def get_fp8_recipe_from_name_string(name: str):
def get_quantization_recipe_from_name_string(name: str):
"""Query recipe from a given name string"""
match name:
case "DelayedScaling":
......@@ -107,5 +114,7 @@ def get_fp8_recipe_from_name_string(name: str):
return recipe.MXFP8BlockScaling()
case "Float8CurrentScaling":
return recipe.Float8CurrentScaling()
case "NVFP4BlockScaling":
return recipe.NVFP4BlockScaling()
case _:
raise ValueError(f"Invalid fp8_recipe, got {name}")
raise ValueError(f"Invalid quantization_recipe, got {name}")
......@@ -10,16 +10,44 @@ TEST_CASES=(
"test_te_delayed_scaling_fp8"
"test_te_current_scaling_fp8"
"test_te_mxfp8"
"test_te_nvfp4"
"test_te_bf16_shardy"
"test_te_delayed_scaling_fp8_shardy"
"test_te_current_scaling_fp8_shardy"
"test_te_nvfp4_shardy"
)
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
echo
echo "*** Executing tests in examples/jax/encoder/test_multiprocessing_encoder.py ***"
HAS_FAILURE=0 # Global failure flag
PIDS=() # Array to store all process PIDs
# Cleanup function to kill all processes
cleanup() {
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Killing process $pid"
kill -TERM "$pid" 2>/dev/null || true
fi
done
# Wait a bit and force kill if needed
sleep 2
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Force killing process $pid"
kill -KILL "$pid" 2>/dev/null || true
fi
done
}
# Set up signal handlers to cleanup on exit
trap cleanup EXIT INT TERM
# Run each test case across all GPUs
for TEST_CASE in "${TEST_CASES[@]}"; do
echo
......@@ -29,25 +57,40 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
# Define output file for logs
LOG_FILE="${TEST_CASE}_gpu_${i}.log"
# Run pytest and redirect stdout and stderr to the log file
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
--num-process=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
done
# For process 0: show live output AND save to log file using tee
if [ $i -eq 0 ]; then
echo "=== Live output from process 0 ==="
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs --junitxml=$XML_LOG_DIR/multiprocessing_encoder_${TEST_CASE}.xml \
"$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
--num-process=$NUM_GPUS \
--process-id=$i 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
else
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
--num-process=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
PID=$!
PIDS+=($PID)
fi
done
# Wait for the process to finish
wait
tail -n +7 "${TEST_CASE}_gpu_0.log"
# Check and print the log content accordingly
if grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE SKIPPED"
elif grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE FAILED"
HAS_FAILURE=1
elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE PASSED"
else
echo "... $TEST_CASE INVALID"
HAS_FAILURE=1
echo "... $TEST_CASE FAILED"
fi
# Remove the log file after processing it
......@@ -56,4 +99,8 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
done
wait
# Final cleanup (trap will also call cleanup on exit)
cleanup
exit $HAS_FAILURE
......@@ -21,13 +21,13 @@ from jax.sharding import PartitionSpec, NamedSharding
from common import (
is_bf16_supported,
get_fp8_recipe_from_name_string,
get_quantization_recipe_from_name_string,
assert_params_sufficiently_sharded,
)
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
DEVICE_DP_AXIS = "data"
......@@ -36,6 +36,7 @@ NAMED_BROADCAST_AXIS = "my_broadcast_axis"
NAMED_TP_AXIS = "my_tp_axis"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
SR_KEY = "sr_rng"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
......@@ -121,6 +122,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
epoch_accuracy = []
for perm in perms:
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds["label"][perm, ...]
......@@ -135,11 +138,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
return state, avg_loss, avg_accuracy, var_collect
def eval_step(state, inputs, masks, labels, var_collect):
def eval_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......@@ -150,7 +153,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
return loss, accuracy
def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
def eval_model(state, test_ds, batch_size, var_collect, eval_fn, rngs):
"""Evaluation loop."""
test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size
......@@ -159,11 +162,13 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
all_accuracy = []
for batch_start in range(0, valid_size, batch_size):
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_end = batch_start + batch_size
batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect, rngs)
all_loss.append(loss)
all_accuracy.append(accuracy)
......@@ -223,7 +228,7 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)}
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
......@@ -257,16 +262,16 @@ def train_and_evaluate(args):
), "Test batch size needs to be multiple of 32 for MXFP8"
if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh, te.fp8_autocast(
) as mesh, te.autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
recipe=fp8_recipe,
mesh_resource=te.MeshResource(
dp_resource=DEVICE_DP_AXIS,
tpsp_resource=DEVICE_TP_AXIS,
......@@ -275,13 +280,14 @@ def train_and_evaluate(args):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
rng, sr_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
# Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
# Get the base axis rules and extend them with TE's rules. This must be done inside autocast
axis_rules = flax.linen.get_logical_axis_rules()
axis_rules += ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
......@@ -355,7 +361,14 @@ def train_and_evaluate(args):
train_step, in_shardings=in_shardings, out_shardings=out_shardings
)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
)
out_shardings = (None, None)
jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings
......@@ -367,22 +380,24 @@ def train_and_evaluate(args):
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng_state}
jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
return None
for epoch in range(1, args.epochs + 1):
# Split and reassign to 'rng' to ensure unique rng for each step
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
rng, sr_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
)
test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, jit_eval_step
state, test_ds, args.test_batch_size, var_collect, jit_eval_step, rngs
)
print(
......@@ -402,16 +417,16 @@ def encoder_parser(args):
parser.add_argument(
"--batch-size",
type=int,
default=128,
default=256,
metavar="N",
help="input batch size for training (default: 128)",
help="input batch size for training (default: 256)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=128,
default=256,
metavar="N",
help="input batch size for testing (default: 128)",
help="input batch size for testing (default: 256)",
)
parser.add_argument(
"--max-seq-len",
......@@ -466,8 +481,9 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
def setUp(self):
"""Run 5 epochs for testing"""
......@@ -477,7 +493,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
......@@ -485,7 +501,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.361 and actual[1] > 0.84
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
......@@ -493,14 +509,22 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4(self):
"""Test Transformer Engine with NVFP4"""
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self):
......@@ -509,7 +533,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self):
......@@ -518,14 +542,23 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_with_sp(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
......@@ -534,7 +567,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self):
......@@ -544,24 +577,27 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.361 and actual[1] > 0.84
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_with_sp_shardy(self):
"""Test Transformer Engine with MXFP8 + SP"""
self.args.enable_shardy = True
......@@ -569,7 +605,17 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_with_sp_shardy(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82
if __name__ == "__main__":
......
......@@ -19,17 +19,18 @@ from flax.training import train_state
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding
from common import is_bf16_supported, get_fp8_recipe_from_name_string
from common import is_bf16_supported, get_quantization_recipe_from_name_string
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
DEVICE_DP_AXIS = "data"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
DROPOUT_KEY = "dropout"
SR_KEY = "sr_rng"
INPUT_KEY = "input_rng"
......@@ -97,6 +98,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
epoch_accuracy = []
for perm in perms:
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds["label"][perm, ...]
......@@ -111,11 +114,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
return state, avg_loss, avg_accuracy, var_collect
def eval_step(state, inputs, masks, labels, var_collect):
def eval_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......@@ -126,7 +129,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
return loss, accuracy
def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
def eval_model(state, test_ds, batch_size, var_collect, eval_fn, rngs):
"""Evaluation loop."""
test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size
......@@ -135,11 +138,13 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
all_accuracy = []
for batch_start in range(0, valid_size, batch_size):
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_end = batch_start + batch_size
batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect, rngs)
all_loss.append(loss)
all_accuracy.append(accuracy)
......@@ -199,7 +204,7 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)}
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
......@@ -254,29 +259,28 @@ def train_and_evaluate(args):
), "Test batch size needs to be multiple of 32 for MXFP8"
if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None
device_mesh = mesh_utils.create_device_mesh((num_gpu,))
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)
) as mesh, te.fp8_autocast(
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh, te.autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
recipe=fp8_recipe,
mesh_resource=te.MeshResource(dp_resource=DEVICE_DP_AXIS),
):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rng, sr_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
# Add TE logical axis rules to our Flax logical axis rule context. This must be done inside fp8_autocast
# Add TE logical axis rules to our Flax logical axis rule context. This must be done inside autocast
sharding_rules = te_flax.extend_logical_axis_rules(tuple())
with flax.linen.logical_axis_rules(sharding_rules):
encoder = Net(num_embed)
......@@ -322,7 +326,14 @@ def train_and_evaluate(args):
train_step, in_shardings=in_shardings, out_shardings=out_shardings
)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
)
out_shardings = (None, None)
jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings
......@@ -334,22 +345,24 @@ def train_and_evaluate(args):
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
return None
for epoch in range(1, args.epochs + 1):
# Split and reassign to 'rng' to ensure unique rng for each step
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
rng, sr_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
)
test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, jit_eval_step
state, test_ds, args.test_batch_size, var_collect, jit_eval_step, rngs
)
print(
......@@ -369,16 +382,16 @@ def encoder_parser(args):
parser.add_argument(
"--batch-size",
type=int,
default=256,
default=512,
metavar="N",
help="input batch size for training (default: 256)",
help="input batch size for training (default: 512)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=256,
default=512,
metavar="N",
help="input batch size for testing (default: 256)",
help="input batch size for testing (default: 512)",
)
parser.add_argument(
"--max-seq-len",
......@@ -430,8 +443,9 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
def setUp(self):
"""Run 5 epochs for testing"""
......@@ -441,7 +455,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74
assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
......@@ -449,7 +463,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74
assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
......@@ -457,7 +471,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74
assert actual[0] < 0.51 and actual[1] > 0.749
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
......@@ -465,6 +479,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4(self):
"""Test Transformer Engine with NVFP4"""
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
......@@ -472,7 +494,7 @@ class TestEncoder(unittest.TestCase):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74
assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
......@@ -481,7 +503,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74
assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8_shardy(self):
......@@ -490,18 +512,24 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74
assert actual[0] < 0.51 and actual[1] > 0.749
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74
......
......@@ -25,7 +25,8 @@ from common import (
is_bf16_supported,
is_fp8_supported,
is_mxfp8_supported,
get_fp8_recipe_from_name_string,
is_nvfp4_supported,
get_quantization_recipe_from_name_string,
)
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
......@@ -39,6 +40,7 @@ NAMED_BROADCAST_AXIS = "my_broadcast_axis"
NAMED_TP_AXIS = "my_tp_axis"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
SR_KEY = "sr_rng"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
......@@ -175,6 +177,8 @@ def train_epoch(
epoch_accuracy = []
for perm in perms:
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_input = sentence[perm, ...]
batch_mask = mask[perm, ...]
batch_label = label[perm, ...]
......@@ -200,11 +204,11 @@ def train_epoch(
return state, avg_loss, avg_accuracy, var_collect
def eval_step(state, inputs, masks, labels, var_collect):
def eval_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......@@ -216,7 +220,16 @@ def eval_step(state, inputs, masks, labels, var_collect):
def eval_model(
state, test_ds, batch_size, var_collect, eval_fn, mesh, inputs_pspec, masks_pspec, labels_pspec
state,
test_ds,
batch_size,
var_collect,
eval_fn,
mesh,
inputs_pspec,
masks_pspec,
labels_pspec,
rngs,
):
"""Evaluation loop."""
global_input_shape, input_named_sharding, sentence = shard_array_wrapper(
......@@ -233,7 +246,8 @@ def eval_model(
all_accuracy = []
for batch_input, batch_mask, batch_label in zip(sentence, mask, label):
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
shard_input = jax.make_array_from_single_device_arrays(
global_input_shape, input_named_sharding, [batch_input]
)
......@@ -244,7 +258,7 @@ def eval_model(
global_label_shape, label_named_sharding, [batch_label]
)
loss, accuracy = eval_fn(state, shard_input, shard_mask, shard_label, var_collect)
loss, accuracy = eval_fn(state, shard_input, shard_mask, shard_label, var_collect, rngs)
all_loss.append(loss)
all_accuracy.append(accuracy)
......@@ -303,7 +317,7 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)}
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
......@@ -372,16 +386,16 @@ def train_and_evaluate(args):
), "Test batch size needs to be multiple of 32 for MXFP8"
if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh, te.fp8_autocast(
) as mesh, te.autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
recipe=fp8_recipe,
mesh_resource=te.MeshResource(
dp_resource=DEVICE_DP_AXIS,
tpsp_resource=DEVICE_TP_AXIS,
......@@ -390,7 +404,8 @@ def train_and_evaluate(args):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
rng, sr_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
......@@ -398,7 +413,7 @@ def train_and_evaluate(args):
# Create custom Flax logical axis rules for sharding.
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
# Extend the logical axis rules with TE's rules. This must be done inside fp8_autocast.
# Extend the logical axis rules with TE's rules. This must be done inside autocast.
sharding_rules = te_flax.extend_logical_axis_rules(customized_rules)
with flax.linen.logical_axis_rules(sharding_rules):
......@@ -444,7 +459,14 @@ def train_and_evaluate(args):
train_step, in_shardings=in_shardings, out_shardings=out_shardings
)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
)
out_shardings = (None, None)
jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings
......@@ -456,14 +478,16 @@ def train_and_evaluate(args):
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng_state}
jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
else:
for epoch in range(1, args.epochs + 1):
# Split and reassign to 'rng' to ensure unique rng for each step
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
rng, sr_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state,
......@@ -488,6 +512,7 @@ def train_and_evaluate(args):
inputs_pspec,
masks_pspec,
labels_sharding.spec,
rngs,
)
if args.process_id == 0:
print(
......@@ -508,16 +533,16 @@ def encoder_parser(args):
parser.add_argument(
"--batch-size",
type=int,
default=128,
default=256,
metavar="N",
help="input batch size for training (default: 128)",
help="input batch size for training (default: 256)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=128,
default=256,
metavar="N",
help="input batch size for testing (default: 128)",
help="input batch size for testing (default: 256)",
)
parser.add_argument(
"--max-seq-len",
......@@ -629,7 +654,7 @@ class TestEncoder(unittest.TestCase):
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling")
assert result[0] < 0.43 and result[1] > 0.80
assert result[0] < 0.432 and result[1] > 0.80
@unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
......@@ -639,6 +664,14 @@ class TestEncoder(unittest.TestCase):
result = self.exec(True, "MXFP8BlockScaling")
assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(
not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4"
)
def test_te_nvfp4(self):
"""Test Transformer Engine with NVFP4"""
result = self.exec(True, "NVFP4BlockScaling")
assert result[0] < 0.451 and result[1] > 0.79
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
......@@ -659,19 +692,24 @@ class TestEncoder(unittest.TestCase):
def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling", enable_shardy=True)
assert result[0] < 0.43 and result[1] > 0.80
assert result[0] < 0.432 and result[1] > 0.80
@unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True)
assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(
not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4"
)
def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4"""
result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True)
assert result[0] < 0.451 and result[1] > 0.79
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
......@@ -16,14 +16,15 @@ from datasets import load_dataset
from flax import linen as nn
from flax.training import train_state
from common import is_bf16_supported, get_fp8_recipe_from_name_string
from common import is_bf16_supported, get_quantization_recipe_from_name_string
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
PARAMS_KEY = "params"
DROPOUT_KEY = "dropout"
SR_KEY = "sr_rng"
INPUT_KEY = "input_rng"
......@@ -92,6 +93,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect):
epoch_accuracy = []
for perm in perms:
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds["label"][perm, ...]
......@@ -107,11 +110,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect):
@jax.jit
def eval_step(state, inputs, masks, labels, var_collect):
def eval_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......@@ -122,7 +125,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
return loss, accuracy
def eval_model(state, test_ds, batch_size, var_collect):
def eval_model(state, test_ds, batch_size, var_collect, rngs):
"""Evaluation loop."""
test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size
......@@ -131,11 +134,15 @@ def eval_model(state, test_ds, batch_size, var_collect):
all_accuracy = []
for batch_start in range(0, valid_size, batch_size):
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_end = batch_start + batch_size
batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_step(state, batch_inputs, batch_masks, batch_labels, var_collect)
loss, accuracy = eval_step(
state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
)
all_loss.append(loss)
all_accuracy.append(accuracy)
......@@ -195,7 +202,7 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)}
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
......@@ -208,19 +215,20 @@ def train_and_evaluate(args):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
rng, sr_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None
with te.fp8_autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
with te.autocast(
enabled=args.use_fp8, recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
):
encoder = Net(num_embed)
# We use nn.Embed, thus inputs need to be in int
......@@ -238,21 +246,25 @@ def train_and_evaluate(args):
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
return None
for epoch in range(1, args.epochs + 1):
# Split and reassign to 'rng' to ensure unique rng for each step
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
rng, sr_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect
)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)
test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, rngs
)
print(
f"Epoch: {epoch:>2} "
......@@ -329,8 +341,9 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
def setUp(self):
"""Run 3 epochs for testing"""
......@@ -340,7 +353,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
assert actual[0] < 0.452 and actual[1] > 0.788
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
......@@ -348,7 +361,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79
assert actual[0] < 0.457 and actual[1] > 0.784
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
......@@ -356,7 +369,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79
assert actual[0] < 0.461 and actual[1] > 0.784
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
......@@ -364,7 +377,15 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79
assert actual[0] < 0.457 and actual[1] > 0.784
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4(self):
"""Test Transformer Engine with NVFP4"""
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.476 and actual[1] > 0.775
if __name__ == "__main__":
......
......@@ -6,13 +6,13 @@ This example uses MNIST training to demonstrate the Transformer Engine usage. Th
2. Define model: The `Net` class is a small CNN model for image classification. It has an option to switch between using `nn.Dense` provided by Flax and `te.DenseGeneral` provided by the Transformer Engine. This allows for easy comparison between the two libraries.
3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. For FP8 training, the key is `te.fp8_autocast` context manager. If fp8_autocast is enabled, it will cast all `te.DenseGeneral` to FP8 precision. The `var_collect` is a collection including needed information for model training, such as parameters and FP8 metadata, which is necessary for correct casting of BF16 tensors into FP8 tensors at runtime. If fp8_autocast is turned on and print var_collect, you will see FP8 metadata inside, such as `fp8_meta_collection` section. The training and evaluating with FP8 have to be done under fp8_autocast. If not, then fp8_autocast will deconstruct the FP8 metadata, and the model will fall back to higher floating point precision, such as BF16 in this example. To check if FP8 is enabled, use the `check_fp8` routine. If model initialization with FP8 works fine, the string returned by jax.make_jaxpr should include the `Float8` keyword.
3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. For FP8 training, the key is `te.autocast` context manager. If autocast is enabled, it will cast all `te.DenseGeneral` to FP8 precision. The `var_collect` is a collection including needed information for model training, such as parameters and FP8 metadata, which is necessary for correct casting of BF16 tensors into FP8 tensors at runtime. If autocast is turned on and print var_collect, you will see FP8 metadata inside, such as `fp8_meta_collection` section. The training and evaluating with FP8 have to be done under autocast. If not, then autocast will deconstruct the FP8 metadata, and the model will fall back to higher floating point precision, such as BF16 in this example. To check if FP8 is enabled, use the `check_fp8` routine. If model initialization with FP8 works fine, the string returned by jax.make_jaxpr should include the `Float8` keyword.
4. Training process: In `apply_model`, the main difference between normal Flax usage and this example is, with FP8 training, the FP8 metadata has to be filled into the gradient function `grad_fn`. Otherwise, the Transformer Engine doesn't know how to cast the BF16 tensor into FP8 tensor at runtime correctly. The FP8 metadata doesn't belong in model parameters (`state.params`), so we need to manually combine the metadata and latest model parameters into var_collect as a frozen dictionary and fill it to the gradient function.
5. Evaluating process: The evaluating process is the same as the training process. Need to ensure FP8 metadata is inside var_collect and fill it into loss function.
6. Additional options: The `te.fp8_autocast` context manager has additional options
6. Additional options: The `te.autocast` context manager has additional options
* FP8 Recipe: control FP8 training behavior. See the [FP8 tutorial](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) for a detailed explanation of FP8 recipes and the supported options.
## Run ##
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment