"vscode:/vscode.git/clone" did not exist on "05aea9edf9fed0bf7e83ff48e136e3f762976e27"
Unverified Commit 9cf40772 authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Enable custom AR for AMD GPUs and maintain it in sgl-kernel (#3406)

parent d3fe9bae
...@@ -9,13 +9,13 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union ...@@ -9,13 +9,13 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch import torch
import torch.library import torch.library
from sglang.srt.utils import is_hpu from sglang.srt.utils import is_hip, is_hpu
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True) use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True)
if not is_hpu(): if not is_hpu():
if use_vllm_custom_allreduce: # Remove vllm dependency for custom allreduce on ROCm
if use_vllm_custom_allreduce and not is_hip():
try: try:
import vllm._C import vllm._C
except ImportError as e: except ImportError as e:
...@@ -56,7 +56,7 @@ def hint_on_error(fn): ...@@ -56,7 +56,7 @@ def hint_on_error(fn):
return wrapper return wrapper
if use_vllm_custom_allreduce: if use_vllm_custom_allreduce and not is_hip():
# custom ar # custom ar
def init_custom_ar( def init_custom_ar(
ipc_tensors: List[torch.Tensor], ipc_tensors: List[torch.Tensor],
...@@ -95,39 +95,87 @@ if use_vllm_custom_allreduce: ...@@ -95,39 +95,87 @@ if use_vllm_custom_allreduce:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
else: else:
# custom ar if is_hip():
def init_custom_ar(
rank_id: int, def init_custom_ar(
world_size: int, meta: torch.Tensor,
rank_data_base: torch.Tensor, rank_data: torch.Tensor,
buffers: List[int], handles: List[str],
tmp_result_buffers: List[int], offsets: List[int],
barrier_in: List[int], rank: int,
barrier_out: List[int], full_nvlink: bool,
) -> int: ) -> int:
return sgl_kernel.ops.init_custom_reduce( return sgl_kernel.ops.init_custom_ar(
rank_id, meta, rank_data, handles, offsets, rank, full_nvlink
world_size, )
rank_data_base,
buffers,
tmp_result_buffers,
barrier_in,
barrier_out,
)
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
sgl_kernel.ops.custom_reduce(fa, inp, out) sgl_kernel.ops.all_reduce_reg(fa, inp, out)
def dispose(fa: int) -> None: def all_reduce_unreg(
sgl_kernel.ops.custom_dispose(fa) fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
) -> None:
sgl_kernel.ops.all_reduce_unreg(fa, inp, reg_buffer, out)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: def dispose(fa: int) -> None:
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa) sgl_kernel.ops.dispose(fa)
def register_graph_buffers( def meta_size() -> int:
fa: int, handles: List[List[int]], offsets: List[List[int]] return sgl_kernel.ops.meta_size()
) -> None:
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets) def register_buffer(
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
) -> None:
return sgl_kernel.ops.register_buffer(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
fa: int, handles: List[str], offsets: List[List[int]]
) -> None:
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
def allocate_meta_buffer(size: int) -> torch.Tensor:
return sgl_kernel.ops.allocate_meta_buffer(size)
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return sgl_kernel.ops.get_meta_buffer_ipc_handle(inp)
else:
# custom ar
def init_custom_ar(
rank_id: int,
world_size: int,
rank_data_base: torch.Tensor,
buffers: List[int],
tmp_result_buffers: List[int],
barrier_in: List[int],
barrier_out: List[int],
) -> int:
return sgl_kernel.ops.init_custom_reduce(
rank_id,
world_size,
rank_data_base,
buffers,
tmp_result_buffers,
barrier_in,
barrier_out,
)
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
sgl_kernel.ops.custom_reduce(fa, inp, out)
def dispose(fa: int) -> None:
sgl_kernel.ops.custom_dispose(fa)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456 # temporary fix for https://github.com/vllm-project/vllm/issues/5456
......
...@@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ...@@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
gpu_p2p_access_check, gpu_p2p_access_check,
) )
from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.utils import cuda_device_count_stateless, is_cuda from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -28,14 +28,27 @@ if is_cuda(): ...@@ -28,14 +28,27 @@ if is_cuda():
except ImportError as e: except ImportError as e:
logger.warning("Failed to import pynvml with %r", e) logger.warning("Failed to import pynvml with %r", e)
if is_hip():
try:
from amdsmi import (
AmdSmiException,
amdsmi_get_gpu_board_info,
amdsmi_get_processor_handles,
amdsmi_init,
amdsmi_shut_down,
amdsmi_topo_get_link_type,
)
except ImportError as e:
logger.warning("Failed to import amdsmi with %r", e)
try: try:
if ops.use_vllm_custom_allreduce: if ops.use_vllm_custom_allreduce and not is_hip():
ops.meta_size() ops.meta_size()
else: else:
import sgl_kernel import sgl_kernel
custom_ar = True custom_ar = True
except Exception: except Exception:
# For AMD GPUs and CPUs # For CPUs
custom_ar = False custom_ar = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -47,37 +60,62 @@ _R = TypeVar("_R") ...@@ -47,37 +60,62 @@ _R = TypeVar("_R")
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn) @wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
pynvml.nvmlInit() if torch.version.hip:
try: try:
return fn(*args, **kwargs) amdsmi_init()
finally: return fn(*args, **kwargs)
pynvml.nvmlShutdown() finally:
amdsmi_shut_down()
else:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
pynvml.nvmlShutdown()
return wrapper return wrapper
@with_nvml_context @with_nvml_context
def is_full_nvlink(physical_device_ids: List[int]) -> bool: def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
""" if is_hip():
query if the set of gpus are fully connected by nvlink (1 hop) """
""" query if the set of gpus are fully connected by xgmi (1 hop)
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids] """
for i, handle in enumerate(handles): handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
for j, peer_handle in enumerate(handles): for i, handle in enumerate(handles):
if i < j: for j, peer_handle in enumerate(handles):
try: if i < j:
p2p_status = pynvml.nvmlDeviceGetP2PStatus( try:
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK link_type = amdsmi_topo_get_link_type(handle, peer_handle)
) # type is 2 for XGMI
if p2p_status != pynvml.NVML_P2P_STATUS_OK: if link_type["hops"] != 1 or link_type["type"] != 2:
return False
except AmdSmiException as error:
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
return False return False
except pynvml.NVMLError: return True
logger.exception( else:
"NVLink detection failed. This is normal if your" """
" machine has no NVLink equipped." query if the set of gpus are fully connected by nvlink (1 hop)
) """
return False handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
return True for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError:
logger.exception(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped."
)
return False
return True
def _can_p2p(rank: int, world_size: int) -> bool: def _can_p2p(rank: int, world_size: int) -> bool:
...@@ -102,15 +140,18 @@ def is_weak_contiguous(inp: torch.Tensor): ...@@ -102,15 +140,18 @@ def is_weak_contiguous(inp: torch.Tensor):
class CustomAllreduce: class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
_MAX_CAR_SIZE = 8192 * 1024
if is_hip():
# crossover is at 16MB buffer size for ROCm
_MAX_CAR_SIZE = 2 * 8192 * 1024
# max_size: max supported allreduce size # max_size: max supported allreduce size
def __init__( def __init__(
self, self,
group: ProcessGroup, group: ProcessGroup,
device: Union[int, str, torch.device], device: Union[int, str, torch.device],
max_size=8192 * 1024, max_size=_MAX_CAR_SIZE,
) -> None: ) -> None:
""" """
Args: Args:
...@@ -185,12 +226,9 @@ class CustomAllreduce: ...@@ -185,12 +226,9 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases # test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported # where custom allreduce is not supported
# this checks hardware and driver support for NVLink # this checks hardware and driver support for NVLink
if is_cuda(): if is_cuda() or is_hip():
assert is_cuda() full_nvlink = is_full_nvlink(physical_device_ids, world_size)
full_nvlink = is_full_nvlink(physical_device_ids)
else:
full_nvlink = False
if world_size > 2 and not full_nvlink: if world_size > 2 and not full_nvlink:
logger.warning( logger.warning(
"Custom allreduce is disabled because it's not supported on" "Custom allreduce is disabled because it's not supported on"
...@@ -201,7 +239,8 @@ class CustomAllreduce: ...@@ -201,7 +239,8 @@ class CustomAllreduce:
# test P2P capability, this checks software/cudaruntime support # test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time # this is expensive to compute at the first time
# then we cache the result # then we cache the result
if not _can_p2p(rank, world_size): # On AMD GPU, p2p is always enabled between XGMI connected GPUs
if not is_hip() and not _can_p2p(rank, world_size):
logger.warning( logger.warning(
"Custom allreduce is disabled because your platform lacks " "Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this " "GPU P2P capability or P2P test failed. To silence this "
...@@ -214,7 +253,7 @@ class CustomAllreduce: ...@@ -214,7 +253,7 @@ class CustomAllreduce:
self.world_size = world_size self.world_size = world_size
self.full_nvlink = full_nvlink self.full_nvlink = full_nvlink
if ops.use_vllm_custom_allreduce: if ops.use_vllm_custom_allreduce and not is_hip():
# Buffers memory are owned by this Python class and passed to C++. # Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a # Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results. # temporary buffer for storing intermediate allreduce results.
...@@ -237,35 +276,56 @@ class CustomAllreduce: ...@@ -237,35 +276,56 @@ class CustomAllreduce:
) )
ops.register_buffer(self._ptr, self.buffer_ptrs) ops.register_buffer(self._ptr, self.buffer_ptrs)
else: else:
# From TensorRT-LLM getMaxRequiredWorkspaceSize if is_hip():
self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024] # meta data buffers need to be "uncached" for signal on MI200
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
self.buffer = torch.empty(
max_size, dtype=torch.uint8, device=self.device
)
handle = ops.get_meta_buffer_ipc_handle(self.meta)
shard_data = (
bytes(handle), # ipc handle to base ptr
0, # offset of base ptr
)
handles, offsets = self._gather_ipc_meta(shard_data)
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self._ptr = ops.init_custom_ar(
self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
)
self.register_buffer(self.buffer)
self.MSCCL = os.getenv("RCCL_MSCCL_ENABLE", "1") == "1"
else:
# From TensorRT-LLM getMaxRequiredWorkspaceSize
self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE; # sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self.barrier_max_size = 8 * (36 + 2) * 8 self.barrier_max_size = 8 * (36 + 2) * 8
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
self.tmp_result_buffer_ptrs = self.create_shared_buffer( self.tmp_result_buffer_ptrs = self.create_shared_buffer(
max_size, group=group max_size, group=group
) )
self.rank_data_base = torch.empty( self.rank_data_base = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
) )
self.barrier_in_ptrs = self.create_shared_buffer( self.barrier_in_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group self.barrier_max_size, group=group
) )
self.barrier_out_ptrs = self.create_shared_buffer( self.barrier_out_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group self.barrier_max_size, group=group
) )
self._ptr = ops.init_custom_ar( self._ptr = ops.init_custom_ar(
rank, rank,
world_size, world_size,
self.rank_data_base, self.rank_data_base,
self.buffer_ptrs, self.buffer_ptrs,
self.tmp_result_buffer_ptrs, self.tmp_result_buffer_ptrs,
self.barrier_in_ptrs, self.barrier_in_ptrs,
self.barrier_out_ptrs, self.barrier_out_ptrs,
) )
self.disabled = False self.disabled = False
@staticmethod @staticmethod
...@@ -316,23 +376,69 @@ class CustomAllreduce: ...@@ -316,23 +376,69 @@ class CustomAllreduce:
if not self.disabled: if not self.disabled:
self.register_graph_buffers() self.register_graph_buffers()
def register_graph_buffers(self): def _get_ipc_meta(self, inp: torch.Tensor):
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) # _share_cuda_() doesn't accept meta buffer not allocated from
logger.info("Registering %d cuda graph addresses", len(offset)) # PyTorch cache allocator, use direct HIP call to get IPC handle
# We cannot directly use `dist.all_gather_object` here handle = ops.get_meta_buffer_ipc_handle(inp)
# because it is incompatible with `gloo` backend under inference mode. shard_data = (
# see https://github.com/pytorch/pytorch/issues/126032 for details. bytes(handle), # ipc handle to base ptr
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))] 0, # offset of base ptr
all_data[self.rank] = [handle, offset] )
ranks = sorted(dist.get_process_group_ranks(group=self.group)) return self._gather_ipc_meta(shard_data)
def _gather_ipc_meta(self, shard_data):
# Note: don't use `[[None]] * self.world_size` here
# because it will create a list of the same reference
all_data: List[Optional[Any]] = [[None] for i in range(self.world_size)]
all_data[self.rank][0] = shard_data
ranks = dist.get_process_group_ranks(group=self.group)
ranks.sort()
for i, rank in enumerate(ranks): for i, rank in enumerate(ranks):
dist.broadcast_object_list( dist.broadcast_object_list(
all_data[i], src=rank, group=self.group, device="cpu" all_data[i], src=rank, group=self.group, device="cpu"
) )
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore # we cannot directly use `dist.all_gather_object` here
offsets = [d[1] for d in all_data] # type: ignore # because it is incompatible with `gloo` backend under inference mode.
ops.register_graph_buffers(self._ptr, handles, offsets) # see https://github.com/pytorch/pytorch/issues/126032 for details.
handles = []
offsets = []
for i in range(len(all_data)):
handles.append(all_data[i][0][0]) # type: ignore
offsets.append(all_data[i][0][1]) # type: ignore
return handles, offsets
def register_buffer(self, inp: torch.Tensor):
handles, offsets = self._get_ipc_meta(inp)
ops.register_buffer(self._ptr, inp, handles, offsets)
def register_graph_buffers(self):
if is_hip():
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
logger.info("Registering %d cuda graph addresses", len(offset))
ops.register_graph_buffers(self._ptr, handles, offsets)
else:
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
logger.info("Registering %d cuda graph addresses", len(offset))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [
[None, None] for _ in range(dist.get_world_size(group=self.group))
]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(
all_data[i], src=rank, group=self.group, device="cpu"
)
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
ops.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor): def should_custom_ar(self, inp: torch.Tensor):
if self.disabled: if self.disabled:
...@@ -345,11 +451,22 @@ class CustomAllreduce: ...@@ -345,11 +451,22 @@ class CustomAllreduce:
return False return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides # for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL. # little performance improvement over NCCL.
if ops.use_vllm_custom_allreduce: if ops.use_vllm_custom_allreduce and not is_hip():
if self.world_size == 2 or self.full_nvlink: if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size return inp_size < self.max_size
return False return False
if is_hip():
if self.full_nvlink:
if self.world_size == 8:
if self.MSCCL:
return False
else:
return inp_size < self.max_size
else:
return inp_size < self.max_size
return False
if self.world_size == 2: if self.world_size == 2:
return ( return (
inp_size < self.max_size inp_size < self.max_size
...@@ -364,6 +481,21 @@ class CustomAllreduce: ...@@ -364,6 +481,21 @@ class CustomAllreduce:
return False return False
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
ops.all_reduce_reg(self._ptr, inp, out)
return out
# all reduce, assuming inp tensor is NOT IPC registered
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
return out
def all_reduce( def all_reduce(
self, self,
inp: torch.Tensor, inp: torch.Tensor,
...@@ -397,13 +529,23 @@ class CustomAllreduce: ...@@ -397,13 +529,23 @@ class CustomAllreduce:
return None return None
if self._IS_CAPTURING: if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing(): if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input, registered=True) if is_hip():
return self.all_reduce_reg(input)
else:
return self.all_reduce(input, registered=True)
else: else:
# If warm up, mimic the allocation pattern since custom # If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place. # allreduce is out-of-place.
return torch.empty_like(input) return torch.empty_like(input)
else: else:
return self.all_reduce(input, registered=False) if is_hip():
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
return self.all_reduce_unreg(input)
else:
return self.all_reduce(input, registered=False)
def close(self): def close(self):
if not self.disabled and self._ptr: if not self.disabled and self._ptr:
...@@ -411,7 +553,7 @@ class CustomAllreduce: ...@@ -411,7 +553,7 @@ class CustomAllreduce:
if ops.use_vllm_custom_allreduce: if ops.use_vllm_custom_allreduce:
self.free_shared_buffer(self.meta_ptrs) self.free_shared_buffer(self.meta_ptrs)
self.free_shared_buffer(self.buffer_ptrs) self.free_shared_buffer(self.buffer_ptrs)
else: elif is_cuda():
self.free_shared_buffer(self.buffer_ptrs) self.free_shared_buffer(self.buffer_ptrs)
self.free_shared_buffer(self.tmp_result_buffer_ptrs) self.free_shared_buffer(self.tmp_result_buffer_ptrs)
self.free_shared_buffer(self.barrier_in_ptrs) self.free_shared_buffer(self.barrier_in_ptrs)
......
...@@ -44,6 +44,7 @@ include_dirs = [ ...@@ -44,6 +44,7 @@ include_dirs = [
sources = [ sources = [
"src/sgl-kernel/torch_extension_rocm.cc", "src/sgl-kernel/torch_extension_rocm.cc",
"src/sgl-kernel/csrc/moe_align_kernel.cu", "src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/custom_all_reduce.hip",
] ]
cxx_flags = ["-O3"] cxx_flags = ["-O3"]
......
import ctypes import ctypes
import os import os
import torch
if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"): if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
ctypes.CDLL( ctypes.CDLL(
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12", "/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12",
mode=ctypes.RTLD_GLOBAL, mode=ctypes.RTLD_GLOBAL,
) )
from .version import __version__
from sgl_kernel.ops import ( if torch.version.hip is not None:
apply_rope_with_cos_sin_cache_inplace, from sgl_kernel.ops import (
bmm_fp8, all_reduce_reg,
build_tree_kernel, all_reduce_unreg,
build_tree_kernel_efficient, allocate_meta_buffer,
cublas_grouped_gemm, apply_rope_with_cos_sin_cache_inplace,
custom_dispose, bmm_fp8,
custom_reduce, dispose,
fp8_blockwise_scaled_mm, fp8_scaled_mm,
fp8_scaled_mm, fused_add_rmsnorm,
fused_add_rmsnorm, gelu_and_mul,
gelu_and_mul, gelu_tanh_and_mul,
gelu_tanh_and_mul, gemma_fused_add_rmsnorm,
gemma_fused_add_rmsnorm, gemma_rmsnorm,
gemma_rmsnorm, get_graph_buffer_ipc_meta,
get_graph_buffer_ipc_meta, get_meta_buffer_ipc_handle,
init_custom_reduce, init_custom_ar,
int8_scaled_mm, int8_scaled_mm,
lightning_attention_decode, lightning_attention_decode,
min_p_sampling_from_probs, meta_size,
moe_align_block_size, min_p_sampling_from_probs,
register_graph_buffers, moe_align_block_size,
rmsnorm, register_buffer,
sampling_scaling_penalties, register_graph_buffers,
sgl_per_token_group_quant_fp8, rmsnorm,
silu_and_mul, sampling_scaling_penalties,
top_k_renorm_prob, silu_and_mul,
top_k_top_p_sampling_from_probs, top_k_renorm_prob,
top_p_renorm_prob, top_k_top_p_sampling_from_probs,
tree_speculative_sampling_target_only, top_p_renorm_prob,
) )
from .version import __version__ __all__ = [
"all_reduce_reg",
"all_reduce_unreg",
"allocate_meta_buffer",
"apply_rope_with_cos_sin_cache_inplace",
"bmm_fp8",
"dispose",
"fp8_scaled_mm",
"fused_add_rmsnorm",
"gelu_and_mul",
"gelu_tanh_and_mul",
"gemma_fused_add_rmsnorm",
"gemma_rmsnorm",
"get_graph_buffer_ipc_meta",
"get_meta_buffer_ipc_handle",
"init_custom_ar",
"int8_scaled_mm",
"lightning_attention_decode",
"meta_size",
"min_p_sampling_from_probs",
"moe_align_block_size",
"register_buffer",
"register_graph_buffers",
"rmsnorm",
"sampling_scaling_penalties",
"silu_and_mul",
"top_k_renorm_prob",
"top_k_top_p_sampling_from_probs",
"top_p_renorm_prob",
]
else:
from sgl_kernel.ops import (
apply_rope_with_cos_sin_cache_inplace,
bmm_fp8,
build_tree_kernel,
build_tree_kernel_efficient,
cublas_grouped_gemm,
custom_dispose,
custom_reduce,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
fused_add_rmsnorm,
gelu_and_mul,
gelu_tanh_and_mul,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
get_graph_buffer_ipc_meta,
init_custom_reduce,
int8_scaled_mm,
lightning_attention_decode,
min_p_sampling_from_probs,
moe_align_block_size,
register_graph_buffers,
rmsnorm,
sampling_scaling_penalties,
sgl_per_token_group_quant_fp8,
silu_and_mul,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
tree_speculative_sampling_target_only,
)
__all__ = [ __all__ = [
"apply_rope_with_cos_sin_cache_inplace", "apply_rope_with_cos_sin_cache_inplace",
"bmm_fp8", "bmm_fp8",
"cublas_grouped_gemm", "cublas_grouped_gemm",
"custom_dispose", "custom_dispose",
"custom_reduce", "custom_reduce",
"fp8_blockwise_scaled_mm", "fp8_blockwise_scaled_mm",
"fp8_scaled_mm", "fp8_scaled_mm",
"fused_add_rmsnorm", "fused_add_rmsnorm",
"gelu_and_mul", "gelu_and_mul",
"gelu_tanh_and_mul", "gelu_tanh_and_mul",
"gemma_fused_add_rmsnorm", "gemma_fused_add_rmsnorm",
"gemma_rmsnorm", "gemma_rmsnorm",
"get_graph_buffer_ipc_meta", "get_graph_buffer_ipc_meta",
"init_custom_reduce", "init_custom_reduce",
"int8_scaled_mm", "int8_scaled_mm",
"lightning_attention_decode", "lightning_attention_decode",
"min_p_sampling_from_probs", "min_p_sampling_from_probs",
"moe_align_block_size", "moe_align_block_size",
"register_graph_buffers", "register_graph_buffers",
"rmsnorm", "rmsnorm",
"sampling_scaling_penalties", "sampling_scaling_penalties",
"silu_and_mul", "silu_and_mul",
"top_k_renorm_prob", "top_k_renorm_prob",
"top_k_top_p_sampling_from_probs", "top_k_top_p_sampling_from_probs",
"top_p_renorm_prob", "top_p_renorm_prob",
"tree_speculative_sampling_target_only", "tree_speculative_sampling_target_only",
"build_tree_kernel_efficient", "build_tree_kernel_efficient",
"build_tree_kernel", "build_tree_kernel",
"sgl_per_token_group_quant_fp8", "sgl_per_token_group_quant_fp8",
] ]
// !!! This is a file automatically generated by hipify!!!
#include <ATen/hip/Exceptions.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
#include <torch/all.h>
#include "custom_all_reduce_hip.cuh"
// fake pointer type, must match fptr_t type in ops.h
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink) {
int world_size = offsets.size();
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now");
if (world_size != handles.size())
throw std::invalid_argument(
"handles length should equal to offsets length");
if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in");
hipIpcMemHandle_t ipc_handles[8];
for (int i = 0; i < world_size; i++) {
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t));
}
return (fptr_t) new vllm::CustomAllreduce(
reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
}
/**
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
* because it allows transpose of contiguous slice (i.e. slicing the first
* dimension). Currently, we require this because stride information is not
* passed into the kernels and we treat input tensors as flat.
*
* Examples
* A = torch.zeros(3, 3, 3)
* 1. A: OK
* 2. A[1:]: OK
* 3. A.permute(2, 0, 1): OK
* 4. A[1:].permute(2, 0, 1): OK
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
bool _is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
t.numel() * t.element_size());
}
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
hipStream_t stream) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
TORCH_CHECK(_is_weak_contiguous(out));
switch (out.scalar_type()) {
case at::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float*>(out.data_ptr()),
out.numel());
break;
}
case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()), out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
break;
}
#endif
default:
throw std::runtime_error(
"custom allreduce only supports float32, float16 and bfloat16");
}
}
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
_all_reduce(_fa, inp, out, stream);
}
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out) {
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
auto input_size = inp.numel() * inp.element_size();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
"registered buffer is too small to contain the input");
AT_CUDA_CHECK(hipMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
input_size, hipMemcpyDeviceToDevice, stream));
_all_reduce(_fa, reg_buffer, out, stream);
}
void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
delete fa;
}
int64_t meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr());
}
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto handles =
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
return {handles, std::move(offsets)};
}
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_graph_buffers(handles, offsets);
}
void free_meta_buffer(void* buffer) { CUDACHECK(hipFree(buffer)); }
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp) {
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto data_handle =
torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)data_handle.data_ptr(),
inp.data_ptr()));
return data_handle;
}
torch::Tensor allocate_meta_buffer(int64_t size) {
auto device_index = c10::hip::current_device();
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
void* buffer;
hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
AT_CUDA_CHECK(
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
AT_CUDA_CHECK(hipMemsetAsync(buffer, 0, size, stream));
AT_CUDA_CHECK(hipStreamSynchronize(stream));
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
auto options = torch::TensorOptions()
.dtype(torch::kI8)
.device(torch::kCUDA, device_index);
return torch::from_blob(buffer, {size}, free_meta_buffer, options);
}
std::vector<uint8_t> get_device_bdf(int dev) {
char busIdStr[] = "0000:00:00.0";
std::vector<uint8_t> bdf(sizeof(busIdStr), 0);
CUDACHECK(hipDeviceGetPCIBusId((char*)bdf.data(), sizeof(busIdStr), dev));
bdf.resize(bdf.size() - 1); // remove trailing NULL
return bdf;
}
This diff is collapsed.
...@@ -34,8 +34,23 @@ limitations under the License. ...@@ -34,8 +34,23 @@ limitations under the License.
return PyModule_Create(&module); \ return PyModule_Create(&module); \
} }
// trt_reduce
using fptr_t = int64_t; using fptr_t = int64_t;
#ifdef USE_ROCM
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int64_t rank, bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out);
void dispose(fptr_t _fa);
int64_t meta_size();
void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets);
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets);
torch::Tensor allocate_meta_buffer(int64_t size);
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
#else
// trt_reduce
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers, fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out); const std::vector<fptr_t>& barrier_out);
...@@ -44,6 +59,7 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); ...@@ -44,6 +59,7 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa); std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets); const std::vector<std::vector<int64_t>>& offsets);
#endif
// moe_align_block_size // moe_align_block_size
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
......
...@@ -64,28 +64,79 @@ def apply_rope_with_cos_sin_cache_inplace( ...@@ -64,28 +64,79 @@ def apply_rope_with_cos_sin_cache_inplace(
) )
def init_custom_reduce( if torch.version.hip is not None:
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
): def init_custom_ar(
return torch.ops.sgl_kernels.init_custom_ar( meta: torch.Tensor,
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out rank_data: torch.Tensor,
) handles: List[str],
offsets: List[int],
rank: int,
full_nvlink: bool,
) -> int:
return torch.ops.sgl_kernels.init_custom_ar(
meta, rank_data, handles, offsets, rank, full_nvlink
)
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
torch.ops.sgl_kernels.all_reduce_reg(fa, inp, out)
def all_reduce_unreg(
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
) -> None:
torch.ops.sgl_kernels.all_reduce_unreg(fa, inp, reg_buffer, out)
def dispose(fa: int) -> None:
torch.ops.sgl_kernels.dispose(fa)
def custom_dispose(fa): def meta_size() -> int:
torch.ops.sgl_kernels.dispose(fa) return torch.ops.sgl_kernels.meta_size()
def register_buffer(
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
) -> None:
return torch.ops.sgl_kernels.register_buffer(fa, t, handles, offsets)
def custom_reduce(fa, inp, out): def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
torch.ops.sgl_kernels.all_reduce(fa, inp, out) return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
fa: int, handles: List[str], offsets: List[List[int]]
) -> None:
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
def allocate_meta_buffer(size: int) -> torch.Tensor:
return torch.ops.sgl_kernels.allocate_meta_buffer(size)
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return torch.ops.sgl_kernels.get_meta_buffer_ipc_handle(inp)
else:
# trt_reduce
def init_custom_reduce(
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
):
return torch.ops.sgl_kernels.init_custom_ar(
rank_id,
num_devices,
rank_data,
buffers,
tmp_buffers,
barrier_in,
barrier_out,
)
def custom_dispose(fa):
torch.ops.sgl_kernels.dispose(fa)
def get_graph_buffer_ipc_meta(fa): def custom_reduce(fa, inp, out):
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa) torch.ops.sgl_kernels.all_reduce(fa, inp, out)
def get_graph_buffer_ipc_meta(fa):
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa, handles, offsets): def register_graph_buffers(fa, handles, offsets):
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets) torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
def moe_align_block_size( def moe_align_block_size(
......
...@@ -19,6 +19,37 @@ limitations under the License. ...@@ -19,6 +19,37 @@ limitations under the License.
#include "sgl_kernels_ops.h" #include "sgl_kernels_ops.h"
TORCH_LIBRARY_EXPAND(sgl_kernels, m) { TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
// Custom all-reduce kernels
m.def(
"init_custom_ar(Tensor meta, Tensor rank_data, "
"str[] handles, int[] offsets, int rank, "
"bool full_nvlink) -> int");
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
m.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
m.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
m.def(
"all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
"()");
m.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
m.def("dispose", &dispose);
m.def("meta_size", &meta_size);
m.def(
"register_buffer(int fa, Tensor t, str[] handles, "
"int[] offsets) -> ()");
m.impl("register_buffer", torch::kCUDA, &register_buffer);
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
m.def("register_graph_buffers", &register_graph_buffers);
m.def("allocate_meta_buffer", &allocate_meta_buffer);
m.impl("allocate_meta_buffer", torch::kCUDA, &allocate_meta_buffer);
m.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle);
m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle);
// moe_align_block_size // moe_align_block_size
m.def( m.def(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment