"vscode:/vscode.git/clone" did not exist on "c63ec817501bf188abd38a4fc3797af8b829aa20"
Unverified Commit 24cafe31 authored by yizhang2077's avatar yizhang2077 Committed by GitHub
Browse files

add config to swtich from vllm custom allreduce to sgl_kernel custom allreduce (#2981)

parent 5a176c92
...@@ -3,6 +3,7 @@ import contextlib ...@@ -3,6 +3,7 @@ import contextlib
import functools import functools
import importlib import importlib
import logging import logging
import os
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch import torch
...@@ -11,8 +12,15 @@ import torch.library ...@@ -11,8 +12,15 @@ import torch.library
from sglang.srt.utils import is_hpu from sglang.srt.utils import is_hpu
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=False)
if not is_hpu(): if not is_hpu():
if use_vllm_custom_allreduce:
try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
else:
try: try:
import sgl_kernel import sgl_kernel
except ImportError as e: except ImportError as e:
...@@ -48,8 +56,47 @@ def hint_on_error(fn): ...@@ -48,8 +56,47 @@ def hint_on_error(fn):
return wrapper return wrapper
# custom ar if use_vllm_custom_allreduce:
def init_custom_ar( # custom ar
def init_custom_ar(
ipc_tensors: List[torch.Tensor],
rank_data: torch.Tensor,
rank: int,
full_nvlink: bool,
) -> int:
return torch.ops._C_custom_ar.init_custom_ar(
ipc_tensors, rank_data, rank, full_nvlink
)
def all_reduce(
fa: int,
inp: torch.Tensor,
out: torch.Tensor,
reg_buffer: int,
reg_buffer_sz_bytes: int,
) -> None:
torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
def dispose(fa: int) -> None:
torch.ops._C_custom_ar.dispose(fa)
def meta_size() -> int:
return torch.ops._C_custom_ar.meta_size()
def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
else:
# custom ar
def init_custom_ar(
rank_id: int, rank_id: int,
world_size: int, world_size: int,
rank_data_base: torch.Tensor, rank_data_base: torch.Tensor,
...@@ -57,7 +104,7 @@ def init_custom_ar( ...@@ -57,7 +104,7 @@ def init_custom_ar(
tmp_result_buffers: List[int], tmp_result_buffers: List[int],
barrier_in: List[int], barrier_in: List[int],
barrier_out: List[int], barrier_out: List[int],
) -> int: ) -> int:
return sgl_kernel.ops.init_custom_reduce( return sgl_kernel.ops.init_custom_reduce(
rank_id, rank_id,
world_size, world_size,
...@@ -68,22 +115,18 @@ def init_custom_ar( ...@@ -68,22 +115,18 @@ def init_custom_ar(
barrier_out, barrier_out,
) )
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
sgl_kernel.ops.custom_reduce(fa, inp, out) sgl_kernel.ops.custom_reduce(fa, inp, out)
def dispose(fa: int) -> None:
def dispose(fa: int) -> None:
sgl_kernel.ops.custom_dispose(fa) sgl_kernel.ops.custom_dispose(fa)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa) return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]] fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None: ) -> None:
sgl_kernel.ops.register_graph_buffers(fa, handles, offsets) sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
......
...@@ -21,8 +21,10 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as ...@@ -21,8 +21,10 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.utils import cuda_device_count_stateless, is_cuda from sglang.srt.utils import cuda_device_count_stateless, is_cuda
try: try:
if ops.use_vllm_custom_allreduce:
ops.meta_size()
else:
import sgl_kernel import sgl_kernel
custom_ar = True custom_ar = True
except Exception: except Exception:
# For AMD GPUs and CPUs # For AMD GPUs and CPUs
...@@ -201,6 +203,29 @@ class CustomAllreduce: ...@@ -201,6 +203,29 @@ class CustomAllreduce:
self.world_size = world_size self.world_size = world_size
self.full_nvlink = full_nvlink self.full_nvlink = full_nvlink
if ops.use_vllm_custom_allreduce:
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(
ops.meta_size() + max_size, group=group
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self._ptr = ops.init_custom_ar(
self.meta_ptrs, self.rank_data, rank, self.full_nvlink
)
ops.register_buffer(self._ptr, self.buffer_ptrs)
else:
# From TensorRT-LLM getMaxRequiredWorkspaceSize # From TensorRT-LLM getMaxRequiredWorkspaceSize
self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024] self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
...@@ -208,7 +233,9 @@ class CustomAllreduce: ...@@ -208,7 +233,9 @@ class CustomAllreduce:
self.barrier_max_size = 8 * (36 + 2) * 8 self.barrier_max_size = 8 * (36 + 2) * 8
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
self.tmp_result_buffer_ptrs = self.create_shared_buffer(max_size, group=group) self.tmp_result_buffer_ptrs = self.create_shared_buffer(
max_size, group=group
)
self.rank_data_base = torch.empty( self.rank_data_base = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
) )
...@@ -307,6 +334,11 @@ class CustomAllreduce: ...@@ -307,6 +334,11 @@ class CustomAllreduce:
return False return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides # for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL. # little performance improvement over NCCL.
if ops.use_vllm_custom_allreduce:
if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size
return False
if self.world_size == 2: if self.world_size == 2:
return ( return (
inp_size < self.max_size inp_size < self.max_size
...@@ -326,6 +358,7 @@ class CustomAllreduce: ...@@ -326,6 +358,7 @@ class CustomAllreduce:
inp: torch.Tensor, inp: torch.Tensor,
*, *,
out: torch.Tensor = None, out: torch.Tensor = None,
registered: bool = False,
): ):
"""Performs an out-of-place all reduce. """Performs an out-of-place all reduce.
...@@ -335,6 +368,14 @@ class CustomAllreduce: ...@@ -335,6 +368,14 @@ class CustomAllreduce:
""" """
if out is None: if out is None:
out = torch.empty_like(inp) out = torch.empty_like(inp)
if ops.use_vllm_custom_allreduce:
if registered:
ops.all_reduce(self._ptr, inp, out, 0, 0)
else:
ops.all_reduce(
self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
)
else:
ops.all_reduce(self._ptr, inp, out) ops.all_reduce(self._ptr, inp, out)
return out return out
...@@ -345,17 +386,21 @@ class CustomAllreduce: ...@@ -345,17 +386,21 @@ class CustomAllreduce:
return None return None
if self._IS_CAPTURING: if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing(): if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input) return self.all_reduce(input, registered=True)
else: else:
# If warm up, mimic the allocation pattern since custom # If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place. # allreduce is out-of-place.
return torch.empty_like(input) return torch.empty_like(input)
else: else:
return self.all_reduce(input) return self.all_reduce(input, registered=False)
def close(self): def close(self):
if not self.disabled and self._ptr: if not self.disabled and self._ptr:
ops.dispose(self._ptr) ops.dispose(self._ptr)
if ops.use_vllm_custom_allreduce:
self.free_shared_buffer(self.meta_ptrs)
self.free_shared_buffer(self.buffer_ptrs)
else:
self.free_shared_buffer(self.buffer_ptrs) self.free_shared_buffer(self.buffer_ptrs)
self.free_shared_buffer(self.tmp_result_buffer_ptrs) self.free_shared_buffer(self.tmp_result_buffer_ptrs)
self.free_shared_buffer(self.barrier_in_ptrs) self.free_shared_buffer(self.barrier_in_ptrs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment