"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "83f9cc092c4d1b986fb6979b4c2612a15dbfc7e0"
Unverified Commit d75bf43f authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] CollectiveGemm (#2166)



* init cgemm + unit tests

* UB bootstrap with NCCL, no MPI dependency

* add NVLINK-P2P check + error message

* skip tests if no NVLINK available

* use std::vector to store ncclComm_t

* update misuse of TP warning
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 4d145786
...@@ -87,4 +87,5 @@ def setup_jax_extension( ...@@ -87,4 +87,5 @@ def setup_jax_extension(
sources=[str(path) for path in sources], sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs], include_dirs=[str(path) for path in include_dirs],
extra_compile_args=cxx_flags, extra_compile_args=cxx_flags,
libraries=["nccl"],
) )
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Shared functions for the comm_overlap tests"""
import jax.numpy as jnp
import numpy as np
# Add this after your existing imports
def dtype_tols(dtype, rtol=None, atol=None):
"""Expected numerical tolerance for a data type."""
# Return immediately if tolerances are fully specified
if rtol is not None and atol is not None:
return {"rtol": rtol, "atol": atol}
# Default tolerances for common dtypes
if dtype in [jnp.float32, "float32"]:
return {"rtol": 1e-5, "atol": 1e-8}
elif dtype in [jnp.float16, "float16"]:
return {"rtol": 1e-3, "atol": 1e-6}
elif dtype in [jnp.bfloat16, "bfloat16"]:
return {"rtol": 1e-2, "atol": 1e-5}
else:
return {"rtol": 1e-5, "atol": 1e-8}
def assert_allclose(
actual,
desired,
rtol=None,
atol=None,
dtype=None,
**kwargs,
):
"""Check if two tensors are close."""
# Infer data type if needed
if dtype is None:
if isinstance(actual, float):
dtype = "float32"
else:
dtype = actual.dtype
# Determine tolerances
tols = {}
if rtol is None or atol is None:
tols = dtype_tols(dtype)
if rtol is not None:
tols["rtol"] = rtol
if atol is not None:
tols["atol"] = atol
# Cast tensors to fp32
if not isinstance(actual, float):
actual = actual.astype(jnp.float32)
if not isinstance(desired, float):
desired = desired.astype(jnp.float32)
# Check if tensors are close
np.testing.assert_allclose(actual, desired, **tols, **kwargs)
def assert_allclose_print_index(ref_output, gathered_output, rtol=1e-5, atol=1e-8):
if not jnp.allclose(ref_output, gathered_output, rtol=rtol, atol=atol):
diff = jnp.abs(ref_output - gathered_output)
mask = diff > (atol + rtol * jnp.abs(gathered_output))
print(mask.astype(int))
print(jnp.where(mask, diff, 0))
# Shared constants for all tests
DP_AXIS = "data"
TPSP_AXIS = "tensor_sequence"
PARAMS_KEY = "params"
# Shared functions for distributed testing
import argparse
import jax
from jax.experimental import mesh_utils
from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap
# Global flag to track if distributed has been initialized
_distributed_initialized = False
def _is_distributed_initialized():
"""Check if JAX distributed has been initialized."""
return _distributed_initialized
def _initialize_distributed(args):
"""Initialize JAX distributed with custom arguments."""
global _distributed_initialized
# Check if already initialized
if _distributed_initialized:
return
if args.coordinator_address is None or args.num_processes is None or args.process_id is None:
raise ValueError(
"All distributed initialization arguments are required: "
"--coordinator-address, --num-processes, --process-id"
)
if args.local_device_ids is None:
assert (
args.num_devices_per_process is not None
), "Either local_device_ids or num_devices_per_process must be provided"
# Calculate device range for this process
# Single process single device: each process gets one unique device
# Single process multiple devices: each process gets a unique range of devices
start_device = args.process_id * args.num_devices_per_process
device_range = range(start_device, start_device + args.num_devices_per_process)
global_device_ids_for_this_process = ",".join(map(str, device_range))
else:
# Use explicitly provided global device IDs
global_device_ids_for_this_process = args.local_device_ids
args.num_devices_per_process = len(args.local_device_ids.split(","))
assert args.num_devices_per_process == 1, "Only single process single GPU is supported!"
print(
f"Initializing JAX distributed with coordinator={args.coordinator_address}, "
f"num_processes={args.num_processes}, process_id={args.process_id}"
)
# Note: "local_device_ids" is a JAX term meaning "global CUDA devices managed by this process"
jax.distributed.initialize(
coordinator_address=args.coordinator_address,
num_processes=args.num_processes,
process_id=args.process_id,
local_device_ids=global_device_ids_for_this_process,
)
_distributed_initialized = True
jax.clear_caches()
jax.config.update(
"jax_use_shardy_partitioner", False
) # CollectiveGEMM does not work with Shardy yet
assert jax.local_device_count() == 1, (
f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found"
f" {jax.local_device_count()}"
)
devices_per_process = 1
num_total_devices = args.num_processes
print(
f"Initializing CGEMM communicator with num_total_devices={num_total_devices},"
f" devices_per_process={devices_per_process}, process_id={args.process_id}"
)
collective_gemm_bootstrap(
num_total_devices=num_total_devices,
num_devices_per_process=devices_per_process,
process_id=args.process_id,
tensor_parallel_size=args.tensor_parallel_size,
)
def _get_dp_and_tp_sizes(args):
num_gpu = args.num_processes * args.num_devices_per_process
if args.tensor_parallel_size is None:
num_gpu_dp = 2 if args.enable_data_parallel else 1
assert (
num_gpu > 1 and num_gpu % num_gpu_dp == 0
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_tp = num_gpu // num_gpu_dp
else:
num_gpu_tp = args.tensor_parallel_size
assert (
num_gpu > 1 and num_gpu % num_gpu_tp == 0
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_dp = num_gpu // num_gpu_tp
return num_gpu_dp, num_gpu_tp
def _create_mesh(args):
"""Create mesh configuration with proper validation."""
num_gpu = args.num_processes * args.num_devices_per_process
assert num_gpu == len(jax.devices()), "Number of GPUs must be equal to number of devices"
num_gpu_dp, num_gpu_tp = _get_dp_and_tp_sizes(args)
print(f"Using {num_gpu_dp}x{num_gpu_tp} mesh ({num_gpu_dp * num_gpu_tp} total GPUs)")
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=(DP_AXIS, TPSP_AXIS))
return mesh
def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor parallelism"):
"""Create common argument parser for all collective GEMM tests."""
parser = argparse.ArgumentParser(description=description)
# Distributed initialization arguments
parser.add_argument(
"--coordinator-address",
type=str,
default=None,
help="Coordinator address for distributed initialization",
)
parser.add_argument(
"--num-processes",
type=int,
default=None,
help="Number of processes for distributed initialization",
)
parser.add_argument(
"--process-id", type=int, default=None, help="Process ID for distributed initialization"
)
parser.add_argument(
"--local-device-ids",
type=str,
default=None,
help="Local device IDs for distributed initialization (comma-separated)",
)
parser.add_argument(
"--num-devices-per-process", type=int, default=1, help="Number of devices per process"
)
# Test configuration arguments
parser.add_argument(
"--tensor-parallel-size", type=int, default=None, help="Tensor parallel size"
)
parser.add_argument("--batch-size", type=int, default=4, help="Batch size for testing")
parser.add_argument("--seq-len", type=int, default=8192, help="Sequence length for testing")
parser.add_argument("--hidden-in", type=int, default=4096, help="Input hidden dimension")
parser.add_argument("--hidden-out", type=int, default=8192, help="Output hidden dimension")
parser.add_argument(
"--collective-type",
type=str,
default="all_gather",
choices=["all_gather", "reduce_scatter"],
help="Type of collective operation",
)
parser.add_argument(
"--fp8-recipe", type=str, default="DelayedScaling", help="FP8 recipe to use"
)
parser.add_argument(
"--enable-data-parallel", action="store_true", help="Enable data parallelism"
)
parser.add_argument(
"--enable-result-check", action="store_true", default=True, help="Enable result checking"
)
return parser
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""config for collective_gemm tests"""
import pytest
def pytest_addoption(parser):
"""Pytest hook for collective_gemm tests"""
parser.addoption("--coordinator-address", action="store", default="localhost:12345")
parser.addoption("--num-processes", action="store", default=1)
parser.addoption("--process-id", action="store", default=0)
parser.addoption("--local-device-ids", action="store", default=None)
@pytest.fixture(autouse=True)
def distributed_args(request):
"""Fixture for querying distributed initialization arguments"""
if request.cls:
request.cls.coordinator_address = request.config.getoption("--coordinator-address")
request.cls.num_processes = int(request.config.getoption("--num-processes"))
request.cls.process_id = int(request.config.getoption("--process-id"))
request.cls.local_device_ids = request.config.getoption("--local-device-ids")
request.cls.num_devices_per_process = (
1
if request.cls.local_device_ids is None
else len(request.cls.local_device_ids.split(","))
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
# Check if NVLINK is supported before running tests
echo "*** Checking NVLINK support***"
NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1)
NVLINK_EXIT_CODE=$?
# Check if command failed OR output indicates no NVLINK
if [ $NVLINK_EXIT_CODE -ne 0 ] || [[ "$NVLINK_OUTPUT" == *"not supported"* ]] || [[ "$NVLINK_OUTPUT" == *"No devices"* ]] || [ -z "$NVLINK_OUTPUT" ]; then
echo "NVLINK is not supported on this platform"
echo "Collective GEMM tests require NVLINK connectivity"
echo "SKIPPING all tests"
exit 0
else
echo "NVLINK support detected"
fi
# Define the test files to run
TEST_FILES=(
"test_gemm.py"
"test_dense_grad.py"
"test_layernorm_mlp_grad.py"
)
echo
echo "*** Executing tests in examples/jax/collective_gemm/ ***"
HAS_FAILURE=0 # Global failure flag
PIDS=() # Array to store all process PIDs
# Cleanup function to kill all processes
cleanup() {
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Killing process $pid"
kill -TERM "$pid" 2>/dev/null || true
fi
done
# Wait a bit and force kill if needed
sleep 2
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Force killing process $pid"
kill -KILL "$pid" 2>/dev/null || true
fi
done
}
# Set up signal handlers to cleanup on exit
trap cleanup EXIT INT TERM
# Run each test file across all GPUs
for TEST_FILE in "${TEST_FILES[@]}"; do
echo
echo "=== Starting test file: $TEST_FILE ..."
# Clear PIDs array for this test file
PIDS=()
for i in $(seq 0 $(($NUM_GPUS - 1))); do
# Define output file for logs
LOG_FILE="${TEST_FILE}_gpu_${i}.log"
if [ $i -eq 0 ]; then
# For process 0: show live output AND save to log file using tee
echo "=== Live output from process 0 ==="
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
--num-processes=$NUM_GPUS \
--process-id=$i 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
else
# For other processes: redirect to log files only
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
--num-processes=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
PID=$!
PIDS+=($PID)
fi
done
# Wait for all processes to finish
wait
# Check and print the log content from process 0 (now has log file thanks to tee)
if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE SKIPPED"
elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE FAILED"
HAS_FAILURE=1
else
echo "... $TEST_FILE PASSED"
fi
# Remove the log files after processing them
wait
rm ${TEST_FILE}_gpu_*.log
done
wait
# Final cleanup (trap will also call cleanup on exit)
cleanup
exit $HAS_FAILURE
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Collective Dense Gradient test on multi-GPU with tensor parallelism"""
import argparse
import unittest
import os
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding
import flax
from common import (
assert_allclose,
_initialize_distributed,
_get_dp_and_tp_sizes,
_create_mesh,
DP_AXIS,
TPSP_AXIS,
PARAMS_KEY,
cgemm_parser,
)
from transformer_engine.jax.dense import dense
from transformer_engine.jax.quantize import fp8_autocast
from transformer_engine.jax.cpp_extensions.gemm import (
CollectiveOp,
CollectiveOpSet,
noop_collective_op_set,
)
from transformer_engine.jax.sharding import MeshResource
import transformer_engine.jax.flax as te_flax
def _get_logical_axes(collective_op):
if collective_op.is_all_gather:
input_axes = (DP_AXIS, TPSP_AXIS, None)
weight_axes = (None, TPSP_AXIS)
bias_axes = (TPSP_AXIS,)
output_axes = (DP_AXIS, None, TPSP_AXIS)
else: # RS
input_axes = (DP_AXIS, None, TPSP_AXIS)
weight_axes = (TPSP_AXIS, None)
bias_axes = (None,)
output_axes = (DP_AXIS, TPSP_AXIS, None)
return input_axes, weight_axes, bias_axes, output_axes
def _get_operand_sharding(mesh, collective_op):
input_axes, weight_axes, bias_axes, _ = _get_logical_axes(collective_op)
x_sharding = NamedSharding(mesh, PartitionSpec(*input_axes))
weight_sharding = NamedSharding(mesh, PartitionSpec(*weight_axes))
bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes))
return x_sharding, weight_sharding, bias_sharding
def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set):
output = dense(
x,
weight,
bias,
contracting_dims=((2,), (0,)),
input_axes=input_axes,
kernel_axes=weight_axes,
output_axes=output_axes,
collective_op_set=collective_op_set,
)
return jnp.mean(output.astype(jnp.float32))
def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set):
return jax.jit(jax.value_and_grad(_mean_dense, (0, 1, 2)), static_argnums=(3, 4, 5, 6))(
x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set
)
def run_dense_grad_tests(args, mesh=None):
"""Execute Dense Gradient tests."""
print(args)
_initialize_distributed(args)
mesh = mesh or _create_mesh(args)
# Create test data
rng = jax.random.PRNGKey(0)
rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4)
x = jax.random.normal(
x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16
)
weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16)
bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16)
collective_op = (
CollectiveOp.ALL_GATHER
if args.collective_type == "all_gather"
else CollectiveOp.REDUCE_SCATTER
)
collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op)
with mesh, fp8_autocast(
enabled=False,
fp8_recipe=None,
mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
):
# Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
axis_rules = flax.linen.get_logical_axis_rules()
axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS))
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
with flax.linen.logical_axis_rules(te_extended_axis_rules):
x_sharding, weight_sharding, bias_sharding = _get_operand_sharding(mesh, collective_op)
x_sharded = jax.device_put(x, x_sharding)
weight_sharded = jax.device_put(weight, weight_sharding)
bias_sharded = jax.device_put(bias, bias_sharding)
input_axes, weight_axes, _, output_axes = _get_logical_axes(collective_op)
ref_output, ref_grads = _value_and_grad_dense(
x_sharded,
weight_sharded,
bias_sharded,
input_axes,
weight_axes,
output_axes,
noop_collective_op_set,
)
output, sharded_grads = _value_and_grad_dense(
x_sharded,
weight_sharded,
bias_sharded,
input_axes,
weight_axes,
output_axes,
collective_op_set,
)
jax.block_until_ready(ref_output)
jax.block_until_ready(output)
gathered_grads = []
gathered_ref_grads = []
for ref_grad, grad in zip(ref_grads, sharded_grads):
gathered_grads.append(
jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None)))
)
gathered_ref_grads.append(
jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None)))
)
jax.block_until_ready(gathered_grads)
jax.block_until_ready(gathered_ref_grads)
if args.enable_result_check and args.process_id == 0:
assert_allclose(ref_output, output, dtype=jnp.bfloat16)
for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads):
assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16)
class TestCollectiveDenseGradient(unittest.TestCase):
"""Collective Dense Gradient unittests"""
def setUp(self):
self.args = cgemm_parser(
"Collective Dense Gradient test on multi-GPU with tensor parallelism"
).parse_args([])
self.args.coordinator_address = self.coordinator_address
self.args.num_processes = self.num_processes
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
# Create mesh once for all tests
self.mesh = _create_mesh(self.args)
jax.sharding.set_mesh(self.mesh)
self.args.enable_result_check = True
os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1"
def tearDown(self):
os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None)
def test_te_bf16_all_gather(self):
"""Test Collective Dense Gradient with AllGather"""
self.args.collective_type = "all_gather"
run_dense_grad_tests(self.args, self.mesh)
def test_te_bf16_reduce_scatter(self):
"""Test Collective Dense Gradient with ReduceScatter"""
self.args.collective_type = "reduce_scatter"
run_dense_grad_tests(self.args, self.mesh)
if __name__ == "__main__":
import sys
if len(sys.argv) < 7: # Need at least the 3 required distributed args
print("Error: This script requires distributed initialization arguments.")
print(
"Usage: python test_dense_grad.py --coordinator-address <address> --num-processes <num>"
" --process-id <id> [--local-device-ids <ids>] [other args]"
)
print(
"Example: python test_dense_grad.py --coordinator-address localhost:1234"
" --num-processes 4 --process-id 0"
)
print(
"Example: python test_dense_grad.py --coordinator-address localhost:1234"
" --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3"
)
sys.exit(1)
args = cgemm_parser(
"Collective Dense Gradient test on multi-GPU with tensor parallelism"
).parse_args([])
_initialize_distributed(args)
run_dense_grad_tests(args, mesh=None)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Collective GEMM test on multi-GPU with tensor parallelism
This script uses custom distributed initialization with the following arguments:
- --coordinator-address: Coordinator address for distributed initialization
- --num-processes: Number of processes for distributed initialization
- --process-id: Process ID for distributed initialization
- --local-device-ids: Local device IDs for distributed initialization
Example:
python test_gemm.py --coordinator-address localhost:1234 --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3
"""
import unittest
import os
from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding
from common import (
assert_allclose,
_initialize_distributed,
_get_dp_and_tp_sizes,
_create_mesh,
DP_AXIS,
TPSP_AXIS,
PARAMS_KEY,
cgemm_parser,
)
import transformer_engine.jax.cpp_extensions as tex
from transformer_engine.jax.quantize import fp8_autocast
from transformer_engine.jax.cpp_extensions.gemm import CollectiveOp
from transformer_engine.jax.sharding import MeshResource
def _get_operand_sharding(mesh, collective_op, is_with_dp):
dp_axis = DP_AXIS if is_with_dp else None
if collective_op == CollectiveOp.ALL_GATHER:
x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None))
weight_sharding = NamedSharding(mesh, PartitionSpec(None, TPSP_AXIS))
bias_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS))
output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS))
else: # RS
x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS))
weight_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS, None))
bias_sharding = NamedSharding(mesh, PartitionSpec(None))
output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None))
return x_sharding, weight_sharding, bias_sharding, output_sharding
def _get_dp_and_tp_sizes(args):
num_gpu = args.num_processes * args.num_devices_per_process
if args.tensor_parallel_size is None:
num_gpu_dp = 2 if args.enable_data_parallel else 1
assert (
num_gpu > 1 and num_gpu % num_gpu_dp == 0
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_tp = num_gpu // num_gpu_dp
else:
num_gpu_tp = args.tensor_parallel_size
assert (
num_gpu > 1 and num_gpu % num_gpu_tp == 0
), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
num_gpu_dp = num_gpu // num_gpu_tp
return num_gpu_dp, num_gpu_tp
@partial(jax.jit, static_argnames=("contracting_dims", "collective_op", "output_sharding"))
def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_sharding):
output = tex.gemm(
x,
weight,
bias=bias,
contracting_dims=contracting_dims,
collective_op=collective_op,
)
if output_sharding is not None:
output = jax.lax.with_sharding_constraint(output, output_sharding)
return output
def run_gemm_tests(args, mesh=None):
"""Execute GEMM tests."""
print(args)
# Collective GEMM requires Shardy partitioner to be disabled
jax.config.update("jax_use_shardy_partitioner", False)
# Initialize distributed with provided arguments
_initialize_distributed(args)
mesh = mesh or _create_mesh(args)
# Create test data
rng = jax.random.PRNGKey(0)
rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4)
x = jax.random.normal(
x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16
)
weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16)
bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16)
collective_op = (
CollectiveOp.ALL_GATHER
if args.collective_type == "all_gather"
else CollectiveOp.REDUCE_SCATTER
)
with mesh, fp8_autocast(
enabled=False,
fp8_recipe=None,
mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
):
print(f"Device mesh: {mesh}")
x_sharding, weight_sharding, bias_sharding, output_sharding = _get_operand_sharding(
mesh, collective_op, args.enable_data_parallel
)
x_sharded = jax.device_put(x, x_sharding)
weight_sharded = jax.device_put(weight, weight_sharding)
bias_sharded = jax.device_put(bias, bias_sharding)
ref_output = _jitted_cgemm(
x_sharded,
weight_sharded,
bias_sharded,
contracting_dims=((2,), (0,)),
collective_op=CollectiveOp.NONE,
output_sharding=output_sharding,
)
output = _jitted_cgemm(
x_sharded,
weight_sharded,
bias_sharded,
contracting_dims=((2,), (0,)),
collective_op=collective_op,
# CollectiveGEMM output should have a correct sharding without applying sharding constraint
output_sharding=None,
)
assert (
ref_output.sharding == output.sharding
), f"ref_output.sharding={ref_output.sharding}, output.sharding={output.sharding}"
gathered_ref_output = jax.lax.with_sharding_constraint(
ref_output, NamedSharding(mesh, PartitionSpec(None))
)
gathered_output = jax.lax.with_sharding_constraint(
output, NamedSharding(mesh, PartitionSpec(None))
)
jax.block_until_ready(gathered_ref_output)
jax.block_until_ready(gathered_output)
if args.enable_result_check and args.process_id == 0:
assert_allclose(gathered_ref_output, gathered_output)
class TestCollectiveGemmWithDP(unittest.TestCase):
"""Collective GEMM with DP unittests"""
def setUp(self):
self.args = cgemm_parser(
"Collective GEMM test on multi-GPU with tensor parallelism"
).parse_args([])
self.args.coordinator_address = self.coordinator_address
self.args.num_processes = self.num_processes
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
self.mesh = _create_mesh(self.args)
jax.sharding.set_mesh(self.mesh)
self.args.enable_result_check = True
os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1"
def tearDown(self):
os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None)
def test_te_bf16_all_gather_with_dp(self):
"""Test Collective GEMM with AllGather"""
self.args.collective_type = "all_gather"
run_gemm_tests(self.args, self.mesh)
def test_te_bf16_reduce_scatter_with_dp(self):
"""Test Collective GEMM with ReduceScatter"""
self.args.collective_type = "reduce_scatter"
run_gemm_tests(self.args, self.mesh)
if __name__ == "__main__":
import sys
if len(sys.argv) < 5: # Need at least the 3 required distributed args
print("Error: This script requires distributed initialization arguments.")
print(
"Usage: python test_gemm.py --coordinator-address <address> --num-processes <num>"
" --process-id <id> [--local-device-ids <ids>] [other args]"
)
sys.exit(1)
args = cgemm_parser("Collective GEMM test on multi-GPU with tensor parallelism").parse_args()
_initialize_distributed(args)
run_gemm_tests(args, mesh=None)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Collective Dense Gradient test on multi-GPU with tensor parallelism"""
import argparse
import unittest
import os
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding
import flax
from common import (
assert_allclose,
_initialize_distributed,
_get_dp_and_tp_sizes,
_create_mesh,
DP_AXIS,
TPSP_AXIS,
PARAMS_KEY,
cgemm_parser,
)
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
from transformer_engine.jax.quantize import fp8_autocast
from transformer_engine.jax.cpp_extensions.gemm import (
CollectiveOpSet,
CollectiveOp,
noop_collective_op_set,
)
from transformer_engine.jax.sharding import MeshResource
import transformer_engine.jax.flax as te_flax
def _get_logical_axes():
input_1_axes = (DP_AXIS, TPSP_AXIS, None)
weight_1_axes = (None, None, TPSP_AXIS)
bias_axes_1 = (None, TPSP_AXIS)
input_2_axes = (DP_AXIS, None, TPSP_AXIS)
weight_2_axes = (TPSP_AXIS, None)
bias_axes_2 = (None,)
return input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2
def _get_operand_sharding(mesh):
input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2 = (
_get_logical_axes()
)
x_sharding = NamedSharding(mesh, PartitionSpec(*input_1_axes))
weight_1_sharding = NamedSharding(mesh, PartitionSpec(*weight_1_axes))
bias_1_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_1))
weight_2_sharding = NamedSharding(mesh, PartitionSpec(*weight_2_axes))
bias_2_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_2))
return x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding
def _mean_layernorm_mlp(
x,
weight_1,
bias_1,
weight_2,
bias_2,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
collective_op_sets,
):
output = layernorm_mlp(
x,
gamma,
beta=None,
kernels=[weight_1, weight_2],
biases=[bias_1, bias_2],
norm_type="rmsnorm",
dot_1_input_axes=input_1_axes,
dot_2_input_axes=input_2_axes,
kernel_1_axes=weight_1_axes,
kernel_2_axes=weight_2_axes,
activation_type=("gelu",),
collective_op_sets=collective_op_sets,
)
return jnp.mean(output)
def _value_and_grad_layernorm_mlp(
x,
weight_1,
bias_1,
weight_2,
bias_2,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
collective_op_sets,
):
return jax.jit(
jax.value_and_grad(_mean_layernorm_mlp, (0, 1, 2, 3, 4, 5)), static_argnums=(6, 7, 8, 9, 10)
)(
x,
weight_1,
bias_1,
weight_2,
bias_2,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
collective_op_sets,
)
def run_layernorm_mlp_grad_tests(args, mesh=None):
"""Execute Dense Gradient tests."""
print(args)
# Collective GEMM requires Shardy partitioner to be disabled
jax.config.update("jax_use_shardy_partitioner", False)
# Initialize distributed with provided arguments
_initialize_distributed(args)
mesh = mesh or _create_mesh(args)
# Create test data
rng = jax.random.PRNGKey(0)
rng, x_rng, weight_1_rng, bias_1_rng, weight_2_rng, bias_2_rng, gamma_rng = jax.random.split(
rng, 7
)
x = jax.random.normal(
x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16
)
weight_1 = jax.random.normal(
weight_1_rng, (args.hidden_in, 1, args.hidden_out), dtype=jnp.bfloat16
) / jnp.sqrt(args.hidden_in)
bias_1 = jax.random.normal(bias_1_rng, (1, args.hidden_out), dtype=jnp.bfloat16)
weight_2 = jax.random.normal(
weight_2_rng, (args.hidden_out, args.hidden_in), dtype=jnp.bfloat16
) / jnp.sqrt(args.hidden_out)
bias_2 = jax.random.normal(bias_2_rng, (args.hidden_in,), dtype=jnp.bfloat16)
gamma = jax.random.normal(gamma_rng, (args.hidden_in,), dtype=jnp.bfloat16) / jnp.sqrt(
args.hidden_in
)
collective_op_set_1 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.ALL_GATHER)
collective_op_set_2 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.REDUCE_SCATTER)
collective_op_sets = (collective_op_set_1, collective_op_set_2)
noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set)
with mesh, fp8_autocast(
enabled=False,
fp8_recipe=None,
mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
):
# Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
axis_rules = flax.linen.get_logical_axis_rules()
axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS))
te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
with flax.linen.logical_axis_rules(te_extended_axis_rules):
x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding = (
_get_operand_sharding(mesh)
)
x_sharded = jax.device_put(x, x_sharding)
weight_1_sharded = jax.device_put(weight_1, weight_1_sharding)
bias_1_sharded = jax.device_put(bias_1, bias_1_sharding)
weight_2_sharded = jax.device_put(weight_2, weight_2_sharding)
bias_2_sharded = jax.device_put(bias_2, bias_2_sharding)
input_1_axes, weight_1_axes, _, input_2_axes, weight_2_axes, _ = _get_logical_axes()
ref_output, ref_grads = _value_and_grad_layernorm_mlp(
x_sharded,
weight_1_sharded,
bias_1_sharded,
weight_2_sharded,
bias_2_sharded,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
noop_collective_op_sets,
)
output, sharded_grads = _value_and_grad_layernorm_mlp(
x_sharded,
weight_1_sharded,
bias_1_sharded,
weight_2_sharded,
bias_2_sharded,
gamma,
input_1_axes,
input_2_axes,
weight_1_axes,
weight_2_axes,
collective_op_sets,
)
jax.block_until_ready(ref_output)
jax.block_until_ready(output)
gathered_grads = []
gathered_ref_grads = []
for ref_grad, grad in zip(ref_grads, sharded_grads):
gathered_grads.append(
jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None)))
)
gathered_ref_grads.append(
jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None)))
)
jax.block_until_ready(gathered_grads)
jax.block_until_ready(gathered_ref_grads)
if args.enable_result_check and args.process_id == 0:
assert_allclose(ref_output, output, dtype=jnp.bfloat16)
for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads):
assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16)
class TestCollectiveLayerNormMLPGradient(unittest.TestCase):
"""Collective Dense Gradient unittests"""
def setUp(self):
self.args = cgemm_parser(
"Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism"
).parse_args([])
self.args.coordinator_address = self.coordinator_address
self.args.num_processes = self.num_processes
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
# Create mesh once for all tests
self.mesh = _create_mesh(self.args)
jax.sharding.set_mesh(self.mesh)
self.args.enable_result_check = True
os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1"
def tearDown(self):
os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None)
def test_te_bf16_layernorm_mlp_grad(self):
"""Test Collective Dense Gradient with AllGather"""
run_layernorm_mlp_grad_tests(self.args, self.mesh)
if __name__ == "__main__":
import sys
if len(sys.argv) < 7: # Need at least the 3 required distributed args
print("Error: This script requires distributed initialization arguments.")
print(
"Usage: python test_layernorm_mlp_grad.py --coordinator-address <address>"
" --num-processes <num> --process-id <id> [--local-device-ids <ids>] [other args]"
)
print(
"Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234"
" --num-processes 4 --process-id 0"
)
print(
"Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234"
" --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3"
)
sys.exit(1)
args = cgemm_parser(
"Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism"
).parse_args([])
_initialize_distributed(args)
run_layernorm_mlp_grad_tests(args, mesh=None)
...@@ -29,6 +29,10 @@ wait ...@@ -29,6 +29,10 @@ wait
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
wait wait
TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
wait
TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/collective_gemm/run_test_cgemm.sh || test_fail "run_test_cgemm.sh"
wait
if [ $RET -ne 0 ]; then if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES" echo "Error: some sub-tests failed: $FAILED_CASES"
......
...@@ -64,6 +64,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl ...@@ -64,6 +64,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
#endif #endif
_comm_created = true; _comm_created = true;
} }
initialize(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority,
num_comm_sm, set_sm_margin, use_ce, atomic_gemm);
}
void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_streams,
int comm_cga_size, int gemm_priority, int comm_priority,
int num_comm_sm, bool set_sm_margin, bool use_ce,
bool atomic_gemm) {
_use_ce = static_cast<int>(use_ce); _use_ce = static_cast<int>(use_ce);
_num_comm_sm = num_comm_sm; _num_comm_sm = num_comm_sm;
_cga_size = comm_cga_size; _cga_size = comm_cga_size;
...@@ -278,6 +287,11 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType ...@@ -278,6 +287,11 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false,
atomic_gemm) { atomic_gemm) {
initialize(buffer_shape, buffer_dtype, rs_overlap_first_gemm);
}
void CommOverlapBase::initialize(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
bool rs_overlap_first_gemm) {
_rs_overlap_first_gemm = rs_overlap_first_gemm; _rs_overlap_first_gemm = rs_overlap_first_gemm;
_rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0); _rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0);
NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3,
...@@ -288,7 +302,9 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType ...@@ -288,7 +302,9 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype);
void *buffer_ptr; void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); if (_ub_comm->myrank == 0) {
printf("!!! [UB] Register UBuf %d\n", _ub_reg);
}
_ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype);
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
...@@ -640,6 +656,11 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, ...@@ -640,6 +656,11 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm) { atomic_gemm) {
initialize(buffer_shape, buffer_dtype, comm_type, aggregate);
}
void CommOverlapP2PBase::initialize(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
CommOverlapType comm_type, bool aggregate) {
_is_p2p = true; _is_p2p = true;
_is_reduce_scatter = comm_type == CommOverlapType::RS; _is_reduce_scatter = comm_type == CommOverlapType::RS;
_aggregate = aggregate; _aggregate = aggregate;
...@@ -647,28 +668,28 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, ...@@ -647,28 +668,28 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
// Create workspace tensor with userbuffer // Create workspace tensor with userbuffer
NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!");
size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype);
int buffer_chunk_bytes = buffer_bytes / tp_size; int buffer_chunk_bytes = buffer_bytes / _tp_size;
_num_ubuf_chunks = tp_size; _num_ubuf_chunks = _tp_size;
if (_is_reduce_scatter) { if (_is_reduce_scatter) {
// GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk
// outputs for reduction at the end of the pipelining. // outputs for reduction at the end of the pipelining.
buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1); buffer_bytes = buffer_bytes / _tp_size * (_tp_size * 2 - 1);
_num_ubuf_chunks = tp_size * 2 - 1; _num_ubuf_chunks = _tp_size * 2 - 1;
} }
void *buffer_ptr; void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); if (_rank == 0) printf("!!! [UBP2P] UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper( _ubuf = TensorWrapper(
buffer_ptr, buffer_ptr,
std::vector<size_t>{buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]}, std::vector<size_t>{buffer_shape[0] / _tp_size * _num_ubuf_chunks, buffer_shape[1]},
buffer_dtype); buffer_dtype);
// Create tensor chunks for easy management // Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(buffer_ptr); char *ubuf_byte_ptr = reinterpret_cast<char *>(buffer_ptr);
for (int i = 0; i < _num_ubuf_chunks; i++) { for (int i = 0; i < _num_ubuf_chunks; i++) {
_ubufs.push_back(TensorWrapper(reinterpret_cast<void *>(ubuf_byte_ptr), _ubufs.push_back(TensorWrapper(reinterpret_cast<void *>(ubuf_byte_ptr),
std::vector<size_t>{buffer_shape[0] / tp_size, buffer_shape[1]}, std::vector<size_t>{buffer_shape[0] / _tp_size, buffer_shape[1]},
buffer_dtype)); buffer_dtype));
ubuf_byte_ptr += buffer_chunk_bytes; ubuf_byte_ptr += buffer_chunk_bytes;
} }
...@@ -691,7 +712,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, ...@@ -691,7 +712,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t)));
} }
for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) { for (int i = 0; i < _stream_compute.size(); i++) {
cudaStream_t stream; cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority));
_stream_send.push_back(std::move(stream)); _stream_send.push_back(std::move(stream));
...@@ -711,6 +732,38 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { ...@@ -711,6 +732,38 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
} }
} }
void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source,
bool local_chunk, bool rowwise) {
// Check element size
const size_t element_size = source.element_size();
NVTE_CHECK(_ubuf.element_size() == element_size,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible ",
"(source dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(),
" bytes)");
// Input data
const size_t source_size = source.numel();
const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr();
// Userbuffers data
void *dst_ptr;
if (local_chunk) {
NVTE_CHECK(_ubufs[_tp_id].numel() == source_size,
"Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ",
"(source_size=", source_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")");
dst_ptr = _ubufs[_tp_id].dptr();
} else {
NVTE_CHECK(_ubuf.numel() == source_size,
"Tried to copy an invalid tensor into a Userbuffers buffer ",
"(source_size=", source_size, ", ubuf_size=", _ubuf.numel(), ")");
dst_ptr = _ubuf.dptr();
}
// Copy data
NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size,
cudaMemcpyDeviceToDevice, stream));
}
TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source,
size_t chunk_id) { size_t chunk_id) {
// Start with a chunk of the source tensor // Start with a chunk of the source tensor
...@@ -851,6 +904,15 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -851,6 +904,15 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
const bool do_gelu = pre_gelu_out.numel() > 0; const bool do_gelu = pre_gelu_out.numel() > 0;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Check B copy sizing
if (B_copy.numel() > 0) {
NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ",
_ubuf.numel(), " elements but got ", B_copy.numel());
NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(),
"Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8,
"-bit data type but got ", B_copy.element_size() * 8, "-bit");
}
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
...@@ -919,12 +981,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -919,12 +981,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0));
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice,
_stream_send[0]));
} }
} }
} else { } else {
...@@ -972,16 +1028,16 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -972,16 +1028,16 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0));
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice,
_stream_send[0]));
} }
} }
} }
// Copy all-gathered B from communication buffer into auxiliary output
if (B_copy.numel() > 0) {
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(),
cudaMemcpyDeviceToDevice, _stream_send[0]));
}
_ub_comm->sms = ori_sms; _ub_comm->sms = ori_sms;
for (size_t i = 0; i < _stream_compute.size(); i++) { for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
......
...@@ -670,9 +670,36 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * ...@@ -670,9 +670,36 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
reinterpret_cast<void *>(&memhndl), sizeof(cudaIpcMemHandle_t), reinterpret_cast<void *>(&memhndl), sizeof(cudaIpcMemHandle_t),
comm->comm_intra); comm->comm_intra);
// Check for NVLINK support before attempting IPC operations
if (comm->nvsize > 1) {
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
cudaDeviceProp deviceProp;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, current_device));
bool peer_access_available = false;
for (int i = 0; i < comm->nvsize; i++) { for (int i = 0; i < comm->nvsize; i++) {
if (i != comm->nvrank) { if (i != comm->nvrank) {
NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*) int can_access_peer;
cudaError_t peer_result = cudaDeviceCanAccessPeer(&can_access_peer, current_device, i);
if (peer_result == cudaSuccess && can_access_peer) {
peer_access_available = true;
break;
}
}
}
if (!peer_access_available) {
free(tmp);
NVTE_ERROR(
"No peer-to-peer access available between GPUs. This platform does not support the "
"GPU-to-GPU "
"communication required for multi-GPU userbuffers. Consider using single-GPU mode.");
return 1;
}
}
for (int i = 0; i < comm->nvsize; i++) {
if (i != comm->nvrank) {
NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i],
cudaIpcMemLazyEnablePeerAccess)); cudaIpcMemLazyEnablePeerAccess));
} }
} }
...@@ -693,4 +720,5 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * ...@@ -693,4 +720,5 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
comm->mem_ptr[hndl] = *gpubuff; comm->mem_ptr[hndl] = *gpubuff;
return comm->free_region++; return comm->free_region++;
printf("***** Returning *****\n");
} }
...@@ -67,6 +67,11 @@ class CommOverlapCore { ...@@ -67,6 +67,11 @@ class CommOverlapCore {
std::vector<cudaStream_t> _stream_compute; std::vector<cudaStream_t> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event;
private:
void initialize(int tp_size, int num_splits, int num_max_streams, int comm_cga_size,
int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin,
bool use_ce, bool atomic_gemm);
public: public:
CommOverlapCore() {} // dummy constructor for exposing type to Python CommOverlapCore() {} // dummy constructor for exposing type to Python
...@@ -78,17 +83,26 @@ class CommOverlapCore { ...@@ -78,17 +83,26 @@ class CommOverlapCore {
virtual ~CommOverlapCore(); virtual ~CommOverlapCore();
void *get_ubuf_dptr() { return _ubuf.dptr(); }
void set_ubuf_scale_inv(float *scale_inv) { void set_ubuf_scale_inv(float *scale_inv) {
_ubuf_scale_inv = scale_inv; _ubuf_scale_inv = scale_inv;
_ubuf_scale_inv_initialized = true; _ubuf_scale_inv_initialized = true;
} }
virtual void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk,
bool rowwise = true) {
NVTE_ERROR("Operation is not implemented.");
}
TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset, TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset,
const std::vector<size_t> &shape); const std::vector<size_t> &shape);
TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset,
const std::vector<size_t> &shape); const std::vector<size_t> &shape);
int get_tp_size() { return _tp_size; }
bool is_atomic_gemm() { return _atomic_gemm; } bool is_atomic_gemm() { return _atomic_gemm; }
bool is_p2p_overlap() { return _is_p2p; } bool is_p2p_overlap() { return _is_p2p; }
...@@ -148,6 +162,10 @@ class CommOverlapBase : public CommOverlapCore { ...@@ -148,6 +162,10 @@ class CommOverlapBase : public CommOverlapCore {
cudaStream_t _stream_comm; cudaStream_t _stream_comm;
cudaEvent_t _start_d2dcopy; cudaEvent_t _start_d2dcopy;
private:
void initialize(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
bool rs_overlap_first_gemm);
public: public:
CommOverlapBase() {} // dummy constructor for exposing type to Python CommOverlapBase() {} // dummy constructor for exposing type to Python
...@@ -224,6 +242,10 @@ class CommOverlapP2PBase : public CommOverlapCore { ...@@ -224,6 +242,10 @@ class CommOverlapP2PBase : public CommOverlapCore {
cudaStream_t _stream_recv; cudaStream_t _stream_recv;
cudaEvent_t _stop_send, _stop_recv; cudaEvent_t _stop_send, _stop_recv;
private:
void initialize(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
CommOverlapType comm_type, bool aggregate);
public: public:
CommOverlapP2PBase() {} // dummy constructor for exposing type to Python CommOverlapP2PBase() {} // dummy constructor for exposing type to Python
...@@ -237,6 +259,9 @@ class CommOverlapP2PBase : public CommOverlapCore { ...@@ -237,6 +259,9 @@ class CommOverlapP2PBase : public CommOverlapCore {
virtual ~CommOverlapP2PBase(); virtual ~CommOverlapP2PBase();
void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk,
bool rowwise = true) override;
TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id); TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id);
void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
#include <cudnn.h> #include <cudnn.h>
#include <nvrtc.h> #include <nvrtc.h>
#include "nccl.h"
#ifdef NVTE_WITH_CUBLASMP #ifdef NVTE_WITH_CUBLASMP
#include <cublasmp.h> #include <cublasmp.h>
#endif // NVTE_WITH_CUBLASMP #endif // NVTE_WITH_CUBLASMP
...@@ -104,4 +106,12 @@ ...@@ -104,4 +106,12 @@
#endif // NVTE_WITH_CUBLASMP #endif // NVTE_WITH_CUBLASMP
#define NVTE_CHECK_NCCL(expr) \
do { \
const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \
if (status_NVTE_CHECK_NCCL != ncclSuccess) { \
NVTE_ERROR("NCCL Error: ", ncclGetErrorString(status_NVTE_CHECK_NCCL)); \
} \
} while (false)
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ #endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
...@@ -293,3 +293,11 @@ class NamedSharding(jax.sharding.NamedSharding): ...@@ -293,3 +293,11 @@ class NamedSharding(jax.sharding.NamedSharding):
Create a new NamedSharding with the same mesh and spec but with a new description. Create a new NamedSharding with the same mesh and spec but with a new description.
""" """
return NamedSharding(self.mesh, self.spec, desc=desc) return NamedSharding(self.mesh, self.spec, desc=desc)
@functools.lru_cache(maxsize=1)
def is_all_reduce_in_float32():
"""
Check if all-reduce is in float32
"""
return os.getenv("NVTE_JAX_ALL_REDUCE_IN_FP32", "0") == "1"
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <cudnn.h> #include <cudnn.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <transformer_engine/comm_gemm_overlap.h>
#include <transformer_engine/normalization.h> #include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
...@@ -32,9 +33,6 @@ ...@@ -32,9 +33,6 @@
#include "transformer_engine/activation.h" #include "transformer_engine/activation.h"
#include "transformer_engine/multi_stream.h" #include "transformer_engine/multi_stream.h"
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
...@@ -121,6 +119,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -121,6 +119,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
// GEMM // GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler);
// Grouped GEMM // Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
...@@ -134,4 +133,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); ...@@ -134,4 +133,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op);
#endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ #endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "cgemm_helper.h"
#include "common/util/system.h"
#include "nccl.h"
namespace transformer_engine {
namespace jax {
ncclUniqueId CommunicatorHandler::coordinate_nccl_unique_id(const std::string &id_type) {
ncclUniqueId unique_id;
int tp_domain_id = get_tp_domain_id();
bool is_tp_leader = (get_local_device_id_within_tp_domain() == 0);
pid_t pgid = getpgid(0);
std::string base_path = getenv<std::string>("NVTE_JAX_NCCL_FILE_PATH", "/tmp");
std::string id_file = base_path + "/nccl_" + id_type + "_unique_id_pgid_" + std::to_string(pgid) +
"_" + std::to_string(num_total_devices) + "_" + std::to_string(tp_size) +
"_domain_" + std::to_string(tp_domain_id) + ".bin";
if (is_tp_leader) {
NVTE_CHECK_NCCL(ncclGetUniqueId(&unique_id));
// Write the ID to a temporary file
std::ofstream file(id_file, std::ios::binary);
NVTE_CHECK(file.is_open(), "Failed to create NCCL unique ID file: ", id_file);
file.write(reinterpret_cast<const char *>(&unique_id), sizeof(ncclUniqueId));
file.close();
} else {
// Wait for the ID file to be created and read it
int attempts = 0;
const int max_attempts = 100;
while (attempts < max_attempts) {
std::ifstream file(id_file, std::ios::binary);
if (file.is_open()) {
file.read(reinterpret_cast<char *>(&unique_id), sizeof(ncclUniqueId));
if (file.gcount() == sizeof(ncclUniqueId)) {
file.close();
break;
}
file.close();
}
std::this_thread::sleep_for(std::chrono::milliseconds(100));
attempts++;
}
NVTE_CHECK(attempts < max_attempts,
"Timeout waiting for " + id_type + " NCCL unique ID file from leader: ", id_file);
}
if (is_tp_leader) {
_nccl_id_file_name.push_back(id_file);
}
return unique_id;
}
void CommunicatorHandler::init(int num_total_devices, int num_devices_per_process, int process_id,
int tp_size) {
// Validate inputs
NVTE_CHECK(num_devices_per_process == 1,
"num_devices_per_process must be == 1, got num_devices_per_process=",
num_devices_per_process);
NVTE_CHECK(num_total_devices >= 1,
"num_total_devices must be >= 1, got num_total_devices=", num_total_devices);
NVTE_CHECK(
num_total_devices % num_devices_per_process == 0,
"num_total_devices must be divisible by num_devices_per_process, got num_total_devices=",
num_total_devices, ", num_devices_per_process=", num_devices_per_process);
// Validate TP size
NVTE_CHECK(tp_size > 0, "tp_size must be > 0, got tp_size=", tp_size);
NVTE_CHECK(num_total_devices % tp_size == 0,
"num_total_devices must be divisible by tp_size, got num_total_devices=",
num_total_devices, ", tp_size=", tp_size);
auto &handler = get(false);
handler.num_total_devices = num_total_devices;
handler.num_devices_per_process = num_devices_per_process;
handler.process_id = process_id;
handler.num_processes = num_total_devices / num_devices_per_process;
handler.tp_size = tp_size;
handler.tp_num_domains = num_total_devices / tp_size;
// Initialize vectors with the correct size
handler.local_device_ids_within_process.resize(num_devices_per_process);
handler.local_device_ids_within_tp_domain.resize(num_devices_per_process);
handler.tp_domain_ids.resize(num_devices_per_process);
handler.global_device_ids.resize(num_devices_per_process);
handler.tp_comms.resize(num_devices_per_process);
NVTE_CHECK(0 <= process_id && process_id < handler.num_processes,
"Invalid process_id=", process_id, ", which is out of range [0, ",
handler.num_processes, ")");
// Initialize local devices and calculate their global device IDs and TP topology
for (int local_idx = 0; local_idx < num_devices_per_process; local_idx++) {
// Use the device that JAX has already assigned to this process
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
handler.local_device_ids_within_process[local_idx] = current_device;
handler.global_device_ids[local_idx] = process_id * num_devices_per_process + local_idx;
// Calculate TP-related values for this device
int global_device_id = handler.global_device_ids[local_idx];
if (num_devices_per_process == tp_size) {
// Scenario 1: Multi-device per process - TP domain = single process
handler.local_device_ids_within_tp_domain[local_idx] = local_idx;
handler.tp_domain_ids[local_idx] = process_id;
} else {
// Scenario 2: Single device per process - TP domain spans multiple processes
handler.local_device_ids_within_tp_domain[local_idx] = global_device_id % tp_size;
handler.tp_domain_ids[local_idx] = global_device_id / tp_size;
}
}
ncclUniqueId tp_id = handler.coordinate_nccl_unique_id("tp");
NVTE_CHECK_NCCL(ncclGroupStart());
for (int local_idx = 0; local_idx < num_devices_per_process; local_idx++) {
NVTE_CHECK_CUDA(cudaSetDevice(handler.local_device_ids_within_process[local_idx]));
int tp_local_rank = handler.local_device_ids_within_tp_domain[local_idx];
NVTE_CHECK_NCCL(
ncclCommInitRank(&handler.tp_comms[local_idx], handler.tp_size, tp_id, tp_local_rank));
}
NVTE_CHECK_NCCL(ncclGroupEnd());
// Allocate device memory for barrier operations
NVTE_CHECK_CUDA(cudaMalloc(&handler._device_barrier, sizeof(int)));
handler._initialize = true;
// Bootstrap UB via creating a dummy CommOverlapP2PBase object
std::vector<size_t> buffer_shape{1, 1};
auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, DType::kFloat32,
JAXX_Collective_Op::ALL_GATHER);
}
void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id,
int tp_size, int num_max_streams, int gemm_priority,
int comm_priority, int num_comm_sm, bool use_ce,
bool aggregate_ag) {
auto &config = CgemmConfig::get(false);
config.init(num_max_streams, gemm_priority, comm_priority, num_comm_sm, use_ce, aggregate_ag);
auto &handler = CommunicatorHandler::get(false);
handler.init(num_total_devices, num_devices_per_process, process_id, tp_size);
}
int GetCgemmNumMaxStreams() {
auto &config = CgemmConfig::get();
return config.num_max_streams;
}
CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector<size_t> buffer_shape,
DType dtype,
JAXX_Collective_Op collective_op) {
auto &comm_handler = CommunicatorHandler::get();
auto &cgemm_config = CgemmConfig::get();
int device_idx = comm_handler.get_local_device_idx_for_current_device();
int64_t plan_id = 0;
hash_combine(plan_id, buffer_shape[0], buffer_shape[1], static_cast<size_t>(dtype),
static_cast<int>(collective_op), comm_handler.tp_size, cgemm_config.num_max_streams,
cgemm_config.gemm_priority, cgemm_config.comm_priority, cgemm_config.num_comm_sm,
cgemm_config.use_ce, cgemm_config.aggregate_ag, device_idx);
auto it = plan_map.find(plan_id);
if (it != plan_map.end()) {
return it->second.get();
}
if (comm_handler.num_devices_per_process == comm_handler.tp_size) {
// Multi-device per process
} else if (comm_handler.num_devices_per_process == 1) {
// Single device per process
NVTE_CHECK(comm_handler.num_total_devices % comm_handler.tp_size == 0,
"For single device per process, num_total_devices must be divisible by tp_size, "
"got num_total_devices=",
comm_handler.num_total_devices, ", tp_size=", comm_handler.tp_size);
} else {
NVTE_ERROR("Unsupported TP configuration: num_devices_per_process=",
comm_handler.num_devices_per_process, ", tp_size=", comm_handler.tp_size,
". Supported scenarios: "
"(1) num_devices_per_process == tp_size (multi-device per process), "
"(2) num_devices_per_process == 1 (single device per process)");
}
std::unique_ptr<CommOverlapCore> executor;
executor = std::make_unique<CommOverlapP2PBase>(
buffer_shape, dtype, comm_handler.get_global_rank(), comm_handler.num_total_devices,
comm_handler.get_local_device_id_within_tp_domain(), comm_handler.tp_size,
comm_handler.get_tp_domain_id(), comm_handler.get_tp_num_domains(), comm_handler.tp_size,
comm_handler.allgather_func, comm_handler.barrier_func, get_nvte_collective_op(collective_op),
cgemm_config.num_max_streams, 1 /*comm_cga_size*/, cgemm_config.gemm_priority,
cgemm_config.comm_priority, cgemm_config.num_comm_sm, true /*set_sm_margin*/,
cgemm_config.use_ce, false /*atomic_gemm*/, cgemm_config.aggregate_ag);
CommOverlapCore *executor_ptr = executor.get();
plan_map[plan_id] = std::move(executor);
return executor_ptr;
}
void CommunicatorHandler::nccl_device_barrier_impl(ExtComm) {
NVTE_CHECK(_initialize, "CommunicatorHandler must be initialized before using barrier");
int device_idx = get_local_device_idx_for_current_device();
ncclComm_t tp_comm = tp_comms[device_idx];
NVTE_CHECK_NCCL(
ncclAllReduce(_device_barrier, _device_barrier, 1, ncclInt, ncclSum, tp_comm, nullptr));
cudaDeviceSynchronize();
}
void CommunicatorHandler::nccl_allgather_impl(void *output_buf, size_t output_bytes,
void *input_buf, size_t input_bytes, ExtComm) {
NVTE_CHECK(_initialize, "CommunicatorHandler must be initialized before using allgather");
int device_idx = get_local_device_idx_for_current_device();
ncclComm_t tp_comm = tp_comms[device_idx];
size_t expected_output_bytes = input_bytes * tp_size;
NVTE_CHECK(output_bytes == expected_output_bytes, "TP allgather buffer size mismatch: expected ",
expected_output_bytes, ", got ", output_bytes);
NVTE_CHECK_NCCL(ncclAllGather(input_buf, output_buf, input_bytes, ncclChar, tp_comm, nullptr));
cudaDeviceSynchronize();
}
CommunicatorHandler::CommunicatorHandler() : _device_barrier(nullptr) {
allgather_func = [this](void *output_buf, size_t output_bytes, void *input_buf,
size_t input_bytes, ExtComm comm) {
this->nccl_allgather_impl(output_buf, output_bytes, input_buf, input_bytes, comm);
};
barrier_func = [this](ExtComm comm) { this->nccl_device_barrier_impl(comm); };
}
CommunicatorHandler::~CommunicatorHandler() {
if (_initialize && !tp_comms.empty()) {
for (auto &comm : tp_comms) {
if (comm != nullptr) {
ncclCommDestroy(comm);
}
}
}
if (_device_barrier) cudaFree(_device_barrier);
for (const auto &file_path : _nccl_id_file_name) {
std::remove(file_path.c_str());
}
}
} // namespace jax
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_
#define TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_
#include <unistd.h>
#include <chrono>
#include <cstdio>
#include <fstream>
#include <functional>
#include <memory>
#include <thread>
#include <unordered_map>
#include "../extensions.h"
#include "common/comm_gemm_overlap/userbuffers/userbuffers.h"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
#include "transformer_engine/comm_gemm_overlap.h"
namespace transformer_engine {
namespace jax {
// Configuration singleton for CGEMM parameters
class CgemmConfig {
public:
int num_max_streams;
int gemm_priority;
int comm_priority;
int num_comm_sm;
bool use_ce;
bool aggregate_ag;
static void init(int _num_max_streams, int _gemm_priority, int _comm_priority, int _num_comm_sm,
bool _use_ce, bool _aggregate_ag) {
auto &config = get(false);
config._initialized = true;
config.num_max_streams = _num_max_streams;
config.gemm_priority = _gemm_priority;
config.comm_priority = _comm_priority;
config.num_comm_sm = _num_comm_sm;
config.use_ce = _use_ce;
config.aggregate_ag = _aggregate_ag;
}
static CgemmConfig &get(bool is_initialized = true) {
static thread_local CgemmConfig instance;
NVTE_CHECK(
instance._initialized == is_initialized,
"CgemmConfig must be initialized before using it, got is_initialized=", is_initialized);
return instance;
}
CgemmConfig(const CgemmConfig &) = delete;
CgemmConfig &operator=(const CgemmConfig &) = delete;
private:
CgemmConfig() = default;
~CgemmConfig() = default;
bool _initialized = false;
};
// Forward declaration
class CollectiveGemmPlanRegistry;
// NCCL communicator handler for collective GEMM operations
// Support both single process single device AND single process multi device
// Two scenarios:
// 1. Single process multiple devices: TP domain = process (num_devices_per_process == tp_size)
// 2. Single process single device: TP domain spans processes (num_devices_per_process == 1)
class CommunicatorHandler {
public:
int num_total_devices = -1;
int num_devices_per_process = -1;
int process_id = -1;
int num_processes = -1;
int tp_size = -1;
int tp_num_domains = -1;
std::vector<int> local_device_ids_within_tp_domain;
std::vector<int> tp_domain_ids;
std::vector<ncclComm_t> tp_comms;
std::vector<int> local_device_ids_within_process;
std::vector<int> global_device_ids;
int get_global_rank() const {
int device_idx = get_local_device_idx_for_current_device();
return global_device_ids[device_idx];
}
void nccl_device_barrier_impl(ExtComm);
void nccl_allgather_impl(void *output_buf, size_t output_bytes, void *input_buf,
size_t input_bytes, ExtComm);
ncclComm_t get_comm_for_current_device() const {
int device_idx = get_local_device_idx_for_current_device();
return tp_comms[device_idx];
}
int get_local_device_idx_for_current_device() const {
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
for (int i = 0; i < num_devices_per_process; i++) {
if (local_device_ids_within_process[i] == current_device) {
return i;
}
}
NVTE_ERROR("Current CUDA device ", current_device,
" not found in local_device_ids_within_process");
}
int get_local_device_id_within_tp_domain() const {
int device_idx = get_local_device_idx_for_current_device();
return local_device_ids_within_tp_domain[device_idx];
}
int get_tp_domain_id() const {
int device_idx = get_local_device_idx_for_current_device();
return tp_domain_ids[device_idx];
}
int get_tp_num_domains() const { return tp_num_domains; }
static void init(int num_total_devices, int num_devices_per_process, int process_id, int tp_size);
private:
ncclUniqueId coordinate_nccl_unique_id(const std::string &id_type);
public:
static CommunicatorHandler &get(bool is_initialized = true) {
static CommunicatorHandler instance;
NVTE_CHECK(instance._initialize == is_initialized,
"CommunicatorHandler._initialize=", instance._initialize,
", is_initialized=", is_initialized);
return instance;
}
ExtAllgatherOp allgather_func;
ExtBarrierOp barrier_func;
CommunicatorHandler(const CommunicatorHandler &) = delete;
CommunicatorHandler &operator=(const CommunicatorHandler &) = delete;
private:
CommunicatorHandler();
~CommunicatorHandler();
bool _initialize = false;
int *_device_barrier = nullptr;
std::vector<std::string> _nccl_id_file_name;
};
// Plan registry for caching collective GEMM executors
class CollectiveGemmPlanRegistry {
public:
static CollectiveGemmPlanRegistry &getInstance() {
static thread_local CollectiveGemmPlanRegistry instance;
return instance;
}
CommOverlapCore *get_executor(std::vector<size_t> buffer_shape, DType dtype,
JAXX_Collective_Op collective_op);
private:
CollectiveGemmPlanRegistry() {}
CollectiveGemmPlanRegistry(const CollectiveGemmPlanRegistry &) = delete;
CollectiveGemmPlanRegistry &operator=(const CollectiveGemmPlanRegistry &) = delete;
std::unordered_map<int64_t, std::unique_ptr<CommOverlapCore>> plan_map;
};
// Function declarations
void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id,
int tp_size, int num_max_streams, int gemm_priority,
int comm_priority, int num_comm_sm, bool use_ce,
bool aggregate_ag);
int GetCgemmNumMaxStreams();
} // namespace jax
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_
...@@ -6,13 +6,19 @@ ...@@ -6,13 +6,19 @@
#include "transformer_engine/gemm.h" #include "transformer_engine/gemm.h"
#include <memory> #include <memory>
#include <mutex>
#include <stdexcept>
#include <string_view> #include <string_view>
#include <tuple> #include <tuple>
#include "../extensions.h" #include "../extensions.h"
#include "cgemm_helper.h"
#include "common.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#include "common/util/string.h" #include "common/util/string.h"
#include "common/util/system.h" #include "common/util/system.h"
#include "cuda_runtime.h"
#include "nccl.h"
#include "transformer_engine/swizzle.h" #include "transformer_engine/swizzle.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
...@@ -66,12 +72,75 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand( ...@@ -66,12 +72,75 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
return std::make_tuple(std::move(input), input_shape); return std::make_tuple(std::move(input), input_shape);
} }
Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias,
Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad,
Result_Type pre_gelu_out, Result_Type workspace,
JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed,
bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad,
bool use_split_accumulator, JAXX_Collective_Op collective_op) {
nvte_cublas_handle_init();
// Init UB buffer
if (collective_op != JAXX_Collective_Op::NONE) {
auto &comm_handler = CommunicatorHandler::get();
std::vector<size_t> lhs_shape = {
product(lhs.dimensions(), 0, lhs_axis_boundary),
product(lhs.dimensions(), lhs_axis_boundary, lhs.dimensions().size())};
std::vector<size_t> rhs_shape = {
product(rhs.dimensions(), 0, rhs_axis_boundary),
product(rhs.dimensions(), rhs_axis_boundary, rhs.dimensions().size())};
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
(rhs_transposed) ? rhs_shape[0] : rhs_shape[1]};
std::vector<size_t> buffer_shape{0, 0};
DType buffer_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
if (collective_op == JAXX_Collective_Op::ALL_GATHER) {
buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size;
buffer_shape[1] = lhs_shape[1];
buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type());
} else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) {
buffer_shape[0] = out_shape[0];
buffer_shape[1] = out_shape[1];
}
auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, buffer_dtype,
collective_op);
}
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI,
FFI::Bind<FFI_Prepare>()
.Arg<Buffer_Type>() // lhs
.Arg<Buffer_Type>() // lhs_scale_inv
.Arg<Buffer_Type>() // rhs
.Arg<Buffer_Type>() // rhs_scale_inv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // gelu_input
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out
.Ret<Buffer_Type>() // workspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("lhs_axis_boundary")
.Attr<int64_t>("rhs_axis_boundary")
.Attr<bool>("lhs_transposed")
.Attr<bool>("rhs_transposed")
.Attr<bool>("fuse_bias")
.Attr<bool>("fuse_gelu")
.Attr<bool>("grad")
.Attr<bool>("use_split_accumulator")
.Attr<JAXX_Collective_Op>("collective_op"));
Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input,
Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out,
Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed,
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator,
JAXX_Collective_Op collective_op) {
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
...@@ -83,16 +152,9 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i ...@@ -83,16 +152,9 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode, auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode,
rhs_axis_boundary, make_rhs_rowwise); rhs_axis_boundary, make_rhs_rowwise);
// Output tensor
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
(rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]};
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, "
"expected ",
out_.numel(), " elements ", to_string_like(out_shape), " but got ",
output->element_count(), " elements ", to_string_like(output->dimensions()));
// Bias input to forward pass or bias gradient output from backward pass // Bias input to forward pass or bias gradient output from backward pass
void *bias_ptr = nullptr; void *bias_ptr = nullptr;
...@@ -133,9 +195,62 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i ...@@ -133,9 +195,62 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
// Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0); auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
if (collective_op == JAXX_Collective_Op::NONE) {
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ",
to_string_like(out_shape), " but got ", output->element_count(), " elements ",
to_string_like(output->dimensions()));
nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(),
rhs_transposed, lhs_transposed, grad, workspace_.data(), false, rhs_transposed, lhs_transposed, grad, workspace_.data(), false,
use_split_accumulator, num_math_sm, stream); use_split_accumulator, num_math_sm, stream);
} else {
std::vector<size_t> buffer_shape{0, 0};
DType buffer_dtype = out_dtype;
auto &comm_handler = CommunicatorHandler::get();
if (collective_op == JAXX_Collective_Op::ALL_GATHER) {
buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size;
buffer_shape[1] = lhs_shape[1];
out_shape[0] = out_shape[0] * comm_handler.tp_size;
buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type());
} else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) {
buffer_shape[0] = out_shape[0];
buffer_shape[1] = out_shape[1];
out_shape[0] = out_shape[0] / comm_handler.tp_size;
}
auto executor = CollectiveGemmPlanRegistry::getInstance().get_executor(
buffer_shape, buffer_dtype, collective_op);
if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) {
auto ubuf_out_ = TensorWrapper(executor->get_ubuf_dptr(), buffer_shape, out_dtype);
// Prepare the auxiliary buffer for the reduce-scattered GEMM output
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(),
" elements ", to_string_like(out_shape), " but got ", output->element_count(),
" elements ", to_string_like(output->dimensions()));
// Launch GEMM+RS
executor->split_overlap_rs(rhs_, rhs_transposed, lhs_, lhs_transposed, ubuf_out_, bias_,
pre_gelu_, workspace_, grad, false, use_split_accumulator, out_,
stream);
} else if (collective_op == JAXX_Collective_Op::ALL_GATHER) {
auto aux_out_ = TensorWrapper(nullptr, std::vector<size_t>{0}, out_dtype); // Empty
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(),
" elements ", to_string_like(out_shape), " but got ", output->element_count(),
" elements ", to_string_like(output->dimensions()));
// Copy the distributed LHS operand into the local chunk of the communication buffer
executor->copy_into_buffer(stream, lhs_, true, make_lhs_rowwise);
// Launch AG+GEMM
executor->split_overlap_ag(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_,
workspace_, grad, false, use_split_accumulator, aux_out_, stream);
}
}
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
...@@ -161,7 +276,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, ...@@ -161,7 +276,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
.Attr<bool>("fuse_bias") .Attr<bool>("fuse_bias")
.Attr<bool>("fuse_gelu") .Attr<bool>("fuse_gelu")
.Attr<bool>("grad") .Attr<bool>("grad")
.Attr<bool>("use_split_accumulator"), .Attr<bool>("use_split_accumulator")
.Attr<JAXX_Collective_Op>("collective_op"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv,
......
...@@ -87,5 +87,31 @@ constexpr struct Alignment { ...@@ -87,5 +87,31 @@ constexpr struct Alignment {
std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise); std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise);
template <typename T, typename... Rest>
void hash_combine(int64_t &seed, const T &v, Rest... rest) {
seed ^= std::hash<T>{}(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
(hash_combine(seed, rest), ...);
}
enum class JAXX_Collective_Op : int64_t {
NONE = 0,
ALL_GATHER = 1,
REDUCE_SCATTER = 2,
};
static CommOverlapType get_nvte_collective_op(const JAXX_Collective_Op &op) {
switch (op) {
case JAXX_Collective_Op::ALL_GATHER:
return CommOverlapType::AG;
break;
case JAXX_Collective_Op::REDUCE_SCATTER:
return CommOverlapType::RS;
break;
default:
NVTE_ERROR("Invalid Collective Op ", static_cast<int>(op));
break;
}
}
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
************************************************************************/ ************************************************************************/
#include "../extensions.h" #include "../extensions.h"
#include "cgemm_helper.h"
#include "common/util/cuda_runtime.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
...@@ -57,7 +59,7 @@ pybind11::dict Registrations() { ...@@ -57,7 +59,7 @@ pybind11::dict Registrations() {
// GEMM // GEMM
dict["te_gemm_ffi"] = dict["te_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CollectiveGemmInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); pybind11::arg("execute") = EncapsulateFFI(GemmHandler));
// Grouped GEMM // Grouped GEMM
...@@ -84,6 +86,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -84,6 +86,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported);
m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator);
m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams);
pybind11::enum_<DType>(m, "DType", pybind11::module_local()) pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
...@@ -159,6 +163,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -159,6 +163,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE) .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE) .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE)
.export_values(); .export_values();
pybind11::enum_<JAXX_Collective_Op>(m, "JAXX_Collective_Op", pybind11::module_local())
.value("NONE", JAXX_Collective_Op::NONE)
.value("ALL_GATHER", JAXX_Collective_Op::ALL_GATHER)
.value("REDUCE_SCATTER", JAXX_Collective_Op::REDUCE_SCATTER)
.export_values();
} }
} // namespace jax } // namespace jax
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment