"vllm/vscode:/vscode.git/clone" did not exist on "6a6108511f251c2b8278a84e4266504c55e1f037"
Commit a810671a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori

parents 86b5aefe 6a09612b
...@@ -170,8 +170,7 @@ class PiecewiseBackend: ...@@ -170,8 +170,7 @@ class PiecewiseBackend:
range_entry = self._find_range_for_shape(runtime_shape) range_entry = self._find_range_for_shape(runtime_shape)
assert range_entry is not None, ( assert range_entry is not None, (
f"Shape out of considered range: {runtime_shape} " f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
"[1, max_num_batched_tokens]"
) )
self._maybe_compile_for_range_entry(range_entry, args) self._maybe_compile_for_range_entry(range_entry, args)
......
...@@ -437,14 +437,14 @@ class CompilationConfig: ...@@ -437,14 +437,14 @@ class CompilationConfig:
compile_ranges_split_points: list[int] | None = None compile_ranges_split_points: list[int] | None = None
"""Split points that represent compile ranges for inductor. """Split points that represent compile ranges for inductor.
The compile ranges are The compile ranges are
[1, split_points[0]], [1, split_points[0]],
[split_points[0] + 1, split_points[1]], ..., [split_points[0] + 1, split_points[1]], ...,
[split_points[-1] + 1, max_num_batched_tokens]. [split_points[-1] + 1, max_num_batched_tokens].
Compile sizes are also used single element ranges, Compile sizes are also used single element ranges,
the range is represented as [compile_sizes[i], compile_sizes[i]]. the range is represented as [compile_sizes[i], compile_sizes[i]].
If a range overlaps with the compile size, graph for compile size If a range overlaps with the compile size, graph for compile size
will be prioritized, i.e. if we have a range [1, 8] and a compile size 4, will be prioritized, i.e. if we have a range [1, 8] and a compile size 4,
graph for compile size 4 will be compiled and used instead of the graph graph for compile size 4 will be compiled and used instead of the graph
for range [1, 8]. for range [1, 8].
...@@ -899,7 +899,7 @@ class CompilationConfig: ...@@ -899,7 +899,7 @@ class CompilationConfig:
self.compute_bs_to_padded_graph_size() self.compute_bs_to_padded_graph_size()
def set_splitting_ops_for_v1( def set_splitting_ops_for_v1(
self, all2all_backend: str | None = None, data_parallel_size: int | None = None self, all2all_backend: str, data_parallel_size: int = 1
): ):
# To compatible with OOT hardware plugin platform (for example vllm-ascend) # To compatible with OOT hardware plugin platform (for example vllm-ascend)
# which currently only supports sequence parallelism in eager mode. # which currently only supports sequence parallelism in eager mode.
...@@ -934,7 +934,7 @@ class CompilationConfig: ...@@ -934,7 +934,7 @@ class CompilationConfig:
or self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE or self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
): ):
logger.warning_once( logger.warning_once(
"Using piecewise compilation with empty splitting_ops" "Using piecewise cudagraph with empty splitting_ops"
) )
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once( logger.warning_once(
...@@ -956,11 +956,9 @@ class CompilationConfig: ...@@ -956,11 +956,9 @@ class CompilationConfig:
self.splitting_ops = [] self.splitting_ops = []
# Disable CUDA graphs for DeepEP high-throughput since its not CG compatible # Disable CUDA graphs for DeepEP high-throughput since its not CG compatible
backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND
dp_size = data_parallel_size if data_parallel_size is not None else 1
if ( if (
backend == "deepep_high_throughput" all2all_backend == "deepep_high_throughput"
and dp_size > 1 and data_parallel_size > 1
and self.cudagraph_mode != CUDAGraphMode.NONE and self.cudagraph_mode != CUDAGraphMode.NONE
): ):
# TODO: Piecewise Cuda graph might be enabled # TODO: Piecewise Cuda graph might be enabled
......
...@@ -64,6 +64,9 @@ class ObservabilityConfig: ...@@ -64,6 +64,9 @@ class ObservabilityConfig:
module in the model and attach informations such as input/output shapes to module in the model and attach informations such as input/output shapes to
nvtx range markers. Noted that this doesn't work with CUDA graphs enabled.""" nvtx range markers. Noted that this doesn't work with CUDA graphs enabled."""
enable_mfu_metrics: bool = False
"""Enable Model FLOPs Utilization (MFU) metrics."""
@cached_property @cached_property
def collect_model_forward_time(self) -> bool: def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request.""" """Whether to collect model forward time for the request."""
......
...@@ -36,6 +36,14 @@ ExpertPlacementStrategy = Literal["linear", "round_robin"] ...@@ -36,6 +36,14 @@ ExpertPlacementStrategy = Literal["linear", "round_robin"]
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
DataParallelBackend = Literal["ray", "mp"] DataParallelBackend = Literal["ray", "mp"]
EPLBPolicyOption = Literal["default"] EPLBPolicyOption = Literal["default"]
All2AllBackend = Literal[
"naive",
"pplx",
"deepep_high_throughput",
"deepep_low_latency",
"allgather_reducescatter",
"flashinfer_all2allv",
]
@config @config
...@@ -126,24 +134,14 @@ class ParallelConfig: ...@@ -126,24 +134,14 @@ class ParallelConfig:
with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1 with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1
will have experts [1, 3]. This strategy can help improve load balancing will have experts [1, 3]. This strategy can help improve load balancing
for grouped expert models with no redundant experts.""" for grouped expert models with no redundant experts."""
all2all_backend: ( all2all_backend: All2AllBackend = "allgather_reducescatter"
Literal[ """All2All backend for MoE expert parallel communication. Available options:
"naive",
"pplx", - "naive": Naive all2all implementation using broadcasts\n
"deepep_high_throughput", - "allgather_reducescatter": All2all based on allgather and reducescatter\n
"deepep_low_latency", - "pplx": Use pplx kernels\n
"allgather_reducescatter", - "deepep_high_throughput": Use deepep high-throughput kernels\n
"flashinfer_all2allv", - "deepep_low_latency": Use deepep low-latency kernels\n
]
| None
) = None
"""All2All backend for MoE expert parallel communication. If not set, uses
the value from VLLM_ALL2ALL_BACKEND environment variable. Available options:
- "naive": Naive all2all implementation using broadcasts
- "allgather_reducescatter": All2all based on allgather and reducescatter
- "pplx": Use pplx kernels
- "deepep_high_throughput": Use deepep high-throughput kernels
- "deepep_low_latency": Use deepep low-latency kernels
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl""" - "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
max_parallel_loading_workers: int | None = None max_parallel_loading_workers: int | None = None
...@@ -156,6 +154,8 @@ class ParallelConfig: ...@@ -156,6 +154,8 @@ class ParallelConfig:
enable_dbo: bool = False enable_dbo: bool = False
"""Enable dual batch overlap for the model executor.""" """Enable dual batch overlap for the model executor."""
ubatch_size: int = 0
"""Number of ubatch size."""
dbo_decode_token_threshold: int = 32 dbo_decode_token_threshold: int = 32
"""The threshold for dual batch overlap for batches only containing decodes. """The threshold for dual batch overlap for batches only containing decodes.
...@@ -325,6 +325,14 @@ class ParallelConfig: ...@@ -325,6 +325,14 @@ class ParallelConfig:
including data parallelism.""" including data parallelism."""
return self.world_size * self.data_parallel_size return self.world_size * self.data_parallel_size
@property
def use_ubatching(self) -> bool:
return self.enable_dbo or self.ubatch_size > 1
@property
def num_ubatches(self) -> int:
return 2 if self.enable_dbo else self.ubatch_size
def get_next_dp_init_port(self) -> int: def get_next_dp_init_port(self) -> int:
""" """
We might need to initialize process groups in multiple We might need to initialize process groups in multiple
...@@ -485,20 +493,17 @@ class ParallelConfig: ...@@ -485,20 +493,17 @@ class ParallelConfig:
from vllm.config.utils import get_hash_factors, hash_factors from vllm.config.utils import get_hash_factors, hash_factors
factors = get_hash_factors(self, ignored_factors) factors = get_hash_factors(self, ignored_factors)
# Explicitly include backend affecting env factor as before
factors["VLLM_ALL2ALL_BACKEND"] = str(envs.VLLM_ALL2ALL_BACKEND)
return hash_factors(factors) return hash_factors(factors)
def __post_init__(self) -> None: def __post_init__(self) -> None:
# Set all2all_backend from env var if not specified, with deprecation warning # Set all2all_backend from env var if not specified, with deprecation warning
if self.all2all_backend is None: if envs.is_set("VLLM_ALL2ALL_BACKEND"):
logger.warning_once(
"VLLM_ALL2ALL_BACKEND environment variable is deprecated and "
"will be removed in v0.15.0. Please use the "
"--all2all-backend command-line argument instead."
)
self.all2all_backend = envs.VLLM_ALL2ALL_BACKEND self.all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if envs.is_set("VLLM_ALL2ALL_BACKEND"):
logger.warning_once(
"VLLM_ALL2ALL_BACKEND environment variable is deprecated and "
"will be removed in a future release. Please use the "
"--all2all-backend command-line argument instead."
)
# Continue with the rest of the initialization # Continue with the rest of the initialization
self.world_size = ( self.world_size = (
......
...@@ -870,9 +870,12 @@ class VllmConfig: ...@@ -870,9 +870,12 @@ class VllmConfig:
f"cudagraph_mode={self.compilation_config.cudagraph_mode}" f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
) )
if self.parallel_config.enable_dbo: if self.parallel_config.use_ubatching:
a2a_backend = self.parallel_config.all2all_backend a2a_backend = self.parallel_config.all2all_backend
assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], ( assert a2a_backend in [
"deepep_low_latency",
"deepep_high_throughput",
], (
"Microbatching currently only supports the deepep_low_latency and " "Microbatching currently only supports the deepep_low_latency and "
f"deepep_high_throughput all2all backend. {a2a_backend} is not " f"deepep_high_throughput all2all backend. {a2a_backend} is not "
"supported. To fix use --all2all-backend=deepep_low_latency or " "supported. To fix use --all2all-backend=deepep_low_latency or "
......
...@@ -64,7 +64,12 @@ class NaiveAll2AllManager(All2AllManagerBase): ...@@ -64,7 +64,12 @@ class NaiveAll2AllManager(All2AllManagerBase):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size = self.tp_group.world_size if is_sequence_parallel else 1 sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None assert dp_metadata is not None
...@@ -76,6 +81,7 @@ class NaiveAll2AllManager(All2AllManagerBase): ...@@ -76,6 +81,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
router_logits = self.naive_multicast( router_logits = self.naive_multicast(
router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
) )
return hidden_states, router_logits return hidden_states, router_logits
def combine( def combine(
...@@ -113,7 +119,11 @@ class AgRsAll2AllManager(All2AllManagerBase): ...@@ -113,7 +119,11 @@ class AgRsAll2AllManager(All2AllManagerBase):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]: extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
""" """
Gather hidden_states and router_logits from all dp ranks. Gather hidden_states and router_logits from all dp ranks.
""" """
...@@ -121,15 +131,22 @@ class AgRsAll2AllManager(All2AllManagerBase): ...@@ -121,15 +131,22 @@ class AgRsAll2AllManager(All2AllManagerBase):
assert dp_metadata is not None assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank() sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0] assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
hidden_states, router_logits = dist_group.all_gatherv(
[hidden_states, router_logits], tensors_to_gather = [hidden_states, router_logits]
if extra_tensors is not None:
tensors_to_gather.extend(extra_tensors)
gathered_tensors = dist_group.all_gatherv(
tensors_to_gather,
dim=0, dim=0,
sizes=sizes, sizes=sizes,
) )
return hidden_states, router_logits
if extra_tensors is not None:
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
return gathered_tensors[0], gathered_tensors[1]
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
...@@ -204,6 +221,7 @@ class PPLXAll2AllManager(All2AllManagerBase): ...@@ -204,6 +221,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
...@@ -251,6 +269,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): ...@@ -251,6 +269,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading import threading
from typing import Any
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
import torch import torch
...@@ -68,7 +69,11 @@ class All2AllManagerBase: ...@@ -68,7 +69,11 @@ class All2AllManagerBase:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
): extra_tensors: list[torch.Tensor] | None = None,
) -> Any:
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise NotImplementedError raise NotImplementedError
def set_num_sms(self, num_sms: int): def set_num_sms(self, num_sms: int):
......
...@@ -318,17 +318,23 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -318,17 +318,23 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list return output_list
def dispatch( def dispatch( # type: ignore[override]
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]: extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch( return self.all2all_manager.dispatch(
hidden_states, router_logits, is_sequence_parallel hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
) )
return hidden_states, router_logits
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
......
...@@ -73,6 +73,7 @@ class ECExampleConnector(ECConnectorBase): ...@@ -73,6 +73,7 @@ class ECExampleConnector(ECConnectorBase):
data hashes (`mm_hash`) to encoder cache tensors. data hashes (`mm_hash`) to encoder cache tensors.
kwargs (dict): Additional keyword arguments for the connector. kwargs (dict): Additional keyword arguments for the connector.
""" """
from vllm.platforms import current_platform
# Get the metadata # Get the metadata
metadata: ECConnectorMetadata = self._get_connector_metadata() metadata: ECConnectorMetadata = self._get_connector_metadata()
...@@ -91,7 +92,9 @@ class ECExampleConnector(ECConnectorBase): ...@@ -91,7 +92,9 @@ class ECExampleConnector(ECConnectorBase):
if mm_data.mm_hash in encoder_cache: if mm_data.mm_hash in encoder_cache:
continue continue
filename = self._generate_filename_debug(mm_data.mm_hash) filename = self._generate_filename_debug(mm_data.mm_hash)
ec_cache = safetensors.torch.load_file(filename)["ec_cache"].cuda() ec_cache = safetensors.torch.load_file(
filename, device=current_platform.device_type
)["ec_cache"]
encoder_cache[mm_data.mm_hash] = ec_cache encoder_cache[mm_data.mm_hash] = ec_cache
logger.debug("Success load encoder cache for hash %s", mm_data.mm_hash) logger.debug("Success load encoder cache for hash %s", mm_data.mm_hash)
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
KV cache helper for store. KV cache helper for store.
""" """
from collections.abc import Iterator
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING, Literal
...@@ -21,6 +22,8 @@ if TYPE_CHECKING: ...@@ -21,6 +22,8 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
EngineId = str
def get_kv_connector_cache_layout(): def get_kv_connector_cache_layout():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is # NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
...@@ -201,6 +204,26 @@ def copy_kv_blocks( ...@@ -201,6 +204,26 @@ def copy_kv_blocks(
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
def yield_req_data(
scheduler_output,
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
"""
Yields:
(req_id, new_block_id_groups, preempted)
"""
# new requests
for req_data in scheduler_output.scheduled_new_reqs:
yield req_data.req_id, req_data.block_ids, False
# cached requests
cached_reqs = scheduler_output.scheduled_cached_reqs
yield from zip(
cached_reqs.req_ids,
cached_reqs.new_block_ids,
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
)
@dataclass @dataclass
class TpKVTopology: class TpKVTopology:
""" """
...@@ -209,12 +232,12 @@ class TpKVTopology: ...@@ -209,12 +232,12 @@ class TpKVTopology:
""" """
tp_rank: int tp_rank: int
remote_tp_size: dict[str, int] remote_tp_size: dict[EngineId, int]
is_mla: bool is_mla: bool
total_num_kv_heads: int total_num_kv_heads: int
attn_backend: type[AttentionBackend] attn_backend: type[AttentionBackend]
engine_id: str engine_id: EngineId
remote_block_size: dict[str, int] remote_block_size: dict[EngineId, int]
def __post_init__(self): def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V # Figure out whether the first dimension of the cache is K/V
...@@ -256,18 +279,28 @@ class TpKVTopology: ...@@ -256,18 +279,28 @@ class TpKVTopology:
Calculate the tensor parallel ratio between local and remote TP. Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`. groups of size `tp_ratio`.If remote tp_size > local tp_size, the
ratio is flipped (remote_size/local_size) and the returned value is
negative.
""" """
assert self.tp_size % remote_tp_size == 0, ( if self.tp_size >= remote_tp_size:
f"Local tensor parallel size {self.tp_size} is not divisible " assert self.tp_size % remote_tp_size == 0, (
f"by remote tensor parallel size {remote_tp_size}." f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
assert remote_tp_size % self.tp_size == 0, (
f"Remote tensor parallel size {remote_tp_size} is not divisible "
f"by local tensor parallel size {self.tp_size}."
) )
return self.tp_size // remote_tp_size # P TP > D TP case, return the ratio as negative
return -remote_tp_size // self.tp_size
def block_size_ratio( def block_size_ratio(
self, self,
remote_block_size: int, remote_block_size: int,
) -> float: ) -> int:
""" """
Calculate the block size ratio between local and remote TP. Calculate the block size ratio between local and remote TP.
""" """
...@@ -279,19 +312,19 @@ class TpKVTopology: ...@@ -279,19 +312,19 @@ class TpKVTopology:
def tp_ratio_from_engine_id( def tp_ratio_from_engine_id(
self, self,
remote_engine_id: str, remote_engine_id: EngineId,
) -> int: ) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id] remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size) return self.tp_ratio(remote_tp_size)
def block_size_ratio_from_engine_id( def block_size_ratio_from_engine_id(
self, self,
remote_engine_id: str, remote_engine_id: EngineId,
) -> float: ) -> int:
remote_block_size = self.remote_block_size[remote_engine_id] remote_block_size = self.remote_block_size[remote_engine_id]
return self.block_size_ratio(remote_block_size) return self.block_size_ratio(remote_block_size)
def is_kv_replicated(self, engine_id: str) -> bool: def is_kv_replicated(self, engine_id: EngineId) -> bool:
""" """
Whether the KV cache is replicated across TP workers due to the Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads. number of TP workers being greater than the number of KV heads.
...@@ -299,24 +332,30 @@ class TpKVTopology: ...@@ -299,24 +332,30 @@ class TpKVTopology:
tp_size = self.remote_tp_size[engine_id] tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1 return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(self, remote_engine_id: str) -> bool: def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
# MLA is always replicated as the hidden dim can't be split. # MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id) return self.is_mla or self.is_kv_replicated(remote_engine_id)
def get_target_remote_rank( def get_target_remote_ranks(
self, self,
remote_tp_size: int, remote_tp_size: int,
) -> int: ) -> list[int]:
""" """
Get the remote TP rank (on P) that the current local TP rank Get the remote TP rank (on P) that the current local TP rank
(on D) will read from. (on D) will read from. When remote tp_size > local tp_size, we
read from multiple remote ranks.
""" """
tp_ratio = self.tp_ratio(remote_tp_size) tp_ratio = self.tp_ratio(remote_tp_size)
return self.tp_rank // tp_ratio if tp_ratio > 0:
return [self.tp_rank // tp_ratio]
# P TP > D TP case, D reads from |tp_ratio| remote workers.
tp_ratio = -tp_ratio
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
def get_target_remote_rank_from_engine_id( def get_target_remote_ranks_from_engine_id(
self, self,
remote_engine_id: str, remote_engine_id: EngineId,
) -> int: ) -> list[int]:
remote_tp_size = self.remote_tp_size[remote_engine_id] remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_rank(remote_tp_size) return self.get_target_remote_ranks(remote_tp_size)
...@@ -147,6 +147,14 @@ class LMCacheMPSchedulerAdapter: ...@@ -147,6 +147,14 @@ class LMCacheMPSchedulerAdapter:
""" """
return self.blocks_in_chunk return self.blocks_in_chunk
def _cleanup_lookup_result(self, request_id: str) -> None:
"""
Clean up lookup future for a finished request to prevent memory leak.
Args:
request_id: The ID of the finished request.
"""
self.lookup_futures.pop(request_id, None)
# Helper functions # Helper functions
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey: def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
"""Convert a block hash to an IPC cache engine key""" """Convert a block hash to an IPC cache engine key"""
...@@ -262,6 +270,7 @@ class LMCacheMPWorkerAdapter: ...@@ -262,6 +270,7 @@ class LMCacheMPWorkerAdapter:
): ):
keys = [] keys = []
block_ids = [] block_ids = []
for op in ops: for op in ops:
keys.extend(self._block_hashes_to_keys(op.block_hashes)) keys.extend(self._block_hashes_to_keys(op.block_hashes))
block_ids.extend(op.block_ids) block_ids.extend(op.block_ids)
......
...@@ -24,6 +24,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import ( ...@@ -24,6 +24,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import (
) )
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.request import RequestStatus
from vllm.v1.utils import ConstantList from vllm.v1.utils import ConstantList
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -211,7 +212,7 @@ class LMCacheMPRequestTracker: ...@@ -211,7 +212,7 @@ class LMCacheMPRequestTracker:
""" """
self.num_stored_blocks += num_new_blocks self.num_stored_blocks += num_new_blocks
def update_block_ids( def append_block_ids(
self, self,
new_block_ids: list[int], new_block_ids: list[int],
): ):
...@@ -455,10 +456,6 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -455,10 +456,6 @@ class LMCacheMPConnector(KVConnectorBase_V1):
metadata = self._get_connector_metadata() metadata = self._get_connector_metadata()
assert isinstance(metadata, LMCacheMPConnectorMetadata) assert isinstance(metadata, LMCacheMPConnectorMetadata)
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
request_ids = [] request_ids = []
ops = [] ops = []
...@@ -468,10 +465,14 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -468,10 +465,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
request_ids.append(meta.request_id) request_ids.append(meta.request_id)
ops.append(meta.op) ops.append(meta.op)
if len(request_ids) > 0: if len(request_ids) == 0:
self.worker_adapter.batched_submit_retrieve_requests( return
request_ids, ops, event
) with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
self.worker_adapter.batched_submit_retrieve_requests(request_ids, ops, event)
def wait_for_layer_load(self, layer_name: str) -> None: def wait_for_layer_load(self, layer_name: str) -> None:
""" """
...@@ -518,10 +519,6 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -518,10 +519,6 @@ class LMCacheMPConnector(KVConnectorBase_V1):
metadata = self._get_connector_metadata() metadata = self._get_connector_metadata()
assert isinstance(metadata, LMCacheMPConnectorMetadata) assert isinstance(metadata, LMCacheMPConnectorMetadata)
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
request_ids = [] request_ids = []
ops = [] ops = []
for meta in metadata.requests: for meta in metadata.requests:
...@@ -530,8 +527,14 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -530,8 +527,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
request_ids.append(meta.request_id) request_ids.append(meta.request_id)
ops.append(meta.op) ops.append(meta.op)
if len(request_ids) > 0: if len(request_ids) == 0:
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event) return
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event)
def get_finished( def get_finished(
self, finished_req_ids: set[str] self, finished_req_ids: set[str]
...@@ -627,6 +630,9 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -627,6 +630,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
into account. into account.
""" """
tracker = self._get_or_create_request_tracker(request) tracker = self._get_or_create_request_tracker(request)
# TODO: support loading KV for preempted requests in the future
if request.status == RequestStatus.PREEMPTED:
return 0, False
self.scheduler_adapter.maybe_submit_lookup_request( self.scheduler_adapter.maybe_submit_lookup_request(
request.request_id, convert_block_hashes_to_bytes(request.block_hashes) request.request_id, convert_block_hashes_to_bytes(request.block_hashes)
...@@ -683,7 +689,7 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -683,7 +689,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
# No matter we need to retrieve or not, we need to update # No matter we need to retrieve or not, we need to update
# the block ids into the tracker # the block ids into the tracker
tracker.update_block_ids(block_ids) tracker.append_block_ids(block_ids)
# Update the state of the tracker # Update the state of the tracker
condition = tracker.needs_retrieve() condition = tracker.needs_retrieve()
...@@ -695,6 +701,8 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -695,6 +701,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
if condition if condition
else LMCacheMPRequestState.READY else LMCacheMPRequestState.READY
) )
# Clean up lookup future in scheduler adapter
self.scheduler_adapter._cleanup_lookup_result(request.request_id)
def build_connector_meta( def build_connector_meta(
self, scheduler_output: SchedulerOutput self, scheduler_output: SchedulerOutput
...@@ -748,6 +756,8 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -748,6 +756,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
Optional KVTransferParams to be included in the request outputs Optional KVTransferParams to be included in the request outputs
returned by the engine. returned by the engine.
""" """
# Clean up request tracker to prevent memory leak
self._cleanup_request_tracker(request.request_id)
return True, None return True, None
def take_events(self) -> Iterable["KVCacheEvent"]: def take_events(self) -> Iterable["KVCacheEvent"]:
...@@ -866,7 +876,8 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -866,7 +876,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
# Update block ids # Update block ids
new_block_ids = reformat_block_ids(cached_reqs.new_block_ids[idx]) new_block_ids = reformat_block_ids(cached_reqs.new_block_ids[idx])
request_tracker.update_block_ids(new_block_ids) if request_id not in cached_reqs.resumed_req_ids:
request_tracker.append_block_ids(new_block_ids)
# Update new scheduled tokens # Update new scheduled tokens
num_new_tokens = cached_reqs.num_computed_tokens[idx] num_new_tokens = cached_reqs.num_computed_tokens[idx]
...@@ -889,7 +900,34 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -889,7 +900,34 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self, request: "Request" self, request: "Request"
) -> LMCacheMPRequestTracker: ) -> LMCacheMPRequestTracker:
request_id = request.request_id request_id = request.request_id
# Remove the old trackers that is created before the preemption
if (
request.status == RequestStatus.PREEMPTED
and request_id in self.request_trackers
):
tracker = self.request_trackers[request_id]
# NOTE: since this function may be called multiple times
# for a single request (because get_num_new_matched_tokens
# may be called multiple times) for the same request, we
# will only do the remove if the tracker is not in the "fresh"
# state, i.e., PREFETCHING
if tracker.state != LMCacheMPRequestState.PREFETCHING:
self.request_trackers.pop(request_id)
if request_id not in self.request_trackers: if request_id not in self.request_trackers:
new_tracker = LMCacheMPRequestTracker(request) new_tracker = LMCacheMPRequestTracker(request)
self.request_trackers[request_id] = new_tracker self.request_trackers[request_id] = new_tracker
return self.request_trackers[request_id] return self.request_trackers[request_id]
def _cleanup_request_tracker(self, request_id: str) -> None:
"""
Clean up request tracker and associated lookup future for a request.
This should be called when a request is finished to prevent memory leak.
"""
# Clean up request tracker
if self.request_trackers.pop(request_id, None):
logger.debug(
"[KVConnector] Cleaned up request_tracker for request %s",
request_id,
)
...@@ -23,7 +23,11 @@ from vllm import envs ...@@ -23,7 +23,11 @@ from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId,
TpKVTopology,
yield_req_data,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp, CopyBlocksOp,
KVConnectorBase_V1, KVConnectorBase_V1,
...@@ -56,7 +60,6 @@ if TYPE_CHECKING: ...@@ -56,7 +60,6 @@ if TYPE_CHECKING:
from vllm.v1.request import Request from vllm.v1.request import Request
TransferHandle = int TransferHandle = int
EngineId = str
ReqId = str ReqId = str
# #
...@@ -482,7 +485,7 @@ class NixlConnectorScheduler: ...@@ -482,7 +485,7 @@ class NixlConnectorScheduler:
# New requests are added by update_state_after_alloc in # New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker. # the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} self._reqs_need_save: dict[ReqId, Request] = {}
# Reqs to send and their expiration time # Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {} self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_in_batch: set[ReqId] = set() self._reqs_in_batch: set[ReqId] = set()
...@@ -628,16 +631,7 @@ class NixlConnectorScheduler: ...@@ -628,16 +631,7 @@ class NixlConnectorScheduler:
if self.use_host_buffer and params.get("do_remote_decode"): if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl, # NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer. # prefilled blocks need to be saved to host memory before transfer.
self._reqs_need_save[request.request_id] = request
# save all blocks
block_ids = blocks.get_block_ids()[0]
# TODO: skip the blocks that are already in the host xfer buffer.
# Currently, the host xfer buffer block is 1-to-1 mapped to device
# kv blocks, so host blocks won't be flushed as long as its device
# block is not overwritten; and it will be safe to skip saving them
# to host xfer buffer.
if block_ids:
self._reqs_need_save[request.request_id] = (request, block_ids)
elif params.get("do_remote_prefill"): elif params.get("do_remote_prefill"):
if params.get("remote_block_ids"): if params.get("remote_block_ids"):
if all( if all(
...@@ -689,13 +683,32 @@ class NixlConnectorScheduler: ...@@ -689,13 +683,32 @@ class NixlConnectorScheduler:
kv_transfer_params=req.kv_transfer_params, kv_transfer_params=req.kv_transfer_params,
) )
for req_id, (req, block_ids) in self._reqs_need_save.items(): # NOTE: For the prefill side, there might be a chance that an early added
# request is a chunked prefill, so we need to check if new blocks are added
for req_id, new_block_id_groups, _ in yield_req_data(scheduler_output):
req_to_save = self._reqs_need_save.get(req_id)
if req_to_save is None or new_block_id_groups is None:
continue
req = req_to_save
assert req.kv_transfer_params is not None assert req.kv_transfer_params is not None
meta.add_new_req_to_save( meta.add_new_req_to_save(
request_id=req_id, request_id=req_id,
local_block_ids=block_ids, local_block_ids=new_block_id_groups[0],
kv_transfer_params=req.kv_transfer_params, kv_transfer_params=req.kv_transfer_params,
) )
assert scheduler_output.num_scheduled_tokens is not None
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
is_partial = (
req.num_computed_tokens + num_scheduled_tokens
) < req.num_prompt_tokens
if not is_partial:
# For non-partial prefills, once new req_meta is scheduled, it
# can be removed from _reqs_need_save.
# For partial prefill case, we will retain the request in
# _reqs_need_save until all blocks are scheduled with req_meta.
# Therefore, only pop if `not is_partial`.
self._reqs_need_save.pop(req_id)
meta.reqs_to_send = self._reqs_need_send meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch meta.reqs_in_batch = self._reqs_in_batch
...@@ -703,7 +716,6 @@ class NixlConnectorScheduler: ...@@ -703,7 +716,6 @@ class NixlConnectorScheduler:
# Clear the list once workers start the transfers # Clear the list once workers start the transfers
self._reqs_need_recv.clear() self._reqs_need_recv.clear()
self._reqs_need_save.clear()
self._reqs_in_batch = set() self._reqs_in_batch = set()
self._reqs_not_processed = set() self._reqs_not_processed = set()
self._reqs_need_send = {} self._reqs_need_send = {}
...@@ -749,6 +761,8 @@ class NixlConnectorScheduler: ...@@ -749,6 +761,8 @@ class NixlConnectorScheduler:
# Also include the case of a P/D Prefill request with immediate # Also include the case of a P/D Prefill request with immediate
# block free (eg abort). Stop tracking this request. # block free (eg abort). Stop tracking this request.
self._reqs_not_processed.add(request.request_id) self._reqs_not_processed.add(request.request_id)
# Clear _reqs_need_save if a request is aborted as partial prefill.
self._reqs_need_save.pop(request.request_id, None)
return False, None return False, None
# TODO: check whether block_ids actually ever be 0. If not we could # TODO: check whether block_ids actually ever be 0. If not we could
...@@ -873,9 +887,10 @@ class NixlConnectorWorker: ...@@ -873,9 +887,10 @@ class NixlConnectorWorker:
self.copy_blocks: CopyBlocksOp | None = None self.copy_blocks: CopyBlocksOp | None = None
# Map of engine_id -> kv_caches_base_addr. For TP case, each local # Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
self.device_id: int = 0 self.device_id: int = 0
# Current rank may pull from multiple remote TP workers.
# EngineId, dict[int, list[int]] -> engine_id, tp_rank, base_addr_for_layer
self.kv_caches_base_addr = defaultdict[EngineId, dict[int, list[int]]](dict)
# Number of NIXL regions. Currently one region per cache # Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer) # (so 1 per layer for MLA, otherwise 2 per layer)
...@@ -883,10 +898,12 @@ class NixlConnectorWorker: ...@@ -883,10 +898,12 @@ class NixlConnectorWorker:
self.num_layers = 0 self.num_layers = 0
# nixl_prepped_dlist_handle. # nixl_prepped_dlist_handle.
self.src_xfer_side_handle: int = 0 self.src_xfer_handles_by_block_size: dict[int, int] = {}
self.src_xfer_side_handles: dict[int, int] = {} # Populated dynamically during handshake based on remote configuration.
# Map of engine_id -> nixl_prepped_dlist_handle (int)]. # Keep track of regions at different tp_ratio values. tp_ratio->handles
self.dst_xfer_side_handles: dict[EngineId, int] = {} self.src_xfer_handles_by_tp_ratio: dict[int, list[int]] = {}
# Map of engine_id -> {tp_rank: nixl_prepped_dlist_handle (int)}.
self.dst_xfer_side_handles = defaultdict[EngineId, dict[int, int]](dict)
# Map of engine_id -> num_blocks. All ranks in the same deployment will # Map of engine_id -> num_blocks. All ranks in the same deployment will
# have the same number of blocks. # have the same number of blocks.
...@@ -977,103 +994,108 @@ class NixlConnectorWorker: ...@@ -977,103 +994,108 @@ class NixlConnectorWorker:
expected_engine_id: str, expected_engine_id: str,
) -> dict[int, str]: ) -> dict[int, str]:
"""Do a NIXL handshake with a remote instance.""" """Do a NIXL handshake with a remote instance."""
# When target instance TP > local TP, we need to perform multiple
start_time = time.perf_counter() # handshakes. Do it in a single background job for simplicity.
# Regardless, only handshake with the remote TP rank(s) that current
# NOTE(rob): we need each rank to have a unique port. This is # local rank will read from. Note that With homogeneous TP,
# a hack to keep us moving. We will switch when moving to etcd # this happens to be the same single rank_i.
# or where we have a single ZMQ socket in the scheduler. p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size)
remote_rank_to_agent_name = {}
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size)
path = make_zmq_path("tcp", host, port) path = make_zmq_path("tcp", host, port)
logger.debug(
"Querying metadata on path: %s at remote tp rank %s", path, p_remote_rank
)
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock: with zmq_ctx(zmq.REQ, path) as sock:
msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank)) for remote_rank in p_remote_ranks:
# Set receive timeout to 5 seconds to avoid hanging on dead server logger.debug(
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds "Querying metadata on path: %s at remote tp rank %s",
sock.send(msg) path,
handshake_bytes = sock.recv() remote_rank,
# Decode handshake payload to get compatibility hash
handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload)
try:
handshake_payload = handshake_decoder.decode(handshake_bytes)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
raise RuntimeError(
f"Failed to decode NixlHandshakePayload. This likely indicates "
f"an incompatibility between connector version. Error: {e}"
) from e
got_metadata_time = time.perf_counter()
logger.debug(
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
)
# Check compatibility hash BEFORE decoding agent metadata
if (
self.enforce_compat_hash
and handshake_payload.compatibility_hash != self.compat_hash
):
raise RuntimeError(
f"NIXL compatibility hash mismatch. "
f"Local: {self.compat_hash}, "
f"Remote: {handshake_payload.compatibility_hash}. "
f"Prefill and decode instances have incompatible configurations. "
f"This may be due to: different vLLM versions, models, dtypes, "
f"KV cache layouts, attention backends, etc. "
f"Both instances must use identical configurations."
f"Disable this check using "
f'--kv-transfer-config \'{{"kv_connector_extra_config": '
f'{{"enforce_handshake_compat": false}}}}\''
) )
logger.info( start_time = time.perf_counter()
"NIXL compatibility check passed (hash: %s)", # Send query for the request.
handshake_payload.compatibility_hash, msg = msgspec.msgpack.encode((GET_META_MSG, remote_rank))
) # Set receive timeout to 5 seconds to avoid hanging on dead server
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
sock.send(msg)
handshake_bytes = sock.recv()
# Decode agent metadata # Decode handshake payload to get compatibility hash
metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload)
try: try:
metadata = metadata_decoder.decode( handshake_payload = handshake_decoder.decode(handshake_bytes)
handshake_payload.agent_metadata_bytes except (msgspec.DecodeError, msgspec.ValidationError) as e:
raise RuntimeError(
f"Failed to decode NixlHandshakePayload. This likely indicates "
f"an incompatibility between connector version. Error: {e}"
) from e
got_metadata_time = time.perf_counter()
logger.debug(
"NIXL handshake: get metadata took: %s",
got_metadata_time - start_time,
) )
except (msgspec.DecodeError, msgspec.ValidationError) as e:
# This should not happen if hash matched
raise RuntimeError(
f"Failed to decode NixlAgentMetadata. Error: {e}"
) from e
# Ensure engine id matches. # Check compatibility hash BEFORE decoding agent metadata
if metadata.engine_id != expected_engine_id: if (
raise RuntimeError( self.enforce_compat_hash
f"Remote NIXL agent engine ID mismatch. " and handshake_payload.compatibility_hash != self.compat_hash
f"Expected {expected_engine_id}," ):
f"received {metadata.engine_id}." raise RuntimeError(
) f"NIXL compatibility hash mismatch. "
f"Local: {self.compat_hash}, "
f"Remote: {handshake_payload.compatibility_hash}. "
f"Prefill and decode instances have incompatible "
f"configurations. This may be due to: different vLLM versions,"
f" models, dtypes, KV cache layouts, attention backends, etc. "
f"Both instances must use identical configurations."
f"Disable this check using "
f'--kv-transfer-config \'{{"kv_connector_extra_config": '
f'{{"enforce_handshake_compat": false}}}}\''
)
# Register Remote agent. logger.info(
assert metadata.block_size <= self.block_size, ( "NIXL compatibility check passed (hash: %s)",
"nP > nD is not supported yet." handshake_payload.compatibility_hash,
) )
remote_agent_name = self.add_remote_agent(
metadata, p_remote_rank, remote_tp_size
)
setup_agent_time = time.perf_counter() # Decode agent metadata
logger.debug( metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
"NIXL handshake: add agent took: %s", try:
setup_agent_time - got_metadata_time, metadata = metadata_decoder.decode(
) handshake_payload.agent_metadata_bytes
)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
# This should not happen if hash matched
raise RuntimeError(
f"Failed to decode NixlAgentMetadata. Error: {e}"
) from e
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(
f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
)
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(
f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
)
setup_agent_time = time.perf_counter()
# Remote rank -> agent name. # Register Remote agent.
return {p_remote_rank: remote_agent_name} remote_agent_name = self.add_remote_agent(
metadata, remote_rank, remote_tp_size
)
logger.debug(
"NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time,
)
remote_rank_to_agent_name[remote_rank] = remote_agent_name
return remote_rank_to_agent_name
def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None: def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None:
""" """
...@@ -1283,7 +1305,7 @@ class NixlConnectorWorker: ...@@ -1283,7 +1305,7 @@ class NixlConnectorWorker:
assert len(self.block_len_per_layer) == len(seen_base_addresses) assert len(self.block_len_per_layer) == len(seen_base_addresses)
assert self.num_blocks != 0 assert self.num_blocks != 0
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
self.num_regions = len(caches_data) self.num_regions = len(caches_data)
self.num_layers = len(xfer_buffers.keys()) self.num_layers = len(xfer_buffers.keys())
...@@ -1310,9 +1332,9 @@ class NixlConnectorWorker: ...@@ -1310,9 +1332,9 @@ class NixlConnectorWorker:
# Register local/src descr for NIXL xfer. # Register local/src descr for NIXL xfer.
self.seen_base_addresses = seen_base_addresses self.seen_base_addresses = seen_base_addresses
self.src_xfer_side_handle = self.register_local_xfer_handler(self.block_size) self.src_xfer_handles_by_block_size[self.block_size], self.src_blocks_data = (
self.register_local_xfer_handler(self.block_size)
self.src_xfer_side_handles[self.block_size] = self.src_xfer_side_handle )
# TODO(mgoin): Hybrid memory allocator is currently disabled for # TODO(mgoin): Hybrid memory allocator is currently disabled for
# models with local attention (Llama 4). Can remove this once enabled. # models with local attention (Llama 4). Can remove this once enabled.
...@@ -1340,8 +1362,8 @@ class NixlConnectorWorker: ...@@ -1340,8 +1362,8 @@ class NixlConnectorWorker:
agent_metadata = NixlAgentMetadata( agent_metadata = NixlAgentMetadata(
engine_id=self.engine_id, engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(), agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
device_id=self.device_id, device_id=self.device_id,
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank],
num_blocks=self.num_blocks, num_blocks=self.num_blocks,
block_lens=self.block_len_per_layer, block_lens=self.block_len_per_layer,
kv_cache_layout=self.kv_cache_layout kv_cache_layout=self.kv_cache_layout
...@@ -1359,7 +1381,7 @@ class NixlConnectorWorker: ...@@ -1359,7 +1381,7 @@ class NixlConnectorWorker:
def register_local_xfer_handler( def register_local_xfer_handler(
self, self,
block_size: int, block_size: int,
) -> int: ) -> tuple[int, list[tuple[int, int, int]]]:
""" """
Function used for register local xfer handler with local block_size or Function used for register local xfer handler with local block_size or
Remote block_size. Remote block_size.
...@@ -1407,7 +1429,7 @@ class NixlConnectorWorker: ...@@ -1407,7 +1429,7 @@ class NixlConnectorWorker:
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
# NIXL_INIT_AGENT to be used for preparations of local descs. # NIXL_INIT_AGENT to be used for preparations of local descs.
return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs), blocks_data
def add_remote_agent( def add_remote_agent(
self, self,
...@@ -1421,10 +1443,12 @@ class NixlConnectorWorker: ...@@ -1421,10 +1443,12 @@ class NixlConnectorWorker:
In particular, handle both homogeneous and heterogeneous TP. The former In particular, handle both homogeneous and heterogeneous TP. The former
requires local rank_i to read from remote rank_i. requires local rank_i to read from remote rank_i.
The latter, assuming D.world_size > P.world_size, requires that two or The latter, in the case of D.world_size < P.world_size, requires that a
more local TP worker share the xfer from a single TP worker. local (D) TP worker reads from multiple remote (P) TP workers.
Conversely, assuming D.world_size > P.world_size, two or more local TP
workers will read from a single remote TP worker.
Here's an example (non-MLA case): Here's an example for the last case described above (non-MLA):
rank_offset p_remote_tp_rank rank_offset p_remote_tp_rank
(kv split no) (kv split no)
...@@ -1474,9 +1498,6 @@ class NixlConnectorWorker: ...@@ -1474,9 +1498,6 @@ class NixlConnectorWorker:
nixl_agent_meta.agent_metadata nixl_agent_meta.agent_metadata
) )
# Handle tp_size>num_kv_heads: replicate KV cache.
replicates_kv_cache = self.kv_topo.replicates_kv_cache(engine_id)
# Create dst descs and xfer side handles. TP workers have same #blocks # Create dst descs and xfer side handles. TP workers have same #blocks
# so we only register once per engine_id. # so we only register once per engine_id.
# Example: # Example:
...@@ -1490,14 +1511,52 @@ class NixlConnectorWorker: ...@@ -1490,14 +1511,52 @@ class NixlConnectorWorker:
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
# Keep track of remote agent kv caches base addresses. # Keep track of remote agent kv caches base addresses.
self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr self.kv_caches_base_addr[engine_id][remote_tp_rank] = (
nixl_agent_meta.kv_caches_base_addr
)
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size) self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
# Number of D TP workers reading from a single P TP worker. This is # This is 1 when P and D `--tensor-parallel-size` match. Otherwise,
# 1 when P and D `--tensor-parallel-size` match. # this is the ratio between the two sizes.
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id) tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id)
# Handle tp_size>num_kv_heads: replicate KV cache.
indexes_into_remote = (
not self.kv_topo.replicates_kv_cache(engine_id) and tp_ratio > 0
)
logger.debug(
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s",
engine_id,
remote_tp_rank,
tp_ratio,
)
### (Optional) Register local agent memory regions. MLA is not split.
if (
tp_ratio < 0
and not self.use_mla
and tp_ratio not in self.src_xfer_handles_by_tp_ratio
):
# Remote tp_size > local tp_size: read from multiple remote ranks.
# Logically "split" own regions into |tp_ratio| chunks. Mind that
# we only do this once per remote tp_size (replica-friendly).
self.src_xfer_handles_by_tp_ratio[tp_ratio] = []
for i in range(-tp_ratio):
blocks_data = []
for memory_region in self.src_blocks_data:
addr, local_block_len, own_tp_rank = memory_region
# Computing block len layer by layer allows for different
# block sizes to be used.
remote_block_len = local_block_len // (-tp_ratio)
addr = addr + i * remote_block_len
blocks_data.append((addr, remote_block_len, own_tp_rank))
descs = self.nixl_wrapper.get_xfer_descs(
blocks_data, self.nixl_memory_type
)
handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle)
### Register remote agent memory regions ### Register remote agent memory regions
blocks_data = [] blocks_data = []
# With homogeneous TP, D pulls the whole kv cache from corresponding # With homogeneous TP, D pulls the whole kv cache from corresponding
...@@ -1507,14 +1566,19 @@ class NixlConnectorWorker: ...@@ -1507,14 +1566,19 @@ class NixlConnectorWorker:
# Register all remote blocks, but only the corresponding kv heads. # Register all remote blocks, but only the corresponding kv heads.
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) # Read our whole local region size from remote.
remote_kv_block_len = kv_block_len // block_size_ratio local_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
remote_kv_block_len = local_block_len // block_size_ratio
if block_size_ratio > 1: if block_size_ratio > 1:
# using remote kv_block_len as transfer unit # using remote kv_block_len as transfer unit
kv_block_len = remote_kv_block_len local_block_len = remote_kv_block_len
if tp_ratio < 0 and not self.use_mla:
# Remote tp is bigger: read a chunk of local region from remote
local_block_len = local_block_len // (-tp_ratio)
rank_offset = ( rank_offset = (
self.tp_rank % tp_ratio * remote_kv_block_len self.tp_rank % tp_ratio * remote_kv_block_len
if not replicates_kv_cache if indexes_into_remote
else 0 else 0
) )
for block_id in range(nixl_agent_meta.num_blocks): for block_id in range(nixl_agent_meta.num_blocks):
...@@ -1524,7 +1588,7 @@ class NixlConnectorWorker: ...@@ -1524,7 +1588,7 @@ class NixlConnectorWorker:
# self.block_len == remote_block_len//tp_ratio bytes. # self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset addr = base_addr + block_offset + rank_offset
# (addr, len, device id) # (addr, len, device id)
blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id)) blocks_data.append((addr, local_block_len, nixl_agent_meta.device_id))
if self.kv_topo.is_kv_layout_blocks_first: if self.kv_topo.is_kv_layout_blocks_first:
# With FlashInfer index V separately to allow head splitting. # With FlashInfer index V separately to allow head splitting.
...@@ -1533,7 +1597,7 @@ class NixlConnectorWorker: ...@@ -1533,7 +1597,7 @@ class NixlConnectorWorker:
addr = base_addr + block_offset + rank_offset addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_lens[i] // 2 v_addr = addr + nixl_agent_meta.block_lens[i] // 2
blocks_data.append( blocks_data.append(
(v_addr, kv_block_len, nixl_agent_meta.device_id) (v_addr, local_block_len, nixl_agent_meta.device_id)
) )
logger.debug( logger.debug(
...@@ -1546,15 +1610,15 @@ class NixlConnectorWorker: ...@@ -1546,15 +1610,15 @@ class NixlConnectorWorker:
# Register with NIXL. # Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist( self.dst_xfer_side_handles[engine_id][remote_tp_rank] = (
remote_agent_name, descs self.nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs)
) )
if block_size_ratio > 1: if block_size_ratio > 1:
# when prefill with smaller block_size, we need to init a # when prefill with smaller block_size, we need to init a
# new handler with same block_len to match # new handler with same block_len to match
self.src_xfer_side_handles[nixl_agent_meta.block_size] = ( self.src_xfer_handles_by_block_size[nixl_agent_meta.block_size] = (
self.register_local_xfer_handler(nixl_agent_meta.block_size) self.register_local_xfer_handler(nixl_agent_meta.block_size)[0]
) )
return remote_agent_name return remote_agent_name
...@@ -1574,7 +1638,9 @@ class NixlConnectorWorker: ...@@ -1574,7 +1638,9 @@ class NixlConnectorWorker:
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
remote_engine_id remote_engine_id
) )
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" # Num kv_heads > tp_size and P TP > D TP case, not supported
assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id))
assert not self._use_pallas or tp_ratio == 1, ( assert not self._use_pallas or tp_ratio == 1, (
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet." "TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
) )
...@@ -1616,17 +1682,29 @@ class NixlConnectorWorker: ...@@ -1616,17 +1682,29 @@ class NixlConnectorWorker:
"All remote layers must have the same block size" "All remote layers must have the same block size"
) )
assert ( if tp_ratio > 0:
remote_block_len # Remote tp is smaller: remote block_len size is bigger
== (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio assert (
), ( remote_block_len
"Remote P worker KV layer cache must be of shape [2, N, " == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." ), (
) "Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, page_size, head_dim] and same dtype."
) # noqa: E501
else:
assert block_size_ratio == 1, (
"Different local/remote block sizes are not supported when"
" P TP > D TP."
)
# Remote tp is bigger: remote block_len size is smaller
assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads/tp_ratio, page_size, head_dim] and same dtype."
) # noqa: E501
# TP workers have same #blocks. # TP workers that handhshake with same remote have same #blocks.
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
# Same number of regions/~layers.
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
...@@ -1710,7 +1788,7 @@ class NixlConnectorWorker: ...@@ -1710,7 +1788,7 @@ class NixlConnectorWorker:
) )
cache.index_copy_(0, indices, permuted_blocks) cache.index_copy_(0, indices, permuted_blocks)
def blocksize_post_process(self, block_ids_per_ratio: dict[float, list[list[int]]]): def blocksize_post_process(self, block_ids_per_ratio: dict[int, list[list[int]]]):
def _process_local_gt_remote(blocks_to_update, block_size_ratio): def _process_local_gt_remote(blocks_to_update, block_size_ratio):
n_kv_heads, block_size, head_size = blocks_to_update.shape[1:] n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
remote_block_size = block_size // block_size_ratio remote_block_size = block_size // block_size_ratio
...@@ -1840,7 +1918,7 @@ class NixlConnectorWorker: ...@@ -1840,7 +1918,7 @@ class NixlConnectorWorker:
notified_req_ids: set[str] = set() notified_req_ids: set[str] = set()
for notifs in self.nixl_wrapper.get_new_notifs().values(): for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs: for notif in notifs:
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1) req_id, tp_size = notif.decode("utf-8").rsplit(":", 1)
if ( if (
req_id not in self._reqs_to_send req_id not in self._reqs_to_send
and req_id not in self._reqs_to_process and req_id not in self._reqs_to_process
...@@ -1853,9 +1931,22 @@ class NixlConnectorWorker: ...@@ -1853,9 +1931,22 @@ class NixlConnectorWorker:
) )
continue continue
# NOTE: `tp_ratio` is the opposite when swapping local<>remote
n_consumers = int(tp_size)
tp_ratio = self.kv_topo.tp_ratio(n_consumers)
# Number of reads *per producer* to wait for.
# When remote D TP > local P TP we expect `tp_ratio` reads.
consumers_per_producer = (
-tp_ratio if n_consumers > self.world_size else 1
)
self.consumer_notification_counts_by_req[req_id] += 1 self.consumer_notification_counts_by_req[req_id] += 1
# Wait all consumers (D) to be done reading before freeing. # Wait all consumers (D) to be done reading before freeing.
if self.consumer_notification_counts_by_req[req_id] == int(tp_ratio): if (
self.consumer_notification_counts_by_req[req_id]
== consumers_per_producer
):
notified_req_ids.add(req_id) notified_req_ids.add(req_id)
del self.consumer_notification_counts_by_req[req_id] del self.consumer_notification_counts_by_req[req_id]
self._reqs_to_process.remove(req_id) self._reqs_to_process.remove(req_id)
...@@ -1872,7 +1963,7 @@ class NixlConnectorWorker: ...@@ -1872,7 +1963,7 @@ class NixlConnectorWorker:
""" """
done_req_ids: set[str] = set() done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()): for req_id, handles in list(transfers.items()):
in_progress = False in_progress = []
for handle in handles: for handle in handles:
try: try:
xfer_state = self.nixl_wrapper.check_xfer_state(handle) xfer_state = self.nixl_wrapper.check_xfer_state(handle)
...@@ -1882,7 +1973,7 @@ class NixlConnectorWorker: ...@@ -1882,7 +1973,7 @@ class NixlConnectorWorker:
self.xfer_stats.record_transfer(res) self.xfer_stats.record_transfer(res)
self.nixl_wrapper.release_xfer_handle(handle) self.nixl_wrapper.release_xfer_handle(handle)
elif xfer_state == "PROC": elif xfer_state == "PROC":
in_progress = True in_progress.append(handle)
continue continue
else: else:
logger.error( logger.error(
...@@ -1892,7 +1983,6 @@ class NixlConnectorWorker: ...@@ -1892,7 +1983,6 @@ class NixlConnectorWorker:
xfer_state, xfer_state,
) )
self._handle_failed_transfer(req_id, handle) self._handle_failed_transfer(req_id, handle)
in_progress = False
except Exception: except Exception:
logger.exception( logger.exception(
"NIXL transfer exception for request %s. " "NIXL transfer exception for request %s. "
...@@ -1900,11 +1990,13 @@ class NixlConnectorWorker: ...@@ -1900,11 +1990,13 @@ class NixlConnectorWorker:
req_id, req_id,
) )
self._handle_failed_transfer(req_id, handle) self._handle_failed_transfer(req_id, handle)
in_progress = False
if not in_progress: if not in_progress:
# Only report request as completed when all transfers are done.
done_req_ids.add(req_id) done_req_ids.add(req_id)
del transfers[req_id] del transfers[req_id]
else:
transfers[req_id] = in_progress
return done_req_ids return done_req_ids
def _handle_failed_transfer(self, req_id: str, handle: int): def _handle_failed_transfer(self, req_id: str, handle: int):
...@@ -1982,18 +2074,62 @@ class NixlConnectorWorker: ...@@ -1982,18 +2074,62 @@ class NixlConnectorWorker:
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
assert meta.remote is not None assert meta.remote is not None
logger.debug( remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
"Remote agent %s available, calling _read_blocks for req %s", meta.remote.engine_id
meta.remote.engine_id,
req_id,
)
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote.engine_id,
remote_request_id=meta.remote.request_id,
local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote.block_ids,
) )
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id)
# D may have to perform multiple reads from different remote ranks.
for i, remote_rank in enumerate(remote_ranks):
if self.use_mla and tp_ratio < 0 and i > 0:
# MLA opt: when P TP > D TP, only a single read is executed for
# the first remote rank (cache is duplicated)..
break
remote_block_size = self.kv_topo.remote_block_size[meta.remote.engine_id]
logger.debug(
"Remote agent %s available, calling _read_blocks"
" on remote rank %s with remote block size %s for req %s",
meta.remote.engine_id,
remote_rank,
remote_block_size,
req_id,
)
# Get side handles.
if tp_ratio < 0 and not self.use_mla:
assert remote_block_size == self.block_size
# Remote tp_size > local tp_size: we must perform multiple
# reads. Get the memory chunk onto which we will write to.
local_xfer_side_handle = self.src_xfer_handles_by_tp_ratio[tp_ratio][i]
else:
# Single read from remote, we write to the whole memory region.
# Also handle remote block size different from local block size.
local_xfer_side_handle = self.src_xfer_handles_by_block_size[
remote_block_size
]
# Destination handle: remote_engine_id -> remote_rank -> handle.
remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][
remote_rank
]
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote.engine_id,
remote_request_id=meta.remote.request_id,
local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote.block_ids,
remote_rank=remote_rank,
local_xfer_side_handle=local_xfer_side_handle,
remote_xfer_side_handle=remote_xfer_side_handle,
)
if self.use_mla and tp_ratio < 0:
# ..but we still need to notify the other remote ranks that we
# have the blocks we need so they can update the request state.
notif_id = f"{req_id}:{self.world_size}".encode()
remote_agents = self._remote_agents[meta.remote.engine_id]
for rank_to_notify, agent in remote_agents.items():
if rank_to_notify != remote_rank:
self.nixl_wrapper.send_notif(agent, notif_msg=notif_id)
def _read_blocks( def _read_blocks(
self, self,
...@@ -2002,7 +2138,14 @@ class NixlConnectorWorker: ...@@ -2002,7 +2138,14 @@ class NixlConnectorWorker:
dst_engine_id: str, dst_engine_id: str,
request_id: str, request_id: str,
remote_request_id: str, remote_request_id: str,
remote_rank: int,
local_xfer_side_handle: int,
remote_xfer_side_handle: int,
): ):
"""
Post a READ point-to-point xfer request from a single local worker to
a single remote worker.
"""
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
if block_size_ratio > 1: if block_size_ratio > 1:
local_block_ids = self.get_mapped_blocks( local_block_ids = self.get_mapped_blocks(
...@@ -2031,18 +2174,14 @@ class NixlConnectorWorker: ...@@ -2031,18 +2174,14 @@ class NixlConnectorWorker:
# saturate IB with heterogeneous TP sizes. We should remove the staging # saturate IB with heterogeneous TP sizes. We should remove the staging
# blocks until we are ready. # blocks until we are ready.
# Number of D TP workers that will read from dst P. Propagate tp_ratio # Number of D TP workers that will read from dst P. Propagate info
# on notification so that dst worker can wait before freeing blocks. # on notification so that dst worker can wait before freeing blocks.
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id) notif_id = f"{remote_request_id}:{self.world_size}".encode()
notif_id = f"{remote_request_id}:{tp_ratio}".encode()
# Full prefix cache hit: do not need to read remote blocks, # Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need. # just notify P worker that we have the blocks we need.
num_local_blocks = len(local_block_ids) num_local_blocks = len(local_block_ids)
if num_local_blocks == 0: if num_local_blocks == 0:
remote_rank = self.kv_topo.get_target_remote_rank_from_engine_id(
dst_engine_id
)
agent_name = self._remote_agents[dst_engine_id][remote_rank] agent_name = self._remote_agents[dst_engine_id][remote_rank]
try: try:
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
...@@ -2062,13 +2201,6 @@ class NixlConnectorWorker: ...@@ -2062,13 +2201,6 @@ class NixlConnectorWorker:
if num_local_blocks < num_remote_blocks: if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:] remote_block_ids = remote_block_ids[-num_local_blocks:]
# Get side handles.
remote_block_size = self.kv_topo.remote_block_size[dst_engine_id]
local_xfer_side_handle = self.src_xfer_side_handles.get(
remote_block_size, self.src_xfer_side_handle
)
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp # corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches. # workers will issue xfers to parts of the P worker remote kv caches.
...@@ -2230,7 +2362,7 @@ class NixlConnectorWorker: ...@@ -2230,7 +2362,7 @@ class NixlConnectorWorker:
block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange
).tolist() ).tolist()
def get_backend_aware_kv_block_len(self, layer_idx: int): def get_backend_aware_kv_block_len(self, layer_idx: int) -> int:
""" """
Get the block length for one K/V element (K and V have the same size). Get the block length for one K/V element (K and V have the same size).
...@@ -2276,11 +2408,16 @@ class NixlConnectorWorker: ...@@ -2276,11 +2408,16 @@ class NixlConnectorWorker:
for handle in handles: for handle in handles:
self.nixl_wrapper.release_xfer_handle(handle) self.nixl_wrapper.release_xfer_handle(handle)
self._recving_transfers.clear() self._recving_transfers.clear()
if self.src_xfer_side_handle: for handle in self.src_xfer_handles_by_block_size.values():
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle) self.nixl_wrapper.release_dlist_handle(handle)
self.src_xfer_side_handle = 0 self.src_xfer_handles_by_block_size.clear()
for dst_xfer_side_handle in self.dst_xfer_side_handles.values(): for handles in self.src_xfer_handles_by_tp_ratio.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) for handle in handles:
self.nixl_wrapper.release_dlist_handle(handle)
self.src_xfer_handles_by_tp_ratio.clear()
for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
for dst_xfer_side_handle in dst_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
self.dst_xfer_side_handles.clear() self.dst_xfer_side_handles.clear()
for remote_agents in self._remote_agents.values(): for remote_agents in self._remote_agents.values():
for agent_name in remote_agents.values(): for agent_name in remote_agents.values():
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable, Iterator from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from itertools import islice from itertools import islice
from typing import Any, ClassVar from typing import Any, ClassVar
...@@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata ...@@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
from vllm.distributed.kv_transfer.kv_connector.v1 import ( from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1, KVConnectorBase_V1,
KVConnectorRole, KVConnectorRole,
...@@ -516,23 +517,3 @@ class OffloadingConnectorWorker: ...@@ -516,23 +517,3 @@ class OffloadingConnectorWorker:
del self._store_jobs[req_id] del self._store_jobs[req_id]
return finished_sending, finished_recving return finished_sending, finished_recving
def yield_req_data(
scheduler_output,
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
"""
Yields:
(req_id, new_block_id_groups, preempted)
"""
# new requests
for req_data in scheduler_output.scheduled_new_reqs:
yield req_data.req_id, req_data.block_ids, False
# cached requests
cached_reqs = scheduler_output.scheduled_cached_reqs
yield from zip(
cached_reqs.req_ids,
cached_reqs.new_block_ids,
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
)
...@@ -1007,10 +1007,17 @@ class GroupCoordinator: ...@@ -1007,10 +1007,17 @@ class GroupCoordinator:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]: extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
if self.device_communicator is not None: if self.device_communicator is not None:
return self.device_communicator.dispatch( return self.device_communicator.dispatch( # type: ignore[call-arg]
hidden_states, router_logits, is_sequence_parallel hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors,
) )
else: else:
return hidden_states, router_logits return hidden_states, router_logits
......
...@@ -93,6 +93,7 @@ from vllm.transformers_utils.utils import is_cloud_storage ...@@ -93,6 +93,7 @@ from vllm.transformers_utils.utils import is_cloud_storage
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_ip from vllm.utils.network_utils import get_ip
from vllm.utils.torch_utils import resolve_kv_cache_dtype_string
from vllm.v1.sample.logits_processor import LogitsProcessor from vllm.v1.sample.logits_processor import LogitsProcessor
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -106,6 +107,7 @@ else: ...@@ -106,6 +107,7 @@ else:
LoadFormats = Any LoadFormats = Any
UsageContext = Any UsageContext = Any
logger = init_logger(__name__) logger = init_logger(__name__)
# object is used to allow for special typing forms # object is used to allow for special typing forms
...@@ -406,8 +408,9 @@ class EngineArgs: ...@@ -406,8 +408,9 @@ class EngineArgs:
data_parallel_external_lb: bool = False data_parallel_external_lb: bool = False
data_parallel_backend: str = ParallelConfig.data_parallel_backend data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
all2all_backend: str | None = ParallelConfig.all2all_backend all2all_backend: str = ParallelConfig.all2all_backend
enable_dbo: bool = ParallelConfig.enable_dbo enable_dbo: bool = ParallelConfig.enable_dbo
ubatch_size: int = ParallelConfig.ubatch_size
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
disable_nccl_for_dp_synchronization: bool = ( disable_nccl_for_dp_synchronization: bool = (
...@@ -520,6 +523,7 @@ class EngineArgs: ...@@ -520,6 +523,7 @@ class EngineArgs:
enable_layerwise_nvtx_tracing: bool = ( enable_layerwise_nvtx_tracing: bool = (
ObservabilityConfig.enable_layerwise_nvtx_tracing ObservabilityConfig.enable_layerwise_nvtx_tracing
) )
enable_mfu_metrics: bool = ObservabilityConfig.enable_mfu_metrics
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
...@@ -841,6 +845,10 @@ class EngineArgs: ...@@ -841,6 +845,10 @@ class EngineArgs:
"--all2all-backend", **parallel_kwargs["all2all_backend"] "--all2all-backend", **parallel_kwargs["all2all_backend"]
) )
parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"]) parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
parallel_group.add_argument(
"--ubatch-size",
**parallel_kwargs["ubatch_size"],
)
parallel_group.add_argument( parallel_group.add_argument(
"--dbo-decode-token-threshold", "--dbo-decode-token-threshold",
**parallel_kwargs["dbo_decode_token_threshold"], **parallel_kwargs["dbo_decode_token_threshold"],
...@@ -1035,6 +1043,10 @@ class EngineArgs: ...@@ -1035,6 +1043,10 @@ class EngineArgs:
"--enable-layerwise-nvtx-tracing", "--enable-layerwise-nvtx-tracing",
**observability_kwargs["enable_layerwise_nvtx_tracing"], **observability_kwargs["enable_layerwise_nvtx_tracing"],
) )
observability_group.add_argument(
"--enable-mfu-metrics",
**observability_kwargs["enable_mfu_metrics"],
)
# Scheduler arguments # Scheduler arguments
scheduler_kwargs = get_kwargs(SchedulerConfig) scheduler_kwargs = get_kwargs(SchedulerConfig)
...@@ -1356,12 +1368,17 @@ class EngineArgs: ...@@ -1356,12 +1368,17 @@ class EngineArgs:
f"dcp_size={self.decode_context_parallel_size}." f"dcp_size={self.decode_context_parallel_size}."
) )
# Resolve "auto" kv_cache_dtype to actual value from model config
resolved_cache_dtype = resolve_kv_cache_dtype_string(
self.kv_cache_dtype, model_config
)
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size, block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization, gpu_memory_utilization=self.gpu_memory_utilization,
kv_cache_memory_bytes=self.kv_cache_memory_bytes, kv_cache_memory_bytes=self.kv_cache_memory_bytes,
swap_space=self.swap_space, swap_space=self.swap_space,
cache_dtype=self.kv_cache_dtype, cache_dtype=resolved_cache_dtype,
is_attention_free=model_config.is_attention_free, is_attention_free=model_config.is_attention_free,
num_gpu_blocks_override=self.num_gpu_blocks_override, num_gpu_blocks_override=self.num_gpu_blocks_override,
sliding_window=sliding_window, sliding_window=sliding_window,
...@@ -1557,6 +1574,7 @@ class EngineArgs: ...@@ -1557,6 +1574,7 @@ class EngineArgs:
enable_expert_parallel=self.enable_expert_parallel, enable_expert_parallel=self.enable_expert_parallel,
all2all_backend=self.all2all_backend, all2all_backend=self.all2all_backend,
enable_dbo=self.enable_dbo, enable_dbo=self.enable_dbo,
ubatch_size=self.ubatch_size,
dbo_decode_token_threshold=self.dbo_decode_token_threshold, dbo_decode_token_threshold=self.dbo_decode_token_threshold,
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold, dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization, disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
...@@ -1676,6 +1694,7 @@ class EngineArgs: ...@@ -1676,6 +1694,7 @@ class EngineArgs:
kv_cache_metrics_sample=self.kv_cache_metrics_sample, kv_cache_metrics_sample=self.kv_cache_metrics_sample,
cudagraph_metrics=self.cudagraph_metrics, cudagraph_metrics=self.cudagraph_metrics,
enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing, enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing,
enable_mfu_metrics=self.enable_mfu_metrics,
) )
# Compilation config overrides # Compilation config overrides
......
...@@ -2,11 +2,13 @@ ...@@ -2,11 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import contextlib import contextlib
import copy
import json import json
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from dataclasses import replace
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
from openai.types.responses.response_function_tool_call_output_item import ( from openai.types.responses.response_function_tool_call_output_item import (
...@@ -164,6 +166,12 @@ class SimpleContext(ConversationContext): ...@@ -164,6 +166,12 @@ class SimpleContext(ConversationContext):
def __init__(self): def __init__(self):
self.last_output = None self.last_output = None
# Accumulated final output for streaming mode
self._accumulated_text: str = ""
self._accumulated_token_ids: list[int] = []
self._accumulated_logprobs: list = []
self.num_prompt_tokens = 0 self.num_prompt_tokens = 0
self.num_output_tokens = 0 self.num_output_tokens = 0
self.num_cached_tokens = 0 self.num_cached_tokens = 0
...@@ -183,6 +191,13 @@ class SimpleContext(ConversationContext): ...@@ -183,6 +191,13 @@ class SimpleContext(ConversationContext):
self.num_cached_tokens = output.num_cached_tokens or 0 self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or []) self.num_output_tokens += len(output.outputs[0].token_ids or [])
# Accumulate text, token_ids, and logprobs for streaming mode
delta_output = output.outputs[0]
self._accumulated_text += delta_output.text
self._accumulated_token_ids.extend(delta_output.token_ids)
if delta_output.logprobs is not None:
self._accumulated_logprobs.extend(delta_output.logprobs)
if len(self.input_messages) == 0: if len(self.input_messages) == 0:
output_prompt = output.prompt or "" output_prompt = output.prompt or ""
output_prompt_token_ids = output.prompt_token_ids or [] output_prompt_token_ids = output.prompt_token_ids or []
...@@ -194,11 +209,26 @@ class SimpleContext(ConversationContext): ...@@ -194,11 +209,26 @@ class SimpleContext(ConversationContext):
) )
self.output_messages.append( self.output_messages.append(
ResponseRawMessageAndToken( ResponseRawMessageAndToken(
message=output.outputs[0].text, message=delta_output.text,
tokens=output.outputs[0].token_ids, tokens=delta_output.token_ids,
) )
) )
@property
def final_output(self) -> RequestOutput | None:
"""Return the final output, with complete text/token_ids/logprobs."""
if self.last_output is not None and self.last_output.outputs:
assert isinstance(self.last_output, RequestOutput)
final_output = copy.copy(self.last_output)
# copy inner item to avoid modify last_output
final_output.outputs = [replace(item) for item in self.last_output.outputs]
final_output.outputs[0].text = self._accumulated_text
final_output.outputs[0].token_ids = tuple(self._accumulated_token_ids)
if self._accumulated_logprobs:
final_output.outputs[0].logprobs = self._accumulated_logprobs
return final_output
return self.last_output
def append_tool_output(self, output) -> None: def append_tool_output(self, output) -> None:
raise NotImplementedError("Should not be called.") raise NotImplementedError("Should not be called.")
...@@ -267,12 +297,40 @@ class ParsableContext(ConversationContext): ...@@ -267,12 +297,40 @@ class ParsableContext(ConversationContext):
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format self.chat_template_content_format = chat_template_content_format
self.input_messages: list[ResponseRawMessageAndToken] = []
self.output_messages: list[ResponseRawMessageAndToken] = []
def append_output(self, output: RequestOutput) -> None: def append_output(self, output: RequestOutput) -> None:
self.num_prompt_tokens = len(output.prompt_token_ids or []) self.num_prompt_tokens = len(output.prompt_token_ids or [])
self.num_cached_tokens = output.num_cached_tokens or 0 self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or []) self.num_output_tokens += len(output.outputs[0].token_ids or [])
self.parser.process(output.outputs[0]) self.parser.process(output.outputs[0])
# only store if enable_response_messages is True, save memory
if self.request.enable_response_messages:
output_prompt = output.prompt or ""
output_prompt_token_ids = output.prompt_token_ids or []
if len(self.input_messages) == 0:
self.input_messages.append(
ResponseRawMessageAndToken(
message=output_prompt,
tokens=output_prompt_token_ids,
)
)
else:
self.output_messages.append(
ResponseRawMessageAndToken(
message=output_prompt,
tokens=output_prompt_token_ids,
)
)
self.output_messages.append(
ResponseRawMessageAndToken(
message=output.outputs[0].text,
tokens=output.outputs[0].token_ids,
)
)
def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None: def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None:
self.parser.response_messages.extend(output) self.parser.response_messages.extend(output)
......
...@@ -18,6 +18,7 @@ from vllm.beam_search import ( ...@@ -18,6 +18,7 @@ from vllm.beam_search import (
create_sort_beams_key_function, create_sort_beams_key_function,
) )
from vllm.config import ( from vllm.config import (
AttentionConfig,
CompilationConfig, CompilationConfig,
PoolerConfig, PoolerConfig,
ProfilerConfig, ProfilerConfig,
...@@ -175,6 +176,10 @@ class LLM: ...@@ -175,6 +176,10 @@ class LLM:
compilation_config: Either an integer or a dictionary. If it is an compilation_config: Either an integer or a dictionary. If it is an
integer, it is used as the mode of compilation optimization. If it integer, it is used as the mode of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration. is a dictionary, it can specify the full compilation configuration.
attention_config: Configuration for attention mechanisms. Can be a
dictionary or an AttentionConfig instance. If a dictionary, it will
be converted to an AttentionConfig. Allows specifying the attention
backend and other attention-related settings.
**kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs]. **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
Note: Note:
...@@ -213,6 +218,7 @@ class LLM: ...@@ -213,6 +218,7 @@ class LLM:
| StructuredOutputsConfig | StructuredOutputsConfig
| None = None, | None = None,
profiler_config: dict[str, Any] | ProfilerConfig | None = None, profiler_config: dict[str, Any] | ProfilerConfig | None = None,
attention_config: dict[str, Any] | AttentionConfig | None = None,
kv_cache_memory_bytes: int | None = None, kv_cache_memory_bytes: int | None = None,
compilation_config: int | dict[str, Any] | CompilationConfig | None = None, compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
logits_processors: list[str | type[LogitsProcessor]] | None = None, logits_processors: list[str | type[LogitsProcessor]] | None = None,
...@@ -252,51 +258,28 @@ class LLM: ...@@ -252,51 +258,28 @@ class LLM:
if hf_overrides is None: if hf_overrides is None:
hf_overrides = {} hf_overrides = {}
if compilation_config is not None: def _make_config(value: Any, cls: type[_R]) -> _R:
if isinstance(compilation_config, int): """Convert dict/None/instance to a config instance."""
compilation_config_instance = CompilationConfig( if value is None:
mode=CompilationMode(compilation_config) return cls()
) if isinstance(value, dict):
elif isinstance(compilation_config, dict): return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)}) # type: ignore[arg-type]
compilation_config_instance = CompilationConfig( return value
**{
k: v if isinstance(compilation_config, int):
for k, v in compilation_config.items() compilation_config_instance = CompilationConfig(
if is_init_field(CompilationConfig, k) mode=CompilationMode(compilation_config)
} )
)
else:
compilation_config_instance = compilation_config
else:
compilation_config_instance = CompilationConfig()
if structured_outputs_config is not None:
if isinstance(structured_outputs_config, dict):
structured_outputs_instance = StructuredOutputsConfig(
**{
k: v
for k, v in structured_outputs_config.items()
if is_init_field(StructuredOutputsConfig, k)
}
)
else:
structured_outputs_instance = structured_outputs_config
else:
structured_outputs_instance = StructuredOutputsConfig()
if profiler_config is not None:
if isinstance(profiler_config, dict):
profiler_config_instance = ProfilerConfig(
**{
k: v
for k, v in profiler_config.items()
if is_init_field(ProfilerConfig, k)
}
)
else:
profiler_config_instance = profiler_config
else: else:
profiler_config_instance = ProfilerConfig() compilation_config_instance = _make_config(
compilation_config, CompilationConfig
)
structured_outputs_instance = _make_config(
structured_outputs_config, StructuredOutputsConfig
)
profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
attention_config_instance = _make_config(attention_config, AttentionConfig)
# warn about single-process data parallel usage. # warn about single-process data parallel usage.
_dp_size = int(kwargs.get("data_parallel_size", 1)) _dp_size = int(kwargs.get("data_parallel_size", 1))
...@@ -341,6 +324,7 @@ class LLM: ...@@ -341,6 +324,7 @@ class LLM:
pooler_config=pooler_config, pooler_config=pooler_config,
structured_outputs_config=structured_outputs_instance, structured_outputs_config=structured_outputs_instance,
profiler_config=profiler_config_instance, profiler_config=profiler_config_instance,
attention_config=attention_config_instance,
compilation_config=compilation_config_instance, compilation_config=compilation_config_instance,
logits_processors=logits_processors, logits_processors=logits_processors,
**kwargs, **kwargs,
......
...@@ -17,21 +17,20 @@ from argparse import Namespace ...@@ -17,21 +17,20 @@ from argparse import Namespace
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from http import HTTPStatus from http import HTTPStatus
from typing import Annotated, Any, Literal from typing import Annotated, Any
import model_hosting_container_standards.sagemaker as sagemaker_standards import model_hosting_container_standards.sagemaker as sagemaker_standards
import pydantic import pydantic
import uvloop import uvloop
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Query, Request from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from starlette.concurrency import iterate_in_threadpool from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, Headers, MutableHeaders, State from starlette.datastructures import URL, Headers, MutableHeaders, State
from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.types import ASGIApp, Message, Receive, Scope, Send
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.protocol import ( from vllm.entrypoints.anthropic.protocol import (
...@@ -639,97 +638,6 @@ async def create_translations( ...@@ -639,97 +638,6 @@ async def create_translations(
return StreamingResponse(content=generator, media_type="text/event-stream") return StreamingResponse(content=generator, media_type="text/event-stream")
if envs.VLLM_SERVER_DEV_MODE:
logger.warning(
"SECURITY WARNING: Development endpoints are enabled! "
"This should NOT be used in production!"
)
PydanticVllmConfig = pydantic.TypeAdapter(VllmConfig)
@router.get("/server_info")
async def show_server_info(
raw_request: Request,
config_format: Annotated[Literal["text", "json"], Query()] = "text",
):
vllm_config: VllmConfig = raw_request.app.state.vllm_config
server_info = {
"vllm_config": str(vllm_config)
if config_format == "text"
else PydanticVllmConfig.dump_python(vllm_config, mode="json", fallback=str)
# fallback=str is needed to handle e.g. torch.dtype
}
return JSONResponse(content=server_info)
@router.post("/reset_prefix_cache")
async def reset_prefix_cache(
raw_request: Request,
reset_running_requests: bool = Query(default=False),
reset_external: bool = Query(default=False),
):
"""
Reset the local prefix cache.
Optionally, if the query parameter `reset_external=true`
also resets the external (connector-managed) prefix cache.
Note that we currently do not check if the prefix cache
is successfully reset in the API server.
Example:
POST /reset_prefix_cache?reset_external=true
"""
logger.info("Resetting prefix cache...")
await engine_client(raw_request).reset_prefix_cache(
reset_running_requests, reset_external
)
return Response(status_code=200)
@router.post("/reset_mm_cache")
async def reset_mm_cache(raw_request: Request):
"""
Reset the multi-modal cache. Note that we currently do not check if the
multi-modal cache is successfully reset in the API server.
"""
logger.info("Resetting multi-modal cache...")
await engine_client(raw_request).reset_mm_cache()
return Response(status_code=200)
@router.post("/collective_rpc")
async def collective_rpc(raw_request: Request):
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"JSON decode error: {e}",
) from e
method = body.get("method")
if method is None:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail="Missing 'method' in request body",
)
# For security reason, only serialized string args/kwargs are passed.
# User-defined `method` is responsible for deserialization if needed.
args: list[str] = body.get("args", [])
kwargs: dict[str, str] = body.get("kwargs", {})
timeout: float | None = body.get("timeout")
results = await engine_client(raw_request).collective_rpc(
method=method, timeout=timeout, args=tuple(args), kwargs=kwargs
)
if results is None:
return Response(status_code=200)
response: list[Any] = []
for result in results:
if result is None or isinstance(result, dict | list):
response.append(result)
else:
response.append(str(result))
return JSONResponse(content={"results": response})
def load_log_config(log_config_file: str | None) -> dict | None: def load_log_config(log_config_file: str | None) -> dict | None:
if not log_config_file: if not log_config_file:
return None return None
...@@ -1174,6 +1082,9 @@ async def init_app_state( ...@@ -1174,6 +1082,9 @@ async def init_app_state(
if "generate" in supported_tasks if "generate" in supported_tasks
else None else None
) )
# Warm up chat template processing to avoid first-request latency
if state.openai_serving_chat is not None:
await state.openai_serving_chat.warmup()
state.openai_serving_completion = ( state.openai_serving_completion = (
OpenAIServingCompletion( OpenAIServingCompletion(
engine_client, engine_client,
......
...@@ -3,7 +3,11 @@ ...@@ -3,7 +3,11 @@
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai.types.responses.response_output_item import McpCall
from openai.types.responses.response_output_message import ResponseOutputMessage from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_output_text import ResponseOutputText from openai.types.responses.response_output_text import ResponseOutputText
from openai.types.responses.response_reasoning_item import ( from openai.types.responses.response_reasoning_item import (
...@@ -11,6 +15,7 @@ from openai.types.responses.response_reasoning_item import ( ...@@ -11,6 +15,7 @@ from openai.types.responses.response_reasoning_item import (
ResponseReasoningItem, ResponseReasoningItem,
) )
from vllm.entrypoints.constants import MCP_PREFIX
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
...@@ -111,6 +116,37 @@ class ResponsesParser: ...@@ -111,6 +116,37 @@ class ResponsesParser:
return self return self
def make_response_output_items_from_parsable_context(
self,
) -> list[ResponseOutputItem]:
"""Given a list of sentences, construct ResponseOutput Items."""
response_messages = self.response_messages[self.num_init_messages :]
output_messages: list[ResponseOutputItem] = []
for message in response_messages:
if not isinstance(message, ResponseFunctionToolCallOutputItem):
output_messages.append(message)
else:
if len(output_messages) == 0:
raise ValueError(
"Cannot have a FunctionToolCallOutput before FunctionToolCall."
)
if isinstance(output_messages[-1], ResponseFunctionToolCall):
mcp_message = McpCall(
id=f"{MCP_PREFIX}{random_uuid()}",
arguments=output_messages[-1].arguments,
name=output_messages[-1].name,
server_label=output_messages[
-1
].name, # TODO: store the server label
type="mcp_call",
status="completed",
output=message.output,
# TODO: support error output
)
output_messages[-1] = mcp_message
return output_messages
def get_responses_parser_for_simple_context( def get_responses_parser_for_simple_context(
*, *,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment