Commit d2b52805 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-ori

parents 9a521c23 5438967f
......@@ -115,8 +115,8 @@ class CacheConfig:
In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
some layers can skip tokens corresponding to prefill. This flag enables
attention metadata for eligible layers to be overriden with metadata
necessary for implementating this optimization in some models (e.g. Gemma3n)
attention metadata for eligible layers to be overridden with metadata
necessary for implementing this optimization in some models (e.g. Gemma3n)
"""
def compute_hash(self) -> str:
......@@ -145,12 +145,19 @@ class CacheConfig:
self._verify_cache_dtype()
self._verify_prefix_caching()
self._verify_kv_sharing_fast_prefill()
def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus
# metrics info
return {key: str(value) for key, value in self.__dict__.items()}
def _verify_kv_sharing_fast_prefill(self) -> None:
if self.kv_sharing_fast_prefill and not envs.VLLM_USE_V1:
raise NotImplementedError(
"Fast prefill optimization for KV sharing is not supported "
"in V0 currently.")
@model_validator(mode='after')
def _verify_args(self) -> Self:
if self.cpu_offload_gb < 0:
......@@ -162,11 +169,6 @@ class CacheConfig:
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")
if self.kv_sharing_fast_prefill:
logger.warning_once(
"--kv-sharing-fast-prefill is currently work in progress "
"and not functional yet (i.e. no prefill savings)")
return self
def _verify_cache_dtype(self) -> None:
......
......@@ -225,7 +225,8 @@ class CompilationConfig:
# CudaGraph compilation
cudagraph_mode: Optional[CUDAGraphMode] = None
"""
The mode of the cudagraph.
The mode of the cudagraph:
- NONE, no cudagraph capture.
- PIECEWISE. (v1 default)
- FULL.
......@@ -336,6 +337,9 @@ class CompilationConfig:
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.mamba_mixer2",
"vllm.mamba_mixer",
"vllm.short_conv",
"vllm.linear_attention",
]
def compute_hash(self) -> str:
......@@ -382,13 +386,10 @@ class CompilationConfig:
if pass_config_exclude:
exclude["pass_config"] = pass_config_exclude
# The cast to string is necessary because Pydantic is mocked in docs
# builds and sphinx-argparse doesn't know the return type of decode()
return str(
TypeAdapter(CompilationConfig).dump_json(
self,
exclude=exclude, # type: ignore[arg-type]
exclude_unset=True).decode())
return TypeAdapter(CompilationConfig).dump_json(
self,
exclude=exclude, # type: ignore[arg-type]
exclude_unset=True).decode()
__str__ = __repr__
......
......@@ -15,7 +15,7 @@ import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless, get_open_port
from vllm.utils import cuda_device_count_stateless, get_open_ports_list
if TYPE_CHECKING:
from ray.runtime_env import RuntimeEnv
......@@ -32,6 +32,31 @@ logger = init_logger(__name__)
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
@config
@dataclass
class EPLBConfig:
"""Configuration for Expert Parallel Load Balancing (EP)."""
window_size: int = 1000
"""Window size for expert load recording."""
step_interval: int = 3000
"""
Interval for rearranging experts in expert parallelism.
Note that if this is greater than the EPLB window size, only the metrics
of the last `lb_window_size` steps will be used for rearranging experts.
"""
num_redundant_experts: int = 0
"""Number of redundant experts to use for expert parallelism."""
log_balancedness: bool = False
"""
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
@config
@dataclass
class ParallelConfig:
......@@ -75,22 +100,24 @@ class ParallelConfig:
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False
"""Enable expert parallelism load balancing for MoE layers."""
num_redundant_experts: int = 0
"""Number of redundant experts to use for expert parallelism."""
eplb_window_size: int = 1000
"""Window size for expert load recording."""
eplb_step_interval: int = 3000
"""
Interval for rearranging experts in expert parallelism.
Note that if this is greater than the EPLB window size, only the metrics
of the last `eplb_window_size` steps will be used for rearranging experts.
"""
eplb_log_balancedness: bool = False
"""
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
eplb_config: EPLBConfig = field(default_factory=EPLBConfig)
"""Expert parallelism configuration."""
num_redundant_experts: Optional[int] = None
"""`num_redundant_experts` is deprecated and has been replaced with
`eplb_config.num_redundant_experts`. This will be removed in v0.12.0.
Please use `eplb_config.num_redundant_experts` instead."""
eplb_window_size: Optional[int] = None
"""`eplb_window_size` is deprecated and has been replaced with
`eplb_config.window_size`. This will be removed in v0.12.0.
Please use `eplb_config.window_size` instead."""
eplb_step_interval: Optional[int] = None
"""`eplb_step_interval` is deprecated and has been replaced with
`eplb_config.step_interval`. This will be removed in v0.12.0.
Please use `eplb_config.step_interval` instead."""
eplb_log_balancedness: Optional[bool] = None
"""`eplb_log_balancedness` is deprecated and has been replaced with
`eplb_config.log_balancedness`. This will be removed in v0.12.0.
Please use `eplb_config.log_balancedness` instead."""
max_parallel_loading_workers: Optional[int] = None
"""Maximum number of parallel loading workers when loading model
......@@ -109,7 +136,8 @@ class ParallelConfig:
placement_group: Optional[PlacementGroup] = None
"""ray distributed model workers placement group."""
distributed_executor_backend: Optional[Union[DistributedExecutorBackend,
distributed_executor_backend: Optional[Union[str,
DistributedExecutorBackend,
type[ExecutorBase]]] = None
"""Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If the product
......@@ -137,9 +165,10 @@ class ParallelConfig:
rank: int = 0
"""Global rank in distributed setup."""
enable_multimodal_encoder_data_parallel: bool = False
""" Use data parallelism instead of tensor parallelism for vision encoder.
Only support LLama4 for now"""
_data_parallel_master_port_list: list[int] = field(default_factory=list)
"""List of open port auto-queried for data parallel messaging.
Set to be private as it's not intended to be configured by users.
"""
@property
def world_size_across_dp(self) -> int:
......@@ -153,11 +182,15 @@ class ParallelConfig:
processes that is related to data parallelism,
e.g. both in the worker and in the engine, which
can live in different processes. To avoid port conflicts, we
increment the port number each time we need to initialize a
new process group related to data parallelism.
pop a new port from the prepared port list each time we need to
initialize a new process group related to data parallelism.
"""
answer = self.data_parallel_master_port
self.data_parallel_master_port += 1
if self._data_parallel_master_port_list:
answer = self._data_parallel_master_port_list.pop()
else:
answer = self.data_parallel_master_port
self.data_parallel_master_port += 1
return answer
def stateless_init_dp_group(self) -> ProcessGroup:
......@@ -241,6 +274,38 @@ class ParallelConfig:
return hashlib.sha256(str(factors).encode()).hexdigest()
def __post_init__(self) -> None:
# Forward deprecated fields to their new location
if self.num_redundant_experts is not None:
self.eplb_config.num_redundant_experts = (
self.num_redundant_experts)
logger.warning_once(
"num_redundant_experts is deprecated and has been replaced "
"with eplb_config.num_redundant_experts. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
if self.eplb_window_size is not None:
self.eplb_config.window_size = self.eplb_window_size
logger.warning_once(
"eplb_window_size is deprecated and has been replaced "
"with eplb_config.window_size. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
if self.eplb_step_interval is not None:
self.eplb_config.step_interval = self.eplb_step_interval
logger.warning_once(
"eplb_step_interval is deprecated and has been replaced "
"with eplb_config.step_interval. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
if self.eplb_log_balancedness is not None:
self.eplb_config.log_balancedness = self.eplb_log_balancedness
logger.warning_once(
"eplb_log_balancedness is deprecated and has been replaced "
"with eplb_config.log_balancedness. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
# Continue with the rest of the initialization
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size
......@@ -251,7 +316,10 @@ class ParallelConfig:
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
# Data parallel was specified in the engine args.
self.data_parallel_master_port = get_open_port()
if not self._data_parallel_master_port_list:
self._data_parallel_master_port_list = get_open_ports_list(5)
self.data_parallel_master_port = \
self._data_parallel_master_port_list.pop()
if not (0 <= self.data_parallel_rank < self.data_parallel_size):
raise ValueError(
......@@ -279,10 +347,10 @@ class ParallelConfig:
raise ValueError(
"Expert parallelism load balancing is only supported on "
"CUDA devices now.")
if self.num_redundant_experts < 0:
if self.eplb_config.num_redundant_experts < 0:
raise ValueError(
"num_redundant_experts must be non-negative, but got "
f"{self.num_redundant_experts}.")
f"{self.eplb_config.num_redundant_experts}.")
if not self.enable_expert_parallel:
raise ValueError(
"enable_expert_parallel must be True to use EPLB.")
......@@ -293,10 +361,10 @@ class ParallelConfig:
f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
)
else:
if self.num_redundant_experts != 0:
if self.eplb_config.num_redundant_experts != 0:
raise ValueError(
"num_redundant_experts should be used with EPLB."
f"{self.num_redundant_experts}.")
f"{self.eplb_config.num_redundant_experts}.")
if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
......@@ -342,23 +410,22 @@ class ParallelConfig:
def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or (
isinstance(self.distributed_executor_backend, type)
and self.distributed_executor_backend.uses_ray)
and getattr(self.distributed_executor_backend, "uses_ray", False))
@model_validator(mode='after')
def _verify_args(self) -> Self:
# Lazy import to avoid circular import
from vllm.executor.executor_base import ExecutorBase
from vllm.platforms import current_platform
if self.distributed_executor_backend not in (
"ray", "mp", "uni",
"external_launcher", None) and not (isinstance(
if self.distributed_executor_backend is not None and not isinstance(
self.distributed_executor_backend, str) and not (isinstance(
self.distributed_executor_backend, type) and issubclass(
self.distributed_executor_backend, ExecutorBase)):
raise ValueError(
"Unrecognized distributed executor backend "
f"{self.distributed_executor_backend}. Supported "
"values are 'ray', 'mp' 'uni', 'external_launcher' or"
" custom ExecutorBase subclass.")
"values are 'ray', 'mp' 'uni', 'external_launcher', "
" custom ExecutorBase subclass or its import path.")
if self.use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
......
......@@ -207,7 +207,7 @@ class NaiveBlockAllocator(BlockAllocator):
Args:
absolute_id (int): The absolute block id for the block
in whole allocator.
in whole allocator.
Returns:
int: The zero-offset block id on certain device.
......
......@@ -61,7 +61,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
Args:
num_blocks (int): The total number of blocks to manage.
block_size (int): The size of each block in tokens.
block_ids(Optional[Iterable[int]], optional): An optional iterable of
block_ids (Optional[Iterable[int]], optional): An optional iterable of
block IDs. If not provided, block IDs will be assigned sequentially
from 0 to num_blocks - 1.
"""
......
......@@ -352,7 +352,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
with num_lookahead_slots.
Args:
sequence_group (SequenceGroup): The sequence group to swap in.
seq_group (SequenceGroup): The sequence group to swap in.
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.
......@@ -405,8 +405,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
Args:
seq_group (SequenceGroup): The sequence group to swap out.
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.
Returns:
bool: Whether it's possible to swap out current sequence group.
......@@ -420,7 +418,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
swapping out the given sequence_group with num_lookahead_slots.
Args:
sequence_group (SequenceGroup): The sequence group to swap out.
seq_group (SequenceGroup): The sequence group to swap out.
Returns:
List[Tuple[int, int]]: The mapping of swapping block from
......@@ -473,7 +471,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
on to the 'device'.
Args:
sequence_group (SequenceGroup): The sequence group to swap in/out.
seq_group (SequenceGroup): The sequence group to swap in/out.
device (Device): device to swap the 'seq_group' on.
status (SequenceStatus): The status of sequence which is needed
for action. RUNNING for swap out and SWAPPED for swap in
......
......@@ -76,7 +76,7 @@ class LRUEvictor(Evictor):
that's recorded in the Block. If there are multiple blocks with
the same last_accessed time, then the one with the largest num_hashed_tokens
will be evicted. If two blocks each have the lowest last_accessed time and
highest num_hashed_tokens value, then one will be chose arbitrarily
highest num_hashed_tokens value, then one will be chosen arbitrarily
"""
# CLEANUP_THRESHOLD determines the maximum allowable size of the priority
......
......@@ -657,7 +657,7 @@ class Scheduler:
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
partial_prefill_metadata: information about the partial prefills
that are currently running
that are currently running
Returns:
SchedulerRunningOutputs.
......@@ -1591,7 +1591,6 @@ class Scheduler:
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
state=seq_group.state,
token_type_ids=seq_group.token_type_ids,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
......
......@@ -152,8 +152,13 @@ class CuMemAllocator:
self.pointer_to_data: dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag
self.allocator_and_pools: dict[str, Any] = {}
# Creating strong references to the two callbacks here to prevent
# these ephemeral bound-method objects being garbage collected.
# See discussions in https://github.com/vllm-project/vllm/pull/22724
self.python_malloc_callback = self._python_malloc_callback
self.python_free_callback = self._python_free_callback
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
"""
Internal method to store the allocation data
when memory is allocated in the memory pool."""
......@@ -162,7 +167,7 @@ class CuMemAllocator:
allocation_handle, self.current_tag)
return
def python_free_callback(self, ptr: int) -> HandleType:
def _python_free_callback(self, ptr: int) -> HandleType:
"""
Internal method to look up the allocation data
when memory is freed in the memory pool."""
......@@ -212,9 +217,9 @@ class CuMemAllocator:
def wake_up(self, tags: Optional[list[str]] = None) -> None:
"""
Wake up the allocator from sleep mode.
All data that is previously offloaded will be loaded back to GPU
All data that is previously offloaded will be loaded back to GPU
memory, and the rest of the data will have empty memory.
:param tags: The tags of the memory allocation that will be loaded
back to GPU memory. If None, all memory allocation will be loaded
back to GPU memory.
......
......@@ -23,6 +23,39 @@ from vllm.utils import (cuda_device_count_stateless,
logger = init_logger(__name__)
MiB = 1024 * 1024
# Max size for each world size in case symmetric memory is available
# For different SM architectures
CUSTOM_ALL_REDUCE_MAX_SIZES = {
"9.0": {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: MiB // 2, # 512 KB
8: MiB // 4, # 256 KB
},
"10.0": {
2: 2 * MiB, # 2 MB
4: 2 * MiB, # 2 MB
6: 2 * MiB, # 2 MB
8: 2 * MiB, # 2 MB
}
}
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
"9.0": {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: 64 * MiB, # 64 MB
8: 64 * MiB, # 64 MB
},
"10.0": {
2: 8 * MiB, # 8 MB
4: 32 * MiB, # 32 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
}
}
def producer(batch_src: Sequence[int],
producer_queue,
......
......@@ -255,7 +255,7 @@ class DeviceCommunicatorBase:
if module.__class__.__name__ == "FusedMoE"
]
for module in moe_modules:
module.quant_method.init_prepare_finalize()
module.quant_method.init_prepare_finalize(module)
def dispatch(
self, hidden_states: torch.Tensor,
......
......@@ -44,6 +44,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
PyNcclCommunicator)
from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce)
from vllm.distributed.device_communicators.symm_mem import (
SymmMemCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
......@@ -54,6 +56,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
......@@ -69,6 +72,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
# currently be an MI300 series.
self.qr_comm = QuickAllReduce(group=self.cpu_group,
device=self.device)
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
......@@ -105,6 +114,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
symm_mem_comm = self.symm_mem_comm
if symm_mem_comm is not None and \
symm_mem_comm.should_use_symm_mem(input_):
out = symm_mem_comm.all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
......@@ -137,7 +152,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
dtype=input_tensor.dtype,
device=input_tensor.device)
pynccl_comm.reduce_scatter(output, input_)
pynccl_comm.reduce_scatter(output, input_tensor)
# Reshape before returning
return output.movedim(0, dim).contiguous()
......@@ -171,9 +186,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
device=input_tensor.device)
if sizes is not None:
pynccl_comm.reduce_scatterv(output, input_, sizes=sizes)
pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
else:
pynccl_comm.reduce_scatter(output, input_)
pynccl_comm.reduce_scatter(output, input_tensor)
# Reshape before returning
return output.movedim(0, dim).contiguous()
......
......@@ -10,8 +10,8 @@ from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
from vllm.distributed.device_communicators.all_reduce_utils import (
CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check)
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.platforms import current_platform
......@@ -109,7 +109,13 @@ class CustomAllreduce:
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
device_capability = current_platform.get_device_capability(
).as_version_str()
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
max_size)
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES)
from vllm.logger import init_logger
from vllm.platforms import current_platform
try:
import torch.distributed._symmetric_memory as torch_symm_mem
symm_mem_available = True
except ImportError:
symm_mem_available = False
logger = init_logger(__name__)
class SymmMemCommunicator:
_WORLD_SIZES_MULTIMEM = {
"9.0": [4, 6, 8],
"10.0": [6, 8],
}
def __init__(self, group: ProcessGroup, device: Union[int, str,
torch.device]):
self.disabled = True
if not symm_mem_available:
return
if not current_platform.is_cuda():
logger.warning("SymmMemCommunicator: symmetric "
"memory is not available.")
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
torch.cuda.set_device(device)
self.dtype = torch.bfloat16
self.device = device
self.group = group
self.world_size = dist.get_world_size(self.group)
self.device_capability = current_platform.get_device_capability(
).as_version_str()
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
logger.warning(
"SymmMemCommunicator: Device capability %s not supported, "
"communicator is not available.",
self.device_capability,
)
return
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
self.device_capability]:
logger.warning(
"SymmMemCommunicator: World size %d not supported, "
"communicator is not available.",
self.world_size,
)
return
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
self.world_size]
self.buffer = torch_symm_mem.empty(
self.max_size // self.dtype.itemsize,
device=self.device,
dtype=self.dtype,
)
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
if handle.multicast_ptr == 0:
logger.warning("SymmMemCommunicator: symmetric memory "
"multicast operations are not supported.")
return
self.disabled = False
def should_use_symm_mem(self, inp: torch.Tensor):
if self.disabled:
return False
if inp.dtype != self.dtype:
return False
inp_size = inp.numel() * inp.element_size()
if inp_size % 4 != 0:
return False
return inp_size < self.max_size
def all_reduce(
self,
inp: torch.Tensor,
*,
out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
if not self.should_use_symm_mem(inp):
return None
if out is None:
out = torch.empty_like(inp)
self.buffer[:inp.numel()].copy_(inp.view(-1))
if self.world_size in self._WORLD_SIZES_MULTIMEM[
self.device_capability]:
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
else:
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
out.copy_(self.buffer[:inp.numel()].view(out.shape))
return out
......@@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS
from .base_device_communicator import DeviceCommunicatorBase
......@@ -18,16 +19,17 @@ USE_RAY = parallel_config = get_current_vllm_config(
logger = init_logger(__name__)
if current_platform.is_tpu():
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import (
create_optimized_replica_groups)
if USE_RAY:
from vllm.executor import ray_utils
if not USE_TPU_COMMONS:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
if current_platform.is_tpu():
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import (
create_optimized_replica_groups)
if USE_RAY:
from vllm.executor import ray_utils
class TpuCommunicator(DeviceCommunicatorBase):
......@@ -94,10 +96,7 @@ class TpuCommunicator(DeviceCommunicatorBase):
return xm.all_gather(input_, dim=dim)
try:
if USE_TPU_COMMONS:
from tpu_commons.distributed.device_communicators import (
TpuCommunicator as TpuCommonsCommunicator)
TpuCommunicator = TpuCommonsCommunicator # type: ignore
except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
pass
......@@ -7,8 +7,13 @@ import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.logger import init_logger
from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__)
class XpuCommunicator(DeviceCommunicatorBase):
......@@ -18,6 +23,12 @@ class XpuCommunicator(DeviceCommunicatorBase):
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
def all_reduce(self, input_) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
......
......@@ -244,7 +244,7 @@ class EplbState:
dtype=torch.int32,
device=device,
)
expert_load_window_size = parallel_config.eplb_window_size
expert_load_window_size = parallel_config.eplb_config.window_size
expert_load_window = torch.zeros(
(expert_load_window_size, model.num_moe_layers,
model.num_physical_experts),
......@@ -253,7 +253,7 @@ class EplbState:
)
# Set the initial progress of rearrangement to 3/4
eplb_step_interval = parallel_config.eplb_step_interval
eplb_step_interval = parallel_config.eplb_config.step_interval
expert_rearrangement_step = max(
0, eplb_step_interval - eplb_step_interval // 4)
......@@ -409,12 +409,14 @@ class EplbState:
self.expert_rearrangement_step = 0
self.rearrange(model)
def rearrange(self,
model: MixtureOfExperts,
is_profile: bool = False,
execute_shuffle: bool = True,
global_expert_load: Optional[torch.Tensor] = None,
rank_mapping: Optional[dict[int, int]] = None) -> None:
def rearrange(
self,
model: MixtureOfExperts,
is_profile: bool = False,
execute_shuffle: bool = True,
global_expert_load: Optional[torch.Tensor] = None,
rank_mapping: Optional[dict[int,
int]] = None) -> Optional[torch.Tensor]:
"""
Rearrange the experts according to the current load.
"""
......@@ -548,6 +550,7 @@ class EplbState:
" (profile) " if is_profile else " ",
time_end - time_start,
)
return None
@staticmethod
def recv_state() -> tuple[torch.Tensor, torch.Tensor]:
......@@ -613,4 +616,4 @@ def _node_count_with_rank_mapping(
if is_same_node and node_assignment[other_rank] == 0:
node_assignment[other_rank] = next_node_id
return next_node_id
\ No newline at end of file
return next_node_id
......@@ -40,16 +40,21 @@ class KVCacheEvent(
"""Base class for all KV cache-related events"""
MEDIUM_GPU = "GPU"
class BlockStored(KVCacheEvent):
block_hashes: list[int]
parent_block_hash: Optional[int]
token_ids: list[int]
block_size: int
lora_id: Optional[int]
medium: Optional[str]
class BlockRemoved(KVCacheEvent):
block_hashes: list[int]
medium: Optional[str]
class AllBlocksCleared(KVCacheEvent):
......
......@@ -2,7 +2,7 @@
# Distributed KV cache transfer
This folder implements distributed KV cache transfer across vLLM instances.
Currently the main usecase is for disaggregated prefilling.
Currently the main use case is for disaggregated prefilling.
## Abstractions
......@@ -14,7 +14,7 @@ The KV cache transfer contains three layer of abstractions:
Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer.
NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed
NOTE: KV pipe layer is bypassable: you can skip this layer if your distributed
communication service already supports key-value-based lookup (like redis or
RDMA database).
......
......@@ -19,6 +19,8 @@ The class provides the following primitives:
Returns whether KV cache should be freed now or will be
freed asynchronously and optionally returns KV transfer
params.
take_events() - returns new KV events that were collected
by the connector since the last call.
Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata.
......@@ -34,6 +36,7 @@ The class provides the following primitives:
import enum
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
import torch
......@@ -45,6 +48,7 @@ from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
......@@ -131,8 +135,8 @@ class KVConnectorBase_V1(ABC):
Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL).
Args: kv_caches:
dictionary of layer names, kv cache
Args:
kv_caches: dictionary of layer names, kv cache
"""
return
......@@ -313,6 +317,15 @@ class KVConnectorBase_V1(ABC):
"""
return False, None
def take_events(self) -> Iterable["KVCacheEvent"]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
return ()
@classmethod
def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment