Commit c721b814 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1

parent d53fe7e5
...@@ -280,10 +280,9 @@ class DynamicShapesConfig: ...@@ -280,10 +280,9 @@ class DynamicShapesConfig:
until this change picked up https://github.com/pytorch/pytorch/pull/169239. until this change picked up https://github.com/pytorch/pytorch/pull/169239.
""" """
assume_32_bit_indexing: bool = False assume_32_bit_indexing: bool = True
""" """
whether all tensor sizes can use 32 bit indexing. whether all tensor sizes can use 32 bit indexing.
`True` requires PyTorch 2.10+
""" """
def compute_hash(self) -> str: def compute_hash(self) -> str:
......
...@@ -34,7 +34,6 @@ MTPModelTypes = Literal[ ...@@ -34,7 +34,6 @@ MTPModelTypes = Literal[
"mimo_mtp", "mimo_mtp",
"glm4_moe_mtp", "glm4_moe_mtp",
"glm4_moe_lite_mtp", "glm4_moe_lite_mtp",
"glm_ocr_mtp",
"ernie_mtp", "ernie_mtp",
"exaone_moe_mtp", "exaone_moe_mtp",
"qwen3_next_mtp", "qwen3_next_mtp",
...@@ -223,17 +222,6 @@ class SpeculativeConfig: ...@@ -223,17 +222,6 @@ class SpeculativeConfig:
} }
) )
if hf_config.architectures[0] == "GlmOcrForConditionalGeneration":
hf_config.model_type = "glm_ocr_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["GlmOcrMTPModel"],
}
)
if hf_config.model_type == "ernie4_5_moe": if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp" hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp": if hf_config.model_type == "ernie_mtp":
......
...@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase): ...@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
return buffer return buffer
def dispatch_router_logits( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -84,34 +84,6 @@ class NaiveAll2AllManager(All2AllManagerBase): ...@@ -84,34 +84,6 @@ class NaiveAll2AllManager(All2AllManagerBase):
return hidden_states, router_logits return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_weights = self.naive_multicast(
topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_ids = self.naive_multicast(
topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel
)
return hidden_states, topk_weights, topk_ids
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -142,7 +114,7 @@ class AgRsAll2AllManager(All2AllManagerBase): ...@@ -142,7 +114,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group): def __init__(self, cpu_group):
super().__init__(cpu_group) super().__init__(cpu_group)
def dispatch_router_logits( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -176,46 +148,6 @@ class AgRsAll2AllManager(All2AllManagerBase): ...@@ -176,46 +148,6 @@ class AgRsAll2AllManager(All2AllManagerBase):
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:]) return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
return gathered_tensors[0], gathered_tensors[1] return gathered_tensors[0], gathered_tensors[1]
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
tensors_to_gather = [hidden_states, topk_weights, topk_ids]
if extra_tensors is not None:
tensors_to_gather.extend(extra_tensors)
gathered_tensors = dist_group.all_gatherv(
tensors_to_gather,
dim=0,
sizes=sizes,
)
hidden_states = gathered_tensors[0]
topk_weights = gathered_tensors[1]
topk_ids = gathered_tensors[2]
if extra_tensors is None:
return hidden_states, topk_weights, topk_ids
return hidden_states, topk_weights, topk_ids, gathered_tensors[3:]
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -284,7 +216,7 @@ class PPLXAll2AllManager(All2AllManagerBase): ...@@ -284,7 +216,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode, pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
) )
def dispatch_router_logits( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -293,19 +225,6 @@ class PPLXAll2AllManager(All2AllManagerBase): ...@@ -293,19 +225,6 @@ class PPLXAll2AllManager(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -345,7 +264,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): ...@@ -345,7 +264,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def get_handle(self, kwargs): def get_handle(self, kwargs):
raise NotImplementedError raise NotImplementedError
def dispatch_router_logits( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -354,19 +273,6 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): ...@@ -354,19 +273,6 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading import threading
from typing import Any
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
import torch import torch
...@@ -63,32 +64,13 @@ class All2AllManagerBase: ...@@ -63,32 +64,13 @@ class All2AllManagerBase:
# and reuse it for the same config. # and reuse it for the same config.
raise NotImplementedError raise NotImplementedError
def dispatch_router_logits( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None, extra_tensors: list[torch.Tensor] | None = None,
) -> ( ) -> Any:
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
# Subclasses should either: # Subclasses should either:
# - implement handling for extra_tensors, or # - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported. # - raise a clear error if extra_tensors is not supported.
...@@ -298,7 +280,7 @@ class DeviceCommunicatorBase: ...@@ -298,7 +280,7 @@ class DeviceCommunicatorBase:
for module in moe_modules: for module in moe_modules:
module.maybe_init_modular_kernel() module.maybe_init_modular_kernel()
def dispatch_router_logits( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -312,29 +294,8 @@ class DeviceCommunicatorBase: ...@@ -312,29 +294,8 @@ class DeviceCommunicatorBase:
Dispatch the hidden states and router logits to the appropriate device. Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class. This is a no-op in the base class.
""" """
if extra_tensors is not None:
return hidden_states, router_logits, extra_tensors
return hidden_states, router_logits return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
if extra_tensors is not None:
return hidden_states, topk_weights, topk_ids, extra_tensors
return hidden_states, topk_weights, topk_ids
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -130,65 +130,30 @@ class CpuCommunicator(DeviceCommunicatorBase): ...@@ -130,65 +130,30 @@ class CpuCommunicator(DeviceCommunicatorBase):
) -> dict[str, torch.Tensor | Any]: ) -> dict[str, torch.Tensor | Any]:
return self.dist_module.recv_tensor_dict(src) return self.dist_module.recv_tensor_dict(src)
def dispatch_router_logits( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None, extra_tensors: list[torch.Tensor] | None = None,
) -> ( ) -> tuple[torch.Tensor, torch.Tensor]:
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.dispatch( return self.all2all_manager.dispatch(
hidden_states, hidden_states,
topk_weights, router_logits,
topk_ids,
is_sequence_parallel, is_sequence_parallel,
extra_tensors=extra_tensors, extra_tensors, # type: ignore[call-arg]
) )
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.combine( hidden_states = self.all2all_manager.combine(
hidden_states, hidden_states, is_sequence_parallel
is_sequence_parallel,
) )
return hidden_states
class _CPUSHMDistributed: class _CPUSHMDistributed:
......
...@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list return output_list
def dispatch_router_logits( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -332,52 +332,20 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -332,52 +332,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
tuple[torch.Tensor, torch.Tensor] tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
): ):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.dispatch( return self.all2all_manager.dispatch(
hidden_states, hidden_states,
topk_weights, router_logits,
topk_ids,
is_sequence_parallel, is_sequence_parallel,
extra_tensors=extra_tensors, extra_tensors, # type: ignore[call-arg]
) )
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.combine( hidden_states = self.all2all_manager.combine(
hidden_states, hidden_states, is_sequence_parallel
is_sequence_parallel,
) )
return hidden_states
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch.distributed as dist import torch.distributed as dist
from flashinfer.comm.mnnvl import CommBackend as CommBackend from flashinfer.comm.mnnvl import CommBackend as CommBackend
...@@ -24,15 +23,3 @@ class CustomCommunicator(CommBackend): ...@@ -24,15 +23,3 @@ class CustomCommunicator(CommBackend):
gathered = [None] * self.Get_size() gathered = [None] * self.Get_size()
dist.all_gather_object(gathered, data, group=self._group) dist.all_gather_object(gathered, data, group=self._group)
return gathered return gathered
# NOTE(rob): CommBackend is an abstract class, and bcast/barrier
# are unimplemented on vLLM side. If we need to utilize these
# methods in the future, can create a concrete implementation.
def bcast(self, data: Any, root: int) -> Any:
raise NotImplementedError
def barrier(self) -> None:
raise NotImplementedError
def Split(self, color: int, key: int) -> "CustomCommunicator":
return self
...@@ -196,62 +196,26 @@ class XpuCommunicator(DeviceCommunicatorBase): ...@@ -196,62 +196,26 @@ class XpuCommunicator(DeviceCommunicatorBase):
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None: def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
dist.broadcast(input_, src=src, group=self.device_group) dist.broadcast(input_, src=src, group=self.device_group)
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors,
)
def dispatch( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_weights: torch.Tensor, router_logits: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False, is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None, extra_tensors: list[torch.Tensor] | None = None,
) -> ( ) -> tuple[torch.Tensor, torch.Tensor]:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.dispatch( return self.all2all_manager.dispatch(
hidden_states, hidden_states,
topk_weights, router_logits,
topk_ids,
is_sequence_parallel, is_sequence_parallel,
extra_tensors=extra_tensors, extra_tensors, # type: ignore[call-arg]
) )
def combine( def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None assert self.all2all_manager is not None
return self.all2all_manager.combine( hidden_states = self.all2all_manager.combine(
hidden_states, hidden_states, is_sequence_parallel
is_sequence_parallel,
) )
return hidden_states
...@@ -298,7 +298,6 @@ class NixlConnectorMetadata(KVConnectorMetadata): ...@@ -298,7 +298,6 @@ class NixlConnectorMetadata(KVConnectorMetadata):
class NixlConnector(KVConnectorBase_V1): class NixlConnector(KVConnectorBase_V1):
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
......
...@@ -1000,7 +1000,7 @@ class GroupCoordinator: ...@@ -1000,7 +1000,7 @@ class GroupCoordinator:
if self.device_communicator is not None: if self.device_communicator is not None:
self.device_communicator.prepare_communication_buffer_for_model(model) self.device_communicator.prepare_communication_buffer_for_model(model)
def dispatch_router_logits( def dispatch(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
...@@ -1011,7 +1011,7 @@ class GroupCoordinator: ...@@ -1011,7 +1011,7 @@ class GroupCoordinator:
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
): ):
if self.device_communicator is not None: if self.device_communicator is not None:
return self.device_communicator.dispatch_router_logits( return self.device_communicator.dispatch( # type: ignore[call-arg]
hidden_states, hidden_states,
router_logits, router_logits,
is_sequence_parallel, is_sequence_parallel,
...@@ -1020,28 +1020,6 @@ class GroupCoordinator: ...@@ -1020,28 +1020,6 @@ class GroupCoordinator:
else: else:
return hidden_states, router_logits return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors,
)
else:
return hidden_states, topk_weights, topk_ids
def combine( def combine(
self, hidden_states, is_sequence_parallel: bool = False self, hidden_states, is_sequence_parallel: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -264,39 +264,6 @@ def load_log_config(log_config_file: str | None) -> dict | None: ...@@ -264,39 +264,6 @@ def load_log_config(log_config_file: str | None) -> dict | None:
return None return None
def get_uvicorn_log_config(args: Namespace) -> dict | None:
"""
Get the uvicorn log config based on the provided arguments.
Priority:
1. If log_config_file is specified, use it
2. If disable_access_log_for_endpoints is specified, create a config with
the access log filter
3. Otherwise, return None (use uvicorn defaults)
"""
# First, try to load from file if specified
log_config = load_log_config(args.log_config_file)
if log_config is not None:
return log_config
# If endpoints to filter are specified, create a config with the filter
if args.disable_access_log_for_endpoints:
from vllm.logging_utils import create_uvicorn_log_config
# Parse comma-separated string into list
excluded_paths = [
p.strip()
for p in args.disable_access_log_for_endpoints.split(",")
if p.strip()
]
return create_uvicorn_log_config(
excluded_paths=excluded_paths,
log_level=args.uvicorn_log_level,
)
return None
class AuthenticationMiddleware: class AuthenticationMiddleware:
""" """
Pure ASGI middleware that authenticates each request by checking Pure ASGI middleware that authenticates each request by checking
......
...@@ -44,7 +44,6 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -44,7 +44,6 @@ from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage, DeltaMessage,
DeltaToolCall, DeltaToolCall,
ErrorResponse, ErrorResponse,
FunctionCall,
PromptTokenUsageInfo, PromptTokenUsageInfo,
RequestResponseMetadata, RequestResponseMetadata,
ToolCall, ToolCall,
...@@ -68,7 +67,6 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ...@@ -68,7 +67,6 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
...@@ -145,6 +143,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -145,6 +143,11 @@ class OpenAIServingChat(OpenAIServing):
self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_force_include_usage = enable_force_include_usage self.enable_force_include_usage = enable_force_include_usage
self.default_sampling_params = self.model_config.get_diff_sampling_param() self.default_sampling_params = self.model_config.get_diff_sampling_param()
if self.model_config.hf_config.model_type == "kimi_k2":
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss" self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss"
if self.use_harmony: if self.use_harmony:
if "stop_token_ids" not in self.default_sampling_params: if "stop_token_ids" not in self.default_sampling_params:
...@@ -153,16 +156,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -153,16 +156,6 @@ class OpenAIServingChat(OpenAIServing):
get_stop_tokens_for_assistant_actions() get_stop_tokens_for_assistant_actions()
) )
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides = getattr(self.model_config, "hf_overrides", None)
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
isinstance(hf_overrides, dict)
and hf_overrides.get("model_type") == "kimi_k2"
):
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
# NOTE(woosuk): While OpenAI's chat completion API supports browsing # NOTE(woosuk): While OpenAI's chat completion API supports browsing
# for some models, currently vLLM doesn't support it. Please use the # for some models, currently vLLM doesn't support it. Please use the
# Responses API instead. # Responses API instead.
...@@ -254,8 +247,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -254,8 +247,8 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially # because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request # re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls` # for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request) # type: ignore[arg-type] maybe_serialize_tool_calls(request)
truncate_tool_call_ids(request) # type: ignore[arg-type] truncate_tool_call_ids(request)
validate_request_params(request) validate_request_params(request)
# Check if tool parsing is unavailable (common condition) # Check if tool parsing is unavailable (common condition)
...@@ -375,18 +368,20 @@ class OpenAIServingChat(OpenAIServing): ...@@ -375,18 +368,20 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = get_prompt_components(engine_prompt) prompt_text, _, _ = self._get_prompt_components(engine_prompt)
# If we are creating sub requests for multiple prompts, ensure that they # If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids. # have unique request ids.
sub_request_id = ( sub_request_id = (
request_id if len(engine_prompts) == 1 else f"{request_id}_{i}" request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
) )
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
request=request, request=request,
prompt=engine_prompt, input_length=len(engine_prompt["prompt_token_ids"]),
default_sampling_params=self.default_sampling_params, default_sampling_params=self.default_sampling_params,
) )
...@@ -459,7 +454,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -459,7 +454,6 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response # Streaming response
tokenizer = self.renderer.tokenizer tokenizer = self.renderer.tokenizer
assert tokenizer is not None
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
...@@ -638,11 +632,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -638,11 +632,9 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
model_name: str, model_name: str,
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
from vllm.tokenizers.mistral import MistralTokenizer
created_time = int(time.time()) created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk" chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True first_iteration = True
...@@ -706,7 +698,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -706,7 +698,7 @@ class OpenAIServingChat(OpenAIServing):
) )
reasoning_parser = self.reasoning_parser( reasoning_parser = self.reasoning_parser(
tokenizer, tokenizer,
chat_template_kwargs=chat_template_kwargs or {}, # type: ignore[call-arg] chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg]
) )
except RuntimeError as e: except RuntimeError as e:
logger.exception("Error in reasoning parser creation.") logger.exception("Error in reasoning parser creation.")
...@@ -963,17 +955,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -963,17 +955,8 @@ class OpenAIServingChat(OpenAIServing):
index=i, index=i,
) )
else: else:
# Generate ID based on tokenizer type
if isinstance(tokenizer, MistralTokenizer):
tool_call_id = MistralToolCall.generate_random_id()
else:
tool_call_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_choice_function_name,
idx=history_tool_call_cnt,
)
delta_tool_call = DeltaToolCall( delta_tool_call = DeltaToolCall(
id=tool_call_id, id=make_tool_call_id(),
type="function", type="function",
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=tool_choice_function_name, name=tool_choice_function_name,
...@@ -1404,11 +1387,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1404,11 +1387,9 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
model_name: str, model_name: str,
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> ErrorResponse | ChatCompletionResponse: ) -> ErrorResponse | ChatCompletionResponse:
from vllm.tokenizers.mistral import MistralTokenizer
created_time = int(time.time()) created_time = int(time.time())
final_res: RequestOutput | None = None final_res: RequestOutput | None = None
...@@ -1543,84 +1524,38 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1543,84 +1524,38 @@ class OpenAIServingChat(OpenAIServing):
tool_call_class = ( tool_call_class = (
MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
) )
if self.use_harmony: if (not self.enable_auto_tools or not self.tool_parser) and (
# Harmony models already have parsed content and tool_calls
# through parse_chat_output. Respect its output directly.
message = ChatMessage(
role=role,
reasoning=reasoning,
content=content,
tool_calls=tool_calls if tool_calls else [],
)
elif (not self.enable_auto_tools or not self.tool_parser) and (
not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
and request.tool_choice != "required" and request.tool_choice != "required"
): ):
message = ChatMessage(role=role, reasoning=reasoning, content=content) message = ChatMessage(role=role, reasoning=reasoning, content=content)
# if the request uses tools and specified a tool choice
elif ( elif (
request.tool_choice request.tool_choice
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
): ):
assert tool_calls is not None and len(tool_calls) > 0 assert tool_calls is not None and len(tool_calls) > 0
tool_call_class_items = []
for idx, tc in enumerate(tool_calls):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if tc.id:
tool_call_class_items.append(
tool_call_class(id=tc.id, function=tc)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_class_items.append(tool_call_class(function=tc))
else:
generated_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tc.name,
idx=history_tool_call_cnt + idx,
)
tool_call_class_items.append(
tool_call_class(id=generated_id, function=tc)
)
history_tool_call_cnt += 1
message = ChatMessage( message = ChatMessage(
role=role, role=role,
reasoning=reasoning, reasoning=reasoning,
content="", content="",
tool_calls=tool_call_class_items, tool_calls=[tool_call_class(function=tc) for tc in tool_calls],
) )
elif request.tool_choice and request.tool_choice == "required": elif request.tool_choice and request.tool_choice == "required":
tool_call_class_items = [] tool_call_class_items = []
assert tool_calls is not None and len(tool_calls) > 0 assert tool_calls is not None and len(tool_calls) > 0
for idx, tool_call in enumerate(tool_calls): for tool_call in tool_calls:
# Use native ID if available,
# otherwise generate ID with correct id_type
if tool_call.id:
tool_call_class_items.append(
tool_call_class(id=tool_call.id, function=tool_call)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_class_items.append( tool_call_class_items.append(
tool_call_class(function=tool_call) tool_call_class(
) id=make_tool_call_id(
else:
generated_id = make_tool_call_id(
id_type=self.tool_call_id_type, id_type=self.tool_call_id_type,
func_name=tool_call.name, func_name=tool_call.name,
idx=history_tool_call_cnt + idx, idx=history_tool_call_cnt,
),
function=tool_call,
) )
tool_call_class_items.append(
tool_call_class(id=generated_id, function=tool_call)
) )
history_tool_call_cnt += 1 history_tool_call_cnt += 1
message = ChatMessage( message = ChatMessage(
...@@ -1647,35 +1582,17 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1647,35 +1582,17 @@ class OpenAIServingChat(OpenAIServing):
# call. The same is not true for named function calls # call. The same is not true for named function calls
auto_tools_called = tool_calls is not None and len(tool_calls) > 0 auto_tools_called = tool_calls is not None and len(tool_calls) > 0
if tool_calls: if tool_calls:
tool_call_items = []
for idx, tc in enumerate(tool_calls):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if tc.id:
tool_call_items.append(
tool_call_class(id=tc.id, function=tc)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_items.append(tool_call_class(function=tc))
else:
generated_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tc.name,
idx=history_tool_call_cnt + idx,
)
tool_call_items.append(
tool_call_class(id=generated_id, function=tc)
)
history_tool_call_cnt += 1
message = ChatMessage( message = ChatMessage(
role=role, role=role,
reasoning=reasoning, reasoning=reasoning,
content=content, content=content,
tool_calls=tool_call_items, tool_calls=[
ToolCall(
function=tc,
type="function",
)
for tc in tool_calls
],
) )
else: else:
...@@ -1784,10 +1701,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1784,10 +1701,12 @@ class OpenAIServingChat(OpenAIServing):
elif choice.message.tool_calls: elif choice.message.tool_calls:
# For tool calls, log the function name and arguments # For tool calls, log the function name and arguments
tool_call_descriptions = [] tool_call_descriptions = []
for tc in choice.message.tool_calls: # type: ignore for tc in choice.message.tool_calls:
function_call: FunctionCall = tc.function # type: ignore if hasattr(tc.function, "name") and hasattr(
tc.function, "arguments"
):
tool_call_descriptions.append( tool_call_descriptions.append(
f"{function_call.name}({function_call.arguments})" f"{tc.function.name}({tc.function.arguments})"
) )
tool_calls_str = ", ".join(tool_call_descriptions) tool_calls_str = ", ".join(tool_call_descriptions)
output_text = f"[tool_calls: {tool_calls_str}]" output_text = f"[tool_calls: {tool_calls_str}]"
...@@ -1976,7 +1895,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1976,7 +1895,7 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially # because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request # re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls` # for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request) # type: ignore[arg-type] maybe_serialize_tool_calls(request)
# Add system message. # Add system message.
# NOTE: In Chat Completion API, browsing is enabled by default # NOTE: In Chat Completion API, browsing is enabled by default
...@@ -1994,7 +1913,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1994,7 +1913,7 @@ class OpenAIServingChat(OpenAIServing):
# Add developer message. # Add developer message.
if request.tools: if request.tools:
dev_msg = get_developer_message( dev_msg = get_developer_message(
tools=request.tools if should_include_tools else None # type: ignore[arg-type] tools=request.tools if should_include_tools else None
) )
messages.append(dev_msg) messages.append(dev_msg)
......
...@@ -85,12 +85,6 @@ class FrontendArgs: ...@@ -85,12 +85,6 @@ class FrontendArgs:
"""Log level for uvicorn.""" """Log level for uvicorn."""
disable_uvicorn_access_log: bool = False disable_uvicorn_access_log: bool = False
"""Disable uvicorn access log.""" """Disable uvicorn access log."""
disable_access_log_for_endpoints: str | None = None
"""Comma-separated list of endpoint paths to exclude from uvicorn access
logs. This is useful to reduce log noise from high-frequency endpoints
like health checks. Example: "/health,/metrics,/ping".
When set, access logs for requests to these paths will be suppressed
while keeping logs for other endpoints."""
allow_credentials: bool = False allow_credentials: bool = False
"""Allow credentials.""" """Allow credentials."""
allowed_origins: list[str] = field(default_factory=lambda: ["*"]) allowed_origins: list[str] = field(default_factory=lambda: ["*"])
...@@ -250,11 +244,6 @@ class FrontendArgs: ...@@ -250,11 +244,6 @@ class FrontendArgs:
del frontend_kwargs["middleware"]["nargs"] del frontend_kwargs["middleware"]["nargs"]
frontend_kwargs["middleware"]["default"] = [] frontend_kwargs["middleware"]["default"] = []
# Special case: disable_access_log_for_endpoints is a single
# comma-separated string, not a list
if "nargs" in frontend_kwargs["disable_access_log_for_endpoints"]:
del frontend_kwargs["disable_access_log_for_endpoints"]["nargs"]
# Special case: Tool call parser shows built-in options. # Special case: Tool call parser shows built-in options.
valid_tool_parsers = list(ToolParserManager.list_registered()) valid_tool_parsers = list(ToolParserManager.list_registered())
parsers_str = ",".join(valid_tool_parsers) parsers_str = ",".join(valid_tool_parsers)
......
...@@ -163,12 +163,25 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -163,12 +163,25 @@ class OpenAIServingCompletion(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = get_prompt_components(engine_prompt) prompt_text, prompt_token_ids, prompt_embeds = (
self._get_prompt_components(engine_prompt)
)
input_length = None
if prompt_token_ids is not None:
input_length = len(prompt_token_ids)
elif prompt_embeds is not None:
input_length = len(prompt_embeds)
else:
raise NotImplementedError
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
request=request, request=request,
prompt=engine_prompt, input_length=input_length,
default_sampling_params=self.default_sampling_params, default_sampling_params=self.default_sampling_params,
) )
......
...@@ -218,10 +218,6 @@ def get_logits_processors( ...@@ -218,10 +218,6 @@ def get_logits_processors(
class FunctionCall(OpenAIBaseModel): class FunctionCall(OpenAIBaseModel):
# Internal field to preserve native tool call ID from tool parser.
# Excluded from serialization to maintain OpenAI API compatibility
# (function object should only contain 'name' and 'arguments').
id: str | None = Field(default=None, exclude=True)
name: str name: str
arguments: str arguments: str
......
...@@ -64,12 +64,13 @@ from vllm.entrypoints.openai.translations.protocol import ( ...@@ -64,12 +64,13 @@ from vllm.entrypoints.openai.translations.protocol import (
from vllm.entrypoints.pooling.classify.protocol import ( from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest, ClassificationChatRequest,
ClassificationCompletionRequest, ClassificationCompletionRequest,
ClassificationRequest,
ClassificationResponse, ClassificationResponse,
) )
from vllm.entrypoints.pooling.embed.protocol import ( from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest, EmbeddingChatRequest,
EmbeddingCompletionRequest, EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse, EmbeddingResponse,
) )
from vllm.entrypoints.pooling.pooling.protocol import ( from vllm.entrypoints.pooling.pooling.protocol import (
...@@ -94,14 +95,11 @@ from vllm.entrypoints.serve.tokenize.protocol import ( ...@@ -94,14 +95,11 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest, TokenizeCompletionRequest,
TokenizeResponse, TokenizeResponse,
) )
from vllm.entrypoints.utils import ( from vllm.entrypoints.utils import _validate_truncation_size, sanitize_message
_validate_truncation_size,
get_max_tokens,
sanitize_message,
)
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import ( from vllm.inputs.parse import (
PromptComponents,
get_prompt_components, get_prompt_components,
is_explicit_encoder_decoder_prompt, is_explicit_encoder_decoder_prompt,
) )
...@@ -172,7 +170,6 @@ AnyResponse: TypeAlias = ( ...@@ -172,7 +170,6 @@ AnyResponse: TypeAlias = (
CompletionResponse CompletionResponse
| ChatCompletionResponse | ChatCompletionResponse
| EmbeddingResponse | EmbeddingResponse
| EmbeddingBytesResponse
| TranscriptionResponse | TranscriptionResponse
| TokenizeResponse | TokenizeResponse
| PoolingResponse | PoolingResponse
...@@ -186,21 +183,51 @@ RequestT = TypeVar("RequestT", bound=AnyRequest) ...@@ -186,21 +183,51 @@ RequestT = TypeVar("RequestT", bound=AnyRequest)
@dataclass(kw_only=True) @dataclass(kw_only=True)
class ServeContext(Generic[RequestT]): class RequestProcessingMixin:
"""
Mixin for request processing,
handling prompt preparation and engine input.
"""
engine_prompts: list[TokensPrompt] | None = field(default_factory=list)
@dataclass(kw_only=True)
class ResponseGenerationMixin:
"""
Mixin for response generation,
managing result generators and final batch results.
"""
result_generator: (
AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
) = None
final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
default_factory=list
)
model_config = ConfigDict(arbitrary_types_allowed=True)
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
request: RequestT request: RequestT
raw_request: Request | None = None raw_request: Request | None = None
model_name: str model_name: str
request_id: str request_id: str
created_time: int = field(default_factory=lambda: int(time.time())) created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
engine_prompts: list[TokensPrompt] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
)
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True) @dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
pass
@dataclass(kw_only=True)
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
chat_template: str | None = None
chat_template_content_format: ChatTemplateContentFormatOption
class OpenAIServing: class OpenAIServing:
...@@ -578,7 +605,10 @@ class OpenAIServing: ...@@ -578,7 +605,10 @@ class OpenAIServing:
self, self,
ctx: ServeContext, ctx: ServeContext,
) -> AnyResponse | ErrorResponse: ) -> AnyResponse | ErrorResponse:
async for response in self._pipeline(ctx): generation: AsyncGenerator[AnyResponse | ErrorResponse, None]
generation = self._pipeline(ctx)
async for response in generation:
return response return response
return self.create_error_response("No response yielded from pipeline") return self.create_error_response("No response yielded from pipeline")
...@@ -637,7 +667,9 @@ class OpenAIServing: ...@@ -637,7 +667,9 @@ class OpenAIServing:
ctx: ServeContext, ctx: ServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Schedule the request and get the result generator.""" """Schedule the request and get the result generator."""
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
] = []
try: try:
trace_headers = ( trace_headers = (
...@@ -691,7 +723,7 @@ class OpenAIServing: ...@@ -691,7 +723,7 @@ class OpenAIServing:
return self.create_error_response("Engine prompts not available") return self.create_error_response("Engine prompts not available")
num_prompts = len(ctx.engine_prompts) num_prompts = len(ctx.engine_prompts)
final_res_batch: list[PoolingRequestOutput | None] final_res_batch: list[RequestOutput | PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts final_res_batch = [None] * num_prompts
if ctx.result_generator is None: if ctx.result_generator is None:
...@@ -979,7 +1011,7 @@ class OpenAIServing: ...@@ -979,7 +1011,7 @@ class OpenAIServing:
def _validate_input( def _validate_input(
self, self,
request: object, request: AnyRequest,
input_ids: list[int], input_ids: list[int],
input_text: str, input_text: str,
) -> TokensPrompt: ) -> TokensPrompt:
...@@ -1290,7 +1322,7 @@ class OpenAIServing: ...@@ -1290,7 +1322,7 @@ class OpenAIServing:
priority: int = 0, priority: int = 0,
**kwargs, **kwargs,
): ):
prompt_text, _, _ = get_prompt_components(engine_prompt) prompt_text, _, _ = self._get_prompt_components(engine_prompt)
orig_priority = priority orig_priority = priority
sub_request = 0 sub_request = 0
...@@ -1341,12 +1373,10 @@ class OpenAIServing: ...@@ -1341,12 +1373,10 @@ class OpenAIServing:
# yield context # yield context
# Create inputs for the next turn. # Create inputs for the next turn.
# Render the next prompt token ids and update sampling_params. # Render the next prompt token ids.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)): if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
token_ids = context.render_for_completion() prompt_token_ids = context.render_for_completion()
engine_prompt = TokensPrompt(prompt_token_ids=token_ids) engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
sampling_params.max_tokens = self.max_model_len - len(token_ids)
elif isinstance(context, ParsableContext): elif isinstance(context, ParsableContext):
engine_prompts = await self._render_next_turn( engine_prompts = await self._render_next_turn(
context.request, context.request,
...@@ -1358,19 +1388,19 @@ class OpenAIServing: ...@@ -1358,19 +1388,19 @@ class OpenAIServing:
context.chat_template_content_format, context.chat_template_content_format,
) )
engine_prompt = engine_prompts[0] engine_prompt = engine_prompts[0]
prompt_text, _, _ = get_prompt_components(engine_prompt) prompt_text, _, _ = self._get_prompt_components(engine_prompt)
sampling_params.max_tokens = get_max_tokens( # Update the sampling params.
self.max_model_len, sampling_params.max_tokens = self.max_model_len - len(
context.request, engine_prompt["prompt_token_ids"]
engine_prompt,
self.default_sampling_params, # type: ignore
) )
# OPTIMIZATION # OPTIMIZATION
priority = orig_priority - 1 priority = orig_priority - 1
sub_request += 1 sub_request += 1
def _get_prompt_components(self, prompt: PromptType) -> PromptComponents:
return get_prompt_components(prompt)
def _log_inputs( def _log_inputs(
self, self,
request_id: str, request_id: str,
...@@ -1381,7 +1411,7 @@ class OpenAIServing: ...@@ -1381,7 +1411,7 @@ class OpenAIServing:
if self.request_logger is None: if self.request_logger is None:
return return
prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs) prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
self.request_logger.log_inputs( self.request_logger.log_inputs(
request_id, request_id,
...@@ -1495,7 +1525,6 @@ class OpenAIServing: ...@@ -1495,7 +1525,6 @@ class OpenAIServing:
# extract_tool_calls() returns a list of tool calls. # extract_tool_calls() returns a list of tool calls.
function_calls.extend( function_calls.extend(
FunctionCall( FunctionCall(
id=tool_call.id,
name=tool_call.function.name, name=tool_call.function.name,
arguments=tool_call.function.arguments, arguments=tool_call.function.arguments,
) )
......
...@@ -63,7 +63,6 @@ from vllm.engine.protocol import EngineClient ...@@ -63,7 +63,6 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatTemplateContentFormatOption, ChatTemplateContentFormatOption,
make_tool_call_id,
) )
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.mcp.tool_server import ToolServer from vllm.entrypoints.mcp.tool_server import ToolServer
...@@ -116,7 +115,6 @@ from vllm.entrypoints.openai.responses.utils import ( ...@@ -116,7 +115,6 @@ from vllm.entrypoints.openai.responses.utils import (
extract_tool_types, extract_tool_types,
should_continue_final_message, should_continue_final_message,
) )
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -252,17 +250,6 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -252,17 +250,6 @@ class OpenAIServingResponses(OpenAIServing):
self.default_sampling_params["stop_token_ids"].extend( self.default_sampling_params["stop_token_ids"].extend(
get_stop_tokens_for_assistant_actions() get_stop_tokens_for_assistant_actions()
) )
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides = getattr(self.model_config, "hf_overrides", None)
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
isinstance(hf_overrides, dict)
and hf_overrides.get("model_type") == "kimi_k2"
):
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
self.enable_auto_tools = enable_auto_tools self.enable_auto_tools = enable_auto_tools
# set up tool use # set up tool use
self.tool_parser = self._get_tool_parser( self.tool_parser = self._get_tool_parser(
...@@ -436,11 +423,8 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -436,11 +423,8 @@ class OpenAIServingResponses(OpenAIServing):
if maybe_error is not None: if maybe_error is not None:
return maybe_error return maybe_error
default_max_tokens = get_max_tokens( default_max_tokens = self.max_model_len - len(
self.max_model_len, engine_prompt["prompt_token_ids"]
request,
engine_prompt,
self.default_sampling_params,
) )
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
...@@ -970,11 +954,8 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -970,11 +954,8 @@ class OpenAIServingResponses(OpenAIServing):
enable_auto_tools=self.enable_auto_tools, enable_auto_tools=self.enable_auto_tools,
tool_parser_cls=self.tool_parser, tool_parser_cls=self.tool_parser,
) )
if content or (self.use_harmony and tool_calls):
res_text_part = None
if content: if content:
res_text_part = ResponseOutputText( output_text = ResponseOutputText(
text=content, text=content,
annotations=[], # TODO annotations=[], # TODO
type="output_text", type="output_text",
...@@ -991,7 +972,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -991,7 +972,7 @@ class OpenAIServingResponses(OpenAIServing):
) )
message_item = ResponseOutputMessage( message_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}", id=f"msg_{random_uuid()}",
content=[res_text_part] if res_text_part else [], content=[output_text],
role="assistant", role="assistant",
status="completed", status="completed",
type="message", type="message",
...@@ -1003,28 +984,17 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1003,28 +984,17 @@ class OpenAIServingResponses(OpenAIServing):
if message_item: if message_item:
outputs.append(message_item) outputs.append(message_item)
if tool_calls: if tool_calls:
# We use a simple counter for history_tool_call_count because tool_call_items = [
# we don't track the history of tool calls in the Responses API yet.
# This means that the tool call index will start from 0 for each
# request.
tool_call_items = []
for history_tool_call_cnt, tool_call in enumerate(tool_calls):
tool_call_items.append(
ResponseFunctionToolCall( ResponseFunctionToolCall(
id=f"fc_{random_uuid()}", id=f"fc_{random_uuid()}",
call_id=tool_call.id call_id=f"call_{random_uuid()}",
if tool_call.id
else make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_call.name,
idx=history_tool_call_cnt,
),
type="function_call", type="function_call",
status="completed", status="completed",
name=tool_call.name, name=tool_call.name,
arguments=tool_call.arguments, arguments=tool_call.arguments,
) )
) for tool_call in tool_calls
]
outputs.extend(tool_call_items) outputs.extend(tool_call_items)
return outputs return outputs
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus from http import HTTPStatus
from typing import Final, cast from typing import cast
import jinja2 import jinja2
import numpy as np import numpy as np
...@@ -11,8 +11,18 @@ from fastapi import Request ...@@ -11,8 +11,18 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo from vllm.entrypoints.openai.chat_completion.protocol import (
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
ClassificationServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.classify.protocol import ( from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest, ClassificationChatRequest,
...@@ -29,68 +39,60 @@ from vllm.pooling_params import PoolingParams ...@@ -29,68 +39,60 @@ from vllm.pooling_params import PoolingParams
logger = init_logger(__name__) logger = init_logger(__name__)
ClassificationServeContext = ServeContext[ClassificationRequest] class ClassificationMixin(OpenAIServing):
chat_template: str | None
chat_template_content_format: ChatTemplateContentFormatOption
class ServingClassification(OpenAIServing): trust_request_chat_template: bool
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def _preprocess( async def _preprocess(
self, self,
ctx: ClassificationServeContext, ctx: ServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
""" """
Process classification inputs: tokenize text, resolve adapters, Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs. and prepare model-specific inputs.
""" """
ctx = cast(ClassificationServeContext, ctx)
try: try:
ctx.lora_request = self._maybe_get_adapters(ctx.request) request_obj = ctx.request
if isinstance(ctx.request, ClassificationChatRequest): if isinstance(request_obj, ClassificationChatRequest):
error_check_ret = self._validate_chat_template( chat_request = request_obj
request_chat_template=ctx.request.chat_template, messages = chat_request.messages
chat_template_kwargs=ctx.request.chat_template_kwargs, trust_request_chat_template = getattr(
trust_request_chat_template=self.trust_request_chat_template, self,
"trust_request_chat_template",
False,
)
ret = self._validate_chat_template(
request_chat_template=chat_request.chat_template,
chat_template_kwargs=chat_request.chat_template_kwargs,
trust_request_chat_template=trust_request_chat_template,
) )
if error_check_ret: if ret:
return error_check_ret return ret
_, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self._preprocess_chat(
ctx.request, cast(ChatCompletionRequest, chat_request),
self.renderer, self.renderer,
ctx.request.messages, messages,
chat_template=ctx.request.chat_template or self.chat_template, chat_template=(
chat_template_content_format=self.chat_template_content_format, chat_request.chat_template
add_generation_prompt=ctx.request.add_generation_prompt, or getattr(self, "chat_template", None)
continue_final_message=ctx.request.continue_final_message, ),
add_special_tokens=ctx.request.add_special_tokens, chat_template_content_format=cast(
ChatTemplateContentFormatOption,
getattr(self, "chat_template_content_format", "auto"),
),
add_generation_prompt=chat_request.add_generation_prompt,
continue_final_message=chat_request.continue_final_message,
add_special_tokens=chat_request.add_special_tokens,
) )
ctx.engine_prompts = engine_prompts ctx.engine_prompts = engine_prompts
elif isinstance(ctx.request, ClassificationCompletionRequest): elif isinstance(request_obj, ClassificationCompletionRequest):
input_data = ctx.request.input completion_request = request_obj
input_data = completion_request.input
if input_data in (None, ""): if input_data in (None, ""):
return self.create_error_response( return self.create_error_response(
"Input or messages must be provided", "Input or messages must be provided",
...@@ -104,10 +106,13 @@ class ServingClassification(OpenAIServing): ...@@ -104,10 +106,13 @@ class ServingClassification(OpenAIServing):
prompt_input = cast(str | list[str], input_data) prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt( ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input, prompt_or_prompts=prompt_input,
config=self._build_render_config(ctx.request), config=self._build_render_config(completion_request),
) )
else: else:
return self.create_error_response("Invalid classification request type") return self.create_error_response(
"Invalid classification request type",
status_code=HTTPStatus.BAD_REQUEST,
)
return None return None
...@@ -117,14 +122,13 @@ class ServingClassification(OpenAIServing): ...@@ -117,14 +122,13 @@ class ServingClassification(OpenAIServing):
def _build_response( def _build_response(
self, self,
ctx: ClassificationServeContext, ctx: ServeContext,
) -> ClassificationResponse | ErrorResponse: ) -> ClassificationResponse | ErrorResponse:
""" """
Convert model outputs to a formatted classification response Convert model outputs to a formatted classification response
with probabilities and labels. with probabilities and labels.
""" """
id2label = getattr(self.model_config.hf_config, "id2label", {}) ctx = cast(ClassificationServeContext, ctx)
items: list[ClassificationData] = [] items: list[ClassificationData] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
...@@ -135,7 +139,9 @@ class ServingClassification(OpenAIServing): ...@@ -135,7 +139,9 @@ class ServingClassification(OpenAIServing):
probs = classify_res.probs probs = classify_res.probs
predicted_index = int(np.argmax(probs)) predicted_index = int(np.argmax(probs))
label = id2label.get(predicted_index) label = getattr(self.model_config.hf_config, "id2label", {}).get(
predicted_index
)
item = ClassificationData( item = ClassificationData(
index=idx, index=idx,
...@@ -168,6 +174,32 @@ class ServingClassification(OpenAIServing): ...@@ -168,6 +174,32 @@ class ServingClassification(OpenAIServing):
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
class ServingClassification(ClassificationMixin):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_classify( async def create_classify(
self, self,
request: ClassificationRequest, request: ClassificationRequest,
...@@ -183,11 +215,11 @@ class ServingClassification(OpenAIServing): ...@@ -183,11 +215,11 @@ class ServingClassification(OpenAIServing):
request_id=request_id, request_id=request_id,
) )
return await self.handle(ctx) # type: ignore[return-value] return await super().handle(ctx) # type: ignore
def _create_pooling_params( def _create_pooling_params(
self, self,
ctx: ClassificationServeContext, ctx: ServeContext[ClassificationRequest],
) -> PoolingParams | ErrorResponse: ) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx) pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse): if isinstance(pooling_params, ErrorResponse):
......
...@@ -6,13 +6,21 @@ from typing import Any, Final, cast ...@@ -6,13 +6,21 @@ from typing import Any, Final, cast
import torch import torch
from fastapi import Request from fastapi import Request
from typing_extensions import assert_never from fastapi.responses import Response
from typing_extensions import assert_never, override
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo from vllm.entrypoints.openai.engine.protocol import (
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
EmbeddingServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.embed.protocol import ( from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse, EmbeddingBytesResponse,
...@@ -25,11 +33,19 @@ from vllm.entrypoints.pooling.embed.protocol import ( ...@@ -25,11 +33,19 @@ from vllm.entrypoints.pooling.embed.protocol import (
from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import (
EmbeddingRequestOutput,
PoolingOutput,
PoolingRequestOutput,
RequestOutput,
)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import ( from vllm.utils.serial_utils import (
EmbedDType,
EncodingFormat,
Endianness,
encode_pooling_bytes, encode_pooling_bytes,
encode_pooling_output, encode_pooling_output,
) )
...@@ -37,33 +53,9 @@ from vllm.utils.serial_utils import ( ...@@ -37,33 +53,9 @@ from vllm.utils.serial_utils import (
logger = init_logger(__name__) logger = init_logger(__name__)
EmbeddingServeContext = ServeContext[EmbeddingRequest] class EmbeddingMixin(OpenAIServing):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class OpenAIServingEmbedding(OpenAIServing):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
pooler_config = self.model_config.pooler_config pooler_config = self.model_config.pooler_config
...@@ -77,41 +69,32 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -77,41 +69,32 @@ class OpenAIServingEmbedding(OpenAIServing):
else None else None
) )
@override
async def _preprocess( async def _preprocess(
self, self,
ctx: EmbeddingServeContext, ctx: ServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
ctx = cast(EmbeddingServeContext, ctx)
try: try:
ctx.lora_request = self._maybe_get_adapters(ctx.request) ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(ctx.request, EmbeddingChatRequest): if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
_, ctx.engine_prompts = await self._preprocess_chat( _, ctx.engine_prompts = await self._preprocess_chat(
ctx.request, ctx.request,
self.renderer, self.renderer,
ctx.request.messages, ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template, chat_template=ctx.request.chat_template or ctx.chat_template,
chat_template_content_format=self.chat_template_content_format, chat_template_content_format=ctx.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt, add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message, continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens, add_special_tokens=ctx.request.add_special_tokens,
) )
elif isinstance(ctx.request, EmbeddingCompletionRequest): else:
renderer = self._get_completion_renderer() renderer = self._get_completion_renderer()
ctx.engine_prompts = await renderer.render_prompt( ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input, prompt_or_prompts=ctx.request.input,
config=self._build_render_config(ctx.request), config=self._build_render_config(ctx.request),
) )
else:
return self.create_error_response("Invalid classification request type")
return None return None
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
...@@ -130,15 +113,16 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -130,15 +113,16 @@ class OpenAIServingEmbedding(OpenAIServing):
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
@override
def _build_response( def _build_response(
self, self,
ctx: EmbeddingServeContext, ctx: ServeContext,
) -> EmbeddingResponse | EmbeddingBytesResponse | ErrorResponse: ) -> EmbeddingResponse | Response | ErrorResponse:
final_res_batch_checked = ctx.final_res_batch final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
encoding_format = ctx.request.encoding_format encoding_format: EncodingFormat = ctx.request.encoding_format
embed_dtype = ctx.request.embed_dtype embed_dtype: EmbedDType = ctx.request.embed_dtype
endianness = ctx.request.endianness endianness: Endianness = ctx.request.endianness
def encode_float_base64(): def encode_float_base64():
items: list[EmbeddingResponseData] = [] items: list[EmbeddingResponseData] = []
...@@ -219,8 +203,8 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -219,8 +203,8 @@ class OpenAIServingEmbedding(OpenAIServing):
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
token_ids: list[int], token_ids: list[int],
pooling_params: PoolingParams, pooling_params,
trace_headers: Mapping[str, str] | None, trace_headers,
prompt_idx: int, prompt_idx: int,
) -> list[AsyncGenerator[PoolingRequestOutput, None]]: ) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
"""Process a single prompt using chunked processing.""" """Process a single prompt using chunked processing."""
...@@ -262,7 +246,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -262,7 +246,7 @@ class OpenAIServingEmbedding(OpenAIServing):
def _validate_input( def _validate_input(
self, self,
request: object, request,
input_ids: list[int], input_ids: list[int],
input_text: str, input_text: str,
) -> TokensPrompt: ) -> TokensPrompt:
...@@ -342,7 +326,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -342,7 +326,7 @@ class OpenAIServingEmbedding(OpenAIServing):
pooling_params: PoolingParams, pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None, trace_headers: Mapping[str, str] | None,
prompt_index: int, prompt_index: int,
) -> AsyncGenerator[PoolingRequestOutput, None]: ) -> AsyncGenerator[RequestOutput | PoolingRequestOutput, None]:
"""Create a generator for a single prompt using standard processing.""" """Create a generator for a single prompt using standard processing."""
request_id_item = f"{ctx.request_id}-{prompt_index}" request_id_item = f"{ctx.request_id}-{prompt_index}"
...@@ -363,6 +347,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -363,6 +347,7 @@ class OpenAIServingEmbedding(OpenAIServing):
priority=getattr(ctx.request, "priority", 0), priority=getattr(ctx.request, "priority", 0),
) )
@override
async def _prepare_generators( async def _prepare_generators(
self, self,
ctx: ServeContext, ctx: ServeContext,
...@@ -378,7 +363,9 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -378,7 +363,9 @@ class OpenAIServingEmbedding(OpenAIServing):
return await super()._prepare_generators(ctx) return await super()._prepare_generators(ctx)
# Custom logic for chunked processing # Custom logic for chunked processing
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
] = []
try: try:
trace_headers = ( trace_headers = (
...@@ -432,9 +419,10 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -432,9 +419,10 @@ class OpenAIServingEmbedding(OpenAIServing):
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
@override
async def _collect_batch( async def _collect_batch(
self, self,
ctx: EmbeddingServeContext, ctx: ServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Collect and aggregate batch results """Collect and aggregate batch results
with support for chunked processing. with support for chunked processing.
...@@ -443,6 +431,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -443,6 +431,7 @@ class OpenAIServingEmbedding(OpenAIServing):
minimize memory usage. minimize memory usage.
For regular requests, collects results normally. For regular requests, collects results normally.
""" """
ctx = cast(EmbeddingServeContext, ctx)
try: try:
if ctx.engine_prompts is None: if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available") return self.create_error_response("Engine prompts not available")
...@@ -538,10 +527,12 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -538,10 +527,12 @@ class OpenAIServingEmbedding(OpenAIServing):
except (ValueError, IndexError): except (ValueError, IndexError):
prompt_idx = result_idx # Fallback to result_idx prompt_idx = result_idx # Fallback to result_idx
short_prompts_results[prompt_idx] = result short_prompts_results[prompt_idx] = cast(
PoolingRequestOutput, result
)
# Finalize aggregated results # Finalize aggregated results
final_res_batch: list[PoolingRequestOutput] = [] final_res_batch: list[PoolingRequestOutput | EmbeddingRequestOutput] = []
num_prompts = len(ctx.engine_prompts) num_prompts = len(ctx.engine_prompts)
for prompt_idx in range(num_prompts): for prompt_idx in range(num_prompts):
...@@ -589,19 +580,49 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -589,19 +580,49 @@ class OpenAIServingEmbedding(OpenAIServing):
f"Failed to aggregate chunks for prompt {prompt_idx}" f"Failed to aggregate chunks for prompt {prompt_idx}"
) )
elif prompt_idx in short_prompts_results: elif prompt_idx in short_prompts_results:
final_res_batch.append(short_prompts_results[prompt_idx]) final_res_batch.append(
cast(PoolingRequestOutput, short_prompts_results[prompt_idx])
)
else: else:
return self.create_error_response( return self.create_error_response(
f"Result not found for prompt {prompt_idx}" f"Result not found for prompt {prompt_idx}"
) )
ctx.final_res_batch = final_res_batch ctx.final_res_batch = cast(
list[RequestOutput | PoolingRequestOutput], final_res_batch
)
return None return None
except Exception as e: except Exception as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
class OpenAIServingEmbedding(EmbeddingMixin):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_embedding( async def create_embedding(
self, self,
request: EmbeddingRequest, request: EmbeddingRequest,
...@@ -624,13 +645,16 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -624,13 +645,16 @@ class OpenAIServingEmbedding(OpenAIServing):
raw_request=raw_request, raw_request=raw_request,
model_name=model_name, model_name=model_name,
request_id=request_id, request_id=request_id,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
) )
return await self.handle(ctx) # type: ignore[return-value] return await super().handle(ctx) # type: ignore
@override
def _create_pooling_params( def _create_pooling_params(
self, self,
ctx: EmbeddingServeContext, ctx: ServeContext[EmbeddingRequest],
) -> PoolingParams | ErrorResponse: ) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx) pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse): if isinstance(pooling_params, ErrorResponse):
...@@ -642,3 +666,17 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -642,3 +666,17 @@ class OpenAIServingEmbedding(OpenAIServing):
return self.create_error_response(str(e)) return self.create_error_response(str(e))
return pooling_params return pooling_params
async def _preprocess(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
return await super()._preprocess(ctx)
\ No newline at end of file
...@@ -17,10 +17,8 @@ from starlette.background import BackgroundTask, BackgroundTasks ...@@ -17,10 +17,8 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs from vllm import envs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import EmbedsPrompt, TokensPrompt
from vllm.logger import current_formatter_type, init_logger from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -34,15 +32,11 @@ if TYPE_CHECKING: ...@@ -34,15 +32,11 @@ if TYPE_CHECKING:
StreamOptions, StreamOptions,
) )
from vllm.entrypoints.openai.models.protocol import LoRAModulePath from vllm.entrypoints.openai.models.protocol import LoRAModulePath
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
else: else:
ChatCompletionRequest = object ChatCompletionRequest = object
CompletionRequest = object CompletionRequest = object
StreamOptions = object StreamOptions = object
LoRAModulePath = object LoRAModulePath = object
ResponsesRequest = object
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -217,26 +211,11 @@ def _validate_truncation_size( ...@@ -217,26 +211,11 @@ def _validate_truncation_size(
def get_max_tokens( def get_max_tokens(
max_model_len: int, max_model_len: int,
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest", request: "ChatCompletionRequest | CompletionRequest",
prompt: TokensPrompt | EmbedsPrompt, input_length: int,
default_sampling_params: dict, default_sampling_params: dict,
) -> int: ) -> int:
# NOTE: Avoid isinstance() for better efficiency max_tokens = getattr(request, "max_completion_tokens", None) or request.max_tokens
max_tokens: int | None = None
if max_tokens is None:
# ChatCompletionRequest
max_tokens = getattr(request, "max_completion_tokens", None)
if max_tokens is None:
# ResponsesRequest
max_tokens = getattr(request, "max_output_tokens", None)
if max_tokens is None:
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens = getattr(request, "max_tokens", None)
input_length = length_from_prompt_token_ids_or_embeds(
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
prompt.get("prompt_embeds"), # type: ignore[arg-type]
)
default_max_tokens = max_model_len - input_length default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length) max_output_tokens = current_platform.get_max_output_tokens(input_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment