Unverified Commit 590f2da0 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

[Feat] Support Torch Symm Mem AllReduce (#10571)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent 148d8d48
"""For Now, SYMM_MEM is only supported on TP8 case
export WORLD_SIZE=1
export RANK=0
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=12345
torchrun --nproc_per_node gpu \
--nnodes $WORLD_SIZE \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT ./benchmark/kernels/all_reduce/benchmark_symm_mem.py
"""
import os
from contextlib import nullcontext
from typing import List
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from sglang.srt.distributed import init_distributed_environment
from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator
from sglang.srt.distributed.device_communicators.symm_mem import SymmMemCommunicator
from sglang.srt.distributed.parallel_state import (
get_tensor_model_parallel_group,
graph_capture,
initialize_model_parallel,
set_symm_mem_all_reduce,
)
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Tensor:
dist.all_reduce(torch_input, group=group)
return torch_input
def symm_mem_allreduce(
symm_mem_input: torch.Tensor, symm_mem_comm: SymmMemCommunicator
) -> torch.Tensor:
return symm_mem_comm.all_reduce(symm_mem_input)
def pynccl_allreduce(
pynccl_input: torch.Tensor, pynccl_comm: PyNcclCommunicator
) -> torch.Tensor:
pynccl_comm.all_reduce(pynccl_input)
return pynccl_input
def _bench_graph_time(func, inp_randn, warmup_loop=2, graph_loop=10, test_loop=10):
graph_input = inp_randn.clone()
with graph_capture() as graph_capture_context:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
for _ in range(graph_loop):
graph_out = func(graph_input)
graph.replay()
func_output = graph_out.clone()
for _ in range(warmup_loop):
graph.replay()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: List[float] = []
for _ in range(test_loop):
torch.cuda.synchronize()
dist.barrier()
start_event.record()
graph.replay()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
func_cost_us = sum(latencies) / len(latencies) / graph_loop * 1000
graph.reset()
return func_output, func_cost_us
def _bench_eager_time(func, inp_randn, warmup_loop=2, test_loop=10):
eager_input = inp_randn.clone()
eager_output = func(eager_input)
func_output = eager_output.clone()
for _ in range(warmup_loop):
func(eager_input)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(test_loop):
func(eager_input)
end_event.record()
torch.cuda.synchronize()
func_cost_us = start_event.elapsed_time(end_event) / test_loop * 1000
return func_output, func_cost_us
def get_torch_prof_ctx(do_prof: bool):
ctx = (
torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
with_stack=True,
)
if do_prof
else nullcontext()
)
return ctx
def human_readable_size(size, decimal_places=1):
for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]:
if size < 1024.0 or unit == "PiB":
break
size /= 1024.0
return f"{size:.{decimal_places}f} {unit}"
try:
from tabulate import tabulate
except ImportError:
print("tabulate not installed, skipping table printing")
tabulate = None
def print_markdown_table(data):
if tabulate is not None:
print(tabulate(data, headers="keys", tablefmt="github"))
return
headers = data[0].keys()
header_row = "| " + " | ".join(headers) + " |"
separator = "| " + " | ".join(["---"] * len(headers)) + " |"
rows = []
for item in data:
row = "| " + " | ".join(str(item[key]) for key in headers) + " |"
rows.append(row)
markdown_table = "\n".join([header_row, separator] + rows)
print(markdown_table)
if __name__ == "__main__":
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
force=True,
)
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
world, world_size = dist.group.WORLD, dist.get_world_size()
rank = dist.get_rank()
torch.cuda.set_device(rank % 8)
device = torch.cuda.current_device()
set_symm_mem_all_reduce(True)
init_distributed_environment(
world_size=world_size,
rank=rank,
local_rank=rank % 8,
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
group = get_tensor_model_parallel_group().device_group
cpu_group = get_tensor_model_parallel_group().cpu_group
pynccl_comm = get_tensor_model_parallel_group().pynccl_comm
symm_mem_comm = get_tensor_model_parallel_group().symm_mem_comm
dist.barrier()
profile = False
dtype = torch.bfloat16
ctx = get_torch_prof_ctx(profile)
result = []
with ctx:
if IS_CI:
i_range = range(10, 11)
else:
i_range = range(10, 20)
for i in i_range:
sz = 2**i
if sz * dtype.itemsize > 2**24:
break
inp_randn = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
memory = torch.empty_like(inp_randn)
memory_out = torch.empty_like(memory)
torch_eager_output, torch_eager_time = _bench_eager_time(
lambda inp: torch_allreduce(inp, group), inp_randn
)
symm_mem_eager_output, symm_mem_eager_time = _bench_eager_time(
lambda inp: symm_mem_allreduce(inp, symm_mem_comm), inp_randn
)
symm_mem_graph_output, symm_mem_graph_time = _bench_graph_time(
lambda inp: symm_mem_allreduce(inp, symm_mem_comm), inp_randn
)
# since pynccl is inplace op, this return result is not correct if graph loop > 1
_, pynccl_graph_time = _bench_graph_time(
lambda inp: pynccl_allreduce(inp, pynccl_comm), inp_randn
)
torch.testing.assert_close(torch_eager_output, symm_mem_graph_output)
torch.testing.assert_close(torch_eager_output, symm_mem_eager_output)
result.append(
{
"msg_size": human_readable_size(inp_randn.nbytes),
"torch eager time": torch_eager_time,
"symm mem eager time": symm_mem_eager_time,
"symm mem graph time": symm_mem_graph_time,
"pynccl graph time": pynccl_graph_time,
}
)
if rank == 0:
print(f"sz={sz}, dtype={dtype}: correctness check PASS!")
if rank == 0:
print_markdown_table(result)
if profile:
prof_dir = f"prof/symm_mem"
os.makedirs(prof_dir, exist_ok=True)
ctx.export_chrome_trace(f"{prof_dir}/trace_rank{dist.get_rank()}.json.gz")
...@@ -28,6 +28,8 @@ def launch_server(args): ...@@ -28,6 +28,8 @@ def launch_server(args):
cmd += "--disable-custom-all-reduce" cmd += "--disable-custom-all-reduce"
if args.enable_mscclpp: if args.enable_mscclpp:
cmd += "--enable-mscclpp" cmd += "--enable-mscclpp"
if args.enable_torch_symm_mem:
cmd += "--enable-torch-symm-mem"
print(cmd) print(cmd)
os.system(cmd) os.system(cmd)
...@@ -70,6 +72,11 @@ if __name__ == "__main__": ...@@ -70,6 +72,11 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.", help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
) )
parser.add_argument(
"--enable-torch-symm-mem",
action="store_true",
help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL.",
)
args = parser.parse_args() args = parser.parse_args()
launch_server(args) launch_server(args)
MiB = 1024 * 1024
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
9: {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: 64 * MiB, # 64 MB
8: 64 * MiB, # 64 MB
},
10: {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
},
}
# Adapted from https://github.com/vllm-project/vllm/blob/bf214ca22625e311a2c4c0dfbf7af19128f4919c/vllm/distributed/device_communicators/symm_mem.py
import logging
from typing import Optional, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from sglang.srt.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
)
from sglang.srt.utils import get_device_capability, is_cuda, is_hip
try:
import torch.distributed._symmetric_memory as torch_symm_mem
symm_mem_available = True
except ImportError:
symm_mem_available = False
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
symm_mem_is_available = False
if _is_hip:
symm_mem_is_available = False
if _is_cuda:
symm_mem_is_available = True
class SymmMemCommunicator:
"""
Thin wrapper around symmetric-memory collectives.
This communicator:
- Validates device capability and world size.
- Allocates a shared symmetric buffer.
- Chooses between 'multimem' and 'two-shot' all-reduce kernels.
- Exposes a fast-path all_reduce() compatible with bfloat16 inputs.
If any prerequisite is not met, the instance remains disabled and will
decline to perform symmetric-memory all-reduce.
"""
# Mapping: compute capability major -> supported world sizes for multimem
# If the current (cc_major, world_size) is not listed, we fall back
# to the two-shot path.
_WORLD_SIZES_MULTIMEM = {
9: [4, 6, 8],
10: [6, 8],
}
def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]):
"""
Args:
group: Torch process group used for rendezvous and naming.
device: Target CUDA device (index, 'cuda:X', or torch.device).
"""
self.disabled = True
if not symm_mem_available:
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
torch.cuda.set_device(device)
self.dtype = torch.bfloat16
self.device = device
self.group = group
self.world_size = dist.get_world_size(self.group)
self.device_capability = torch.cuda.get_device_capability(device)[0]
if self.device_capability < 9:
logger.warning(
"SymmMemCommunicator: Device capability %s not supported, "
"communicator is not available.",
self.device_capability,
)
return
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]:
logger.warning(
"SymmMemCommunicator: World size %d not supported, "
"communicator is not available.",
self.world_size,
)
return
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
self.world_size
]
self.buffer = torch_symm_mem.empty(
self.max_size // self.dtype.itemsize,
device=self.device,
dtype=self.dtype,
)
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
if handle.multicast_ptr == 0:
logger.warning(
"SymmMemCommunicator: symmetric memory "
"multicast operations are not supported."
)
self.buffer = None
self.disabled = True
return
self.disabled = False
def should_symm_mem_allreduce(self, inp: torch.Tensor):
"""
Fast-path eligibility check for a given tensor.
Conditions:
- Communicator must be enabled.
- dtype must be bfloat16 (matches kernel + buffer dtype).
- Total byte size must be 4-byte aligned (hardware requirement).
- Payload must be smaller than the symmetric-memory max size.
Returns:
True if the symmetric-memory path can handle this tensor.
"""
if self.disabled:
return False
if inp.dtype != self.dtype:
return False
inp_size = inp.numel() * inp.element_size()
# enforce 4-byte alignment
if inp_size % 4 != 0:
return False
return inp_size < self.max_size
def all_reduce(
self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
) -> Optional[torch.Tensor]:
"""
Perform an in-place sum all-reduce via symmetric memory.
Args:
inp: Input tensor on the target CUDA device (bfloat16).
out: Optional output tensor; if omitted, a new tensor is allocated.
Returns:
The reduced tensor (same shape as inp), or None if disabled.
Implementation details:
- Stages 'inp' into the symmetric buffer.
- Selects 'multimem' or 'two_shot' kernel based on topology.
- Writes the result into 'out' and returns it.
"""
if out is None:
out = torch.empty_like(inp)
self.buffer[: inp.numel()].copy_(inp.view(-1))
if self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]:
torch.ops.symm_mem.multimem_all_reduce_(
self.buffer[: inp.numel()], "sum", self.group.group_name
)
else:
torch.ops.symm_mem.two_shot_all_reduce_(
self.buffer[: inp.numel()], "sum", self.group.group_name
)
out.copy_(self.buffer[: inp.numel()].view(out.shape))
return out
...@@ -215,12 +215,14 @@ class GroupCoordinator: ...@@ -215,12 +215,14 @@ class GroupCoordinator:
use_pynccl: bool # a hint of whether to use PyNccl use_pynccl: bool # a hint of whether to use PyNccl
use_pymscclpp: bool # a hint of whether to use PyMsccl use_pymscclpp: bool # a hint of whether to use PyMsccl
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
use_torch_symm_mem: bool # a hint of whether to use SymmMemAllReduce
use_message_queue_broadcaster: ( use_message_queue_broadcaster: (
bool # a hint of whether to use message queue broadcaster bool # a hint of whether to use message queue broadcaster
) )
# communicators are only created for world size > 1 # communicators are only created for world size > 1
pynccl_comm: Optional[Any] # PyNccl communicator pynccl_comm: Optional[Any] # PyNccl communicator
ca_comm: Optional[Any] # Custom allreduce communicator ca_comm: Optional[Any] # Custom allreduce communicator
symm_mem_comm: Optional[Any] # Symm mem communicator
mq_broadcaster: Optional[Any] # shared memory broadcaster mq_broadcaster: Optional[Any] # shared memory broadcaster
def __init__( def __init__(
...@@ -231,6 +233,7 @@ class GroupCoordinator: ...@@ -231,6 +233,7 @@ class GroupCoordinator:
use_pynccl: bool, use_pynccl: bool,
use_pymscclpp: bool, use_pymscclpp: bool,
use_custom_allreduce: bool, use_custom_allreduce: bool,
use_torch_symm_mem: bool,
use_hpu_communicator: bool, use_hpu_communicator: bool,
use_xpu_communicator: bool, use_xpu_communicator: bool,
use_npu_communicator: bool, use_npu_communicator: bool,
...@@ -279,6 +282,7 @@ class GroupCoordinator: ...@@ -279,6 +282,7 @@ class GroupCoordinator:
self.use_pynccl = use_pynccl self.use_pynccl = use_pynccl
self.use_pymscclpp = use_pymscclpp self.use_pymscclpp = use_pymscclpp
self.use_custom_allreduce = use_custom_allreduce self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem
self.use_hpu_communicator = use_hpu_communicator self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator self.use_xpu_communicator = use_xpu_communicator
self.use_npu_communicator = use_npu_communicator self.use_npu_communicator = use_npu_communicator
...@@ -294,6 +298,9 @@ class GroupCoordinator: ...@@ -294,6 +298,9 @@ class GroupCoordinator:
from sglang.srt.distributed.device_communicators.pynccl import ( from sglang.srt.distributed.device_communicators.pynccl import (
PyNcclCommunicator, PyNcclCommunicator,
) )
from sglang.srt.distributed.device_communicators.symm_mem import (
SymmMemCommunicator,
)
if is_hip(): if is_hip():
from sglang.srt.distributed.device_communicators.quick_all_reduce import ( from sglang.srt.distributed.device_communicators.quick_all_reduce import (
...@@ -342,6 +349,13 @@ class GroupCoordinator: ...@@ -342,6 +349,13 @@ class GroupCoordinator:
except Exception as e: except Exception as e:
logger.warning(f"Failed to initialize QuickAllReduce: {e}") logger.warning(f"Failed to initialize QuickAllReduce: {e}")
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if self.use_torch_symm_mem and self.world_size > 1:
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)
# Create communicator for other hardware backends # Create communicator for other hardware backends
from sglang.srt.distributed.device_communicators.hpu_communicator import ( from sglang.srt.distributed.device_communicators.hpu_communicator import (
HpuCommunicator, HpuCommunicator,
...@@ -446,6 +460,7 @@ class GroupCoordinator: ...@@ -446,6 +460,7 @@ class GroupCoordinator:
# custom allreduce | enabled | enabled | # custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled | # PyNccl | disabled| enabled |
# PyMscclpp | disabled| enabled | # PyMscclpp | disabled| enabled |
# TorchSymmMem | disabled| enabled |
# torch.distributed | enabled | disabled| # torch.distributed | enabled | disabled|
# #
# Note: When custom quick allreduce is enabled, a runtime check # Note: When custom quick allreduce is enabled, a runtime check
...@@ -547,7 +562,12 @@ class GroupCoordinator: ...@@ -547,7 +562,12 @@ class GroupCoordinator:
and self.pymscclpp_comm.should_mscclpp_allreduce(input_) and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
): ):
outplace_all_reduce_method = "pymscclpp" outplace_all_reduce_method = "pymscclpp"
elif (
self.symm_mem_comm is not None
and not self.symm_mem_comm.disabled
and self.symm_mem_comm.should_symm_mem_allreduce(input_)
):
outplace_all_reduce_method = "symm_mem"
if outplace_all_reduce_method is not None: if outplace_all_reduce_method is not None:
return torch.ops.sglang.outplace_all_reduce( return torch.ops.sglang.outplace_all_reduce(
input_, input_,
...@@ -564,6 +584,7 @@ class GroupCoordinator: ...@@ -564,6 +584,7 @@ class GroupCoordinator:
ca_comm = self.ca_comm ca_comm = self.ca_comm
qr_comm = self.qr_comm qr_comm = self.qr_comm
pymscclpp_comm = self.pymscclpp_comm pymscclpp_comm = self.pymscclpp_comm
symm_mem_comm = self.symm_mem_comm
assert any([qr_comm, ca_comm, pymscclpp_comm]) assert any([qr_comm, ca_comm, pymscclpp_comm])
if outplace_all_reduce_method == "ca": if outplace_all_reduce_method == "ca":
assert not ca_comm.disabled assert not ca_comm.disabled
...@@ -571,6 +592,9 @@ class GroupCoordinator: ...@@ -571,6 +592,9 @@ class GroupCoordinator:
elif outplace_all_reduce_method == "qr": elif outplace_all_reduce_method == "qr":
assert not qr_comm.disabled assert not qr_comm.disabled
out = qr_comm.quick_all_reduce(input_) out = qr_comm.quick_all_reduce(input_)
elif outplace_all_reduce_method == "symm_mem":
assert not symm_mem_comm.disabled
out = symm_mem_comm.all_reduce(input_)
else: else:
assert not pymscclpp_comm.disabled assert not pymscclpp_comm.disabled
out = pymscclpp_comm.all_reduce(input_) out = pymscclpp_comm.all_reduce(input_)
...@@ -1219,6 +1243,7 @@ def init_world_group( ...@@ -1219,6 +1243,7 @@ def init_world_group(
use_pynccl=False, use_pynccl=False,
use_pymscclpp=False, use_pymscclpp=False,
use_custom_allreduce=False, use_custom_allreduce=False,
use_torch_symm_mem=False,
use_hpu_communicator=False, use_hpu_communicator=False,
use_xpu_communicator=False, use_xpu_communicator=False,
use_npu_communicator=False, use_npu_communicator=False,
...@@ -1234,11 +1259,14 @@ def init_model_parallel_group( ...@@ -1234,11 +1259,14 @@ def init_model_parallel_group(
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None, group_name: Optional[str] = None,
use_mscclpp_allreduce: Optional[bool] = None, use_mscclpp_allreduce: Optional[bool] = None,
use_symm_mem_allreduce: Optional[bool] = None,
) -> GroupCoordinator: ) -> GroupCoordinator:
if use_custom_allreduce is None: if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
if use_mscclpp_allreduce is None: if use_mscclpp_allreduce is None:
use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
if use_symm_mem_allreduce is None:
use_symm_mem_allreduce = _ENABLE_SYMM_MEM_ALL_REDUCE
return GroupCoordinator( return GroupCoordinator(
group_ranks=group_ranks, group_ranks=group_ranks,
local_rank=local_rank, local_rank=local_rank,
...@@ -1246,6 +1274,7 @@ def init_model_parallel_group( ...@@ -1246,6 +1274,7 @@ def init_model_parallel_group(
use_pynccl=not _is_npu, use_pynccl=not _is_npu,
use_pymscclpp=use_mscclpp_allreduce, use_pymscclpp=use_mscclpp_allreduce,
use_custom_allreduce=use_custom_allreduce, use_custom_allreduce=use_custom_allreduce,
use_torch_symm_mem=use_symm_mem_allreduce,
use_hpu_communicator=True, use_hpu_communicator=True,
use_xpu_communicator=True, use_xpu_communicator=True,
use_npu_communicator=True, use_npu_communicator=True,
...@@ -1331,6 +1360,7 @@ logger = logging.getLogger(__name__) ...@@ -1331,6 +1360,7 @@ logger = logging.getLogger(__name__)
_ENABLE_CUSTOM_ALL_REDUCE = True _ENABLE_CUSTOM_ALL_REDUCE = True
_ENABLE_MSCCLPP_ALL_REDUCE = False _ENABLE_MSCCLPP_ALL_REDUCE = False
_ENABLE_SYMM_MEM_ALL_REDUCE = False
def set_custom_all_reduce(enable: bool): def set_custom_all_reduce(enable: bool):
...@@ -1343,6 +1373,11 @@ def set_mscclpp_all_reduce(enable: bool): ...@@ -1343,6 +1373,11 @@ def set_mscclpp_all_reduce(enable: bool):
_ENABLE_MSCCLPP_ALL_REDUCE = enable _ENABLE_MSCCLPP_ALL_REDUCE = enable
def set_symm_mem_all_reduce(enable: bool):
global _ENABLE_SYMM_MEM_ALL_REDUCE
_ENABLE_SYMM_MEM_ALL_REDUCE = enable
def init_distributed_environment( def init_distributed_environment(
world_size: int = -1, world_size: int = -1,
rank: int = -1, rank: int = -1,
......
...@@ -263,6 +263,7 @@ def initialize_dp_attention( ...@@ -263,6 +263,7 @@ def initialize_dp_attention(
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP, use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
use_pymscclpp=False, use_pymscclpp=False,
use_custom_allreduce=False, use_custom_allreduce=False,
use_torch_symm_mem=False,
use_hpu_communicator=False, use_hpu_communicator=False,
use_xpu_communicator=False, use_xpu_communicator=False,
use_npu_communicator=False, use_npu_communicator=False,
......
...@@ -42,6 +42,7 @@ from sglang.srt.distributed import ( ...@@ -42,6 +42,7 @@ from sglang.srt.distributed import (
initialize_model_parallel, initialize_model_parallel,
set_custom_all_reduce, set_custom_all_reduce,
set_mscclpp_all_reduce, set_mscclpp_all_reduce,
set_symm_mem_all_reduce,
) )
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.eplb.eplb_manager import EPLBManager from sglang.srt.eplb.eplb_manager import EPLBManager
...@@ -646,6 +647,7 @@ class ModelRunner: ...@@ -646,6 +647,7 @@ class ModelRunner:
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
set_mscclpp_all_reduce(self.server_args.enable_mscclpp) set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
if not self.is_draft_worker: if not self.is_draft_worker:
if self.device == "cpu": if self.device == "cpu":
......
...@@ -382,6 +382,7 @@ class ServerArgs: ...@@ -382,6 +382,7 @@ class ServerArgs:
disable_outlines_disk_cache: bool = False disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
enable_mscclpp: bool = False enable_mscclpp: bool = False
enable_torch_symm_mem: bool = False
disable_overlap_schedule: bool = False disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False enable_mixed_chunk: bool = False
enable_dp_attention: bool = False enable_dp_attention: bool = False
...@@ -2443,6 +2444,11 @@ class ServerArgs: ...@@ -2443,6 +2444,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.", help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
) )
parser.add_argument(
"--enable-torch-symm-mem",
action="store_true",
help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL. Only supports CUDA device SM90 and above. SM90 supports world size 4, 6, 8. SM10 supports world size 6, 8.",
)
parser.add_argument( parser.add_argument(
"--disable-overlap-schedule", "--disable-overlap-schedule",
action="store_true", action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment