Unverified Commit 767c9dec authored by yizhang2077's avatar yizhang2077 Committed by GitHub
Browse files

adapt custom allreduce for tensorrt llm (#2511)

parent a53454c5
...@@ -27,7 +27,7 @@ runtime_common = [ ...@@ -27,7 +27,7 @@ runtime_common = [
] ]
srt = [ srt = [
"sglang[runtime_common]", "cuda-python", "sglang[runtime_common]", "cuda-python",
"sgl-kernel>=0.0.2.post12", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "sgl-kernel>=0.0.2.post14", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1",
"flashinfer==0.1.6" "flashinfer==0.1.6"
] ]
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/_custom_ops.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
import contextlib import contextlib
import functools import functools
import importlib import importlib
...@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) ...@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
if not is_hpu(): if not is_hpu():
try: try:
import custom_ar import sgl_kernel
except ImportError as e: except ImportError as e:
logger.warning("Failed to import from custom_ar with %r", e) logger.warning("Failed to import from custom_ar with %r", e)
...@@ -50,46 +50,41 @@ def hint_on_error(fn): ...@@ -50,46 +50,41 @@ def hint_on_error(fn):
# custom ar # custom ar
def init_custom_ar( def init_custom_ar(
ipc_tensors: List[torch.Tensor], rank_id: int,
rank_data: torch.Tensor, world_size: int,
rank: int, rank_data_base: torch.Tensor,
full_nvlink: bool, buffers: List[int],
tmp_result_buffers: List[int],
barrier_in: List[int],
barrier_out: List[int],
) -> int: ) -> int:
return torch.ops._C_vllm_ar.init_custom_ar( return sgl_kernel.ops.init_custom_reduce(
ipc_tensors, rank_data, rank, full_nvlink rank_id,
world_size,
rank_data_base,
buffers,
tmp_result_buffers,
barrier_in,
barrier_out,
) )
def all_reduce( def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
fa: int, sgl_kernel.ops.custom_reduce(fa, inp, out)
inp: torch.Tensor,
out: torch.Tensor,
reg_buffer: int,
reg_buffer_sz_bytes: int,
) -> None:
torch.ops._C_vllm_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
def dispose(fa: int) -> None: def dispose(fa: int) -> None:
torch.ops._C_vllm_ar.dispose(fa) sgl_kernel.ops.custom_dispose(fa)
def meta_size() -> int:
return torch.ops._C_vllm_ar.meta_size()
def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
return torch.ops._C_vllm_ar.register_buffer(fa, ipc_tensors)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return torch.ops._C_vllm_ar.get_graph_buffer_ipc_meta(fa) return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers( def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]] fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None: ) -> None:
torch.ops._C_vllm_ar.register_graph_buffers(fa, handles, offsets) sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456 # temporary fix for https://github.com/vllm-project/vllm/issues/5456
......
...@@ -21,7 +21,8 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as ...@@ -21,7 +21,8 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.utils import cuda_device_count_stateless, is_cuda from sglang.srt.utils import cuda_device_count_stateless, is_cuda
try: try:
ops.meta_size() import sgl_kernel
custom_ar = True custom_ar = True
except Exception: except Exception:
# For AMD GPUs and CPUs # For AMD GPUs and CPUs
...@@ -29,7 +30,6 @@ except Exception: ...@@ -29,7 +30,6 @@ except Exception:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_P = ParamSpec("_P") _P = ParamSpec("_P")
_R = TypeVar("_R") _R = TypeVar("_R")
...@@ -47,7 +47,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: ...@@ -47,7 +47,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@with_nvml_context @with_nvml_context
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: def is_full_nvlink(physical_device_ids: List[int]) -> bool:
""" """
query if the set of gpus are fully connected by nvlink (1 hop) query if the set of gpus are fully connected by nvlink (1 hop)
""" """
...@@ -196,32 +196,39 @@ class CustomAllreduce: ...@@ -196,32 +196,39 @@ class CustomAllreduce:
) )
return return
self.disabled = False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(
ops.meta_size() + max_size, group=group
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self.max_size = max_size self.max_size = max_size
self.rank = rank self.rank = rank
self.world_size = world_size self.world_size = world_size
self.full_nvlink = full_nvlink self.full_nvlink = full_nvlink
# From TensorRT-LLM getMaxRequiredWorkspaceSize
self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
# sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
self.barrier_max_size = 8 * (36 + 2) * 8
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
self.tmp_result_buffer_ptrs = self.create_shared_buffer(max_size, group=group)
self.rank_data_base = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self.barrier_in_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group
)
self.barrier_out_ptrs = self.create_shared_buffer(
self.barrier_max_size, group=group
)
self._ptr = ops.init_custom_ar( self._ptr = ops.init_custom_ar(
self.meta_ptrs, self.rank_data, rank, self.full_nvlink rank,
world_size,
self.rank_data_base,
self.buffer_ptrs,
self.tmp_result_buffer_ptrs,
self.barrier_in_ptrs,
self.barrier_out_ptrs,
) )
ops.register_buffer(self._ptr, self.buffer_ptrs) self.disabled = False
@staticmethod @staticmethod
def create_shared_buffer( def create_shared_buffer(
...@@ -300,12 +307,25 @@ class CustomAllreduce: ...@@ -300,12 +307,25 @@ class CustomAllreduce:
return False return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides # for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL. # little performance improvement over NCCL.
if self.world_size == 2 or self.full_nvlink: if self.world_size == 2:
return inp_size < self.max_size return (
inp_size < self.max_size
and inp_size < self.max_required_workspace_size[0]
)
if self.full_nvlink:
return (
inp_size < self.max_size
and inp_size < self.max_required_workspace_size[1]
)
return False return False
def all_reduce( def all_reduce(
self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False self,
inp: torch.Tensor,
*,
out: torch.Tensor = None,
): ):
"""Performs an out-of-place all reduce. """Performs an out-of-place all reduce.
...@@ -315,12 +335,7 @@ class CustomAllreduce: ...@@ -315,12 +335,7 @@ class CustomAllreduce:
""" """
if out is None: if out is None:
out = torch.empty_like(inp) out = torch.empty_like(inp)
if registered: ops.all_reduce(self._ptr, inp, out)
ops.all_reduce(self._ptr, inp, out, 0, 0)
else:
ops.all_reduce(
self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
)
return out return out
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
...@@ -330,23 +345,22 @@ class CustomAllreduce: ...@@ -330,23 +345,22 @@ class CustomAllreduce:
return None return None
if self._IS_CAPTURING: if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing(): if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input, registered=True) return self.all_reduce(input)
else: else:
# If warm up, mimic the allocation pattern since custom # If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place. # allreduce is out-of-place.
return torch.empty_like(input) return torch.empty_like(input)
else: else:
# Note: outside of cuda graph context, custom allreduce incurs a return self.all_reduce(input)
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return self.all_reduce(input, registered=False)
def close(self): def close(self):
if not self.disabled and self._ptr: if not self.disabled and self._ptr:
ops.dispose(self._ptr) ops.dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.meta_ptrs)
self.free_shared_buffer(self.buffer_ptrs) self.free_shared_buffer(self.buffer_ptrs)
self.free_shared_buffer(self.tmp_result_buffer_ptrs)
self.free_shared_buffer(self.barrier_in_ptrs)
self.free_shared_buffer(self.barrier_out_ptrs)
self._ptr = 0
def __del__(self): def __del__(self):
self.close() self.close()
...@@ -12,6 +12,7 @@ suites = { ...@@ -12,6 +12,7 @@ suites = {
"sampling/penaltylib", "sampling/penaltylib",
"test_abort.py", "test_abort.py",
"test_chunked_prefill.py", "test_chunked_prefill.py",
"test_custom_allreduce.py",
"test_double_sparsity.py", "test_double_sparsity.py",
"test_eagle_infer.py", "test_eagle_infer.py",
"test_embedding_openai_server.py", "test_embedding_openai_server.py",
......
import os
import random
import socket
import unittest
from typing import Any
import ray
import torch
import torch.distributed as dist
from sglang.srt.distributed import init_distributed_environment
from sglang.srt.distributed.communication_op import ( # noqa
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.parallel_state import (
get_tensor_model_parallel_group,
graph_capture,
initialize_model_parallel,
)
def get_open_port() -> int:
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def multi_process_parallel(
world_size: int,
cls: Any,
test_target: Any,
) -> None:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray.init(log_to_driver=False)
distributed_init_port = get_open_port()
refs = []
for rank in range(world_size):
refs.append(test_target.remote(cls, world_size, rank, distributed_init_port))
ray.get(refs)
ray.shutdown()
class TestCustomAllReduce(unittest.TestCase):
@classmethod
def setUpClass(cls):
random.seed(42)
# 512B to 32MB
cls.test_sizes = [512, 4096, 32768, 262144, 2097152, 16777216, 33554432]
cls.world_sizes = [2, 4, 6, 8]
cls.test_loop = 10
def test_graph_allreduce(self):
for world_size in self.world_sizes:
if world_size > torch.cuda.device_count():
continue
multi_process_parallel(world_size, self, self.graph_allreduce)
def test_eager_allreduce(self):
for world_size in self.world_sizes:
if world_size > torch.cuda.device_count():
continue
multi_process_parallel(world_size, self, self.eager_allreduce)
@ray.remote(num_gpus=1, max_calls=1)
def graph_allreduce(self, world_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(
world_size=world_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=rank,
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
group = get_tensor_model_parallel_group().device_group
# A small all_reduce for warmup.
# this is needed because device communicators might be created lazily
# (e.g. NCCL). This will ensure that the communicator is initialized
# before any communication happens, so that this group can be used for
# graph capture immediately.
data = torch.zeros(1)
data = data.to(device=device)
torch.distributed.all_reduce(data, group=group)
torch.cuda.synchronize()
del data
for sz in self.test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
for _ in range(self.test_loop):
with graph_capture() as graph_capture_context:
# use integers so result matches NCCL exactly
inp1 = torch.randint(
1,
16,
(sz,),
dtype=dtype,
device=torch.cuda.current_device(),
)
inp2 = torch.randint(
1,
16,
(sz,),
dtype=dtype,
device=torch.cuda.current_device(),
)
torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(
graph, stream=graph_capture_context.stream
):
out1 = tensor_model_parallel_all_reduce(inp1)
# the input buffer is immediately modified to test
# synchronization
dist.all_reduce(inp1, group=group)
out2 = tensor_model_parallel_all_reduce(inp2)
dist.all_reduce(inp2, group=group)
graph.replay()
torch.testing.assert_close(out1, inp1)
torch.testing.assert_close(out2, inp2)
@ray.remote(num_gpus=1, max_calls=1)
def eager_allreduce(self, world_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(
world_size=world_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=rank,
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
group = get_tensor_model_parallel_group().device_group
for sz in self.test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
for _ in range(self.test_loop):
inp1 = torch.randint(
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
)
out1 = tensor_model_parallel_all_reduce(inp1)
dist.all_reduce(inp1, group=group)
torch.testing.assert_close(out1, inp1)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment