Unverified Commit 8e3797be authored by zyksir's avatar zyksir Committed by GitHub
Browse files

support 1 shot allreduce in 1-node and 2-node using mscclpp (#6277)

parent 4474eaf5
"""For Now, MSCCL is only supported on TP16 and TP8 case
export WORLD_SIZE=1
export RANK=0
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=12345
torchrun --nproc_per_node gpu \
--nnodes $WORLD_SIZE \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT benchmark/kernels/all_reduce/benchmark_mscclpp.py
"""
import os
from contextlib import nullcontext
from typing import List
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from sglang.srt.distributed import init_distributed_environment
from sglang.srt.distributed.device_communicators.pymscclpp import PyMscclppCommunicator
from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator
from sglang.srt.distributed.parallel_state import (
get_tensor_model_parallel_group,
graph_capture,
initialize_model_parallel,
set_mscclpp_all_reduce,
)
def torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Tensor:
dist.all_reduce(torch_input, group=group)
return torch_input
def msccl_allreduce(
msccl_input: torch.Tensor, msccl_comm: PyMscclppCommunicator
) -> torch.Tensor:
return msccl_comm.all_reduce(msccl_input)
def pynccl_allreduce(
msccl_input: torch.Tensor, pynccl_comm: PyNcclCommunicator
) -> torch.Tensor:
pynccl_comm.all_reduce(msccl_input)
return msccl_input
def _bench_graph_time(func, inp_randn, warmup_loop=2, graph_loop=10, test_loop=10):
graph_input = inp_randn.clone()
with graph_capture() as graph_capture_context:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
for _ in range(graph_loop):
graph_out = func(graph_input)
graph.replay()
func_output = graph_out.clone()
for _ in range(warmup_loop):
graph.replay()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: List[float] = []
for _ in range(test_loop):
torch.cuda.synchronize()
dist.barrier()
start_event.record()
graph.replay()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
func_cost_us = sum(latencies) / len(latencies) / graph_loop * 1000
graph.reset()
return func_output, func_cost_us
def _bench_eager_time(func, inp_randn, warmup_loop=2, test_loop=10):
eager_input = inp_randn.clone()
eager_output = func(eager_input)
func_output = eager_output.clone()
for _ in range(warmup_loop):
func(eager_input)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(test_loop):
func(eager_input)
end_event.record()
torch.cuda.synchronize()
func_cost_us = start_event.elapsed_time(end_event) / test_loop * 1000
return func_output, func_cost_us
def get_torch_prof_ctx(do_prof: bool):
ctx = (
torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
with_stack=True,
)
if do_prof
else nullcontext()
)
return ctx
def human_readable_size(size, decimal_places=1):
for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]:
if size < 1024.0 or unit == "PiB":
break
size /= 1024.0
return f"{size:.{decimal_places}f} {unit}"
try:
from tabulate import tabulate
except ImportError:
print("tabulate not installed, skipping table printing")
tabulate = None
def print_markdown_table(data):
if tabulate is not None:
print(tabulate(data, headers="keys", tablefmt="github"))
return
headers = data[0].keys()
header_row = "| " + " | ".join(headers) + " |"
separator = "| " + " | ".join(["---"] * len(headers)) + " |"
rows = []
for item in data:
row = "| " + " | ".join(str(item[key]) for key in headers) + " |"
rows.append(row)
markdown_table = "\n".join([header_row, separator] + rows)
print(markdown_table)
if __name__ == "__main__":
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
force=True,
)
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
world, world_size = dist.group.WORLD, dist.get_world_size()
rank = dist.get_rank()
torch.cuda.set_device(rank % 8)
device = torch.cuda.current_device()
set_mscclpp_all_reduce(True)
init_distributed_environment(
world_size=world_size,
rank=rank,
local_rank=rank % 8,
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
group = get_tensor_model_parallel_group().device_group
cpu_group = get_tensor_model_parallel_group().cpu_group
pynccl_comm = get_tensor_model_parallel_group().pynccl_comm
pymscclpp_comm = get_tensor_model_parallel_group().pymscclpp_comm
dist.barrier()
profile = False
dtype = torch.bfloat16
ctx = get_torch_prof_ctx(profile)
result = []
with ctx:
for i in range(10, 20):
sz = 2**i
if sz * dtype.itemsize > 2**20:
break
inp_randn = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
memory = torch.empty_like(inp_randn)
memory_out = torch.empty_like(memory)
torch_eager_output, torch_eager_time = _bench_eager_time(
lambda inp: torch_allreduce(inp, group), inp_randn
)
msccl_eager_output, msccl_eager_time = _bench_eager_time(
lambda inp: msccl_allreduce(inp, pymscclpp_comm), inp_randn
)
msccl_graph_output, msccl_graph_time = _bench_graph_time(
lambda inp: msccl_allreduce(inp, pymscclpp_comm), inp_randn
)
# since pynccl is inplace op, this return result is not correct if graph loop > 1
_, pynccl_graph_time = _bench_graph_time(
lambda inp: pynccl_allreduce(inp, pynccl_comm), inp_randn
)
torch.testing.assert_close(torch_eager_output, msccl_graph_output)
torch.testing.assert_close(torch_eager_output, msccl_eager_output)
result.append(
{
"msg_size": human_readable_size(inp_randn.nbytes),
"torch eager time": torch_eager_time,
"msccl eager time": msccl_eager_time,
"msccl graph time": msccl_graph_time,
"pynccl graph time": pynccl_graph_time,
}
)
if rank == 0:
print(f"sz={sz}, dtype={dtype}: correctness check PASS!")
if rank == 0:
print_markdown_table(result)
if profile:
prof_dir = f"prof/msccl"
os.makedirs(prof_dir, exist_ok=True)
ctx.export_chrome_trace(f"{prof_dir}/trace_rank{dist.get_rank()}.json.gz")
......@@ -26,6 +26,8 @@ def launch_server(args):
cmd += f"--tp-size {args.tp_size} "
if args.disable_custom_all_reduce:
cmd += "--disable-custom-all-reduce"
if args.enable_mscclpp:
cmd += "--enable-mscclpp"
print(cmd)
os.system(cmd)
......@@ -63,6 +65,11 @@ if __name__ == "__main__":
action="store_true",
help="Disable custom all reduce when device does not support p2p communication",
)
parser.add_argument(
"--enable-mscclpp",
action="store_true",
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
)
args = parser.parse_args()
launch_server(args)
......@@ -201,6 +201,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `disable_cuda_graph_padding` | Disable CUDA Graph when padding is needed; otherwise, still use CUDA Graph. | `False` |
| `disable_outlines_disk_cache` | Disable disk cache for outlines grammar backend. | `False` |
| `disable_custom_all_reduce` | Disable usage of custom all-reduce kernel. | `False` |
| `enable_mscclpp` | Enable usage of mscclpp kernel for small message all-reduce. | `False` |
| `disable_overlap_schedule` | Disable the [Overhead-Scheduler](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#zero-overhead-batch-scheduler). | `False` |
| `enable_nan_detection` | Enable warning if the logits contain `NaN`. | `False` |
| `enable_p2p_check` | Turns off the default of always allowing P2P checks when accessing GPU. | `False` |
......
......@@ -113,3 +113,37 @@ else:
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
def mscclpp_generate_unique_id() -> bytes:
return sgl_kernel.allreduce.mscclpp_generate_unique_id()
def mscclpp_init_context(
unique_id: bytes,
rank: int,
world_size: int,
scratch: torch.Tensor,
put_buffer: torch.Tensor,
nranks_per_node: int,
rank_to_node: List[int],
rank_to_ib: List[int],
context_selection: int,
) -> int:
return sgl_kernel.allreduce.mscclpp_init_context(
unique_id,
rank,
world_size,
scratch,
put_buffer,
nranks_per_node,
rank_to_node,
rank_to_ib,
context_selection,
)
def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks)
import bisect
import logging
import math
import os
from contextlib import contextmanager
from enum import IntEnum
from typing import Any, Callable, List, Optional, TypeVar, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
from sglang.srt import _custom_ops as ops
from sglang.srt.utils import is_cuda, is_hip
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
mscclpp_is_available = False
if _is_hip:
# TODO(zyksir): mscclpp is untested on AMD and therefore disabled.
mscclpp_is_available = False
if _is_cuda:
try:
import sgl_kernel
mscclpp_is_available = True
except:
mscclpp_is_available = False
class MscclContextSelection(IntEnum):
MSCCL1SHOT1NODELL = 1
MSCCL1SHOT2NODELL = 2
def mscclpp_is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
def mscclpp_convert_to_bytes(size_str):
"""
Converts a human-readable size string (e.g., "1MB", "2.5kb", "3 GB")
into the equivalent number of bytes using binary units.
Args:
size_str (str): A string representing size with unit (KB, MB, GB).
Returns:
int: Number of bytes.
"""
size_str = size_str.strip().lower()
if not size_str:
raise ValueError("Empty input string")
# Extract numeric part and unit
for i in range(len(size_str)):
if not size_str[i].isdigit() and size_str[i] != ".":
break
num_str = size_str[:i]
unit = size_str[i:].strip()
try:
num = float(num_str)
except ValueError:
raise ValueError(f"Invalid numeric value in '{size_str}'")
# Conversion factors
if unit == "b":
return int(num)
elif unit == "kb":
return int(num * 1024)
elif unit == "mb":
return int(num * 1024 * 1024)
elif unit == "gb":
return int(num * 1024 * 1024 * 1024)
else:
raise ValueError(f"Unsupported unit: {unit}, support B, KB, MB, GB only")
def mscclpp_bench_time(func, test_niter: int = 10, warmup_niter: int = 2):
# warmup
for _ in range(warmup_niter):
func()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
dist.barrier()
start_event.record()
for _ in range(test_niter):
func()
end_event.record()
end_event.synchronize()
func_cost_us = start_event.elapsed_time(end_event) / test_niter * 1000
return func_cost_us
class PyMscclppCommunicator:
_SUPPORTED_WORLD_SIZES = [8, 16]
_MAX_BYTES = mscclpp_convert_to_bytes(os.getenv("SGLANG_MSCCLPP_MAX_BYTES", "1MB"))
_SUPPORTED_DTYPE = [torch.float, torch.float16, torch.bfloat16]
# max_bytes: max supported mscclpp allreduce size
# in A100 mscclpp is faster than nccl only under condition of msg size smaller than1MB
def __init__(
self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_bytes=_MAX_BYTES,
) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self._IS_CAPTURING = False
self.disabled = True
if not mscclpp_is_available:
# disable because of missing mscclpp library
# e.g. in a non-cuda environment
return
self.group = group
assert (
dist.get_backend(group) != dist.Backend.NCCL
), "CustomAllreduce should be attached to a non-NCCL group."
rank = dist.get_rank(group=self.group)
world_size = dist.get_world_size(group=self.group)
if world_size == 1:
# No need to initialize mscclpp for single GPU case.
return
if world_size not in PyMscclppCommunicator._SUPPORTED_WORLD_SIZES:
logger.warning(
"PyMscclpp is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_mscclpp=True explicitly.",
world_size,
str(PyMscclppCommunicator._SUPPORTED_WORLD_SIZES),
)
return
self.ranks = torch.distributed.get_process_group_ranks(group)
self.nranks_per_node = torch.cuda.device_count()
# for now mscclpp with stride in the communicator is not tested
if not (abs(self.ranks[-1] - self.ranks[0]) == world_size - 1):
logger.warning(
"PyMscclpp is disabled due to an unsupported group %s."
"Please ensure all ranks in the group are consecutive."
"To silence this warning, specify disable_mscclpp=True explicitly.",
str(self.ranks),
)
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
self.max_bytes = max_bytes
self.rank = rank
self.world_size = world_size
if dist.get_rank(group) == 0:
unique_id = [ops.mscclpp_generate_unique_id()]
else:
unique_id = [None]
dist.broadcast_object_list(unique_id, src=self.ranks[0], group=self.group)
self.unique_id = unique_id[0]
self.rank_to_node, self.rank_to_ib = list(range(world_size)), list(
range(world_size)
)
for r in range(world_size):
self.rank_to_node[r] = r // 8
self.rank_to_ib[r] = self.rank % 8
self._context = None
self.context_selection = None
self.msg_size_for_finetune = [
2**i for i in range(10, math.floor(math.log2(self.max_bytes)) + 1)
]
self.msg_size2best_config = {}
if world_size == 8:
self.context_selection = MscclContextSelection.MSCCL1SHOT1NODELL
elif world_size == 16:
self.context_selection = MscclContextSelection.MSCCL1SHOT2NODELL
if not _is_hip:
self.scratch = torch.empty(
self.max_bytes * 8,
dtype=torch.uint8,
device=self.device,
)
self.put_buffer = torch.empty(
self.max_bytes * 8 // self.nranks_per_node,
dtype=torch.uint8,
device=self.device,
)
self._context = ops.mscclpp_init_context(
self.unique_id,
self.rank,
self.world_size,
self.scratch,
self.put_buffer,
self.nranks_per_node,
self.rank_to_node,
self.rank_to_ib,
int(self.context_selection),
)
else:
raise NotImplementedError("HIP Mscclpp is not supported yet.")
self.msg_size2best_config = {}
self.pre_tune_config()
if dist.get_rank(group) == 0:
msg_size2best_config = [self.msg_size2best_config]
else:
msg_size2best_config = [None]
dist.broadcast_object_list(
msg_size2best_config, src=self.ranks[0], group=self.group
)
self.msg_size2best_config = msg_size2best_config[0]
# PyMscclpp is enabled only in cuda graph
self.disabled = True
def pre_tune_config(self, dtype=torch.bfloat16) -> bool:
logger.debug(f"start to pre-tune configs for rank {self.rank}")
nthreads_to_try = [256, 512, 1024]
nblocks_to_try = [21, 42, 84]
inp_randn = torch.ones(
self.msg_size_for_finetune[-1] // dtype.itemsize, dtype=dtype, device="cuda"
)
oup_randn = torch.empty_like(inp_randn)
for msg_size in self.msg_size_for_finetune:
mock_inp, mock_outp = (
inp_randn[: msg_size // dtype.itemsize],
oup_randn[: msg_size // dtype.itemsize],
)
best_config, best_time = None, None
for nthreads in nthreads_to_try:
for nblocks in nblocks_to_try:
cur_cost = mscclpp_bench_time(
lambda: ops.mscclpp_allreduce(
self._context, mock_inp, mock_outp, nthreads, nblocks
)
)
if best_time is None or cur_cost < best_time:
best_config = (nthreads, nblocks)
best_time = cur_cost
self.msg_size2best_config[msg_size] = best_config
if self.rank == 0:
logger.debug(
f"for msg_size {msg_size}, best_config: {best_config}, best_time: {best_time}us"
)
def should_mscclpp_allreduce(
self, inp: torch.Tensor, op: ReduceOp = ReduceOp.SUM
) -> bool:
if self.disabled or self._context is None:
return False
if inp.dtype not in PyMscclppCommunicator._SUPPORTED_DTYPE:
return False
if not mscclpp_is_weak_contiguous(inp):
return False
# only support sum op
if op != ReduceOp.SUM:
return False
if inp.numel() * inp.element_size() > self.max_bytes:
return False
return True
def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM):
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
self.graph_input_set.add((tensor.dtype, tensor.numel()))
msg_size = tensor.numel() * tensor.itemsize
index = bisect.bisect_left(self.msg_size_for_finetune, msg_size)
msg_size_finetune = self.msg_size_for_finetune[index]
nthreads, nblocks = self.msg_size2best_config[msg_size_finetune]
result = torch.empty_like(tensor)
ops.mscclpp_allreduce(self._context, tensor, result, nthreads, nblocks)
return result
@contextmanager
def change_state(
self,
enable: Optional[bool] = None,
):
if enable is None:
# guess a default value when not specified
enable = self.available
old_disable = self.disabled
self.disabled = not enable
yield
self.disabled = old_disable
......@@ -190,6 +190,7 @@ class GroupCoordinator:
cpu_group: ProcessGroup # group for CPU communication
device_group: ProcessGroup # group for device communication
use_pynccl: bool # a hint of whether to use PyNccl
use_pymscclpp: bool # a hint of whether to use PyMsccl
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
use_message_queue_broadcaster: (
bool # a hint of whether to use message queue broadcaster
......@@ -205,6 +206,7 @@ class GroupCoordinator:
local_rank: int,
torch_distributed_backend: Union[str, Backend],
use_pynccl: bool,
use_pymscclpp: bool,
use_custom_allreduce: bool,
use_hpu_communicator: bool,
use_xpu_communicator: bool,
......@@ -244,6 +246,7 @@ class GroupCoordinator:
self.device = torch.device("cpu")
self.use_pynccl = use_pynccl
self.use_pymscclpp = use_pymscclpp
self.use_custom_allreduce = use_custom_allreduce
self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator
......@@ -265,6 +268,17 @@ class GroupCoordinator:
device=self.device,
)
from sglang.srt.distributed.device_communicators.pymscclpp import (
PyMscclppCommunicator,
)
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
if use_pymscclpp and self.world_size > 1:
self.pymscclpp_comm = PyMscclppCommunicator(
group=self.cpu_group,
device=self.device,
)
self.ca_comm: Optional[CustomAllreduce] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
......@@ -373,11 +387,15 @@ class GroupCoordinator:
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# PyMscclpp | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the
# tensor size is too large, it will fallback to the next
# available option.
# Note that the PyMsccl needs to register the tensor in ahead,
# which will introduce large overhead in the eager case,
# therefore it is only supported in the graph case.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using
# CUDA graph, we use either custom all-reduce kernel or
......@@ -392,7 +410,14 @@ class GroupCoordinator:
maybe_pynccl_context = pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream()
)
with maybe_pynccl_context:
pymscclpp_comm = self.pymscclpp_comm
maybe_pymscclpp_context: Any
if not pymscclpp_comm:
maybe_pymscclpp_context = nullcontext()
else:
maybe_pymscclpp_context = pymscclpp_comm.change_state(enable=True)
with maybe_pynccl_context, maybe_pymscclpp_context:
yield graph_capture_context
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
......@@ -437,6 +462,10 @@ class GroupCoordinator:
self.ca_comm is not None
and not self.ca_comm.disabled
and self.ca_comm.should_custom_ar(input_)
) or (
self.pymscclpp_comm is not None
and not self.pymscclpp_comm.disabled
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
):
return torch.ops.sglang.outplace_all_reduce(
input_, group_name=self.unique_name
......@@ -447,9 +476,13 @@ class GroupCoordinator:
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
ca_comm = self.ca_comm
assert ca_comm is not None
assert not ca_comm.disabled
out = ca_comm.custom_all_reduce(input_)
pymscclpp_comm = self.pymscclpp_comm
assert ca_comm is not None or pymscclpp_comm is not None
if ca_comm is not None and not ca_comm.disabled:
out = ca_comm.custom_all_reduce(input_)
else:
assert not pymscclpp_comm.disabled
out = pymscclpp_comm.all_reduce(input_)
assert out is not None
return out
......@@ -958,6 +991,7 @@ def init_world_group(
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=False,
use_pymscclpp=False,
use_custom_allreduce=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
......@@ -973,14 +1007,18 @@ def init_model_parallel_group(
use_custom_allreduce: Optional[bool] = None,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
use_mscclpp_allreduce: Optional[bool] = None,
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
if use_mscclpp_allreduce is None:
use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=not is_npu(),
use_pymscclpp=use_mscclpp_allreduce,
use_custom_allreduce=use_custom_allreduce,
use_hpu_communicator=True,
use_xpu_communicator=True,
......@@ -1037,6 +1075,7 @@ def graph_capture():
logger = logging.getLogger(__name__)
_ENABLE_CUSTOM_ALL_REDUCE = True
_ENABLE_MSCCLPP_ALL_REDUCE = False
def set_custom_all_reduce(enable: bool):
......@@ -1044,6 +1083,11 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE = enable
def set_mscclpp_all_reduce(enable: bool):
global _ENABLE_MSCCLPP_ALL_REDUCE
_ENABLE_MSCCLPP_ALL_REDUCE = enable
def init_distributed_environment(
world_size: int = -1,
rank: int = -1,
......
......@@ -98,11 +98,12 @@ def initialize_dp_attention(
],
local_rank,
torch.distributed.get_backend(tp_group.device_group),
SYNC_TOKEN_IDS_ACROSS_TP,
False,
False,
False,
False,
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
use_pymscclpp=False,
use_custom_allreduce=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
use_npu_communicator=False,
group_name="attention_tp",
)
......
......@@ -35,6 +35,7 @@ from sglang.srt.distributed import (
init_distributed_environment,
initialize_model_parallel,
set_custom_all_reduce,
set_mscclpp_all_reduce,
)
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
......@@ -460,6 +461,7 @@ class ModelRunner:
else:
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
if not self.is_draft_worker:
# Only initialize the distributed environment on the target model worker.
......
......@@ -165,6 +165,7 @@ class ServerArgs:
enable_tokenizer_batch_encode: bool = False
disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False
enable_mscclpp: bool = False
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
......@@ -1168,6 +1169,11 @@ class ServerArgs:
action="store_true",
help="Disable the custom all-reduce kernel and fall back to NCCL.",
)
parser.add_argument(
"--enable-mscclpp",
action="store_true",
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
)
parser.add_argument(
"--disable-overlap-schedule",
action="store_true",
......
......@@ -73,6 +73,14 @@ FetchContent_Declare(
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flash-attention)
# mscclpp
FetchContent_Declare(
repo-mscclpp
GIT_REPOSITORY https://github.com/microsoft/mscclpp.git
GIT_TAG 51eca89d20f0cfb3764ccd764338d7b22cd486a6
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-mscclpp)
# ccache option
option(ENABLE_CCACHE "Whether to use ccache" ON)
......@@ -99,6 +107,7 @@ include_directories(
${repo-cutlass_SOURCE_DIR}/tools/util/include
${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc
${repo-mscclpp_SOURCE_DIR}/include
)
set(SGL_KERNEL_CUDA_FLAGS
......@@ -196,6 +205,7 @@ string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
set(SOURCES
"csrc/allreduce/mscclpp_allreduce.cu"
"csrc/allreduce/custom_all_reduce.cu"
"csrc/attention/cascade.cu"
"csrc/attention/merge_attn_states.cu"
......@@ -250,7 +260,27 @@ target_include_directories(common_ops PRIVATE
${repo-cutlass_SOURCE_DIR}/examples/common
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
)
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
find_package(Python3 COMPONENTS Interpreter REQUIRED)
execute_process(
COMMAND ${Python3_EXECUTABLE} -c "import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))"
OUTPUT_VARIABLE TORCH_CXX11_ABI
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(TORCH_CXX11_ABI STREQUAL "0")
message(STATUS "Using old C++ ABI (-D_GLIBCXX_USE_CXX11_ABI=0)")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
else()
message(STATUS "Using new C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI=1)")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
endif()
set(MSCCLPP_USE_CUDA ON)
set(MSCCLPP_BYPASS_GPU_CHECK ON)
set(MSCCLPP_BUILD_TESTS OFF)
add_subdirectory(${repo-mscclpp_SOURCE_DIR})
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
target_compile_definitions(common_ops PRIVATE
FLASHATTENTION_DISABLE_BACKWARD
......
......@@ -19,14 +19,14 @@ submodule: ## Initialize and update git submodules
@git submodule update --init --recursive
ln: submodule ## Create compilation database
@rm -rf build && mkdir build && cd build && cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=YES
@rm -rf build && mkdir build && cd build && cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=YES -DCMAKE_POLICY_VERSION_MINIMUM=3.5
install: submodule ## Install package in development mode
@pip install -e . --no-build-isolation
build: install-deps submodule ## Build and install wheel package
@rm -rf dist/* || true && export MAX_JOBS=$(nproc) && CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) uv build --wheel -Cbuild-dir=build . --verbose --color=always --no-build-isolation && pip3 install dist/*whl --force-reinstall --no-deps
@rm -rf dist/* || true && export MAX_JOBS=$(nproc) && CMAKE_POLICY_VERSION_MINIMUM=3.5 CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) uv build --wheel -Cbuild-dir=build . --verbose --color=always --no-build-isolation && pip3 install dist/*whl --force-reinstall --no-deps
clean: ## Remove build artifacts
@rm -rf build dist *.egg-info
......
......@@ -50,6 +50,9 @@ docker run --rm \
which cmake
cmake --version
yum install numactl-devel -y && \
yum install libibverbs -y && \
ln -sv /usr/lib64/libibverbs.so.1 /usr/lib64/libibverbs.so && \
${PYTHON_ROOT_PATH}/bin/${TORCH_INSTALL} && \
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja setuptools==75.0.0 wheel==0.41.0 numpy uv scikit-build-core && \
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \
......
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include <torch/library.h>
#include "mscclpp_allreduce.cuh"
enum MscclContextSelection {
MSCCL1NODELL = 1,
MSCCL2NODELL = 2,
};
class MscclContext {
public:
MscclContextSelection selection_;
std::shared_ptr<sglang::Msccl1NodeLLcontext> msccl_1nodeLL_context;
std::shared_ptr<sglang::Msccl2NodeLLcontext> msccl_2nodeLL_context;
MscclContext(MscclContextSelection selection) : selection_(selection) {}
template <typename T>
void allreduce(
cudaStream_t stream, T* input, T* output, const size_t input_numel, int threads = 512, int block_limit = 21) {
if (selection_ == MSCCL1NODELL) {
msccl_1nodeLL_context->allreduce<T>(stream, input, output, input_numel, threads, block_limit);
} else if (selection_ == MSCCL2NODELL) {
msccl_2nodeLL_context->allreduce<T>(stream, input, output, input_numel, threads, block_limit);
}
}
};
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
torch::Tensor _unique_id2tensor(const mscclpp::UniqueId& unique_id) {
auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCPU);
auto tensor = torch::empty({static_cast<int64_t>(unique_id.size())}, options);
std::memcpy(tensor.data_ptr<uint8_t>(), unique_id.data(), unique_id.size());
return tensor;
}
// Function to convert vector of int32_t back to array of uint8_t
mscclpp::UniqueId _tensor2unique_id(const torch::Tensor& tensor) {
mscclpp::UniqueId unique_id;
std::memcpy(unique_id.data(), tensor.data_ptr<uint8_t>(), unique_id.size());
return unique_id;
}
torch::Tensor mscclpp_generate_unique_id() {
mscclpp::UniqueId unique_id = mscclpp::TcpBootstrap::createUniqueId();
return _unique_id2tensor(unique_id);
}
fptr_t mscclpp_init_context(
const torch::Tensor& unique_id,
const int64_t rank,
const int64_t world_size,
torch::Tensor& scratch,
torch::Tensor& put_buffer,
const int64_t nranks_per_node,
const std::vector<int64_t>& rank_to_node,
const std::vector<int64_t>& rank_to_ib,
const int64_t context_selection) {
MscclContext* context_ptr = new MscclContext(static_cast<MscclContextSelection>(context_selection));
mscclpp::UniqueId uid = _tensor2unique_id(unique_id);
if (context_selection == MSCCL1NODELL) {
void* scratch_ptr = reinterpret_cast<void*>(scratch.data_ptr());
const size_t scratch_bytes = scratch.numel() * scratch.element_size();
context_ptr->msccl_1nodeLL_context = std::make_shared<sglang::Msccl1NodeLLcontext>(
uid, rank, world_size, scratch_ptr, scratch_bytes, nranks_per_node, rank_to_node, rank_to_ib);
} else if (context_selection == MSCCL2NODELL) {
void* scratch_ptr = reinterpret_cast<void*>(scratch.data_ptr());
const size_t scratch_bytes = scratch.numel() * scratch.element_size();
void* put_buffer_ptr = reinterpret_cast<void*>(put_buffer.data_ptr());
const size_t put_buffer_bytes = put_buffer.numel() * put_buffer.element_size();
context_ptr->msccl_2nodeLL_context = std::make_shared<sglang::Msccl2NodeLLcontext>(
uid,
rank,
world_size,
scratch_ptr,
scratch_bytes,
put_buffer_ptr,
put_buffer_bytes,
nranks_per_node,
rank_to_node,
rank_to_ib);
} else {
throw std::runtime_error("invalid context selection");
}
return (fptr_t)context_ptr;
}
bool _mscclpp_is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size());
}
void mscclpp_allreduce(fptr_t _context, torch::Tensor& inp, torch::Tensor& out, int64_t nthreads, int64_t nblocks) {
MscclContext* context = reinterpret_cast<MscclContext*>(_context);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(_mscclpp_is_weak_contiguous(out));
TORCH_CHECK(_mscclpp_is_weak_contiguous(inp));
switch (out.scalar_type()) {
case at::ScalarType::Float: {
context->allreduce<float>(
stream,
reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float*>(out.data_ptr()),
inp.numel(),
nthreads,
nblocks);
break;
}
case at::ScalarType::Half: {
context->allreduce<half>(
stream,
reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()),
inp.numel(),
nthreads,
nblocks);
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
context->allreduce<__nv_bfloat16>(
stream,
reinterpret_cast<__nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<__nv_bfloat16*>(out.data_ptr()),
inp.numel(),
nthreads,
nblocks);
break;
}
#endif
default:
throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16");
}
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#pragma once
#if defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_fp16.h>
#else
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#endif
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/memory_channel.hpp>
#include <mscclpp/memory_channel_device.hpp>
#include <mscclpp/nvls_device.hpp>
#include <mscclpp/port_channel.hpp>
#include <mscclpp/port_channel_device.hpp>
// comment this for test_mscclpp_allreduce.cu
#include "utils.h"
namespace sglang {
__device__ mscclpp::DeviceSyncer deviceSyncer;
__device__ mscclpp::DeviceSyncer allGatherDeviceSyncer;
__device__ mscclpp::DeviceSyncer reduceScatterDeviceSyncer;
__device__ mscclpp::DeviceSyncer ibDeviceSyncer;
template <typename To, typename From>
__forceinline__ __device__ To bit_cast(const From& src) {
static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast");
union {
From f;
To t;
} u;
u.f = src;
return u.t;
}
template <typename T>
__forceinline__ __device__ T add_elements(T a, T b) {
return a + b;
}
template <>
__forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) {
return __hadd2(a, b);
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <>
__forceinline__ __device__ __nv_bfloat162 add_elements(__nv_bfloat162 a, __nv_bfloat162 b) {
return __hadd2(a, b);
}
#endif
template <typename T>
__forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) {
int4 ret;
ret.w = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.w), bit_cast<T, int>(b.w)));
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
ret.z = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.z), bit_cast<T, int>(b.z)));
return ret;
}
template <typename T>
__forceinline__ __device__ int4 add_vectors(int4 a, int4 b) {
return add_vectors_helper<T>(a, b);
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <>
__forceinline__ __device__ int4 add_vectors<__nv_bfloat16>(int4 a, int4 b) {
return add_vectors_helper<__nv_bfloat162>(a, b);
}
#endif
template <>
__forceinline__ __device__ int4 add_vectors<__half>(int4 a, int4 b) {
return add_vectors_helper<__half2>(a, b);
}
template <typename T>
__forceinline__ __device__ uint2 add_vectors_helper(uint2 a, uint2 b) {
uint2 ret;
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
return ret;
}
template <typename T>
__forceinline__ __device__ uint2 add_vectors(uint2 a, uint2 b) {
return add_vectors_helper<T>(a, b);
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <>
__forceinline__ __device__ uint2 add_vectors<__nv_bfloat16>(uint2 a, uint2 b) {
return add_vectors_helper<__nv_bfloat162>(a, b);
}
#endif
template <>
__forceinline__ __device__ uint2 add_vectors<__half>(uint2 a, uint2 b) {
return add_vectors_helper<__half2>(a, b);
}
template <typename T>
__forceinline__ __device__ int add_vectors_helper(int a, int b) {
return bit_cast<int, T>(add_elements(bit_cast<T, int>(a), bit_cast<T, int>(b)));
}
template <typename T>
__forceinline__ __device__ int add_vectors(int a, int b) {
return add_vectors_helper<T>(a, b);
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <>
__forceinline__ __device__ int add_vectors<__nv_bfloat16>(int a, int b) {
return add_vectors_helper<__nv_bfloat162>(a, b);
}
#endif
template <>
__forceinline__ __device__ int add_vectors<__half>(int a, int b) {
return add_vectors_helper<__half2>(a, b);
}
// -------------------------------------------------------
// allreduce_LL_1node using LLPacket, origin allreduce2
// -------------------------------------------------------
__device__ uint64_t globalFlag = 1;
template <typename TYPE>
__global__ void __launch_bounds__(1024, 1) allreduce_LL_1node(
mscclpp::MemoryChannelDeviceHandle* memChans,
TYPE* buff,
TYPE* scratch,
void* resultBuff,
int rank,
int worldSize,
size_t nelems) {
nelems = nelems / (sizeof(int) / sizeof(TYPE));
// This version of allreduce only works for single nodes
const int nPeers = worldSize - 1;
const size_t nPkts = nelems / 2;
const int nelemsPerRank = nelems / worldSize;
const int nPktsPerRank = nelemsPerRank / 2;
// flag for packets. Initially 1
const uint32_t flag = (uint32_t)globalFlag;
// thread block & channel info
const int nBlocksPerPeer = gridDim.x / nPeers;
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
const int peerIdx = blockIdx.x / nBlocksPerPeer;
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx];
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
// double buffering
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LLPacket);
size_t scratchResultOffset =
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket);
size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int);
uint2* src = (uint2*)((char*)buff + rank * nelemsPerRank * sizeof(int));
uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int));
// step 1: write to scratch buffer
memChan.putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) {
uint2 data = make_uint2(0, 0);
for (int index = 0; index < nPeers; index++) {
const int remoteRank = index < rank ? index : index + 1;
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank;
uint2 val = dstPkt[idx].read(flag);
data = add_vectors<TYPE>(val, data);
}
data = add_vectors<TYPE>(data, src[idx]);
dst[idx] = data;
mscclpp::LLPacket packet;
packet.data1 = data.x;
packet.flag1 = flag;
packet.data2 = data.y;
packet.flag2 = flag;
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + rank * nPktsPerRank);
for (int index = 0; index < nPeers; index++) {
memChans[index].write(offset, packet);
}
}
// step 3: get data result from scratch buffer
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
const int dstOffset = remoteRank * nPktsPerRank;
uint2* result = (uint2*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int));
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) {
uint2 data = dstPkt[idx + dstOffset].read(flag);
result[idx].x = data.x;
result[idx].y = data.y;
}
if (threadIdx.x == 0 && blockIdx.x == 0) {
globalFlag += 1;
}
}
// -------------------------------------------------------
// allreduce_LL_2node using LLPacket, origin allreduce5
// -------------------------------------------------------
template <typename TYPE>
__global__ void __launch_bounds__(1024, 1) allreduce_LL_2node(
mscclpp::MemoryChannelDeviceHandle* memChans,
mscclpp::PortChannelDeviceHandle* portChans,
TYPE* buff,
TYPE* scratch,
TYPE* putBuff,
TYPE* resultBuff,
int rank,
int nRanksPerNode,
int worldSize,
size_t nelems) {
nelems = nelems / (sizeof(int) / sizeof(TYPE));
// This version of allreduce only works for single nodes
const int nPeersInNode = nRanksPerNode - 1;
const int nPkts = nelems / 2;
const int nelemsPerLocalRank = nelems / nRanksPerNode;
const int nPktsPerLocalRank = nelemsPerLocalRank / 2;
const int localRankId = rank % nRanksPerNode;
// flag for packets. Initially 1
const uint32_t flag = (uint32_t)globalFlag;
// thread block & channel info
const int nBlocksPerPeer = gridDim.x / nPeersInNode;
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
const int peerIdx = blockIdx.x / nBlocksPerPeer;
const int remoteRankIdx = peerIdx < localRankId ? peerIdx : peerIdx + 1;
mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx];
mscclpp::PortChannelDeviceHandle portChan = portChans[localRankId];
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
// double buffering
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
size_t putBaseOffset = (flag & 1) ? 0 : nPktsPerLocalRank * sizeof(mscclpp::LLPacket);
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
size_t scratchOffset = scratchBaseOffset + localRankId * nPktsPerLocalRank * sizeof(mscclpp::LLPacket);
size_t scratchResultOffset =
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket);
size_t srcOffset = remoteRankIdx * nelemsPerLocalRank * sizeof(int);
uint2* src = (uint2*)((char*)buff + localRankId * nelemsPerLocalRank * sizeof(int));
uint2* dst = (uint2*)((char*)resultBuff + localRankId * nelemsPerLocalRank * sizeof(int));
// step 1: write to scratch buffer
if (nRanksPerNode > 1) {
memChan.putPackets(
scratchOffset, srcOffset, nelemsPerLocalRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
}
// step 2: get data from scratch buffer, do local reduce-scatter in each node.
mscclpp::LLPacket* putPkt = (mscclpp::LLPacket*)((char*)putBuff + putBaseOffset);
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerLocalRank; idx += blockDim.x * gridDim.x) {
uint2 data = make_uint2(0, 0);
for (int index = 0; index < nPeersInNode; index++) {
const int remoteRank = index < localRankId ? index : index + 1;
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerLocalRank;
uint2 val = dstPkt[idx].read(flag);
data = add_vectors<TYPE>(val, data);
}
data = add_vectors<TYPE>(data, src[idx]);
putPkt[idx].write(data.x, data.y, flag);
dst[idx] = data;
}
deviceSyncer.sync(gridDim.x);
// step 3. send local reduced data to remote node.
if (threadIdx.x == 0 && blockIdx.x == 0) {
portChan.put(scratchOffset, putBaseOffset, nPktsPerLocalRank * sizeof(mscclpp::LLPacket));
if ((flag & 63) == 0) {
portChan.flush();
}
}
// step 4. try to read the data from scratch buffer and write to local peers
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + localRankId * nPktsPerLocalRank;
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerLocalRank; idx += blockDim.x * gridDim.x) {
uint2 res = dst[idx];
uint2 val = dstPkt[idx].read(flag);
res = add_vectors<TYPE>(res, val);
mscclpp::LLPacket packet;
packet.data1 = res.x;
packet.flag1 = flag;
packet.data2 = res.y;
packet.flag2 = flag;
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + localRankId * nPktsPerLocalRank);
for (int index = 0; index < nPeersInNode; index++) {
memChans[index].write(offset, packet);
}
dst[idx] = res;
}
// step 5: get data result from scratch buffer
dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
const int dstOffset = remoteRankIdx * nPktsPerLocalRank;
uint2* result = (uint2*)((char*)resultBuff + remoteRankIdx * nelemsPerLocalRank * sizeof(int));
if (nRanksPerNode > 1) {
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerLocalRank;
idx += blockDim.x * nBlocksPerPeer) {
uint2 data = dstPkt[idx + dstOffset].read(flag);
result[idx] = data;
}
}
if (threadIdx.x == 0 && blockIdx.x == 0) {
globalFlag += 1;
}
}
static const mscclpp::Transport IBs[] = {
mscclpp::Transport::IB0,
mscclpp::Transport::IB1,
mscclpp::Transport::IB2,
mscclpp::Transport::IB3,
mscclpp::Transport::IB4,
mscclpp::Transport::IB5,
mscclpp::Transport::IB6,
mscclpp::Transport::IB7};
class MscclCommGroup {
public:
std::shared_ptr<mscclpp::Communicator> comm_;
const size_t rank_;
const size_t world_size_;
const std::vector<int64_t> rank_to_node_;
const std::vector<int64_t> rank_to_ib_;
MscclCommGroup(
mscclpp::UniqueId unique_id,
const size_t rank,
const size_t world_size,
const std::vector<int64_t>& rank_to_node,
const std::vector<int64_t>& rank_to_ib)
: rank_(rank), world_size_(world_size), rank_to_node_(rank_to_node), rank_to_ib_(rank_to_ib) {
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, world_size);
bootstrap->initialize(unique_id);
comm_ = std::make_shared<mscclpp::Communicator>(bootstrap);
}
template <typename T>
void allreduce(cudaStream_t stream, T* output, size_t input_numel, int threads = 512, int block_limit = 21) {
throw std::runtime_error("you should not call allreduce of a base context");
}
bool is_same_node(int r1, int r2) {
return rank_to_node_[r1] == rank_to_node_[r2];
}
void make_connection(
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& same_node_connections,
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& cross_node_connections) {
same_node_connections.clear();
cross_node_connections.clear();
std::unordered_map<int, mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> conn_futures;
for (int r = 0; r < world_size_; ++r) {
if (r == rank_) continue;
mscclpp::Transport transport = is_same_node(r, rank_) ? mscclpp::Transport::CudaIpc : IBs[rank_to_ib_[r]];
conn_futures.emplace(r, comm_->connectOnSetup(r, 0, transport));
}
comm_->setup();
for (int r = 0; r < world_size_; ++r) {
if (r == rank_) continue;
if (is_same_node(r, rank_)) {
same_node_connections.emplace(r, conn_futures[r].get());
} else {
cross_node_connections.emplace(r, conn_futures[r].get());
}
}
}
void make_memory_channels_with_scratch(
void* tensor_ptr,
const size_t tensor_bytes,
void* scratch_ptr,
const size_t scratch_bytes,
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& semaphores,
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories,
std::unordered_map<int, mscclpp::MemoryChannel>& channels) {
channels.clear();
make_semaphores<mscclpp::MemoryDevice2DeviceSemaphore>(connections, semaphores);
register_tensor_with_connections(scratch_ptr, scratch_bytes, connections, registered_memories);
for (const auto& [peer, _] : connections) {
channels.emplace(
peer, mscclpp::MemoryChannel(semaphores[peer], registered_memories[peer], tensor_ptr, scratch_ptr));
}
}
void make_port_channels_with_scratch(
std::shared_ptr<mscclpp::ProxyService> proxyService,
void* tensor_ptr,
const size_t tensor_bytes,
void* scratch_ptr,
const size_t scratch_bytes,
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::unordered_map<int, std::shared_ptr<mscclpp::Host2DeviceSemaphore>>& semaphores,
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories,
std::unordered_map<int, mscclpp::PortChannel>& channels) {
channels.clear();
make_semaphores<mscclpp::Host2DeviceSemaphore>(connections, semaphores);
mscclpp::TransportFlags flags;
for (const auto& [_, conn] : connections) {
flags |= conn->transport();
}
auto local_reg_memory = comm_->registerMemory(tensor_ptr, tensor_bytes, flags);
register_tensor_with_connections(scratch_ptr, scratch_bytes, connections, registered_memories);
std::unordered_map<int, mscclpp::SemaphoreId> semaphore_ids;
std::unordered_map<int, size_t> memory_ids;
memory_ids[rank_] = proxyService->addMemory(local_reg_memory);
for (const auto& [peer, memory] : registered_memories) {
if (peer == rank_) continue;
memory_ids[peer] = proxyService->addMemory(memory);
}
for (const auto& [peer, semaphore] : semaphores) {
semaphore_ids[peer] = proxyService->addSemaphore(semaphore);
}
for (const auto& [peer, _] : connections) {
channels.emplace(peer, proxyService->portChannel(semaphore_ids[peer], memory_ids[peer], memory_ids[rank_]));
}
}
template <typename SemaphoreType>
void make_semaphores(
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::unordered_map<int, std::shared_ptr<SemaphoreType>>& semaphores) {
semaphores.clear();
for (const auto& [peer, conn] : connections) {
semaphores[peer] = std::make_shared<SemaphoreType>(*comm_, conn);
}
comm_->setup();
}
void register_tensor_with_connections(
void* tensor_ptr,
size_t tensor_bytes,
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories) {
registered_memories.clear();
mscclpp::TransportFlags all_transports;
for (const auto& [_, connection] : connections) {
all_transports |= connection->transport();
}
mscclpp::RegisteredMemory buf_reg_mem = comm_->registerMemory(tensor_ptr, tensor_bytes, all_transports);
registered_memories[rank_] = buf_reg_mem;
std::unordered_map<int, mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remote_mem_futures;
for (const auto& [r, connection] : connections) {
comm_->sendMemoryOnSetup(buf_reg_mem, r, 0);
auto remoteMemory = comm_->recvMemoryOnSetup(r, 0);
remote_mem_futures.emplace(r, remoteMemory);
}
comm_->setup();
for (auto& [r, mem_feature] : remote_mem_futures) {
registered_memories.emplace(r, mem_feature.get());
}
}
void make_device_memory_handle_base_on_new_ptr(
const std::unordered_map<int, mscclpp::MemoryChannel>& old_memory_channels,
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_sm_memories,
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& memory_semaphores,
std::unordered_map<int, mscclpp::MemoryChannel>& memory_channels,
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>& device_memory_handle,
void* input,
void* scratch,
const cudaStream_t stream) {
memory_channels.clear();
for (const auto& [peer, channel] : old_memory_channels) {
memory_channels.emplace(
peer, mscclpp::MemoryChannel(memory_semaphores[peer], registered_sm_memories[peer], input, scratch));
}
std::vector<mscclpp::MemoryChannel> memory_channels_list;
for (int r = 0; r < world_size_; r++) {
if (r == rank_) continue;
if (is_same_node(r, rank_)) {
memory_channels_list.push_back(memory_channels[r]);
}
}
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
std::transform(
memory_channels_list.begin(),
memory_channels_list.end(),
memory_channel_handlers.begin(),
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
mscclpp::gpuMemcpyAsync<mscclpp::MemoryChannelDeviceHandle>(
device_memory_handle.data(),
memory_channel_handlers.data(),
memory_channel_handlers.size(),
stream,
cudaMemcpyHostToDevice);
}
};
class Msccl1NodeLLcontext {
private:
std::shared_ptr<MscclCommGroup> comm_group_ = nullptr;
void* scratch_;
const size_t scratch_bytes_;
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> same_node_connections_;
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> cross_node_connections_;
std::unordered_map<int, mscclpp::RegisteredMemory> registered_sm_memories_;
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memory_semaphores_;
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels_;
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> d_memHandles_;
std::unordered_map<void*, std::unordered_map<int, mscclpp::MemoryChannel>> input_ptr2memory_channels_;
std::unordered_map<void*, mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>> input_ptr2d_memHandles_;
cudaStream_t h2d_stream;
const size_t nranks_per_node_;
public:
Msccl1NodeLLcontext(
mscclpp::UniqueId unique_id,
const size_t rank,
const size_t world_size,
void* scratch,
const size_t scratch_bytes,
const size_t nranks_per_node,
const std::vector<int64_t>& rank_to_node,
const std::vector<int64_t>& rank_to_ib)
: scratch_(scratch),
scratch_bytes_(scratch_bytes),
nranks_per_node_(nranks_per_node),
d_memHandles_(nranks_per_node - 1) {
CHECK_CUDA_SUCCESS(cudaStreamCreateWithFlags(&h2d_stream, cudaStreamNonBlocking));
comm_group_ = std::make_shared<MscclCommGroup>(unique_id, rank, world_size, rank_to_node, rank_to_ib);
comm_group_->make_connection(same_node_connections_, cross_node_connections_);
comm_group_->make_memory_channels_with_scratch(
scratch_,
scratch_bytes_,
scratch_,
scratch_bytes_,
same_node_connections_,
memory_semaphores_,
registered_sm_memories_,
memory_channels_);
std::vector<mscclpp::MemoryChannel> memory_channels_list;
for (int r = 0; r < comm_group_->world_size_; r++) {
if (r == comm_group_->rank_) continue;
memory_channels_list.push_back(memory_channels_[r]);
}
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
std::transform(
memory_channels_list.begin(),
memory_channels_list.end(),
memory_channel_handlers.begin(),
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
mscclpp::gpuMemcpy<mscclpp::MemoryChannelDeviceHandle>(
d_memHandles_.data(), memory_channel_handlers.data(), memory_channel_handlers.size(), cudaMemcpyHostToDevice);
}
~Msccl1NodeLLcontext() {
CHECK_CUDA_SUCCESS(cudaStreamDestroy(h2d_stream));
}
template <typename T>
void allreduce(cudaStream_t stream, T* input, T* output, size_t input_numel, int nthreads = 512, int nblocks = 21) {
dim3 nthrs(nthreads);
dim3 nblks(nblocks);
cudaStreamCaptureStatus capturing_status;
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &capturing_status));
mscclpp::MemoryChannelDeviceHandle* memChans;
if (capturing_status != cudaStreamCaptureStatusActive) {
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
comm_group_->make_device_memory_handle_base_on_new_ptr(
memory_channels_,
registered_sm_memories_,
memory_semaphores_,
memory_channels,
d_memHandles_,
input,
scratch_,
h2d_stream);
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(h2d_stream));
memChans = d_memHandles_.data();
} else {
void* input_void_ptr = reinterpret_cast<void*>(input);
if (input_ptr2d_memHandles_.find(input_void_ptr) == input_ptr2d_memHandles_.end()) {
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> device_memory_handle(comm_group_->world_size_ - 1);
comm_group_->make_device_memory_handle_base_on_new_ptr(
memory_channels_,
registered_sm_memories_,
memory_semaphores_,
memory_channels,
device_memory_handle,
input,
scratch_,
h2d_stream);
input_ptr2memory_channels_.emplace(input_void_ptr, memory_channels);
input_ptr2d_memHandles_.emplace(input_void_ptr, device_memory_handle);
}
auto it = input_ptr2d_memHandles_.find(input_void_ptr);
memChans = it->second.data();
}
allreduce_LL_1node<T><<<nblks, nthrs, 0, stream>>>(
memChans, (T*)input, (T*)scratch_, output, comm_group_->rank_, comm_group_->world_size_, input_numel);
cudaError_t status = cudaGetLastError();
if (status != cudaSuccess) {
printf("rank: %lu failed to launch allreduce_LL_1node: %s\n", comm_group_->rank_, cudaGetErrorString(status));
}
}
};
class Msccl2NodeLLcontext {
private:
std::shared_ptr<MscclCommGroup> comm_group_ = nullptr;
void* scratch_;
const size_t scratch_bytes_;
void* put_buffer_;
const size_t put_buffer_bytes_;
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> same_node_connections_;
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> cross_node_connections_;
std::unordered_map<int, mscclpp::RegisteredMemory> registered_sm_memories_;
std::unordered_map<int, mscclpp::RegisteredMemory> registered_port_memories_;
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memory_semaphores_;
std::unordered_map<int, std::shared_ptr<mscclpp::Host2DeviceSemaphore>> port_semaphores_;
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels_;
std::unordered_map<int, mscclpp::PortChannel> port_channels_;
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> d_memHandles_;
mscclpp::GpuBuffer<mscclpp::PortChannelDeviceHandle> d_portHandles_;
std::shared_ptr<mscclpp::ProxyService> proxyService;
cudaStream_t h2d_stream;
const size_t nranks_per_node_;
std::unordered_map<void*, std::unordered_map<int, mscclpp::MemoryChannel>> input_ptr2memory_channels_;
std::unordered_map<void*, mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>> input_ptr2d_memHandles_;
public:
Msccl2NodeLLcontext(
mscclpp::UniqueId unique_id,
const size_t rank,
const size_t world_size,
void* scratch,
const size_t scratch_bytes,
void* put_buffer,
const size_t put_buffer_bytes,
const size_t nranks_per_node,
const std::vector<int64_t>& rank_to_node,
const std::vector<int64_t>& rank_to_ib)
: scratch_(scratch),
scratch_bytes_(scratch_bytes),
put_buffer_(put_buffer),
put_buffer_bytes_(put_buffer_bytes),
nranks_per_node_(nranks_per_node),
d_memHandles_(nranks_per_node - 1),
d_portHandles_(world_size - nranks_per_node) {
CHECK_CUDA_SUCCESS(cudaStreamCreateWithFlags(&h2d_stream, cudaStreamNonBlocking));
comm_group_ = std::make_shared<MscclCommGroup>(unique_id, rank, world_size, rank_to_node, rank_to_ib);
proxyService = std::make_shared<mscclpp::ProxyService>();
proxyService->startProxy();
comm_group_->make_connection(same_node_connections_, cross_node_connections_);
comm_group_->make_memory_channels_with_scratch(
scratch_,
scratch_bytes_,
scratch_,
scratch_bytes_,
same_node_connections_,
memory_semaphores_,
registered_sm_memories_,
memory_channels_);
comm_group_->make_port_channels_with_scratch(
proxyService,
put_buffer_,
put_buffer_bytes_,
scratch_,
scratch_bytes_,
cross_node_connections_,
port_semaphores_,
registered_port_memories_,
port_channels_);
std::vector<mscclpp::MemoryChannel> memory_channels_list;
std::vector<mscclpp::PortChannel> port_channels_list;
for (int r = 0; r < comm_group_->world_size_; r++) {
if (r == comm_group_->rank_) continue;
if (comm_group_->is_same_node(r, comm_group_->rank_)) {
memory_channels_list.push_back(memory_channels_[r]);
} else {
port_channels_list.push_back(port_channels_[r]);
}
}
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
std::transform(
memory_channels_list.begin(),
memory_channels_list.end(),
memory_channel_handlers.begin(),
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
mscclpp::gpuMemcpy<mscclpp::MemoryChannelDeviceHandle>(
d_memHandles_.data(), memory_channel_handlers.data(), memory_channel_handlers.size(), cudaMemcpyHostToDevice);
std::vector<mscclpp::PortChannelDeviceHandle> port_channel_handlers(port_channels_list.size());
std::transform(
port_channels_list.begin(),
port_channels_list.end(),
port_channel_handlers.begin(),
[](const mscclpp::PortChannel& channel) { return channel.deviceHandle(); });
mscclpp::gpuMemcpy<mscclpp::PortChannelDeviceHandle>(
d_portHandles_.data(), port_channel_handlers.data(), port_channel_handlers.size(), cudaMemcpyHostToDevice);
}
~Msccl2NodeLLcontext() {
CHECK_CUDA_SUCCESS(cudaStreamDestroy(h2d_stream));
if (proxyService) {
proxyService->stopProxy();
}
}
template <typename T>
void
allreduce(cudaStream_t stream, T* input, T* output, const size_t input_numel, int nthreads = 512, int nblocks = 21) {
dim3 nthrs(nthreads);
dim3 nblks(nblocks);
cudaStreamCaptureStatus capturing_status;
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &capturing_status));
mscclpp::MemoryChannelDeviceHandle* memChans;
if (capturing_status != cudaStreamCaptureStatusActive) {
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
comm_group_->make_device_memory_handle_base_on_new_ptr(
memory_channels_,
registered_sm_memories_,
memory_semaphores_,
memory_channels,
d_memHandles_,
input,
scratch_,
h2d_stream);
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(h2d_stream));
memChans = d_memHandles_.data();
} else {
void* input_void_ptr = reinterpret_cast<void*>(input);
if (input_ptr2d_memHandles_.find(input_void_ptr) == input_ptr2d_memHandles_.end()) {
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> device_memory_handle(7);
comm_group_->make_device_memory_handle_base_on_new_ptr(
memory_channels_,
registered_sm_memories_,
memory_semaphores_,
memory_channels,
device_memory_handle,
input,
scratch_,
h2d_stream);
input_ptr2memory_channels_.emplace(input_void_ptr, memory_channels);
input_ptr2d_memHandles_.emplace(input_void_ptr, device_memory_handle);
}
auto it = input_ptr2d_memHandles_.find(input_void_ptr);
memChans = it->second.data();
}
allreduce_LL_2node<T><<<nblks, nthrs, 0, stream>>>(
memChans,
d_portHandles_.data(),
(T*)input,
(T*)scratch_,
(T*)put_buffer_,
output,
comm_group_->rank_,
nranks_per_node_,
comm_group_->world_size_,
input_numel);
cudaError_t status = cudaGetLastError();
if (status != cudaSuccess) {
printf("rank: %lu failed to launch allreduce_LL_2node: %s\n", comm_group_->rank_, cudaGetErrorString(status));
}
}
};
} // namespace sglang
/*
* this file is used to test mscclpp_allreduce.cu using mpirun
* this file is adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.2.5/src/test_sum_all_reduce.cu
usage:
cd PATH-TO-THIS-FILE
export MPI_HOME=/usr/local/mpi
# export MPI_HOME=/opt/hpcx/ompi/
export MSCCLPP_HOME=/workspace/test/mscclpp
nvcc -O2 -arch=native -std=c++17 test_mscclpp_allreduce.cu \
-o test_mscclpp_allreduce -D_GLIBCXX_USE_CXX11_ABI=0 \
-I${MSCCLPP_HOME}/include -L${MSCCLPP_HOME}/build -lmscclpp \
-lnccl -I${MPI_HOME}/include -L${MPI_HOME}/lib -lmpi
/opt/hpcx/ompi/bin/
mpirun --allow-run-as-root -H 127.0.0.1:8 -np 8 \
--map-by ppr:8:node \
--mca btl_openib_warn_no_device_params_found 0 \
--mca btl_tcp_if_include bond0 \
--allow-run-as-root -np 8 \
-x NCCL_RUNTIME_CONNECT=0 -x NCCL_IB_GID_INDEX=3 -x NCCL_DEBUG=WARN \
-x LD_PRELOAD=${MSCCLPP_HOME}/build/libmscclpp.so ./test_mscclpp_allreduce
*/
#include <mpi.h>
#include <thrust/detail/raw_pointer_cast.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#ifndef CHECK_CUDA_SUCCESS
#define CHECK_CUDA_SUCCESS(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#endif
#include <cstdint>
#include "mscclpp_allreduce.cuh"
template <typename T>
bool isclose(T a, T b, float rtol = 1e-5, float atol = 1e-8) {
return fabs(a - b) <= (atol + rtol * fabs(b));
}
int main(int argc, char* argv[]) {
// init mpi
MPI_Init(&argc, &argv);
printf("MPI Initialized.\n");
int nranks, rank;
// get work size and rank id
MPI_Comm_size(MPI_COMM_WORLD, &nranks);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
cudaSetDevice(rank);
printf("nranks: %d, rank: %d\n", nranks, rank);
// init host and device buffers
using T = float;
using ReduceT = float;
const size_t num_elems = 2 * 1024 * 1024;
std::vector<T> host_buf(num_elems);
for (uint32_t i = 0; i < num_elems; ++i) {
host_buf[i] = T(i + rank);
}
thrust::device_vector<T> device_buf(host_buf);
const size_t buf_size_in_bytes = num_elems * sizeof(T);
std::vector<T> host_result_buf(num_elems);
thrust::device_vector<T> device_result_buf(host_result_buf);
std::vector<T> host_scratch_buf(num_elems * 8);
for (uint32_t i = 0; i < num_elems; ++i) {
host_scratch_buf[i] = 1;
}
thrust::device_vector<T> device_scratch_buf(host_scratch_buf);
std::vector<T> host_put_buf(num_elems);
thrust::device_vector<T> device_put_buf(host_put_buf);
mscclpp::UniqueId unique_id;
if (rank == 0) unique_id = mscclpp::TcpBootstrap::createUniqueId();
MPI_Bcast(&unique_id, sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD);
std::vector<int64_t> rank_to_node(nranks);
std::vector<int64_t> rank_to_ib(nranks);
for (int i = 0; i < nranks; i++) {
rank_to_node[i] = i / 8;
rank_to_ib[i] = i % 8;
}
cudaStream_t s;
CHECK_CUDA_SUCCESS(cudaStreamCreate(&s));
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(s));
if (nranks == 8) {
auto context = std::make_shared<sglang::Msccl1NodeLLcontext>(
unique_id,
rank,
nranks,
thrust::raw_pointer_cast(device_scratch_buf.data()),
buf_size_in_bytes * 8,
rank_to_node,
rank_to_ib);
printf("rank: %d, Msccl1NodeLLcontext setup.\n", rank);
MPI_Barrier(MPI_COMM_WORLD);
context->allreduce<T>(
s,
thrust::raw_pointer_cast(device_buf.data()),
thrust::raw_pointer_cast(device_result_buf.data()),
device_buf.size());
} else if (nranks == 16) {
// TODO: this branch is untested since there is something wrong with mpirun in my test machince
auto context = std::make_shared<sglang::Msccl2NodeLLcontext>(
unique_id,
rank,
nranks,
thrust::raw_pointer_cast(device_scratch_buf.data()),
buf_size_in_bytes * 8,
thrust::raw_pointer_cast(device_put_buf.data()),
buf_size_in_bytes,
rank_to_node,
rank_to_ib);
printf("rank: %d, Msccl2NodeLLcontext setup.\n", rank);
MPI_Barrier(MPI_COMM_WORLD);
context->allreduce<T>(
s,
thrust::raw_pointer_cast(device_buf.data()),
thrust::raw_pointer_cast(device_result_buf.data()),
device_buf.size());
}
// check result correctness
thrust::host_vector<T> host_buf_result = device_result_buf;
size_t num_results_error_atol_1e_3_rtol_1e_3 = 0;
bool nan_detected = false;
for (uint32_t i = 0; i < num_elems; ++i) {
T expected = T(i * nranks + (nranks - 1) * nranks / 2);
if (std::isnan(float(host_buf_result[i]))) {
nan_detected = true;
}
if (!isclose(float(host_buf_result[i]), float(expected), 1e-3, 1e-3)) {
num_results_error_atol_1e_3_rtol_1e_3++;
}
}
float result_accuracy = 1. - float(num_results_error_atol_1e_3_rtol_1e_3) / float(num_elems);
printf("rank: %d, nan_detected: %d accuracy: %f\n", rank, nan_detected, result_accuracy);
CHECK_CUDA_SUCCESS(cudaStreamDestroy(s));
MPI_Finalize();
return 0;
}
......@@ -38,6 +38,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
m.impl("all_reduce", torch::kCUDA, &all_reduce);
m.def("mscclpp_generate_unique_id", &mscclpp_generate_unique_id);
m.def(
"mscclpp_init_context(Tensor unique_id, int rank, int world_size, Tensor scratch, Tensor put_buffer, "
"int nranks_per_node, int[] rank_to_node, int[] rank_to_ib, int context_selection) -> int");
m.impl("mscclpp_init_context", torch::kCUDA, &mscclpp_init_context);
m.def("mscclpp_allreduce(int context, Tensor inp, Tensor! out, int nthreads, int nblocks) -> ()");
m.impl("mscclpp_allreduce", torch::kCUDA, &mscclpp_allreduce);
/*
* From csrc/attention
*/
......
......@@ -74,6 +74,18 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs);
void register_graph_buffers(
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets);
torch::Tensor mscclpp_generate_unique_id();
fptr_t mscclpp_init_context(
const torch::Tensor& unique_id,
const int64_t rank,
const int64_t world_size,
torch::Tensor& scratch,
torch::Tensor& put_buffer,
const int64_t nranks_per_node,
const std::vector<int64_t>& rank_to_node,
const std::vector<int64_t>& rank_to_ib,
const int64_t context_selection);
void mscclpp_allreduce(fptr_t _context, torch::Tensor& inp, torch::Tensor& out, int64_t nthreads, int64_t nblocks);
#endif
/*
......
......@@ -49,6 +49,27 @@ if torch.version.hip is not None:
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp)
def mscclpp_generate_unique_id() -> bytes:
raise NotImplementedError()
def mscclpp_init_context(
unique_id: bytes,
rank: int,
world_size: int,
scratch: torch.Tensor,
put_buffer: torch.Tensor,
nranks_per_node: int,
rank_to_node: List[int],
rank_to_ib: List[int],
context_selection: int,
) -> int:
raise NotImplementedError()
def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
raise NotImplementedError()
else:
def init_custom_ar(
......@@ -85,3 +106,36 @@ else:
def meta_size() -> int:
return torch.ops.sgl_kernel.meta_size.default()
def mscclpp_generate_unique_id() -> torch.Tensor:
return torch.ops.sgl_kernel.mscclpp_generate_unique_id.default()
def mscclpp_init_context(
unique_id: torch.Tensor,
rank: int,
world_size: int,
scratch: torch.Tensor,
put_buffer: torch.Tensor,
nranks_per_node: int,
rank_to_node: List[int],
rank_to_ib: List[int],
context_selection: int,
) -> int:
return torch.ops.sgl_kernel.mscclpp_init_context.default(
unique_id,
rank,
world_size,
scratch,
put_buffer,
nranks_per_node,
rank_to_node,
rank_to_ib,
context_selection,
)
def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
torch.ops.sgl_kernel.mscclpp_allreduce.default(
context, inp, out, nthreads, nblocks
)
import multiprocessing as mp
import os
import socket
import unittest
from enum import IntEnum
from typing import Any
import sgl_kernel.allreduce as custom_ops
import torch
import torch.distributed as dist
class MscclContextSelection(IntEnum):
MSCCL1SHOT1NODELL = 1
MSCCL1SHOT2NODELL = 2
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = dist.group.WORLD
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
if rank == 0:
unique_id = [custom_ops.mscclpp_generate_unique_id()]
else:
unique_id = [None]
dist.broadcast_object_list(
unique_id, src=0, device=torch.device("cpu"), group=cpu_group
)
unique_id = unique_id[0]
rank_to_node, rank_to_ib = list(range(world_size)), list(range(world_size))
for r in range(world_size):
rank_to_node[r] = r // 8
rank_to_ib[r] = rank % 8
MAX_BYTES = 2**20
scratch = torch.empty(
MAX_BYTES * 8, dtype=torch.bfloat16, device=torch.cuda.current_device()
)
put_buffer = torch.empty(
MAX_BYTES, dtype=torch.bfloat16, device=torch.cuda.current_device()
)
print(f"[{rank}] start mscclpp_context init")
nranks_per_node = torch.cuda.device_count()
selection = int(MscclContextSelection.MSCCL1SHOT1NODELL)
mscclpp_context = custom_ops.mscclpp_init_context(
unique_id,
rank,
world_size,
scratch,
put_buffer,
nranks_per_node,
rank_to_node,
rank_to_ib,
selection,
)
try:
test_loop = 10
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
if sz * dtype.itemsize > MAX_BYTES:
continue
if rank == 0:
print(f"mscclpp allreduce test sz {sz}, dtype {dtype}")
for _ in range(test_loop):
inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
inp1_ref = inp1.clone()
out1 = torch.empty_like(inp1)
custom_ops.mscclpp_allreduce(
mscclpp_context, inp1, out1, nthreads=512, nblocks=21
)
dist.all_reduce(inp1_ref, group=group)
torch.testing.assert_close(out1, inp1_ref)
finally:
dist.barrier(group=group)
dist.destroy_process_group(group=group)
def get_open_port() -> int:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
except OSError:
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("::1", 0))
return s.getsockname()[1]
def multi_process_parallel(
world_size: int, test_target: Any, target_args: tuple = ()
) -> None:
mp.set_start_method("spawn", force=True)
procs = []
distributed_init_port = get_open_port()
for i in range(world_size):
proc_args = (world_size, i, distributed_init_port) + target_args
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
proc.start()
procs.append(proc)
for i in range(world_size):
procs[i].join()
assert (
procs[i].exitcode == 0
), f"Process {i} failed with exit code {procs[i].exitcode}"
class TestMSCCLAllReduce(unittest.TestCase):
test_sizes = [
512,
2560,
4096,
5120,
7680,
32768,
262144,
524288,
]
world_sizes = [8]
def test_correctness(self):
for world_size in self.world_sizes:
available_gpus = torch.cuda.device_count()
if world_size > available_gpus:
print(
f"Skipping world_size={world_size}, found {available_gpus} and now ray is not supported here"
)
continue
print(f"Running test for world_size={world_size}")
multi_process_parallel(
world_size, _run_correctness_worker, target_args=(self.test_sizes,)
)
print(f"custom allreduce tp = {world_size}: OK")
if __name__ == "__main__":
unittest.main()
"""For Now, MSCCL is only supported on TP16 and TP8 case
if [[ $RANK -eq 0 ]]; then
ray start --block --head --port=6379 &
python3 test_mscclpp.py;
else
ray start --block --address=${MASTER_ADDR}:6379;
fi
"""
import itertools
import os
import random
import socket
import unittest
from contextlib import contextmanager, nullcontext
from typing import Any, List, Optional, Union
import ray
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
from sglang.srt.distributed import init_distributed_environment
from sglang.srt.distributed.communication_op import ( # noqa
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce,
)
from sglang.srt.distributed.device_communicators.pymscclpp import PyMscclppCommunicator
from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator
from sglang.srt.distributed.parallel_state import (
get_tensor_model_parallel_group,
graph_capture,
initialize_model_parallel,
set_custom_all_reduce,
set_mscclpp_all_reduce,
)
from sglang.srt.distributed.utils import StatelessProcessGroup
from sglang.test.test_utils import CustomTestCase
def get_open_port() -> int:
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def multi_process_parallel(
world_size: int,
master_addr: str,
cls: Any,
test_target: Any,
) -> None:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray.init(log_to_driver=True)
distributed_init_port = get_open_port()
refs = []
for rank in range(world_size):
refs.append(
test_target.remote(
cls, world_size, master_addr, rank, distributed_init_port
)
)
ray.get(refs)
ray.shutdown()
class TestMSCCLAllReduce(CustomTestCase):
@classmethod
def setUpClass(cls):
random.seed(42)
# 1KB to 1MB
cls.test_sizes = [512, 4096, 32768, 262144, 524288]
cls.world_sizes = [8]
TEST_TP16 = int(os.getenv("SGL_MSCCLPP_TEST_TP16", "0"))
if TEST_TP16:
cls.world_sizes = [16]
cls.test_loop = 10
def test_graph_allreduce(self):
TEST_MASTER_ADDR = os.getenv("SGL_MSCCLPP_TEST_MASTER_ADDR", "localhost")
for world_size in self.world_sizes:
if world_size not in [8, 16]:
continue
multi_process_parallel(
world_size, TEST_MASTER_ADDR, self, self.graph_allreduce
)
def test_eager_allreduce(self):
TEST_MASTER_ADDR = os.getenv("SGL_MSCCLPP_TEST_MASTER_ADDR", "localhost")
for world_size in self.world_sizes:
if world_size not in [8, 16]:
continue
multi_process_parallel(
world_size, TEST_MASTER_ADDR, self, self.eager_allreduce
)
@ray.remote(num_gpus=1, max_calls=1)
def graph_allreduce(self, world_size, master_addr, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://{master_addr}:{distributed_init_port}"
set_mscclpp_all_reduce(True)
set_custom_all_reduce(False)
init_distributed_environment(
world_size=world_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=rank % torch.cuda.device_count(),
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
group = get_tensor_model_parallel_group().device_group
# A small all_reduce for warmup.
# this is needed because device communicators might be created lazily
# (e.g. NCCL). This will ensure that the communicator is initialized
# before any communication happens, so that this group can be used for
# graph capture immediately.
data = torch.zeros(1)
data = data.to(device=device)
torch.distributed.all_reduce(data, group=group)
torch.cuda.synchronize()
del data
for sz in self.test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
for _ in range(self.test_loop):
with graph_capture() as graph_capture_context:
# use integers so result matches NCCL exactly
inp1 = torch.randint(
1,
16,
(sz,),
dtype=dtype,
device=torch.cuda.current_device(),
)
inp2 = torch.randint(
1,
16,
(sz,),
dtype=dtype,
device=torch.cuda.current_device(),
)
torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(
graph, stream=graph_capture_context.stream
):
out1 = tensor_model_parallel_all_reduce(inp1)
# the input buffer is immediately modified to test
# synchronization
dist.all_reduce(inp1, group=group)
out2 = tensor_model_parallel_all_reduce(inp2)
dist.all_reduce(inp2, group=group)
graph.replay()
torch.testing.assert_close(out1, inp1)
torch.testing.assert_close(out2, inp2)
@ray.remote(num_gpus=1, max_calls=1)
def eager_allreduce(self, world_size, master_addr, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://{master_addr}:{distributed_init_port}"
set_mscclpp_all_reduce(True)
set_custom_all_reduce(False)
init_distributed_environment(
world_size=world_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=rank,
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
group = get_tensor_model_parallel_group().device_group
for sz in self.test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
for _ in range(self.test_loop):
inp1 = torch.randint(
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
)
out1 = tensor_model_parallel_all_reduce(inp1)
dist.all_reduce(inp1, group=group)
torch.testing.assert_close(out1, inp1)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment