Commit 44740c6c authored by yuguo's avatar yuguo
Browse files

Merge commit '7a9a0825' of...

Merge commit '7a9a0825' of https://github.com/NVIDIA/TransformerEngine
parents 8113d9e0 7a9a0825
...@@ -43,7 +43,7 @@ jobs: ...@@ -43,7 +43,7 @@ jobs:
run: | run: |
apt-get update apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12 apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript
- name: 'Checkout' - name: 'Checkout'
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
...@@ -83,7 +83,7 @@ jobs: ...@@ -83,7 +83,7 @@ jobs:
options: --user root options: --user root
steps: steps:
- name: 'Dependencies' - name: 'Dependencies'
run: pip install torch pybind11[global] einops run: pip install torch pybind11[global] einops onnxscript
- name: 'Checkout' - name: 'Checkout'
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
......
...@@ -39,3 +39,4 @@ downloads/ ...@@ -39,3 +39,4 @@ downloads/
compile_commands.json compile_commands.json
.nfs .nfs
tensor_dumps/ tensor_dumps/
artifacts/
Subproject commit 20c28ea798fe99e31d7274e009ee2fbf0e88abfd
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import argparse
import torch
import torch.utils.benchmark as benchmark
import pandas as pd
import pathlib
from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.common.recipe import Float8BlockScaling, MXFP8BlockScaling
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from contextlib import nullcontext
"""
# Profile BF16 recipe with Nsight Systems
nsys profile \
--output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_bf16 \
--force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe bf16
# Profile FP8 sub-channel recipe with Nsight Systems
nsys profile \
--output=./benchmarks/linear/h100hbm_mkn_4096_4096_4096_numgemm_8_fp8_sub_channel \
--force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe fp8_sub_channel
# Profile MXFP8 recipe with Nsight Systems
nsys profile \
--output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_mxfp8 \
--force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe mxfp8
"""
RECIPES = {
"bf16": None,
"fp8_sub_channel": Float8BlockScaling(),
"mxfp8": MXFP8BlockScaling(),
}
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None):
assert mode in ["fwd_only", "fwd_bwd"]
fp8_context = (
fp8_autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext()
)
# print(f"fp8_context: {fp8_context} and is it nullcontext? {isinstance(fp8_context, nullcontext)}")
if mode == "fwd_only":
with torch.no_grad(), fp8_context:
for i in range(run_num_steps):
y_q = layer.forward(
x,
m_splits,
is_first_microbatch=(i == 0),
)
return y_q
else:
# reset gradients
layer.zero_grad()
x.grad = None
with fp8_context:
for i in range(run_num_steps):
label = f"step_{i}"
torch.cuda.nvtx.range_push(label)
y_q = layer.forward(
x,
m_splits,
is_first_microbatch=(i == 0),
)
y_q.backward(gradient)
torch.cuda.nvtx.range_pop()
grads_q = []
grads_q.append(x.grad)
# remaining derivatives are in respect to model parameters
for p in layer.parameters():
if p.requires_grad:
grads_q.append(p.grad)
return y_q, grads_q
def benchmark_linear(
x,
ws,
m_splits,
bias,
recipe_name,
mode,
num_gemms=4,
):
params_dtype = torch.bfloat16
recipe = RECIPES[recipe_name]
in_features = x.shape[1]
out_features = ws[0].shape[0]
gradient = torch.ones((x.shape[0], out_features), dtype=torch.bfloat16, device=x.device)
layer = GroupedLinear(
num_gemms,
in_features,
out_features,
bias=bias is not None,
params_dtype=params_dtype,
)
layer = layer.to("cuda")
with torch.no_grad():
for i in range(num_gemms):
weight_i = getattr(layer, f"weight{i}")
weight_i.copy_(ws[i])
if bias is not None:
bias_i = getattr(layer, f"bias{i}")
bias_i.copy_(bias)
num_microbatches = 32
label = f"{recipe_name}_{'grouped'}"
torch.cuda.nvtx.range_push(label)
timing = benchmark.Timer(
stmt=(
"run_linear_multiple_steps(layer, x, m_splits, mode, gradient, num_microbatches,"
" recipe)"
),
globals={
"run_linear_multiple_steps": run_linear_multiple_steps,
"layer": layer,
"x": x,
"m_splits": m_splits,
"mode": mode,
"gradient": gradient,
"num_microbatches": num_microbatches,
"recipe": recipe,
},
num_threads=1,
).blocked_autorange(min_run_time=5)
print(f"{recipe_name}: {timing} \n")
timing_ms = timing.median * 1000 / num_microbatches
return timing_ms
def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
data = []
assert not use_bias, "Bias is not supported for GroupedLinear benchmark"
print(f"========== Benchmarking {recipe_name} ==========")
for m, k, n in mkns:
device = "cuda"
x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True)
ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)]
assert m % num_gemms == 0
m_splits = [m // num_gemms] * num_gemms
# Bias is not supported for GroupedLinear benchmark
bias = None
# Run the benchmark
print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}")
grouped_fwd_bwd_timing_ms = benchmark_linear(
x,
ws,
m_splits,
bias,
recipe_name,
mode="fwd_bwd",
num_gemms=num_gemms,
)
# Append the results
data.append(
[
m,
k,
n,
recipe_name,
num_gemms,
grouped_fwd_bwd_timing_ms,
]
)
df = pd.DataFrame(
data=data,
columns=[
"m",
"k",
"n",
"recipe",
"num_gemms",
"grouped_fwd_bwd_time_ms",
],
)
print(df, "\n")
return df
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
parser.add_argument(
"--output_dir",
type=str,
default="benchmark_output/",
help="output path for report",
)
# arguments for recipe, options are fp8_sub_channel, mxfp8, bf16, all
parser.add_argument(
"--recipe",
type=str,
default="bf16",
help="Recipe to use, options are fp8_sub_channel, mxfp8, bf16, or all",
)
args = parser.parse_args()
use_bias = False
# Set the MKN values to benchmark
mkns = []
for m in [8192]:
# for m in [4096, 8192, 16384]:
# for n in [1024, 2048, 4096, 8192, 16384]:
for n in [8192]:
for k in [4096]:
mkns.append((m, k, n))
# default recipes to run if not specified
recipe_list = ["bf16"]
if args.recipe == "all":
recipe_list = ["bf16", "fp8_sub_channel", "mxfp8"]
else:
recipe_list = [args.recipe]
num_gemms_list = [8]
if args.profile:
mkns = [(4096, 4096, 4096)]
# in profile mode, only run one recipe specified in args.recipe
assert args.recipe != "all", (
"In profile mode, only one recipe can be specified, please specify the recipe as"
" fp8_sub_channel, mxfp8, or bf16"
)
recipe_list = [args.recipe]
num_gemms_list = [8]
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
# Initialize a dataframe to store the results
df_linears = pd.DataFrame()
# Run the fp8 benchmarks
for num_gemms in num_gemms_list:
print(f"========== Benchmarking with num_gemms={num_gemms} ==========")
for recipe_name in recipe_list:
assert recipe_name in [
"bf16",
"fp8_sub_channel",
"mxfp8",
], "Recipe must be one of bf16, fp8_sub_channel, or mxfp8"
if recipe_name == "mxfp8" and not mxfp8_available:
print(f"MXFP8 is not available, skipping {recipe_name}")
continue
if recipe_name == "fp8_sub_channel" and not fp8_block_scaling_available:
print(f"FP8 block scaling is not available, skipping {recipe_name}")
continue
df = run_benchmark_linear(
mkns,
recipe_name,
use_bias,
num_gemms=num_gemms,
)
df_linears = pd.concat([df_linears, df])
print(df_linears)
if args.profile:
torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""JAX related extensions.""" """JAX related extensions."""
import os import os
from pathlib import Path from pathlib import Path
from packaging import version
import setuptools import setuptools
...@@ -27,7 +28,13 @@ def xla_path() -> str: ...@@ -27,7 +28,13 @@ def xla_path() -> str:
Throws FileNotFoundError if XLA source is not found.""" Throws FileNotFoundError if XLA source is not found."""
try: try:
from jax.extend import ffi import jax
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
except ImportError: except ImportError:
if os.getenv("XLA_HOME"): if os.getenv("XLA_HOME"):
xla_home = Path(os.getenv("XLA_HOME")) xla_home = Path(os.getenv("XLA_HOME"))
......
...@@ -13,12 +13,19 @@ from typing import List ...@@ -13,12 +13,19 @@ from typing import List
def install_requirements() -> List[str]: def install_requirements() -> List[str]:
"""Install dependencies for TE/JAX extensions.""" """Install dependencies for TE/PyTorch extensions."""
reqs = ["torch>=2.1", "einops"] reqs = ["torch>=2.1", "einops"]
# reqs.append( # reqs.append(
# "nvdlfw-inspect @" # "nvdlfw-inspect @"
# " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect" # " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
# ) # )
reqs.extend(
[
"torch>=2.1",
# "onnx",
# "onnxscript@git+https://github.com/microsoft/onnxscript.git@51ecf47523ef079c53b0e620c62d56d70cfd3871",
]
)
return reqs return reqs
......
datasets datasets<4.0.0
flax>=0.7.1 flax>=0.7.1
nltk>=3.8.2 nltk>=3.8.2
optax optax
...@@ -6,13 +6,13 @@ NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} ...@@ -6,13 +6,13 @@ NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
# Define the test cases to run # Define the test cases to run
TEST_CASES=( TEST_CASES=(
# "test_te_bf16" "test_te_bf16"
"test_te_delayed_scaling_fp8" "test_te_delayed_scaling_fp8"
# "test_te_current_scaling_fp8" "test_te_current_scaling_fp8"
# "test_te_mxfp8" "test_te_mxfp8"
# "test_te_bf16_shardy" "test_te_bf16_shardy"
"test_te_delayed_scaling_fp8_shardy" "test_te_delayed_scaling_fp8_shardy"
# "test_te_current_scaling_fp8_shardy" "test_te_current_scaling_fp8_shardy"
) )
echo echo
...@@ -30,7 +30,7 @@ for TEST_CASE in "${TEST_CASES[@]}"; do ...@@ -30,7 +30,7 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
LOG_FILE="${TEST_CASE}_gpu_${i}.log" LOG_FILE="${TEST_CASE}_gpu_${i}.log"
# Run pytest and redirect stdout and stderr to the log file # Run pytest and redirect stdout and stderr to the log file
pytest -c "$TE_PATH/tests/jax/pytest.ini" \ pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
--num-process=$NUM_GPUS \ --num-process=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 & --process-id=$i > "$LOG_FILE" 2>&1 &
...@@ -40,21 +40,20 @@ for TEST_CASE in "${TEST_CASES[@]}"; do ...@@ -40,21 +40,20 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
wait wait
tail -n +7 "${TEST_CASE}_gpu_0.log" tail -n +7 "${TEST_CASE}_gpu_0.log"
tail -n +7 "${TEST_CASE}_gpu_0.log"
# Check and print the log content accordingly # Check and print the log content accordingly
if grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then if grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then
HAS_FAILURE=1
echo "... $TEST_CASE FAILED"
elif grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE SKIPPED" echo "... $TEST_CASE SKIPPED"
elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE PASSED" echo "... $TEST_CASE PASSED"
else else
echo "Invalid ${TEST_CASE}_gpu_0.log" HAS_FAILURE=1
echo "... $TEST_CASE FAILED"
fi fi
# Remove the log file after processing it # Remove the log file after processing it
wait
rm ${TEST_CASE}_gpu_*.log rm ${TEST_CASE}_gpu_*.log
done done
wait
exit $HAS_FAILURE exit $HAS_FAILURE
...@@ -25,6 +25,7 @@ from common import ( ...@@ -25,6 +25,7 @@ from common import (
assert_params_sufficiently_sharded, assert_params_sufficiently_sharded,
) )
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
...@@ -263,8 +264,10 @@ def train_and_evaluate(args): ...@@ -263,8 +264,10 @@ def train_and_evaluate(args):
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh( with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh, nn_partitioning.axis_rules( ) as mesh, te.fp8_autocast(
((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
): ):
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng) rng, params_rng = jax.random.split(rng)
...@@ -275,22 +278,21 @@ def train_and_evaluate(args): ...@@ -275,22 +278,21 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size] label_shape = [args.batch_size]
with te.fp8_autocast( # Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
enabled=args.use_fp8, axis_rules = flax.linen.get_logical_axis_rules()
fp8_recipe=fp8_recipe, axis_rules += ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
):
with flax.linen.logical_axis_rules(te_extended_axis_rules):
print(f"Device mesh: {mesh}")
print(f"Axis rules: {te_extended_axis_rules}")
encoder = Net(num_embed, args.enable_sp) encoder = Net(num_embed, args.enable_sp)
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
# Get the base axis rules and extend them with TE's rules.
axis_rules = nn_partitioning.get_axis_rules()
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
print(f"Device mesh: {mesh}")
print(f"Axis rules: {te_extended_axis_rules}")
logical_partition_spec = nn.get_partition_spec(abs_var_collect) logical_partition_spec = nn.get_partition_spec(abs_var_collect)
# Note that `nn.logical_to_mesh_sharding` returns a dict with an extra # Note that `nn.logical_to_mesh_sharding` returns a dict with an extra
...@@ -307,7 +309,9 @@ def train_and_evaluate(args): ...@@ -307,7 +309,9 @@ def train_and_evaluate(args):
key: params_sharding[PARAMS_KEY] if key is PARAMS_KEY else None key: params_sharding[PARAMS_KEY] if key is PARAMS_KEY else None
for key in abs_var_collect for key in abs_var_collect
} }
jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) jit_encoder_init = jax.jit(
encoder.init, in_shardings=in_shardings, out_shardings=out_shardings
)
var_collect = jit_encoder_init(init_rngs, inputs, masks) var_collect = jit_encoder_init(init_rngs, inputs, masks)
# Check if params are sufficiently sharded after initialization # Check if params are sufficiently sharded after initialization
...@@ -344,11 +348,15 @@ def train_and_evaluate(args): ...@@ -344,11 +348,15 @@ def train_and_evaluate(args):
None, None,
) )
out_shardings = (state_sharding, None, None, None) out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(train_step, in_shardings, out_shardings) jit_train_step = jax.jit(
train_step, in_shardings=in_shardings, out_shardings=out_shardings
)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (None, None) out_shardings = (None, None)
jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings
)
if args.use_fp8: if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
...@@ -459,14 +467,14 @@ class TestEncoder(unittest.TestCase): ...@@ -459,14 +467,14 @@ class TestEncoder(unittest.TestCase):
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
def setUp(self): def setUp(self):
"""Run 3 epochs for testing""" """Run 5 epochs for testing"""
self.args = encoder_parser(["--epochs", "3"]) self.args = encoder_parser(["--epochs", "5"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
...@@ -474,7 +482,7 @@ class TestEncoder(unittest.TestCase): ...@@ -474,7 +482,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
...@@ -482,14 +490,14 @@ class TestEncoder(unittest.TestCase): ...@@ -482,14 +490,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self): def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP""" """Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True self.args.enable_sp = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self): def test_te_delayed_scaling_fp8_with_sp(self):
...@@ -498,7 +506,7 @@ class TestEncoder(unittest.TestCase): ...@@ -498,7 +506,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self): def test_te_mxfp8_with_sp(self):
...@@ -507,14 +515,14 @@ class TestEncoder(unittest.TestCase): ...@@ -507,14 +515,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self): def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
self.args.enable_shardy = True self.args.enable_shardy = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self): def test_te_delayed_scaling_fp8_shardy(self):
...@@ -523,7 +531,7 @@ class TestEncoder(unittest.TestCase): ...@@ -523,7 +531,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self): def test_te_delayed_scaling_fp8_with_sp_shardy(self):
...@@ -533,9 +541,32 @@ class TestEncoder(unittest.TestCase): ...@@ -533,9 +541,32 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.39 and actual[1] > 0.83
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_with_sp_shardy(self):
"""Test Transformer Engine with MXFP8 + SP"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -21,6 +21,7 @@ from jax.sharding import PartitionSpec, NamedSharding ...@@ -21,6 +21,7 @@ from jax.sharding import PartitionSpec, NamedSharding
from common import is_bf16_supported, get_fp8_recipe_from_name_string from common import is_bf16_supported, get_fp8_recipe_from_name_string
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
...@@ -258,7 +259,13 @@ def train_and_evaluate(args): ...@@ -258,7 +259,13 @@ def train_and_evaluate(args):
fp8_recipe = None fp8_recipe = None
device_mesh = mesh_utils.create_device_mesh((num_gpu,)) device_mesh = mesh_utils.create_device_mesh((num_gpu,))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh: with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)
) as mesh, te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None),
):
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng) rng, params_rng = jax.random.split(rng)
...@@ -269,17 +276,14 @@ def train_and_evaluate(args): ...@@ -269,17 +276,14 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size] label_shape = [args.batch_size]
with te.fp8_autocast( # Add TE logical axis rules to our Flax logical axis rule context. This must be done inside fp8_autocast
enabled=args.use_fp8, sharding_rules = te_flax.extend_logical_axis_rules(tuple())
fp8_recipe=fp8_recipe, with flax.linen.logical_axis_rules(sharding_rules):
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None),
):
encoder = Net(num_embed) encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
sharding_rules = te_flax.extend_logical_axis_rules(tuple())
params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None)) inputs_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None))
masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None)) masks_sharding = NamedSharding(mesh, PartitionSpec(DEVICE_DP_AXIS, None, None, None))
...@@ -288,7 +292,9 @@ def train_and_evaluate(args): ...@@ -288,7 +292,9 @@ def train_and_evaluate(args):
out_shardings = { out_shardings = {
key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
} }
jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) jit_encoder_init = jax.jit(
encoder.init, in_shardings=in_shardings, out_shardings=out_shardings
)
var_collect = jit_encoder_init(init_rngs, inputs, masks) var_collect = jit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr) optimizer = optax.adamw(args.lr)
...@@ -312,11 +318,15 @@ def train_and_evaluate(args): ...@@ -312,11 +318,15 @@ def train_and_evaluate(args):
None, None,
) )
out_shardings = (state_sharding, None, None, None) out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(train_step, in_shardings, out_shardings) jit_train_step = jax.jit(
train_step, in_shardings=in_shardings, out_shardings=out_shardings
)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (None, None) out_shardings = (None, None)
jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings
)
if args.use_fp8: if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
...@@ -424,14 +434,14 @@ class TestEncoder(unittest.TestCase): ...@@ -424,14 +434,14 @@ class TestEncoder(unittest.TestCase):
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
def setUp(self): def setUp(self):
"""Run 3 epochs for testing""" """Run 5 epochs for testing"""
self.args = encoder_parser(["--epochs", "3"]) self.args = encoder_parser(["--epochs", "5"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.52 and actual[1] > 0.74
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
...@@ -439,7 +449,7 @@ class TestEncoder(unittest.TestCase): ...@@ -439,7 +449,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.52 and actual[1] > 0.74
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self): def test_te_current_scaling_fp8(self):
...@@ -447,7 +457,7 @@ class TestEncoder(unittest.TestCase): ...@@ -447,7 +457,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling" self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.52 and actual[1] > 0.74
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
...@@ -455,14 +465,14 @@ class TestEncoder(unittest.TestCase): ...@@ -455,14 +465,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.52 and actual[1] > 0.74
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self): def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
self.args.enable_shardy = True self.args.enable_shardy = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.52 and actual[1] > 0.74
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self): def test_te_delayed_scaling_fp8_shardy(self):
...@@ -471,9 +481,7 @@ class TestEncoder(unittest.TestCase): ...@@ -471,9 +481,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.52 and actual[1] > 0.74
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8_shardy(self): def test_te_current_scaling_fp8_shardy(self):
...@@ -482,7 +490,19 @@ class TestEncoder(unittest.TestCase): ...@@ -482,7 +490,19 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling" self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.52 and actual[1] > 0.74
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -28,8 +28,8 @@ from common import ( ...@@ -28,8 +28,8 @@ from common import (
get_fp8_recipe_from_name_string, get_fp8_recipe_from_name_string,
) )
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
...@@ -379,8 +379,11 @@ def train_and_evaluate(args): ...@@ -379,8 +379,11 @@ def train_and_evaluate(args):
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh( with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
) as mesh: ) as mesh, te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
):
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng) rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng)
...@@ -390,18 +393,18 @@ def train_and_evaluate(args): ...@@ -390,18 +393,18 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size] label_shape = [args.batch_size]
with te.fp8_autocast( # Create custom Flax logical axis rules for sharding.
enabled=args.use_fp8, customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
fp8_recipe=fp8_recipe, # Extend the logical axis rules with TE's rules. This must be done inside fp8_autocast.
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), sharding_rules = te_flax.extend_logical_axis_rules(customized_rules)
):
with flax.linen.logical_axis_rules(sharding_rules):
encoder = Net(num_embed) encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks) abs_var_collect = jax.eval_shape(encoder.init, init_rngs, inputs, masks)
customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS))
sharding_rules = te_flax.extend_logical_axis_rules(tuple()) + customized_rules
params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh) params_sharding = get_params_sharding(sharding_rules, abs_var_collect, mesh)
inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) inputs_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None) masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
...@@ -412,7 +415,9 @@ def train_and_evaluate(args): ...@@ -412,7 +415,9 @@ def train_and_evaluate(args):
out_shardings = { out_shardings = {
key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect key: params_sharding if key is PARAMS_KEY else None for key in abs_var_collect
} }
jit_encoder_init = jax.jit(encoder.init, in_shardings, out_shardings) jit_encoder_init = jax.jit(
encoder.init, in_shardings=in_shardings, out_shardings=out_shardings
)
var_collect = jit_encoder_init(init_rngs, inputs, masks) var_collect = jit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr) optimizer = optax.adamw(args.lr)
...@@ -432,11 +437,15 @@ def train_and_evaluate(args): ...@@ -432,11 +437,15 @@ def train_and_evaluate(args):
None, None,
) )
out_shardings = (state_sharding, None, None, None) out_shardings = (state_sharding, None, None, None)
jit_train_step = jax.jit(train_step, in_shardings, out_shardings) jit_train_step = jax.jit(
train_step, in_shardings=in_shardings, out_shardings=out_shardings
)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
out_shardings = (None, None) out_shardings = (None, None)
jit_eval_step = jax.jit(eval_step, in_shardings, out_shardings) jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings
)
if args.use_fp8: if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
...@@ -578,8 +587,8 @@ class TestEncoder(unittest.TestCase): ...@@ -578,8 +587,8 @@ class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False):
"""Run 3 epochs for testing""" """Run 5 epochs for testing"""
args = encoder_parser([]) args = encoder_parser(["--epochs", "5"])
num_gpu = self.num_process num_gpu = self.num_process
tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1 tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1
...@@ -601,7 +610,7 @@ class TestEncoder(unittest.TestCase): ...@@ -601,7 +610,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
result = self.exec(False, None) result = self.exec(False, None)
assert result[0] < 0.505 and result[1] > 0.755 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf( @unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
...@@ -609,7 +618,7 @@ class TestEncoder(unittest.TestCase): ...@@ -609,7 +618,7 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8""" """Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling") result = self.exec(True, "DelayedScaling")
assert result[0] < 0.506 and result[1] > 0.753 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf( @unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
...@@ -617,7 +626,7 @@ class TestEncoder(unittest.TestCase): ...@@ -617,7 +626,7 @@ class TestEncoder(unittest.TestCase):
def test_te_current_scaling_fp8(self): def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8""" """Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling") result = self.exec(True, "Float8CurrentScaling")
assert result[0] < 0.507 and result[1] > 0.753 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf( @unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
...@@ -625,13 +634,13 @@ class TestEncoder(unittest.TestCase): ...@@ -625,13 +634,13 @@ class TestEncoder(unittest.TestCase):
def test_te_mxfp8(self): def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8""" """Test Transformer Engine with MXFP8"""
result = self.exec(True, "MXFP8BlockScaling") result = self.exec(True, "MXFP8BlockScaling")
assert result[0] < 0.505 and result[1] > 0.754 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self): def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
result = self.exec(False, None, enable_shardy=True) result = self.exec(False, None, enable_shardy=True)
assert result[0] < 0.505 and result[1] > 0.755 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf( @unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
...@@ -639,9 +648,7 @@ class TestEncoder(unittest.TestCase): ...@@ -639,9 +648,7 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8_shardy(self): def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8""" """Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling", enable_shardy=True) result = self.exec(True, "DelayedScaling", enable_shardy=True)
assert result[0] < 0.506 and result[1] > 0.753 assert result[0] < 0.43 and result[1] > 0.80
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
@unittest.skipIf( @unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
...@@ -649,7 +656,18 @@ class TestEncoder(unittest.TestCase): ...@@ -649,7 +656,18 @@ class TestEncoder(unittest.TestCase):
def test_te_current_scaling_fp8_shardy(self): def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8""" """Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) result = self.exec(True, "Float8CurrentScaling", enable_shardy=True)
assert result[0] < 0.507 and result[1] > 0.753 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True)
assert result[0] < 0.43 and result[1] > 0.80
if __name__ == "__main__": if __name__ == "__main__":
......
datasets datasets<4.0.0
flax>=0.7.1 flax>=0.7.1
optax optax
Pillow Pillow
...@@ -6,7 +6,7 @@ set -e ...@@ -6,7 +6,7 @@ set -e
# Find TE # Find TE
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
TE_LIB_PATH=`pip3 show transformer-engine | grep Location | cut -d ' ' -f 2` TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}')
export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH
# Set parallelization parameters # Set parallelization parameters
......
...@@ -24,11 +24,11 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa ...@@ -24,11 +24,11 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa
# Make encoder tests to have run-to-run deterministic to have the stable CI results # Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
# python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
# wait wait
# python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
# wait wait
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
if [ $RET -ne 0 ]; then if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES" echo "Error: some sub-tests failed: $FAILED_CASES"
......
...@@ -20,7 +20,8 @@ pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TE ...@@ -20,7 +20,8 @@ pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TE
pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
# standard numerics tests with initialized debug # standard sanity and numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
exit $FAIL exit $FAIL
...@@ -23,6 +23,8 @@ set -x ...@@ -23,6 +23,8 @@ set -x
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
pip3 install onnxruntime==1.20.1 || error_exit "Failed to install onnxruntime"
pip3 install onnxruntime_extensions==0.13.0 || error_exit "Failed to install onnxruntime_extensions"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
...@@ -43,6 +45,7 @@ NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_ ...@@ -43,6 +45,7 @@ NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py || test_fail "test_onnx_export.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
...@@ -50,6 +53,8 @@ NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp ...@@ -50,6 +53,8 @@ NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
......
...@@ -18,10 +18,10 @@ sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); pri ...@@ -18,10 +18,10 @@ sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); pri
export FLASH_ATTN_CUDA_ARCHS=$sm_arch export FLASH_ATTN_CUDA_ARCHS=$sm_arch
if [ $sm_arch -gt 90 ] if [ $sm_arch -gt 90 ]
then then
FA_versions=(2.7.3) FA_versions=(2.8.1)
elif [ $sm_arch -eq 90 ] elif [ $sm_arch -eq 90 ]
then then
FA_versions=(2.5.7 2.7.3 3.0.0b1) FA_versions=(2.7.3 2.8.1 3.0.0b1)
fi fi
for fa_version in "${FA_versions[@]}" for fa_version in "${FA_versions[@]}"
......
...@@ -66,11 +66,13 @@ enable_testing() ...@@ -66,11 +66,13 @@ enable_testing()
include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
if(NOT DEFINED TE_LIB_PATH) if(NOT DEFINED TE_LIB_PATH)
execute_process(COMMAND bash -c "pip3 show transformer-engine | grep Location | cut -d ' ' -f 2 | tr -d '\n'" execute_process(COMMAND bash -c "python3 -c 'import transformer_engine as te; print(te.__file__)'"
OUTPUT_VARIABLE TE_LIB_PATH) OUTPUT_VARIABLE TE_LIB_FILE
OUTPUT_STRIP_TRAILING_WHITESPACE)
get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY)
endif() endif()
find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/transformer_engine" ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED)
message(STATUS "Found transformer_engine library: ${TE_LIB}") message(STATUS "Found transformer_engine library: ${TE_LIB}")
include_directories(../../transformer_engine/common/include) include_directories(../../transformer_engine/common/include)
......
...@@ -22,8 +22,10 @@ list(APPEND test_cuda_sources ...@@ -22,8 +22,10 @@ list(APPEND test_cuda_sources
test_act.cu test_act.cu
test_normalization.cu test_normalization.cu
test_normalization_mxfp8.cu test_normalization_mxfp8.cu
test_memset.cu
test_multi_cast_transpose.cu test_multi_cast_transpose.cu
test_multi_padding.cu test_multi_padding.cu
test_multi_unpadding.cu
test_causal_softmax.cu test_causal_softmax.cu
test_swizzle.cu test_swizzle.cu
../test_common.cu) ../test_common.cu)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment