Commit a810671a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori

parents 86b5aefe 6a09612b
......@@ -170,8 +170,7 @@ class PiecewiseBackend:
range_entry = self._find_range_for_shape(runtime_shape)
assert range_entry is not None, (
f"Shape out of considered range: {runtime_shape} "
"[1, max_num_batched_tokens]"
f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
)
self._maybe_compile_for_range_entry(range_entry, args)
......
......@@ -437,14 +437,14 @@ class CompilationConfig:
compile_ranges_split_points: list[int] | None = None
"""Split points that represent compile ranges for inductor.
The compile ranges are
[1, split_points[0]],
[split_points[0] + 1, split_points[1]], ...,
The compile ranges are
[1, split_points[0]],
[split_points[0] + 1, split_points[1]], ...,
[split_points[-1] + 1, max_num_batched_tokens].
Compile sizes are also used single element ranges,
the range is represented as [compile_sizes[i], compile_sizes[i]].
If a range overlaps with the compile size, graph for compile size
If a range overlaps with the compile size, graph for compile size
will be prioritized, i.e. if we have a range [1, 8] and a compile size 4,
graph for compile size 4 will be compiled and used instead of the graph
for range [1, 8].
......@@ -899,7 +899,7 @@ class CompilationConfig:
self.compute_bs_to_padded_graph_size()
def set_splitting_ops_for_v1(
self, all2all_backend: str | None = None, data_parallel_size: int | None = None
self, all2all_backend: str, data_parallel_size: int = 1
):
# To compatible with OOT hardware plugin platform (for example vllm-ascend)
# which currently only supports sequence parallelism in eager mode.
......@@ -934,7 +934,7 @@ class CompilationConfig:
or self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
):
logger.warning_once(
"Using piecewise compilation with empty splitting_ops"
"Using piecewise cudagraph with empty splitting_ops"
)
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once(
......@@ -956,11 +956,9 @@ class CompilationConfig:
self.splitting_ops = []
# Disable CUDA graphs for DeepEP high-throughput since its not CG compatible
backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND
dp_size = data_parallel_size if data_parallel_size is not None else 1
if (
backend == "deepep_high_throughput"
and dp_size > 1
all2all_backend == "deepep_high_throughput"
and data_parallel_size > 1
and self.cudagraph_mode != CUDAGraphMode.NONE
):
# TODO: Piecewise Cuda graph might be enabled
......
......@@ -64,6 +64,9 @@ class ObservabilityConfig:
module in the model and attach informations such as input/output shapes to
nvtx range markers. Noted that this doesn't work with CUDA graphs enabled."""
enable_mfu_metrics: bool = False
"""Enable Model FLOPs Utilization (MFU) metrics."""
@cached_property
def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request."""
......
......@@ -36,6 +36,14 @@ ExpertPlacementStrategy = Literal["linear", "round_robin"]
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
DataParallelBackend = Literal["ray", "mp"]
EPLBPolicyOption = Literal["default"]
All2AllBackend = Literal[
"naive",
"pplx",
"deepep_high_throughput",
"deepep_low_latency",
"allgather_reducescatter",
"flashinfer_all2allv",
]
@config
......@@ -126,24 +134,14 @@ class ParallelConfig:
with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1
will have experts [1, 3]. This strategy can help improve load balancing
for grouped expert models with no redundant experts."""
all2all_backend: (
Literal[
"naive",
"pplx",
"deepep_high_throughput",
"deepep_low_latency",
"allgather_reducescatter",
"flashinfer_all2allv",
]
| None
) = None
"""All2All backend for MoE expert parallel communication. If not set, uses
the value from VLLM_ALL2ALL_BACKEND environment variable. Available options:
- "naive": Naive all2all implementation using broadcasts
- "allgather_reducescatter": All2all based on allgather and reducescatter
- "pplx": Use pplx kernels
- "deepep_high_throughput": Use deepep high-throughput kernels
- "deepep_low_latency": Use deepep low-latency kernels
all2all_backend: All2AllBackend = "allgather_reducescatter"
"""All2All backend for MoE expert parallel communication. Available options:
- "naive": Naive all2all implementation using broadcasts\n
- "allgather_reducescatter": All2all based on allgather and reducescatter\n
- "pplx": Use pplx kernels\n
- "deepep_high_throughput": Use deepep high-throughput kernels\n
- "deepep_low_latency": Use deepep low-latency kernels\n
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
max_parallel_loading_workers: int | None = None
......@@ -156,6 +154,8 @@ class ParallelConfig:
enable_dbo: bool = False
"""Enable dual batch overlap for the model executor."""
ubatch_size: int = 0
"""Number of ubatch size."""
dbo_decode_token_threshold: int = 32
"""The threshold for dual batch overlap for batches only containing decodes.
......@@ -325,6 +325,14 @@ class ParallelConfig:
including data parallelism."""
return self.world_size * self.data_parallel_size
@property
def use_ubatching(self) -> bool:
return self.enable_dbo or self.ubatch_size > 1
@property
def num_ubatches(self) -> int:
return 2 if self.enable_dbo else self.ubatch_size
def get_next_dp_init_port(self) -> int:
"""
We might need to initialize process groups in multiple
......@@ -485,20 +493,17 @@ class ParallelConfig:
from vllm.config.utils import get_hash_factors, hash_factors
factors = get_hash_factors(self, ignored_factors)
# Explicitly include backend affecting env factor as before
factors["VLLM_ALL2ALL_BACKEND"] = str(envs.VLLM_ALL2ALL_BACKEND)
return hash_factors(factors)
def __post_init__(self) -> None:
# Set all2all_backend from env var if not specified, with deprecation warning
if self.all2all_backend is None:
if envs.is_set("VLLM_ALL2ALL_BACKEND"):
logger.warning_once(
"VLLM_ALL2ALL_BACKEND environment variable is deprecated and "
"will be removed in v0.15.0. Please use the "
"--all2all-backend command-line argument instead."
)
self.all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if envs.is_set("VLLM_ALL2ALL_BACKEND"):
logger.warning_once(
"VLLM_ALL2ALL_BACKEND environment variable is deprecated and "
"will be removed in a future release. Please use the "
"--all2all-backend command-line argument instead."
)
# Continue with the rest of the initialization
self.world_size = (
......
......@@ -870,9 +870,12 @@ class VllmConfig:
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
)
if self.parallel_config.enable_dbo:
if self.parallel_config.use_ubatching:
a2a_backend = self.parallel_config.all2all_backend
assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], (
assert a2a_backend in [
"deepep_low_latency",
"deepep_high_throughput",
], (
"Microbatching currently only supports the deepep_low_latency and "
f"deepep_high_throughput all2all backend. {a2a_backend} is not "
"supported. To fix use --all2all-backend=deepep_low_latency or "
......
......@@ -64,7 +64,12 @@ class NaiveAll2AllManager(All2AllManagerBase):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
......@@ -76,6 +81,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
router_logits = self.naive_multicast(
router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
)
return hidden_states, router_logits
def combine(
......@@ -113,7 +119,11 @@ class AgRsAll2AllManager(All2AllManagerBase):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
......@@ -121,15 +131,22 @@ class AgRsAll2AllManager(All2AllManagerBase):
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
hidden_states, router_logits = dist_group.all_gatherv(
[hidden_states, router_logits],
tensors_to_gather = [hidden_states, router_logits]
if extra_tensors is not None:
tensors_to_gather.extend(extra_tensors)
gathered_tensors = dist_group.all_gatherv(
tensors_to_gather,
dim=0,
sizes=sizes,
)
return hidden_states, router_logits
if extra_tensors is not None:
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
return gathered_tensors[0], gathered_tensors[1]
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
......@@ -204,6 +221,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
......@@ -251,6 +269,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Any
from weakref import WeakValueDictionary
import torch
......@@ -68,7 +69,11 @@ class All2AllManagerBase:
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
):
extra_tensors: list[torch.Tensor] | None = None,
) -> Any:
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise NotImplementedError
def set_num_sms(self, num_sms: int):
......
......@@ -318,17 +318,23 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list
def dispatch(
def dispatch( # type: ignore[override]
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits, is_sequence_parallel
return self.all2all_manager.dispatch(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
)
return hidden_states, router_logits
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
......
......@@ -73,6 +73,7 @@ class ECExampleConnector(ECConnectorBase):
data hashes (`mm_hash`) to encoder cache tensors.
kwargs (dict): Additional keyword arguments for the connector.
"""
from vllm.platforms import current_platform
# Get the metadata
metadata: ECConnectorMetadata = self._get_connector_metadata()
......@@ -91,7 +92,9 @@ class ECExampleConnector(ECConnectorBase):
if mm_data.mm_hash in encoder_cache:
continue
filename = self._generate_filename_debug(mm_data.mm_hash)
ec_cache = safetensors.torch.load_file(filename)["ec_cache"].cuda()
ec_cache = safetensors.torch.load_file(
filename, device=current_platform.device_type
)["ec_cache"]
encoder_cache[mm_data.mm_hash] = ec_cache
logger.debug("Success load encoder cache for hash %s", mm_data.mm_hash)
......
......@@ -4,6 +4,7 @@
KV cache helper for store.
"""
from collections.abc import Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal
......@@ -21,6 +22,8 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
EngineId = str
def get_kv_connector_cache_layout():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
......@@ -201,6 +204,26 @@ def copy_kv_blocks(
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
def yield_req_data(
scheduler_output,
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
"""
Yields:
(req_id, new_block_id_groups, preempted)
"""
# new requests
for req_data in scheduler_output.scheduled_new_reqs:
yield req_data.req_id, req_data.block_ids, False
# cached requests
cached_reqs = scheduler_output.scheduled_cached_reqs
yield from zip(
cached_reqs.req_ids,
cached_reqs.new_block_ids,
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
)
@dataclass
class TpKVTopology:
"""
......@@ -209,12 +232,12 @@ class TpKVTopology:
"""
tp_rank: int
remote_tp_size: dict[str, int]
remote_tp_size: dict[EngineId, int]
is_mla: bool
total_num_kv_heads: int
attn_backend: type[AttentionBackend]
engine_id: str
remote_block_size: dict[str, int]
engine_id: EngineId
remote_block_size: dict[EngineId, int]
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
......@@ -256,18 +279,28 @@ class TpKVTopology:
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
groups of size `tp_ratio`.If remote tp_size > local tp_size, the
ratio is flipped (remote_size/local_size) and the returned value is
negative.
"""
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
if self.tp_size >= remote_tp_size:
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
assert remote_tp_size % self.tp_size == 0, (
f"Remote tensor parallel size {remote_tp_size} is not divisible "
f"by local tensor parallel size {self.tp_size}."
)
return self.tp_size // remote_tp_size
# P TP > D TP case, return the ratio as negative
return -remote_tp_size // self.tp_size
def block_size_ratio(
self,
remote_block_size: int,
) -> float:
) -> int:
"""
Calculate the block size ratio between local and remote TP.
"""
......@@ -279,19 +312,19 @@ class TpKVTopology:
def tp_ratio_from_engine_id(
self,
remote_engine_id: str,
remote_engine_id: EngineId,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size)
def block_size_ratio_from_engine_id(
self,
remote_engine_id: str,
) -> float:
remote_engine_id: EngineId,
) -> int:
remote_block_size = self.remote_block_size[remote_engine_id]
return self.block_size_ratio(remote_block_size)
def is_kv_replicated(self, engine_id: str) -> bool:
def is_kv_replicated(self, engine_id: EngineId) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
......@@ -299,24 +332,30 @@ class TpKVTopology:
tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(self, remote_engine_id: str) -> bool:
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
# MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id)
def get_target_remote_rank(
def get_target_remote_ranks(
self,
remote_tp_size: int,
) -> int:
) -> list[int]:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
(on D) will read from. When remote tp_size > local tp_size, we
read from multiple remote ranks.
"""
tp_ratio = self.tp_ratio(remote_tp_size)
return self.tp_rank // tp_ratio
if tp_ratio > 0:
return [self.tp_rank // tp_ratio]
# P TP > D TP case, D reads from |tp_ratio| remote workers.
tp_ratio = -tp_ratio
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
def get_target_remote_rank_from_engine_id(
def get_target_remote_ranks_from_engine_id(
self,
remote_engine_id: str,
) -> int:
remote_engine_id: EngineId,
) -> list[int]:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_rank(remote_tp_size)
return self.get_target_remote_ranks(remote_tp_size)
......@@ -147,6 +147,14 @@ class LMCacheMPSchedulerAdapter:
"""
return self.blocks_in_chunk
def _cleanup_lookup_result(self, request_id: str) -> None:
"""
Clean up lookup future for a finished request to prevent memory leak.
Args:
request_id: The ID of the finished request.
"""
self.lookup_futures.pop(request_id, None)
# Helper functions
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
"""Convert a block hash to an IPC cache engine key"""
......@@ -262,6 +270,7 @@ class LMCacheMPWorkerAdapter:
):
keys = []
block_ids = []
for op in ops:
keys.extend(self._block_hashes_to_keys(op.block_hashes))
block_ids.extend(op.block_ids)
......
......@@ -24,6 +24,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import (
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.request import RequestStatus
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
......@@ -211,7 +212,7 @@ class LMCacheMPRequestTracker:
"""
self.num_stored_blocks += num_new_blocks
def update_block_ids(
def append_block_ids(
self,
new_block_ids: list[int],
):
......@@ -455,10 +456,6 @@ class LMCacheMPConnector(KVConnectorBase_V1):
metadata = self._get_connector_metadata()
assert isinstance(metadata, LMCacheMPConnectorMetadata)
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
request_ids = []
ops = []
......@@ -468,10 +465,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
request_ids.append(meta.request_id)
ops.append(meta.op)
if len(request_ids) > 0:
self.worker_adapter.batched_submit_retrieve_requests(
request_ids, ops, event
)
if len(request_ids) == 0:
return
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
self.worker_adapter.batched_submit_retrieve_requests(request_ids, ops, event)
def wait_for_layer_load(self, layer_name: str) -> None:
"""
......@@ -518,10 +519,6 @@ class LMCacheMPConnector(KVConnectorBase_V1):
metadata = self._get_connector_metadata()
assert isinstance(metadata, LMCacheMPConnectorMetadata)
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
request_ids = []
ops = []
for meta in metadata.requests:
......@@ -530,8 +527,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
request_ids.append(meta.request_id)
ops.append(meta.op)
if len(request_ids) > 0:
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event)
if len(request_ids) == 0:
return
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event)
def get_finished(
self, finished_req_ids: set[str]
......@@ -627,6 +630,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
into account.
"""
tracker = self._get_or_create_request_tracker(request)
# TODO: support loading KV for preempted requests in the future
if request.status == RequestStatus.PREEMPTED:
return 0, False
self.scheduler_adapter.maybe_submit_lookup_request(
request.request_id, convert_block_hashes_to_bytes(request.block_hashes)
......@@ -683,7 +689,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
# No matter we need to retrieve or not, we need to update
# the block ids into the tracker
tracker.update_block_ids(block_ids)
tracker.append_block_ids(block_ids)
# Update the state of the tracker
condition = tracker.needs_retrieve()
......@@ -695,6 +701,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
if condition
else LMCacheMPRequestState.READY
)
# Clean up lookup future in scheduler adapter
self.scheduler_adapter._cleanup_lookup_result(request.request_id)
def build_connector_meta(
self, scheduler_output: SchedulerOutput
......@@ -748,6 +756,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
# Clean up request tracker to prevent memory leak
self._cleanup_request_tracker(request.request_id)
return True, None
def take_events(self) -> Iterable["KVCacheEvent"]:
......@@ -866,7 +876,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
# Update block ids
new_block_ids = reformat_block_ids(cached_reqs.new_block_ids[idx])
request_tracker.update_block_ids(new_block_ids)
if request_id not in cached_reqs.resumed_req_ids:
request_tracker.append_block_ids(new_block_ids)
# Update new scheduled tokens
num_new_tokens = cached_reqs.num_computed_tokens[idx]
......@@ -889,7 +900,34 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self, request: "Request"
) -> LMCacheMPRequestTracker:
request_id = request.request_id
# Remove the old trackers that is created before the preemption
if (
request.status == RequestStatus.PREEMPTED
and request_id in self.request_trackers
):
tracker = self.request_trackers[request_id]
# NOTE: since this function may be called multiple times
# for a single request (because get_num_new_matched_tokens
# may be called multiple times) for the same request, we
# will only do the remove if the tracker is not in the "fresh"
# state, i.e., PREFETCHING
if tracker.state != LMCacheMPRequestState.PREFETCHING:
self.request_trackers.pop(request_id)
if request_id not in self.request_trackers:
new_tracker = LMCacheMPRequestTracker(request)
self.request_trackers[request_id] = new_tracker
return self.request_trackers[request_id]
def _cleanup_request_tracker(self, request_id: str) -> None:
"""
Clean up request tracker and associated lookup future for a request.
This should be called when a request is finished to prevent memory leak.
"""
# Clean up request tracker
if self.request_trackers.pop(request_id, None):
logger.debug(
"[KVConnector] Cleaned up request_tracker for request %s",
request_id,
)
......@@ -23,7 +23,11 @@ from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId,
TpKVTopology,
yield_req_data,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp,
KVConnectorBase_V1,
......@@ -56,7 +60,6 @@ if TYPE_CHECKING:
from vllm.v1.request import Request
TransferHandle = int
EngineId = str
ReqId = str
#
......@@ -482,7 +485,7 @@ class NixlConnectorScheduler:
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_save: dict[ReqId, Request] = {}
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_in_batch: set[ReqId] = set()
......@@ -628,16 +631,7 @@ class NixlConnectorScheduler:
if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
# save all blocks
block_ids = blocks.get_block_ids()[0]
# TODO: skip the blocks that are already in the host xfer buffer.
# Currently, the host xfer buffer block is 1-to-1 mapped to device
# kv blocks, so host blocks won't be flushed as long as its device
# block is not overwritten; and it will be safe to skip saving them
# to host xfer buffer.
if block_ids:
self._reqs_need_save[request.request_id] = (request, block_ids)
self._reqs_need_save[request.request_id] = request
elif params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(
......@@ -689,13 +683,32 @@ class NixlConnectorScheduler:
kv_transfer_params=req.kv_transfer_params,
)
for req_id, (req, block_ids) in self._reqs_need_save.items():
# NOTE: For the prefill side, there might be a chance that an early added
# request is a chunked prefill, so we need to check if new blocks are added
for req_id, new_block_id_groups, _ in yield_req_data(scheduler_output):
req_to_save = self._reqs_need_save.get(req_id)
if req_to_save is None or new_block_id_groups is None:
continue
req = req_to_save
assert req.kv_transfer_params is not None
meta.add_new_req_to_save(
request_id=req_id,
local_block_ids=block_ids,
local_block_ids=new_block_id_groups[0],
kv_transfer_params=req.kv_transfer_params,
)
assert scheduler_output.num_scheduled_tokens is not None
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
is_partial = (
req.num_computed_tokens + num_scheduled_tokens
) < req.num_prompt_tokens
if not is_partial:
# For non-partial prefills, once new req_meta is scheduled, it
# can be removed from _reqs_need_save.
# For partial prefill case, we will retain the request in
# _reqs_need_save until all blocks are scheduled with req_meta.
# Therefore, only pop if `not is_partial`.
self._reqs_need_save.pop(req_id)
meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch
......@@ -703,7 +716,6 @@ class NixlConnectorScheduler:
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_need_save.clear()
self._reqs_in_batch = set()
self._reqs_not_processed = set()
self._reqs_need_send = {}
......@@ -749,6 +761,8 @@ class NixlConnectorScheduler:
# Also include the case of a P/D Prefill request with immediate
# block free (eg abort). Stop tracking this request.
self._reqs_not_processed.add(request.request_id)
# Clear _reqs_need_save if a request is aborted as partial prefill.
self._reqs_need_save.pop(request.request_id, None)
return False, None
# TODO: check whether block_ids actually ever be 0. If not we could
......@@ -873,9 +887,10 @@ class NixlConnectorWorker:
self.copy_blocks: CopyBlocksOp | None = None
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
self.device_id: int = 0
# Current rank may pull from multiple remote TP workers.
# EngineId, dict[int, list[int]] -> engine_id, tp_rank, base_addr_for_layer
self.kv_caches_base_addr = defaultdict[EngineId, dict[int, list[int]]](dict)
# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
......@@ -883,10 +898,12 @@ class NixlConnectorWorker:
self.num_layers = 0
# nixl_prepped_dlist_handle.
self.src_xfer_side_handle: int = 0
self.src_xfer_side_handles: dict[int, int] = {}
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
self.dst_xfer_side_handles: dict[EngineId, int] = {}
self.src_xfer_handles_by_block_size: dict[int, int] = {}
# Populated dynamically during handshake based on remote configuration.
# Keep track of regions at different tp_ratio values. tp_ratio->handles
self.src_xfer_handles_by_tp_ratio: dict[int, list[int]] = {}
# Map of engine_id -> {tp_rank: nixl_prepped_dlist_handle (int)}.
self.dst_xfer_side_handles = defaultdict[EngineId, dict[int, int]](dict)
# Map of engine_id -> num_blocks. All ranks in the same deployment will
# have the same number of blocks.
......@@ -977,103 +994,108 @@ class NixlConnectorWorker:
expected_engine_id: str,
) -> dict[int, str]:
"""Do a NIXL handshake with a remote instance."""
start_time = time.perf_counter()
# NOTE(rob): we need each rank to have a unique port. This is
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size)
# When target instance TP > local TP, we need to perform multiple
# handshakes. Do it in a single background job for simplicity.
# Regardless, only handshake with the remote TP rank(s) that current
# local rank will read from. Note that With homogeneous TP,
# this happens to be the same single rank_i.
p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size)
remote_rank_to_agent_name = {}
path = make_zmq_path("tcp", host, port)
logger.debug(
"Querying metadata on path: %s at remote tp rank %s", path, p_remote_rank
)
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank))
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
sock.send(msg)
handshake_bytes = sock.recv()
# Decode handshake payload to get compatibility hash
handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload)
try:
handshake_payload = handshake_decoder.decode(handshake_bytes)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
raise RuntimeError(
f"Failed to decode NixlHandshakePayload. This likely indicates "
f"an incompatibility between connector version. Error: {e}"
) from e
got_metadata_time = time.perf_counter()
logger.debug(
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
)
# Check compatibility hash BEFORE decoding agent metadata
if (
self.enforce_compat_hash
and handshake_payload.compatibility_hash != self.compat_hash
):
raise RuntimeError(
f"NIXL compatibility hash mismatch. "
f"Local: {self.compat_hash}, "
f"Remote: {handshake_payload.compatibility_hash}. "
f"Prefill and decode instances have incompatible configurations. "
f"This may be due to: different vLLM versions, models, dtypes, "
f"KV cache layouts, attention backends, etc. "
f"Both instances must use identical configurations."
f"Disable this check using "
f'--kv-transfer-config \'{{"kv_connector_extra_config": '
f'{{"enforce_handshake_compat": false}}}}\''
for remote_rank in p_remote_ranks:
logger.debug(
"Querying metadata on path: %s at remote tp rank %s",
path,
remote_rank,
)
logger.info(
"NIXL compatibility check passed (hash: %s)",
handshake_payload.compatibility_hash,
)
start_time = time.perf_counter()
# Send query for the request.
msg = msgspec.msgpack.encode((GET_META_MSG, remote_rank))
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
sock.send(msg)
handshake_bytes = sock.recv()
# Decode agent metadata
metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
try:
metadata = metadata_decoder.decode(
handshake_payload.agent_metadata_bytes
# Decode handshake payload to get compatibility hash
handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload)
try:
handshake_payload = handshake_decoder.decode(handshake_bytes)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
raise RuntimeError(
f"Failed to decode NixlHandshakePayload. This likely indicates "
f"an incompatibility between connector version. Error: {e}"
) from e
got_metadata_time = time.perf_counter()
logger.debug(
"NIXL handshake: get metadata took: %s",
got_metadata_time - start_time,
)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
# This should not happen if hash matched
raise RuntimeError(
f"Failed to decode NixlAgentMetadata. Error: {e}"
) from e
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(
f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
)
# Check compatibility hash BEFORE decoding agent metadata
if (
self.enforce_compat_hash
and handshake_payload.compatibility_hash != self.compat_hash
):
raise RuntimeError(
f"NIXL compatibility hash mismatch. "
f"Local: {self.compat_hash}, "
f"Remote: {handshake_payload.compatibility_hash}. "
f"Prefill and decode instances have incompatible "
f"configurations. This may be due to: different vLLM versions,"
f" models, dtypes, KV cache layouts, attention backends, etc. "
f"Both instances must use identical configurations."
f"Disable this check using "
f'--kv-transfer-config \'{{"kv_connector_extra_config": '
f'{{"enforce_handshake_compat": false}}}}\''
)
# Register Remote agent.
assert metadata.block_size <= self.block_size, (
"nP > nD is not supported yet."
)
remote_agent_name = self.add_remote_agent(
metadata, p_remote_rank, remote_tp_size
)
logger.info(
"NIXL compatibility check passed (hash: %s)",
handshake_payload.compatibility_hash,
)
setup_agent_time = time.perf_counter()
logger.debug(
"NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time,
)
# Decode agent metadata
metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
try:
metadata = metadata_decoder.decode(
handshake_payload.agent_metadata_bytes
)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
# This should not happen if hash matched
raise RuntimeError(
f"Failed to decode NixlAgentMetadata. Error: {e}"
) from e
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(
f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
)
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(
f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
)
setup_agent_time = time.perf_counter()
# Remote rank -> agent name.
return {p_remote_rank: remote_agent_name}
# Register Remote agent.
remote_agent_name = self.add_remote_agent(
metadata, remote_rank, remote_tp_size
)
logger.debug(
"NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time,
)
remote_rank_to_agent_name[remote_rank] = remote_agent_name
return remote_rank_to_agent_name
def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None:
"""
......@@ -1283,7 +1305,7 @@ class NixlConnectorWorker:
assert len(self.block_len_per_layer) == len(seen_base_addresses)
assert self.num_blocks != 0
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
self.num_regions = len(caches_data)
self.num_layers = len(xfer_buffers.keys())
......@@ -1310,9 +1332,9 @@ class NixlConnectorWorker:
# Register local/src descr for NIXL xfer.
self.seen_base_addresses = seen_base_addresses
self.src_xfer_side_handle = self.register_local_xfer_handler(self.block_size)
self.src_xfer_side_handles[self.block_size] = self.src_xfer_side_handle
self.src_xfer_handles_by_block_size[self.block_size], self.src_blocks_data = (
self.register_local_xfer_handler(self.block_size)
)
# TODO(mgoin): Hybrid memory allocator is currently disabled for
# models with local attention (Llama 4). Can remove this once enabled.
......@@ -1340,8 +1362,8 @@ class NixlConnectorWorker:
agent_metadata = NixlAgentMetadata(
engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
device_id=self.device_id,
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank],
num_blocks=self.num_blocks,
block_lens=self.block_len_per_layer,
kv_cache_layout=self.kv_cache_layout
......@@ -1359,7 +1381,7 @@ class NixlConnectorWorker:
def register_local_xfer_handler(
self,
block_size: int,
) -> int:
) -> tuple[int, list[tuple[int, int, int]]]:
"""
Function used for register local xfer handler with local block_size or
Remote block_size.
......@@ -1407,7 +1429,7 @@ class NixlConnectorWorker:
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
# NIXL_INIT_AGENT to be used for preparations of local descs.
return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs), blocks_data
def add_remote_agent(
self,
......@@ -1421,10 +1443,12 @@ class NixlConnectorWorker:
In particular, handle both homogeneous and heterogeneous TP. The former
requires local rank_i to read from remote rank_i.
The latter, assuming D.world_size > P.world_size, requires that two or
more local TP worker share the xfer from a single TP worker.
The latter, in the case of D.world_size < P.world_size, requires that a
local (D) TP worker reads from multiple remote (P) TP workers.
Conversely, assuming D.world_size > P.world_size, two or more local TP
workers will read from a single remote TP worker.
Here's an example (non-MLA case):
Here's an example for the last case described above (non-MLA):
rank_offset p_remote_tp_rank
(kv split no)
......@@ -1474,9 +1498,6 @@ class NixlConnectorWorker:
nixl_agent_meta.agent_metadata
)
# Handle tp_size>num_kv_heads: replicate KV cache.
replicates_kv_cache = self.kv_topo.replicates_kv_cache(engine_id)
# Create dst descs and xfer side handles. TP workers have same #blocks
# so we only register once per engine_id.
# Example:
......@@ -1490,14 +1511,52 @@ class NixlConnectorWorker:
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
# Keep track of remote agent kv caches base addresses.
self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr
self.kv_caches_base_addr[engine_id][remote_tp_rank] = (
nixl_agent_meta.kv_caches_base_addr
)
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
# Number of D TP workers reading from a single P TP worker. This is
# 1 when P and D `--tensor-parallel-size` match.
# This is 1 when P and D `--tensor-parallel-size` match. Otherwise,
# this is the ratio between the two sizes.
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id)
# Handle tp_size>num_kv_heads: replicate KV cache.
indexes_into_remote = (
not self.kv_topo.replicates_kv_cache(engine_id) and tp_ratio > 0
)
logger.debug(
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s",
engine_id,
remote_tp_rank,
tp_ratio,
)
### (Optional) Register local agent memory regions. MLA is not split.
if (
tp_ratio < 0
and not self.use_mla
and tp_ratio not in self.src_xfer_handles_by_tp_ratio
):
# Remote tp_size > local tp_size: read from multiple remote ranks.
# Logically "split" own regions into |tp_ratio| chunks. Mind that
# we only do this once per remote tp_size (replica-friendly).
self.src_xfer_handles_by_tp_ratio[tp_ratio] = []
for i in range(-tp_ratio):
blocks_data = []
for memory_region in self.src_blocks_data:
addr, local_block_len, own_tp_rank = memory_region
# Computing block len layer by layer allows for different
# block sizes to be used.
remote_block_len = local_block_len // (-tp_ratio)
addr = addr + i * remote_block_len
blocks_data.append((addr, remote_block_len, own_tp_rank))
descs = self.nixl_wrapper.get_xfer_descs(
blocks_data, self.nixl_memory_type
)
handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle)
### Register remote agent memory regions
blocks_data = []
# With homogeneous TP, D pulls the whole kv cache from corresponding
......@@ -1507,14 +1566,19 @@ class NixlConnectorWorker:
# Register all remote blocks, but only the corresponding kv heads.
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
remote_kv_block_len = kv_block_len // block_size_ratio
# Read our whole local region size from remote.
local_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
remote_kv_block_len = local_block_len // block_size_ratio
if block_size_ratio > 1:
# using remote kv_block_len as transfer unit
kv_block_len = remote_kv_block_len
local_block_len = remote_kv_block_len
if tp_ratio < 0 and not self.use_mla:
# Remote tp is bigger: read a chunk of local region from remote
local_block_len = local_block_len // (-tp_ratio)
rank_offset = (
self.tp_rank % tp_ratio * remote_kv_block_len
if not replicates_kv_cache
if indexes_into_remote
else 0
)
for block_id in range(nixl_agent_meta.num_blocks):
......@@ -1524,7 +1588,7 @@ class NixlConnectorWorker:
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id))
blocks_data.append((addr, local_block_len, nixl_agent_meta.device_id))
if self.kv_topo.is_kv_layout_blocks_first:
# With FlashInfer index V separately to allow head splitting.
......@@ -1533,7 +1597,7 @@ class NixlConnectorWorker:
addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
blocks_data.append(
(v_addr, kv_block_len, nixl_agent_meta.device_id)
(v_addr, local_block_len, nixl_agent_meta.device_id)
)
logger.debug(
......@@ -1546,15 +1610,15 @@ class NixlConnectorWorker:
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist(
remote_agent_name, descs
self.dst_xfer_side_handles[engine_id][remote_tp_rank] = (
self.nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs)
)
if block_size_ratio > 1:
# when prefill with smaller block_size, we need to init a
# new handler with same block_len to match
self.src_xfer_side_handles[nixl_agent_meta.block_size] = (
self.register_local_xfer_handler(nixl_agent_meta.block_size)
self.src_xfer_handles_by_block_size[nixl_agent_meta.block_size] = (
self.register_local_xfer_handler(nixl_agent_meta.block_size)[0]
)
return remote_agent_name
......@@ -1574,7 +1638,9 @@ class NixlConnectorWorker:
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
remote_engine_id
)
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
# Num kv_heads > tp_size and P TP > D TP case, not supported
assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id))
assert not self._use_pallas or tp_ratio == 1, (
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
)
......@@ -1616,17 +1682,29 @@ class NixlConnectorWorker:
"All remote layers must have the same block size"
)
assert (
remote_block_len
== (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
), (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)
if tp_ratio > 0:
# Remote tp is smaller: remote block_len size is bigger
assert (
remote_block_len
== (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
), (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, page_size, head_dim] and same dtype."
) # noqa: E501
else:
assert block_size_ratio == 1, (
"Different local/remote block sizes are not supported when"
" P TP > D TP."
)
# Remote tp is bigger: remote block_len size is smaller
assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads/tp_ratio, page_size, head_dim] and same dtype."
) # noqa: E501
# TP workers have same #blocks.
# TP workers that handhshake with same remote have same #blocks.
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
# Same number of regions/~layers.
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
......@@ -1710,7 +1788,7 @@ class NixlConnectorWorker:
)
cache.index_copy_(0, indices, permuted_blocks)
def blocksize_post_process(self, block_ids_per_ratio: dict[float, list[list[int]]]):
def blocksize_post_process(self, block_ids_per_ratio: dict[int, list[list[int]]]):
def _process_local_gt_remote(blocks_to_update, block_size_ratio):
n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
remote_block_size = block_size // block_size_ratio
......@@ -1840,7 +1918,7 @@ class NixlConnectorWorker:
notified_req_ids: set[str] = set()
for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs:
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
req_id, tp_size = notif.decode("utf-8").rsplit(":", 1)
if (
req_id not in self._reqs_to_send
and req_id not in self._reqs_to_process
......@@ -1853,9 +1931,22 @@ class NixlConnectorWorker:
)
continue
# NOTE: `tp_ratio` is the opposite when swapping local<>remote
n_consumers = int(tp_size)
tp_ratio = self.kv_topo.tp_ratio(n_consumers)
# Number of reads *per producer* to wait for.
# When remote D TP > local P TP we expect `tp_ratio` reads.
consumers_per_producer = (
-tp_ratio if n_consumers > self.world_size else 1
)
self.consumer_notification_counts_by_req[req_id] += 1
# Wait all consumers (D) to be done reading before freeing.
if self.consumer_notification_counts_by_req[req_id] == int(tp_ratio):
if (
self.consumer_notification_counts_by_req[req_id]
== consumers_per_producer
):
notified_req_ids.add(req_id)
del self.consumer_notification_counts_by_req[req_id]
self._reqs_to_process.remove(req_id)
......@@ -1872,7 +1963,7 @@ class NixlConnectorWorker:
"""
done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()):
in_progress = False
in_progress = []
for handle in handles:
try:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
......@@ -1882,7 +1973,7 @@ class NixlConnectorWorker:
self.xfer_stats.record_transfer(res)
self.nixl_wrapper.release_xfer_handle(handle)
elif xfer_state == "PROC":
in_progress = True
in_progress.append(handle)
continue
else:
logger.error(
......@@ -1892,7 +1983,6 @@ class NixlConnectorWorker:
xfer_state,
)
self._handle_failed_transfer(req_id, handle)
in_progress = False
except Exception:
logger.exception(
"NIXL transfer exception for request %s. "
......@@ -1900,11 +1990,13 @@ class NixlConnectorWorker:
req_id,
)
self._handle_failed_transfer(req_id, handle)
in_progress = False
if not in_progress:
# Only report request as completed when all transfers are done.
done_req_ids.add(req_id)
del transfers[req_id]
else:
transfers[req_id] = in_progress
return done_req_ids
def _handle_failed_transfer(self, req_id: str, handle: int):
......@@ -1982,18 +2074,62 @@ class NixlConnectorWorker:
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
assert meta.remote is not None
logger.debug(
"Remote agent %s available, calling _read_blocks for req %s",
meta.remote.engine_id,
req_id,
)
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote.engine_id,
remote_request_id=meta.remote.request_id,
local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote.block_ids,
remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
meta.remote.engine_id
)
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id)
# D may have to perform multiple reads from different remote ranks.
for i, remote_rank in enumerate(remote_ranks):
if self.use_mla and tp_ratio < 0 and i > 0:
# MLA opt: when P TP > D TP, only a single read is executed for
# the first remote rank (cache is duplicated)..
break
remote_block_size = self.kv_topo.remote_block_size[meta.remote.engine_id]
logger.debug(
"Remote agent %s available, calling _read_blocks"
" on remote rank %s with remote block size %s for req %s",
meta.remote.engine_id,
remote_rank,
remote_block_size,
req_id,
)
# Get side handles.
if tp_ratio < 0 and not self.use_mla:
assert remote_block_size == self.block_size
# Remote tp_size > local tp_size: we must perform multiple
# reads. Get the memory chunk onto which we will write to.
local_xfer_side_handle = self.src_xfer_handles_by_tp_ratio[tp_ratio][i]
else:
# Single read from remote, we write to the whole memory region.
# Also handle remote block size different from local block size.
local_xfer_side_handle = self.src_xfer_handles_by_block_size[
remote_block_size
]
# Destination handle: remote_engine_id -> remote_rank -> handle.
remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][
remote_rank
]
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote.engine_id,
remote_request_id=meta.remote.request_id,
local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote.block_ids,
remote_rank=remote_rank,
local_xfer_side_handle=local_xfer_side_handle,
remote_xfer_side_handle=remote_xfer_side_handle,
)
if self.use_mla and tp_ratio < 0:
# ..but we still need to notify the other remote ranks that we
# have the blocks we need so they can update the request state.
notif_id = f"{req_id}:{self.world_size}".encode()
remote_agents = self._remote_agents[meta.remote.engine_id]
for rank_to_notify, agent in remote_agents.items():
if rank_to_notify != remote_rank:
self.nixl_wrapper.send_notif(agent, notif_msg=notif_id)
def _read_blocks(
self,
......@@ -2002,7 +2138,14 @@ class NixlConnectorWorker:
dst_engine_id: str,
request_id: str,
remote_request_id: str,
remote_rank: int,
local_xfer_side_handle: int,
remote_xfer_side_handle: int,
):
"""
Post a READ point-to-point xfer request from a single local worker to
a single remote worker.
"""
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
if block_size_ratio > 1:
local_block_ids = self.get_mapped_blocks(
......@@ -2031,18 +2174,14 @@ class NixlConnectorWorker:
# saturate IB with heterogeneous TP sizes. We should remove the staging
# blocks until we are ready.
# Number of D TP workers that will read from dst P. Propagate tp_ratio
# Number of D TP workers that will read from dst P. Propagate info
# on notification so that dst worker can wait before freeing blocks.
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id)
notif_id = f"{remote_request_id}:{tp_ratio}".encode()
notif_id = f"{remote_request_id}:{self.world_size}".encode()
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
remote_rank = self.kv_topo.get_target_remote_rank_from_engine_id(
dst_engine_id
)
agent_name = self._remote_agents[dst_engine_id][remote_rank]
try:
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
......@@ -2062,13 +2201,6 @@ class NixlConnectorWorker:
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
# Get side handles.
remote_block_size = self.kv_topo.remote_block_size[dst_engine_id]
local_xfer_side_handle = self.src_xfer_side_handles.get(
remote_block_size, self.src_xfer_side_handle
)
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches.
......@@ -2230,7 +2362,7 @@ class NixlConnectorWorker:
block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange
).tolist()
def get_backend_aware_kv_block_len(self, layer_idx: int):
def get_backend_aware_kv_block_len(self, layer_idx: int) -> int:
"""
Get the block length for one K/V element (K and V have the same size).
......@@ -2276,11 +2408,16 @@ class NixlConnectorWorker:
for handle in handles:
self.nixl_wrapper.release_xfer_handle(handle)
self._recving_transfers.clear()
if self.src_xfer_side_handle:
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
self.src_xfer_side_handle = 0
for dst_xfer_side_handle in self.dst_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
for handle in self.src_xfer_handles_by_block_size.values():
self.nixl_wrapper.release_dlist_handle(handle)
self.src_xfer_handles_by_block_size.clear()
for handles in self.src_xfer_handles_by_tp_ratio.values():
for handle in handles:
self.nixl_wrapper.release_dlist_handle(handle)
self.src_xfer_handles_by_tp_ratio.clear()
for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
for dst_xfer_side_handle in dst_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
self.dst_xfer_side_handles.clear()
for remote_agents in self._remote_agents.values():
for agent_name in remote_agents.values():
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Iterable, Iterator
from collections.abc import Iterable
from dataclasses import dataclass
from itertools import islice
from typing import Any, ClassVar
......@@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
......@@ -516,23 +517,3 @@ class OffloadingConnectorWorker:
del self._store_jobs[req_id]
return finished_sending, finished_recving
def yield_req_data(
scheduler_output,
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
"""
Yields:
(req_id, new_block_id_groups, preempted)
"""
# new requests
for req_data in scheduler_output.scheduled_new_reqs:
yield req_data.req_id, req_data.block_ids, False
# cached requests
cached_reqs = scheduler_output.scheduled_cached_reqs
yield from zip(
cached_reqs.req_ids,
cached_reqs.new_block_ids,
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
)
......@@ -1007,10 +1007,17 @@ class GroupCoordinator:
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch(
hidden_states, router_logits, is_sequence_parallel
return self.device_communicator.dispatch( # type: ignore[call-arg]
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors,
)
else:
return hidden_states, router_logits
......
......@@ -93,6 +93,7 @@ from vllm.transformers_utils.utils import is_cloud_storage
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_ip
from vllm.utils.torch_utils import resolve_kv_cache_dtype_string
from vllm.v1.sample.logits_processor import LogitsProcessor
if TYPE_CHECKING:
......@@ -106,6 +107,7 @@ else:
LoadFormats = Any
UsageContext = Any
logger = init_logger(__name__)
# object is used to allow for special typing forms
......@@ -406,8 +408,9 @@ class EngineArgs:
data_parallel_external_lb: bool = False
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
all2all_backend: str | None = ParallelConfig.all2all_backend
all2all_backend: str = ParallelConfig.all2all_backend
enable_dbo: bool = ParallelConfig.enable_dbo
ubatch_size: int = ParallelConfig.ubatch_size
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
disable_nccl_for_dp_synchronization: bool = (
......@@ -520,6 +523,7 @@ class EngineArgs:
enable_layerwise_nvtx_tracing: bool = (
ObservabilityConfig.enable_layerwise_nvtx_tracing
)
enable_mfu_metrics: bool = ObservabilityConfig.enable_mfu_metrics
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
......@@ -841,6 +845,10 @@ class EngineArgs:
"--all2all-backend", **parallel_kwargs["all2all_backend"]
)
parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
parallel_group.add_argument(
"--ubatch-size",
**parallel_kwargs["ubatch_size"],
)
parallel_group.add_argument(
"--dbo-decode-token-threshold",
**parallel_kwargs["dbo_decode_token_threshold"],
......@@ -1035,6 +1043,10 @@ class EngineArgs:
"--enable-layerwise-nvtx-tracing",
**observability_kwargs["enable_layerwise_nvtx_tracing"],
)
observability_group.add_argument(
"--enable-mfu-metrics",
**observability_kwargs["enable_mfu_metrics"],
)
# Scheduler arguments
scheduler_kwargs = get_kwargs(SchedulerConfig)
......@@ -1356,12 +1368,17 @@ class EngineArgs:
f"dcp_size={self.decode_context_parallel_size}."
)
# Resolve "auto" kv_cache_dtype to actual value from model config
resolved_cache_dtype = resolve_kv_cache_dtype_string(
self.kv_cache_dtype, model_config
)
cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
kv_cache_memory_bytes=self.kv_cache_memory_bytes,
swap_space=self.swap_space,
cache_dtype=self.kv_cache_dtype,
cache_dtype=resolved_cache_dtype,
is_attention_free=model_config.is_attention_free,
num_gpu_blocks_override=self.num_gpu_blocks_override,
sliding_window=sliding_window,
......@@ -1557,6 +1574,7 @@ class EngineArgs:
enable_expert_parallel=self.enable_expert_parallel,
all2all_backend=self.all2all_backend,
enable_dbo=self.enable_dbo,
ubatch_size=self.ubatch_size,
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
disable_nccl_for_dp_synchronization=self.disable_nccl_for_dp_synchronization,
......@@ -1676,6 +1694,7 @@ class EngineArgs:
kv_cache_metrics_sample=self.kv_cache_metrics_sample,
cudagraph_metrics=self.cudagraph_metrics,
enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing,
enable_mfu_metrics=self.enable_mfu_metrics,
)
# Compilation config overrides
......
......@@ -2,11 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import contextlib
import copy
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from contextlib import AsyncExitStack
from dataclasses import replace
from typing import TYPE_CHECKING, Union
from openai.types.responses.response_function_tool_call_output_item import (
......@@ -164,6 +166,12 @@ class SimpleContext(ConversationContext):
def __init__(self):
self.last_output = None
# Accumulated final output for streaming mode
self._accumulated_text: str = ""
self._accumulated_token_ids: list[int] = []
self._accumulated_logprobs: list = []
self.num_prompt_tokens = 0
self.num_output_tokens = 0
self.num_cached_tokens = 0
......@@ -183,6 +191,13 @@ class SimpleContext(ConversationContext):
self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or [])
# Accumulate text, token_ids, and logprobs for streaming mode
delta_output = output.outputs[0]
self._accumulated_text += delta_output.text
self._accumulated_token_ids.extend(delta_output.token_ids)
if delta_output.logprobs is not None:
self._accumulated_logprobs.extend(delta_output.logprobs)
if len(self.input_messages) == 0:
output_prompt = output.prompt or ""
output_prompt_token_ids = output.prompt_token_ids or []
......@@ -194,11 +209,26 @@ class SimpleContext(ConversationContext):
)
self.output_messages.append(
ResponseRawMessageAndToken(
message=output.outputs[0].text,
tokens=output.outputs[0].token_ids,
message=delta_output.text,
tokens=delta_output.token_ids,
)
)
@property
def final_output(self) -> RequestOutput | None:
"""Return the final output, with complete text/token_ids/logprobs."""
if self.last_output is not None and self.last_output.outputs:
assert isinstance(self.last_output, RequestOutput)
final_output = copy.copy(self.last_output)
# copy inner item to avoid modify last_output
final_output.outputs = [replace(item) for item in self.last_output.outputs]
final_output.outputs[0].text = self._accumulated_text
final_output.outputs[0].token_ids = tuple(self._accumulated_token_ids)
if self._accumulated_logprobs:
final_output.outputs[0].logprobs = self._accumulated_logprobs
return final_output
return self.last_output
def append_tool_output(self, output) -> None:
raise NotImplementedError("Should not be called.")
......@@ -267,12 +297,40 @@ class ParsableContext(ConversationContext):
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
self.input_messages: list[ResponseRawMessageAndToken] = []
self.output_messages: list[ResponseRawMessageAndToken] = []
def append_output(self, output: RequestOutput) -> None:
self.num_prompt_tokens = len(output.prompt_token_ids or [])
self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or [])
self.parser.process(output.outputs[0])
# only store if enable_response_messages is True, save memory
if self.request.enable_response_messages:
output_prompt = output.prompt or ""
output_prompt_token_ids = output.prompt_token_ids or []
if len(self.input_messages) == 0:
self.input_messages.append(
ResponseRawMessageAndToken(
message=output_prompt,
tokens=output_prompt_token_ids,
)
)
else:
self.output_messages.append(
ResponseRawMessageAndToken(
message=output_prompt,
tokens=output_prompt_token_ids,
)
)
self.output_messages.append(
ResponseRawMessageAndToken(
message=output.outputs[0].text,
tokens=output.outputs[0].token_ids,
)
)
def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None:
self.parser.response_messages.extend(output)
......
......@@ -18,6 +18,7 @@ from vllm.beam_search import (
create_sort_beams_key_function,
)
from vllm.config import (
AttentionConfig,
CompilationConfig,
PoolerConfig,
ProfilerConfig,
......@@ -175,6 +176,10 @@ class LLM:
compilation_config: Either an integer or a dictionary. If it is an
integer, it is used as the mode of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration.
attention_config: Configuration for attention mechanisms. Can be a
dictionary or an AttentionConfig instance. If a dictionary, it will
be converted to an AttentionConfig. Allows specifying the attention
backend and other attention-related settings.
**kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
Note:
......@@ -213,6 +218,7 @@ class LLM:
| StructuredOutputsConfig
| None = None,
profiler_config: dict[str, Any] | ProfilerConfig | None = None,
attention_config: dict[str, Any] | AttentionConfig | None = None,
kv_cache_memory_bytes: int | None = None,
compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
logits_processors: list[str | type[LogitsProcessor]] | None = None,
......@@ -252,51 +258,28 @@ class LLM:
if hf_overrides is None:
hf_overrides = {}
if compilation_config is not None:
if isinstance(compilation_config, int):
compilation_config_instance = CompilationConfig(
mode=CompilationMode(compilation_config)
)
elif isinstance(compilation_config, dict):
compilation_config_instance = CompilationConfig(
**{
k: v
for k, v in compilation_config.items()
if is_init_field(CompilationConfig, k)
}
)
else:
compilation_config_instance = compilation_config
else:
compilation_config_instance = CompilationConfig()
if structured_outputs_config is not None:
if isinstance(structured_outputs_config, dict):
structured_outputs_instance = StructuredOutputsConfig(
**{
k: v
for k, v in structured_outputs_config.items()
if is_init_field(StructuredOutputsConfig, k)
}
)
else:
structured_outputs_instance = structured_outputs_config
else:
structured_outputs_instance = StructuredOutputsConfig()
if profiler_config is not None:
if isinstance(profiler_config, dict):
profiler_config_instance = ProfilerConfig(
**{
k: v
for k, v in profiler_config.items()
if is_init_field(ProfilerConfig, k)
}
)
else:
profiler_config_instance = profiler_config
def _make_config(value: Any, cls: type[_R]) -> _R:
"""Convert dict/None/instance to a config instance."""
if value is None:
return cls()
if isinstance(value, dict):
return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)}) # type: ignore[arg-type]
return value
if isinstance(compilation_config, int):
compilation_config_instance = CompilationConfig(
mode=CompilationMode(compilation_config)
)
else:
profiler_config_instance = ProfilerConfig()
compilation_config_instance = _make_config(
compilation_config, CompilationConfig
)
structured_outputs_instance = _make_config(
structured_outputs_config, StructuredOutputsConfig
)
profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
attention_config_instance = _make_config(attention_config, AttentionConfig)
# warn about single-process data parallel usage.
_dp_size = int(kwargs.get("data_parallel_size", 1))
......@@ -341,6 +324,7 @@ class LLM:
pooler_config=pooler_config,
structured_outputs_config=structured_outputs_instance,
profiler_config=profiler_config_instance,
attention_config=attention_config_instance,
compilation_config=compilation_config_instance,
logits_processors=logits_processors,
**kwargs,
......
......@@ -17,21 +17,20 @@ from argparse import Namespace
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Annotated, Any, Literal
from typing import Annotated, Any
import model_hosting_container_standards.sagemaker as sagemaker_standards
import pydantic
import uvloop
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Query, Request
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, Headers, MutableHeaders, State
from starlette.types import ASGIApp, Message, Receive, Scope, Send
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.protocol import (
......@@ -639,97 +638,6 @@ async def create_translations(
return StreamingResponse(content=generator, media_type="text/event-stream")
if envs.VLLM_SERVER_DEV_MODE:
logger.warning(
"SECURITY WARNING: Development endpoints are enabled! "
"This should NOT be used in production!"
)
PydanticVllmConfig = pydantic.TypeAdapter(VllmConfig)
@router.get("/server_info")
async def show_server_info(
raw_request: Request,
config_format: Annotated[Literal["text", "json"], Query()] = "text",
):
vllm_config: VllmConfig = raw_request.app.state.vllm_config
server_info = {
"vllm_config": str(vllm_config)
if config_format == "text"
else PydanticVllmConfig.dump_python(vllm_config, mode="json", fallback=str)
# fallback=str is needed to handle e.g. torch.dtype
}
return JSONResponse(content=server_info)
@router.post("/reset_prefix_cache")
async def reset_prefix_cache(
raw_request: Request,
reset_running_requests: bool = Query(default=False),
reset_external: bool = Query(default=False),
):
"""
Reset the local prefix cache.
Optionally, if the query parameter `reset_external=true`
also resets the external (connector-managed) prefix cache.
Note that we currently do not check if the prefix cache
is successfully reset in the API server.
Example:
POST /reset_prefix_cache?reset_external=true
"""
logger.info("Resetting prefix cache...")
await engine_client(raw_request).reset_prefix_cache(
reset_running_requests, reset_external
)
return Response(status_code=200)
@router.post("/reset_mm_cache")
async def reset_mm_cache(raw_request: Request):
"""
Reset the multi-modal cache. Note that we currently do not check if the
multi-modal cache is successfully reset in the API server.
"""
logger.info("Resetting multi-modal cache...")
await engine_client(raw_request).reset_mm_cache()
return Response(status_code=200)
@router.post("/collective_rpc")
async def collective_rpc(raw_request: Request):
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"JSON decode error: {e}",
) from e
method = body.get("method")
if method is None:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail="Missing 'method' in request body",
)
# For security reason, only serialized string args/kwargs are passed.
# User-defined `method` is responsible for deserialization if needed.
args: list[str] = body.get("args", [])
kwargs: dict[str, str] = body.get("kwargs", {})
timeout: float | None = body.get("timeout")
results = await engine_client(raw_request).collective_rpc(
method=method, timeout=timeout, args=tuple(args), kwargs=kwargs
)
if results is None:
return Response(status_code=200)
response: list[Any] = []
for result in results:
if result is None or isinstance(result, dict | list):
response.append(result)
else:
response.append(str(result))
return JSONResponse(content={"results": response})
def load_log_config(log_config_file: str | None) -> dict | None:
if not log_config_file:
return None
......@@ -1174,6 +1082,9 @@ async def init_app_state(
if "generate" in supported_tasks
else None
)
# Warm up chat template processing to avoid first-request latency
if state.openai_serving_chat is not None:
await state.openai_serving_chat.warmup()
state.openai_serving_completion = (
OpenAIServingCompletion(
engine_client,
......
......@@ -3,7 +3,11 @@
import logging
from collections.abc import Callable
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai.types.responses.response_output_item import McpCall
from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_output_text import ResponseOutputText
from openai.types.responses.response_reasoning_item import (
......@@ -11,6 +15,7 @@ from openai.types.responses.response_reasoning_item import (
ResponseReasoningItem,
)
from vllm.entrypoints.constants import MCP_PREFIX
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
from vllm.outputs import CompletionOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
......@@ -111,6 +116,37 @@ class ResponsesParser:
return self
def make_response_output_items_from_parsable_context(
self,
) -> list[ResponseOutputItem]:
"""Given a list of sentences, construct ResponseOutput Items."""
response_messages = self.response_messages[self.num_init_messages :]
output_messages: list[ResponseOutputItem] = []
for message in response_messages:
if not isinstance(message, ResponseFunctionToolCallOutputItem):
output_messages.append(message)
else:
if len(output_messages) == 0:
raise ValueError(
"Cannot have a FunctionToolCallOutput before FunctionToolCall."
)
if isinstance(output_messages[-1], ResponseFunctionToolCall):
mcp_message = McpCall(
id=f"{MCP_PREFIX}{random_uuid()}",
arguments=output_messages[-1].arguments,
name=output_messages[-1].name,
server_label=output_messages[
-1
].name, # TODO: store the server label
type="mcp_call",
status="completed",
output=message.output,
# TODO: support error output
)
output_messages[-1] = mcp_message
return output_messages
def get_responses_parser_for_simple_context(
*,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment