Commit d76fc11e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.15.0rc1' into v0.15.0rc1-dev

parents 38166ec4 58996f35
......@@ -459,7 +459,7 @@ def test_eagle_correctness(
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("5.0.0.dev")
required = Version("5.0.0")
if installed < required:
pytest.skip(
"Eagle3 with the Transformers modeling backend requires "
......
......@@ -3167,13 +3167,13 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
class CPUDNNLGEMMHandler:
def __init__(self) -> None:
self.handler: int | None = None
self.handler_tensor: torch.Tensor | None = None
self.n = -1
self.k = -1
def __del__(self):
if self.handler is not None:
torch.ops._C.release_dnnl_matmul_handler(self.handler)
if self.handler_tensor is not None:
torch.ops._C.release_dnnl_matmul_handler(self.handler_tensor.item())
_supports_onednn = bool(hasattr(torch.ops._C, "create_onednn_mm_handler"))
......@@ -3189,8 +3189,10 @@ def create_onednn_mm(
) -> CPUDNNLGEMMHandler:
handler = CPUDNNLGEMMHandler()
handler.k, handler.n = weight.size()
handler.handler = torch.ops._C.create_onednn_mm_handler(
weight, primitive_cache_size
# store the handler pointer in a tensor it doesn't get inlined
handler.handler_tensor = torch.tensor(
torch.ops._C.create_onednn_mm_handler(weight, primitive_cache_size),
dtype=torch.int64,
)
return handler
......@@ -3202,7 +3204,7 @@ def onednn_mm(
) -> torch.Tensor:
output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype)
torch.ops._C.onednn_mm(
output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler
output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler_tensor
)
return output
......@@ -3218,8 +3220,17 @@ def create_onednn_scaled_mm(
) -> CPUDNNLGEMMHandler:
handler = CPUDNNLGEMMHandler()
handler.k, handler.n = weight.size()
handler.handler = torch.ops._C.create_onednn_scaled_mm_handler(
weight, weight_scales, output_type, dynamic_quant, use_azp, primitive_cache_size
# store the handler pointer in a tensor so it doesn't get inlined
handler.handler_tensor = torch.tensor(
torch.ops._C.create_onednn_scaled_mm_handler(
weight,
weight_scales,
output_type,
dynamic_quant,
use_azp,
primitive_cache_size,
),
dtype=torch.int64,
)
return handler
......@@ -3272,7 +3283,13 @@ def onednn_scaled_mm(
bias: torch.Tensor | None,
) -> torch.Tensor:
torch.ops._C.onednn_scaled_mm(
output, x, input_scale, input_zp, input_zp_adj, bias, dnnl_handler.handler
output,
x,
input_scale,
input_zp,
input_zp_adj,
bias,
dnnl_handler.handler_tensor,
)
return output
......
......@@ -281,9 +281,10 @@ class DynamicShapesConfig:
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
"""
assume_32_bit_indexing: bool = True
assume_32_bit_indexing: bool = False
"""
whether all tensor sizes can use 32 bit indexing.
`True` requires PyTorch 2.10+
"""
def compute_hash(self) -> str:
......
......@@ -34,6 +34,7 @@ MTPModelTypes = Literal[
"mimo_mtp",
"glm4_moe_mtp",
"glm4_moe_lite_mtp",
"glm_ocr_mtp",
"ernie_mtp",
"exaone_moe_mtp",
"qwen3_next_mtp",
......@@ -221,6 +222,17 @@ class SpeculativeConfig:
}
)
if hf_config.architectures[0] == "GlmOcrForConditionalGeneration":
hf_config.model_type = "glm_ocr_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["GlmOcrMTPModel"],
}
)
if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp":
......
......@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
return buffer
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -84,6 +84,34 @@ class NaiveAll2AllManager(All2AllManagerBase):
return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_weights = self.naive_multicast(
topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_ids = self.naive_multicast(
topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel
)
return hidden_states, topk_weights, topk_ids
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......@@ -114,7 +142,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
super().__init__(cpu_group)
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -148,6 +176,46 @@ class AgRsAll2AllManager(All2AllManagerBase):
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
return gathered_tensors[0], gathered_tensors[1]
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
tensors_to_gather = [hidden_states, topk_weights, topk_ids]
if extra_tensors is not None:
tensors_to_gather.extend(extra_tensors)
gathered_tensors = dist_group.all_gatherv(
tensors_to_gather,
dim=0,
sizes=sizes,
)
hidden_states = gathered_tensors[0]
topk_weights = gathered_tensors[1]
topk_ids = gathered_tensors[2]
if extra_tensors is None:
return hidden_states, topk_weights, topk_ids
return hidden_states, topk_weights, topk_ids, gathered_tensors[3:]
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......@@ -216,7 +284,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
)
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -225,6 +293,19 @@ class PPLXAll2AllManager(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......@@ -264,7 +345,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def get_handle(self, kwargs):
raise NotImplementedError
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -273,6 +354,19 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Any
from weakref import WeakValueDictionary
import torch
......@@ -64,13 +63,32 @@ class All2AllManagerBase:
# and reuse it for the same config.
raise NotImplementedError
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> Any:
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
......@@ -280,7 +298,7 @@ class DeviceCommunicatorBase:
for module in moe_modules:
module.maybe_init_modular_kernel()
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -294,8 +312,29 @@ class DeviceCommunicatorBase:
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
if extra_tensors is not None:
return hidden_states, router_logits, extra_tensors
return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
if extra_tensors is not None:
return hidden_states, topk_weights, topk_ids, extra_tensors
return hidden_states, topk_weights, topk_ids
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......
......@@ -130,29 +130,65 @@ class CpuCommunicator(DeviceCommunicatorBase):
) -> dict[str, torch.Tensor | Any]:
return self.dist_module.recv_tensor_dict(src)
def dispatch( # type: ignore[override]
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(
hidden_states, is_sequence_parallel
return self.all2all_manager.combine(
hidden_states,
is_sequence_parallel,
)
return hidden_states
class _CPUSHMDistributed:
......
......@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list
def dispatch( # type: ignore[override]
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -332,19 +332,52 @@ class CudaCommunicator(DeviceCommunicatorBase):
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(
hidden_states, is_sequence_parallel
return self.all2all_manager.combine(
hidden_states,
is_sequence_parallel,
)
return hidden_states
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch.distributed as dist
from flashinfer.comm.mnnvl import CommBackend as CommBackend
......@@ -23,5 +25,14 @@ class CustomCommunicator(CommBackend):
dist.all_gather_object(gathered, data, group=self._group)
return gathered
# NOTE(rob): CommBackend is an abstract class, and bcast/barrier
# are unimplemented on vLLM side. If we need to utilize these
# methods in the future, can create a concrete implementation.
def bcast(self, data: Any, root: int) -> Any:
raise NotImplementedError
def barrier(self) -> None:
raise NotImplementedError
def Split(self, color: int, key: int) -> "CustomCommunicator":
return self
......@@ -196,26 +196,62 @@ class XpuCommunicator(DeviceCommunicatorBase):
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
dist.broadcast(input_, src=src, group=self.device_group)
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(
hidden_states, is_sequence_parallel
return self.all2all_manager.combine(
hidden_states,
is_sequence_parallel,
)
return hidden_states
......@@ -302,7 +302,7 @@ class NixlConnector(KVConnectorBase_V1):
@property
def prefer_cross_layer_blocks(self) -> bool:
backend = get_current_attn_backend(self._vllm_config)
if backend().get_name() not in (
if backend.get_name() not in (
"FLASH_ATTN",
"FLASHINFER",
):
......
......@@ -1003,7 +1003,7 @@ class GroupCoordinator:
if self.device_communicator is not None:
self.device_communicator.prepare_communication_buffer_for_model(model)
def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -1014,7 +1014,7 @@ class GroupCoordinator:
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch( # type: ignore[call-arg]
return self.device_communicator.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
......@@ -1023,6 +1023,28 @@ class GroupCoordinator:
else:
return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors,
)
else:
return hidden_states, topk_weights, topk_ids
def combine(
self, hidden_states, is_sequence_parallel: bool = False
) -> torch.Tensor:
......
......@@ -46,6 +46,9 @@ from vllm.multimodal.inputs import (
MultiModalBatchedField,
MultiModalFlatField,
MultiModalSharedField,
VisionChunk,
VisionChunkImage,
VisionChunkVideo,
)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
......@@ -336,7 +339,9 @@ ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
ChatTemplateContentFormat = Literal["string", "openai"]
ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
ModalityStr = Literal[
"image", "audio", "video", "image_embeds", "audio_embeds", "vision_chunk"
]
_T = TypeVar("_T")
......@@ -449,6 +454,78 @@ def _get_embeds_data(
raise NotImplementedError(type(data_items))
def rebuild_mm_uuids_from_mm_data(
mm_uuids: MultiModalUUIDDict,
mm_data: MultiModalDataDict,
) -> MultiModalUUIDDict:
"""Rebuild mm_uuids after vision_chunk processing.
When videos are split into chunks, the original UUIDs need to be updated
to reflect the new UUIDs generated for each chunk.
Args:
mm_uuids: Original UUIDs dictionary
mm_data: Processed multimodal data with vision_chunk items
Returns:
Updated UUIDs dictionary with chunk UUIDs
"""
vision_chunks = mm_data.get("vision_chunk")
if vision_chunks is None:
return mm_uuids
new_uuids = dict(mm_uuids)
vision_chunk_uuids = []
for item in vision_chunks:
# vision_chunk items are always dicts (VisionChunkImage/VisionChunkVideo)
assert isinstance(item, dict)
uuid_val = item.get("uuid")
if uuid_val is not None:
vision_chunk_uuids.append(uuid_val)
if vision_chunk_uuids:
new_uuids["vision_chunk"] = vision_chunk_uuids
return new_uuids
def build_video_prompts_from_mm_data(
mm_data: MultiModalDataDict,
) -> list[str]:
"""Build video prompts from vision_chunk data.
Collects prompts from video chunks and groups them by video_idx.
Args:
mm_data: Processed multimodal data with vision_chunk items
Returns:
List of video prompts, one per video.
"""
vision_chunks = mm_data.get("vision_chunk")
if vision_chunks is None:
return []
# Group chunks by video_idx
video_prompts_dict: dict[int, list[str]] = defaultdict(list)
for item in vision_chunks:
# vision_chunk items are always dicts (VisionChunkImage/VisionChunkVideo)
assert isinstance(item, dict)
if item.get("type") == "video_chunk":
video_idx = item.get("video_idx", 0)
prompt = item.get("prompt", "")
video_prompts_dict[video_idx].append(prompt)
# Build prompts in video order
video_prompts = []
for video_idx in sorted(video_prompts_dict.keys()):
video_prompts.append("".join(video_prompts_dict[video_idx]))
return video_prompts
class BaseMultiModalItemTracker(ABC, Generic[_T]):
"""
Tracks multi-modal items in a given request and ensures that the number
......@@ -462,6 +539,13 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._model_config = model_config
self._items_by_modality = defaultdict[str, list[_T]](list)
# Track original modality for each vision_chunk item (image or video)
self._modality_order = defaultdict[str, list[str]](list)
@cached_property
def use_unified_vision_chunk_modality(self) -> bool:
"""Check if model uses unified vision_chunk modality for images/videos."""
return getattr(self._model_config.hf_config, "use_unified_vision_chunk", False)
@property
def model_config(self) -> ModelConfig:
......@@ -499,11 +583,31 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
media.
"""
input_modality = modality.replace("_embeds", "")
num_items = len(self._items_by_modality[modality]) + 1
original_modality = modality
use_vision_chunk = (
self.use_unified_vision_chunk_modality
and original_modality in ["video", "image"]
)
# If use_unified_vision_chunk_modality is enabled,
# map image/video to vision_chunk
if use_vision_chunk:
# To avoid validation fail
# because models with use_unified_vision_chunk_modality=True
# will only accept vision_chunk modality.
input_modality = "vision_chunk"
num_items = len(self._items_by_modality[input_modality]) + 1
else:
num_items = len(self._items_by_modality[original_modality]) + 1
self.mm_processor.validate_num_items(input_modality, num_items)
self._items_by_modality[modality].append(item)
# Track original modality for vision_chunk items
if use_vision_chunk:
self._items_by_modality[input_modality].append(item) # type: ignore
self._modality_order["vision_chunk"].append(original_modality)
else:
self._items_by_modality[original_modality].append(item)
return self.model_cls.get_placeholder_str(modality, num_items)
......@@ -515,6 +619,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def _resolve_items(
items_by_modality: dict[str, list[tuple[object, str | None]]],
mm_processor: BaseMultiModalProcessor,
vision_chunk_modality_order: dict[str, list[str]],
) -> tuple[MultiModalDataDict, MultiModalUUIDDict]:
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed")
......@@ -546,6 +651,74 @@ def _resolve_items(
if "video" in items_by_modality:
mm_data["video"] = [data for data, uuid in items_by_modality["video"]]
mm_uuids["video"] = [uuid for data, uuid in items_by_modality["video"]]
if "vision_chunk" in items_by_modality:
# Process vision_chunk items - extract from (data, modality) tuples
# and convert to VisionChunk types with proper UUID handling
vision_chunk_items = items_by_modality["vision_chunk"]
modality_order = vision_chunk_modality_order.get("vision_chunk", [])
mm_uuids["vision_chunk"] = [
uuid for data, uuid in items_by_modality["vision_chunk"]
]
# Filter out None items (from asyncio.sleep(0) placeholders)
filtered_items = [
(idx, item)
for idx, item in enumerate(vision_chunk_items)
if item is not None
]
assert len(filtered_items) == len(modality_order), (
f"vision_chunk items ({len(filtered_items)}) and "
f"modality_order ({len(modality_order)}) must have same length"
)
processed_chunks: list[VisionChunk] = []
video_idx = 0
for i, (idx, item) in enumerate(filtered_items):
inner_modality = modality_order[i]
data, uuid = item
uuid_val = uuid if idx < len(mm_uuids["vision_chunk"]) else None
if inner_modality == "image":
# Cast data to proper type for image
# Use .media (PIL.Image) directly to avoid redundant
# bytes→PIL conversion in media_processor
if hasattr(data, "media"):
image_data = data.media # type: ignore[union-attr]
processed_chunks.append(
VisionChunkImage(type="image", image=image_data, uuid=uuid_val)
)
else:
processed_chunks.append(data) # type: ignore[arg-type]
elif inner_modality == "video":
# For video, we may need to split into chunks
# if processor supports it
# For now, just wrap as a video chunk placeholder
if hasattr(mm_processor, "split_video_chunks") and data is not None:
try:
video_uuid = uuid_val or random_uuid()
# video await result is (video_data, video_meta) tuple
if isinstance(data, tuple) and len(data) >= 1:
video_data = data[0]
else:
video_data = data
video_chunks = mm_processor.split_video_chunks(video_data)
for i, vc in enumerate(video_chunks):
processed_chunks.append(
VisionChunkVideo(
type="video_chunk",
video_chunk=vc["video_chunk"],
uuid=f"{video_uuid}-{i}",
video_idx=video_idx,
prompt=vc["prompt"],
)
)
video_idx += 1
except Exception as e:
logger.warning("Failed to split video chunks: %s", e)
processed_chunks.append(data) # type: ignore[arg-type]
else:
processed_chunks.append(data) # type: ignore[arg-type]
mm_data["vision_chunk"] = processed_chunks
return mm_data, mm_uuids
......@@ -557,7 +730,9 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[tuple[object, str | None]]
if not self._items_by_modality:
return None, None
return _resolve_items(dict(self._items_by_modality), self.mm_processor)
return _resolve_items(
dict(self._items_by_modality), self.mm_processor, self._modality_order
)
def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self)
......@@ -577,7 +752,9 @@ class AsyncMultiModalItemTracker(
for modality, coros in self._items_by_modality.items()
}
return _resolve_items(resolved_items_by_modality, self.mm_processor)
return _resolve_items(
resolved_items_by_modality, self.mm_processor, self._modality_order
)
def create_parser(self) -> "BaseMultiModalContentParser":
return AsyncMultiModalContentParser(self)
......
......@@ -265,6 +265,39 @@ def load_log_config(log_config_file: str | None) -> dict | None:
return None
def get_uvicorn_log_config(args: Namespace) -> dict | None:
"""
Get the uvicorn log config based on the provided arguments.
Priority:
1. If log_config_file is specified, use it
2. If disable_access_log_for_endpoints is specified, create a config with
the access log filter
3. Otherwise, return None (use uvicorn defaults)
"""
# First, try to load from file if specified
log_config = load_log_config(args.log_config_file)
if log_config is not None:
return log_config
# If endpoints to filter are specified, create a config with the filter
if args.disable_access_log_for_endpoints:
from vllm.logging_utils import create_uvicorn_log_config
# Parse comma-separated string into list
excluded_paths = [
p.strip()
for p in args.disable_access_log_for_endpoints.split(",")
if p.strip()
]
return create_uvicorn_log_config(
excluded_paths=excluded_paths,
log_level=args.uvicorn_log_level,
)
return None
class AuthenticationMiddleware:
"""
Pure ASGI middleware that authenticates each request by checking
......@@ -931,8 +964,8 @@ async def run_server_worker(
if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)
# Load logging config for uvicorn if specified
log_config = load_log_config(args.log_config_file)
# Get uvicorn log config (from file or with endpoint filter)
log_config = get_uvicorn_log_config(args)
if log_config is not None:
uvicorn_kwargs["log_config"] = log_config
......
......@@ -44,6 +44,7 @@ from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage,
DeltaToolCall,
ErrorResponse,
FunctionCall,
PromptTokenUsageInfo,
RequestResponseMetadata,
ToolCall,
......@@ -67,6 +68,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
......@@ -143,11 +145,6 @@ class OpenAIServingChat(OpenAIServing):
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_force_include_usage = enable_force_include_usage
self.default_sampling_params = self.model_config.get_diff_sampling_param()
if self.model_config.hf_config.model_type == "kimi_k2":
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss"
if self.use_harmony:
if "stop_token_ids" not in self.default_sampling_params:
......@@ -156,6 +153,16 @@ class OpenAIServingChat(OpenAIServing):
get_stop_tokens_for_assistant_actions()
)
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides = getattr(self.model_config, "hf_overrides", None)
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
isinstance(hf_overrides, dict)
and hf_overrides.get("model_type") == "kimi_k2"
):
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
# NOTE(woosuk): While OpenAI's chat completion API supports browsing
# for some models, currently vLLM doesn't support it. Please use the
# Responses API instead.
......@@ -247,8 +254,8 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request)
truncate_tool_call_ids(request)
maybe_serialize_tool_calls(request) # type: ignore[arg-type]
truncate_tool_call_ids(request) # type: ignore[arg-type]
validate_request_params(request)
# Check if tool parsing is unavailable (common condition)
......@@ -368,20 +375,18 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = self._get_prompt_components(engine_prompt)
prompt_text, _, _ = get_prompt_components(engine_prompt)
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
sub_request_id = (
request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
)
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=len(engine_prompt["prompt_token_ids"]),
prompt=engine_prompt,
default_sampling_params=self.default_sampling_params,
)
......@@ -454,6 +459,7 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response
tokenizer = self.renderer.tokenizer
assert tokenizer is not None
if request.stream:
return self.chat_completion_stream_generator(
......@@ -632,9 +638,11 @@ class OpenAIServingChat(OpenAIServing):
request_id: str,
model_name: str,
conversation: list[ConversationMessage],
tokenizer: TokenizerLike | None,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
from vllm.tokenizers.mistral import MistralTokenizer
created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True
......@@ -698,7 +706,7 @@ class OpenAIServingChat(OpenAIServing):
)
reasoning_parser = self.reasoning_parser(
tokenizer,
chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg]
chat_template_kwargs=chat_template_kwargs or {}, # type: ignore[call-arg]
)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
......@@ -955,8 +963,17 @@ class OpenAIServingChat(OpenAIServing):
index=i,
)
else:
# Generate ID based on tokenizer type
if isinstance(tokenizer, MistralTokenizer):
tool_call_id = MistralToolCall.generate_random_id()
else:
tool_call_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_choice_function_name,
idx=history_tool_call_cnt,
)
delta_tool_call = DeltaToolCall(
id=make_tool_call_id(),
id=tool_call_id,
type="function",
function=DeltaFunctionCall(
name=tool_choice_function_name,
......@@ -1387,9 +1404,11 @@ class OpenAIServingChat(OpenAIServing):
request_id: str,
model_name: str,
conversation: list[ConversationMessage],
tokenizer: TokenizerLike | None,
tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata,
) -> ErrorResponse | ChatCompletionResponse:
from vllm.tokenizers.mistral import MistralTokenizer
created_time = int(time.time())
final_res: RequestOutput | None = None
......@@ -1524,39 +1543,85 @@ class OpenAIServingChat(OpenAIServing):
tool_call_class = (
MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
)
if (not self.enable_auto_tools or not self.tool_parser) and (
if self.use_harmony:
# Harmony models already have parsed content and tool_calls
# through parse_chat_output. Respect its output directly.
message = ChatMessage(
role=role,
reasoning=reasoning,
content=content,
tool_calls=tool_calls if tool_calls else [],
)
elif (not self.enable_auto_tools or not self.tool_parser) and (
not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
and request.tool_choice != "required"
):
message = ChatMessage(role=role, reasoning=reasoning, content=content)
# if the request uses tools and specified a tool choice
elif (
request.tool_choice
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
):
assert tool_calls is not None and len(tool_calls) > 0
tool_call_class_items = []
for idx, tc in enumerate(tool_calls):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if tc.id:
tool_call_class_items.append(
tool_call_class(id=tc.id, function=tc)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_class_items.append(tool_call_class(function=tc))
else:
generated_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tc.name,
idx=history_tool_call_cnt + idx,
)
tool_call_class_items.append(
tool_call_class(id=generated_id, function=tc)
)
history_tool_call_cnt += 1
message = ChatMessage(
role=role,
reasoning=reasoning,
content="",
tool_calls=[tool_call_class(function=tc) for tc in tool_calls],
tool_calls=tool_call_class_items,
)
elif request.tool_choice and request.tool_choice == "required":
tool_call_class_items = []
assert tool_calls is not None and len(tool_calls) > 0
for tool_call in tool_calls:
tool_call_class_items.append(
tool_call_class(
id=make_tool_call_id(
for idx, tool_call in enumerate(tool_calls):
# Use native ID if available,
# otherwise generate ID with correct id_type
if tool_call.id:
tool_call_class_items.append(
tool_call_class(id=tool_call.id, function=tool_call)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_class_items.append(
tool_call_class(function=tool_call)
)
else:
generated_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_call.name,
idx=history_tool_call_cnt,
),
function=tool_call,
)
)
idx=history_tool_call_cnt + idx,
)
tool_call_class_items.append(
tool_call_class(id=generated_id, function=tool_call)
)
history_tool_call_cnt += 1
message = ChatMessage(
role=role,
......@@ -1582,17 +1647,35 @@ class OpenAIServingChat(OpenAIServing):
# call. The same is not true for named function calls
auto_tools_called = tool_calls is not None and len(tool_calls) > 0
if tool_calls:
tool_call_items = []
for idx, tc in enumerate(tool_calls):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if tc.id:
tool_call_items.append(
tool_call_class(id=tc.id, function=tc)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_items.append(tool_call_class(function=tc))
else:
generated_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tc.name,
idx=history_tool_call_cnt + idx,
)
tool_call_items.append(
tool_call_class(id=generated_id, function=tc)
)
history_tool_call_cnt += 1
message = ChatMessage(
role=role,
reasoning=reasoning,
content=content,
tool_calls=[
ToolCall(
function=tc,
type="function",
)
for tc in tool_calls
],
tool_calls=tool_call_items,
)
else:
......@@ -1701,13 +1784,11 @@ class OpenAIServingChat(OpenAIServing):
elif choice.message.tool_calls:
# For tool calls, log the function name and arguments
tool_call_descriptions = []
for tc in choice.message.tool_calls:
if hasattr(tc.function, "name") and hasattr(
tc.function, "arguments"
):
tool_call_descriptions.append(
f"{tc.function.name}({tc.function.arguments})"
)
for tc in choice.message.tool_calls: # type: ignore
function_call: FunctionCall = tc.function # type: ignore
tool_call_descriptions.append(
f"{function_call.name}({function_call.arguments})"
)
tool_calls_str = ", ".join(tool_call_descriptions)
output_text = f"[tool_calls: {tool_calls_str}]"
......@@ -1895,7 +1976,7 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request)
maybe_serialize_tool_calls(request) # type: ignore[arg-type]
# Add system message.
# NOTE: In Chat Completion API, browsing is enabled by default
......@@ -1913,7 +1994,7 @@ class OpenAIServingChat(OpenAIServing):
# Add developer message.
if request.tools:
dev_msg = get_developer_message(
tools=request.tools if should_include_tools else None
tools=request.tools if should_include_tools else None # type: ignore[arg-type]
)
messages.append(dev_msg)
......
......@@ -85,6 +85,12 @@ class FrontendArgs:
"""Log level for uvicorn."""
disable_uvicorn_access_log: bool = False
"""Disable uvicorn access log."""
disable_access_log_for_endpoints: str | None = None
"""Comma-separated list of endpoint paths to exclude from uvicorn access
logs. This is useful to reduce log noise from high-frequency endpoints
like health checks. Example: "/health,/metrics,/ping".
When set, access logs for requests to these paths will be suppressed
while keeping logs for other endpoints."""
allow_credentials: bool = False
"""Allow credentials."""
allowed_origins: list[str] = field(default_factory=lambda: ["*"])
......@@ -244,6 +250,11 @@ class FrontendArgs:
del frontend_kwargs["middleware"]["nargs"]
frontend_kwargs["middleware"]["default"] = []
# Special case: disable_access_log_for_endpoints is a single
# comma-separated string, not a list
if "nargs" in frontend_kwargs["disable_access_log_for_endpoints"]:
del frontend_kwargs["disable_access_log_for_endpoints"]["nargs"]
# Special case: Tool call parser shows built-in options.
valid_tool_parsers = list(ToolParserManager.list_registered())
parsers_str = ",".join(valid_tool_parsers)
......
......@@ -36,6 +36,7 @@ from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
......@@ -162,25 +163,12 @@ class OpenAIServingCompletion(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text, prompt_token_ids, prompt_embeds = (
self._get_prompt_components(engine_prompt)
)
input_length = None
if prompt_token_ids is not None:
input_length = len(prompt_token_ids)
elif prompt_embeds is not None:
input_length = len(prompt_embeds)
else:
raise NotImplementedError
if self.default_sampling_params is None:
self.default_sampling_params = {}
prompt_text, _, _ = get_prompt_components(engine_prompt)
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=input_length,
prompt=engine_prompt,
default_sampling_params=self.default_sampling_params,
)
......
......@@ -218,6 +218,10 @@ def get_logits_processors(
class FunctionCall(OpenAIBaseModel):
# Internal field to preserve native tool call ID from tool parser.
# Excluded from serialization to maintain OpenAI API compatibility
# (function object should only contain 'name' and 'arguments').
id: str | None = Field(default=None, exclude=True)
name: str
arguments: str
......
......@@ -64,13 +64,12 @@ from vllm.entrypoints.openai.translations.protocol import (
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
......@@ -95,11 +94,14 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest,
TokenizeResponse,
)
from vllm.entrypoints.utils import _validate_truncation_size, sanitize_message
from vllm.entrypoints.utils import (
_validate_truncation_size,
get_max_tokens,
sanitize_message,
)
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import (
PromptComponents,
get_prompt_components,
is_explicit_encoder_decoder_prompt,
)
......@@ -170,6 +172,7 @@ AnyResponse: TypeAlias = (
CompletionResponse
| ChatCompletionResponse
| EmbeddingResponse
| EmbeddingBytesResponse
| TranscriptionResponse
| TokenizeResponse
| PoolingResponse
......@@ -183,51 +186,21 @@ RequestT = TypeVar("RequestT", bound=AnyRequest)
@dataclass(kw_only=True)
class RequestProcessingMixin:
"""
Mixin for request processing,
handling prompt preparation and engine input.
"""
engine_prompts: list[TokensPrompt] | None = field(default_factory=list)
@dataclass(kw_only=True)
class ResponseGenerationMixin:
"""
Mixin for response generation,
managing result generators and final batch results.
"""
result_generator: (
AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
) = None
final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
default_factory=list
)
model_config = ConfigDict(arbitrary_types_allowed=True)
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
class ServeContext(Generic[RequestT]):
request: RequestT
raw_request: Request | None = None
model_name: str
request_id: str
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_prompts: list[TokensPrompt] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
)
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
@dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
pass
@dataclass(kw_only=True)
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
chat_template: str | None = None
chat_template_content_format: ChatTemplateContentFormatOption
model_config = ConfigDict(arbitrary_types_allowed=True)
class OpenAIServing:
......@@ -605,10 +578,7 @@ class OpenAIServing:
self,
ctx: ServeContext,
) -> AnyResponse | ErrorResponse:
generation: AsyncGenerator[AnyResponse | ErrorResponse, None]
generation = self._pipeline(ctx)
async for response in generation:
async for response in self._pipeline(ctx):
return response
return self.create_error_response("No response yielded from pipeline")
......@@ -667,9 +637,7 @@ class OpenAIServing:
ctx: ServeContext,
) -> ErrorResponse | None:
"""Schedule the request and get the result generator."""
generators: list[
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
] = []
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
trace_headers = (
......@@ -723,7 +691,7 @@ class OpenAIServing:
return self.create_error_response("Engine prompts not available")
num_prompts = len(ctx.engine_prompts)
final_res_batch: list[RequestOutput | PoolingRequestOutput | None]
final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts
if ctx.result_generator is None:
......@@ -1012,7 +980,7 @@ class OpenAIServing:
def _validate_input(
self,
request: AnyRequest,
request: object,
input_ids: list[int],
input_text: str,
) -> TokensPrompt:
......@@ -1323,7 +1291,7 @@ class OpenAIServing:
priority: int = 0,
**kwargs,
):
prompt_text, _, _ = self._get_prompt_components(engine_prompt)
prompt_text, _, _ = get_prompt_components(engine_prompt)
orig_priority = priority
sub_request = 0
......@@ -1374,10 +1342,12 @@ class OpenAIServing:
# yield context
# Create inputs for the next turn.
# Render the next prompt token ids.
# Render the next prompt token ids and update sampling_params.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
prompt_token_ids = context.render_for_completion()
engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
token_ids = context.render_for_completion()
engine_prompt = TokensPrompt(prompt_token_ids=token_ids)
sampling_params.max_tokens = self.max_model_len - len(token_ids)
elif isinstance(context, ParsableContext):
engine_prompts = await self._render_next_turn(
context.request,
......@@ -1389,19 +1359,19 @@ class OpenAIServing:
context.chat_template_content_format,
)
engine_prompt = engine_prompts[0]
prompt_text, _, _ = self._get_prompt_components(engine_prompt)
prompt_text, _, _ = get_prompt_components(engine_prompt)
sampling_params.max_tokens = get_max_tokens(
self.max_model_len,
context.request,
engine_prompt,
self.default_sampling_params, # type: ignore
)
# Update the sampling params.
sampling_params.max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]
)
# OPTIMIZATION
priority = orig_priority - 1
sub_request += 1
def _get_prompt_components(self, prompt: PromptType) -> PromptComponents:
return get_prompt_components(prompt)
def _log_inputs(
self,
request_id: str,
......@@ -1412,7 +1382,7 @@ class OpenAIServing:
if self.request_logger is None:
return
prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs)
self.request_logger.log_inputs(
request_id,
......@@ -1526,6 +1496,7 @@ class OpenAIServing:
# extract_tool_calls() returns a list of tool calls.
function_calls.extend(
FunctionCall(
id=tool_call.id,
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
......
......@@ -63,6 +63,7 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
make_tool_call_id,
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.mcp.tool_server import ToolServer
......@@ -115,6 +116,7 @@ from vllm.entrypoints.openai.responses.utils import (
extract_tool_types,
should_continue_final_message,
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
......@@ -250,6 +252,17 @@ class OpenAIServingResponses(OpenAIServing):
self.default_sampling_params["stop_token_ids"].extend(
get_stop_tokens_for_assistant_actions()
)
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides = getattr(self.model_config, "hf_overrides", None)
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
isinstance(hf_overrides, dict)
and hf_overrides.get("model_type") == "kimi_k2"
):
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
self.enable_auto_tools = enable_auto_tools
# set up tool use
self.tool_parser = self._get_tool_parser(
......@@ -423,8 +436,11 @@ class OpenAIServingResponses(OpenAIServing):
if maybe_error is not None:
return maybe_error
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]
default_max_tokens = get_max_tokens(
self.max_model_len,
request,
engine_prompt,
self.default_sampling_params,
)
sampling_params = request.to_sampling_params(
......@@ -954,25 +970,28 @@ class OpenAIServingResponses(OpenAIServing):
enable_auto_tools=self.enable_auto_tools,
tool_parser_cls=self.tool_parser,
)
if content:
output_text = ResponseOutputText(
text=content,
annotations=[], # TODO
type="output_text",
logprobs=(
self._create_response_logprobs(
token_ids=final_output.token_ids,
logprobs=final_output.logprobs,
tokenizer=tokenizer,
top_logprobs=request.top_logprobs,
)
if request.is_include_output_logprobs()
else None
),
)
if content or (self.use_harmony and tool_calls):
res_text_part = None
if content:
res_text_part = ResponseOutputText(
text=content,
annotations=[], # TODO
type="output_text",
logprobs=(
self._create_response_logprobs(
token_ids=final_output.token_ids,
logprobs=final_output.logprobs,
tokenizer=tokenizer,
top_logprobs=request.top_logprobs,
)
if request.is_include_output_logprobs()
else None
),
)
message_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=[output_text],
content=[res_text_part] if res_text_part else [],
role="assistant",
status="completed",
type="message",
......@@ -984,17 +1003,28 @@ class OpenAIServingResponses(OpenAIServing):
if message_item:
outputs.append(message_item)
if tool_calls:
tool_call_items = [
ResponseFunctionToolCall(
id=f"fc_{random_uuid()}",
call_id=f"call_{random_uuid()}",
type="function_call",
status="completed",
name=tool_call.name,
arguments=tool_call.arguments,
# We use a simple counter for history_tool_call_count because
# we don't track the history of tool calls in the Responses API yet.
# This means that the tool call index will start from 0 for each
# request.
tool_call_items = []
for history_tool_call_cnt, tool_call in enumerate(tool_calls):
tool_call_items.append(
ResponseFunctionToolCall(
id=f"fc_{random_uuid()}",
call_id=tool_call.id
if tool_call.id
else make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_call.name,
idx=history_tool_call_cnt,
),
type="function_call",
status="completed",
name=tool_call.name,
arguments=tool_call.arguments,
)
)
for tool_call in tool_calls
]
outputs.extend(tool_call_items)
return outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment