Commit 53fa872c authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_release_v2.8' into release_v2.8

parents 27ddce40 40c69e75
...@@ -19,7 +19,7 @@ jobs: ...@@ -19,7 +19,7 @@ jobs:
run: | run: |
apt-get update apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12 apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake==3.21.0 pybind11[global] ninja pip install cmake==3.21.0 pybind11[global] ninja nvidia-mathdx==25.1.1
- name: 'Checkout' - name: 'Checkout'
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
...@@ -43,7 +43,7 @@ jobs: ...@@ -43,7 +43,7 @@ jobs:
run: | run: |
apt-get update apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12 apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript nvidia-mathdx==25.1.1
- name: 'Checkout' - name: 'Checkout'
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
...@@ -63,7 +63,7 @@ jobs: ...@@ -63,7 +63,7 @@ jobs:
options: --user root options: --user root
steps: steps:
- name: 'Dependencies' - name: 'Dependencies'
run: pip install pybind11[global] run: pip install pybind11[global] nvidia-mathdx==25.1.1
- name: 'Checkout' - name: 'Checkout'
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
...@@ -83,7 +83,7 @@ jobs: ...@@ -83,7 +83,7 @@ jobs:
options: --user root options: --user root
steps: steps:
- name: 'Dependencies' - name: 'Dependencies'
run: pip install torch pybind11[global] einops onnxscript run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1
- name: 'Checkout' - name: 'Checkout'
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import argparse
import torch
import pandas as pd
import torch.utils.benchmark as benchmark
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
scale_padding_to = 1
permute_scale = False
TORCH_TO_TE_FLOAT_MAP = {
torch.bfloat16: tex.DType.kBFloat16,
}
def run_kernel(shape, stochastic_rounding: bool, input_dtype=torch.bfloat16):
# Generate random input data
M, K = shape
x = torch.randn([M, K], dtype=input_dtype, device="cuda")
assert shape[0] % 16 == 0, "Shape must be divisible by 16"
assert shape[1] % 16 == 0, "Shape must be divisible by 16"
# Quantize
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=True,
with_post_rht_amax=True,
with_random_sign_mask=True,
stochastic_rounding=stochastic_rounding,
)
x_nvfp4_sut = nvfp4_quantizer.make_empty(
(M, K), dtype=x.dtype, device=x.device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
with torch.no_grad():
stmt = "kernel_func(input, output)"
globals_dict = {
"kernel_func": nvfp4_quantizer.update_quantized,
"input": x,
"output": x_nvfp4_sut,
}
timing = benchmark.Timer(
stmt=stmt,
globals=globals_dict,
num_threads=1,
).blocked_autorange(min_run_time=5)
print(timing)
timing_us = timing.median * 1e6
input_nbytes = shape[0] * shape[1] * 2 # bf16
output_nbytes = shape[0] * shape[1] // 2 # //2 for fp4
sf_nbytes = shape[0] * shape[1] // 16 # //16 for 1 byte per 16 elems
total_nbytes = (
0
+ input_nbytes
* 3 # Reading input for Amax(x)&Amax(RHT(x.T)), Reading input for Cast(x), Reaindg input for Cast(RHT(x.T))
+ 2 * 4 # Output 2 * float for scale & amax
+ 2 * 4 # Input 2 * float
+ output_nbytes * 2 # Output from Cast(x) and Cast(RHT(x.T))
+ sf_nbytes * 2 # Scale factor
)
throughput_GBps = total_nbytes / (1024 * 1024 * 1024) / (timing_us / 1e6)
print(
f"Stochastic rounding: {stochastic_rounding}, Total: {total_nbytes} bytes, Throughput:"
f" {throughput_GBps} GB/s"
)
return timing_us, throughput_GBps
# Nsight Compute Profiling Command:
# ncu -f -o block_scaled_1d_cast_transpose_kernel --set=full --kernel-name "block_scaled_1d_cast_transpose_kernel" -s 5 -c 5 python benchmark_cast_transpose_1d_block.py --profile
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
args = parser.parse_args()
if args.profile:
print("Profiling is enabled.")
else:
print("Profiling is disabled.")
shapes = [
(8192, 5120),
(8192, 10240),
(8192, 2560),
(8192, 11328),
(8192, 512),
(8192, 3584),
(5120, 8192),
(10240, 8192),
(2560, 8192),
(11328, 8192),
(512, 8192),
(3584, 8192),
(4096, 16384),
(14336, 16384),
]
if args.profile:
shapes = [
(16384, 6144),
]
data = []
for stochastic_rounding in [True]: # , False]:
for shape in shapes:
print(
f"Running benchmark_func with shape {shape} and stochastic_rounding"
f" {stochastic_rounding}"
)
timing_us, throughput_GBps = run_kernel(shape, stochastic_rounding)
data.append(
[
"benchmark_func",
shape,
stochastic_rounding,
timing_us,
throughput_GBps,
]
)
df = pd.DataFrame(
data=data,
columns=[
"kernel",
"shape",
"stochastic_rounding",
"timing_us",
"throughput(GB/s)",
],
)
print(df)
df.to_csv("benchmark_cast_nvfp4.csv", index=False)
...@@ -87,4 +87,5 @@ def setup_jax_extension( ...@@ -87,4 +87,5 @@ def setup_jax_extension(
sources=[str(path) for path in sources], sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs], include_dirs=[str(path) for path in include_dirs],
extra_compile_args=cxx_flags, extra_compile_args=cxx_flags,
libraries=["nccl"],
) )
...@@ -14,7 +14,7 @@ from typing import List ...@@ -14,7 +14,7 @@ from typing import List
def install_requirements() -> List[str]: def install_requirements() -> List[str]:
"""Install dependencies for TE/PyTorch extensions.""" """Install dependencies for TE/PyTorch extensions."""
return ["torch>=2.1", "einops"] # "onnxscript==0.3.1", "onnx"] return ["torch>=2.1", "einops"] # "onnxscript", "onnx"]
def test_requirements() -> List[str]: def test_requirements() -> List[str]:
......
...@@ -272,15 +272,18 @@ def get_cuda_include_dirs() -> Tuple[str, str]: ...@@ -272,15 +272,18 @@ def get_cuda_include_dirs() -> Tuple[str, str]:
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def cuda_archs() -> str: def cuda_archs() -> str:
archs = os.getenv("NVTE_CUDA_ARCHS")
if archs is None:
version = cuda_version() version = cuda_version()
if os.getenv("NVTE_CUDA_ARCHS") is None:
if version >= (13, 0): if version >= (13, 0):
os.environ["NVTE_CUDA_ARCHS"] = "75;80;89;90;100;120" archs = "75;80;89;90;100;100a;103a;120"
elif version >= (12, 9):
archs = "70;80;89;90;100;100a;103a;120"
elif version >= (12, 8): elif version >= (12, 8):
os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90;100;120" archs = "70;80;89;90;100;100a;120"
else: else:
os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90" archs = "70;80;89;90"
return os.getenv("NVTE_CUDA_ARCHS") return archs
def cuda_version() -> Tuple[int, ...]: def cuda_version() -> Tuple[int, ...]:
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Shared functions for the comm_overlap tests"""
import jax.numpy as jnp
import numpy as np
# Add this after your existing imports
def dtype_tols(dtype, rtol=None, atol=None):
"""Expected numerical tolerance for a data type."""
# Return immediately if tolerances are fully specified
if rtol is not None and atol is not None:
return {"rtol": rtol, "atol": atol}
# Default tolerances for common dtypes
if dtype in [jnp.float32, "float32"]:
return {"rtol": 1e-5, "atol": 1e-8}
elif dtype in [jnp.float16, "float16"]:
return {"rtol": 1e-3, "atol": 1e-6}
elif dtype in [jnp.bfloat16, "bfloat16"]:
return {"rtol": 1e-2, "atol": 1e-5}
else:
return {"rtol": 1e-5, "atol": 1e-8}
def assert_allclose(
actual,
desired,
rtol=None,
atol=None,
dtype=None,
**kwargs,
):
"""Check if two tensors are close."""
# Infer data type if needed
if dtype is None:
if isinstance(actual, float):
dtype = "float32"
else:
dtype = actual.dtype
# Determine tolerances
tols = {}
if rtol is None or atol is None:
tols = dtype_tols(dtype)
if rtol is not None:
tols["rtol"] = rtol
if atol is not None:
tols["atol"] = atol
# Cast tensors to fp32
if not isinstance(actual, float):
actual = actual.astype(jnp.float32)
if not isinstance(desired, float):
desired = desired.astype(jnp.float32)
# Check if tensors are close
np.testing.assert_allclose(actual, desired, **tols, **kwargs)
def assert_allclose_print_index(ref_output, gathered_output, rtol=1e-5, atol=1e-8):
if not jnp.allclose(ref_output, gathered_output, rtol=rtol, atol=atol):
diff = jnp.abs(ref_output - gathered_output)
mask = diff > (atol + rtol * jnp.abs(gathered_output))
print(mask.astype(int))
print(jnp.where(mask, diff, 0))
# Shared constants for all tests
DP_AXIS = "data"
TPSP_AXIS = "tensor_sequence"
PARAMS_KEY = "params"
# Shared functions for distributed testing
import argparse
import jax
from jax.experimental import mesh_utils
from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap
# Global flag to track if distributed has been initialized
_distributed_initialized = False
def _is_distributed_initialized():
"""Check if JAX distributed has been initialized."""
return _distributed_initialized
def _initialize_distributed(args):
"""Initialize JAX distributed with custom arguments."""
global _distributed_initialized
# Check if already initialized
if _distributed_initialized:
return
if args.coordinator_address is None or args.num_processes is None or args.process_id is None:
raise ValueError(
"All distributed initialization arguments are required: "
"--coordinator-address, --num-processes, --process-id"
)
if args.local_device_ids is None:
assert (
args.num_devices_per_process is not None
), "Either local_device_ids or num_devices_per_process must be provided"
# Calculate device range for this process
# Single process single device: each process gets one unique device
# Single process multiple devices: each process gets a unique range of devices
start_device = args.process_id * args.num_devices_per_process
device_range = range(start_device, start_device + args.num_devices_per_process)
global_device_ids_for_this_process = ",".join(map(str, device_range))
else:
# Use explicitly provided global device IDs
global_device_ids_for_this_process = args.local_device_ids
args.num_devices_per_process = len(args.local_device_ids.split(","))
assert args.num_devices_per_process == 1, "Only single process single GPU is supported!"
print(
f"Initializing JAX distributed with coordinator={args.coordinator_address}, "
f"num_processes={args.num_processes}, process_id={args.process_id}"
)
# Note: "local_device_ids" is a JAX term meaning "global CUDA devices managed by this process"
jax.distributed.initialize(
coordinator_address=args.coordinator_address,
num_processes=args.num_processes,
process_id=args.process_id,
local_device_ids=global_device_ids_for_this_process,
)
_distributed_initialized = True
jax.clear_caches()
jax.config.update(
"jax_use_shardy_partitioner", False
) # CollectiveGEMM does not work with Shardy yet
assert jax.local_device_count() == 1, (
f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found"
f" {jax.local_device_count()}"
)
devices_per_process = 1
num_total_devices = args.num_processes
print(
f"Initializing CGEMM communicator with num_total_devices={num_total_devices},"
f" devices_per_process={devices_per_process}, process_id={args.process_id}"
)
collective_gemm_bootstrap(
num_total_devices=num_total_devices,
num_devices_per_process=devices_per_process,
process_id=args.process_id,
tensor_parallel_size=args.tensor_parallel_size,
)
def _get_dp_and_tp_sizes(args):
num_gpu = args.num_processes * args.num_devices_per_process
if args.tensor_parallel_size is None:
num_gpu_dp = 2 if args.enable_data_parallel else 1
assert (
num_gpu > 1 and num_gpu % num_gpu_dp == 0
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_tp = num_gpu // num_gpu_dp
else:
num_gpu_tp = args.tensor_parallel_size
assert (
num_gpu > 1 and num_gpu % num_gpu_tp == 0
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_dp = num_gpu // num_gpu_tp
return num_gpu_dp, num_gpu_tp
def _create_mesh(args):
"""Create mesh configuration with proper validation."""
num_gpu = args.num_processes * args.num_devices_per_process
assert num_gpu == len(jax.devices()), "Number of GPUs must be equal to number of devices"
num_gpu_dp, num_gpu_tp = _get_dp_and_tp_sizes(args)
print(f"Using {num_gpu_dp}x{num_gpu_tp} mesh ({num_gpu_dp * num_gpu_tp} total GPUs)")
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=(DP_AXIS, TPSP_AXIS))
return mesh
def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor parallelism"):
"""Create common argument parser for all collective GEMM tests."""
parser = argparse.ArgumentParser(description=description)
# Distributed initialization arguments
parser.add_argument(
"--coordinator-address",
type=str,
default=None,
help="Coordinator address for distributed initialization",
)
parser.add_argument(
"--num-processes",
type=int,
default=None,
help="Number of processes for distributed initialization",
)
parser.add_argument(
"--process-id", type=int, default=None, help="Process ID for distributed initialization"
)
parser.add_argument(
"--local-device-ids",
type=str,
default=None,
help="Local device IDs for distributed initialization (comma-separated)",
)
parser.add_argument(
"--num-devices-per-process", type=int, default=1, help="Number of devices per process"
)
# Test configuration arguments
parser.add_argument(
"--tensor-parallel-size", type=int, default=None, help="Tensor parallel size"
)
parser.add_argument("--batch-size", type=int, default=4, help="Batch size for testing")
parser.add_argument("--seq-len", type=int, default=8192, help="Sequence length for testing")
parser.add_argument("--hidden-in", type=int, default=4096, help="Input hidden dimension")
parser.add_argument("--hidden-out", type=int, default=8192, help="Output hidden dimension")
parser.add_argument(
"--collective-type",
type=str,
default="all_gather",
choices=["all_gather", "reduce_scatter"],
help="Type of collective operation",
)
parser.add_argument(
"--fp8-recipe", type=str, default="DelayedScaling", help="FP8 recipe to use"
)
parser.add_argument(
"--enable-data-parallel", action="store_true", help="Enable data parallelism"
)
parser.add_argument(
"--enable-result-check", action="store_true", default=True, help="Enable result checking"
)
return parser
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""config for collective_gemm tests"""
import pytest
def pytest_addoption(parser):
"""Pytest hook for collective_gemm tests"""
parser.addoption("--coordinator-address", action="store", default="localhost:12345")
parser.addoption("--num-processes", action="store", default=1)
parser.addoption("--process-id", action="store", default=0)
parser.addoption("--local-device-ids", action="store", default=None)
@pytest.fixture(autouse=True)
def distributed_args(request):
"""Fixture for querying distributed initialization arguments"""
if request.cls:
request.cls.coordinator_address = request.config.getoption("--coordinator-address")
request.cls.num_processes = int(request.config.getoption("--num-processes"))
request.cls.process_id = int(request.config.getoption("--process-id"))
request.cls.local_device_ids = request.config.getoption("--local-device-ids")
request.cls.num_devices_per_process = (
1
if request.cls.local_device_ids is None
else len(request.cls.local_device_ids.split(","))
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
# Check if NVLINK is supported before running tests
echo "*** Checking NVLINK support***"
NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1)
NVLINK_EXIT_CODE=$?
# Check if command failed OR output indicates no NVLINK
if [ $NVLINK_EXIT_CODE -ne 0 ] || [[ "$NVLINK_OUTPUT" == *"not supported"* ]] || [[ "$NVLINK_OUTPUT" == *"No devices"* ]] || [ -z "$NVLINK_OUTPUT" ]; then
echo "NVLINK is not supported on this platform"
echo "Collective GEMM tests require NVLINK connectivity"
echo "SKIPPING all tests"
exit 0
else
echo "NVLINK support detected"
fi
# Define the test files to run
TEST_FILES=(
"test_gemm.py"
"test_dense_grad.py"
"test_layernorm_mlp_grad.py"
)
echo
echo "*** Executing tests in examples/jax/collective_gemm/ ***"
HAS_FAILURE=0 # Global failure flag
PIDS=() # Array to store all process PIDs
# Cleanup function to kill all processes
cleanup() {
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Killing process $pid"
kill -TERM "$pid" 2>/dev/null || true
fi
done
# Wait a bit and force kill if needed
sleep 2
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Force killing process $pid"
kill -KILL "$pid" 2>/dev/null || true
fi
done
}
# Set up signal handlers to cleanup on exit
trap cleanup EXIT INT TERM
# Run each test file across all GPUs
for TEST_FILE in "${TEST_FILES[@]}"; do
echo
echo "=== Starting test file: $TEST_FILE ..."
# Clear PIDs array for this test file
PIDS=()
for i in $(seq 0 $(($NUM_GPUS - 1))); do
# Define output file for logs
LOG_FILE="${TEST_FILE}_gpu_${i}.log"
if [ $i -eq 0 ]; then
# For process 0: show live output AND save to log file using tee
echo "=== Live output from process 0 ==="
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \
"$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
--num-processes=$NUM_GPUS \
--process-id=$i 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
else
# For other processes: redirect to log files only
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
--num-processes=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
PID=$!
PIDS+=($PID)
fi
done
# Wait for all processes to finish
wait
# Check and print the log content from process 0 (now has log file thanks to tee)
if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE SKIPPED"
elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE FAILED"
HAS_FAILURE=1
elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE PASSED"
else
echo "... $TEST_FILE INVALID"
HAS_FAILURE=1
fi
# Remove the log files after processing them
wait
rm ${TEST_FILE}_gpu_*.log
done
wait
# Final cleanup (trap will also call cleanup on exit)
cleanup
exit $HAS_FAILURE
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Collective Dense Gradient test on multi-GPU with tensor parallelism"""
import argparse
import unittest
import os
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding
import flax
from common import (
assert_allclose,
_initialize_distributed,
_get_dp_and_tp_sizes,
_create_mesh,
DP_AXIS,
TPSP_AXIS,
PARAMS_KEY,
cgemm_parser,
)
from transformer_engine.jax.dense import dense
from transformer_engine.jax.quantize import fp8_autocast
from transformer_engine.jax.cpp_extensions.gemm import (
CollectiveOp,
CollectiveOpSet,
noop_collective_op_set,
)
from transformer_engine.jax.sharding import MeshResource
import transformer_engine.jax.flax as te_flax
def _get_logical_axes(collective_op):
if collective_op.is_all_gather:
input_axes = (DP_AXIS, TPSP_AXIS, None)
weight_axes = (None, TPSP_AXIS)
bias_axes = (TPSP_AXIS,)
output_axes = (DP_AXIS, None, TPSP_AXIS)
else: # RS
input_axes = (DP_AXIS, None, TPSP_AXIS)
weight_axes = (TPSP_AXIS, None)
bias_axes = (None,)
output_axes = (DP_AXIS, TPSP_AXIS, None)
return input_axes, weight_axes, bias_axes, output_axes
def _get_operand_sharding(mesh, collective_op):
input_axes, weight_axes, bias_axes, _ = _get_logical_axes(collective_op)
x_sharding = NamedSharding(mesh, PartitionSpec(*input_axes))
weight_sharding = NamedSharding(mesh, PartitionSpec(*weight_axes))
bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes))
return x_sharding, weight_sharding, bias_sharding
def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set):
output = dense(
x,
weight,
bias,
contracting_dims=((2,), (0,)),
input_axes=input_axes,
kernel_axes=weight_axes,
output_axes=output_axes,
collective_op_set=collective_op_set,
)
return jnp.mean(output.astype(jnp.float32))
def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set):
return jax.jit(jax.value_and_grad(_mean_dense, (0, 1, 2)), static_argnums=(3, 4, 5, 6))(
x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set
)
def run_dense_grad_tests(args, mesh=None):
"""Execute Dense Gradient tests."""
print(args)
_initialize_distributed(args)
mesh = mesh or _create_mesh(args)
# Create test data
rng = jax.random.PRNGKey(0)
rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4)
x = jax.random.normal(
x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16
)
weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16)
bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16)
collective_op = (
CollectiveOp.ALL_GATHER
if args.collective_type == "all_gather"
else CollectiveOp.REDUCE_SCATTER
)
collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op)
with mesh, fp8_autocast(
enabled=False,
fp8_recipe=None,
mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
):
# Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
axis_rules = flax.linen.get_logical_axis_rules()
axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS))
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
with flax.linen.logical_axis_rules(te_extended_axis_rules):
x_sharding, weight_sharding, bias_sharding = _get_operand_sharding(mesh, collective_op)
x_sharded = jax.device_put(x, x_sharding)
weight_sharded = jax.device_put(weight, weight_sharding)
bias_sharded = jax.device_put(bias, bias_sharding)
input_axes, weight_axes, _, output_axes = _get_logical_axes(collective_op)
ref_output, ref_grads = _value_and_grad_dense(
x_sharded,
weight_sharded,
bias_sharded,
input_axes,
weight_axes,
output_axes,
noop_collective_op_set,
)
output, sharded_grads = _value_and_grad_dense(
x_sharded,
weight_sharded,
bias_sharded,
input_axes,
weight_axes,
output_axes,
collective_op_set,
)
jax.block_until_ready(ref_output)
jax.block_until_ready(output)
gathered_grads = []
gathered_ref_grads = []
for ref_grad, grad in zip(ref_grads, sharded_grads):
gathered_grads.append(
jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None)))
)
gathered_ref_grads.append(
jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None)))
)
jax.block_until_ready(gathered_grads)
jax.block_until_ready(gathered_ref_grads)
if args.enable_result_check and args.process_id == 0:
assert_allclose(ref_output, output, dtype=jnp.bfloat16)
for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads):
assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16)
class TestCollectiveDenseGradient(unittest.TestCase):
"""Collective Dense Gradient unittests"""
def setUp(self):
self.args = cgemm_parser(
"Collective Dense Gradient test on multi-GPU with tensor parallelism"
).parse_args([])
self.args.coordinator_address = self.coordinator_address
self.args.num_processes = self.num_processes
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
# Create mesh once for all tests
self.mesh = _create_mesh(self.args)
jax.sharding.set_mesh(self.mesh)
self.args.enable_result_check = True
os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1"
def tearDown(self):
os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None)
def test_te_bf16_all_gather(self):
"""Test Collective Dense Gradient with AllGather"""
self.args.collective_type = "all_gather"
run_dense_grad_tests(self.args, self.mesh)
def test_te_bf16_reduce_scatter(self):
"""Test Collective Dense Gradient with ReduceScatter"""
self.args.collective_type = "reduce_scatter"
run_dense_grad_tests(self.args, self.mesh)
if __name__ == "__main__":
import sys
if len(sys.argv) < 7: # Need at least the 3 required distributed args
print("Error: This script requires distributed initialization arguments.")
print(
"Usage: python test_dense_grad.py --coordinator-address <address> --num-processes <num>"
" --process-id <id> [--local-device-ids <ids>] [other args]"
)
print(
"Example: python test_dense_grad.py --coordinator-address localhost:1234"
" --num-processes 4 --process-id 0"
)
print(
"Example: python test_dense_grad.py --coordinator-address localhost:1234"
" --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3"
)
sys.exit(1)
args = cgemm_parser(
"Collective Dense Gradient test on multi-GPU with tensor parallelism"
).parse_args([])
_initialize_distributed(args)
run_dense_grad_tests(args, mesh=None)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Collective GEMM test on multi-GPU with tensor parallelism
This script uses custom distributed initialization with the following arguments:
- --coordinator-address: Coordinator address for distributed initialization
- --num-processes: Number of processes for distributed initialization
- --process-id: Process ID for distributed initialization
- --local-device-ids: Local device IDs for distributed initialization
Example:
python test_gemm.py --coordinator-address localhost:1234 --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3
"""
import unittest
import os
from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding
from common import (
assert_allclose,
_initialize_distributed,
_get_dp_and_tp_sizes,
_create_mesh,
DP_AXIS,
TPSP_AXIS,
PARAMS_KEY,
cgemm_parser,
)
import transformer_engine.jax.cpp_extensions as tex
from transformer_engine.jax.quantize import fp8_autocast
from transformer_engine.jax.cpp_extensions.gemm import CollectiveOp
from transformer_engine.jax.sharding import MeshResource
def _get_operand_sharding(mesh, collective_op, is_with_dp):
dp_axis = DP_AXIS if is_with_dp else None
if collective_op == CollectiveOp.ALL_GATHER:
x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None))
weight_sharding = NamedSharding(mesh, PartitionSpec(None, TPSP_AXIS))
bias_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS))
output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS))
else: # RS
x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS))
weight_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS, None))
bias_sharding = NamedSharding(mesh, PartitionSpec(None))
output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None))
return x_sharding, weight_sharding, bias_sharding, output_sharding
def _get_dp_and_tp_sizes(args):
num_gpu = args.num_processes * args.num_devices_per_process
if args.tensor_parallel_size is None:
num_gpu_dp = 2 if args.enable_data_parallel else 1
assert (
num_gpu > 1 and num_gpu % num_gpu_dp == 0
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_tp = num_gpu // num_gpu_dp
else:
num_gpu_tp = args.tensor_parallel_size
assert (
num_gpu > 1 and num_gpu % num_gpu_tp == 0
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_dp = num_gpu // num_gpu_tp
return num_gpu_dp, num_gpu_tp
@partial(jax.jit, static_argnames=("contracting_dims", "collective_op", "output_sharding"))
def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_sharding):
output = tex.gemm(
x,
weight,
bias=bias,
contracting_dims=contracting_dims,
collective_op=collective_op,
)
if output_sharding is not None:
output = jax.lax.with_sharding_constraint(output, output_sharding)
return output
def run_gemm_tests(args, mesh=None):
"""Execute GEMM tests."""
print(args)
# Collective GEMM requires Shardy partitioner to be disabled
jax.config.update("jax_use_shardy_partitioner", False)
# Initialize distributed with provided arguments
_initialize_distributed(args)
mesh = mesh or _create_mesh(args)
# Create test data
rng = jax.random.PRNGKey(0)
rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4)
x = jax.random.normal(
x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16
)
weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16)
bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16)
collective_op = (
CollectiveOp.ALL_GATHER
if args.collective_type == "all_gather"
else CollectiveOp.REDUCE_SCATTER
)
with mesh, fp8_autocast(
enabled=False,
fp8_recipe=None,
mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
):
print(f"Device mesh: {mesh}")
x_sharding, weight_sharding, bias_sharding, output_sharding = _get_operand_sharding(
mesh, collective_op, args.enable_data_parallel
)
x_sharded = jax.device_put(x, x_sharding)
weight_sharded = jax.device_put(weight, weight_sharding)
bias_sharded = jax.device_put(bias, bias_sharding)
ref_output = _jitted_cgemm(
x_sharded,
weight_sharded,
bias_sharded,
contracting_dims=((2,), (0,)),
collective_op=CollectiveOp.NONE,
output_sharding=output_sharding,
)
output = _jitted_cgemm(
x_sharded,
weight_sharded,
bias_sharded,
contracting_dims=((2,), (0,)),
collective_op=collective_op,
# CollectiveGEMM output should have a correct sharding without applying sharding constraint
output_sharding=None,
)
assert (
ref_output.sharding == output.sharding
), f"ref_output.sharding={ref_output.sharding}, output.sharding={output.sharding}"
gathered_ref_output = jax.lax.with_sharding_constraint(
ref_output, NamedSharding(mesh, PartitionSpec(None))
)
gathered_output = jax.lax.with_sharding_constraint(
output, NamedSharding(mesh, PartitionSpec(None))
)
jax.block_until_ready(gathered_ref_output)
jax.block_until_ready(gathered_output)
if args.enable_result_check and args.process_id == 0:
assert_allclose(gathered_ref_output, gathered_output)
class TestCollectiveGemmWithDP(unittest.TestCase):
"""Collective GEMM with DP unittests"""
def setUp(self):
self.args = cgemm_parser(
"Collective GEMM test on multi-GPU with tensor parallelism"
).parse_args([])
self.args.coordinator_address = self.coordinator_address
self.args.num_processes = self.num_processes
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
self.mesh = _create_mesh(self.args)
jax.sharding.set_mesh(self.mesh)
self.args.enable_result_check = True
os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1"
def tearDown(self):
os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None)
def test_te_bf16_all_gather_with_dp(self):
"""Test Collective GEMM with AllGather"""
self.args.collective_type = "all_gather"
run_gemm_tests(self.args, self.mesh)
def test_te_bf16_reduce_scatter_with_dp(self):
"""Test Collective GEMM with ReduceScatter"""
self.args.collective_type = "reduce_scatter"
run_gemm_tests(self.args, self.mesh)
if __name__ == "__main__":
import sys
if len(sys.argv) < 5: # Need at least the 3 required distributed args
print("Error: This script requires distributed initialization arguments.")
print(
"Usage: python test_gemm.py --coordinator-address <address> --num-processes <num>"
" --process-id <id> [--local-device-ids <ids>] [other args]"
)
sys.exit(1)
args = cgemm_parser("Collective GEMM test on multi-GPU with tensor parallelism").parse_args()
_initialize_distributed(args)
run_gemm_tests(args, mesh=None)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Collective Dense Gradient test on multi-GPU with tensor parallelism"""
import argparse
import unittest
import os
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding
import flax
from common import (
assert_allclose,
_initialize_distributed,
_get_dp_and_tp_sizes,
_create_mesh,
DP_AXIS,
TPSP_AXIS,
PARAMS_KEY,
cgemm_parser,
)
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
from transformer_engine.jax.quantize import fp8_autocast
from transformer_engine.jax.cpp_extensions.gemm import (
CollectiveOpSet,
CollectiveOp,
noop_collective_op_set,
)
from transformer_engine.jax.sharding import MeshResource
import transformer_engine.jax.flax as te_flax
def _get_logical_axes():
input_1_axes = (DP_AXIS, TPSP_AXIS, None)
weight_1_axes = (None, None, TPSP_AXIS)
bias_axes_1 = (None, TPSP_AXIS)
input_2_axes = (DP_AXIS, None, TPSP_AXIS)
weight_2_axes = (TPSP_AXIS, None)
bias_axes_2 = (None,)
return input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2
def _get_operand_sharding(mesh):
input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2 = (
_get_logical_axes()
)
x_sharding = NamedSharding(mesh, PartitionSpec(*input_1_axes))
weight_1_sharding = NamedSharding(mesh, PartitionSpec(*weight_1_axes))
bias_1_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_1))
weight_2_sharding = NamedSharding(mesh, PartitionSpec(*weight_2_axes))
bias_2_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_2))
return x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding
def _mean_layernorm_mlp(
x,
weight_1,
bias_1,
weight_2,
bias_2,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
collective_op_sets,
):
output = layernorm_mlp(
x,
gamma,
beta=None,
kernels=[weight_1, weight_2],
biases=[bias_1, bias_2],
norm_type="rmsnorm",
dot_1_input_axes=input_1_axes,
dot_2_input_axes=input_2_axes,
kernel_1_axes=weight_1_axes,
kernel_2_axes=weight_2_axes,
activation_type=("gelu",),
collective_op_sets=collective_op_sets,
)
return jnp.mean(output)
def _value_and_grad_layernorm_mlp(
x,
weight_1,
bias_1,
weight_2,
bias_2,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
collective_op_sets,
):
return jax.jit(
jax.value_and_grad(_mean_layernorm_mlp, (0, 1, 2, 3, 4, 5)), static_argnums=(6, 7, 8, 9, 10)
)(
x,
weight_1,
bias_1,
weight_2,
bias_2,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
collective_op_sets,
)
def run_layernorm_mlp_grad_tests(args, mesh=None):
"""Execute Dense Gradient tests."""
print(args)
# Collective GEMM requires Shardy partitioner to be disabled
jax.config.update("jax_use_shardy_partitioner", False)
# Initialize distributed with provided arguments
_initialize_distributed(args)
mesh = mesh or _create_mesh(args)
# Create test data
rng = jax.random.PRNGKey(0)
rng, x_rng, weight_1_rng, bias_1_rng, weight_2_rng, bias_2_rng, gamma_rng = jax.random.split(
rng, 7
)
x = jax.random.normal(
x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16
)
weight_1 = jax.random.normal(
weight_1_rng, (args.hidden_in, 1, args.hidden_out), dtype=jnp.bfloat16
) / jnp.sqrt(args.hidden_in)
bias_1 = jax.random.normal(bias_1_rng, (1, args.hidden_out), dtype=jnp.bfloat16)
weight_2 = jax.random.normal(
weight_2_rng, (args.hidden_out, args.hidden_in), dtype=jnp.bfloat16
) / jnp.sqrt(args.hidden_out)
bias_2 = jax.random.normal(bias_2_rng, (args.hidden_in,), dtype=jnp.bfloat16)
gamma = jax.random.normal(gamma_rng, (args.hidden_in,), dtype=jnp.bfloat16) / jnp.sqrt(
args.hidden_in
)
collective_op_set_1 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.ALL_GATHER)
collective_op_set_2 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.REDUCE_SCATTER)
collective_op_sets = (collective_op_set_1, collective_op_set_2)
noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set)
with mesh, fp8_autocast(
enabled=False,
fp8_recipe=None,
mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
):
# Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
axis_rules = flax.linen.get_logical_axis_rules()
axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS))
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
with flax.linen.logical_axis_rules(te_extended_axis_rules):
x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding = (
_get_operand_sharding(mesh)
)
x_sharded = jax.device_put(x, x_sharding)
weight_1_sharded = jax.device_put(weight_1, weight_1_sharding)
bias_1_sharded = jax.device_put(bias_1, bias_1_sharding)
weight_2_sharded = jax.device_put(weight_2, weight_2_sharding)
bias_2_sharded = jax.device_put(bias_2, bias_2_sharding)
input_1_axes, weight_1_axes, _, input_2_axes, weight_2_axes, _ = _get_logical_axes()
ref_output, ref_grads = _value_and_grad_layernorm_mlp(
x_sharded,
weight_1_sharded,
bias_1_sharded,
weight_2_sharded,
bias_2_sharded,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
noop_collective_op_sets,
)
output, sharded_grads = _value_and_grad_layernorm_mlp(
x_sharded,
weight_1_sharded,
bias_1_sharded,
weight_2_sharded,
bias_2_sharded,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
collective_op_sets,
)
jax.block_until_ready(ref_output)
jax.block_until_ready(output)
gathered_grads = []
gathered_ref_grads = []
for ref_grad, grad in zip(ref_grads, sharded_grads):
gathered_grads.append(
jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None)))
)
gathered_ref_grads.append(
jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None)))
)
jax.block_until_ready(gathered_grads)
jax.block_until_ready(gathered_ref_grads)
if args.enable_result_check and args.process_id == 0:
assert_allclose(ref_output, output, dtype=jnp.bfloat16)
for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads):
assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16)
class TestCollectiveLayerNormMLPGradient(unittest.TestCase):
"""Collective Dense Gradient unittests"""
def setUp(self):
self.args = cgemm_parser(
"Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism"
).parse_args([])
self.args.coordinator_address = self.coordinator_address
self.args.num_processes = self.num_processes
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
# Create mesh once for all tests
self.mesh = _create_mesh(self.args)
jax.sharding.set_mesh(self.mesh)
self.args.enable_result_check = True
os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1"
def tearDown(self):
os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None)
def test_te_bf16_layernorm_mlp_grad(self):
"""Test Collective Dense Gradient with AllGather"""
run_layernorm_mlp_grad_tests(self.args, self.mesh)
if __name__ == "__main__":
import sys
if len(sys.argv) < 7: # Need at least the 3 required distributed args
print("Error: This script requires distributed initialization arguments.")
print(
"Usage: python test_layernorm_mlp_grad.py --coordinator-address <address>"
" --num-processes <num> --process-id <id> [--local-device-ids <ids>] [other args]"
)
print(
"Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234"
" --num-processes 4 --process-id 0"
)
print(
"Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234"
" --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3"
)
sys.exit(1)
args = cgemm_parser(
"Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism"
).parse_args([])
_initialize_distributed(args)
run_layernorm_mlp_grad_tests(args, mesh=None)
...@@ -15,11 +15,37 @@ TEST_CASES=( ...@@ -15,11 +15,37 @@ TEST_CASES=(
"test_te_current_scaling_fp8_shardy" "test_te_current_scaling_fp8_shardy"
) )
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
echo echo
echo "*** Executing tests in examples/jax/encoder/test_multiprocessing_encoder.py ***" echo "*** Executing tests in examples/jax/encoder/test_multiprocessing_encoder.py ***"
HAS_FAILURE=0 # Global failure flag HAS_FAILURE=0 # Global failure flag
PIDS=() # Array to store all process PIDs
# Cleanup function to kill all processes
cleanup() {
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Killing process $pid"
kill -TERM "$pid" 2>/dev/null || true
fi
done
# Wait a bit and force kill if needed
sleep 2
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Force killing process $pid"
kill -KILL "$pid" 2>/dev/null || true
fi
done
}
# Set up signal handlers to cleanup on exit
trap cleanup EXIT INT TERM
# Run each test case across all GPUs # Run each test case across all GPUs
for TEST_CASE in "${TEST_CASES[@]}"; do for TEST_CASE in "${TEST_CASES[@]}"; do
echo echo
...@@ -29,25 +55,40 @@ for TEST_CASE in "${TEST_CASES[@]}"; do ...@@ -29,25 +55,40 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
# Define output file for logs # Define output file for logs
LOG_FILE="${TEST_CASE}_gpu_${i}.log" LOG_FILE="${TEST_CASE}_gpu_${i}.log"
# Run pytest and redirect stdout and stderr to the log file # For process 0: show live output AND save to log file using tee
if [ $i -eq 0 ]; then
echo "=== Live output from process 0 ==="
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs --junitxml=$XML_LOG_DIR/multiprocessing_encoder_${TEST_CASE}.xml \
"$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
--num-process=$NUM_GPUS \
--process-id=$i 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
else
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
--num-process=$NUM_GPUS \ --num-process=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 & --process-id=$i > "$LOG_FILE" 2>&1 &
PID=$!
PIDS+=($PID)
fi
done done
# Wait for the process to finish # Wait for the process to finish
wait wait
tail -n +7 "${TEST_CASE}_gpu_0.log"
# Check and print the log content accordingly # Check and print the log content accordingly
if grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then if grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE SKIPPED" echo "... $TEST_CASE SKIPPED"
elif grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE FAILED"
HAS_FAILURE=1
elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE PASSED" echo "... $TEST_CASE PASSED"
else else
echo "... $TEST_CASE INVALID"
HAS_FAILURE=1 HAS_FAILURE=1
echo "... $TEST_CASE FAILED"
fi fi
# Remove the log file after processing it # Remove the log file after processing it
...@@ -56,4 +97,8 @@ for TEST_CASE in "${TEST_CASES[@]}"; do ...@@ -56,4 +97,8 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
done done
wait wait
# Final cleanup (trap will also call cleanup on exit)
cleanup
exit $HAS_FAILURE exit $HAS_FAILURE
...@@ -29,6 +29,10 @@ wait ...@@ -29,6 +29,10 @@ wait
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
wait wait
TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
wait
TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/collective_gemm/run_test_cgemm.sh || test_fail "run_test_cgemm.sh"
wait
if [ $RET -ne 0 ]; then if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES" echo "Error: some sub-tests failed: $FAILED_CASES"
......
...@@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" ...@@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls # Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder_without_custom_call.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"
if [ $RET -ne 0 ]; then if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES" echo "Error: some sub-tests failed: $FAILED_CASES"
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
: ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/} : ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
# Config with the dummy feature which prevents nvinspect from being disabled. # Config with the dummy feature which prevents nvinspect from being disabled.
# Nvinspect will be disabled if no feature is active. # Nvinspect will be disabled if no feature is active.
...@@ -20,17 +22,16 @@ pip uninstall -y nvdlfw-inspect ...@@ -20,17 +22,16 @@ pip uninstall -y nvdlfw-inspect
pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git
pip install pytest==8.2.1 pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
# standard sanity and numerics tests with initialized debug # standard sanity and numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
exit $FAIL exit $FAIL
...@@ -31,6 +31,7 @@ ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 ...@@ -31,6 +31,7 @@ ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0
ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
......
...@@ -30,6 +30,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" ...@@ -30,6 +30,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
...@@ -47,9 +48,9 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_ ...@@ -47,9 +48,9 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_
: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} : ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
# standard numerics tests with initialized debug # standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
# See LICENSE for license information. # See LICENSE for license information.
pip3 install onnxruntime==1.20.1 pip3 install onnxruntime
pip3 install onnxruntime_extensions==0.13.0 pip3 install onnxruntime_extensions
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
python3 -m pytest --tb=auto $TE_PATH/tests/pytorch/test_onnx_export.py python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py
...@@ -11,6 +11,7 @@ list(APPEND test_cuda_sources ...@@ -11,6 +11,7 @@ list(APPEND test_cuda_sources
test_cast_mxfp8_gated_swiglu.cu test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu test_qdq.cu
test_cast_mxfp8.cu test_cast_mxfp8.cu
test_cast_nvfp4_transpose.cu
test_cast_float8blockwise.cu test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu test_dequantize_mxfp8.cu
test_transpose.cu test_transpose.cu
...@@ -66,6 +67,13 @@ else() ...@@ -66,6 +67,13 @@ else()
add_executable(test_operator ${test_hip_sources}) add_executable(test_operator ${test_hip_sources})
endif() endif()
# Add profiling and debug flags for CUDA compilation
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -lineinfo") # Generate line info for device code
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g") # Add debug symbols for host code
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --ptxas-options=-v") # Add info about registers usage
# Note: Using -lineinfo instead of -G to avoid conflicts and get line mapping
# Find required packages
find_package(OpenMP REQUIRED) find_package(OpenMP REQUIRED)
if(USE_CUDA) if(USE_CUDA)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment