Unverified Commit 9cf40772 authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Enable custom AR for AMD GPUs and maintain it in sgl-kernel (#3406)

parent d3fe9bae
......@@ -9,13 +9,13 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
import torch.library
from sglang.srt.utils import is_hpu
from sglang.srt.utils import is_hip, is_hpu
logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True)
if not is_hpu():
if use_vllm_custom_allreduce:
# Remove vllm dependency for custom allreduce on ROCm
if use_vllm_custom_allreduce and not is_hip():
try:
import vllm._C
except ImportError as e:
......@@ -56,7 +56,7 @@ def hint_on_error(fn):
return wrapper
if use_vllm_custom_allreduce:
if use_vllm_custom_allreduce and not is_hip():
# custom ar
def init_custom_ar(
ipc_tensors: List[torch.Tensor],
......@@ -95,39 +95,87 @@ if use_vllm_custom_allreduce:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
else:
# custom ar
def init_custom_ar(
rank_id: int,
world_size: int,
rank_data_base: torch.Tensor,
buffers: List[int],
tmp_result_buffers: List[int],
barrier_in: List[int],
barrier_out: List[int],
) -> int:
return sgl_kernel.ops.init_custom_reduce(
rank_id,
world_size,
rank_data_base,
buffers,
tmp_result_buffers,
barrier_in,
barrier_out,
)
if is_hip():
def init_custom_ar(
meta: torch.Tensor,
rank_data: torch.Tensor,
handles: List[str],
offsets: List[int],
rank: int,
full_nvlink: bool,
) -> int:
return sgl_kernel.ops.init_custom_ar(
meta, rank_data, handles, offsets, rank, full_nvlink
)
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
sgl_kernel.ops.custom_reduce(fa, inp, out)
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
sgl_kernel.ops.all_reduce_reg(fa, inp, out)
def dispose(fa: int) -> None:
sgl_kernel.ops.custom_dispose(fa)
def all_reduce_unreg(
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
) -> None:
sgl_kernel.ops.all_reduce_unreg(fa, inp, reg_buffer, out)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
def dispose(fa: int) -> None:
sgl_kernel.ops.dispose(fa)
def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
def meta_size() -> int:
return sgl_kernel.ops.meta_size()
def register_buffer(
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
) -> None:
return sgl_kernel.ops.register_buffer(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
fa: int, handles: List[str], offsets: List[List[int]]
) -> None:
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
def allocate_meta_buffer(size: int) -> torch.Tensor:
return sgl_kernel.ops.allocate_meta_buffer(size)
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return sgl_kernel.ops.get_meta_buffer_ipc_handle(inp)
else:
# custom ar
def init_custom_ar(
rank_id: int,
world_size: int,
rank_data_base: torch.Tensor,
buffers: List[int],
tmp_result_buffers: List[int],
barrier_in: List[int],
barrier_out: List[int],
) -> int:
return sgl_kernel.ops.init_custom_reduce(
rank_id,
world_size,
rank_data_base,
buffers,
tmp_result_buffers,
barrier_in,
barrier_out,
)
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
sgl_kernel.ops.custom_reduce(fa, inp, out)
def dispose(fa: int) -> None:
sgl_kernel.ops.custom_dispose(fa)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
......
......@@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
gpu_p2p_access_check,
)
from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.utils import cuda_device_count_stateless, is_cuda
from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
logger = logging.getLogger(__name__)
......@@ -28,14 +28,27 @@ if is_cuda():
except ImportError as e:
logger.warning("Failed to import pynvml with %r", e)
if is_hip():
try:
from amdsmi import (
AmdSmiException,
amdsmi_get_gpu_board_info,
amdsmi_get_processor_handles,
amdsmi_init,
amdsmi_shut_down,
amdsmi_topo_get_link_type,
)
except ImportError as e:
logger.warning("Failed to import amdsmi with %r", e)
try:
if ops.use_vllm_custom_allreduce:
if ops.use_vllm_custom_allreduce and not is_hip():
ops.meta_size()
else:
import sgl_kernel
custom_ar = True
except Exception:
# For AMD GPUs and CPUs
# For CPUs
custom_ar = False
logger = logging.getLogger(__name__)
......@@ -47,37 +60,62 @@ _R = TypeVar("_R")
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
pynvml.nvmlShutdown()
if torch.version.hip:
try:
amdsmi_init()
return fn(*args, **kwargs)
finally:
amdsmi_shut_down()
else:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
pynvml.nvmlShutdown()
return wrapper
@with_nvml_context
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
if is_hip():
"""
query if the set of gpus are fully connected by xgmi (1 hop)
"""
handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
link_type = amdsmi_topo_get_link_type(handle, peer_handle)
# type is 2 for XGMI
if link_type["hops"] != 1 or link_type["type"] != 2:
return False
except AmdSmiException as error:
logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
return False
except pynvml.NVMLError:
logger.exception(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped."
)
return False
return True
return True
else:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError:
logger.exception(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped."
)
return False
return True
def _can_p2p(rank: int, world_size: int) -> bool:
......@@ -102,15 +140,18 @@ def is_weak_contiguous(inp: torch.Tensor):
class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
_MAX_CAR_SIZE = 8192 * 1024
if is_hip():
# crossover is at 16MB buffer size for ROCm
_MAX_CAR_SIZE = 2 * 8192 * 1024
# max_size: max supported allreduce size
def __init__(
self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_size=8192 * 1024,
max_size=_MAX_CAR_SIZE,
) -> None:
"""
Args:
......@@ -185,12 +226,9 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
if is_cuda():
assert is_cuda()
if is_cuda() or is_hip():
full_nvlink = is_full_nvlink(physical_device_ids, world_size)
full_nvlink = is_full_nvlink(physical_device_ids)
else:
full_nvlink = False
if world_size > 2 and not full_nvlink:
logger.warning(
"Custom allreduce is disabled because it's not supported on"
......@@ -201,7 +239,8 @@ class CustomAllreduce:
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if not is_hip() and not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
......@@ -214,7 +253,7 @@ class CustomAllreduce:
self.world_size = world_size
self.full_nvlink = full_nvlink
if ops.use_vllm_custom_allreduce:
if ops.use_vllm_custom_allreduce and not is_hip():
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
......@@ -237,35 +276,56 @@ class CustomAllreduce:
)
ops.register_buffer(self._ptr, self.buffer_ptrs)
else:
# From TensorRT-LLM getMaxRequiredWorkspaceSize
self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
if is_hip():
# meta data buffers need to be "uncached" for signal on MI200
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
self.buffer = torch.empty(
max_size, dtype=torch.uint8, device=self.device
)
handle = ops.get_meta_buffer_ipc_handle(self.meta)
shard_data = (
bytes(handle), # ipc handle to base ptr
0, # offset of base ptr
)
handles, offsets = self._gather_ipc_meta(shard_data)
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self._ptr = ops.init_custom_ar(
self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
)
self.register_buffer(self.buffer)
self.MSCCL = os.getenv("RCCL_MSCCL_ENABLE", "1") == "1"
else:
# From TensorRT-LLM getMaxRequiredWorkspaceSize
self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self.barrier_max_size = 8 * (36 + 2) * 8
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self.barrier_max_size = 8 * (36 + 2) * 8
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
self.tmp_result_buffer_ptrs = self.create_shared_buffer(
max_size, group=group
)
self.rank_data_base = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self.barrier_in_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group
)
self.barrier_out_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group
)
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
self.tmp_result_buffer_ptrs = self.create_shared_buffer(
max_size, group=group
)
self.rank_data_base = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self.barrier_in_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group
)
self.barrier_out_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group
)
self._ptr = ops.init_custom_ar(
rank,
world_size,
self.rank_data_base,
self.buffer_ptrs,
self.tmp_result_buffer_ptrs,
self.barrier_in_ptrs,
self.barrier_out_ptrs,
)
self._ptr = ops.init_custom_ar(
rank,
world_size,
self.rank_data_base,
self.buffer_ptrs,
self.tmp_result_buffer_ptrs,
self.barrier_in_ptrs,
self.barrier_out_ptrs,
)
self.disabled = False
@staticmethod
......@@ -316,23 +376,69 @@ class CustomAllreduce:
if not self.disabled:
self.register_graph_buffers()
def register_graph_buffers(self):
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
logger.info("Registering %d cuda graph addresses", len(offset))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
def _get_ipc_meta(self, inp: torch.Tensor):
# _share_cuda_() doesn't accept meta buffer not allocated from
# PyTorch cache allocator, use direct HIP call to get IPC handle
handle = ops.get_meta_buffer_ipc_handle(inp)
shard_data = (
bytes(handle), # ipc handle to base ptr
0, # offset of base ptr
)
return self._gather_ipc_meta(shard_data)
def _gather_ipc_meta(self, shard_data):
# Note: don't use `[[None]] * self.world_size` here
# because it will create a list of the same reference
all_data: List[Optional[Any]] = [[None] for i in range(self.world_size)]
all_data[self.rank][0] = shard_data
ranks = dist.get_process_group_ranks(group=self.group)
ranks.sort()
for i, rank in enumerate(ranks):
dist.broadcast_object_list(
all_data[i], src=rank, group=self.group, device="cpu"
)
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
ops.register_graph_buffers(self._ptr, handles, offsets)
# we cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
handles = []
offsets = []
for i in range(len(all_data)):
handles.append(all_data[i][0][0]) # type: ignore
offsets.append(all_data[i][0][1]) # type: ignore
return handles, offsets
def register_buffer(self, inp: torch.Tensor):
handles, offsets = self._get_ipc_meta(inp)
ops.register_buffer(self._ptr, inp, handles, offsets)
def register_graph_buffers(self):
if is_hip():
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
logger.info("Registering %d cuda graph addresses", len(offset))
ops.register_graph_buffers(self._ptr, handles, offsets)
else:
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
logger.info("Registering %d cuda graph addresses", len(offset))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [
[None, None] for _ in range(dist.get_world_size(group=self.group))
]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(
all_data[i], src=rank, group=self.group, device="cpu"
)
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
ops.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor):
if self.disabled:
......@@ -345,11 +451,22 @@ class CustomAllreduce:
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if ops.use_vllm_custom_allreduce:
if ops.use_vllm_custom_allreduce and not is_hip():
if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size
return False
if is_hip():
if self.full_nvlink:
if self.world_size == 8:
if self.MSCCL:
return False
else:
return inp_size < self.max_size
else:
return inp_size < self.max_size
return False
if self.world_size == 2:
return (
inp_size < self.max_size
......@@ -364,6 +481,21 @@ class CustomAllreduce:
return False
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
ops.all_reduce_reg(self._ptr, inp, out)
return out
# all reduce, assuming inp tensor is NOT IPC registered
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
return out
def all_reduce(
self,
inp: torch.Tensor,
......@@ -397,13 +529,23 @@ class CustomAllreduce:
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input, registered=True)
if is_hip():
return self.all_reduce_reg(input)
else:
return self.all_reduce(input, registered=True)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return torch.empty_like(input)
else:
return self.all_reduce(input, registered=False)
if is_hip():
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
return self.all_reduce_unreg(input)
else:
return self.all_reduce(input, registered=False)
def close(self):
if not self.disabled and self._ptr:
......@@ -411,7 +553,7 @@ class CustomAllreduce:
if ops.use_vllm_custom_allreduce:
self.free_shared_buffer(self.meta_ptrs)
self.free_shared_buffer(self.buffer_ptrs)
else:
elif is_cuda():
self.free_shared_buffer(self.buffer_ptrs)
self.free_shared_buffer(self.tmp_result_buffer_ptrs)
self.free_shared_buffer(self.barrier_in_ptrs)
......
......@@ -44,6 +44,7 @@ include_dirs = [
sources = [
"src/sgl-kernel/torch_extension_rocm.cc",
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/custom_all_reduce.hip",
]
cxx_flags = ["-O3"]
......
import ctypes
import os
import torch
if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
ctypes.CDLL(
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12",
mode=ctypes.RTLD_GLOBAL,
)
from .version import __version__
from sgl_kernel.ops import (
apply_rope_with_cos_sin_cache_inplace,
bmm_fp8,
build_tree_kernel,
build_tree_kernel_efficient,
cublas_grouped_gemm,
custom_dispose,
custom_reduce,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
fused_add_rmsnorm,
gelu_and_mul,
gelu_tanh_and_mul,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
get_graph_buffer_ipc_meta,
init_custom_reduce,
int8_scaled_mm,
lightning_attention_decode,
min_p_sampling_from_probs,
moe_align_block_size,
register_graph_buffers,
rmsnorm,
sampling_scaling_penalties,
sgl_per_token_group_quant_fp8,
silu_and_mul,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
tree_speculative_sampling_target_only,
)
if torch.version.hip is not None:
from sgl_kernel.ops import (
all_reduce_reg,
all_reduce_unreg,
allocate_meta_buffer,
apply_rope_with_cos_sin_cache_inplace,
bmm_fp8,
dispose,
fp8_scaled_mm,
fused_add_rmsnorm,
gelu_and_mul,
gelu_tanh_and_mul,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
get_graph_buffer_ipc_meta,
get_meta_buffer_ipc_handle,
init_custom_ar,
int8_scaled_mm,
lightning_attention_decode,
meta_size,
min_p_sampling_from_probs,
moe_align_block_size,
register_buffer,
register_graph_buffers,
rmsnorm,
sampling_scaling_penalties,
silu_and_mul,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
)
from .version import __version__
__all__ = [
"all_reduce_reg",
"all_reduce_unreg",
"allocate_meta_buffer",
"apply_rope_with_cos_sin_cache_inplace",
"bmm_fp8",
"dispose",
"fp8_scaled_mm",
"fused_add_rmsnorm",
"gelu_and_mul",
"gelu_tanh_and_mul",
"gemma_fused_add_rmsnorm",
"gemma_rmsnorm",
"get_graph_buffer_ipc_meta",
"get_meta_buffer_ipc_handle",
"init_custom_ar",
"int8_scaled_mm",
"lightning_attention_decode",
"meta_size",
"min_p_sampling_from_probs",
"moe_align_block_size",
"register_buffer",
"register_graph_buffers",
"rmsnorm",
"sampling_scaling_penalties",
"silu_and_mul",
"top_k_renorm_prob",
"top_k_top_p_sampling_from_probs",
"top_p_renorm_prob",
]
else:
from sgl_kernel.ops import (
apply_rope_with_cos_sin_cache_inplace,
bmm_fp8,
build_tree_kernel,
build_tree_kernel_efficient,
cublas_grouped_gemm,
custom_dispose,
custom_reduce,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
fused_add_rmsnorm,
gelu_and_mul,
gelu_tanh_and_mul,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
get_graph_buffer_ipc_meta,
init_custom_reduce,
int8_scaled_mm,
lightning_attention_decode,
min_p_sampling_from_probs,
moe_align_block_size,
register_graph_buffers,
rmsnorm,
sampling_scaling_penalties,
sgl_per_token_group_quant_fp8,
silu_and_mul,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
tree_speculative_sampling_target_only,
)
__all__ = [
"apply_rope_with_cos_sin_cache_inplace",
"bmm_fp8",
"cublas_grouped_gemm",
"custom_dispose",
"custom_reduce",
"fp8_blockwise_scaled_mm",
"fp8_scaled_mm",
"fused_add_rmsnorm",
"gelu_and_mul",
"gelu_tanh_and_mul",
"gemma_fused_add_rmsnorm",
"gemma_rmsnorm",
"get_graph_buffer_ipc_meta",
"init_custom_reduce",
"int8_scaled_mm",
"lightning_attention_decode",
"min_p_sampling_from_probs",
"moe_align_block_size",
"register_graph_buffers",
"rmsnorm",
"sampling_scaling_penalties",
"silu_and_mul",
"top_k_renorm_prob",
"top_k_top_p_sampling_from_probs",
"top_p_renorm_prob",
"tree_speculative_sampling_target_only",
"build_tree_kernel_efficient",
"build_tree_kernel",
"sgl_per_token_group_quant_fp8",
]
__all__ = [
"apply_rope_with_cos_sin_cache_inplace",
"bmm_fp8",
"cublas_grouped_gemm",
"custom_dispose",
"custom_reduce",
"fp8_blockwise_scaled_mm",
"fp8_scaled_mm",
"fused_add_rmsnorm",
"gelu_and_mul",
"gelu_tanh_and_mul",
"gemma_fused_add_rmsnorm",
"gemma_rmsnorm",
"get_graph_buffer_ipc_meta",
"init_custom_reduce",
"int8_scaled_mm",
"lightning_attention_decode",
"min_p_sampling_from_probs",
"moe_align_block_size",
"register_graph_buffers",
"rmsnorm",
"sampling_scaling_penalties",
"silu_and_mul",
"top_k_renorm_prob",
"top_k_top_p_sampling_from_probs",
"top_p_renorm_prob",
"tree_speculative_sampling_target_only",
"build_tree_kernel_efficient",
"build_tree_kernel",
"sgl_per_token_group_quant_fp8",
]
// !!! This is a file automatically generated by hipify!!!
#include <ATen/hip/Exceptions.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
#include <torch/all.h>
#include "custom_all_reduce_hip.cuh"
// fake pointer type, must match fptr_t type in ops.h
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink) {
int world_size = offsets.size();
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now");
if (world_size != handles.size())
throw std::invalid_argument(
"handles length should equal to offsets length");
if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in");
hipIpcMemHandle_t ipc_handles[8];
for (int i = 0; i < world_size; i++) {
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t));
}
return (fptr_t) new vllm::CustomAllreduce(
reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
}
/**
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
* because it allows transpose of contiguous slice (i.e. slicing the first
* dimension). Currently, we require this because stride information is not
* passed into the kernels and we treat input tensors as flat.
*
* Examples
* A = torch.zeros(3, 3, 3)
* 1. A: OK
* 2. A[1:]: OK
* 3. A.permute(2, 0, 1): OK
* 4. A[1:].permute(2, 0, 1): OK
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
bool _is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
t.numel() * t.element_size());
}
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
hipStream_t stream) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
TORCH_CHECK(_is_weak_contiguous(out));
switch (out.scalar_type()) {
case at::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float*>(out.data_ptr()),
out.numel());
break;
}
case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()), out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
break;
}
#endif
default:
throw std::runtime_error(
"custom allreduce only supports float32, float16 and bfloat16");
}
}
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
_all_reduce(_fa, inp, out, stream);
}
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out) {
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
auto input_size = inp.numel() * inp.element_size();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
"registered buffer is too small to contain the input");
AT_CUDA_CHECK(hipMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
input_size, hipMemcpyDeviceToDevice, stream));
_all_reduce(_fa, reg_buffer, out, stream);
}
void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
delete fa;
}
int64_t meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr());
}
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto handles =
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
return {handles, std::move(offsets)};
}
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_graph_buffers(handles, offsets);
}
void free_meta_buffer(void* buffer) { CUDACHECK(hipFree(buffer)); }
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp) {
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto data_handle =
torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)data_handle.data_ptr(),
inp.data_ptr()));
return data_handle;
}
torch::Tensor allocate_meta_buffer(int64_t size) {
auto device_index = c10::hip::current_device();
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
void* buffer;
hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
AT_CUDA_CHECK(
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
AT_CUDA_CHECK(hipMemsetAsync(buffer, 0, size, stream));
AT_CUDA_CHECK(hipStreamSynchronize(stream));
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
auto options = torch::TensorOptions()
.dtype(torch::kI8)
.device(torch::kCUDA, device_index);
return torch::from_blob(buffer, {size}, free_meta_buffer, options);
}
std::vector<uint8_t> get_device_bdf(int dev) {
char busIdStr[] = "0000:00:00.0";
std::vector<uint8_t> bdf(sizeof(busIdStr), 0);
CUDACHECK(hipDeviceGetPCIBusId((char*)bdf.data(), sizeof(busIdStr), dev));
bdf.resize(bdf.size() - 1); // remove trailing NULL
return bdf;
}
This diff is collapsed.
......@@ -34,8 +34,23 @@ limitations under the License.
return PyModule_Create(&module); \
}
// trt_reduce
using fptr_t = int64_t;
#ifdef USE_ROCM
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int64_t rank, bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out);
void dispose(fptr_t _fa);
int64_t meta_size();
void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets);
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets);
torch::Tensor allocate_meta_buffer(int64_t size);
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
#else
// trt_reduce
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out);
......@@ -44,6 +59,7 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets);
#endif
// moe_align_block_size
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
......
......@@ -64,28 +64,79 @@ def apply_rope_with_cos_sin_cache_inplace(
)
def init_custom_reduce(
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
):
return torch.ops.sgl_kernels.init_custom_ar(
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
)
if torch.version.hip is not None:
def init_custom_ar(
meta: torch.Tensor,
rank_data: torch.Tensor,
handles: List[str],
offsets: List[int],
rank: int,
full_nvlink: bool,
) -> int:
return torch.ops.sgl_kernels.init_custom_ar(
meta, rank_data, handles, offsets, rank, full_nvlink
)
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
torch.ops.sgl_kernels.all_reduce_reg(fa, inp, out)
def all_reduce_unreg(
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
) -> None:
torch.ops.sgl_kernels.all_reduce_unreg(fa, inp, reg_buffer, out)
def dispose(fa: int) -> None:
torch.ops.sgl_kernels.dispose(fa)
def custom_dispose(fa):
torch.ops.sgl_kernels.dispose(fa)
def meta_size() -> int:
return torch.ops.sgl_kernels.meta_size()
def register_buffer(
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
) -> None:
return torch.ops.sgl_kernels.register_buffer(fa, t, handles, offsets)
def custom_reduce(fa, inp, out):
torch.ops.sgl_kernels.all_reduce(fa, inp, out)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
fa: int, handles: List[str], offsets: List[List[int]]
) -> None:
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
def allocate_meta_buffer(size: int) -> torch.Tensor:
return torch.ops.sgl_kernels.allocate_meta_buffer(size)
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return torch.ops.sgl_kernels.get_meta_buffer_ipc_handle(inp)
else:
# trt_reduce
def init_custom_reduce(
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
):
return torch.ops.sgl_kernels.init_custom_ar(
rank_id,
num_devices,
rank_data,
buffers,
tmp_buffers,
barrier_in,
barrier_out,
)
def custom_dispose(fa):
torch.ops.sgl_kernels.dispose(fa)
def get_graph_buffer_ipc_meta(fa):
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
def custom_reduce(fa, inp, out):
torch.ops.sgl_kernels.all_reduce(fa, inp, out)
def get_graph_buffer_ipc_meta(fa):
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa, handles, offsets):
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
def register_graph_buffers(fa, handles, offsets):
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
def moe_align_block_size(
......
......@@ -19,6 +19,37 @@ limitations under the License.
#include "sgl_kernels_ops.h"
TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
// Custom all-reduce kernels
m.def(
"init_custom_ar(Tensor meta, Tensor rank_data, "
"str[] handles, int[] offsets, int rank, "
"bool full_nvlink) -> int");
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
m.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
m.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
m.def(
"all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
"()");
m.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
m.def("dispose", &dispose);
m.def("meta_size", &meta_size);
m.def(
"register_buffer(int fa, Tensor t, str[] handles, "
"int[] offsets) -> ()");
m.impl("register_buffer", torch::kCUDA, &register_buffer);
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
m.def("register_graph_buffers", &register_graph_buffers);
m.def("allocate_meta_buffer", &allocate_meta_buffer);
m.impl("allocate_meta_buffer", torch::kCUDA, &allocate_meta_buffer);
m.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle);
m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle);
// moe_align_block_size
m.def(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment