Commit d2b52805 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-ori

parents 9a521c23 5438967f
...@@ -115,8 +115,8 @@ class CacheConfig: ...@@ -115,8 +115,8 @@ class CacheConfig:
In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254), In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
some layers can skip tokens corresponding to prefill. This flag enables some layers can skip tokens corresponding to prefill. This flag enables
attention metadata for eligible layers to be overriden with metadata attention metadata for eligible layers to be overridden with metadata
necessary for implementating this optimization in some models (e.g. Gemma3n) necessary for implementing this optimization in some models (e.g. Gemma3n)
""" """
def compute_hash(self) -> str: def compute_hash(self) -> str:
...@@ -145,12 +145,19 @@ class CacheConfig: ...@@ -145,12 +145,19 @@ class CacheConfig:
self._verify_cache_dtype() self._verify_cache_dtype()
self._verify_prefix_caching() self._verify_prefix_caching()
self._verify_kv_sharing_fast_prefill()
def metrics_info(self): def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus # convert cache_config to dict(key: str, value: str) for prometheus
# metrics info # metrics info
return {key: str(value) for key, value in self.__dict__.items()} return {key: str(value) for key, value in self.__dict__.items()}
def _verify_kv_sharing_fast_prefill(self) -> None:
if self.kv_sharing_fast_prefill and not envs.VLLM_USE_V1:
raise NotImplementedError(
"Fast prefill optimization for KV sharing is not supported "
"in V0 currently.")
@model_validator(mode='after') @model_validator(mode='after')
def _verify_args(self) -> Self: def _verify_args(self) -> Self:
if self.cpu_offload_gb < 0: if self.cpu_offload_gb < 0:
...@@ -162,11 +169,6 @@ class CacheConfig: ...@@ -162,11 +169,6 @@ class CacheConfig:
"GPU memory utilization must be less than 1.0. Got " "GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.") f"{self.gpu_memory_utilization}.")
if self.kv_sharing_fast_prefill:
logger.warning_once(
"--kv-sharing-fast-prefill is currently work in progress "
"and not functional yet (i.e. no prefill savings)")
return self return self
def _verify_cache_dtype(self) -> None: def _verify_cache_dtype(self) -> None:
......
...@@ -225,7 +225,8 @@ class CompilationConfig: ...@@ -225,7 +225,8 @@ class CompilationConfig:
# CudaGraph compilation # CudaGraph compilation
cudagraph_mode: Optional[CUDAGraphMode] = None cudagraph_mode: Optional[CUDAGraphMode] = None
""" """
The mode of the cudagraph. The mode of the cudagraph:
- NONE, no cudagraph capture. - NONE, no cudagraph capture.
- PIECEWISE. (v1 default) - PIECEWISE. (v1 default)
- FULL. - FULL.
...@@ -336,6 +337,9 @@ class CompilationConfig: ...@@ -336,6 +337,9 @@ class CompilationConfig:
"vllm.unified_attention", "vllm.unified_attention",
"vllm.unified_attention_with_output", "vllm.unified_attention_with_output",
"vllm.mamba_mixer2", "vllm.mamba_mixer2",
"vllm.mamba_mixer",
"vllm.short_conv",
"vllm.linear_attention",
] ]
def compute_hash(self) -> str: def compute_hash(self) -> str:
...@@ -382,13 +386,10 @@ class CompilationConfig: ...@@ -382,13 +386,10 @@ class CompilationConfig:
if pass_config_exclude: if pass_config_exclude:
exclude["pass_config"] = pass_config_exclude exclude["pass_config"] = pass_config_exclude
# The cast to string is necessary because Pydantic is mocked in docs return TypeAdapter(CompilationConfig).dump_json(
# builds and sphinx-argparse doesn't know the return type of decode() self,
return str( exclude=exclude, # type: ignore[arg-type]
TypeAdapter(CompilationConfig).dump_json( exclude_unset=True).decode()
self,
exclude=exclude, # type: ignore[arg-type]
exclude_unset=True).decode())
__str__ = __repr__ __str__ = __repr__
......
...@@ -15,7 +15,7 @@ import vllm.envs as envs ...@@ -15,7 +15,7 @@ import vllm.envs as envs
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless, get_open_port from vllm.utils import cuda_device_count_stateless, get_open_ports_list
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.runtime_env import RuntimeEnv from ray.runtime_env import RuntimeEnv
...@@ -32,6 +32,31 @@ logger = init_logger(__name__) ...@@ -32,6 +32,31 @@ logger = init_logger(__name__)
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
@config
@dataclass
class EPLBConfig:
"""Configuration for Expert Parallel Load Balancing (EP)."""
window_size: int = 1000
"""Window size for expert load recording."""
step_interval: int = 3000
"""
Interval for rearranging experts in expert parallelism.
Note that if this is greater than the EPLB window size, only the metrics
of the last `lb_window_size` steps will be used for rearranging experts.
"""
num_redundant_experts: int = 0
"""Number of redundant experts to use for expert parallelism."""
log_balancedness: bool = False
"""
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
@config @config
@dataclass @dataclass
class ParallelConfig: class ParallelConfig:
...@@ -75,22 +100,24 @@ class ParallelConfig: ...@@ -75,22 +100,24 @@ class ParallelConfig:
"""Use expert parallelism instead of tensor parallelism for MoE layers.""" """Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False enable_eplb: bool = False
"""Enable expert parallelism load balancing for MoE layers.""" """Enable expert parallelism load balancing for MoE layers."""
num_redundant_experts: int = 0 eplb_config: EPLBConfig = field(default_factory=EPLBConfig)
"""Number of redundant experts to use for expert parallelism.""" """Expert parallelism configuration."""
eplb_window_size: int = 1000 num_redundant_experts: Optional[int] = None
"""Window size for expert load recording.""" """`num_redundant_experts` is deprecated and has been replaced with
eplb_step_interval: int = 3000 `eplb_config.num_redundant_experts`. This will be removed in v0.12.0.
""" Please use `eplb_config.num_redundant_experts` instead."""
Interval for rearranging experts in expert parallelism. eplb_window_size: Optional[int] = None
"""`eplb_window_size` is deprecated and has been replaced with
Note that if this is greater than the EPLB window size, only the metrics `eplb_config.window_size`. This will be removed in v0.12.0.
of the last `eplb_window_size` steps will be used for rearranging experts. Please use `eplb_config.window_size` instead."""
""" eplb_step_interval: Optional[int] = None
eplb_log_balancedness: bool = False """`eplb_step_interval` is deprecated and has been replaced with
""" `eplb_config.step_interval`. This will be removed in v0.12.0.
Log the balancedness each step of expert parallelism. Please use `eplb_config.step_interval` instead."""
This is turned off by default since it will cause communication overhead. eplb_log_balancedness: Optional[bool] = None
""" """`eplb_log_balancedness` is deprecated and has been replaced with
`eplb_config.log_balancedness`. This will be removed in v0.12.0.
Please use `eplb_config.log_balancedness` instead."""
max_parallel_loading_workers: Optional[int] = None max_parallel_loading_workers: Optional[int] = None
"""Maximum number of parallel loading workers when loading model """Maximum number of parallel loading workers when loading model
...@@ -109,7 +136,8 @@ class ParallelConfig: ...@@ -109,7 +136,8 @@ class ParallelConfig:
placement_group: Optional[PlacementGroup] = None placement_group: Optional[PlacementGroup] = None
"""ray distributed model workers placement group.""" """ray distributed model workers placement group."""
distributed_executor_backend: Optional[Union[DistributedExecutorBackend, distributed_executor_backend: Optional[Union[str,
DistributedExecutorBackend,
type[ExecutorBase]]] = None type[ExecutorBase]]] = None
"""Backend to use for distributed model """Backend to use for distributed model
workers, either "ray" or "mp" (multiprocessing). If the product workers, either "ray" or "mp" (multiprocessing). If the product
...@@ -137,9 +165,10 @@ class ParallelConfig: ...@@ -137,9 +165,10 @@ class ParallelConfig:
rank: int = 0 rank: int = 0
"""Global rank in distributed setup.""" """Global rank in distributed setup."""
enable_multimodal_encoder_data_parallel: bool = False _data_parallel_master_port_list: list[int] = field(default_factory=list)
""" Use data parallelism instead of tensor parallelism for vision encoder. """List of open port auto-queried for data parallel messaging.
Only support LLama4 for now""" Set to be private as it's not intended to be configured by users.
"""
@property @property
def world_size_across_dp(self) -> int: def world_size_across_dp(self) -> int:
...@@ -153,11 +182,15 @@ class ParallelConfig: ...@@ -153,11 +182,15 @@ class ParallelConfig:
processes that is related to data parallelism, processes that is related to data parallelism,
e.g. both in the worker and in the engine, which e.g. both in the worker and in the engine, which
can live in different processes. To avoid port conflicts, we can live in different processes. To avoid port conflicts, we
increment the port number each time we need to initialize a pop a new port from the prepared port list each time we need to
new process group related to data parallelism. initialize a new process group related to data parallelism.
""" """
answer = self.data_parallel_master_port if self._data_parallel_master_port_list:
self.data_parallel_master_port += 1 answer = self._data_parallel_master_port_list.pop()
else:
answer = self.data_parallel_master_port
self.data_parallel_master_port += 1
return answer return answer
def stateless_init_dp_group(self) -> ProcessGroup: def stateless_init_dp_group(self) -> ProcessGroup:
...@@ -241,6 +274,38 @@ class ParallelConfig: ...@@ -241,6 +274,38 @@ class ParallelConfig:
return hashlib.sha256(str(factors).encode()).hexdigest() return hashlib.sha256(str(factors).encode()).hexdigest()
def __post_init__(self) -> None: def __post_init__(self) -> None:
# Forward deprecated fields to their new location
if self.num_redundant_experts is not None:
self.eplb_config.num_redundant_experts = (
self.num_redundant_experts)
logger.warning_once(
"num_redundant_experts is deprecated and has been replaced "
"with eplb_config.num_redundant_experts. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
if self.eplb_window_size is not None:
self.eplb_config.window_size = self.eplb_window_size
logger.warning_once(
"eplb_window_size is deprecated and has been replaced "
"with eplb_config.window_size. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
if self.eplb_step_interval is not None:
self.eplb_config.step_interval = self.eplb_step_interval
logger.warning_once(
"eplb_step_interval is deprecated and has been replaced "
"with eplb_config.step_interval. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
if self.eplb_log_balancedness is not None:
self.eplb_config.log_balancedness = self.eplb_log_balancedness
logger.warning_once(
"eplb_log_balancedness is deprecated and has been replaced "
"with eplb_config.log_balancedness. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
# Continue with the rest of the initialization
self.world_size = self.pipeline_parallel_size * \ self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size self.tensor_parallel_size
...@@ -251,7 +316,10 @@ class ParallelConfig: ...@@ -251,7 +316,10 @@ class ParallelConfig:
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
# Data parallel was specified in the engine args. # Data parallel was specified in the engine args.
self.data_parallel_master_port = get_open_port() if not self._data_parallel_master_port_list:
self._data_parallel_master_port_list = get_open_ports_list(5)
self.data_parallel_master_port = \
self._data_parallel_master_port_list.pop()
if not (0 <= self.data_parallel_rank < self.data_parallel_size): if not (0 <= self.data_parallel_rank < self.data_parallel_size):
raise ValueError( raise ValueError(
...@@ -279,10 +347,10 @@ class ParallelConfig: ...@@ -279,10 +347,10 @@ class ParallelConfig:
raise ValueError( raise ValueError(
"Expert parallelism load balancing is only supported on " "Expert parallelism load balancing is only supported on "
"CUDA devices now.") "CUDA devices now.")
if self.num_redundant_experts < 0: if self.eplb_config.num_redundant_experts < 0:
raise ValueError( raise ValueError(
"num_redundant_experts must be non-negative, but got " "num_redundant_experts must be non-negative, but got "
f"{self.num_redundant_experts}.") f"{self.eplb_config.num_redundant_experts}.")
if not self.enable_expert_parallel: if not self.enable_expert_parallel:
raise ValueError( raise ValueError(
"enable_expert_parallel must be True to use EPLB.") "enable_expert_parallel must be True to use EPLB.")
...@@ -293,10 +361,10 @@ class ParallelConfig: ...@@ -293,10 +361,10 @@ class ParallelConfig:
f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}." f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
) )
else: else:
if self.num_redundant_experts != 0: if self.eplb_config.num_redundant_experts != 0:
raise ValueError( raise ValueError(
"num_redundant_experts should be used with EPLB." "num_redundant_experts should be used with EPLB."
f"{self.num_redundant_experts}.") f"{self.eplb_config.num_redundant_experts}.")
if self.distributed_executor_backend is None and self.world_size > 1: if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the # We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group. # current node and we aren't in a ray placement group.
...@@ -342,23 +410,22 @@ class ParallelConfig: ...@@ -342,23 +410,22 @@ class ParallelConfig:
def use_ray(self) -> bool: def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or ( return self.distributed_executor_backend == "ray" or (
isinstance(self.distributed_executor_backend, type) isinstance(self.distributed_executor_backend, type)
and self.distributed_executor_backend.uses_ray) and getattr(self.distributed_executor_backend, "uses_ray", False))
@model_validator(mode='after') @model_validator(mode='after')
def _verify_args(self) -> Self: def _verify_args(self) -> Self:
# Lazy import to avoid circular import # Lazy import to avoid circular import
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.platforms import current_platform from vllm.platforms import current_platform
if self.distributed_executor_backend not in ( if self.distributed_executor_backend is not None and not isinstance(
"ray", "mp", "uni", self.distributed_executor_backend, str) and not (isinstance(
"external_launcher", None) and not (isinstance(
self.distributed_executor_backend, type) and issubclass( self.distributed_executor_backend, type) and issubclass(
self.distributed_executor_backend, ExecutorBase)): self.distributed_executor_backend, ExecutorBase)):
raise ValueError( raise ValueError(
"Unrecognized distributed executor backend " "Unrecognized distributed executor backend "
f"{self.distributed_executor_backend}. Supported " f"{self.distributed_executor_backend}. Supported "
"values are 'ray', 'mp' 'uni', 'external_launcher' or" "values are 'ray', 'mp' 'uni', 'external_launcher', "
" custom ExecutorBase subclass.") " custom ExecutorBase subclass or its import path.")
if self.use_ray: if self.use_ray:
from vllm.executor import ray_utils from vllm.executor import ray_utils
ray_utils.assert_ray_available() ray_utils.assert_ray_available()
......
...@@ -207,7 +207,7 @@ class NaiveBlockAllocator(BlockAllocator): ...@@ -207,7 +207,7 @@ class NaiveBlockAllocator(BlockAllocator):
Args: Args:
absolute_id (int): The absolute block id for the block absolute_id (int): The absolute block id for the block
in whole allocator. in whole allocator.
Returns: Returns:
int: The zero-offset block id on certain device. int: The zero-offset block id on certain device.
......
...@@ -61,7 +61,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -61,7 +61,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
Args: Args:
num_blocks (int): The total number of blocks to manage. num_blocks (int): The total number of blocks to manage.
block_size (int): The size of each block in tokens. block_size (int): The size of each block in tokens.
block_ids(Optional[Iterable[int]], optional): An optional iterable of block_ids (Optional[Iterable[int]], optional): An optional iterable of
block IDs. If not provided, block IDs will be assigned sequentially block IDs. If not provided, block IDs will be assigned sequentially
from 0 to num_blocks - 1. from 0 to num_blocks - 1.
""" """
......
...@@ -352,7 +352,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): ...@@ -352,7 +352,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
with num_lookahead_slots. with num_lookahead_slots.
Args: Args:
sequence_group (SequenceGroup): The sequence group to swap in. seq_group (SequenceGroup): The sequence group to swap in.
num_lookahead_slots (int): Number of lookahead slots used in num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0. speculative decoding, default to 0.
...@@ -405,8 +405,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): ...@@ -405,8 +405,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
Args: Args:
seq_group (SequenceGroup): The sequence group to swap out. seq_group (SequenceGroup): The sequence group to swap out.
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.
Returns: Returns:
bool: Whether it's possible to swap out current sequence group. bool: Whether it's possible to swap out current sequence group.
...@@ -420,7 +418,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): ...@@ -420,7 +418,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
swapping out the given sequence_group with num_lookahead_slots. swapping out the given sequence_group with num_lookahead_slots.
Args: Args:
sequence_group (SequenceGroup): The sequence group to swap out. seq_group (SequenceGroup): The sequence group to swap out.
Returns: Returns:
List[Tuple[int, int]]: The mapping of swapping block from List[Tuple[int, int]]: The mapping of swapping block from
...@@ -473,7 +471,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): ...@@ -473,7 +471,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
on to the 'device'. on to the 'device'.
Args: Args:
sequence_group (SequenceGroup): The sequence group to swap in/out. seq_group (SequenceGroup): The sequence group to swap in/out.
device (Device): device to swap the 'seq_group' on. device (Device): device to swap the 'seq_group' on.
status (SequenceStatus): The status of sequence which is needed status (SequenceStatus): The status of sequence which is needed
for action. RUNNING for swap out and SWAPPED for swap in for action. RUNNING for swap out and SWAPPED for swap in
......
...@@ -76,7 +76,7 @@ class LRUEvictor(Evictor): ...@@ -76,7 +76,7 @@ class LRUEvictor(Evictor):
that's recorded in the Block. If there are multiple blocks with that's recorded in the Block. If there are multiple blocks with
the same last_accessed time, then the one with the largest num_hashed_tokens the same last_accessed time, then the one with the largest num_hashed_tokens
will be evicted. If two blocks each have the lowest last_accessed time and will be evicted. If two blocks each have the lowest last_accessed time and
highest num_hashed_tokens value, then one will be chose arbitrarily highest num_hashed_tokens value, then one will be chosen arbitrarily
""" """
# CLEANUP_THRESHOLD determines the maximum allowable size of the priority # CLEANUP_THRESHOLD determines the maximum allowable size of the priority
......
...@@ -657,7 +657,7 @@ class Scheduler: ...@@ -657,7 +657,7 @@ class Scheduler:
`budget.num_batched_tokens` has not enough capacity to schedule `budget.num_batched_tokens` has not enough capacity to schedule
all tokens. all tokens.
partial_prefill_metadata: information about the partial prefills partial_prefill_metadata: information about the partial prefills
that are currently running that are currently running
Returns: Returns:
SchedulerRunningOutputs. SchedulerRunningOutputs.
...@@ -1591,7 +1591,6 @@ class Scheduler: ...@@ -1591,7 +1591,6 @@ class Scheduler:
encoder_seq_data=encoder_seq_data, encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table, cross_block_table=cross_block_table,
state=seq_group.state, state=seq_group.state,
token_type_ids=seq_group.token_type_ids,
# `multi_modal_data` will only be present for the 1st comm # `multi_modal_data` will only be present for the 1st comm
# between engine and worker. # between engine and worker.
# the subsequent comms can still use delta, but # the subsequent comms can still use delta, but
......
...@@ -152,8 +152,13 @@ class CuMemAllocator: ...@@ -152,8 +152,13 @@ class CuMemAllocator:
self.pointer_to_data: dict[int, AllocationData] = {} self.pointer_to_data: dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag self.current_tag: str = CuMemAllocator.default_tag
self.allocator_and_pools: dict[str, Any] = {} self.allocator_and_pools: dict[str, Any] = {}
# Creating strong references to the two callbacks here to prevent
# these ephemeral bound-method objects being garbage collected.
# See discussions in https://github.com/vllm-project/vllm/pull/22724
self.python_malloc_callback = self._python_malloc_callback
self.python_free_callback = self._python_free_callback
def python_malloc_callback(self, allocation_handle: HandleType) -> None: def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
""" """
Internal method to store the allocation data Internal method to store the allocation data
when memory is allocated in the memory pool.""" when memory is allocated in the memory pool."""
...@@ -162,7 +167,7 @@ class CuMemAllocator: ...@@ -162,7 +167,7 @@ class CuMemAllocator:
allocation_handle, self.current_tag) allocation_handle, self.current_tag)
return return
def python_free_callback(self, ptr: int) -> HandleType: def _python_free_callback(self, ptr: int) -> HandleType:
""" """
Internal method to look up the allocation data Internal method to look up the allocation data
when memory is freed in the memory pool.""" when memory is freed in the memory pool."""
...@@ -212,9 +217,9 @@ class CuMemAllocator: ...@@ -212,9 +217,9 @@ class CuMemAllocator:
def wake_up(self, tags: Optional[list[str]] = None) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
""" """
Wake up the allocator from sleep mode. Wake up the allocator from sleep mode.
All data that is previously offloaded will be loaded back to GPU All data that is previously offloaded will be loaded back to GPU
memory, and the rest of the data will have empty memory. memory, and the rest of the data will have empty memory.
:param tags: The tags of the memory allocation that will be loaded :param tags: The tags of the memory allocation that will be loaded
back to GPU memory. If None, all memory allocation will be loaded back to GPU memory. If None, all memory allocation will be loaded
back to GPU memory. back to GPU memory.
......
...@@ -23,6 +23,39 @@ from vllm.utils import (cuda_device_count_stateless, ...@@ -23,6 +23,39 @@ from vllm.utils import (cuda_device_count_stateless,
logger = init_logger(__name__) logger = init_logger(__name__)
MiB = 1024 * 1024
# Max size for each world size in case symmetric memory is available
# For different SM architectures
CUSTOM_ALL_REDUCE_MAX_SIZES = {
"9.0": {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: MiB // 2, # 512 KB
8: MiB // 4, # 256 KB
},
"10.0": {
2: 2 * MiB, # 2 MB
4: 2 * MiB, # 2 MB
6: 2 * MiB, # 2 MB
8: 2 * MiB, # 2 MB
}
}
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
"9.0": {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: 64 * MiB, # 64 MB
8: 64 * MiB, # 64 MB
},
"10.0": {
2: 8 * MiB, # 8 MB
4: 32 * MiB, # 32 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
}
}
def producer(batch_src: Sequence[int], def producer(batch_src: Sequence[int],
producer_queue, producer_queue,
......
...@@ -255,7 +255,7 @@ class DeviceCommunicatorBase: ...@@ -255,7 +255,7 @@ class DeviceCommunicatorBase:
if module.__class__.__name__ == "FusedMoE" if module.__class__.__name__ == "FusedMoE"
] ]
for module in moe_modules: for module in moe_modules:
module.quant_method.init_prepare_finalize() module.quant_method.init_prepare_finalize(module)
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self, hidden_states: torch.Tensor,
......
...@@ -44,6 +44,8 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -44,6 +44,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
PyNcclCommunicator) PyNcclCommunicator)
from vllm.distributed.device_communicators.quick_all_reduce import ( from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce) QuickAllReduce)
from vllm.distributed.device_communicators.symm_mem import (
SymmMemCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1: if use_pynccl and self.world_size > 1:
...@@ -54,6 +56,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -54,6 +56,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.ca_comm: Optional[CustomAllreduce] = None self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if use_custom_allreduce and self.world_size > 1: if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation. # Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce( self.ca_comm = CustomAllreduce(
...@@ -69,6 +72,12 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -69,6 +72,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
# currently be an MI300 series. # currently be an MI300 series.
self.qr_comm = QuickAllReduce(group=self.cpu_group, self.qr_comm = QuickAllReduce(group=self.cpu_group,
device=self.device) device=self.device)
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)
if self.use_all2all: if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive": if all2all_backend == "naive":
...@@ -105,6 +114,12 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -105,6 +114,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
out = ca_comm.custom_all_reduce(input_) out = ca_comm.custom_all_reduce(input_)
assert out is not None assert out is not None
return out return out
symm_mem_comm = self.symm_mem_comm
if symm_mem_comm is not None and \
symm_mem_comm.should_use_symm_mem(input_):
out = symm_mem_comm.all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_) out = pynccl_comm.all_reduce(input_)
...@@ -137,7 +152,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -137,7 +152,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
dtype=input_tensor.dtype, dtype=input_tensor.dtype,
device=input_tensor.device) device=input_tensor.device)
pynccl_comm.reduce_scatter(output, input_) pynccl_comm.reduce_scatter(output, input_tensor)
# Reshape before returning # Reshape before returning
return output.movedim(0, dim).contiguous() return output.movedim(0, dim).contiguous()
...@@ -171,9 +186,9 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -171,9 +186,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
device=input_tensor.device) device=input_tensor.device)
if sizes is not None: if sizes is not None:
pynccl_comm.reduce_scatterv(output, input_, sizes=sizes) pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
else: else:
pynccl_comm.reduce_scatter(output, input_) pynccl_comm.reduce_scatter(output, input_tensor)
# Reshape before returning # Reshape before returning
return output.movedim(0, dim).contiguous() return output.movedim(0, dim).contiguous()
......
...@@ -10,8 +10,8 @@ from torch.distributed import ProcessGroup ...@@ -10,8 +10,8 @@ from torch.distributed import ProcessGroup
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.custom_all_reduce_utils import ( from vllm.distributed.device_communicators.all_reduce_utils import (
gpu_p2p_access_check) CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check)
from vllm.distributed.parallel_state import in_the_same_node_as from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -109,7 +109,13 @@ class CustomAllreduce: ...@@ -109,7 +109,13 @@ class CustomAllreduce:
# now `device` is a `torch.device` object # now `device` is a `torch.device` object
assert isinstance(device, torch.device) assert isinstance(device, torch.device)
self.device = device self.device = device
device_capability = current_platform.get_device_capability(
).as_version_str()
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
max_size)
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices: if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(","))) device_ids = list(map(int, cuda_visible_devices.split(",")))
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES)
from vllm.logger import init_logger
from vllm.platforms import current_platform
try:
import torch.distributed._symmetric_memory as torch_symm_mem
symm_mem_available = True
except ImportError:
symm_mem_available = False
logger = init_logger(__name__)
class SymmMemCommunicator:
_WORLD_SIZES_MULTIMEM = {
"9.0": [4, 6, 8],
"10.0": [6, 8],
}
def __init__(self, group: ProcessGroup, device: Union[int, str,
torch.device]):
self.disabled = True
if not symm_mem_available:
return
if not current_platform.is_cuda():
logger.warning("SymmMemCommunicator: symmetric "
"memory is not available.")
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
torch.cuda.set_device(device)
self.dtype = torch.bfloat16
self.device = device
self.group = group
self.world_size = dist.get_world_size(self.group)
self.device_capability = current_platform.get_device_capability(
).as_version_str()
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
logger.warning(
"SymmMemCommunicator: Device capability %s not supported, "
"communicator is not available.",
self.device_capability,
)
return
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
self.device_capability]:
logger.warning(
"SymmMemCommunicator: World size %d not supported, "
"communicator is not available.",
self.world_size,
)
return
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
self.world_size]
self.buffer = torch_symm_mem.empty(
self.max_size // self.dtype.itemsize,
device=self.device,
dtype=self.dtype,
)
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
if handle.multicast_ptr == 0:
logger.warning("SymmMemCommunicator: symmetric memory "
"multicast operations are not supported.")
return
self.disabled = False
def should_use_symm_mem(self, inp: torch.Tensor):
if self.disabled:
return False
if inp.dtype != self.dtype:
return False
inp_size = inp.numel() * inp.element_size()
if inp_size % 4 != 0:
return False
return inp_size < self.max_size
def all_reduce(
self,
inp: torch.Tensor,
*,
out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
if not self.should_use_symm_mem(inp):
return None
if out is None:
out = torch.empty_like(inp)
self.buffer[:inp.numel()].copy_(inp.view(-1))
if self.world_size in self._WORLD_SIZES_MULTIMEM[
self.device_capability]:
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
else:
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
out.copy_(self.buffer[:inp.numel()].view(out.shape))
return out
...@@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup ...@@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS
from .base_device_communicator import DeviceCommunicatorBase from .base_device_communicator import DeviceCommunicatorBase
...@@ -18,16 +19,17 @@ USE_RAY = parallel_config = get_current_vllm_config( ...@@ -18,16 +19,17 @@ USE_RAY = parallel_config = get_current_vllm_config(
logger = init_logger(__name__) logger = init_logger(__name__)
if current_platform.is_tpu(): if not USE_TPU_COMMONS:
import torch_xla logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
import torch_xla.core.xla_model as xm if current_platform.is_tpu():
import torch_xla.runtime as xr import torch_xla
from torch_xla._internal import pjrt import torch_xla.core.xla_model as xm
from torch_xla.distributed.xla_multiprocessing import ( import torch_xla.runtime as xr
create_optimized_replica_groups) from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import (
if USE_RAY: create_optimized_replica_groups)
from vllm.executor import ray_utils if USE_RAY:
from vllm.executor import ray_utils
class TpuCommunicator(DeviceCommunicatorBase): class TpuCommunicator(DeviceCommunicatorBase):
...@@ -94,10 +96,7 @@ class TpuCommunicator(DeviceCommunicatorBase): ...@@ -94,10 +96,7 @@ class TpuCommunicator(DeviceCommunicatorBase):
return xm.all_gather(input_, dim=dim) return xm.all_gather(input_, dim=dim)
try: if USE_TPU_COMMONS:
from tpu_commons.distributed.device_communicators import ( from tpu_commons.distributed.device_communicators import (
TpuCommunicator as TpuCommonsCommunicator) TpuCommunicator as TpuCommonsCommunicator)
TpuCommunicator = TpuCommonsCommunicator # type: ignore TpuCommunicator = TpuCommonsCommunicator # type: ignore
except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
pass
...@@ -7,8 +7,13 @@ import torch ...@@ -7,8 +7,13 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.logger import init_logger
from .base_device_communicator import DeviceCommunicatorBase from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__)
class XpuCommunicator(DeviceCommunicatorBase): class XpuCommunicator(DeviceCommunicatorBase):
...@@ -18,6 +23,12 @@ class XpuCommunicator(DeviceCommunicatorBase): ...@@ -18,6 +23,12 @@ class XpuCommunicator(DeviceCommunicatorBase):
device_group: Optional[ProcessGroup] = None, device_group: Optional[ProcessGroup] = None,
unique_name: str = ""): unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name) super().__init__(cpu_group, device, device_group, unique_name)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
def all_reduce(self, input_) -> torch.Tensor: def all_reduce(self, input_) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group) dist.all_reduce(input_, group=self.device_group)
......
...@@ -244,7 +244,7 @@ class EplbState: ...@@ -244,7 +244,7 @@ class EplbState:
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
expert_load_window_size = parallel_config.eplb_window_size expert_load_window_size = parallel_config.eplb_config.window_size
expert_load_window = torch.zeros( expert_load_window = torch.zeros(
(expert_load_window_size, model.num_moe_layers, (expert_load_window_size, model.num_moe_layers,
model.num_physical_experts), model.num_physical_experts),
...@@ -253,7 +253,7 @@ class EplbState: ...@@ -253,7 +253,7 @@ class EplbState:
) )
# Set the initial progress of rearrangement to 3/4 # Set the initial progress of rearrangement to 3/4
eplb_step_interval = parallel_config.eplb_step_interval eplb_step_interval = parallel_config.eplb_config.step_interval
expert_rearrangement_step = max( expert_rearrangement_step = max(
0, eplb_step_interval - eplb_step_interval // 4) 0, eplb_step_interval - eplb_step_interval // 4)
...@@ -409,12 +409,14 @@ class EplbState: ...@@ -409,12 +409,14 @@ class EplbState:
self.expert_rearrangement_step = 0 self.expert_rearrangement_step = 0
self.rearrange(model) self.rearrange(model)
def rearrange(self, def rearrange(
model: MixtureOfExperts, self,
is_profile: bool = False, model: MixtureOfExperts,
execute_shuffle: bool = True, is_profile: bool = False,
global_expert_load: Optional[torch.Tensor] = None, execute_shuffle: bool = True,
rank_mapping: Optional[dict[int, int]] = None) -> None: global_expert_load: Optional[torch.Tensor] = None,
rank_mapping: Optional[dict[int,
int]] = None) -> Optional[torch.Tensor]:
""" """
Rearrange the experts according to the current load. Rearrange the experts according to the current load.
""" """
...@@ -548,6 +550,7 @@ class EplbState: ...@@ -548,6 +550,7 @@ class EplbState:
" (profile) " if is_profile else " ", " (profile) " if is_profile else " ",
time_end - time_start, time_end - time_start,
) )
return None
@staticmethod @staticmethod
def recv_state() -> tuple[torch.Tensor, torch.Tensor]: def recv_state() -> tuple[torch.Tensor, torch.Tensor]:
...@@ -613,4 +616,4 @@ def _node_count_with_rank_mapping( ...@@ -613,4 +616,4 @@ def _node_count_with_rank_mapping(
if is_same_node and node_assignment[other_rank] == 0: if is_same_node and node_assignment[other_rank] == 0:
node_assignment[other_rank] = next_node_id node_assignment[other_rank] = next_node_id
return next_node_id return next_node_id
\ No newline at end of file
...@@ -40,16 +40,21 @@ class KVCacheEvent( ...@@ -40,16 +40,21 @@ class KVCacheEvent(
"""Base class for all KV cache-related events""" """Base class for all KV cache-related events"""
MEDIUM_GPU = "GPU"
class BlockStored(KVCacheEvent): class BlockStored(KVCacheEvent):
block_hashes: list[int] block_hashes: list[int]
parent_block_hash: Optional[int] parent_block_hash: Optional[int]
token_ids: list[int] token_ids: list[int]
block_size: int block_size: int
lora_id: Optional[int] lora_id: Optional[int]
medium: Optional[str]
class BlockRemoved(KVCacheEvent): class BlockRemoved(KVCacheEvent):
block_hashes: list[int] block_hashes: list[int]
medium: Optional[str]
class AllBlocksCleared(KVCacheEvent): class AllBlocksCleared(KVCacheEvent):
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Distributed KV cache transfer # Distributed KV cache transfer
This folder implements distributed KV cache transfer across vLLM instances. This folder implements distributed KV cache transfer across vLLM instances.
Currently the main usecase is for disaggregated prefilling. Currently the main use case is for disaggregated prefilling.
## Abstractions ## Abstractions
...@@ -14,7 +14,7 @@ The KV cache transfer contains three layer of abstractions: ...@@ -14,7 +14,7 @@ The KV cache transfer contains three layer of abstractions:
Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer. Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer.
NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed NOTE: KV pipe layer is bypassable: you can skip this layer if your distributed
communication service already supports key-value-based lookup (like redis or communication service already supports key-value-based lookup (like redis or
RDMA database). RDMA database).
......
...@@ -19,6 +19,8 @@ The class provides the following primitives: ...@@ -19,6 +19,8 @@ The class provides the following primitives:
Returns whether KV cache should be freed now or will be Returns whether KV cache should be freed now or will be
freed asynchronously and optionally returns KV transfer freed asynchronously and optionally returns KV transfer
params. params.
take_events() - returns new KV events that were collected
by the connector since the last call.
Worker-side: runs in each worker, loads/saves KV cache to/from Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata. the Connector based on the metadata.
...@@ -34,6 +36,7 @@ The class provides the following primitives: ...@@ -34,6 +36,7 @@ The class provides the following primitives:
import enum import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
import torch import torch
...@@ -45,6 +48,7 @@ from vllm.v1.outputs import KVConnectorOutput ...@@ -45,6 +48,7 @@ from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -131,8 +135,8 @@ class KVConnectorBase_V1(ABC): ...@@ -131,8 +135,8 @@ class KVConnectorBase_V1(ABC):
Initialize with the KV caches. Useful for pre-registering the Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL). KV Caches in the KVConnector (e.g. for NIXL).
Args: kv_caches: Args:
dictionary of layer names, kv cache kv_caches: dictionary of layer names, kv cache
""" """
return return
...@@ -313,6 +317,15 @@ class KVConnectorBase_V1(ABC): ...@@ -313,6 +317,15 @@ class KVConnectorBase_V1(ABC):
""" """
return False, None return False, None
def take_events(self) -> Iterable["KVCacheEvent"]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
return ()
@classmethod @classmethod
def get_required_kvcache_layout( def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]: cls, vllm_config: "VllmConfig") -> Optional[str]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment