Commit 0c5532b0 authored by maxiao1's avatar maxiao1
Browse files

enable custom_allreduce

parent 785e5e90
...@@ -19,14 +19,15 @@ logger = logging.getLogger(__name__) ...@@ -19,14 +19,15 @@ logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = get_bool_env_var( use_vllm_custom_allreduce = get_bool_env_var(
"USE_VLLM_CUSTOM_ALLREDUCE", default="false" "USE_VLLM_CUSTOM_ALLREDUCE", default="false"
) )
use_dcu_custom_allreduce= get_bool_env_var(
"USE_DCU_CUSTOM_ALLREDUCE", default="false"
)
if not is_hpu(): if not is_hpu():
# ROCm does not use vllm custom allreduce # ROCm does not use vllm custom allreduce
# if use_vllm_custom_allreduce and not is_hip(): if use_vllm_custom_allreduce and not is_hip():
if use_vllm_custom_allreduce:
try: try:
import vllm._C # noqa: F401 import vllm._C # noqa: F401
print("[DEBUG] ✅ Using vLLM custom allreduce (vllm._C successfully imported)")
except ImportError as e: except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e) logger.warning("Failed to import from vllm._C with %r", e)
else: else:
...@@ -35,12 +36,15 @@ if not is_hpu(): ...@@ -35,12 +36,15 @@ if not is_hpu():
except ImportError as e: except ImportError as e:
logger.warning("Failed to import from custom_ar with %r", e) logger.warning("Failed to import from custom_ar with %r", e)
if use_dcu_custom_allreduce:
try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
# if not is_hip() and not is_npu(): if not is_hip() and not is_npu():
if not is_npu():
if use_vllm_custom_allreduce: if use_vllm_custom_allreduce:
custom_op = torch.ops._C_custom_ar custom_op = torch.ops._C_custom_ar
print("[DEBUG] ✅ custom_op = torch.ops._C_custom_ar (vLLM path active)")
else: else:
custom_op = sgl_kernel.allreduce custom_op = sgl_kernel.allreduce
...@@ -79,8 +83,79 @@ if not is_npu(): ...@@ -79,8 +83,79 @@ if not is_npu():
) -> None: ) -> None:
custom_op.register_graph_buffers(fa, handles, offsets) custom_op.register_graph_buffers(fa, handles, offsets)
elif is_hip and use_dcu_custom_allreduce:
# custom ar
def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor,
rank: int, fully_connected: bool) -> int:
return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank,
fully_connected)
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
reg_buffer_sz_bytes: int) -> None:
torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer,
reg_buffer_sz_bytes)
def dispose(fa: int) -> None:
torch.ops._C_custom_ar.dispose(fa)
def meta_size() -> int:
return torch.ops._C_custom_ar.meta_size()
def register_buffer(fa: int, ipc_tensors: list[int]) -> None:
return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
def get_graph_buffer_ipc_meta(fa: int) -> tuple[list[int], list[int]]:
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa: int, handles: list[list[int]],
offsets: list[list[int]]) -> None:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]:
return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size)
def open_mem_handle(mem_handle: torch.Tensor):
return torch.ops._C_custom_ar.open_mem_handle(mem_handle)
def free_shared_buffer(ptr: int) -> None:
torch.ops._C_custom_ar.free_shared_buffer(ptr)
def read_cache(
keys: torch.Tensor,
values: torch.Tensor,
key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache_dtype: str
) -> None:
torch.ops._C_cache_ops.read_cache(keys, values, key_caches,
value_caches, slot_mapping,
kv_cache_dtype)
def write_cache_multi_layers(
keys: torch.Tensor,
values: torch.Tensor,
key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache_dtype: str
) -> None:
torch.ops._C_cache_ops.write_cache_multi_layers(keys, values, key_caches,
value_caches, slot_mapping,
kv_cache_dtype)
else: else:
# ROCM custom allreduce # sgl_kernel ROCM custom allreduce
def init_custom_ar( def init_custom_ar(
meta: torch.Tensor, meta: torch.Tensor,
......
...@@ -27,10 +27,11 @@ _is_hip = is_hip() ...@@ -27,10 +27,11 @@ _is_hip = is_hip()
try: try:
# if ops.use_vllm_custom_allreduce and not _is_hip: if ops.use_vllm_custom_allreduce and not _is_hip:
if ops.use_vllm_custom_allreduce:
# Use vLLM custom allreduce # Use vLLM custom allreduce
ops.meta_size() ops.meta_size()
elif ops.use_dcu_custom_allreduce:
ops.meta_size()
else: else:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM) # Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import sgl_kernel # noqa: F401 import sgl_kernel # noqa: F401
...@@ -420,3 +421,274 @@ class CustomAllreduce: ...@@ -420,3 +421,274 @@ class CustomAllreduce:
def __del__(self): def __del__(self):
self.close() self.close()
class DCUCustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8, 16]
# max_size: max supported allreduce size
def __init__(self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_size=8192 * 512) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self._IS_CAPTURING = False
self.disabled = True
if not custom_ar:
# disable because of missing custom allreduce library
# e.g. in a non-GPU environment
logger.info("Custom allreduce is disabled because "
"of missing custom allreduce library")
return
self.group = group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"CustomAllreduce should be attached to a non-NCCL group.")
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom allreduce for multi-node case.
logger.warning(
"Custom allreduce is disabled because this process group"
" spans across nodes.")
return
rank = dist.get_rank(group=self.group)
self.rank = rank
world_size = dist.get_world_size(group=self.group)
# if world_size > envs.VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX:
if world_size > 16:
return
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
logger.warning(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.",
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(torch.cuda.device_count()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id],
dtype=torch.int,
device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu")
for _ in range(world_size)
]
dist.all_gather(gather_list, tensor, group=self.group)
physical_device_ids = [t.item() for t in gather_list]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
# assert current_platform.is_cuda_alike()
# fully_connected = current_platform.is_fully_connected(
# physical_device_ids)
if _is_cuda or _is_hip:
fully_connected = is_full_nvlink(physical_device_ids, world_size)
# if world_size > 2 and not fully_connected:
if not fully_connected:
max_size = 32 * 8192 * 2
# if not envs.VLLM_PCIE_USE_CUSTOM_ALLREDUCE:
# logger.warning(
# "Custom allreduce is disabled because it's not supported on"
# " more than two PCIe-only GPUs. To silence this warning, "
# "specify disable_custom_all_reduce=True explicitly.")
# return
logger.warning(
"We are using PCIe's custom allreduce."
"If the performance is poor, we can add "
"--disable-custom-all-reduce in the instruction.")
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if not _is_hip and not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.")
return
self.disabled = False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
group=group,
uncached=True)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(8 * 1024 * 1024,
dtype=torch.uint8,
device=self.device)
self.max_size = max_size
self.rank = rank
self.world_size = world_size
self.fully_connected = fully_connected
self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank,
self.fully_connected)
ops.register_buffer(self._ptr, self.buffer_ptrs)
@contextmanager
def capture(self):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try:
self._IS_CAPTURING = True
yield
finally:
self._IS_CAPTURING = False
if not self.disabled:
self.register_graph_buffers()
def register_graph_buffers(self):
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
logger.info("Registering %d cuda graph addresses", len(offset))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [[None, None]
for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(all_data[i],
src=rank,
group=self.group,
device="cpu")
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
ops.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor):
if self.disabled:
return False
inp_size = inp.numel() * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0:
return False
if not is_weak_contiguous(inp):
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
return inp_size <= self.max_size
def all_reduce(self,
inp: torch.Tensor,
*,
out: torch.Tensor = None,
registered: bool = False):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if out is None:
out = torch.empty_like(inp)
if registered:
ops.all_reduce(self._ptr, inp, out, 0, 0)
else:
ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank],
self.max_size)
return out
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if self.disabled or not self.should_custom_ar(input):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input, registered=False)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return torch.empty_like(input)
else:
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return self.all_reduce(input, registered=False)
def close(self):
if not self.disabled and self._ptr:
if ops is not None:
ops.dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)
def __del__(self):
self.close()
@staticmethod
def create_shared_buffer(size_in_bytes: int,
group: Optional[ProcessGroup] = None,
uncached: Optional[bool] = False) -> list[int]:
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
pointers: list[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer) # type: ignore
else:
pointers.append(ops.open_mem_handle(h))
return pointers
@staticmethod
def free_shared_buffer(pointers: list[int],
group: Optional[ProcessGroup] = None,
rank: Optional[int] = 0) -> None:
if rank is None:
rank = dist.get_rank(group=group)
if ops is not None:
ops.free_shared_buffer(pointers[rank])
...@@ -53,6 +53,7 @@ from sglang.srt.utils import ( ...@@ -53,6 +53,7 @@ from sglang.srt.utils import (
is_xpu, is_xpu,
supports_custom_op, supports_custom_op,
) )
from sglang.srt import _custom_ops as ops
_is_npu = is_npu() _is_npu = is_npu()
_is_cpu = is_cpu() _is_cpu = is_cpu()
...@@ -303,7 +304,7 @@ class GroupCoordinator: ...@@ -303,7 +304,7 @@ class GroupCoordinator:
# Lazy import to avoid documentation build error # Lazy import to avoid documentation build error
from sglang.srt.distributed.device_communicators.custom_all_reduce import ( from sglang.srt.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce, CustomAllreduce, DCUCustomAllreduce
) )
from sglang.srt.distributed.device_communicators.pymscclpp import ( from sglang.srt.distributed.device_communicators.pymscclpp import (
PyMscclppCommunicator, PyMscclppCommunicator,
...@@ -347,6 +348,12 @@ class GroupCoordinator: ...@@ -347,6 +348,12 @@ class GroupCoordinator:
else: else:
ca_max_size = 8 * 1024 * 1024 ca_max_size = 8 * 1024 * 1024
try: try:
if is_hip() and ops.use_dcu_custom_allreduce:
self.ca_comm = DCUCustomAllreduce(
group=self.cpu_group,
device=self.device,
)
else:
self.ca_comm = CustomAllreduce( self.ca_comm = CustomAllreduce(
group=self.cpu_group, group=self.cpu_group,
device=self.device, device=self.device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment