Commit d76fc11e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.15.0rc1' into v0.15.0rc1-dev

parents 38166ec4 58996f35
...@@ -459,7 +459,7 @@ def test_eagle_correctness( ...@@ -459,7 +459,7 @@ def test_eagle_correctness(
from packaging.version import Version from packaging.version import Version
installed = Version(transformers.__version__) installed = Version(transformers.__version__)
required = Version("5.0.0.dev") required = Version("5.0.0")
if installed < required: if installed < required:
pytest.skip( pytest.skip(
"Eagle3 with the Transformers modeling backend requires " "Eagle3 with the Transformers modeling backend requires "
......
...@@ -3167,13 +3167,13 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): ...@@ -3167,13 +3167,13 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
class CPUDNNLGEMMHandler: class CPUDNNLGEMMHandler:
def __init__(self) -> None: def __init__(self) -> None:
self.handler: int | None = None self.handler_tensor: torch.Tensor | None = None
self.n = -1 self.n = -1
self.k = -1 self.k = -1
def __del__(self): def __del__(self):
if self.handler is not None: if self.handler_tensor is not None:
torch.ops._C.release_dnnl_matmul_handler(self.handler) torch.ops._C.release_dnnl_matmul_handler(self.handler_tensor.item())
_supports_onednn = bool(hasattr(torch.ops._C, "create_onednn_mm_handler")) _supports_onednn = bool(hasattr(torch.ops._C, "create_onednn_mm_handler"))
...@@ -3189,8 +3189,10 @@ def create_onednn_mm( ...@@ -3189,8 +3189,10 @@ def create_onednn_mm(
) -> CPUDNNLGEMMHandler: ) -> CPUDNNLGEMMHandler:
handler = CPUDNNLGEMMHandler() handler = CPUDNNLGEMMHandler()
handler.k, handler.n = weight.size() handler.k, handler.n = weight.size()
handler.handler = torch.ops._C.create_onednn_mm_handler( # store the handler pointer in a tensor it doesn't get inlined
weight, primitive_cache_size handler.handler_tensor = torch.tensor(
torch.ops._C.create_onednn_mm_handler(weight, primitive_cache_size),
dtype=torch.int64,
) )
return handler return handler
...@@ -3202,7 +3204,7 @@ def onednn_mm( ...@@ -3202,7 +3204,7 @@ def onednn_mm(
) -> torch.Tensor: ) -> torch.Tensor:
output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype) output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype)
torch.ops._C.onednn_mm( torch.ops._C.onednn_mm(
output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler_tensor
) )
return output return output
...@@ -3218,8 +3220,17 @@ def create_onednn_scaled_mm( ...@@ -3218,8 +3220,17 @@ def create_onednn_scaled_mm(
) -> CPUDNNLGEMMHandler: ) -> CPUDNNLGEMMHandler:
handler = CPUDNNLGEMMHandler() handler = CPUDNNLGEMMHandler()
handler.k, handler.n = weight.size() handler.k, handler.n = weight.size()
handler.handler = torch.ops._C.create_onednn_scaled_mm_handler( # store the handler pointer in a tensor so it doesn't get inlined
weight, weight_scales, output_type, dynamic_quant, use_azp, primitive_cache_size handler.handler_tensor = torch.tensor(
torch.ops._C.create_onednn_scaled_mm_handler(
weight,
weight_scales,
output_type,
dynamic_quant,
use_azp,
primitive_cache_size,
),
dtype=torch.int64,
) )
return handler return handler
...@@ -3272,7 +3283,13 @@ def onednn_scaled_mm( ...@@ -3272,7 +3283,13 @@ def onednn_scaled_mm(
bias: torch.Tensor | None, bias: torch.Tensor | None,
) -> torch.Tensor: ) -> torch.Tensor:
torch.ops._C.onednn_scaled_mm( torch.ops._C.onednn_scaled_mm(
output, x, input_scale, input_zp, input_zp_adj, bias, dnnl_handler.handler output,
x,
input_scale,
input_zp,
input_zp_adj,
bias,
dnnl_handler.handler_tensor,
) )
return output return output
......
...@@ -281,9 +281,10 @@ class DynamicShapesConfig: ...@@ -281,9 +281,10 @@ class DynamicShapesConfig:
until this change picked up https://github.com/pytorch/pytorch/pull/169239. until this change picked up https://github.com/pytorch/pytorch/pull/169239.
""" """
assume_32_bit_indexing: bool = True assume_32_bit_indexing: bool = False
""" """
whether all tensor sizes can use 32 bit indexing. whether all tensor sizes can use 32 bit indexing.
`True` requires PyTorch 2.10+
""" """
def compute_hash(self) -> str: def compute_hash(self) -> str:
......
...@@ -34,6 +34,7 @@ MTPModelTypes = Literal[ ...@@ -34,6 +34,7 @@ MTPModelTypes = Literal[
"mimo_mtp", "mimo_mtp",
"glm4_moe_mtp", "glm4_moe_mtp",
"glm4_moe_lite_mtp", "glm4_moe_lite_mtp",
"glm_ocr_mtp",
"ernie_mtp", "ernie_mtp",
"exaone_moe_mtp", "exaone_moe_mtp",
"qwen3_next_mtp", "qwen3_next_mtp",
...@@ -221,6 +222,17 @@ class SpeculativeConfig: ...@@ -221,6 +222,17 @@ class SpeculativeConfig:
} }
) )
if hf_config.architectures[0] == "GlmOcrForConditionalGeneration":
hf_config.model_type = "glm_ocr_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["GlmOcrMTPModel"],
}
)
if hf_config.model_type == "ernie4_5_moe": if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp" hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp": if hf_config.model_type == "ernie_mtp":
......
...@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase): ...@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
return buffer return buffer
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -84,6 +84,34 @@ class NaiveAll2AllManager(All2AllManagerBase): ...@@ -84,6 +84,34 @@ class NaiveAll2AllManager(All2AllManagerBase):
return hidden_states, router_logits return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_weights = self.naive_multicast(
topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_ids = self.naive_multicast(
topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel
)
return hidden_states, topk_weights, topk_ids
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -114,7 +142,7 @@ class AgRsAll2AllManager(All2AllManagerBase): ...@@ -114,7 +142,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group): def __init__(self, cpu_group):
super().__init__(cpu_group) super().__init__(cpu_group)
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -148,6 +176,46 @@ class AgRsAll2AllManager(All2AllManagerBase): ...@@ -148,6 +176,46 @@ class AgRsAll2AllManager(All2AllManagerBase):
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:]) return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
return gathered_tensors[0], gathered_tensors[1] return gathered_tensors[0], gathered_tensors[1]
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
tensors_to_gather = [hidden_states, topk_weights, topk_ids]
if extra_tensors is not None:
tensors_to_gather.extend(extra_tensors)
gathered_tensors = dist_group.all_gatherv(
tensors_to_gather,
dim=0,
sizes=sizes,
)
hidden_states = gathered_tensors[0]
topk_weights = gathered_tensors[1]
topk_ids = gathered_tensors[2]
if extra_tensors is None:
return hidden_states, topk_weights, topk_ids
return hidden_states, topk_weights, topk_ids, gathered_tensors[3:]
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -216,7 +284,7 @@ class PPLXAll2AllManager(All2AllManagerBase): ...@@ -216,7 +284,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode, pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
) )
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -225,6 +293,19 @@ class PPLXAll2AllManager(All2AllManagerBase): ...@@ -225,6 +293,19 @@ class PPLXAll2AllManager(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -264,7 +345,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): ...@@ -264,7 +345,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def get_handle(self, kwargs): def get_handle(self, kwargs):
raise NotImplementedError raise NotImplementedError
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -273,6 +354,19 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): ...@@ -273,6 +354,19 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading import threading
from typing import Any
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
import torch import torch
...@@ -64,13 +63,32 @@ class All2AllManagerBase: ...@@ -64,13 +63,32 @@ class All2AllManagerBase:
# and reuse it for the same config. # and reuse it for the same config.
raise NotImplementedError raise NotImplementedError
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None, extra_tensors: list[torch.Tensor] | None = None,
) -> Any: ) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
# Subclasses should either: # Subclasses should either:
# - implement handling for extra_tensors, or # - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported. # - raise a clear error if extra_tensors is not supported.
...@@ -280,7 +298,7 @@ class DeviceCommunicatorBase: ...@@ -280,7 +298,7 @@ class DeviceCommunicatorBase:
for module in moe_modules: for module in moe_modules:
module.maybe_init_modular_kernel() module.maybe_init_modular_kernel()
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -294,8 +312,29 @@ class DeviceCommunicatorBase: ...@@ -294,8 +312,29 @@ class DeviceCommunicatorBase:
Dispatch the hidden states and router logits to the appropriate device. Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class. This is a no-op in the base class.
""" """
if extra_tensors is not None:
return hidden_states, router_logits, extra_tensors
return hidden_states, router_logits return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
if extra_tensors is not None:
return hidden_states, topk_weights, topk_ids, extra_tensors
return hidden_states, topk_weights, topk_ids
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -130,29 +130,65 @@ class CpuCommunicator(DeviceCommunicatorBase): ...@@ -130,29 +130,65 @@ class CpuCommunicator(DeviceCommunicatorBase):
) -> dict[str, torch.Tensor | Any]: ) -> dict[str, torch.Tensor | Any]:
return self.dist_module.recv_tensor_dict(src) return self.dist_module.recv_tensor_dict(src)
def dispatch( # type: ignore[override] def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None, extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.dispatch( return self.all2all_manager.dispatch_router_logits(
hidden_states, hidden_states,
router_logits, router_logits,
is_sequence_parallel, is_sequence_parallel,
extra_tensors, # type: ignore[call-arg] extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
) )
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine( return self.all2all_manager.combine(
hidden_states, is_sequence_parallel hidden_states,
is_sequence_parallel,
) )
return hidden_states
class _CPUSHMDistributed: class _CPUSHMDistributed:
......
...@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list return output_list
def dispatch( # type: ignore[override] def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -332,19 +332,52 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -332,19 +332,52 @@ class CudaCommunicator(DeviceCommunicatorBase):
tuple[torch.Tensor, torch.Tensor] tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
): ):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.dispatch( return self.all2all_manager.dispatch_router_logits(
hidden_states, hidden_states,
router_logits, router_logits,
is_sequence_parallel, is_sequence_parallel,
extra_tensors, # type: ignore[call-arg] extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
) )
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine( return self.all2all_manager.combine(
hidden_states, is_sequence_parallel hidden_states,
is_sequence_parallel,
) )
return hidden_states
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch.distributed as dist import torch.distributed as dist
from flashinfer.comm.mnnvl import CommBackend as CommBackend from flashinfer.comm.mnnvl import CommBackend as CommBackend
...@@ -23,5 +25,14 @@ class CustomCommunicator(CommBackend): ...@@ -23,5 +25,14 @@ class CustomCommunicator(CommBackend):
dist.all_gather_object(gathered, data, group=self._group) dist.all_gather_object(gathered, data, group=self._group)
return gathered return gathered
# NOTE(rob): CommBackend is an abstract class, and bcast/barrier
# are unimplemented on vLLM side. If we need to utilize these
# methods in the future, can create a concrete implementation.
def bcast(self, data: Any, root: int) -> Any:
raise NotImplementedError
def barrier(self) -> None:
raise NotImplementedError
def Split(self, color: int, key: int) -> "CustomCommunicator": def Split(self, color: int, key: int) -> "CustomCommunicator":
return self return self
...@@ -196,26 +196,62 @@ class XpuCommunicator(DeviceCommunicatorBase): ...@@ -196,26 +196,62 @@ class XpuCommunicator(DeviceCommunicatorBase):
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None: def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
dist.broadcast(input_, src=src, group=self.device_group) dist.broadcast(input_, src=src, group=self.device_group)
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None, extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.dispatch( return self.all2all_manager.dispatch_router_logits(
hidden_states, hidden_states,
router_logits, router_logits,
is_sequence_parallel, is_sequence_parallel,
extra_tensors, # type: ignore[call-arg] extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
) )
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine( return self.all2all_manager.combine(
hidden_states, is_sequence_parallel hidden_states,
is_sequence_parallel,
) )
return hidden_states
...@@ -302,7 +302,7 @@ class NixlConnector(KVConnectorBase_V1): ...@@ -302,7 +302,7 @@ class NixlConnector(KVConnectorBase_V1):
@property @property
def prefer_cross_layer_blocks(self) -> bool: def prefer_cross_layer_blocks(self) -> bool:
backend = get_current_attn_backend(self._vllm_config) backend = get_current_attn_backend(self._vllm_config)
if backend().get_name() not in ( if backend.get_name() not in (
"FLASH_ATTN", "FLASH_ATTN",
"FLASHINFER", "FLASHINFER",
): ):
......
...@@ -1003,7 +1003,7 @@ class GroupCoordinator: ...@@ -1003,7 +1003,7 @@ class GroupCoordinator:
if self.device_communicator is not None: if self.device_communicator is not None:
self.device_communicator.prepare_communication_buffer_for_model(model) self.device_communicator.prepare_communication_buffer_for_model(model)
def dispatch( def dispatch_router_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -1014,7 +1014,7 @@ class GroupCoordinator: ...@@ -1014,7 +1014,7 @@ class GroupCoordinator:
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
): ):
if self.device_communicator is not None: if self.device_communicator is not None:
return self.device_communicator.dispatch( # type: ignore[call-arg] return self.device_communicator.dispatch_router_logits(
hidden_states, hidden_states,
router_logits, router_logits,
is_sequence_parallel, is_sequence_parallel,
...@@ -1023,6 +1023,28 @@ class GroupCoordinator: ...@@ -1023,6 +1023,28 @@ class GroupCoordinator:
else: else:
return hidden_states, router_logits return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors,
)
else:
return hidden_states, topk_weights, topk_ids
def combine( def combine(
self, hidden_states, is_sequence_parallel: bool = False self, hidden_states, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -46,6 +46,9 @@ from vllm.multimodal.inputs import ( ...@@ -46,6 +46,9 @@ from vllm.multimodal.inputs import (
MultiModalBatchedField, MultiModalBatchedField,
MultiModalFlatField, MultiModalFlatField,
MultiModalSharedField, MultiModalSharedField,
VisionChunk,
VisionChunkImage,
VisionChunkVideo,
) )
from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
...@@ -336,7 +339,9 @@ ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] ...@@ -336,7 +339,9 @@ ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
ChatTemplateContentFormat = Literal["string", "openai"] ChatTemplateContentFormat = Literal["string", "openai"]
ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"] ModalityStr = Literal[
"image", "audio", "video", "image_embeds", "audio_embeds", "vision_chunk"
]
_T = TypeVar("_T") _T = TypeVar("_T")
...@@ -449,6 +454,78 @@ def _get_embeds_data( ...@@ -449,6 +454,78 @@ def _get_embeds_data(
raise NotImplementedError(type(data_items)) raise NotImplementedError(type(data_items))
def rebuild_mm_uuids_from_mm_data(
mm_uuids: MultiModalUUIDDict,
mm_data: MultiModalDataDict,
) -> MultiModalUUIDDict:
"""Rebuild mm_uuids after vision_chunk processing.
When videos are split into chunks, the original UUIDs need to be updated
to reflect the new UUIDs generated for each chunk.
Args:
mm_uuids: Original UUIDs dictionary
mm_data: Processed multimodal data with vision_chunk items
Returns:
Updated UUIDs dictionary with chunk UUIDs
"""
vision_chunks = mm_data.get("vision_chunk")
if vision_chunks is None:
return mm_uuids
new_uuids = dict(mm_uuids)
vision_chunk_uuids = []
for item in vision_chunks:
# vision_chunk items are always dicts (VisionChunkImage/VisionChunkVideo)
assert isinstance(item, dict)
uuid_val = item.get("uuid")
if uuid_val is not None:
vision_chunk_uuids.append(uuid_val)
if vision_chunk_uuids:
new_uuids["vision_chunk"] = vision_chunk_uuids
return new_uuids
def build_video_prompts_from_mm_data(
mm_data: MultiModalDataDict,
) -> list[str]:
"""Build video prompts from vision_chunk data.
Collects prompts from video chunks and groups them by video_idx.
Args:
mm_data: Processed multimodal data with vision_chunk items
Returns:
List of video prompts, one per video.
"""
vision_chunks = mm_data.get("vision_chunk")
if vision_chunks is None:
return []
# Group chunks by video_idx
video_prompts_dict: dict[int, list[str]] = defaultdict(list)
for item in vision_chunks:
# vision_chunk items are always dicts (VisionChunkImage/VisionChunkVideo)
assert isinstance(item, dict)
if item.get("type") == "video_chunk":
video_idx = item.get("video_idx", 0)
prompt = item.get("prompt", "")
video_prompts_dict[video_idx].append(prompt)
# Build prompts in video order
video_prompts = []
for video_idx in sorted(video_prompts_dict.keys()):
video_prompts.append("".join(video_prompts_dict[video_idx]))
return video_prompts
class BaseMultiModalItemTracker(ABC, Generic[_T]): class BaseMultiModalItemTracker(ABC, Generic[_T]):
""" """
Tracks multi-modal items in a given request and ensures that the number Tracks multi-modal items in a given request and ensures that the number
...@@ -462,6 +539,13 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -462,6 +539,13 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._model_config = model_config self._model_config = model_config
self._items_by_modality = defaultdict[str, list[_T]](list) self._items_by_modality = defaultdict[str, list[_T]](list)
# Track original modality for each vision_chunk item (image or video)
self._modality_order = defaultdict[str, list[str]](list)
@cached_property
def use_unified_vision_chunk_modality(self) -> bool:
"""Check if model uses unified vision_chunk modality for images/videos."""
return getattr(self._model_config.hf_config, "use_unified_vision_chunk", False)
@property @property
def model_config(self) -> ModelConfig: def model_config(self) -> ModelConfig:
...@@ -499,11 +583,31 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -499,11 +583,31 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
media. media.
""" """
input_modality = modality.replace("_embeds", "") input_modality = modality.replace("_embeds", "")
num_items = len(self._items_by_modality[modality]) + 1 original_modality = modality
use_vision_chunk = (
self.use_unified_vision_chunk_modality
and original_modality in ["video", "image"]
)
# If use_unified_vision_chunk_modality is enabled,
# map image/video to vision_chunk
if use_vision_chunk:
# To avoid validation fail
# because models with use_unified_vision_chunk_modality=True
# will only accept vision_chunk modality.
input_modality = "vision_chunk"
num_items = len(self._items_by_modality[input_modality]) + 1
else:
num_items = len(self._items_by_modality[original_modality]) + 1
self.mm_processor.validate_num_items(input_modality, num_items) self.mm_processor.validate_num_items(input_modality, num_items)
self._items_by_modality[modality].append(item) # Track original modality for vision_chunk items
if use_vision_chunk:
self._items_by_modality[input_modality].append(item) # type: ignore
self._modality_order["vision_chunk"].append(original_modality)
else:
self._items_by_modality[original_modality].append(item)
return self.model_cls.get_placeholder_str(modality, num_items) return self.model_cls.get_placeholder_str(modality, num_items)
...@@ -515,6 +619,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -515,6 +619,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def _resolve_items( def _resolve_items(
items_by_modality: dict[str, list[tuple[object, str | None]]], items_by_modality: dict[str, list[tuple[object, str | None]]],
mm_processor: BaseMultiModalProcessor, mm_processor: BaseMultiModalProcessor,
vision_chunk_modality_order: dict[str, list[str]],
) -> tuple[MultiModalDataDict, MultiModalUUIDDict]: ) -> tuple[MultiModalDataDict, MultiModalUUIDDict]:
if "image" in items_by_modality and "image_embeds" in items_by_modality: if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed") raise ValueError("Mixing raw image and embedding inputs is not allowed")
...@@ -546,6 +651,74 @@ def _resolve_items( ...@@ -546,6 +651,74 @@ def _resolve_items(
if "video" in items_by_modality: if "video" in items_by_modality:
mm_data["video"] = [data for data, uuid in items_by_modality["video"]] mm_data["video"] = [data for data, uuid in items_by_modality["video"]]
mm_uuids["video"] = [uuid for data, uuid in items_by_modality["video"]] mm_uuids["video"] = [uuid for data, uuid in items_by_modality["video"]]
if "vision_chunk" in items_by_modality:
# Process vision_chunk items - extract from (data, modality) tuples
# and convert to VisionChunk types with proper UUID handling
vision_chunk_items = items_by_modality["vision_chunk"]
modality_order = vision_chunk_modality_order.get("vision_chunk", [])
mm_uuids["vision_chunk"] = [
uuid for data, uuid in items_by_modality["vision_chunk"]
]
# Filter out None items (from asyncio.sleep(0) placeholders)
filtered_items = [
(idx, item)
for idx, item in enumerate(vision_chunk_items)
if item is not None
]
assert len(filtered_items) == len(modality_order), (
f"vision_chunk items ({len(filtered_items)}) and "
f"modality_order ({len(modality_order)}) must have same length"
)
processed_chunks: list[VisionChunk] = []
video_idx = 0
for i, (idx, item) in enumerate(filtered_items):
inner_modality = modality_order[i]
data, uuid = item
uuid_val = uuid if idx < len(mm_uuids["vision_chunk"]) else None
if inner_modality == "image":
# Cast data to proper type for image
# Use .media (PIL.Image) directly to avoid redundant
# bytes→PIL conversion in media_processor
if hasattr(data, "media"):
image_data = data.media # type: ignore[union-attr]
processed_chunks.append(
VisionChunkImage(type="image", image=image_data, uuid=uuid_val)
)
else:
processed_chunks.append(data) # type: ignore[arg-type]
elif inner_modality == "video":
# For video, we may need to split into chunks
# if processor supports it
# For now, just wrap as a video chunk placeholder
if hasattr(mm_processor, "split_video_chunks") and data is not None:
try:
video_uuid = uuid_val or random_uuid()
# video await result is (video_data, video_meta) tuple
if isinstance(data, tuple) and len(data) >= 1:
video_data = data[0]
else:
video_data = data
video_chunks = mm_processor.split_video_chunks(video_data)
for i, vc in enumerate(video_chunks):
processed_chunks.append(
VisionChunkVideo(
type="video_chunk",
video_chunk=vc["video_chunk"],
uuid=f"{video_uuid}-{i}",
video_idx=video_idx,
prompt=vc["prompt"],
)
)
video_idx += 1
except Exception as e:
logger.warning("Failed to split video chunks: %s", e)
processed_chunks.append(data) # type: ignore[arg-type]
else:
processed_chunks.append(data) # type: ignore[arg-type]
mm_data["vision_chunk"] = processed_chunks
return mm_data, mm_uuids return mm_data, mm_uuids
...@@ -557,7 +730,9 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[tuple[object, str | None]] ...@@ -557,7 +730,9 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[tuple[object, str | None]]
if not self._items_by_modality: if not self._items_by_modality:
return None, None return None, None
return _resolve_items(dict(self._items_by_modality), self.mm_processor) return _resolve_items(
dict(self._items_by_modality), self.mm_processor, self._modality_order
)
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self) return MultiModalContentParser(self)
...@@ -577,7 +752,9 @@ class AsyncMultiModalItemTracker( ...@@ -577,7 +752,9 @@ class AsyncMultiModalItemTracker(
for modality, coros in self._items_by_modality.items() for modality, coros in self._items_by_modality.items()
} }
return _resolve_items(resolved_items_by_modality, self.mm_processor) return _resolve_items(
resolved_items_by_modality, self.mm_processor, self._modality_order
)
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(self) -> "BaseMultiModalContentParser":
return AsyncMultiModalContentParser(self) return AsyncMultiModalContentParser(self)
......
...@@ -265,6 +265,39 @@ def load_log_config(log_config_file: str | None) -> dict | None: ...@@ -265,6 +265,39 @@ def load_log_config(log_config_file: str | None) -> dict | None:
return None return None
def get_uvicorn_log_config(args: Namespace) -> dict | None:
"""
Get the uvicorn log config based on the provided arguments.
Priority:
1. If log_config_file is specified, use it
2. If disable_access_log_for_endpoints is specified, create a config with
the access log filter
3. Otherwise, return None (use uvicorn defaults)
"""
# First, try to load from file if specified
log_config = load_log_config(args.log_config_file)
if log_config is not None:
return log_config
# If endpoints to filter are specified, create a config with the filter
if args.disable_access_log_for_endpoints:
from vllm.logging_utils import create_uvicorn_log_config
# Parse comma-separated string into list
excluded_paths = [
p.strip()
for p in args.disable_access_log_for_endpoints.split(",")
if p.strip()
]
return create_uvicorn_log_config(
excluded_paths=excluded_paths,
log_level=args.uvicorn_log_level,
)
return None
class AuthenticationMiddleware: class AuthenticationMiddleware:
""" """
Pure ASGI middleware that authenticates each request by checking Pure ASGI middleware that authenticates each request by checking
...@@ -931,8 +964,8 @@ async def run_server_worker( ...@@ -931,8 +964,8 @@ async def run_server_worker(
if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3: if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin) ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)
# Load logging config for uvicorn if specified # Get uvicorn log config (from file or with endpoint filter)
log_config = load_log_config(args.log_config_file) log_config = get_uvicorn_log_config(args)
if log_config is not None: if log_config is not None:
uvicorn_kwargs["log_config"] = log_config uvicorn_kwargs["log_config"] = log_config
......
...@@ -44,6 +44,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -44,6 +44,7 @@ from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage, DeltaMessage,
DeltaToolCall, DeltaToolCall,
ErrorResponse, ErrorResponse,
FunctionCall,
PromptTokenUsageInfo, PromptTokenUsageInfo,
RequestResponseMetadata, RequestResponseMetadata,
ToolCall, ToolCall,
...@@ -67,6 +68,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ...@@ -67,6 +68,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
...@@ -143,11 +145,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -143,11 +145,6 @@ class OpenAIServingChat(OpenAIServing):
self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_force_include_usage = enable_force_include_usage self.enable_force_include_usage = enable_force_include_usage
self.default_sampling_params = self.model_config.get_diff_sampling_param() self.default_sampling_params = self.model_config.get_diff_sampling_param()
if self.model_config.hf_config.model_type == "kimi_k2":
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss" self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss"
if self.use_harmony: if self.use_harmony:
if "stop_token_ids" not in self.default_sampling_params: if "stop_token_ids" not in self.default_sampling_params:
...@@ -156,6 +153,16 @@ class OpenAIServingChat(OpenAIServing): ...@@ -156,6 +153,16 @@ class OpenAIServingChat(OpenAIServing):
get_stop_tokens_for_assistant_actions() get_stop_tokens_for_assistant_actions()
) )
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides = getattr(self.model_config, "hf_overrides", None)
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
isinstance(hf_overrides, dict)
and hf_overrides.get("model_type") == "kimi_k2"
):
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
# NOTE(woosuk): While OpenAI's chat completion API supports browsing # NOTE(woosuk): While OpenAI's chat completion API supports browsing
# for some models, currently vLLM doesn't support it. Please use the # for some models, currently vLLM doesn't support it. Please use the
# Responses API instead. # Responses API instead.
...@@ -247,8 +254,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -247,8 +254,8 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially # because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request # re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls` # for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request) maybe_serialize_tool_calls(request) # type: ignore[arg-type]
truncate_tool_call_ids(request) truncate_tool_call_ids(request) # type: ignore[arg-type]
validate_request_params(request) validate_request_params(request)
# Check if tool parsing is unavailable (common condition) # Check if tool parsing is unavailable (common condition)
...@@ -368,20 +375,18 @@ class OpenAIServingChat(OpenAIServing): ...@@ -368,20 +375,18 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = self._get_prompt_components(engine_prompt) prompt_text, _, _ = get_prompt_components(engine_prompt)
# If we are creating sub requests for multiple prompts, ensure that they # If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids. # have unique request ids.
sub_request_id = ( sub_request_id = (
request_id if len(engine_prompts) == 1 else f"{request_id}_{i}" request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
) )
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
request=request, request=request,
input_length=len(engine_prompt["prompt_token_ids"]), prompt=engine_prompt,
default_sampling_params=self.default_sampling_params, default_sampling_params=self.default_sampling_params,
) )
...@@ -454,6 +459,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -454,6 +459,7 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response # Streaming response
tokenizer = self.renderer.tokenizer tokenizer = self.renderer.tokenizer
assert tokenizer is not None
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
...@@ -632,9 +638,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -632,9 +638,11 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
model_name: str, model_name: str,
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: TokenizerLike | None, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
from vllm.tokenizers.mistral import MistralTokenizer
created_time = int(time.time()) created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk" chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True first_iteration = True
...@@ -698,7 +706,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -698,7 +706,7 @@ class OpenAIServingChat(OpenAIServing):
) )
reasoning_parser = self.reasoning_parser( reasoning_parser = self.reasoning_parser(
tokenizer, tokenizer,
chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg] chat_template_kwargs=chat_template_kwargs or {}, # type: ignore[call-arg]
) )
except RuntimeError as e: except RuntimeError as e:
logger.exception("Error in reasoning parser creation.") logger.exception("Error in reasoning parser creation.")
...@@ -955,8 +963,17 @@ class OpenAIServingChat(OpenAIServing): ...@@ -955,8 +963,17 @@ class OpenAIServingChat(OpenAIServing):
index=i, index=i,
) )
else: else:
# Generate ID based on tokenizer type
if isinstance(tokenizer, MistralTokenizer):
tool_call_id = MistralToolCall.generate_random_id()
else:
tool_call_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_choice_function_name,
idx=history_tool_call_cnt,
)
delta_tool_call = DeltaToolCall( delta_tool_call = DeltaToolCall(
id=make_tool_call_id(), id=tool_call_id,
type="function", type="function",
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=tool_choice_function_name, name=tool_choice_function_name,
...@@ -1387,9 +1404,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1387,9 +1404,11 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
model_name: str, model_name: str,
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: TokenizerLike | None, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> ErrorResponse | ChatCompletionResponse: ) -> ErrorResponse | ChatCompletionResponse:
from vllm.tokenizers.mistral import MistralTokenizer
created_time = int(time.time()) created_time = int(time.time())
final_res: RequestOutput | None = None final_res: RequestOutput | None = None
...@@ -1524,39 +1543,85 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1524,39 +1543,85 @@ class OpenAIServingChat(OpenAIServing):
tool_call_class = ( tool_call_class = (
MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
) )
if (not self.enable_auto_tools or not self.tool_parser) and ( if self.use_harmony:
# Harmony models already have parsed content and tool_calls
# through parse_chat_output. Respect its output directly.
message = ChatMessage(
role=role,
reasoning=reasoning,
content=content,
tool_calls=tool_calls if tool_calls else [],
)
elif (not self.enable_auto_tools or not self.tool_parser) and (
not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
and request.tool_choice != "required" and request.tool_choice != "required"
): ):
message = ChatMessage(role=role, reasoning=reasoning, content=content) message = ChatMessage(role=role, reasoning=reasoning, content=content)
# if the request uses tools and specified a tool choice
elif ( elif (
request.tool_choice request.tool_choice
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
): ):
assert tool_calls is not None and len(tool_calls) > 0 assert tool_calls is not None and len(tool_calls) > 0
tool_call_class_items = []
for idx, tc in enumerate(tool_calls):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if tc.id:
tool_call_class_items.append(
tool_call_class(id=tc.id, function=tc)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_class_items.append(tool_call_class(function=tc))
else:
generated_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tc.name,
idx=history_tool_call_cnt + idx,
)
tool_call_class_items.append(
tool_call_class(id=generated_id, function=tc)
)
history_tool_call_cnt += 1
message = ChatMessage( message = ChatMessage(
role=role, role=role,
reasoning=reasoning, reasoning=reasoning,
content="", content="",
tool_calls=[tool_call_class(function=tc) for tc in tool_calls], tool_calls=tool_call_class_items,
) )
elif request.tool_choice and request.tool_choice == "required": elif request.tool_choice and request.tool_choice == "required":
tool_call_class_items = [] tool_call_class_items = []
assert tool_calls is not None and len(tool_calls) > 0 assert tool_calls is not None and len(tool_calls) > 0
for tool_call in tool_calls: for idx, tool_call in enumerate(tool_calls):
tool_call_class_items.append( # Use native ID if available,
tool_call_class( # otherwise generate ID with correct id_type
id=make_tool_call_id( if tool_call.id:
tool_call_class_items.append(
tool_call_class(id=tool_call.id, function=tool_call)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_class_items.append(
tool_call_class(function=tool_call)
)
else:
generated_id = make_tool_call_id(
id_type=self.tool_call_id_type, id_type=self.tool_call_id_type,
func_name=tool_call.name, func_name=tool_call.name,
idx=history_tool_call_cnt, idx=history_tool_call_cnt + idx,
), )
function=tool_call, tool_call_class_items.append(
) tool_call_class(id=generated_id, function=tool_call)
) )
history_tool_call_cnt += 1 history_tool_call_cnt += 1
message = ChatMessage( message = ChatMessage(
role=role, role=role,
...@@ -1582,17 +1647,35 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1582,17 +1647,35 @@ class OpenAIServingChat(OpenAIServing):
# call. The same is not true for named function calls # call. The same is not true for named function calls
auto_tools_called = tool_calls is not None and len(tool_calls) > 0 auto_tools_called = tool_calls is not None and len(tool_calls) > 0
if tool_calls: if tool_calls:
tool_call_items = []
for idx, tc in enumerate(tool_calls):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if tc.id:
tool_call_items.append(
tool_call_class(id=tc.id, function=tc)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_items.append(tool_call_class(function=tc))
else:
generated_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tc.name,
idx=history_tool_call_cnt + idx,
)
tool_call_items.append(
tool_call_class(id=generated_id, function=tc)
)
history_tool_call_cnt += 1
message = ChatMessage( message = ChatMessage(
role=role, role=role,
reasoning=reasoning, reasoning=reasoning,
content=content, content=content,
tool_calls=[ tool_calls=tool_call_items,
ToolCall(
function=tc,
type="function",
)
for tc in tool_calls
],
) )
else: else:
...@@ -1701,13 +1784,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1701,13 +1784,11 @@ class OpenAIServingChat(OpenAIServing):
elif choice.message.tool_calls: elif choice.message.tool_calls:
# For tool calls, log the function name and arguments # For tool calls, log the function name and arguments
tool_call_descriptions = [] tool_call_descriptions = []
for tc in choice.message.tool_calls: for tc in choice.message.tool_calls: # type: ignore
if hasattr(tc.function, "name") and hasattr( function_call: FunctionCall = tc.function # type: ignore
tc.function, "arguments" tool_call_descriptions.append(
): f"{function_call.name}({function_call.arguments})"
tool_call_descriptions.append( )
f"{tc.function.name}({tc.function.arguments})"
)
tool_calls_str = ", ".join(tool_call_descriptions) tool_calls_str = ", ".join(tool_call_descriptions)
output_text = f"[tool_calls: {tool_calls_str}]" output_text = f"[tool_calls: {tool_calls_str}]"
...@@ -1895,7 +1976,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1895,7 +1976,7 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially # because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request # re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls` # for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request) maybe_serialize_tool_calls(request) # type: ignore[arg-type]
# Add system message. # Add system message.
# NOTE: In Chat Completion API, browsing is enabled by default # NOTE: In Chat Completion API, browsing is enabled by default
...@@ -1913,7 +1994,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1913,7 +1994,7 @@ class OpenAIServingChat(OpenAIServing):
# Add developer message. # Add developer message.
if request.tools: if request.tools:
dev_msg = get_developer_message( dev_msg = get_developer_message(
tools=request.tools if should_include_tools else None tools=request.tools if should_include_tools else None # type: ignore[arg-type]
) )
messages.append(dev_msg) messages.append(dev_msg)
......
...@@ -85,6 +85,12 @@ class FrontendArgs: ...@@ -85,6 +85,12 @@ class FrontendArgs:
"""Log level for uvicorn.""" """Log level for uvicorn."""
disable_uvicorn_access_log: bool = False disable_uvicorn_access_log: bool = False
"""Disable uvicorn access log.""" """Disable uvicorn access log."""
disable_access_log_for_endpoints: str | None = None
"""Comma-separated list of endpoint paths to exclude from uvicorn access
logs. This is useful to reduce log noise from high-frequency endpoints
like health checks. Example: "/health,/metrics,/ping".
When set, access logs for requests to these paths will be suppressed
while keeping logs for other endpoints."""
allow_credentials: bool = False allow_credentials: bool = False
"""Allow credentials.""" """Allow credentials."""
allowed_origins: list[str] = field(default_factory=lambda: ["*"]) allowed_origins: list[str] = field(default_factory=lambda: ["*"])
...@@ -244,6 +250,11 @@ class FrontendArgs: ...@@ -244,6 +250,11 @@ class FrontendArgs:
del frontend_kwargs["middleware"]["nargs"] del frontend_kwargs["middleware"]["nargs"]
frontend_kwargs["middleware"]["default"] = [] frontend_kwargs["middleware"]["default"] = []
# Special case: disable_access_log_for_endpoints is a single
# comma-separated string, not a list
if "nargs" in frontend_kwargs["disable_access_log_for_endpoints"]:
del frontend_kwargs["disable_access_log_for_endpoints"]["nargs"]
# Special case: Tool call parser shows built-in options. # Special case: Tool call parser shows built-in options.
valid_tool_parsers = list(ToolParserManager.list_registered()) valid_tool_parsers = list(ToolParserManager.list_registered())
parsers_str = ",".join(valid_tool_parsers) parsers_str = ",".join(valid_tool_parsers)
......
...@@ -36,6 +36,7 @@ from vllm.entrypoints.renderer import RenderConfig ...@@ -36,6 +36,7 @@ from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
...@@ -162,25 +163,12 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -162,25 +163,12 @@ class OpenAIServingCompletion(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text, prompt_token_ids, prompt_embeds = ( prompt_text, _, _ = get_prompt_components(engine_prompt)
self._get_prompt_components(engine_prompt)
)
input_length = None
if prompt_token_ids is not None:
input_length = len(prompt_token_ids)
elif prompt_embeds is not None:
input_length = len(prompt_embeds)
else:
raise NotImplementedError
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
request=request, request=request,
input_length=input_length, prompt=engine_prompt,
default_sampling_params=self.default_sampling_params, default_sampling_params=self.default_sampling_params,
) )
......
...@@ -218,6 +218,10 @@ def get_logits_processors( ...@@ -218,6 +218,10 @@ def get_logits_processors(
class FunctionCall(OpenAIBaseModel): class FunctionCall(OpenAIBaseModel):
# Internal field to preserve native tool call ID from tool parser.
# Excluded from serialization to maintain OpenAI API compatibility
# (function object should only contain 'name' and 'arguments').
id: str | None = Field(default=None, exclude=True)
name: str name: str
arguments: str arguments: str
......
...@@ -64,13 +64,12 @@ from vllm.entrypoints.openai.translations.protocol import ( ...@@ -64,13 +64,12 @@ from vllm.entrypoints.openai.translations.protocol import (
from vllm.entrypoints.pooling.classify.protocol import ( from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest, ClassificationChatRequest,
ClassificationCompletionRequest, ClassificationCompletionRequest,
ClassificationRequest,
ClassificationResponse, ClassificationResponse,
) )
from vllm.entrypoints.pooling.embed.protocol import ( from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest, EmbeddingChatRequest,
EmbeddingCompletionRequest, EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse, EmbeddingResponse,
) )
from vllm.entrypoints.pooling.pooling.protocol import ( from vllm.entrypoints.pooling.pooling.protocol import (
...@@ -95,11 +94,14 @@ from vllm.entrypoints.serve.tokenize.protocol import ( ...@@ -95,11 +94,14 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest, TokenizeCompletionRequest,
TokenizeResponse, TokenizeResponse,
) )
from vllm.entrypoints.utils import _validate_truncation_size, sanitize_message from vllm.entrypoints.utils import (
_validate_truncation_size,
get_max_tokens,
sanitize_message,
)
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import ( from vllm.inputs.parse import (
PromptComponents,
get_prompt_components, get_prompt_components,
is_explicit_encoder_decoder_prompt, is_explicit_encoder_decoder_prompt,
) )
...@@ -170,6 +172,7 @@ AnyResponse: TypeAlias = ( ...@@ -170,6 +172,7 @@ AnyResponse: TypeAlias = (
CompletionResponse CompletionResponse
| ChatCompletionResponse | ChatCompletionResponse
| EmbeddingResponse | EmbeddingResponse
| EmbeddingBytesResponse
| TranscriptionResponse | TranscriptionResponse
| TokenizeResponse | TokenizeResponse
| PoolingResponse | PoolingResponse
...@@ -183,51 +186,21 @@ RequestT = TypeVar("RequestT", bound=AnyRequest) ...@@ -183,51 +186,21 @@ RequestT = TypeVar("RequestT", bound=AnyRequest)
@dataclass(kw_only=True) @dataclass(kw_only=True)
class RequestProcessingMixin: class ServeContext(Generic[RequestT]):
"""
Mixin for request processing,
handling prompt preparation and engine input.
"""
engine_prompts: list[TokensPrompt] | None = field(default_factory=list)
@dataclass(kw_only=True)
class ResponseGenerationMixin:
"""
Mixin for response generation,
managing result generators and final batch results.
"""
result_generator: (
AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
) = None
final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
default_factory=list
)
model_config = ConfigDict(arbitrary_types_allowed=True)
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
request: RequestT request: RequestT
raw_request: Request | None = None raw_request: Request | None = None
model_name: str model_name: str
request_id: str request_id: str
created_time: int = field(default_factory=lambda: int(time.time())) created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
engine_prompts: list[TokensPrompt] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
)
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
@dataclass(kw_only=True) model_config = ConfigDict(arbitrary_types_allowed=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
pass
@dataclass(kw_only=True)
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
chat_template: str | None = None
chat_template_content_format: ChatTemplateContentFormatOption
class OpenAIServing: class OpenAIServing:
...@@ -605,10 +578,7 @@ class OpenAIServing: ...@@ -605,10 +578,7 @@ class OpenAIServing:
self, self,
ctx: ServeContext, ctx: ServeContext,
) -> AnyResponse | ErrorResponse: ) -> AnyResponse | ErrorResponse:
generation: AsyncGenerator[AnyResponse | ErrorResponse, None] async for response in self._pipeline(ctx):
generation = self._pipeline(ctx)
async for response in generation:
return response return response
return self.create_error_response("No response yielded from pipeline") return self.create_error_response("No response yielded from pipeline")
...@@ -667,9 +637,7 @@ class OpenAIServing: ...@@ -667,9 +637,7 @@ class OpenAIServing:
ctx: ServeContext, ctx: ServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Schedule the request and get the result generator.""" """Schedule the request and get the result generator."""
generators: list[ generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
] = []
try: try:
trace_headers = ( trace_headers = (
...@@ -723,7 +691,7 @@ class OpenAIServing: ...@@ -723,7 +691,7 @@ class OpenAIServing:
return self.create_error_response("Engine prompts not available") return self.create_error_response("Engine prompts not available")
num_prompts = len(ctx.engine_prompts) num_prompts = len(ctx.engine_prompts)
final_res_batch: list[RequestOutput | PoolingRequestOutput | None] final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts final_res_batch = [None] * num_prompts
if ctx.result_generator is None: if ctx.result_generator is None:
...@@ -1012,7 +980,7 @@ class OpenAIServing: ...@@ -1012,7 +980,7 @@ class OpenAIServing:
def _validate_input( def _validate_input(
self, self,
request: AnyRequest, request: object,
input_ids: list[int], input_ids: list[int],
input_text: str, input_text: str,
) -> TokensPrompt: ) -> TokensPrompt:
...@@ -1323,7 +1291,7 @@ class OpenAIServing: ...@@ -1323,7 +1291,7 @@ class OpenAIServing:
priority: int = 0, priority: int = 0,
**kwargs, **kwargs,
): ):
prompt_text, _, _ = self._get_prompt_components(engine_prompt) prompt_text, _, _ = get_prompt_components(engine_prompt)
orig_priority = priority orig_priority = priority
sub_request = 0 sub_request = 0
...@@ -1374,10 +1342,12 @@ class OpenAIServing: ...@@ -1374,10 +1342,12 @@ class OpenAIServing:
# yield context # yield context
# Create inputs for the next turn. # Create inputs for the next turn.
# Render the next prompt token ids. # Render the next prompt token ids and update sampling_params.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)): if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
prompt_token_ids = context.render_for_completion() token_ids = context.render_for_completion()
engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) engine_prompt = TokensPrompt(prompt_token_ids=token_ids)
sampling_params.max_tokens = self.max_model_len - len(token_ids)
elif isinstance(context, ParsableContext): elif isinstance(context, ParsableContext):
engine_prompts = await self._render_next_turn( engine_prompts = await self._render_next_turn(
context.request, context.request,
...@@ -1389,19 +1359,19 @@ class OpenAIServing: ...@@ -1389,19 +1359,19 @@ class OpenAIServing:
context.chat_template_content_format, context.chat_template_content_format,
) )
engine_prompt = engine_prompts[0] engine_prompt = engine_prompts[0]
prompt_text, _, _ = self._get_prompt_components(engine_prompt) prompt_text, _, _ = get_prompt_components(engine_prompt)
sampling_params.max_tokens = get_max_tokens(
self.max_model_len,
context.request,
engine_prompt,
self.default_sampling_params, # type: ignore
)
# Update the sampling params.
sampling_params.max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]
)
# OPTIMIZATION # OPTIMIZATION
priority = orig_priority - 1 priority = orig_priority - 1
sub_request += 1 sub_request += 1
def _get_prompt_components(self, prompt: PromptType) -> PromptComponents:
return get_prompt_components(prompt)
def _log_inputs( def _log_inputs(
self, self,
request_id: str, request_id: str,
...@@ -1412,7 +1382,7 @@ class OpenAIServing: ...@@ -1412,7 +1382,7 @@ class OpenAIServing:
if self.request_logger is None: if self.request_logger is None:
return return
prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs) prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs)
self.request_logger.log_inputs( self.request_logger.log_inputs(
request_id, request_id,
...@@ -1526,6 +1496,7 @@ class OpenAIServing: ...@@ -1526,6 +1496,7 @@ class OpenAIServing:
# extract_tool_calls() returns a list of tool calls. # extract_tool_calls() returns a list of tool calls.
function_calls.extend( function_calls.extend(
FunctionCall( FunctionCall(
id=tool_call.id,
name=tool_call.function.name, name=tool_call.function.name,
arguments=tool_call.function.arguments, arguments=tool_call.function.arguments,
) )
......
...@@ -63,6 +63,7 @@ from vllm.engine.protocol import EngineClient ...@@ -63,6 +63,7 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatTemplateContentFormatOption, ChatTemplateContentFormatOption,
make_tool_call_id,
) )
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.mcp.tool_server import ToolServer from vllm.entrypoints.mcp.tool_server import ToolServer
...@@ -115,6 +116,7 @@ from vllm.entrypoints.openai.responses.utils import ( ...@@ -115,6 +116,7 @@ from vllm.entrypoints.openai.responses.utils import (
extract_tool_types, extract_tool_types,
should_continue_final_message, should_continue_final_message,
) )
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -250,6 +252,17 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -250,6 +252,17 @@ class OpenAIServingResponses(OpenAIServing):
self.default_sampling_params["stop_token_ids"].extend( self.default_sampling_params["stop_token_ids"].extend(
get_stop_tokens_for_assistant_actions() get_stop_tokens_for_assistant_actions()
) )
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides = getattr(self.model_config, "hf_overrides", None)
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
isinstance(hf_overrides, dict)
and hf_overrides.get("model_type") == "kimi_k2"
):
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
self.enable_auto_tools = enable_auto_tools self.enable_auto_tools = enable_auto_tools
# set up tool use # set up tool use
self.tool_parser = self._get_tool_parser( self.tool_parser = self._get_tool_parser(
...@@ -423,8 +436,11 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -423,8 +436,11 @@ class OpenAIServingResponses(OpenAIServing):
if maybe_error is not None: if maybe_error is not None:
return maybe_error return maybe_error
default_max_tokens = self.max_model_len - len( default_max_tokens = get_max_tokens(
engine_prompt["prompt_token_ids"] self.max_model_len,
request,
engine_prompt,
self.default_sampling_params,
) )
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
...@@ -954,25 +970,28 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -954,25 +970,28 @@ class OpenAIServingResponses(OpenAIServing):
enable_auto_tools=self.enable_auto_tools, enable_auto_tools=self.enable_auto_tools,
tool_parser_cls=self.tool_parser, tool_parser_cls=self.tool_parser,
) )
if content:
output_text = ResponseOutputText( if content or (self.use_harmony and tool_calls):
text=content, res_text_part = None
annotations=[], # TODO if content:
type="output_text", res_text_part = ResponseOutputText(
logprobs=( text=content,
self._create_response_logprobs( annotations=[], # TODO
token_ids=final_output.token_ids, type="output_text",
logprobs=final_output.logprobs, logprobs=(
tokenizer=tokenizer, self._create_response_logprobs(
top_logprobs=request.top_logprobs, token_ids=final_output.token_ids,
) logprobs=final_output.logprobs,
if request.is_include_output_logprobs() tokenizer=tokenizer,
else None top_logprobs=request.top_logprobs,
), )
) if request.is_include_output_logprobs()
else None
),
)
message_item = ResponseOutputMessage( message_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}", id=f"msg_{random_uuid()}",
content=[output_text], content=[res_text_part] if res_text_part else [],
role="assistant", role="assistant",
status="completed", status="completed",
type="message", type="message",
...@@ -984,17 +1003,28 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -984,17 +1003,28 @@ class OpenAIServingResponses(OpenAIServing):
if message_item: if message_item:
outputs.append(message_item) outputs.append(message_item)
if tool_calls: if tool_calls:
tool_call_items = [ # We use a simple counter for history_tool_call_count because
ResponseFunctionToolCall( # we don't track the history of tool calls in the Responses API yet.
id=f"fc_{random_uuid()}", # This means that the tool call index will start from 0 for each
call_id=f"call_{random_uuid()}", # request.
type="function_call", tool_call_items = []
status="completed", for history_tool_call_cnt, tool_call in enumerate(tool_calls):
name=tool_call.name, tool_call_items.append(
arguments=tool_call.arguments, ResponseFunctionToolCall(
id=f"fc_{random_uuid()}",
call_id=tool_call.id
if tool_call.id
else make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_call.name,
idx=history_tool_call_cnt,
),
type="function_call",
status="completed",
name=tool_call.name,
arguments=tool_call.arguments,
)
) )
for tool_call in tool_calls
]
outputs.extend(tool_call_items) outputs.extend(tool_call_items)
return outputs return outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment