Unverified Commit 702bee46 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[Core][Distributed] refactor custom allreduce to support multiple tp groups (#4754)

parent a7be4d00
...@@ -16,7 +16,7 @@ from vllm.test_utils import (init_test_distributed_environment, ...@@ -16,7 +16,7 @@ from vllm.test_utils import (init_test_distributed_environment,
@ray.remote(num_gpus=1, max_calls=1) @ray.remote(num_gpus=1, max_calls=1)
def all_reduce_test_worker(tensor_parallel_size: int, rank: int, def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str): distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs # so that each worker can see all the GPUs
...@@ -24,12 +24,12 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, ...@@ -24,12 +24,12 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
del os.environ["CUDA_VISIBLE_DEVICES"] del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank, init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port) distributed_init_port)
num_elements = 8 num_elements = 8
all_tensors = [ all_tensors = [
torch.arange(num_elements, dtype=torch.float32, device="cuda") * torch.arange(num_elements, dtype=torch.float32, device="cuda") *
(r + 1) for r in range(tensor_parallel_size) (r + 1) for r in range(tp_size)
] ]
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[rank] t = all_tensors[rank]
...@@ -38,7 +38,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, ...@@ -38,7 +38,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
@ray.remote(num_gpus=1, max_calls=1) @ray.remote(num_gpus=1, max_calls=1)
def all_gather_test_worker(tensor_parallel_size: int, rank: int, def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str): distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs # so that each worker can see all the GPUs
...@@ -46,7 +46,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, ...@@ -46,7 +46,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
del os.environ["CUDA_VISIBLE_DEVICES"] del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank, init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port) distributed_init_port)
num_dimensions = 3 num_dimensions = 3
tensor_size = list(range(2, num_dimensions + 2)) tensor_size = list(range(2, num_dimensions + 2))
...@@ -57,7 +57,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, ...@@ -57,7 +57,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
all_tensors = [ all_tensors = [
torch.arange(total_size, dtype=torch.float32, torch.arange(total_size, dtype=torch.float32,
device="cuda").reshape(tensor_size) * (r + 1) device="cuda").reshape(tensor_size) * (r + 1)
for r in range(tensor_parallel_size) for r in range(tp_size)
] ]
expected = torch.cat(all_tensors, dim=all_gather_dimension) expected = torch.cat(all_tensors, dim=all_gather_dimension)
t = all_tensors[rank] t = all_tensors[rank]
...@@ -66,7 +66,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, ...@@ -66,7 +66,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
@ray.remote(num_gpus=1, max_calls=1) @ray.remote(num_gpus=1, max_calls=1)
def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str): distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs # so that each worker can see all the GPUs
...@@ -74,7 +74,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, ...@@ -74,7 +74,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
del os.environ["CUDA_VISIBLE_DEVICES"] del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank, init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port) distributed_init_port)
test_dict = { test_dict = {
# device tensor # device tensor
...@@ -106,10 +106,10 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, ...@@ -106,10 +106,10 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.") reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tensor_parallel_size", [2]) @pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("test_target", [ @pytest.mark.parametrize("test_target", [
all_reduce_test_worker, all_gather_test_worker, all_reduce_test_worker, all_gather_test_worker,
broadcast_tensor_dict_test_worker broadcast_tensor_dict_test_worker
]) ])
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): def test_multi_process_tensor_parallel(tp_size, test_target):
multi_process_tensor_parallel(tensor_parallel_size, test_target) multi_process_tensor_parallel(tp_size, 1, test_target)
...@@ -6,8 +6,10 @@ import ray ...@@ -6,8 +6,10 @@ import ray
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.communication_op import ( # noqa
from vllm.distributed.device_communicators import custom_all_reduce graph_capture, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
get_tp_ca_communicator)
from vllm.test_utils import (init_test_distributed_environment, from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel) multi_process_tensor_parallel)
...@@ -18,17 +20,36 @@ for i, v in enumerate(test_sizes): ...@@ -18,17 +20,36 @@ for i, v in enumerate(test_sizes):
@ray.remote(num_gpus=1, max_calls=1) @ray.remote(num_gpus=1, max_calls=1)
def graph_allreduce(world_size, rank, distributed_init_port): def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"] del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(1, world_size, rank, init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port) distributed_init_port)
custom_all_reduce.init_custom_ar() group = get_tensor_model_parallel_group()
# A small all_reduce for warmup.
# this is needed because device communicators might be created lazily
# (e.g. NCCL). This will ensure that the communicator is initialized
# before any communication happens, so that this group can be used for
# graph capture immediately.
data = torch.zeros(1)
data = data.to(device=device)
torch.distributed.all_reduce(data, group=group)
torch.cuda.synchronize()
del data
# we use the first group to communicate once
# and the second group to communicate twice
# and so on
# this is used to demonstrate that each group can
# communicate independently
num_communication = rank // tp_size + 1
for sz in test_sizes: for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]: for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with custom_all_reduce.capture(): with graph_capture():
# use integers so result matches NCCL exactly # use integers so result matches NCCL exactly
inp1 = torch.randint(1, inp1 = torch.randint(1,
16, (sz, ), 16, (sz, ),
...@@ -41,44 +62,52 @@ def graph_allreduce(world_size, rank, distributed_init_port): ...@@ -41,44 +62,52 @@ def graph_allreduce(world_size, rank, distributed_init_port):
torch.cuda.synchronize() torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph): with torch.cuda.graph(graph):
for i in range(num_communication):
out1 = tensor_model_parallel_all_reduce(inp1) out1 = tensor_model_parallel_all_reduce(inp1)
# the input buffer is immediately modified to test # the input buffer is immediately modified to test
# synchronization # synchronization
dist.all_reduce(inp1) dist.all_reduce(inp1, group=group)
out2 = tensor_model_parallel_all_reduce(inp2) out2 = tensor_model_parallel_all_reduce(inp2)
dist.all_reduce(inp2) dist.all_reduce(inp2, group=group)
graph.replay() graph.replay()
assert torch.allclose(out1, inp1) assert torch.allclose(out1, inp1)
assert torch.allclose(out2, inp2) assert torch.allclose(out2, inp2)
@ray.remote(num_gpus=1, max_calls=1) @ray.remote(num_gpus=1, max_calls=1)
def eager_allreduce(world_size, rank, distributed_init_port): def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"] del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device) torch.cuda.set_device(device)
init_test_distributed_environment(1, world_size, rank, init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port) distributed_init_port)
# we use the first group to communicate once
# and the second group to communicate twice
# and so on
# this is used to demonstrate that each group can
# communicate independently
num_communication = rank // tp_size + 1
sz = 1024 sz = 1024
custom_all_reduce.init_custom_ar() fa = get_tp_ca_communicator()
fa = custom_all_reduce.get_handle()
inp = torch.ones(sz, dtype=torch.float32, device=device) inp = torch.ones(sz, dtype=torch.float32, device=device)
out = fa.all_reduce_unreg(inp) out = inp
assert torch.allclose(out, inp * world_size) for _ in range(num_communication):
out = fa.all_reduce_unreg(out)
assert torch.allclose(out, inp * (tp_size**num_communication))
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device) inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
out = fa.all_reduce_unreg(inp) out = inp
assert torch.allclose(out, inp * world_size) for _ in range(num_communication):
out = fa.all_reduce_unreg(out)
assert torch.allclose(out, inp * (tp_size**num_communication))
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.parametrize("tp_size", [2])
reason="Need at least 2 GPUs to run the test.") @pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
@pytest.mark.parametrize("tensor_parallel_size", [2])
@pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce]) @pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce])
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
multi_process_tensor_parallel(tensor_parallel_size, test_target) world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
if __name__ == "__main__": multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target)
multi_process_tensor_parallel(2, graph_allreduce)
...@@ -5,7 +5,7 @@ import pytest ...@@ -5,7 +5,7 @@ import pytest
import torch import torch
from vllm.distributed.communication_op import ( # noqa from vllm.distributed.communication_op import ( # noqa
graph_capture_mode, tensor_model_parallel_all_reduce) graph_mode, tensor_model_parallel_all_reduce)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
...@@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn(): ...@@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}") device = torch.device(f"cuda:{torch.distributed.get_rank()}")
ensure_model_parallel_initialized(2, 2) ensure_model_parallel_initialized(2, 2)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
with graph_capture_mode(): with graph_mode():
# two tp groups can communicate independently # two tp groups can communicate independently
if torch.distributed.get_rank() in [0, 1]: if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor) tensor = tensor_model_parallel_all_reduce(tensor)
......
from collections import namedtuple from collections import namedtuple
from contextlib import contextmanager from contextlib import contextmanager, nullcontext
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -9,12 +9,13 @@ from .parallel_state import (get_cpu_world_group, ...@@ -9,12 +9,13 @@ from .parallel_state import (get_cpu_world_group,
get_tensor_model_parallel_group, get_tensor_model_parallel_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tp_ca_communicator,
get_tp_pynccl_communicator) get_tp_pynccl_communicator)
@contextmanager @contextmanager
def graph_capture_mode(): def graph_mode():
# In graph capture, we have to be very careful about the collective # In graph mode, we have to be very careful about the collective
# operations. The current status is: # operations. The current status is:
# allreduce \ Mode | Eager | Graph | # allreduce \ Mode | Eager | Graph |
# -------------------------------------------- # --------------------------------------------
...@@ -24,10 +25,32 @@ def graph_capture_mode(): ...@@ -24,10 +25,32 @@ def graph_capture_mode():
# #
# Note that custom allreduce will have a runtime check, if the tensor size # Note that custom allreduce will have a runtime check, if the tensor size
# is too large, it will fallback to the next available option. # is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
pynccl_comm = get_tp_pynccl_communicator() pynccl_comm = get_tp_pynccl_communicator()
assert pynccl_comm is not None if pynccl_comm is None:
with pynccl_comm.change_state(enable=True, context = nullcontext()
stream=torch.cuda.current_stream()): else:
context = pynccl_comm.change_state(enable=True,
stream=torch.cuda.current_stream())
with context:
yield
@contextmanager
def graph_capture():
"""
`graph_capture` is a context manager which should include the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed.
"""
ca_comm = get_tp_ca_communicator()
context = nullcontext() if ca_comm is None else ca_comm.capture()
with context:
yield yield
...@@ -43,13 +66,13 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: ...@@ -43,13 +66,13 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
TLDR: always assume this function modifies its input, but use the return TLDR: always assume this function modifies its input, but use the return
value as the output. value as the output.
""" """
from vllm.distributed.device_communicators.custom_all_reduce import ( ca_comm = get_tp_ca_communicator()
custom_all_reduce)
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1: if get_tensor_model_parallel_world_size() == 1:
return input_ return input_
out = custom_all_reduce(input_) if ca_comm is not None:
out = ca_comm.custom_all_reduce(input_)
if out is not None: if out is not None:
return out return out
pynccl_comm = get_tp_pynccl_communicator() pynccl_comm = get_tp_pynccl_communicator()
......
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, List, Optional from typing import Any, List, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.parallel_state import (
get_local_rank, get_tensor_model_parallel_cpu_group)
from vllm.logger import init_logger from vllm.logger import init_logger
try: try:
import pynvml import pynvml
from vllm._C import custom_ar from vllm._C import custom_ar
@contextmanager
def _nvml():
try:
pynvml.nvmlInit()
yield
finally:
pynvml.nvmlShutdown()
except ImportError: except ImportError:
# For AMD GPUs # For AMD GPUs
custom_ar = None custom_ar = None
pynvml = None pynvml = None
logger = init_logger(__name__) @contextmanager
def _nvml():
_CA_HANDLE: Optional["CustomAllreduce"] = None
_IS_CAPTURING = False
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
def init_custom_ar() -> None:
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
global _CA_HANDLE
if _CA_HANDLE is not None:
return
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in _SUPPORTED_WORLD_SIZES:
logger.warning(
"Custom allreduce is disabled due to an unsupported world size: "
"%d. Supported world sizes: %s. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.", world_size,
str(_SUPPORTED_WORLD_SIZES))
return
num_dev = torch.cuda.device_count()
# note: num dev can be larger than world_size if we're only using
# first few GPUs
if num_dev < world_size:
logger.warning(
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.")
return
# we only use a subset of GPUs here
# so we only need to check the nvlink connectivity of these GPUs
num_dev = world_size
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(num_dev))
# this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(device_ids)
if world_size > 2 and not full_nvlink:
logger.warning(
"Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability or P2P test failed. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
_CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink)
def begin_capture() -> None:
global _IS_CAPTURING
_IS_CAPTURING = True
def end_capture() -> None:
global _IS_CAPTURING
_IS_CAPTURING = False
def is_capturing() -> bool:
return _IS_CAPTURING and _CA_HANDLE is not None
def get_handle() -> Optional["CustomAllreduce"]:
return _CA_HANDLE
def is_initialized() -> bool:
return _CA_HANDLE is not None
@contextmanager
def capture():
try: try:
begin_capture()
yield yield
finally: finally:
end_capture() pass
handle = get_handle()
if handle is not None:
handle.register_graph_buffers()
def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
ca_handle = get_handle()
# when custom allreduce is disabled, this will be None
if ca_handle is None:
return None
if is_capturing():
if torch.cuda.is_current_stream_capturing():
if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_reg(input)
else:
if ca_handle.should_custom_ar(input):
# if warm up, mimic the allocation pattern
# since custom allreduce is out-of-place
return torch.empty_like(input)
else:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_unreg(input)
return None
@contextmanager logger = init_logger(__name__)
def _nvml():
try:
pynvml.nvmlInit()
yield
finally:
pynvml.nvmlShutdown()
@_nvml() @_nvml()
...@@ -188,22 +76,112 @@ def _can_p2p(rank: int, world_size: int) -> bool: ...@@ -188,22 +76,112 @@ def _can_p2p(rank: int, world_size: int) -> bool:
class CustomAllreduce: class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
# max_size: max supported allreduce size # max_size: max supported allreduce size
def __init__(self, def __init__(self,
rank, group: Optional[ProcessGroup] = None,
world_size, device: Optional[Union[int, str, torch.device]] = None,
full_nvlink,
max_size=8192 * 1024) -> None: max_size=8192 * 1024) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self._IS_CAPTURING = False
self.disabled = True
if custom_ar is None:
# disable because of missing custom allreduce library
# e.g. in a non-cuda environment
return
group = group or get_tensor_model_parallel_cpu_group()
self.group = group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"CustomAllreduce should be attached to a non-NCCL group.")
rank = dist.get_rank(group=self.group)
world_size = dist.get_world_size(group=self.group)
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
logger.warning(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.",
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
return
if device is None:
local_rank = get_local_rank()
device = torch.device(f"cuda:{local_rank}")
elif isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(torch.cuda.device_count()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id],
dtype=torch.int,
device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu")
for _ in range(world_size)
]
dist.all_gather(gather_list, tensor, group=self.group)
physical_device_ids = [t.item() for t in gather_list]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(physical_device_ids)
if world_size > 2 and not full_nvlink:
logger.warning(
"Custom allreduce is disabled because it's not supported on"
" more than two PCIe-only GPUs. To silence this warning, "
"specify disable_custom_all_reduce=True explicitly.")
return
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.")
return
self.disabled = False
# buffers memory are owned by this Python class and passed to C++ # buffers memory are owned by this Python class and passed to C++
# meta data composes of two parts: meta data for synchronization # meta data composes of two parts: meta data for synchronization
# (256 bytes) and a temporary buffer for storing intermediate # (256 bytes) and a temporary buffer for storing intermediate
# allreduce results. # allreduce results.
self.meta = torch.zeros(custom_ar.meta_size() + max_size, self.meta = torch.zeros(custom_ar.meta_size() + max_size,
dtype=torch.uint8, dtype=torch.uint8,
device="cuda") device=self.device)
# This is a pre-registered IPC buffer. In eager mode, input tensors # This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed # are first copied into this buffer before allreduce is performed
self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda") self.buffer = torch.empty(max_size,
dtype=torch.uint8,
device=self.device)
# This is a buffer for storing the tuples of pointers pointing to # This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of # IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB # 8*world_size bytes where world_size is at most 8. Allocating 8MB
...@@ -211,8 +189,9 @@ class CustomAllreduce: ...@@ -211,8 +189,9 @@ class CustomAllreduce:
# needs less than 10000 of registered tuples. # needs less than 10000 of registered tuples.
self.rank_data = torch.empty(8 * 1024 * 1024, self.rank_data = torch.empty(8 * 1024 * 1024,
dtype=torch.uint8, dtype=torch.uint8,
device="cuda") device=self.device)
self.max_size = max_size self.max_size = max_size
self.rank = rank
self.world_size = world_size self.world_size = world_size
handles, offsets = self._get_ipc_meta(self.meta) handles, offsets = self._get_ipc_meta(self.meta)
self.full_nvlink = full_nvlink self.full_nvlink = full_nvlink
...@@ -221,6 +200,21 @@ class CustomAllreduce: ...@@ -221,6 +200,21 @@ class CustomAllreduce:
self.full_nvlink) self.full_nvlink)
self.register_buffer(self.buffer) self.register_buffer(self.buffer)
@contextmanager
def capture(self):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try:
self._IS_CAPTURING = True
yield
finally:
self._IS_CAPTURING = False
if not self.disabled:
self.register_graph_buffers()
def _get_ipc_meta(self, inp: torch.Tensor): def _get_ipc_meta(self, inp: torch.Tensor):
data = inp.untyped_storage()._share_cuda_() data = inp.untyped_storage()._share_cuda_()
shard_data = ( shard_data = (
...@@ -230,14 +224,29 @@ class CustomAllreduce: ...@@ -230,14 +224,29 @@ class CustomAllreduce:
return self._gather_ipc_meta(shard_data) return self._gather_ipc_meta(shard_data)
def _gather_ipc_meta(self, shard_data): def _gather_ipc_meta(self, shard_data):
all_data: List[Optional[Any]] = [None] * self.world_size # Note: don't use `[[None]] * self.world_size` here
dist.all_gather_object(all_data, shard_data) # because it will create a list of the same reference
all_data: List[Optional[Any]] = [[None]
for i in range(self.world_size)]
all_data[self.rank][0] = shard_data
ranks = dist.get_process_group_ranks(group=self.group)
ranks.sort()
for i, rank in enumerate(ranks):
dist.broadcast_object_list(all_data[i],
src=rank,
group=self.group,
device="cpu")
# we cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
handles = [] handles = []
offsets = [] offsets = []
for i in range(len(all_data)): for i in range(len(all_data)):
handles.append(all_data[i][0]) # type: ignore handles.append(all_data[i][0][0]) # type: ignore
offsets.append(all_data[i][1]) # type: ignore offsets.append(all_data[i][0][1]) # type: ignore
return handles, offsets return handles, offsets
def register_buffer(self, inp: torch.Tensor): def register_buffer(self, inp: torch.Tensor):
...@@ -269,8 +278,31 @@ class CustomAllreduce: ...@@ -269,8 +278,31 @@ class CustomAllreduce:
custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out) custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out)
return out return out
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
# when custom allreduce is disabled, this will be None
if self.disabled:
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
if self.should_custom_ar(input):
return self.all_reduce_reg(input)
else:
if self.should_custom_ar(input):
# if warm up, mimic the allocation pattern
# since custom allreduce is out-of-place
return torch.empty_like(input)
else:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
if self.should_custom_ar(input):
return self.all_reduce_unreg(input)
return None
def close(self): def close(self):
if self._ptr: if not self.disabled and self._ptr:
custom_ar.dispose(self._ptr) custom_ar.dispose(self._ptr)
self._ptr = 0 self._ptr = 0
......
...@@ -96,8 +96,10 @@ class PyNcclCommunicator: ...@@ -96,8 +96,10 @@ class PyNcclCommunicator:
self.stream = torch.cuda.Stream() self.stream = torch.cuda.Stream()
# A small all_reduce for warmup. # A small all_reduce for warmup.
self.all_reduce(torch.zeros(1, device=device)) data = torch.zeros(1, device=device)
self.all_reduce(data)
self.stream.synchronize() self.stream.synchronize()
del data
# by default it is disabled, e.g. in profiling models and prefill phase. # by default it is disabled, e.g. in profiling models and prefill phase.
# to use it, use under `with obj.change_state(enable=True)`, usually # to use it, use under `with obj.change_state(enable=True)`, usually
......
...@@ -13,10 +13,13 @@ from vllm.logger import init_logger ...@@ -13,10 +13,13 @@ from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
_ENABLE_CUSTOM_ALL_REDUCE = True
# Tensor model parallel group that the current rank belongs to. # Tensor model parallel group that the current rank belongs to.
_TP_DEVICE_GROUP: Optional[ProcessGroup] = None _TP_DEVICE_GROUP: Optional[ProcessGroup] = None
_TP_CPU_GROUP: Optional[ProcessGroup] = None _TP_CPU_GROUP: Optional[ProcessGroup] = None
_TP_PYNCCL_COMMUNICATOR = None _TP_PYNCCL_COMMUNICATOR = None
_TP_CA_COMMUNICATOR = None
# Pipeline model parallel group that the current rank belongs to. # Pipeline model parallel group that the current rank belongs to.
_PP_DEVICE_GROUP: Optional[ProcessGroup] = None _PP_DEVICE_GROUP: Optional[ProcessGroup] = None
...@@ -47,11 +50,21 @@ _PP_GLOBAL_RANKS: Optional[List[int]] = None ...@@ -47,11 +50,21 @@ _PP_GLOBAL_RANKS: Optional[List[int]] = None
_LOCAL_RANK = -1 _LOCAL_RANK = -1
def set_custom_all_reduce(enable: bool):
global _ENABLE_CUSTOM_ALL_REDUCE
_ENABLE_CUSTOM_ALL_REDUCE = enable
def get_tp_pynccl_communicator(): def get_tp_pynccl_communicator():
global _TP_PYNCCL_COMMUNICATOR global _TP_PYNCCL_COMMUNICATOR
return _TP_PYNCCL_COMMUNICATOR return _TP_PYNCCL_COMMUNICATOR
def get_tp_ca_communicator():
global _TP_CA_COMMUNICATOR
return _TP_CA_COMMUNICATOR
def get_local_rank(): def get_local_rank():
global _LOCAL_RANK global _LOCAL_RANK
return _LOCAL_RANK return _LOCAL_RANK
...@@ -100,6 +113,9 @@ def init_distributed_environment( ...@@ -100,6 +113,9 @@ def init_distributed_environment(
if torch.cuda.is_available(): if torch.cuda.is_available():
data = data.to(device=f"cuda:{local_rank}") data = data.to(device=f"cuda:{local_rank}")
torch.distributed.all_reduce(data) torch.distributed.all_reduce(data)
if torch.cuda.is_available():
torch.cuda.synchronize()
del data
def initialize_model_parallel( def initialize_model_parallel(
...@@ -149,7 +165,8 @@ def initialize_model_parallel( ...@@ -149,7 +165,8 @@ def initialize_model_parallel(
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
# Build the tensor model-parallel groups. # Build the tensor model-parallel groups.
global _TP_DEVICE_GROUP, _TP_CPU_GROUP, _TP_PYNCCL_COMMUNICATOR global _TP_DEVICE_GROUP, _TP_CPU_GROUP
global _TP_PYNCCL_COMMUNICATOR, _TP_CA_COMMUNICATOR
assert _TP_DEVICE_GROUP is None, ( assert _TP_DEVICE_GROUP is None, (
"tensor model parallel group is already initialized") "tensor model parallel group is already initialized")
for i in range(num_tensor_model_parallel_groups): for i in range(num_tensor_model_parallel_groups):
...@@ -168,6 +185,15 @@ def initialize_model_parallel( ...@@ -168,6 +185,15 @@ def initialize_model_parallel(
device=_LOCAL_RANK, device=_LOCAL_RANK,
) )
# Initialize a custom fast all-reduce implementation.
if _ENABLE_CUSTOM_ALL_REDUCE:
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce)
_TP_CA_COMMUNICATOR = CustomAllreduce(
group=_TP_CPU_GROUP,
device=_LOCAL_RANK,
)
# Build the pipeline model-parallel groups. # Build the pipeline model-parallel groups.
global _PP_DEVICE_GROUP global _PP_DEVICE_GROUP
global _PP_GLOBAL_RANKS global _PP_GLOBAL_RANKS
......
...@@ -6,24 +6,24 @@ from vllm.utils import get_open_port ...@@ -6,24 +6,24 @@ from vllm.utils import get_open_port
def init_test_distributed_environment( def init_test_distributed_environment(
pipeline_parallel_size: int, tp_size: int,
tensor_parallel_size: int, pp_size: int,
rank: int, rank: int,
distributed_init_port: str, distributed_init_port: str,
local_rank: int = -1, local_rank: int = -1,
) -> None: ) -> None:
distributed_init_method = f"tcp://localhost:{distributed_init_port}" distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment( init_distributed_environment(
world_size=pipeline_parallel_size * tensor_parallel_size, world_size=pp_size * tp_size,
rank=rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
local_rank=local_rank) local_rank=local_rank)
ensure_model_parallel_initialized(tensor_parallel_size, ensure_model_parallel_initialized(tp_size, pp_size)
pipeline_parallel_size)
def multi_process_tensor_parallel( def multi_process_tensor_parallel(
tensor_parallel_size: int, tp_size: int,
pp_size: int,
test_target, test_target,
) -> None: ) -> None:
# Using ray helps debugging the error when it failed # Using ray helps debugging the error when it failed
...@@ -32,10 +32,9 @@ def multi_process_tensor_parallel( ...@@ -32,10 +32,9 @@ def multi_process_tensor_parallel(
distributed_init_port = get_open_port() distributed_init_port = get_open_port()
refs = [] refs = []
for rank in range(tensor_parallel_size): for rank in range(tp_size * pp_size):
refs.append( refs.append(
test_target.remote(tensor_parallel_size, rank, test_target.remote(tp_size, pp_size, rank, distributed_init_port))
distributed_init_port))
ray.get(refs) ray.get(refs)
ray.shutdown() ray.shutdown()
...@@ -12,8 +12,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -12,8 +12,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict from vllm.distributed import broadcast_tensor_dict
from vllm.distributed.communication_op import graph_capture_mode from vllm.distributed.communication_op import graph_capture, graph_mode
from vllm.distributed.device_communicators import custom_all_reduce
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -942,13 +941,7 @@ class ModelRunner: ...@@ -942,13 +941,7 @@ class ModelRunner:
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
] ]
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce with graph_capture():
# kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
with custom_all_reduce.capture():
# NOTE: Capturing the largest batch size first may help reduce the # NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph. # memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list): for batch_size in reversed(batch_size_capture_list):
...@@ -1040,7 +1033,7 @@ class CUDAGraphRunner: ...@@ -1040,7 +1033,7 @@ class CUDAGraphRunner:
# Run the model once without capturing the graph. # Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the # This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune). # kernel launches for initial benchmarking (e.g., Triton autotune).
with graph_capture_mode(): with graph_mode():
self.model( self.model(
input_ids, input_ids,
positions, positions,
...@@ -1055,7 +1048,7 @@ class CUDAGraphRunner: ...@@ -1055,7 +1048,7 @@ class CUDAGraphRunner:
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self._graph = torch.cuda.CUDAGraph() self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117 with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
with graph_capture_mode(): with graph_mode():
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
positions, positions,
......
...@@ -11,9 +11,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -11,9 +11,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict, from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized, ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment,
from vllm.distributed.device_communicators.custom_all_reduce import ( set_custom_all_reduce)
init_custom_ar)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
...@@ -302,16 +301,14 @@ def init_worker_distributed_environment( ...@@ -302,16 +301,14 @@ def init_worker_distributed_environment(
local_rank: int = -1, local_rank: int = -1,
) -> None: ) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank, init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank) distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
# Initialize a custom fast all-reduce implementation.
if not parallel_config.disable_custom_all_reduce:
init_custom_ar()
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype. # Check if the GPU supports the dtype.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment