Unverified Commit ea3890a5 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[Core][Distributed] code deduplication in tp&pp with coordinator(#5293)

[Core][Distributed] add coordinator to reduce code duplication in tp and pp (#5293)
parent 2135cacb
...@@ -15,7 +15,8 @@ from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq, ...@@ -15,7 +15,8 @@ from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.inputs import TextPrompt from vllm.inputs import TextPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalData from vllm.multimodal import MultiModalData
...@@ -54,6 +55,7 @@ def _read_prompts(filename: str) -> List[str]: ...@@ -54,6 +55,7 @@ def _read_prompts(filename: str) -> List[str]:
def cleanup(): def cleanup():
destroy_model_parallel() destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError): with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
gc.collect() gc.collect()
......
...@@ -7,9 +7,9 @@ import torch ...@@ -7,9 +7,9 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from vllm.distributed.communication_op import ( # noqa from vllm.distributed.communication_op import ( # noqa
graph_capture, tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
get_tp_ca_communicator) get_tp_group, graph_capture)
from ..utils import (init_test_distributed_environment, from ..utils import (init_test_distributed_environment,
multi_process_tensor_parallel) multi_process_tensor_parallel)
...@@ -91,7 +91,7 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port): ...@@ -91,7 +91,7 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
# communicate independently # communicate independently
num_communication = rank // tp_size + 1 num_communication = rank // tp_size + 1
sz = 1024 sz = 1024
fa = get_tp_ca_communicator() fa = get_tp_group().ca_comm
inp = torch.ones(sz, dtype=torch.float32, device=device) inp = torch.ones(sz, dtype=torch.float32, device=device)
out = inp out = inp
for _ in range(num_communication): for _ in range(num_communication):
......
...@@ -6,10 +6,11 @@ import torch ...@@ -6,10 +6,11 @@ import torch
import torch.distributed import torch.distributed
from vllm.distributed.communication_op import ( # noqa from vllm.distributed.communication_op import ( # noqa
graph_capture, tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
get_world_group, graph_capture,
init_distributed_environment) init_distributed_environment)
from vllm.utils import update_environment_variables from vllm.utils import update_environment_variables
...@@ -53,7 +54,8 @@ def worker_fn_wrapper(fn): ...@@ -53,7 +54,8 @@ def worker_fn_wrapper(fn):
@worker_fn_wrapper @worker_fn_wrapper
def worker_fn(): def worker_fn():
pynccl_comm = PyNcclCommunicator() pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
tensor = torch.ones(16, 1024, 1024, tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank) dtype=torch.float32).cuda(pynccl_comm.rank)
with pynccl_comm.change_state(enable=True): with pynccl_comm.change_state(enable=True):
...@@ -129,7 +131,8 @@ def test_pynccl_multiple_allreduce_with_vllm(): ...@@ -129,7 +131,8 @@ def test_pynccl_multiple_allreduce_with_vllm():
def worker_fn_with_cudagraph(): def worker_fn_with_cudagraph():
with torch.no_grad(): with torch.no_grad():
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
pynccl_comm = PyNcclCommunicator() pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
# run something in the default stream to initialize torch engine # run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -154,7 +157,8 @@ def test_pynccl_with_cudagraph(): ...@@ -154,7 +157,8 @@ def test_pynccl_with_cudagraph():
@worker_fn_wrapper @worker_fn_wrapper
def send_recv_worker_fn(): def send_recv_worker_fn():
pynccl_comm = PyNcclCommunicator() pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
if pynccl_comm.rank == 0: if pynccl_comm.rank == 0:
tensor = torch.ones(16, 1024, 1024, tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank) dtype=torch.float32).cuda(pynccl_comm.rank)
......
...@@ -12,7 +12,10 @@ from huggingface_hub import snapshot_download ...@@ -12,7 +12,10 @@ from huggingface_hub import snapshot_download
import vllm import vllm
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.distributed import destroy_model_parallel, initialize_model_parallel from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel,
init_distributed_environment,
initialize_model_parallel)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
...@@ -35,6 +38,7 @@ LONG_LORA_INFOS = [{ ...@@ -35,6 +38,7 @@ LONG_LORA_INFOS = [{
def cleanup(): def cleanup():
destroy_model_parallel() destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError): with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
gc.collect() gc.collect()
...@@ -64,15 +68,14 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): ...@@ -64,15 +68,14 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
@pytest.fixture @pytest.fixture
def dist_init(): def dist_init():
if not torch.distributed.is_initialized(): temp_file = tempfile.mkstemp()[1]
temp_file = tempfile.mkstemp()[1] init_distributed_environment(
torch.distributed.init_process_group( world_size=1,
backend="nccl", rank=0,
world_size=1, distributed_init_method=f"file://{temp_file}",
rank=0, local_rank=0,
init_method=f"file://{temp_file}", backend="nccl",
) )
torch.distributed.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(1, 1) initialize_model_parallel(1, 1)
yield yield
cleanup() cleanup()
......
import pytest import pytest
import torch import torch
from vllm.distributed.parallel_state import init_distributed_environment from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
...@@ -292,6 +293,7 @@ def distributed_init(): ...@@ -292,6 +293,7 @@ def distributed_init():
rank=0, rank=0,
distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}", distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
local_rank=0) local_rank=0)
ensure_model_parallel_initialized(1, 1)
@pytest.mark.parametrize("batch_size", list(range(2, 128))) @pytest.mark.parametrize("batch_size", list(range(2, 128)))
......
...@@ -110,7 +110,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -110,7 +110,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise NotImplementedError("TPU version must be 4 or higher.") raise NotImplementedError("TPU version must be 4 or higher.")
self.megacore_mode = None self.megacore_mode = None
tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower() tpu_type = torch_xla.tpu.get_tp_groupu_env()["TYPE"].lower()
if not tpu_type.endswith("lite"): if not tpu_type.endswith("lite"):
if self.num_kv_heads % 2 == 0: if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head" self.megacore_mode = "kv_head"
......
from collections import namedtuple from typing import Any, Dict, Optional, Union
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from torch.distributed import ProcessGroup import torch.distributed
from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator, from .parallel_state import get_tp_group
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_ca_communicator,
get_tp_pynccl_communicator)
@dataclass
class GraphCaptureContext:
stream: torch.cuda.Stream
@contextmanager
def graph_capture():
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
stream = torch.cuda.Stream()
graph_capture_context = GraphCaptureContext(stream)
ca_comm = get_tp_ca_communicator()
maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
with torch.cuda.stream(stream), maybe_ca_context:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the tensor
# size is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
tp_pynccl_comm = get_tp_pynccl_communicator()
pp_pynccl_comm = get_pp_pynccl_communicator()
if not tp_pynccl_comm:
maybe_tp_pynccl_context = nullcontext()
else:
maybe_tp_pynccl_context = tp_pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream())
if not pp_pynccl_comm:
maybe_pp_pynccl_context = nullcontext()
else:
maybe_pp_pynccl_context = pp_pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream())
with maybe_tp_pynccl_context, maybe_pp_pynccl_context:
yield graph_capture_context
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group. """All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_)
NOTE: This operation will be applied in-place on the input tensor if
disable_custom_all_reduce is set to True. Otherwise, this operation may or
may not be applied in place depending on whether custom all reduce is
invoked for a particular tensor, which further depends on the tensor size
and GPU topology.
TLDR: always assume this function modifies its input, but use the return
value as the output.
"""
ca_comm = get_tp_ca_communicator()
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:
return input_
if ca_comm is not None:
out = ca_comm.custom_all_reduce(input_)
if out is not None:
return out
pynccl_comm = get_tp_pynccl_communicator()
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
else:
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
return input_
def tensor_model_parallel_all_gather(input_: torch.Tensor, def tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor: dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group.""" """All-gather the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size() return get_tp_group().all_gather(input_, dim)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=get_tensor_model_parallel_group())
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size * input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
def tensor_model_parallel_gather(input_: torch.Tensor, def tensor_model_parallel_gather(input_: torch.Tensor,
dst: int = 0, dst: int = 0,
dim: int = -1) -> torch.Tensor: dim: int = -1) -> torch.Tensor:
"""Gather the input tensor across model parallel group. """Gather the input tensor across model parallel group."""
return get_tp_group().gather(input_, dst, dim)
NOTE: We assume that the input tensor is on the same device across
all the ranks.
"""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if get_tensor_model_parallel_rank() == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=dst,
group=get_tensor_model_parallel_group())
if get_tensor_model_parallel_rank() == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def broadcast(input_: torch.Tensor,
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input tensor."""
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return input_
# Broadcast.
torch.distributed.broadcast(input_, src=src, group=group)
return input_
def broadcast_object_list(obj_list: List[Any], def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
src: int = 0, Any]]] = None,
group: Optional[ProcessGroup] = None): src: int = 0):
"""Broadcast the input object list.""" if not torch.distributed.is_initialized():
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return obj_list
# Broadcast.
torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
return obj_list
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def _split_tensor_dict(
tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
metadata_list = []
tensor_list = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device = "cpu" if value.is_cpu else "cuda"
metadata_list.append(
(key, TensorMetadata(device, value.dtype, value.size())))
tensor_list.append(value)
else:
metadata_list.append((key, value))
return metadata_list, tensor_list
def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
`group` is used to broadcast the tensors, while `metadata_group` is used
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes).
"""
# Bypass the function if we are using only 1 GPU.
if (not torch.distributed.is_initialized()
or torch.distributed.get_world_size(group=group) == 1):
return tensor_dict return tensor_dict
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
group = group or torch.distributed.group.WORLD
metadata_group = metadata_group or get_cpu_world_group()
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
rank = torch.distributed.get_rank()
if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` involves serialization and deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
torch.distributed.broadcast_object_list([metadata_list],
src=src,
group=metadata_group)
async_handles = []
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=metadata_group,
async_op=True)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True)
async_handles.append(handle)
for async_handle in async_handles:
async_handle.wait()
else:
recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list,
src=src,
group=metadata_group)
assert recv_metadata_list[0] is not None
tensor_dict = {}
async_handles = []
for key, value in recv_metadata_list[0]:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=metadata_group,
async_op=True)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True)
async_handles.append(handle)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
for async_handle in async_handles:
async_handle.wait()
return tensor_dict
...@@ -9,8 +9,7 @@ import vllm.envs as envs ...@@ -9,8 +9,7 @@ import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.custom_all_reduce_utils import ( from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check) gpu_p2p_access_check)
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import is_in_the_same_node
get_local_rank, get_tensor_model_parallel_cpu_group, is_in_the_same_node)
from vllm.logger import init_logger from vllm.logger import init_logger
try: try:
...@@ -86,8 +85,8 @@ class CustomAllreduce: ...@@ -86,8 +85,8 @@ class CustomAllreduce:
# max_size: max supported allreduce size # max_size: max supported allreduce size
def __init__(self, def __init__(self,
group: Optional[ProcessGroup] = None, group: ProcessGroup,
device: Optional[Union[int, str, torch.device]] = None, device: Union[int, str, torch.device],
max_size=8192 * 1024) -> None: max_size=8192 * 1024) -> None:
""" """
Args: Args:
...@@ -107,7 +106,6 @@ class CustomAllreduce: ...@@ -107,7 +106,6 @@ class CustomAllreduce:
# e.g. in a non-cuda environment # e.g. in a non-cuda environment
return return
group = group or get_tensor_model_parallel_cpu_group()
self.group = group self.group = group
assert dist.get_backend(group) != dist.Backend.NCCL, ( assert dist.get_backend(group) != dist.Backend.NCCL, (
...@@ -134,10 +132,7 @@ class CustomAllreduce: ...@@ -134,10 +132,7 @@ class CustomAllreduce:
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
return return
if device is None: if isinstance(device, int):
local_rank = get_local_rank()
device = torch.device(f"cuda:{local_rank}")
elif isinstance(device, int):
device = torch.device(f"cuda:{device}") device = torch.device(f"cuda:{device}")
elif isinstance(device, str): elif isinstance(device, str):
device = torch.device(device) device = torch.device(device)
......
...@@ -11,7 +11,6 @@ import torch.distributed as dist ...@@ -11,7 +11,6 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -162,7 +161,8 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: ...@@ -162,7 +161,8 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json" f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
) )
os.makedirs(os.path.dirname(path), exist_ok=True) os.makedirs(os.path.dirname(path), exist_ok=True)
if ((not is_distributed or get_local_rank() == 0) from vllm.distributed.parallel_state import get_world_group
if ((not is_distributed or get_world_group().local_rank == 0)
and (not os.path.exists(path))): and (not os.path.exists(path))):
# only the local master process (with local_rank == 0) can # only the local master process (with local_rank == 0) can
# enter this block to calculate the cache # enter this block to calculate the cache
...@@ -174,8 +174,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: ...@@ -174,8 +174,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
with open(path, "w") as f: with open(path, "w") as f:
json.dump(cache, f, indent=4) json.dump(cache, f, indent=4)
if is_distributed: if is_distributed:
cpu_world_group = get_cpu_world_group() get_world_group().barrier()
dist.barrier(cpu_world_group)
logger.info("reading GPU P2P access cache from %s", path) logger.info("reading GPU P2P access cache from %s", path)
with open(path, "r") as f: with open(path, "r") as f:
cache = json.load(f) cache = json.load(f)
......
...@@ -9,7 +9,6 @@ from torch.distributed import ProcessGroup, ReduceOp ...@@ -9,7 +9,6 @@ from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.device_communicators.pynccl_wrapper import ( from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
ncclRedOpTypeEnum, ncclUniqueId) ncclRedOpTypeEnum, ncclUniqueId)
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -19,8 +18,8 @@ class PyNcclCommunicator: ...@@ -19,8 +18,8 @@ class PyNcclCommunicator:
def __init__( def __init__(
self, self,
group: Optional[ProcessGroup] = None, group: ProcessGroup,
device: Optional[Union[int, str, torch.device]] = None, device: Union[int, str, torch.device],
library_path: Optional[str] = None, library_path: Optional[str] = None,
): ):
""" """
...@@ -35,7 +34,6 @@ class PyNcclCommunicator: ...@@ -35,7 +34,6 @@ class PyNcclCommunicator:
is bind to a unique device. is bind to a unique device.
""" """
assert dist.is_initialized() assert dist.is_initialized()
group = get_cpu_world_group() if group is None else group
assert dist.get_backend(group) != dist.Backend.NCCL, ( assert dist.get_backend(group) != dist.Backend.NCCL, (
"PyNcclCommunicator should be attached to a non-NCCL group.") "PyNcclCommunicator should be attached to a non-NCCL group.")
self.group = group self.group = group
...@@ -77,10 +75,7 @@ class PyNcclCommunicator: ...@@ -77,10 +75,7 @@ class PyNcclCommunicator:
byte_list = tensor.tolist() byte_list = tensor.tolist()
for i, byte in enumerate(byte_list): for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte self.unique_id.internal[i] = byte
if device is None: if isinstance(device, int):
local_rank = get_local_rank()
device = torch.device(f"cuda:{local_rank}")
elif isinstance(device, int):
device = torch.device(f"cuda:{device}") device = torch.device(f"cuda:{device}")
elif isinstance(device, str): elif isinstance(device, str):
device = torch.device(device) device = torch.device(device)
......
This diff is collapsed.
...@@ -13,7 +13,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -13,7 +13,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict from vllm.distributed import broadcast_tensor_dict
from vllm.distributed.communication_op import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment