Unverified Commit d75bf43f authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] CollectiveGemm (#2166)



* init cgemm + unit tests

* UB bootstrap with NCCL, no MPI dependency

* add NVLINK-P2P check + error message

* skip tests if no NVLINK available

* use std::vector to store ncclComm_t

* update misuse of TP warning
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 4d145786
......@@ -87,4 +87,5 @@ def setup_jax_extension(
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args=cxx_flags,
libraries=["nccl"],
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Shared functions for the comm_overlap tests"""
import jax.numpy as jnp
import numpy as np
# Add this after your existing imports
def dtype_tols(dtype, rtol=None, atol=None):
"""Expected numerical tolerance for a data type."""
# Return immediately if tolerances are fully specified
if rtol is not None and atol is not None:
return {"rtol": rtol, "atol": atol}
# Default tolerances for common dtypes
if dtype in [jnp.float32, "float32"]:
return {"rtol": 1e-5, "atol": 1e-8}
elif dtype in [jnp.float16, "float16"]:
return {"rtol": 1e-3, "atol": 1e-6}
elif dtype in [jnp.bfloat16, "bfloat16"]:
return {"rtol": 1e-2, "atol": 1e-5}
else:
return {"rtol": 1e-5, "atol": 1e-8}
def assert_allclose(
actual,
desired,
rtol=None,
atol=None,
dtype=None,
**kwargs,
):
"""Check if two tensors are close."""
# Infer data type if needed
if dtype is None:
if isinstance(actual, float):
dtype = "float32"
else:
dtype = actual.dtype
# Determine tolerances
tols = {}
if rtol is None or atol is None:
tols = dtype_tols(dtype)
if rtol is not None:
tols["rtol"] = rtol
if atol is not None:
tols["atol"] = atol
# Cast tensors to fp32
if not isinstance(actual, float):
actual = actual.astype(jnp.float32)
if not isinstance(desired, float):
desired = desired.astype(jnp.float32)
# Check if tensors are close
np.testing.assert_allclose(actual, desired, **tols, **kwargs)
def assert_allclose_print_index(ref_output, gathered_output, rtol=1e-5, atol=1e-8):
if not jnp.allclose(ref_output, gathered_output, rtol=rtol, atol=atol):
diff = jnp.abs(ref_output - gathered_output)
mask = diff > (atol + rtol * jnp.abs(gathered_output))
print(mask.astype(int))
print(jnp.where(mask, diff, 0))
# Shared constants for all tests
DP_AXIS = "data"
TPSP_AXIS = "tensor_sequence"
PARAMS_KEY = "params"
# Shared functions for distributed testing
import argparse
import jax
from jax.experimental import mesh_utils
from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap
# Global flag to track if distributed has been initialized
_distributed_initialized = False
def _is_distributed_initialized():
"""Check if JAX distributed has been initialized."""
return _distributed_initialized
def _initialize_distributed(args):
"""Initialize JAX distributed with custom arguments."""
global _distributed_initialized
# Check if already initialized
if _distributed_initialized:
return
if args.coordinator_address is None or args.num_processes is None or args.process_id is None:
raise ValueError(
"All distributed initialization arguments are required: "
"--coordinator-address, --num-processes, --process-id"
)
if args.local_device_ids is None:
assert (
args.num_devices_per_process is not None
), "Either local_device_ids or num_devices_per_process must be provided"
# Calculate device range for this process
# Single process single device: each process gets one unique device
# Single process multiple devices: each process gets a unique range of devices
start_device = args.process_id * args.num_devices_per_process
device_range = range(start_device, start_device + args.num_devices_per_process)
global_device_ids_for_this_process = ",".join(map(str, device_range))
else:
# Use explicitly provided global device IDs
global_device_ids_for_this_process = args.local_device_ids
args.num_devices_per_process = len(args.local_device_ids.split(","))
assert args.num_devices_per_process == 1, "Only single process single GPU is supported!"
print(
f"Initializing JAX distributed with coordinator={args.coordinator_address}, "
f"num_processes={args.num_processes}, process_id={args.process_id}"
)
# Note: "local_device_ids" is a JAX term meaning "global CUDA devices managed by this process"
jax.distributed.initialize(
coordinator_address=args.coordinator_address,
num_processes=args.num_processes,
process_id=args.process_id,
local_device_ids=global_device_ids_for_this_process,
)
_distributed_initialized = True
jax.clear_caches()
jax.config.update(
"jax_use_shardy_partitioner", False
) # CollectiveGEMM does not work with Shardy yet
assert jax.local_device_count() == 1, (
f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found"
f" {jax.local_device_count()}"
)
devices_per_process = 1
num_total_devices = args.num_processes
print(
f"Initializing CGEMM communicator with num_total_devices={num_total_devices},"
f" devices_per_process={devices_per_process}, process_id={args.process_id}"
)
collective_gemm_bootstrap(
num_total_devices=num_total_devices,
num_devices_per_process=devices_per_process,
process_id=args.process_id,
tensor_parallel_size=args.tensor_parallel_size,
)
def _get_dp_and_tp_sizes(args):
num_gpu = args.num_processes * args.num_devices_per_process
if args.tensor_parallel_size is None:
num_gpu_dp = 2 if args.enable_data_parallel else 1
assert (
num_gpu > 1 and num_gpu % num_gpu_dp == 0
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_tp = num_gpu // num_gpu_dp
else:
num_gpu_tp = args.tensor_parallel_size
assert (
num_gpu > 1 and num_gpu % num_gpu_tp == 0
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_dp = num_gpu // num_gpu_tp
return num_gpu_dp, num_gpu_tp
def _create_mesh(args):
"""Create mesh configuration with proper validation."""
num_gpu = args.num_processes * args.num_devices_per_process
assert num_gpu == len(jax.devices()), "Number of GPUs must be equal to number of devices"
num_gpu_dp, num_gpu_tp = _get_dp_and_tp_sizes(args)
print(f"Using {num_gpu_dp}x{num_gpu_tp} mesh ({num_gpu_dp * num_gpu_tp} total GPUs)")
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=(DP_AXIS, TPSP_AXIS))
return mesh
def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor parallelism"):
"""Create common argument parser for all collective GEMM tests."""
parser = argparse.ArgumentParser(description=description)
# Distributed initialization arguments
parser.add_argument(
"--coordinator-address",
type=str,
default=None,
help="Coordinator address for distributed initialization",
)
parser.add_argument(
"--num-processes",
type=int,
default=None,
help="Number of processes for distributed initialization",
)
parser.add_argument(
"--process-id", type=int, default=None, help="Process ID for distributed initialization"
)
parser.add_argument(
"--local-device-ids",
type=str,
default=None,
help="Local device IDs for distributed initialization (comma-separated)",
)
parser.add_argument(
"--num-devices-per-process", type=int, default=1, help="Number of devices per process"
)
# Test configuration arguments
parser.add_argument(
"--tensor-parallel-size", type=int, default=None, help="Tensor parallel size"
)
parser.add_argument("--batch-size", type=int, default=4, help="Batch size for testing")
parser.add_argument("--seq-len", type=int, default=8192, help="Sequence length for testing")
parser.add_argument("--hidden-in", type=int, default=4096, help="Input hidden dimension")
parser.add_argument("--hidden-out", type=int, default=8192, help="Output hidden dimension")
parser.add_argument(
"--collective-type",
type=str,
default="all_gather",
choices=["all_gather", "reduce_scatter"],
help="Type of collective operation",
)
parser.add_argument(
"--fp8-recipe", type=str, default="DelayedScaling", help="FP8 recipe to use"
)
parser.add_argument(
"--enable-data-parallel", action="store_true", help="Enable data parallelism"
)
parser.add_argument(
"--enable-result-check", action="store_true", default=True, help="Enable result checking"
)
return parser
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""config for collective_gemm tests"""
import pytest
def pytest_addoption(parser):
"""Pytest hook for collective_gemm tests"""
parser.addoption("--coordinator-address", action="store", default="localhost:12345")
parser.addoption("--num-processes", action="store", default=1)
parser.addoption("--process-id", action="store", default=0)
parser.addoption("--local-device-ids", action="store", default=None)
@pytest.fixture(autouse=True)
def distributed_args(request):
"""Fixture for querying distributed initialization arguments"""
if request.cls:
request.cls.coordinator_address = request.config.getoption("--coordinator-address")
request.cls.num_processes = int(request.config.getoption("--num-processes"))
request.cls.process_id = int(request.config.getoption("--process-id"))
request.cls.local_device_ids = request.config.getoption("--local-device-ids")
request.cls.num_devices_per_process = (
1
if request.cls.local_device_ids is None
else len(request.cls.local_device_ids.split(","))
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
# Check if NVLINK is supported before running tests
echo "*** Checking NVLINK support***"
NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1)
NVLINK_EXIT_CODE=$?
# Check if command failed OR output indicates no NVLINK
if [ $NVLINK_EXIT_CODE -ne 0 ] || [[ "$NVLINK_OUTPUT" == *"not supported"* ]] || [[ "$NVLINK_OUTPUT" == *"No devices"* ]] || [ -z "$NVLINK_OUTPUT" ]; then
echo "NVLINK is not supported on this platform"
echo "Collective GEMM tests require NVLINK connectivity"
echo "SKIPPING all tests"
exit 0
else
echo "NVLINK support detected"
fi
# Define the test files to run
TEST_FILES=(
"test_gemm.py"
"test_dense_grad.py"
"test_layernorm_mlp_grad.py"
)
echo
echo "*** Executing tests in examples/jax/collective_gemm/ ***"
HAS_FAILURE=0 # Global failure flag
PIDS=() # Array to store all process PIDs
# Cleanup function to kill all processes
cleanup() {
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Killing process $pid"
kill -TERM "$pid" 2>/dev/null || true
fi
done
# Wait a bit and force kill if needed
sleep 2
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Force killing process $pid"
kill -KILL "$pid" 2>/dev/null || true
fi
done
}
# Set up signal handlers to cleanup on exit
trap cleanup EXIT INT TERM
# Run each test file across all GPUs
for TEST_FILE in "${TEST_FILES[@]}"; do
echo
echo "=== Starting test file: $TEST_FILE ..."
# Clear PIDs array for this test file
PIDS=()
for i in $(seq 0 $(($NUM_GPUS - 1))); do
# Define output file for logs
LOG_FILE="${TEST_FILE}_gpu_${i}.log"
if [ $i -eq 0 ]; then
# For process 0: show live output AND save to log file using tee
echo "=== Live output from process 0 ==="
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
--num-processes=$NUM_GPUS \
--process-id=$i 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
else
# For other processes: redirect to log files only
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
--num-processes=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
PID=$!
PIDS+=($PID)
fi
done
# Wait for all processes to finish
wait
# Check and print the log content from process 0 (now has log file thanks to tee)
if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE SKIPPED"
elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE FAILED"
HAS_FAILURE=1
else
echo "... $TEST_FILE PASSED"
fi
# Remove the log files after processing them
wait
rm ${TEST_FILE}_gpu_*.log
done
wait
# Final cleanup (trap will also call cleanup on exit)
cleanup
exit $HAS_FAILURE
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Collective Dense Gradient test on multi-GPU with tensor parallelism"""
import argparse
import unittest
import os
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding
import flax
from common import (
assert_allclose,
_initialize_distributed,
_get_dp_and_tp_sizes,
_create_mesh,
DP_AXIS,
TPSP_AXIS,
PARAMS_KEY,
cgemm_parser,
)
from transformer_engine.jax.dense import dense
from transformer_engine.jax.quantize import fp8_autocast
from transformer_engine.jax.cpp_extensions.gemm import (
CollectiveOp,
CollectiveOpSet,
noop_collective_op_set,
)
from transformer_engine.jax.sharding import MeshResource
import transformer_engine.jax.flax as te_flax
def _get_logical_axes(collective_op):
if collective_op.is_all_gather:
input_axes = (DP_AXIS, TPSP_AXIS, None)
weight_axes = (None, TPSP_AXIS)
bias_axes = (TPSP_AXIS,)
output_axes = (DP_AXIS, None, TPSP_AXIS)
else: # RS
input_axes = (DP_AXIS, None, TPSP_AXIS)
weight_axes = (TPSP_AXIS, None)
bias_axes = (None,)
output_axes = (DP_AXIS, TPSP_AXIS, None)
return input_axes, weight_axes, bias_axes, output_axes
def _get_operand_sharding(mesh, collective_op):
input_axes, weight_axes, bias_axes, _ = _get_logical_axes(collective_op)
x_sharding = NamedSharding(mesh, PartitionSpec(*input_axes))
weight_sharding = NamedSharding(mesh, PartitionSpec(*weight_axes))
bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes))
return x_sharding, weight_sharding, bias_sharding
def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set):
output = dense(
x,
weight,
bias,
contracting_dims=((2,), (0,)),
input_axes=input_axes,
kernel_axes=weight_axes,
output_axes=output_axes,
collective_op_set=collective_op_set,
)
return jnp.mean(output.astype(jnp.float32))
def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set):
return jax.jit(jax.value_and_grad(_mean_dense, (0, 1, 2)), static_argnums=(3, 4, 5, 6))(
x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set
)
def run_dense_grad_tests(args, mesh=None):
"""Execute Dense Gradient tests."""
print(args)
_initialize_distributed(args)
mesh = mesh or _create_mesh(args)
# Create test data
rng = jax.random.PRNGKey(0)
rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4)
x = jax.random.normal(
x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16
)
weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16)
bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16)
collective_op = (
CollectiveOp.ALL_GATHER
if args.collective_type == "all_gather"
else CollectiveOp.REDUCE_SCATTER
)
collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op)
with mesh, fp8_autocast(
enabled=False,
fp8_recipe=None,
mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
):
# Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
axis_rules = flax.linen.get_logical_axis_rules()
axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS))
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
with flax.linen.logical_axis_rules(te_extended_axis_rules):
x_sharding, weight_sharding, bias_sharding = _get_operand_sharding(mesh, collective_op)
x_sharded = jax.device_put(x, x_sharding)
weight_sharded = jax.device_put(weight, weight_sharding)
bias_sharded = jax.device_put(bias, bias_sharding)
input_axes, weight_axes, _, output_axes = _get_logical_axes(collective_op)
ref_output, ref_grads = _value_and_grad_dense(
x_sharded,
weight_sharded,
bias_sharded,
input_axes,
weight_axes,
output_axes,
noop_collective_op_set,
)
output, sharded_grads = _value_and_grad_dense(
x_sharded,
weight_sharded,
bias_sharded,
input_axes,
weight_axes,
output_axes,
collective_op_set,
)
jax.block_until_ready(ref_output)
jax.block_until_ready(output)
gathered_grads = []
gathered_ref_grads = []
for ref_grad, grad in zip(ref_grads, sharded_grads):
gathered_grads.append(
jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None)))
)
gathered_ref_grads.append(
jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None)))
)
jax.block_until_ready(gathered_grads)
jax.block_until_ready(gathered_ref_grads)
if args.enable_result_check and args.process_id == 0:
assert_allclose(ref_output, output, dtype=jnp.bfloat16)
for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads):
assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16)
class TestCollectiveDenseGradient(unittest.TestCase):
"""Collective Dense Gradient unittests"""
def setUp(self):
self.args = cgemm_parser(
"Collective Dense Gradient test on multi-GPU with tensor parallelism"
).parse_args([])
self.args.coordinator_address = self.coordinator_address
self.args.num_processes = self.num_processes
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
# Create mesh once for all tests
self.mesh = _create_mesh(self.args)
jax.sharding.set_mesh(self.mesh)
self.args.enable_result_check = True
os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1"
def tearDown(self):
os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None)
def test_te_bf16_all_gather(self):
"""Test Collective Dense Gradient with AllGather"""
self.args.collective_type = "all_gather"
run_dense_grad_tests(self.args, self.mesh)
def test_te_bf16_reduce_scatter(self):
"""Test Collective Dense Gradient with ReduceScatter"""
self.args.collective_type = "reduce_scatter"
run_dense_grad_tests(self.args, self.mesh)
if __name__ == "__main__":
import sys
if len(sys.argv) < 7: # Need at least the 3 required distributed args
print("Error: This script requires distributed initialization arguments.")
print(
"Usage: python test_dense_grad.py --coordinator-address <address> --num-processes <num>"
" --process-id <id> [--local-device-ids <ids>] [other args]"
)
print(
"Example: python test_dense_grad.py --coordinator-address localhost:1234"
" --num-processes 4 --process-id 0"
)
print(
"Example: python test_dense_grad.py --coordinator-address localhost:1234"
" --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3"
)
sys.exit(1)
args = cgemm_parser(
"Collective Dense Gradient test on multi-GPU with tensor parallelism"
).parse_args([])
_initialize_distributed(args)
run_dense_grad_tests(args, mesh=None)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Collective GEMM test on multi-GPU with tensor parallelism
This script uses custom distributed initialization with the following arguments:
- --coordinator-address: Coordinator address for distributed initialization
- --num-processes: Number of processes for distributed initialization
- --process-id: Process ID for distributed initialization
- --local-device-ids: Local device IDs for distributed initialization
Example:
python test_gemm.py --coordinator-address localhost:1234 --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3
"""
import unittest
import os
from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding
from common import (
assert_allclose,
_initialize_distributed,
_get_dp_and_tp_sizes,
_create_mesh,
DP_AXIS,
TPSP_AXIS,
PARAMS_KEY,
cgemm_parser,
)
import transformer_engine.jax.cpp_extensions as tex
from transformer_engine.jax.quantize import fp8_autocast
from transformer_engine.jax.cpp_extensions.gemm import CollectiveOp
from transformer_engine.jax.sharding import MeshResource
def _get_operand_sharding(mesh, collective_op, is_with_dp):
dp_axis = DP_AXIS if is_with_dp else None
if collective_op == CollectiveOp.ALL_GATHER:
x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None))
weight_sharding = NamedSharding(mesh, PartitionSpec(None, TPSP_AXIS))
bias_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS))
output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS))
else: # RS
x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS))
weight_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS, None))
bias_sharding = NamedSharding(mesh, PartitionSpec(None))
output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None))
return x_sharding, weight_sharding, bias_sharding, output_sharding
def _get_dp_and_tp_sizes(args):
num_gpu = args.num_processes * args.num_devices_per_process
if args.tensor_parallel_size is None:
num_gpu_dp = 2 if args.enable_data_parallel else 1
assert (
num_gpu > 1 and num_gpu % num_gpu_dp == 0
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_tp = num_gpu // num_gpu_dp
else:
num_gpu_tp = args.tensor_parallel_size
assert (
num_gpu > 1 and num_gpu % num_gpu_tp == 0
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_dp = num_gpu // num_gpu_tp
return num_gpu_dp, num_gpu_tp
@partial(jax.jit, static_argnames=("contracting_dims", "collective_op", "output_sharding"))
def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_sharding):
output = tex.gemm(
x,
weight,
bias=bias,
contracting_dims=contracting_dims,
collective_op=collective_op,
)
if output_sharding is not None:
output = jax.lax.with_sharding_constraint(output, output_sharding)
return output
def run_gemm_tests(args, mesh=None):
"""Execute GEMM tests."""
print(args)
# Collective GEMM requires Shardy partitioner to be disabled
jax.config.update("jax_use_shardy_partitioner", False)
# Initialize distributed with provided arguments
_initialize_distributed(args)
mesh = mesh or _create_mesh(args)
# Create test data
rng = jax.random.PRNGKey(0)
rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4)
x = jax.random.normal(
x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16
)
weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16)
bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16)
collective_op = (
CollectiveOp.ALL_GATHER
if args.collective_type == "all_gather"
else CollectiveOp.REDUCE_SCATTER
)
with mesh, fp8_autocast(
enabled=False,
fp8_recipe=None,
mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
):
print(f"Device mesh: {mesh}")
x_sharding, weight_sharding, bias_sharding, output_sharding = _get_operand_sharding(
mesh, collective_op, args.enable_data_parallel
)
x_sharded = jax.device_put(x, x_sharding)
weight_sharded = jax.device_put(weight, weight_sharding)
bias_sharded = jax.device_put(bias, bias_sharding)
ref_output = _jitted_cgemm(
x_sharded,
weight_sharded,
bias_sharded,
contracting_dims=((2,), (0,)),
collective_op=CollectiveOp.NONE,
output_sharding=output_sharding,
)
output = _jitted_cgemm(
x_sharded,
weight_sharded,
bias_sharded,
contracting_dims=((2,), (0,)),
collective_op=collective_op,
# CollectiveGEMM output should have a correct sharding without applying sharding constraint
output_sharding=None,
)
assert (
ref_output.sharding == output.sharding
), f"ref_output.sharding={ref_output.sharding}, output.sharding={output.sharding}"
gathered_ref_output = jax.lax.with_sharding_constraint(
ref_output, NamedSharding(mesh, PartitionSpec(None))
)
gathered_output = jax.lax.with_sharding_constraint(
output, NamedSharding(mesh, PartitionSpec(None))
)
jax.block_until_ready(gathered_ref_output)
jax.block_until_ready(gathered_output)
if args.enable_result_check and args.process_id == 0:
assert_allclose(gathered_ref_output, gathered_output)
class TestCollectiveGemmWithDP(unittest.TestCase):
"""Collective GEMM with DP unittests"""
def setUp(self):
self.args = cgemm_parser(
"Collective GEMM test on multi-GPU with tensor parallelism"
).parse_args([])
self.args.coordinator_address = self.coordinator_address
self.args.num_processes = self.num_processes
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
self.mesh = _create_mesh(self.args)
jax.sharding.set_mesh(self.mesh)
self.args.enable_result_check = True
os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1"
def tearDown(self):
os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None)
def test_te_bf16_all_gather_with_dp(self):
"""Test Collective GEMM with AllGather"""
self.args.collective_type = "all_gather"
run_gemm_tests(self.args, self.mesh)
def test_te_bf16_reduce_scatter_with_dp(self):
"""Test Collective GEMM with ReduceScatter"""
self.args.collective_type = "reduce_scatter"
run_gemm_tests(self.args, self.mesh)
if __name__ == "__main__":
import sys
if len(sys.argv) < 5: # Need at least the 3 required distributed args
print("Error: This script requires distributed initialization arguments.")
print(
"Usage: python test_gemm.py --coordinator-address <address> --num-processes <num>"
" --process-id <id> [--local-device-ids <ids>] [other args]"
)
sys.exit(1)
args = cgemm_parser("Collective GEMM test on multi-GPU with tensor parallelism").parse_args()
_initialize_distributed(args)
run_gemm_tests(args, mesh=None)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Collective Dense Gradient test on multi-GPU with tensor parallelism"""
import argparse
import unittest
import os
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding
import flax
from common import (
assert_allclose,
_initialize_distributed,
_get_dp_and_tp_sizes,
_create_mesh,
DP_AXIS,
TPSP_AXIS,
PARAMS_KEY,
cgemm_parser,
)
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
from transformer_engine.jax.quantize import fp8_autocast
from transformer_engine.jax.cpp_extensions.gemm import (
CollectiveOpSet,
CollectiveOp,
noop_collective_op_set,
)
from transformer_engine.jax.sharding import MeshResource
import transformer_engine.jax.flax as te_flax
def _get_logical_axes():
input_1_axes = (DP_AXIS, TPSP_AXIS, None)
weight_1_axes = (None, None, TPSP_AXIS)
bias_axes_1 = (None, TPSP_AXIS)
input_2_axes = (DP_AXIS, None, TPSP_AXIS)
weight_2_axes = (TPSP_AXIS, None)
bias_axes_2 = (None,)
return input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2
def _get_operand_sharding(mesh):
input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2 = (
_get_logical_axes()
)
x_sharding = NamedSharding(mesh, PartitionSpec(*input_1_axes))
weight_1_sharding = NamedSharding(mesh, PartitionSpec(*weight_1_axes))
bias_1_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_1))
weight_2_sharding = NamedSharding(mesh, PartitionSpec(*weight_2_axes))
bias_2_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_2))
return x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding
def _mean_layernorm_mlp(
x,
weight_1,
bias_1,
weight_2,
bias_2,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
collective_op_sets,
):
output = layernorm_mlp(
x,
gamma,
beta=None,
kernels=[weight_1, weight_2],
biases=[bias_1, bias_2],
norm_type="rmsnorm",
dot_1_input_axes=input_1_axes,
dot_2_input_axes=input_2_axes,
kernel_1_axes=weight_1_axes,
kernel_2_axes=weight_2_axes,
activation_type=("gelu",),
collective_op_sets=collective_op_sets,
)
return jnp.mean(output)
def _value_and_grad_layernorm_mlp(
x,
weight_1,
bias_1,
weight_2,
bias_2,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
collective_op_sets,
):
return jax.jit(
jax.value_and_grad(_mean_layernorm_mlp, (0, 1, 2, 3, 4, 5)), static_argnums=(6, 7, 8, 9, 10)
)(
x,
weight_1,
bias_1,
weight_2,
bias_2,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
collective_op_sets,
)
def run_layernorm_mlp_grad_tests(args, mesh=None):
"""Execute Dense Gradient tests."""
print(args)
# Collective GEMM requires Shardy partitioner to be disabled
jax.config.update("jax_use_shardy_partitioner", False)
# Initialize distributed with provided arguments
_initialize_distributed(args)
mesh = mesh or _create_mesh(args)
# Create test data
rng = jax.random.PRNGKey(0)
rng, x_rng, weight_1_rng, bias_1_rng, weight_2_rng, bias_2_rng, gamma_rng = jax.random.split(
rng, 7
)
x = jax.random.normal(
x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16
)
weight_1 = jax.random.normal(
weight_1_rng, (args.hidden_in, 1, args.hidden_out), dtype=jnp.bfloat16
) / jnp.sqrt(args.hidden_in)
bias_1 = jax.random.normal(bias_1_rng, (1, args.hidden_out), dtype=jnp.bfloat16)
weight_2 = jax.random.normal(
weight_2_rng, (args.hidden_out, args.hidden_in), dtype=jnp.bfloat16
) / jnp.sqrt(args.hidden_out)
bias_2 = jax.random.normal(bias_2_rng, (args.hidden_in,), dtype=jnp.bfloat16)
gamma = jax.random.normal(gamma_rng, (args.hidden_in,), dtype=jnp.bfloat16) / jnp.sqrt(
args.hidden_in
)
collective_op_set_1 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.ALL_GATHER)
collective_op_set_2 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.REDUCE_SCATTER)
collective_op_sets = (collective_op_set_1, collective_op_set_2)
noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set)
with mesh, fp8_autocast(
enabled=False,
fp8_recipe=None,
mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
):
# Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
axis_rules = flax.linen.get_logical_axis_rules()
axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS))
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
with flax.linen.logical_axis_rules(te_extended_axis_rules):
x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding = (
_get_operand_sharding(mesh)
)
x_sharded = jax.device_put(x, x_sharding)
weight_1_sharded = jax.device_put(weight_1, weight_1_sharding)
bias_1_sharded = jax.device_put(bias_1, bias_1_sharding)
weight_2_sharded = jax.device_put(weight_2, weight_2_sharding)
bias_2_sharded = jax.device_put(bias_2, bias_2_sharding)
input_1_axes, weight_1_axes, _, input_2_axes, weight_2_axes, _ = _get_logical_axes()
ref_output, ref_grads = _value_and_grad_layernorm_mlp(
x_sharded,
weight_1_sharded,
bias_1_sharded,
weight_2_sharded,
bias_2_sharded,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
noop_collective_op_sets,
)
output, sharded_grads = _value_and_grad_layernorm_mlp(
x_sharded,
weight_1_sharded,
bias_1_sharded,
weight_2_sharded,
bias_2_sharded,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
collective_op_sets,
)
jax.block_until_ready(ref_output)
jax.block_until_ready(output)
gathered_grads = []
gathered_ref_grads = []
for ref_grad, grad in zip(ref_grads, sharded_grads):
gathered_grads.append(
jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None)))
)
gathered_ref_grads.append(
jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None)))
)
jax.block_until_ready(gathered_grads)
jax.block_until_ready(gathered_ref_grads)
if args.enable_result_check and args.process_id == 0:
assert_allclose(ref_output, output, dtype=jnp.bfloat16)
for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads):
assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16)
class TestCollectiveLayerNormMLPGradient(unittest.TestCase):
"""Collective Dense Gradient unittests"""
def setUp(self):
self.args = cgemm_parser(
"Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism"
).parse_args([])
self.args.coordinator_address = self.coordinator_address
self.args.num_processes = self.num_processes
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
# Create mesh once for all tests
self.mesh = _create_mesh(self.args)
jax.sharding.set_mesh(self.mesh)
self.args.enable_result_check = True
os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1"
def tearDown(self):
os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None)
def test_te_bf16_layernorm_mlp_grad(self):
"""Test Collective Dense Gradient with AllGather"""
run_layernorm_mlp_grad_tests(self.args, self.mesh)
if __name__ == "__main__":
import sys
if len(sys.argv) < 7: # Need at least the 3 required distributed args
print("Error: This script requires distributed initialization arguments.")
print(
"Usage: python test_layernorm_mlp_grad.py --coordinator-address <address>"
" --num-processes <num> --process-id <id> [--local-device-ids <ids>] [other args]"
)
print(
"Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234"
" --num-processes 4 --process-id 0"
)
print(
"Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234"
" --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3"
)
sys.exit(1)
args = cgemm_parser(
"Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism"
).parse_args([])
_initialize_distributed(args)
run_layernorm_mlp_grad_tests(args, mesh=None)
......@@ -29,6 +29,10 @@ wait
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
wait
TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
wait
TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/collective_gemm/run_test_cgemm.sh || test_fail "run_test_cgemm.sh"
wait
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
......
......@@ -64,6 +64,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
#endif
_comm_created = true;
}
initialize(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority,
num_comm_sm, set_sm_margin, use_ce, atomic_gemm);
}
void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_streams,
int comm_cga_size, int gemm_priority, int comm_priority,
int num_comm_sm, bool set_sm_margin, bool use_ce,
bool atomic_gemm) {
_use_ce = static_cast<int>(use_ce);
_num_comm_sm = num_comm_sm;
_cga_size = comm_cga_size;
......@@ -278,6 +287,11 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false,
atomic_gemm) {
initialize(buffer_shape, buffer_dtype, rs_overlap_first_gemm);
}
void CommOverlapBase::initialize(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
bool rs_overlap_first_gemm) {
_rs_overlap_first_gemm = rs_overlap_first_gemm;
_rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0);
NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3,
......@@ -288,7 +302,9 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype);
void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
if (_ub_comm->myrank == 0) {
printf("!!! [UB] Register UBuf %d\n", _ub_reg);
}
_ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype);
NVTE_CHECK_CUDA(
......@@ -640,6 +656,11 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm) {
initialize(buffer_shape, buffer_dtype, comm_type, aggregate);
}
void CommOverlapP2PBase::initialize(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
CommOverlapType comm_type, bool aggregate) {
_is_p2p = true;
_is_reduce_scatter = comm_type == CommOverlapType::RS;
_aggregate = aggregate;
......@@ -647,28 +668,28 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
// Create workspace tensor with userbuffer
NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!");
size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype);
int buffer_chunk_bytes = buffer_bytes / tp_size;
_num_ubuf_chunks = tp_size;
int buffer_chunk_bytes = buffer_bytes / _tp_size;
_num_ubuf_chunks = _tp_size;
if (_is_reduce_scatter) {
// GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk
// outputs for reduction at the end of the pipelining.
buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1);
_num_ubuf_chunks = tp_size * 2 - 1;
buffer_bytes = buffer_bytes / _tp_size * (_tp_size * 2 - 1);
_num_ubuf_chunks = _tp_size * 2 - 1;
}
void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
if (_rank == 0) printf("!!! [UBP2P] UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(
buffer_ptr,
std::vector<size_t>{buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]},
std::vector<size_t>{buffer_shape[0] / _tp_size * _num_ubuf_chunks, buffer_shape[1]},
buffer_dtype);
// Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(buffer_ptr);
for (int i = 0; i < _num_ubuf_chunks; i++) {
_ubufs.push_back(TensorWrapper(reinterpret_cast<void *>(ubuf_byte_ptr),
std::vector<size_t>{buffer_shape[0] / tp_size, buffer_shape[1]},
std::vector<size_t>{buffer_shape[0] / _tp_size, buffer_shape[1]},
buffer_dtype));
ubuf_byte_ptr += buffer_chunk_bytes;
}
......@@ -691,7 +712,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t)));
}
for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) {
for (int i = 0; i < _stream_compute.size(); i++) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority));
_stream_send.push_back(std::move(stream));
......@@ -711,6 +732,38 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
}
}
void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source,
bool local_chunk, bool rowwise) {
// Check element size
const size_t element_size = source.element_size();
NVTE_CHECK(_ubuf.element_size() == element_size,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible ",
"(source dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(),
" bytes)");
// Input data
const size_t source_size = source.numel();
const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr();
// Userbuffers data
void *dst_ptr;
if (local_chunk) {
NVTE_CHECK(_ubufs[_tp_id].numel() == source_size,
"Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ",
"(source_size=", source_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")");
dst_ptr = _ubufs[_tp_id].dptr();
} else {
NVTE_CHECK(_ubuf.numel() == source_size,
"Tried to copy an invalid tensor into a Userbuffers buffer ",
"(source_size=", source_size, ", ubuf_size=", _ubuf.numel(), ")");
dst_ptr = _ubuf.dptr();
}
// Copy data
NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size,
cudaMemcpyDeviceToDevice, stream));
}
TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source,
size_t chunk_id) {
// Start with a chunk of the source tensor
......@@ -851,6 +904,15 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
const bool do_gelu = pre_gelu_out.numel() > 0;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Check B copy sizing
if (B_copy.numel() > 0) {
NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ",
_ubuf.numel(), " elements but got ", B_copy.numel());
NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(),
"Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8,
"-bit data type but got ", B_copy.element_size() * 8, "-bit");
}
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
......@@ -919,12 +981,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0));
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice,
_stream_send[0]));
}
}
} else {
......@@ -972,16 +1028,16 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0));
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice,
_stream_send[0]));
}
}
}
// Copy all-gathered B from communication buffer into auxiliary output
if (B_copy.numel() > 0) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(),
cudaMemcpyDeviceToDevice, _stream_send[0]));
}
_ub_comm->sms = ori_sms;
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
......
......@@ -670,9 +670,36 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
reinterpret_cast<void *>(&memhndl), sizeof(cudaIpcMemHandle_t),
comm->comm_intra);
// Check for NVLINK support before attempting IPC operations
if (comm->nvsize > 1) {
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
cudaDeviceProp deviceProp;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, current_device));
bool peer_access_available = false;
for (int i = 0; i < comm->nvsize; i++) {
if (i != comm->nvrank) {
NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*)
int can_access_peer;
cudaError_t peer_result = cudaDeviceCanAccessPeer(&can_access_peer, current_device, i);
if (peer_result == cudaSuccess && can_access_peer) {
peer_access_available = true;
break;
}
}
}
if (!peer_access_available) {
free(tmp);
NVTE_ERROR(
"No peer-to-peer access available between GPUs. This platform does not support the "
"GPU-to-GPU "
"communication required for multi-GPU userbuffers. Consider using single-GPU mode.");
return 1;
}
}
for (int i = 0; i < comm->nvsize; i++) {
if (i != comm->nvrank) {
NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i],
cudaIpcMemLazyEnablePeerAccess));
}
}
......@@ -693,4 +720,5 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
comm->mem_ptr[hndl] = *gpubuff;
return comm->free_region++;
printf("***** Returning *****\n");
}
......@@ -67,6 +67,11 @@ class CommOverlapCore {
std::vector<cudaStream_t> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event;
private:
void initialize(int tp_size, int num_splits, int num_max_streams, int comm_cga_size,
int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin,
bool use_ce, bool atomic_gemm);
public:
CommOverlapCore() {} // dummy constructor for exposing type to Python
......@@ -78,17 +83,26 @@ class CommOverlapCore {
virtual ~CommOverlapCore();
void *get_ubuf_dptr() { return _ubuf.dptr(); }
void set_ubuf_scale_inv(float *scale_inv) {
_ubuf_scale_inv = scale_inv;
_ubuf_scale_inv_initialized = true;
}
virtual void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk,
bool rowwise = true) {
NVTE_ERROR("Operation is not implemented.");
}
TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset,
const std::vector<size_t> &shape);
TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset,
const std::vector<size_t> &shape);
int get_tp_size() { return _tp_size; }
bool is_atomic_gemm() { return _atomic_gemm; }
bool is_p2p_overlap() { return _is_p2p; }
......@@ -148,6 +162,10 @@ class CommOverlapBase : public CommOverlapCore {
cudaStream_t _stream_comm;
cudaEvent_t _start_d2dcopy;
private:
void initialize(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
bool rs_overlap_first_gemm);
public:
CommOverlapBase() {} // dummy constructor for exposing type to Python
......@@ -224,6 +242,10 @@ class CommOverlapP2PBase : public CommOverlapCore {
cudaStream_t _stream_recv;
cudaEvent_t _stop_send, _stop_recv;
private:
void initialize(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
CommOverlapType comm_type, bool aggregate);
public:
CommOverlapP2PBase() {} // dummy constructor for exposing type to Python
......@@ -237,6 +259,9 @@ class CommOverlapP2PBase : public CommOverlapCore {
virtual ~CommOverlapP2PBase();
void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk,
bool rowwise = true) override;
TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id);
void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
......
......@@ -12,6 +12,8 @@
#include <cudnn.h>
#include <nvrtc.h>
#include "nccl.h"
#ifdef NVTE_WITH_CUBLASMP
#include <cublasmp.h>
#endif // NVTE_WITH_CUBLASMP
......@@ -104,4 +106,12 @@
#endif // NVTE_WITH_CUBLASMP
#define NVTE_CHECK_NCCL(expr) \
do { \
const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \
if (status_NVTE_CHECK_NCCL != ncclSuccess) { \
NVTE_ERROR("NCCL Error: ", ncclGetErrorString(status_NVTE_CHECK_NCCL)); \
} \
} while (false)
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
......@@ -6,8 +6,10 @@
import math
import operator
from collections.abc import Iterable
from typing import Tuple, Sequence, Union
from dataclasses import dataclass
from functools import partial, reduce
from typing import Tuple, Sequence, Union
from enum import Enum
import warnings
import jax
......@@ -16,8 +18,13 @@ from jax import dtypes
from jax.sharding import NamedSharding, PartitionSpec
from jax.experimental.custom_partitioning import SdyShardingRule
import transformer_engine_jax as tex
from transformer_engine_jax import get_num_compute_streams
from transformer_engine_jax import (
get_num_compute_streams,
JAXX_Collective_Op,
get_device_compute_capability,
initialize_cgemm_communicator,
get_cgemm_num_max_streams,
)
from .base import BasePrimitive, register_primitive
from .quantization import grouped_quantize
......@@ -37,11 +44,19 @@ from ..quantize import (
is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv,
)
from ..sharding import global_mesh_resource
from .misc import get_padded_spec
from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import (
global_mesh_resource,
tpsp_axis_size,
dp_or_fsdp_axis_size,
)
__all__ = [
"CollectiveOp",
"CollectiveOpSet",
"collective_gemm_bootstrap",
"noop_collective_op_set",
"gemm",
"grouped_gemm",
"gemm_uses_jax_dot",
......@@ -56,7 +71,7 @@ num_cublas_streams = get_num_compute_streams()
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if tex.get_device_compute_capability(0) >= 90:
if get_device_compute_capability(0) >= 90:
return 33_554_432
return 4_194_304
......@@ -152,6 +167,161 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
return lhs_q, rhs_q
def collective_gemm_bootstrap(
num_total_devices,
num_devices_per_process,
process_id,
tensor_parallel_size,
num_max_streams=3,
compute_stream_priority=0,
communication_stream_priority=0,
num_sm_for_communication=2,
use_ce=True,
aggregate_all_gather=False,
):
"""Initialize NCCL communicators for Collective GEMM operations.
This function sets up the distributed communication infrastructure needed for
tensor parallel collective GEMM operations. It supports two main scenarios:
1. **Multi-device per process**: TP domain = single process
- Each process manages multiple GPUs (num_devices_per_process > 1)
- TP group consists of GPUs within the same process
- Example: 2 processes × 4 GPUs each = 8 total ranks, tp_size=4
2. **Single device per process**: TP domain spans multiple processes
- Each process manages one GPU (num_devices_per_process = 1)
- TP group spans across multiple processes
- Example: 8 processes × 1 GPU each = 8 total ranks, tp_size=4
Args:
num_total_devices (int): Total number of ranks across all processes.
Must be divisible by num_devices_per_process.
num_devices_per_process (int): Number of GPUs per process.
- For multi-device: equals tp_size (e.g., 4 GPUs per process)
- For single-device: equals 1 (1 GPU per process)
process_id (int): Process identifier (0-based).
Must be in range [0, num_total_devices // num_devices_per_process).
tensor_parallel_size (int): Size of tensor parallel groups.
Must divide num_total_devices evenly.
num_max_streams (int, optional): Maximum number of CUDA streams for overlap.
Higher values enable more parallelism but use more GPU resources. Default: 3.
compute_stream_priority (int, optional): Priority for GEMM computation streams.
Lower values = higher priority. Range: 0 (highest) to 3 (lowest). Default: 0.
communication_stream_priority (int, optional): Priority for NCCL communication streams.
Lower values = higher priority. Range: 0 (highest) to 3 (lowest). Default: 0.
num_sm_for_communication (int, optional): Number of streaming multiprocessors
reserved for communication operations. Default: 2.
use_ce (bool, optional): Enable CUDA copy engines for memory transfers.
Can improve performance by offloading memory operations. Default: True.
aggregate_all_gather (bool, optional): Aggregate multiple small all-gather operations
into larger ones for better efficiency. Default: False.
Raises:
AssertionError: If num_total_devices is not divisible by num_devices_per_process,
or if process_id is out of valid range.
AssertionError: If num_devices_per_process is not 1 (Temporary: only single device per process is supported for now)
RuntimeError: If NCCL initialization fails or if configuration
is invalid (e.g., insufficient GPUs).
Example:
# Basic initialization (single device per process)
collective_gemm_bootstrap(
num_total_devices=8,
num_devices_per_process=1,
process_id=0,
tensor_parallel_size=4
)
# Advanced configuration with custom performance settings
collective_gemm_bootstrap(
num_total_devices=8,
num_devices_per_process=1,
process_id=0,
tensor_parallel_size=4,
num_max_streams=5, # More parallelism
compute_stream_priority=1, # Lower compute priority
communication_stream_priority=0, # Higher comm priority
num_sm_for_communication=4, # More SMs for communication
use_ce=True, # Enable copy engines
aggregate_all_gather=True # Aggregate small operations
)
Note:
This function must be called after JAX distributed initialization
and before any collective GEMM operations. Each process should call
this function with its own unique process_id.
"""
assert (
num_devices_per_process == 1 and jax.local_device_count() == 1
), "Only single device per process is supported at the moment!"
assert num_total_devices % num_devices_per_process == 0, (
f"Invalid num_total_devices={num_total_devices},"
f" num_devices_per_process={num_devices_per_process}"
)
assert 0 <= process_id < num_total_devices, f"Invalid process_id={process_id}"
initialize_cgemm_communicator(
num_total_devices,
num_devices_per_process,
process_id,
tensor_parallel_size,
num_max_streams,
compute_stream_priority,
communication_stream_priority,
num_sm_for_communication,
use_ce,
aggregate_all_gather,
)
class CollectiveOp(Enum):
"Enum for Collective Type in Collective GEMM"
NONE = JAXX_Collective_Op.NONE
ALL_GATHER = JAXX_Collective_Op.ALL_GATHER
REDUCE_SCATTER = JAXX_Collective_Op.REDUCE_SCATTER
@property
def is_all_gather(self) -> bool:
"""Check if AllGather"""
return self == CollectiveOp.ALL_GATHER
@property
def is_reduce_scatter(self) -> bool:
"""Check if ReduceScatter"""
return self == CollectiveOp.REDUCE_SCATTER
@property
def is_none(self) -> bool:
"""Check if None"""
return self == CollectiveOp.NONE
@dataclass(frozen=True)
class CollectiveOpSet:
"""
A set of CollectiveOp objects that provide complementary collective GEMM configurations for the Forward and Backward passes through Dense-layers.
"""
forward: CollectiveOp
backward: CollectiveOp
@staticmethod
def create(forward_collective_op: CollectiveOp):
"""Create a set of CollectiveOp for forward and backward passes"""
if forward_collective_op.is_all_gather:
backward_collective_op = CollectiveOp.REDUCE_SCATTER
elif forward_collective_op.is_reduce_scatter:
backward_collective_op = CollectiveOp.ALL_GATHER
else:
backward_collective_op = CollectiveOp.NONE
return CollectiveOpSet(forward=forward_collective_op, backward=backward_collective_op)
noop_collective_op_set = CollectiveOpSet.create(forward_collective_op=CollectiveOp.NONE)
@partial(jax.jit, static_argnums=(1, 2))
def swizzled_scale(scale_inv, flatten_axis, is_colwise):
"Swizzle scale_inv via JAX transpose ops"
......@@ -174,7 +344,7 @@ class GemmPrimitive(BasePrimitive):
name = "te_gemm_ffi"
multiple_results = True
impl_static_args = (6, 7, 8, 9, 10, 11, 12)
impl_static_args = 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
inner_primitive = None
outer_primitive = None
......@@ -193,8 +363,12 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu,
grad,
use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
):
del use_split_accumulator
del use_split_accumulator, transpose_batch_sequence
def _dims_are_consecutive(dims):
if len(dims) <= 1:
......@@ -238,7 +412,7 @@ class GemmPrimitive(BasePrimitive):
), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands."
if (
scaling_mode != ScalingMode.MXFP8_1D_SCALING
and not tex.is_non_nt_fp8_gemm_supported()
and not is_fp8_gemm_with_all_layouts_supported()
):
assert not lhs_is_transposed and rhs_is_transposed, (
"cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) "
......@@ -263,6 +437,19 @@ class GemmPrimitive(BasePrimitive):
out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape)
output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
# Adjust output shape for comm+GEMM overlap
if not collective_op.is_none and not is_outer: # Inner abstract
assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}"
overlap_out_shape = list(out_shape).copy()
if collective_op.is_all_gather:
overlap_out_shape[1] *= tpsp_axis_size()
else: # RS
overlap_out_shape[sequence_dim] = (
overlap_out_shape[sequence_dim] // tpsp_axis_size()
)
assert out_dtype == jnp.bfloat16, f"Unsupported out_dtype={out_dtype}"
output = jax.core.ShapedArray(shape=overlap_out_shape, dtype=out_dtype)
# Validate bias
bias_shape = (0,)
bias_dtype = out_dtype
......@@ -302,9 +489,12 @@ class GemmPrimitive(BasePrimitive):
pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype)
# Declare cuBLAS workspace
workspace_size = get_cublas_workspace_size_bytes()
if not collective_op.is_none:
workspace_size *= get_cgemm_num_max_streams()
# cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
# necessarily 256 bytes aligned, we add some padding to ensure alignment.
workspace_size = get_cublas_workspace_size_bytes() + 256
workspace_size += 256
workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
return output, bias_grad, pre_gelu_out, workspace
......@@ -330,8 +520,12 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu,
grad,
use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
):
del out_dtype
del out_dtype, transpose_batch_sequence, sequence_dim, is_outer
lhs_aval, _, rhs_aval, *_ = ctx.avals_in
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims)
......@@ -350,6 +544,7 @@ class GemmPrimitive(BasePrimitive):
"fuse_gelu": fuse_gelu,
"grad": grad,
"use_split_accumulator": use_split_accumulator,
"collective_op": int(collective_op.value),
}
operand_output_aliases = {}
......@@ -378,6 +573,10 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu,
grad,
use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
):
if scaling_mode.is_1d_block_scaling():
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
......@@ -396,7 +595,34 @@ class GemmPrimitive(BasePrimitive):
lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed)
rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed)
outputs = GemmPrimitive.inner_primitive.bind(
# Alter lhs blocks so that CGEMM RS outputs correctly
if (
collective_op.is_reduce_scatter
and not transpose_batch_sequence
and not is_outer
and not lhs.shape[0] == 1
):
assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}"
original_shape = lhs.shape
assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, (
f"Original_shape[0]={original_shape[0]} is not divisible by"
f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}"
)
assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, (
f"Original_shape[1]={original_shape[1]} is not divisible by"
f" tpsp_axis_size()={tpsp_axis_size()}"
)
reshaped = lhs.reshape(
dp_or_fsdp_axis_size(),
int(original_shape[0] / dp_or_fsdp_axis_size()),
tpsp_axis_size(),
int(original_shape[1] / tpsp_axis_size()),
*original_shape[2:],
)
reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim))
lhs = reordered.reshape(original_shape)
(output, bias_grad, pre_gelu_out, _) = GemmPrimitive.inner_primitive.bind(
lhs,
lhs_scale_inv,
rhs,
......@@ -410,8 +636,39 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
collective_op=collective_op,
transpose_batch_sequence=transpose_batch_sequence,
sequence_dim=sequence_dim,
is_outer=is_outer,
)
return outputs[:-1] # discard workspace array
# Alter output blocks for CGEMM AG
if (
collective_op.is_all_gather
and not transpose_batch_sequence
and not is_outer
and not output.shape[0] == 1
):
assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}"
original_shape = output.shape
assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, (
f"Original_shape[0]={original_shape[0]} is not divisible by"
f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}"
)
assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, (
f"Original_shape[1]={original_shape[1]} is not divisible by"
f" tpsp_axis_size()={tpsp_axis_size()}"
)
reshaped = output.reshape(
tpsp_axis_size(),
dp_or_fsdp_axis_size(),
int(original_shape[0] / dp_or_fsdp_axis_size()),
int(original_shape[1] / tpsp_axis_size()),
*original_shape[2:],
)
reordered = reshaped.transpose(1, 2, 0, 3, *range(4, reshaped.ndim))
output = reordered.reshape(original_shape)
return [output, bias_grad, pre_gelu_out]
@staticmethod
def outer_impl(
......@@ -428,6 +685,10 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu,
grad,
use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
):
return GemmPrimitive.impl(
lhs,
......@@ -443,6 +704,10 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu,
grad,
use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
)
@staticmethod
......@@ -456,7 +721,12 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu,
grad,
use_split_accumulator,
collective_op,
transpose_batch_sequence,
sequence_dim,
is_outer,
):
del transpose_batch_sequence, sequence_dim, is_outer
assert GemmPrimitive.outer_primitive is not None
lhs_bdims, _, rhs_bdims, *_ = batch_dims
......@@ -484,6 +754,10 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
collective_op=collective_op,
transpose_batch_sequence=transpose_batch_sequence,
sequence_dim=sequence_dim,
is_outer=is_outer,
),
(out_bdims, bias_bdims, pre_gelu_bdims),
)
......@@ -492,6 +766,8 @@ class GemmPrimitive(BasePrimitive):
def _parse_operand_output_specs(
arg_infos,
contracting_dims,
transpose_batch_sequence,
collective_op,
):
lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos)
......@@ -499,13 +775,11 @@ class GemmPrimitive(BasePrimitive):
# Ensure that tensor sequence parallelism is not used via setting tp_resource
if gsr.tp_resource is not None:
for i in range(len(lhs_specs) - 1):
if lhs_specs[i] == gsr.tp_resource and lhs_specs[i + 1] == gsr.tp_resource:
if gsr.tp_resource in lhs_specs:
warnings.warn(
"Tensor sequence parallelism is detected as"
f" tp_resource='{gsr.tp_resource}' appears twice consecutively in"
f" lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource for"
" tensor sequence parallelism to avoid potential issues."
"Tensor sequence parallelism is detected as tp_resource='{gsr.tp_resource}'"
" appears in lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource"
" for tensor sequence parallelism to avoid potential issues."
)
lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs))
......@@ -528,9 +802,42 @@ class GemmPrimitive(BasePrimitive):
assert reduce_spec is None, "Multiple reduce dimension is detected!"
reduce_spec = l
sequence_dim = None
# Find sequence dimension in lhs_specs if tensor sequence parallel is enabled
# We only do CollectiveGemm AG on the x or dY thus they always the LHS and have sequence dim
if collective_op.is_all_gather:
try:
tpsp_idx = lhs_specs.index(gsr.tpsp_resource)
except ValueError as exc:
raise ValueError(
f"tpsp_resource '{gsr.tpsp_resource}' is not found in lhs_specs: {lhs_specs}."
" Please check your sharding configuration."
) from exc
sequence_dim = tpsp_idx
assert (sequence_dim == 1) ^ transpose_batch_sequence, (
"CollectiveGEMM supports only (sequence_dim=1 and transpose_batch_sequence=False)"
" or (sequence_dim=0 and transpose_batch_sequence=True). Received:"
f" sequence_dim={sequence_dim},"
f" transpose_batch_sequence={transpose_batch_sequence}."
)
elif collective_op.is_reduce_scatter:
assert reduce_spec == gsr.tpsp_resource, (
"Only CollectiveGemm RS with the Reduction over the TPSP axis is supported! Got"
f" reduce_spec={reduce_spec}, tpsp_resource={gsr.tpsp_resource}"
)
sequence_dim = int(not transpose_batch_sequence)
if reduce_spec is not None:
# Other non-reduce cdims (if exists) need to be unsharded
lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs)
# Only do AG Sequence dim if not Overlap
if collective_op.is_all_gather:
rhs_cspecs = tuple(
s if s in (reduce_spec, gsr.tpsp_resource) else None for s in rhs_cspecs
)
else:
rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs)
# Non-contracting dims of RHS always needs to be gathered, i.e. for TP + activation_hidden
......@@ -551,13 +858,31 @@ class GemmPrimitive(BasePrimitive):
for spec in rhs_non_cspecs
)
# Only do AG Sequence dim if not Overlap
if not collective_op.is_all_gather:
# Non-contracting dims of LHS to be gathered along the SP axis.
# Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for
# dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet.
lhs_non_cspecs = tuple(None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs)
lhs_non_cspecs = tuple(
None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs
)
out_specs = lhs_non_cspecs + rhs_non_cspecs
# Only do AG Sequence dim if not Overlap RS
if collective_op.is_all_gather:
assert sequence_dim <= len(
lhs_non_cspecs
), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}"
out_specs = out_specs[:sequence_dim] + (None,) + out_specs[sequence_dim + 1 :]
elif collective_op.is_reduce_scatter:
assert sequence_dim <= len(
lhs_non_cspecs
), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}"
out_specs = (
out_specs[:sequence_dim] + (gsr.tpsp_resource,) + out_specs[sequence_dim + 1 :]
)
# specs = merge(cspecs, non_cspecs)
lhs_specs, rhs_specs = map(
lambda cdims, cspecs, non_cspecs: (
......@@ -572,10 +897,14 @@ class GemmPrimitive(BasePrimitive):
bias_specs = tuple(list(rhs_non_cspecs).copy())
gelu_specs = tuple(list(out_specs).copy())
if not collective_op.is_none:
assert sequence_dim >= 0, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}"
return (
(lhs_specs, rhs_specs, bias_specs, gelu_specs),
(out_specs, bias_specs, gelu_specs),
reduce_spec,
sequence_dim,
)
@staticmethod
......@@ -587,6 +916,10 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu,
grad,
use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
mesh,
arg_infos,
result_infos,
......@@ -595,11 +928,16 @@ class GemmPrimitive(BasePrimitive):
out_dtype,
scaling_mode,
grad,
use_split_accumulator,
result_infos,
is_outer,
sequence_dim,
)
del use_split_accumulator, result_infos
(_, (out_specs, dbias_specs, pre_gelu_specs), _) = (
GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims)
(_, (out_specs, dbias_specs, pre_gelu_specs), *_) = (
GemmPrimitive._parse_operand_output_specs(
arg_infos, contracting_dims, transpose_batch_sequence, collective_op
)
)
out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs))
......@@ -624,20 +962,29 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu,
grad,
use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
mesh,
arg_infos,
result_infos,
):
del result_infos
del result_infos, is_outer, sequence_dim
(
(lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs),
(out_specs, dbias_specs, pre_gelu_specs),
reduce_spec,
) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims)
inferred_sequence_dim,
) = GemmPrimitive._parse_operand_output_specs(
arg_infos,
contracting_dims,
transpose_batch_sequence,
collective_op,
)
# Assemble argument shardings
# NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded.
# Block scale inverses match their operands, but tensor scale inverses are unsharded.
none_sharding = NamedSharding(mesh, PartitionSpec(None))
lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs))
rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs))
......@@ -686,10 +1033,18 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
transpose_batch_sequence=transpose_batch_sequence,
sequence_dim=inferred_sequence_dim,
is_outer=False,
collective_op=collective_op,
)
# All-Reduce GEMM output
if reduce_spec is not None:
if reduce_spec is not None and not collective_op.is_reduce_scatter:
if is_all_reduce_in_float32(): # For unittest only
outputs[0] = jax.lax.psum(outputs[0].astype(jnp.float32), reduce_spec).astype(
out_dtype
)
else:
outputs[0] = jax.lax.psum(outputs[0], reduce_spec)
return outputs
......@@ -705,12 +1060,22 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu,
grad,
use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
mesh,
operand_types,
result_types,
):
del out_dtype, grad, use_split_accumulator
del mesh, result_types
del mesh, result_types, transpose_batch_sequence, sequence_dim, is_outer
if not collective_op.is_none:
raise NotImplementedError(
"CollectiveGEMM with Shardy propagation is not supported yet! Please turn off"
" Shardy by exporting env var JAX_USE_SHARDY_PARTITIONER=false"
)
prefix = "Gemm_"
......@@ -792,6 +1157,8 @@ def _te_gemm(
fuse_gelu: bool = False,
grad: bool = False,
use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP,
transpose_batch_sequence: bool = False,
collective_op: CollectiveOp = CollectiveOp.NONE,
) -> Tuple[jax.Array, ...]:
# Prepare non-quantized GEMM operands
......@@ -800,6 +1167,7 @@ def _te_gemm(
lhs_scale_inv = jnp.empty(0, dtype=jnp.float32)
rhs_scale_inv = jnp.empty(0, dtype=jnp.float32)
scaling_mode = ScalingMode.NO_SCALING
lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims)
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
......@@ -859,6 +1227,10 @@ def _te_gemm(
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
transpose_batch_sequence=transpose_batch_sequence,
sequence_dim=-1,
is_outer=True,
collective_op=collective_op,
)
......@@ -1176,6 +1548,8 @@ def gemm(
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
lhs_quantizer: Quantizer = None,
rhs_quantizer: Quantizer = None,
transpose_batch_sequence: bool = False,
collective_op: CollectiveOp = CollectiveOp.NONE,
**kwargs,
) -> Tuple[jnp.ndarray, ...]:
r"""General matrix multiplication with optional quantization.
......@@ -1209,8 +1583,11 @@ def gemm(
TE's custom call to cuBLAS GEMM.
use_split_accumulator: bool, default = True
Enable promoting some intermediate sums to higher precision when accumulating the result in
the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only
supported with TE's custom call to cuBLAS GEMM.
the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed.
transpose_batch_sequence: bool, default = False
Transpose the batch and sequence dimensions of the input tensor.
collective_op: CollectiveOp, default = CollectiveOp.NONE
Collective operation type for collective GEMM.
Returns
-------
......@@ -1254,6 +1631,7 @@ def gemm(
"`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled."
)
assert collective_op.is_none, "JAX GEMM does not support collective GEMM"
return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer)
outputs = _te_gemm(
......@@ -1262,6 +1640,8 @@ def gemm(
lhs_quantizer=lhs_quantizer,
rhs_quantizer=rhs_quantizer,
contracting_dims=contracting_dims,
transpose_batch_sequence=transpose_batch_sequence,
collective_op=collective_op,
**kwargs,
)
......
......@@ -293,3 +293,11 @@ class NamedSharding(jax.sharding.NamedSharding):
Create a new NamedSharding with the same mesh and spec but with a new description.
"""
return NamedSharding(self.mesh, self.spec, desc=desc)
@functools.lru_cache(maxsize=1)
def is_all_reduce_in_float32():
"""
Check if all-reduce is in float32
"""
return os.getenv("NVTE_JAX_ALL_REDUCE_IN_FP32", "0") == "1"
......@@ -13,6 +13,7 @@
#include <cudnn.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <transformer_engine/comm_gemm_overlap.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
......@@ -32,9 +33,6 @@
#include "transformer_engine/activation.h"
#include "transformer_engine/multi_stream.h"
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
namespace transformer_engine {
namespace jax {
......@@ -121,6 +119,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
// GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler);
// Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
......@@ -134,4 +133,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler);
} // namespace jax
} // namespace transformer_engine
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op);
#endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "cgemm_helper.h"
#include "common/util/system.h"
#include "nccl.h"
namespace transformer_engine {
namespace jax {
ncclUniqueId CommunicatorHandler::coordinate_nccl_unique_id(const std::string &id_type) {
ncclUniqueId unique_id;
int tp_domain_id = get_tp_domain_id();
bool is_tp_leader = (get_local_device_id_within_tp_domain() == 0);
pid_t pgid = getpgid(0);
std::string base_path = getenv<std::string>("NVTE_JAX_NCCL_FILE_PATH", "/tmp");
std::string id_file = base_path + "/nccl_" + id_type + "_unique_id_pgid_" + std::to_string(pgid) +
"_" + std::to_string(num_total_devices) + "_" + std::to_string(tp_size) +
"_domain_" + std::to_string(tp_domain_id) + ".bin";
if (is_tp_leader) {
NVTE_CHECK_NCCL(ncclGetUniqueId(&unique_id));
// Write the ID to a temporary file
std::ofstream file(id_file, std::ios::binary);
NVTE_CHECK(file.is_open(), "Failed to create NCCL unique ID file: ", id_file);
file.write(reinterpret_cast<const char *>(&unique_id), sizeof(ncclUniqueId));
file.close();
} else {
// Wait for the ID file to be created and read it
int attempts = 0;
const int max_attempts = 100;
while (attempts < max_attempts) {
std::ifstream file(id_file, std::ios::binary);
if (file.is_open()) {
file.read(reinterpret_cast<char *>(&unique_id), sizeof(ncclUniqueId));
if (file.gcount() == sizeof(ncclUniqueId)) {
file.close();
break;
}
file.close();
}
std::this_thread::sleep_for(std::chrono::milliseconds(100));
attempts++;
}
NVTE_CHECK(attempts < max_attempts,
"Timeout waiting for " + id_type + " NCCL unique ID file from leader: ", id_file);
}
if (is_tp_leader) {
_nccl_id_file_name.push_back(id_file);
}
return unique_id;
}
void CommunicatorHandler::init(int num_total_devices, int num_devices_per_process, int process_id,
int tp_size) {
// Validate inputs
NVTE_CHECK(num_devices_per_process == 1,
"num_devices_per_process must be == 1, got num_devices_per_process=",
num_devices_per_process);
NVTE_CHECK(num_total_devices >= 1,
"num_total_devices must be >= 1, got num_total_devices=", num_total_devices);
NVTE_CHECK(
num_total_devices % num_devices_per_process == 0,
"num_total_devices must be divisible by num_devices_per_process, got num_total_devices=",
num_total_devices, ", num_devices_per_process=", num_devices_per_process);
// Validate TP size
NVTE_CHECK(tp_size > 0, "tp_size must be > 0, got tp_size=", tp_size);
NVTE_CHECK(num_total_devices % tp_size == 0,
"num_total_devices must be divisible by tp_size, got num_total_devices=",
num_total_devices, ", tp_size=", tp_size);
auto &handler = get(false);
handler.num_total_devices = num_total_devices;
handler.num_devices_per_process = num_devices_per_process;
handler.process_id = process_id;
handler.num_processes = num_total_devices / num_devices_per_process;
handler.tp_size = tp_size;
handler.tp_num_domains = num_total_devices / tp_size;
// Initialize vectors with the correct size
handler.local_device_ids_within_process.resize(num_devices_per_process);
handler.local_device_ids_within_tp_domain.resize(num_devices_per_process);
handler.tp_domain_ids.resize(num_devices_per_process);
handler.global_device_ids.resize(num_devices_per_process);
handler.tp_comms.resize(num_devices_per_process);
NVTE_CHECK(0 <= process_id && process_id < handler.num_processes,
"Invalid process_id=", process_id, ", which is out of range [0, ",
handler.num_processes, ")");
// Initialize local devices and calculate their global device IDs and TP topology
for (int local_idx = 0; local_idx < num_devices_per_process; local_idx++) {
// Use the device that JAX has already assigned to this process
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
handler.local_device_ids_within_process[local_idx] = current_device;
handler.global_device_ids[local_idx] = process_id * num_devices_per_process + local_idx;
// Calculate TP-related values for this device
int global_device_id = handler.global_device_ids[local_idx];
if (num_devices_per_process == tp_size) {
// Scenario 1: Multi-device per process - TP domain = single process
handler.local_device_ids_within_tp_domain[local_idx] = local_idx;
handler.tp_domain_ids[local_idx] = process_id;
} else {
// Scenario 2: Single device per process - TP domain spans multiple processes
handler.local_device_ids_within_tp_domain[local_idx] = global_device_id % tp_size;
handler.tp_domain_ids[local_idx] = global_device_id / tp_size;
}
}
ncclUniqueId tp_id = handler.coordinate_nccl_unique_id("tp");
NVTE_CHECK_NCCL(ncclGroupStart());
for (int local_idx = 0; local_idx < num_devices_per_process; local_idx++) {
NVTE_CHECK_CUDA(cudaSetDevice(handler.local_device_ids_within_process[local_idx]));
int tp_local_rank = handler.local_device_ids_within_tp_domain[local_idx];
NVTE_CHECK_NCCL(
ncclCommInitRank(&handler.tp_comms[local_idx], handler.tp_size, tp_id, tp_local_rank));
}
NVTE_CHECK_NCCL(ncclGroupEnd());
// Allocate device memory for barrier operations
NVTE_CHECK_CUDA(cudaMalloc(&handler._device_barrier, sizeof(int)));
handler._initialize = true;
// Bootstrap UB via creating a dummy CommOverlapP2PBase object
std::vector<size_t> buffer_shape{1, 1};
auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, DType::kFloat32,
JAXX_Collective_Op::ALL_GATHER);
}
void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id,
int tp_size, int num_max_streams, int gemm_priority,
int comm_priority, int num_comm_sm, bool use_ce,
bool aggregate_ag) {
auto &config = CgemmConfig::get(false);
config.init(num_max_streams, gemm_priority, comm_priority, num_comm_sm, use_ce, aggregate_ag);
auto &handler = CommunicatorHandler::get(false);
handler.init(num_total_devices, num_devices_per_process, process_id, tp_size);
}
int GetCgemmNumMaxStreams() {
auto &config = CgemmConfig::get();
return config.num_max_streams;
}
CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector<size_t> buffer_shape,
DType dtype,
JAXX_Collective_Op collective_op) {
auto &comm_handler = CommunicatorHandler::get();
auto &cgemm_config = CgemmConfig::get();
int device_idx = comm_handler.get_local_device_idx_for_current_device();
int64_t plan_id = 0;
hash_combine(plan_id, buffer_shape[0], buffer_shape[1], static_cast<size_t>(dtype),
static_cast<int>(collective_op), comm_handler.tp_size, cgemm_config.num_max_streams,
cgemm_config.gemm_priority, cgemm_config.comm_priority, cgemm_config.num_comm_sm,
cgemm_config.use_ce, cgemm_config.aggregate_ag, device_idx);
auto it = plan_map.find(plan_id);
if (it != plan_map.end()) {
return it->second.get();
}
if (comm_handler.num_devices_per_process == comm_handler.tp_size) {
// Multi-device per process
} else if (comm_handler.num_devices_per_process == 1) {
// Single device per process
NVTE_CHECK(comm_handler.num_total_devices % comm_handler.tp_size == 0,
"For single device per process, num_total_devices must be divisible by tp_size, "
"got num_total_devices=",
comm_handler.num_total_devices, ", tp_size=", comm_handler.tp_size);
} else {
NVTE_ERROR("Unsupported TP configuration: num_devices_per_process=",
comm_handler.num_devices_per_process, ", tp_size=", comm_handler.tp_size,
". Supported scenarios: "
"(1) num_devices_per_process == tp_size (multi-device per process), "
"(2) num_devices_per_process == 1 (single device per process)");
}
std::unique_ptr<CommOverlapCore> executor;
executor = std::make_unique<CommOverlapP2PBase>(
buffer_shape, dtype, comm_handler.get_global_rank(), comm_handler.num_total_devices,
comm_handler.get_local_device_id_within_tp_domain(), comm_handler.tp_size,
comm_handler.get_tp_domain_id(), comm_handler.get_tp_num_domains(), comm_handler.tp_size,
comm_handler.allgather_func, comm_handler.barrier_func, get_nvte_collective_op(collective_op),
cgemm_config.num_max_streams, 1 /*comm_cga_size*/, cgemm_config.gemm_priority,
cgemm_config.comm_priority, cgemm_config.num_comm_sm, true /*set_sm_margin*/,
cgemm_config.use_ce, false /*atomic_gemm*/, cgemm_config.aggregate_ag);
CommOverlapCore *executor_ptr = executor.get();
plan_map[plan_id] = std::move(executor);
return executor_ptr;
}
void CommunicatorHandler::nccl_device_barrier_impl(ExtComm) {
NVTE_CHECK(_initialize, "CommunicatorHandler must be initialized before using barrier");
int device_idx = get_local_device_idx_for_current_device();
ncclComm_t tp_comm = tp_comms[device_idx];
NVTE_CHECK_NCCL(
ncclAllReduce(_device_barrier, _device_barrier, 1, ncclInt, ncclSum, tp_comm, nullptr));
cudaDeviceSynchronize();
}
void CommunicatorHandler::nccl_allgather_impl(void *output_buf, size_t output_bytes,
void *input_buf, size_t input_bytes, ExtComm) {
NVTE_CHECK(_initialize, "CommunicatorHandler must be initialized before using allgather");
int device_idx = get_local_device_idx_for_current_device();
ncclComm_t tp_comm = tp_comms[device_idx];
size_t expected_output_bytes = input_bytes * tp_size;
NVTE_CHECK(output_bytes == expected_output_bytes, "TP allgather buffer size mismatch: expected ",
expected_output_bytes, ", got ", output_bytes);
NVTE_CHECK_NCCL(ncclAllGather(input_buf, output_buf, input_bytes, ncclChar, tp_comm, nullptr));
cudaDeviceSynchronize();
}
CommunicatorHandler::CommunicatorHandler() : _device_barrier(nullptr) {
allgather_func = [this](void *output_buf, size_t output_bytes, void *input_buf,
size_t input_bytes, ExtComm comm) {
this->nccl_allgather_impl(output_buf, output_bytes, input_buf, input_bytes, comm);
};
barrier_func = [this](ExtComm comm) { this->nccl_device_barrier_impl(comm); };
}
CommunicatorHandler::~CommunicatorHandler() {
if (_initialize && !tp_comms.empty()) {
for (auto &comm : tp_comms) {
if (comm != nullptr) {
ncclCommDestroy(comm);
}
}
}
if (_device_barrier) cudaFree(_device_barrier);
for (const auto &file_path : _nccl_id_file_name) {
std::remove(file_path.c_str());
}
}
} // namespace jax
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_
#define TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_
#include <unistd.h>
#include <chrono>
#include <cstdio>
#include <fstream>
#include <functional>
#include <memory>
#include <thread>
#include <unordered_map>
#include "../extensions.h"
#include "common/comm_gemm_overlap/userbuffers/userbuffers.h"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
#include "transformer_engine/comm_gemm_overlap.h"
namespace transformer_engine {
namespace jax {
// Configuration singleton for CGEMM parameters
class CgemmConfig {
public:
int num_max_streams;
int gemm_priority;
int comm_priority;
int num_comm_sm;
bool use_ce;
bool aggregate_ag;
static void init(int _num_max_streams, int _gemm_priority, int _comm_priority, int _num_comm_sm,
bool _use_ce, bool _aggregate_ag) {
auto &config = get(false);
config._initialized = true;
config.num_max_streams = _num_max_streams;
config.gemm_priority = _gemm_priority;
config.comm_priority = _comm_priority;
config.num_comm_sm = _num_comm_sm;
config.use_ce = _use_ce;
config.aggregate_ag = _aggregate_ag;
}
static CgemmConfig &get(bool is_initialized = true) {
static thread_local CgemmConfig instance;
NVTE_CHECK(
instance._initialized == is_initialized,
"CgemmConfig must be initialized before using it, got is_initialized=", is_initialized);
return instance;
}
CgemmConfig(const CgemmConfig &) = delete;
CgemmConfig &operator=(const CgemmConfig &) = delete;
private:
CgemmConfig() = default;
~CgemmConfig() = default;
bool _initialized = false;
};
// Forward declaration
class CollectiveGemmPlanRegistry;
// NCCL communicator handler for collective GEMM operations
// Support both single process single device AND single process multi device
// Two scenarios:
// 1. Single process multiple devices: TP domain = process (num_devices_per_process == tp_size)
// 2. Single process single device: TP domain spans processes (num_devices_per_process == 1)
class CommunicatorHandler {
public:
int num_total_devices = -1;
int num_devices_per_process = -1;
int process_id = -1;
int num_processes = -1;
int tp_size = -1;
int tp_num_domains = -1;
std::vector<int> local_device_ids_within_tp_domain;
std::vector<int> tp_domain_ids;
std::vector<ncclComm_t> tp_comms;
std::vector<int> local_device_ids_within_process;
std::vector<int> global_device_ids;
int get_global_rank() const {
int device_idx = get_local_device_idx_for_current_device();
return global_device_ids[device_idx];
}
void nccl_device_barrier_impl(ExtComm);
void nccl_allgather_impl(void *output_buf, size_t output_bytes, void *input_buf,
size_t input_bytes, ExtComm);
ncclComm_t get_comm_for_current_device() const {
int device_idx = get_local_device_idx_for_current_device();
return tp_comms[device_idx];
}
int get_local_device_idx_for_current_device() const {
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
for (int i = 0; i < num_devices_per_process; i++) {
if (local_device_ids_within_process[i] == current_device) {
return i;
}
}
NVTE_ERROR("Current CUDA device ", current_device,
" not found in local_device_ids_within_process");
}
int get_local_device_id_within_tp_domain() const {
int device_idx = get_local_device_idx_for_current_device();
return local_device_ids_within_tp_domain[device_idx];
}
int get_tp_domain_id() const {
int device_idx = get_local_device_idx_for_current_device();
return tp_domain_ids[device_idx];
}
int get_tp_num_domains() const { return tp_num_domains; }
static void init(int num_total_devices, int num_devices_per_process, int process_id, int tp_size);
private:
ncclUniqueId coordinate_nccl_unique_id(const std::string &id_type);
public:
static CommunicatorHandler &get(bool is_initialized = true) {
static CommunicatorHandler instance;
NVTE_CHECK(instance._initialize == is_initialized,
"CommunicatorHandler._initialize=", instance._initialize,
", is_initialized=", is_initialized);
return instance;
}
ExtAllgatherOp allgather_func;
ExtBarrierOp barrier_func;
CommunicatorHandler(const CommunicatorHandler &) = delete;
CommunicatorHandler &operator=(const CommunicatorHandler &) = delete;
private:
CommunicatorHandler();
~CommunicatorHandler();
bool _initialize = false;
int *_device_barrier = nullptr;
std::vector<std::string> _nccl_id_file_name;
};
// Plan registry for caching collective GEMM executors
class CollectiveGemmPlanRegistry {
public:
static CollectiveGemmPlanRegistry &getInstance() {
static thread_local CollectiveGemmPlanRegistry instance;
return instance;
}
CommOverlapCore *get_executor(std::vector<size_t> buffer_shape, DType dtype,
JAXX_Collective_Op collective_op);
private:
CollectiveGemmPlanRegistry() {}
CollectiveGemmPlanRegistry(const CollectiveGemmPlanRegistry &) = delete;
CollectiveGemmPlanRegistry &operator=(const CollectiveGemmPlanRegistry &) = delete;
std::unordered_map<int64_t, std::unique_ptr<CommOverlapCore>> plan_map;
};
// Function declarations
void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id,
int tp_size, int num_max_streams, int gemm_priority,
int comm_priority, int num_comm_sm, bool use_ce,
bool aggregate_ag);
int GetCgemmNumMaxStreams();
} // namespace jax
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_
......@@ -6,13 +6,19 @@
#include "transformer_engine/gemm.h"
#include <memory>
#include <mutex>
#include <stdexcept>
#include <string_view>
#include <tuple>
#include "../extensions.h"
#include "cgemm_helper.h"
#include "common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/string.h"
#include "common/util/system.h"
#include "cuda_runtime.h"
#include "nccl.h"
#include "transformer_engine/swizzle.h"
#include "xla/ffi/api/c_api.h"
......@@ -66,12 +72,75 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
return std::make_tuple(std::move(input), input_shape);
}
Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias,
Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad,
Result_Type pre_gelu_out, Result_Type workspace,
JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed,
bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad,
bool use_split_accumulator, JAXX_Collective_Op collective_op) {
nvte_cublas_handle_init();
// Init UB buffer
if (collective_op != JAXX_Collective_Op::NONE) {
auto &comm_handler = CommunicatorHandler::get();
std::vector<size_t> lhs_shape = {
product(lhs.dimensions(), 0, lhs_axis_boundary),
product(lhs.dimensions(), lhs_axis_boundary, lhs.dimensions().size())};
std::vector<size_t> rhs_shape = {
product(rhs.dimensions(), 0, rhs_axis_boundary),
product(rhs.dimensions(), rhs_axis_boundary, rhs.dimensions().size())};
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
(rhs_transposed) ? rhs_shape[0] : rhs_shape[1]};
std::vector<size_t> buffer_shape{0, 0};
DType buffer_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
if (collective_op == JAXX_Collective_Op::ALL_GATHER) {
buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size;
buffer_shape[1] = lhs_shape[1];
buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type());
} else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) {
buffer_shape[0] = out_shape[0];
buffer_shape[1] = out_shape[1];
}
auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, buffer_dtype,
collective_op);
}
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI,
FFI::Bind<FFI_Prepare>()
.Arg<Buffer_Type>() // lhs
.Arg<Buffer_Type>() // lhs_scale_inv
.Arg<Buffer_Type>() // rhs
.Arg<Buffer_Type>() // rhs_scale_inv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // gelu_input
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out
.Ret<Buffer_Type>() // workspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("lhs_axis_boundary")
.Attr<int64_t>("rhs_axis_boundary")
.Attr<bool>("lhs_transposed")
.Attr<bool>("rhs_transposed")
.Attr<bool>("fuse_bias")
.Attr<bool>("fuse_gelu")
.Attr<bool>("grad")
.Attr<bool>("use_split_accumulator")
.Attr<JAXX_Collective_Op>("collective_op"));
Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input,
Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out,
Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed,
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) {
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator,
JAXX_Collective_Op collective_op) {
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
......@@ -83,16 +152,9 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode,
rhs_axis_boundary, make_rhs_rowwise);
// Output tensor
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
(rhs_transposed) ? rhs_shape[0] : rhs_shape[1]};
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, "
"expected ",
out_.numel(), " elements ", to_string_like(out_shape), " but got ",
output->element_count(), " elements ", to_string_like(output->dimensions()));
// Bias input to forward pass or bias gradient output from backward pass
void *bias_ptr = nullptr;
......@@ -133,9 +195,62 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
// Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
if (collective_op == JAXX_Collective_Op::NONE) {
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ",
to_string_like(out_shape), " but got ", output->element_count(), " elements ",
to_string_like(output->dimensions()));
nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(),
rhs_transposed, lhs_transposed, grad, workspace_.data(), false,
use_split_accumulator, num_math_sm, stream);
} else {
std::vector<size_t> buffer_shape{0, 0};
DType buffer_dtype = out_dtype;
auto &comm_handler = CommunicatorHandler::get();
if (collective_op == JAXX_Collective_Op::ALL_GATHER) {
buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size;
buffer_shape[1] = lhs_shape[1];
out_shape[0] = out_shape[0] * comm_handler.tp_size;
buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type());
} else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) {
buffer_shape[0] = out_shape[0];
buffer_shape[1] = out_shape[1];
out_shape[0] = out_shape[0] / comm_handler.tp_size;
}
auto executor = CollectiveGemmPlanRegistry::getInstance().get_executor(
buffer_shape, buffer_dtype, collective_op);
if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) {
auto ubuf_out_ = TensorWrapper(executor->get_ubuf_dptr(), buffer_shape, out_dtype);
// Prepare the auxiliary buffer for the reduce-scattered GEMM output
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(),
" elements ", to_string_like(out_shape), " but got ", output->element_count(),
" elements ", to_string_like(output->dimensions()));
// Launch GEMM+RS
executor->split_overlap_rs(rhs_, rhs_transposed, lhs_, lhs_transposed, ubuf_out_, bias_,
pre_gelu_, workspace_, grad, false, use_split_accumulator, out_,
stream);
} else if (collective_op == JAXX_Collective_Op::ALL_GATHER) {
auto aux_out_ = TensorWrapper(nullptr, std::vector<size_t>{0}, out_dtype); // Empty
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(),
" elements ", to_string_like(out_shape), " but got ", output->element_count(),
" elements ", to_string_like(output->dimensions()));
// Copy the distributed LHS operand into the local chunk of the communication buffer
executor->copy_into_buffer(stream, lhs_, true, make_lhs_rowwise);
// Launch AG+GEMM
executor->split_overlap_ag(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_,
workspace_, grad, false, use_split_accumulator, aux_out_, stream);
}
}
return ffi_with_cuda_error_check();
}
......@@ -161,7 +276,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
.Attr<bool>("fuse_bias")
.Attr<bool>("fuse_gelu")
.Attr<bool>("grad")
.Attr<bool>("use_split_accumulator"),
.Attr<bool>("use_split_accumulator")
.Attr<JAXX_Collective_Op>("collective_op"),
FFI_CudaGraph_Traits);
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv,
......
......@@ -87,5 +87,31 @@ constexpr struct Alignment {
std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise);
template <typename T, typename... Rest>
void hash_combine(int64_t &seed, const T &v, Rest... rest) {
seed ^= std::hash<T>{}(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
(hash_combine(seed, rest), ...);
}
enum class JAXX_Collective_Op : int64_t {
NONE = 0,
ALL_GATHER = 1,
REDUCE_SCATTER = 2,
};
static CommOverlapType get_nvte_collective_op(const JAXX_Collective_Op &op) {
switch (op) {
case JAXX_Collective_Op::ALL_GATHER:
return CommOverlapType::AG;
break;
case JAXX_Collective_Op::REDUCE_SCATTER:
return CommOverlapType::RS;
break;
default:
NVTE_ERROR("Invalid Collective Op ", static_cast<int>(op));
break;
}
}
} // namespace jax
} // namespace transformer_engine
......@@ -5,6 +5,8 @@
************************************************************************/
#include "../extensions.h"
#include "cgemm_helper.h"
#include "common/util/cuda_runtime.h"
namespace transformer_engine {
namespace jax {
......@@ -57,7 +59,7 @@ pybind11::dict Registrations() {
// GEMM
dict["te_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CollectiveGemmInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GemmHandler));
// Grouped GEMM
......@@ -84,6 +86,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported);
m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator);
m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams);
pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte)
......@@ -159,6 +163,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE)
.export_values();
pybind11::enum_<JAXX_Collective_Op>(m, "JAXX_Collective_Op", pybind11::module_local())
.value("NONE", JAXX_Collective_Op::NONE)
.value("ALL_GATHER", JAXX_Collective_Op::ALL_GATHER)
.value("REDUCE_SCATTER", JAXX_Collective_Op::REDUCE_SCATTER)
.export_values();
}
} // namespace jax
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment