Commit c721b814 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1

parent d53fe7e5
......@@ -280,10 +280,9 @@ class DynamicShapesConfig:
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
"""
assume_32_bit_indexing: bool = False
assume_32_bit_indexing: bool = True
"""
whether all tensor sizes can use 32 bit indexing.
`True` requires PyTorch 2.10+
"""
def compute_hash(self) -> str:
......
......@@ -34,7 +34,6 @@ MTPModelTypes = Literal[
"mimo_mtp",
"glm4_moe_mtp",
"glm4_moe_lite_mtp",
"glm_ocr_mtp",
"ernie_mtp",
"exaone_moe_mtp",
"qwen3_next_mtp",
......@@ -223,17 +222,6 @@ class SpeculativeConfig:
}
)
if hf_config.architectures[0] == "GlmOcrForConditionalGeneration":
hf_config.model_type = "glm_ocr_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["GlmOcrMTPModel"],
}
)
if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp":
......
......@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
return buffer
def dispatch_router_logits(
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -84,34 +84,6 @@ class NaiveAll2AllManager(All2AllManagerBase):
return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_weights = self.naive_multicast(
topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_ids = self.naive_multicast(
topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel
)
return hidden_states, topk_weights, topk_ids
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......@@ -142,7 +114,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
super().__init__(cpu_group)
def dispatch_router_logits(
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -176,46 +148,6 @@ class AgRsAll2AllManager(All2AllManagerBase):
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
return gathered_tensors[0], gathered_tensors[1]
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
tensors_to_gather = [hidden_states, topk_weights, topk_ids]
if extra_tensors is not None:
tensors_to_gather.extend(extra_tensors)
gathered_tensors = dist_group.all_gatherv(
tensors_to_gather,
dim=0,
sizes=sizes,
)
hidden_states = gathered_tensors[0]
topk_weights = gathered_tensors[1]
topk_ids = gathered_tensors[2]
if extra_tensors is None:
return hidden_states, topk_weights, topk_ids
return hidden_states, topk_weights, topk_ids, gathered_tensors[3:]
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......@@ -284,7 +216,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
)
def dispatch_router_logits(
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -293,19 +225,6 @@ class PPLXAll2AllManager(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......@@ -345,7 +264,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def get_handle(self, kwargs):
raise NotImplementedError
def dispatch_router_logits(
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -354,19 +273,6 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Any
from weakref import WeakValueDictionary
import torch
......@@ -63,32 +64,13 @@ class All2AllManagerBase:
# and reuse it for the same config.
raise NotImplementedError
def dispatch_router_logits(
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
) -> Any:
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
......@@ -298,7 +280,7 @@ class DeviceCommunicatorBase:
for module in moe_modules:
module.maybe_init_modular_kernel()
def dispatch_router_logits(
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -312,29 +294,8 @@ class DeviceCommunicatorBase:
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
if extra_tensors is not None:
return hidden_states, router_logits, extra_tensors
return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
if extra_tensors is not None:
return hidden_states, topk_weights, topk_ids, extra_tensors
return hidden_states, topk_weights, topk_ids
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
......
......@@ -130,65 +130,30 @@ class CpuCommunicator(DeviceCommunicatorBase):
) -> dict[str, torch.Tensor | Any]:
return self.dist_module.recv_tensor_dict(src)
def dispatch_router_logits(
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
router_logits,
is_sequence_parallel,
extra_tensors=extra_tensors,
extra_tensors, # type: ignore[call-arg]
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.combine(
hidden_states,
is_sequence_parallel,
hidden_states = self.all2all_manager.combine(
hidden_states, is_sequence_parallel
)
return hidden_states
class _CPUSHMDistributed:
......
......@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list
def dispatch_router_logits(
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -332,52 +332,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
router_logits,
is_sequence_parallel,
extra_tensors=extra_tensors,
extra_tensors, # type: ignore[call-arg]
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.combine(
hidden_states,
is_sequence_parallel,
hidden_states = self.all2all_manager.combine(
hidden_states, is_sequence_parallel
)
return hidden_states
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch.distributed as dist
from flashinfer.comm.mnnvl import CommBackend as CommBackend
......@@ -24,15 +23,3 @@ class CustomCommunicator(CommBackend):
gathered = [None] * self.Get_size()
dist.all_gather_object(gathered, data, group=self._group)
return gathered
# NOTE(rob): CommBackend is an abstract class, and bcast/barrier
# are unimplemented on vLLM side. If we need to utilize these
# methods in the future, can create a concrete implementation.
def bcast(self, data: Any, root: int) -> Any:
raise NotImplementedError
def barrier(self) -> None:
raise NotImplementedError
def Split(self, color: int, key: int) -> "CustomCommunicator":
return self
......@@ -196,62 +196,26 @@ class XpuCommunicator(DeviceCommunicatorBase):
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
dist.broadcast(input_, src=src, group=self.device_group)
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors,
)
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Dispatch the hidden states and topk weights/ids to the appropriate device.
This is a no-op in the base class.
"""
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
router_logits,
is_sequence_parallel,
extra_tensors=extra_tensors,
extra_tensors, # type: ignore[call-arg]
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
return self.all2all_manager.combine(
hidden_states,
is_sequence_parallel,
hidden_states = self.all2all_manager.combine(
hidden_states, is_sequence_parallel
)
return hidden_states
......@@ -298,7 +298,6 @@ class NixlConnectorMetadata(KVConnectorMetadata):
class NixlConnector(KVConnectorBase_V1):
def __init__(
self,
vllm_config: VllmConfig,
......
......@@ -1000,7 +1000,7 @@ class GroupCoordinator:
if self.device_communicator is not None:
self.device_communicator.prepare_communication_buffer_for_model(model)
def dispatch_router_logits(
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
......@@ -1011,7 +1011,7 @@ class GroupCoordinator:
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch_router_logits(
return self.device_communicator.dispatch( # type: ignore[call-arg]
hidden_states,
router_logits,
is_sequence_parallel,
......@@ -1020,28 +1020,6 @@ class GroupCoordinator:
else:
return hidden_states, router_logits
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors,
)
else:
return hidden_states, topk_weights, topk_ids
def combine(
self, hidden_states, is_sequence_parallel: bool = False
) -> torch.Tensor:
......
......@@ -264,39 +264,6 @@ def load_log_config(log_config_file: str | None) -> dict | None:
return None
def get_uvicorn_log_config(args: Namespace) -> dict | None:
"""
Get the uvicorn log config based on the provided arguments.
Priority:
1. If log_config_file is specified, use it
2. If disable_access_log_for_endpoints is specified, create a config with
the access log filter
3. Otherwise, return None (use uvicorn defaults)
"""
# First, try to load from file if specified
log_config = load_log_config(args.log_config_file)
if log_config is not None:
return log_config
# If endpoints to filter are specified, create a config with the filter
if args.disable_access_log_for_endpoints:
from vllm.logging_utils import create_uvicorn_log_config
# Parse comma-separated string into list
excluded_paths = [
p.strip()
for p in args.disable_access_log_for_endpoints.split(",")
if p.strip()
]
return create_uvicorn_log_config(
excluded_paths=excluded_paths,
log_level=args.uvicorn_log_level,
)
return None
class AuthenticationMiddleware:
"""
Pure ASGI middleware that authenticates each request by checking
......
......@@ -44,7 +44,6 @@ from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage,
DeltaToolCall,
ErrorResponse,
FunctionCall,
PromptTokenUsageInfo,
RequestResponseMetadata,
ToolCall,
......@@ -68,7 +67,6 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
......@@ -145,6 +143,11 @@ class OpenAIServingChat(OpenAIServing):
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_force_include_usage = enable_force_include_usage
self.default_sampling_params = self.model_config.get_diff_sampling_param()
if self.model_config.hf_config.model_type == "kimi_k2":
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss"
if self.use_harmony:
if "stop_token_ids" not in self.default_sampling_params:
......@@ -153,16 +156,6 @@ class OpenAIServingChat(OpenAIServing):
get_stop_tokens_for_assistant_actions()
)
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides = getattr(self.model_config, "hf_overrides", None)
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
isinstance(hf_overrides, dict)
and hf_overrides.get("model_type") == "kimi_k2"
):
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
# NOTE(woosuk): While OpenAI's chat completion API supports browsing
# for some models, currently vLLM doesn't support it. Please use the
# Responses API instead.
......@@ -254,8 +247,8 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request) # type: ignore[arg-type]
truncate_tool_call_ids(request) # type: ignore[arg-type]
maybe_serialize_tool_calls(request)
truncate_tool_call_ids(request)
validate_request_params(request)
# Check if tool parsing is unavailable (common condition)
......@@ -375,18 +368,20 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = get_prompt_components(engine_prompt)
prompt_text, _, _ = self._get_prompt_components(engine_prompt)
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
sub_request_id = (
request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
)
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
prompt=engine_prompt,
input_length=len(engine_prompt["prompt_token_ids"]),
default_sampling_params=self.default_sampling_params,
)
......@@ -459,7 +454,6 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response
tokenizer = self.renderer.tokenizer
assert tokenizer is not None
if request.stream:
return self.chat_completion_stream_generator(
......@@ -638,11 +632,9 @@ class OpenAIServingChat(OpenAIServing):
request_id: str,
model_name: str,
conversation: list[ConversationMessage],
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
from vllm.tokenizers.mistral import MistralTokenizer
created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True
......@@ -706,7 +698,7 @@ class OpenAIServingChat(OpenAIServing):
)
reasoning_parser = self.reasoning_parser(
tokenizer,
chat_template_kwargs=chat_template_kwargs or {}, # type: ignore[call-arg]
chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg]
)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
......@@ -963,17 +955,8 @@ class OpenAIServingChat(OpenAIServing):
index=i,
)
else:
# Generate ID based on tokenizer type
if isinstance(tokenizer, MistralTokenizer):
tool_call_id = MistralToolCall.generate_random_id()
else:
tool_call_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_choice_function_name,
idx=history_tool_call_cnt,
)
delta_tool_call = DeltaToolCall(
id=tool_call_id,
id=make_tool_call_id(),
type="function",
function=DeltaFunctionCall(
name=tool_choice_function_name,
......@@ -1404,11 +1387,9 @@ class OpenAIServingChat(OpenAIServing):
request_id: str,
model_name: str,
conversation: list[ConversationMessage],
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata,
) -> ErrorResponse | ChatCompletionResponse:
from vllm.tokenizers.mistral import MistralTokenizer
created_time = int(time.time())
final_res: RequestOutput | None = None
......@@ -1543,85 +1524,39 @@ class OpenAIServingChat(OpenAIServing):
tool_call_class = (
MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
)
if self.use_harmony:
# Harmony models already have parsed content and tool_calls
# through parse_chat_output. Respect its output directly.
message = ChatMessage(
role=role,
reasoning=reasoning,
content=content,
tool_calls=tool_calls if tool_calls else [],
)
elif (not self.enable_auto_tools or not self.tool_parser) and (
if (not self.enable_auto_tools or not self.tool_parser) and (
not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
and request.tool_choice != "required"
):
message = ChatMessage(role=role, reasoning=reasoning, content=content)
# if the request uses tools and specified a tool choice
elif (
request.tool_choice
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
):
assert tool_calls is not None and len(tool_calls) > 0
tool_call_class_items = []
for idx, tc in enumerate(tool_calls):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if tc.id:
tool_call_class_items.append(
tool_call_class(id=tc.id, function=tc)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_class_items.append(tool_call_class(function=tc))
else:
generated_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tc.name,
idx=history_tool_call_cnt + idx,
)
tool_call_class_items.append(
tool_call_class(id=generated_id, function=tc)
)
history_tool_call_cnt += 1
message = ChatMessage(
role=role,
reasoning=reasoning,
content="",
tool_calls=tool_call_class_items,
tool_calls=[tool_call_class(function=tc) for tc in tool_calls],
)
elif request.tool_choice and request.tool_choice == "required":
tool_call_class_items = []
assert tool_calls is not None and len(tool_calls) > 0
for idx, tool_call in enumerate(tool_calls):
# Use native ID if available,
# otherwise generate ID with correct id_type
if tool_call.id:
tool_call_class_items.append(
tool_call_class(id=tool_call.id, function=tool_call)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_class_items.append(
tool_call_class(function=tool_call)
)
else:
generated_id = make_tool_call_id(
for tool_call in tool_calls:
tool_call_class_items.append(
tool_call_class(
id=make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_call.name,
idx=history_tool_call_cnt + idx,
)
tool_call_class_items.append(
tool_call_class(id=generated_id, function=tool_call)
)
idx=history_tool_call_cnt,
),
function=tool_call,
)
)
history_tool_call_cnt += 1
message = ChatMessage(
role=role,
......@@ -1647,35 +1582,17 @@ class OpenAIServingChat(OpenAIServing):
# call. The same is not true for named function calls
auto_tools_called = tool_calls is not None and len(tool_calls) > 0
if tool_calls:
tool_call_items = []
for idx, tc in enumerate(tool_calls):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if tc.id:
tool_call_items.append(
tool_call_class(id=tc.id, function=tc)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_items.append(tool_call_class(function=tc))
else:
generated_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tc.name,
idx=history_tool_call_cnt + idx,
)
tool_call_items.append(
tool_call_class(id=generated_id, function=tc)
)
history_tool_call_cnt += 1
message = ChatMessage(
role=role,
reasoning=reasoning,
content=content,
tool_calls=tool_call_items,
tool_calls=[
ToolCall(
function=tc,
type="function",
)
for tc in tool_calls
],
)
else:
......@@ -1784,11 +1701,13 @@ class OpenAIServingChat(OpenAIServing):
elif choice.message.tool_calls:
# For tool calls, log the function name and arguments
tool_call_descriptions = []
for tc in choice.message.tool_calls: # type: ignore
function_call: FunctionCall = tc.function # type: ignore
tool_call_descriptions.append(
f"{function_call.name}({function_call.arguments})"
)
for tc in choice.message.tool_calls:
if hasattr(tc.function, "name") and hasattr(
tc.function, "arguments"
):
tool_call_descriptions.append(
f"{tc.function.name}({tc.function.arguments})"
)
tool_calls_str = ", ".join(tool_call_descriptions)
output_text = f"[tool_calls: {tool_calls_str}]"
......@@ -1976,7 +1895,7 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request) # type: ignore[arg-type]
maybe_serialize_tool_calls(request)
# Add system message.
# NOTE: In Chat Completion API, browsing is enabled by default
......@@ -1994,7 +1913,7 @@ class OpenAIServingChat(OpenAIServing):
# Add developer message.
if request.tools:
dev_msg = get_developer_message(
tools=request.tools if should_include_tools else None # type: ignore[arg-type]
tools=request.tools if should_include_tools else None
)
messages.append(dev_msg)
......@@ -2009,4 +1928,4 @@ class OpenAIServingChat(OpenAIServing):
if request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
return messages, [engine_prompt]
return messages, [engine_prompt]
\ No newline at end of file
......@@ -85,12 +85,6 @@ class FrontendArgs:
"""Log level for uvicorn."""
disable_uvicorn_access_log: bool = False
"""Disable uvicorn access log."""
disable_access_log_for_endpoints: str | None = None
"""Comma-separated list of endpoint paths to exclude from uvicorn access
logs. This is useful to reduce log noise from high-frequency endpoints
like health checks. Example: "/health,/metrics,/ping".
When set, access logs for requests to these paths will be suppressed
while keeping logs for other endpoints."""
allow_credentials: bool = False
"""Allow credentials."""
allowed_origins: list[str] = field(default_factory=lambda: ["*"])
......@@ -250,11 +244,6 @@ class FrontendArgs:
del frontend_kwargs["middleware"]["nargs"]
frontend_kwargs["middleware"]["default"] = []
# Special case: disable_access_log_for_endpoints is a single
# comma-separated string, not a list
if "nargs" in frontend_kwargs["disable_access_log_for_endpoints"]:
del frontend_kwargs["disable_access_log_for_endpoints"]["nargs"]
# Special case: Tool call parser shows built-in options.
valid_tool_parsers = list(ToolParserManager.list_registered())
parsers_str = ",".join(valid_tool_parsers)
......
......@@ -163,12 +163,25 @@ class OpenAIServingCompletion(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = get_prompt_components(engine_prompt)
prompt_text, prompt_token_ids, prompt_embeds = (
self._get_prompt_components(engine_prompt)
)
input_length = None
if prompt_token_ids is not None:
input_length = len(prompt_token_ids)
elif prompt_embeds is not None:
input_length = len(prompt_embeds)
else:
raise NotImplementedError
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
prompt=engine_prompt,
input_length=input_length,
default_sampling_params=self.default_sampling_params,
)
......
......@@ -218,10 +218,6 @@ def get_logits_processors(
class FunctionCall(OpenAIBaseModel):
# Internal field to preserve native tool call ID from tool parser.
# Excluded from serialization to maintain OpenAI API compatibility
# (function object should only contain 'name' and 'arguments').
id: str | None = Field(default=None, exclude=True)
name: str
arguments: str
......
......@@ -64,12 +64,13 @@ from vllm.entrypoints.openai.translations.protocol import (
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
......@@ -94,14 +95,11 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest,
TokenizeResponse,
)
from vllm.entrypoints.utils import (
_validate_truncation_size,
get_max_tokens,
sanitize_message,
)
from vllm.entrypoints.utils import _validate_truncation_size, sanitize_message
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import (
PromptComponents,
get_prompt_components,
is_explicit_encoder_decoder_prompt,
)
......@@ -172,7 +170,6 @@ AnyResponse: TypeAlias = (
CompletionResponse
| ChatCompletionResponse
| EmbeddingResponse
| EmbeddingBytesResponse
| TranscriptionResponse
| TokenizeResponse
| PoolingResponse
......@@ -186,21 +183,51 @@ RequestT = TypeVar("RequestT", bound=AnyRequest)
@dataclass(kw_only=True)
class ServeContext(Generic[RequestT]):
class RequestProcessingMixin:
"""
Mixin for request processing,
handling prompt preparation and engine input.
"""
engine_prompts: list[TokensPrompt] | None = field(default_factory=list)
@dataclass(kw_only=True)
class ResponseGenerationMixin:
"""
Mixin for response generation,
managing result generators and final batch results.
"""
result_generator: (
AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None
) = None
final_res_batch: list[RequestOutput | PoolingRequestOutput] = field(
default_factory=list
)
model_config = ConfigDict(arbitrary_types_allowed=True)
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
request: RequestT
raw_request: Request | None = None
model_name: str
request_id: str
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_prompts: list[TokensPrompt] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
)
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
@dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
pass
@dataclass(kw_only=True)
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
chat_template: str | None = None
chat_template_content_format: ChatTemplateContentFormatOption
class OpenAIServing:
......@@ -578,7 +605,10 @@ class OpenAIServing:
self,
ctx: ServeContext,
) -> AnyResponse | ErrorResponse:
async for response in self._pipeline(ctx):
generation: AsyncGenerator[AnyResponse | ErrorResponse, None]
generation = self._pipeline(ctx)
async for response in generation:
return response
return self.create_error_response("No response yielded from pipeline")
......@@ -637,7 +667,9 @@ class OpenAIServing:
ctx: ServeContext,
) -> ErrorResponse | None:
"""Schedule the request and get the result generator."""
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
generators: list[
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
] = []
try:
trace_headers = (
......@@ -691,7 +723,7 @@ class OpenAIServing:
return self.create_error_response("Engine prompts not available")
num_prompts = len(ctx.engine_prompts)
final_res_batch: list[PoolingRequestOutput | None]
final_res_batch: list[RequestOutput | PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts
if ctx.result_generator is None:
......@@ -979,7 +1011,7 @@ class OpenAIServing:
def _validate_input(
self,
request: object,
request: AnyRequest,
input_ids: list[int],
input_text: str,
) -> TokensPrompt:
......@@ -1290,7 +1322,7 @@ class OpenAIServing:
priority: int = 0,
**kwargs,
):
prompt_text, _, _ = get_prompt_components(engine_prompt)
prompt_text, _, _ = self._get_prompt_components(engine_prompt)
orig_priority = priority
sub_request = 0
......@@ -1341,12 +1373,10 @@ class OpenAIServing:
# yield context
# Create inputs for the next turn.
# Render the next prompt token ids and update sampling_params.
# Render the next prompt token ids.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
token_ids = context.render_for_completion()
engine_prompt = TokensPrompt(prompt_token_ids=token_ids)
sampling_params.max_tokens = self.max_model_len - len(token_ids)
prompt_token_ids = context.render_for_completion()
engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
elif isinstance(context, ParsableContext):
engine_prompts = await self._render_next_turn(
context.request,
......@@ -1358,19 +1388,19 @@ class OpenAIServing:
context.chat_template_content_format,
)
engine_prompt = engine_prompts[0]
prompt_text, _, _ = get_prompt_components(engine_prompt)
sampling_params.max_tokens = get_max_tokens(
self.max_model_len,
context.request,
engine_prompt,
self.default_sampling_params, # type: ignore
)
prompt_text, _, _ = self._get_prompt_components(engine_prompt)
# Update the sampling params.
sampling_params.max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]
)
# OPTIMIZATION
priority = orig_priority - 1
sub_request += 1
def _get_prompt_components(self, prompt: PromptType) -> PromptComponents:
return get_prompt_components(prompt)
def _log_inputs(
self,
request_id: str,
......@@ -1381,7 +1411,7 @@ class OpenAIServing:
if self.request_logger is None:
return
prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs)
prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
self.request_logger.log_inputs(
request_id,
......@@ -1495,7 +1525,6 @@ class OpenAIServing:
# extract_tool_calls() returns a list of tool calls.
function_calls.extend(
FunctionCall(
id=tool_call.id,
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
......@@ -1548,4 +1577,4 @@ def clamp_prompt_logprobs(
for logprob_values in logprob_dict.values():
if logprob_values.logprob == float("-inf"):
logprob_values.logprob = -9999.0
return prompt_logprobs
return prompt_logprobs
\ No newline at end of file
......@@ -63,7 +63,6 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
make_tool_call_id,
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.mcp.tool_server import ToolServer
......@@ -116,7 +115,6 @@ from vllm.entrypoints.openai.responses.utils import (
extract_tool_types,
should_continue_final_message,
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
......@@ -252,17 +250,6 @@ class OpenAIServingResponses(OpenAIServing):
self.default_sampling_params["stop_token_ids"].extend(
get_stop_tokens_for_assistant_actions()
)
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides = getattr(self.model_config, "hf_overrides", None)
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
isinstance(hf_overrides, dict)
and hf_overrides.get("model_type") == "kimi_k2"
):
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
self.enable_auto_tools = enable_auto_tools
# set up tool use
self.tool_parser = self._get_tool_parser(
......@@ -436,11 +423,8 @@ class OpenAIServingResponses(OpenAIServing):
if maybe_error is not None:
return maybe_error
default_max_tokens = get_max_tokens(
self.max_model_len,
request,
engine_prompt,
self.default_sampling_params,
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]
)
sampling_params = request.to_sampling_params(
......@@ -970,28 +954,25 @@ class OpenAIServingResponses(OpenAIServing):
enable_auto_tools=self.enable_auto_tools,
tool_parser_cls=self.tool_parser,
)
if content or (self.use_harmony and tool_calls):
res_text_part = None
if content:
res_text_part = ResponseOutputText(
text=content,
annotations=[], # TODO
type="output_text",
logprobs=(
self._create_response_logprobs(
token_ids=final_output.token_ids,
logprobs=final_output.logprobs,
tokenizer=tokenizer,
top_logprobs=request.top_logprobs,
)
if request.is_include_output_logprobs()
else None
),
)
if content:
output_text = ResponseOutputText(
text=content,
annotations=[], # TODO
type="output_text",
logprobs=(
self._create_response_logprobs(
token_ids=final_output.token_ids,
logprobs=final_output.logprobs,
tokenizer=tokenizer,
top_logprobs=request.top_logprobs,
)
if request.is_include_output_logprobs()
else None
),
)
message_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=[res_text_part] if res_text_part else [],
content=[output_text],
role="assistant",
status="completed",
type="message",
......@@ -1003,28 +984,17 @@ class OpenAIServingResponses(OpenAIServing):
if message_item:
outputs.append(message_item)
if tool_calls:
# We use a simple counter for history_tool_call_count because
# we don't track the history of tool calls in the Responses API yet.
# This means that the tool call index will start from 0 for each
# request.
tool_call_items = []
for history_tool_call_cnt, tool_call in enumerate(tool_calls):
tool_call_items.append(
ResponseFunctionToolCall(
id=f"fc_{random_uuid()}",
call_id=tool_call.id
if tool_call.id
else make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_call.name,
idx=history_tool_call_cnt,
),
type="function_call",
status="completed",
name=tool_call.name,
arguments=tool_call.arguments,
)
tool_call_items = [
ResponseFunctionToolCall(
id=f"fc_{random_uuid()}",
call_id=f"call_{random_uuid()}",
type="function_call",
status="completed",
name=tool_call.name,
arguments=tool_call.arguments,
)
for tool_call in tool_calls
]
outputs.extend(tool_call_items)
return outputs
......@@ -2589,4 +2559,4 @@ class OpenAIServingResponses(OpenAIServing):
sequence_number=-1,
response=final_response,
)
)
)
\ No newline at end of file
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from typing import Final, cast
from typing import cast
import jinja2
import numpy as np
......@@ -11,8 +11,18 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
ClassificationServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
......@@ -29,68 +39,60 @@ from vllm.pooling_params import PoolingParams
logger = init_logger(__name__)
ClassificationServeContext = ServeContext[ClassificationRequest]
class ServingClassification(OpenAIServing):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
class ClassificationMixin(OpenAIServing):
chat_template: str | None
chat_template_content_format: ChatTemplateContentFormatOption
trust_request_chat_template: bool
async def _preprocess(
self,
ctx: ClassificationServeContext,
ctx: ServeContext,
) -> ErrorResponse | None:
"""
Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs.
"""
ctx = cast(ClassificationServeContext, ctx)
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(ctx.request, ClassificationChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
request_obj = ctx.request
if isinstance(request_obj, ClassificationChatRequest):
chat_request = request_obj
messages = chat_request.messages
trust_request_chat_template = getattr(
self,
"trust_request_chat_template",
False,
)
if error_check_ret:
return error_check_ret
ret = self._validate_chat_template(
request_chat_template=chat_request.chat_template,
chat_template_kwargs=chat_request.chat_template_kwargs,
trust_request_chat_template=trust_request_chat_template,
)
if ret:
return ret
_, engine_prompts = await self._preprocess_chat(
ctx.request,
cast(ChatCompletionRequest, chat_request),
self.renderer,
ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
messages,
chat_template=(
chat_request.chat_template
or getattr(self, "chat_template", None)
),
chat_template_content_format=cast(
ChatTemplateContentFormatOption,
getattr(self, "chat_template_content_format", "auto"),
),
add_generation_prompt=chat_request.add_generation_prompt,
continue_final_message=chat_request.continue_final_message,
add_special_tokens=chat_request.add_special_tokens,
)
ctx.engine_prompts = engine_prompts
elif isinstance(ctx.request, ClassificationCompletionRequest):
input_data = ctx.request.input
elif isinstance(request_obj, ClassificationCompletionRequest):
completion_request = request_obj
input_data = completion_request.input
if input_data in (None, ""):
return self.create_error_response(
"Input or messages must be provided",
......@@ -104,10 +106,13 @@ class ServingClassification(OpenAIServing):
prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input,
config=self._build_render_config(ctx.request),
config=self._build_render_config(completion_request),
)
else:
return self.create_error_response("Invalid classification request type")
return self.create_error_response(
"Invalid classification request type",
status_code=HTTPStatus.BAD_REQUEST,
)
return None
......@@ -117,14 +122,13 @@ class ServingClassification(OpenAIServing):
def _build_response(
self,
ctx: ClassificationServeContext,
ctx: ServeContext,
) -> ClassificationResponse | ErrorResponse:
"""
Convert model outputs to a formatted classification response
with probabilities and labels.
"""
id2label = getattr(self.model_config.hf_config, "id2label", {})
ctx = cast(ClassificationServeContext, ctx)
items: list[ClassificationData] = []
num_prompt_tokens = 0
......@@ -135,7 +139,9 @@ class ServingClassification(OpenAIServing):
probs = classify_res.probs
predicted_index = int(np.argmax(probs))
label = id2label.get(predicted_index)
label = getattr(self.model_config.hf_config, "id2label", {}).get(
predicted_index
)
item = ClassificationData(
index=idx,
......@@ -168,6 +174,32 @@ class ServingClassification(OpenAIServing):
add_special_tokens=request.add_special_tokens,
)
class ServingClassification(ClassificationMixin):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_classify(
self,
request: ClassificationRequest,
......@@ -183,11 +215,11 @@ class ServingClassification(OpenAIServing):
request_id=request_id,
)
return await self.handle(ctx) # type: ignore[return-value]
return await super().handle(ctx) # type: ignore
def _create_pooling_params(
self,
ctx: ClassificationServeContext,
ctx: ServeContext[ClassificationRequest],
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
......@@ -198,4 +230,4 @@ class ServingClassification(OpenAIServing):
except ValueError as e:
return self.create_error_response(str(e))
return pooling_params
return pooling_params
\ No newline at end of file
......@@ -6,13 +6,21 @@ from typing import Any, Final, cast
import torch
from fastapi import Request
from typing_extensions import assert_never
from fastapi.responses import Response
from typing_extensions import assert_never, override
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
EmbeddingServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
......@@ -25,11 +33,19 @@ from vllm.entrypoints.pooling.embed.protocol import (
from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.outputs import (
EmbeddingRequestOutput,
PoolingOutput,
PoolingRequestOutput,
RequestOutput,
)
from vllm.pooling_params import PoolingParams
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import (
EmbedDType,
EncodingFormat,
Endianness,
encode_pooling_bytes,
encode_pooling_output,
)
......@@ -37,33 +53,9 @@ from vllm.utils.serial_utils import (
logger = init_logger(__name__)
EmbeddingServeContext = ServeContext[EmbeddingRequest]
class OpenAIServingEmbedding(OpenAIServing):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
class EmbeddingMixin(OpenAIServing):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
pooler_config = self.model_config.pooler_config
......@@ -77,41 +69,32 @@ class OpenAIServingEmbedding(OpenAIServing):
else None
)
@override
async def _preprocess(
self,
ctx: EmbeddingServeContext,
ctx: ServeContext,
) -> ErrorResponse | None:
ctx = cast(EmbeddingServeContext, ctx)
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
_, ctx.engine_prompts = await self._preprocess_chat(
ctx.request,
self.renderer,
ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
chat_template=ctx.request.chat_template or ctx.chat_template,
chat_template_content_format=ctx.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
)
elif isinstance(ctx.request, EmbeddingCompletionRequest):
else:
renderer = self._get_completion_renderer()
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input,
config=self._build_render_config(ctx.request),
)
else:
return self.create_error_response("Invalid classification request type")
return None
except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs")
......@@ -130,15 +113,16 @@ class OpenAIServingEmbedding(OpenAIServing):
add_special_tokens=request.add_special_tokens,
)
@override
def _build_response(
self,
ctx: EmbeddingServeContext,
) -> EmbeddingResponse | EmbeddingBytesResponse | ErrorResponse:
final_res_batch_checked = ctx.final_res_batch
ctx: ServeContext,
) -> EmbeddingResponse | Response | ErrorResponse:
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
encoding_format = ctx.request.encoding_format
embed_dtype = ctx.request.embed_dtype
endianness = ctx.request.endianness
encoding_format: EncodingFormat = ctx.request.encoding_format
embed_dtype: EmbedDType = ctx.request.embed_dtype
endianness: Endianness = ctx.request.endianness
def encode_float_base64():
items: list[EmbeddingResponseData] = []
......@@ -219,8 +203,8 @@ class OpenAIServingEmbedding(OpenAIServing):
self,
ctx: EmbeddingServeContext,
token_ids: list[int],
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
pooling_params,
trace_headers,
prompt_idx: int,
) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
"""Process a single prompt using chunked processing."""
......@@ -262,7 +246,7 @@ class OpenAIServingEmbedding(OpenAIServing):
def _validate_input(
self,
request: object,
request,
input_ids: list[int],
input_text: str,
) -> TokensPrompt:
......@@ -342,7 +326,7 @@ class OpenAIServingEmbedding(OpenAIServing):
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
prompt_index: int,
) -> AsyncGenerator[PoolingRequestOutput, None]:
) -> AsyncGenerator[RequestOutput | PoolingRequestOutput, None]:
"""Create a generator for a single prompt using standard processing."""
request_id_item = f"{ctx.request_id}-{prompt_index}"
......@@ -363,6 +347,7 @@ class OpenAIServingEmbedding(OpenAIServing):
priority=getattr(ctx.request, "priority", 0),
)
@override
async def _prepare_generators(
self,
ctx: ServeContext,
......@@ -378,7 +363,9 @@ class OpenAIServingEmbedding(OpenAIServing):
return await super()._prepare_generators(ctx)
# Custom logic for chunked processing
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
generators: list[
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
] = []
try:
trace_headers = (
......@@ -432,9 +419,10 @@ class OpenAIServingEmbedding(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@override
async def _collect_batch(
self,
ctx: EmbeddingServeContext,
ctx: ServeContext,
) -> ErrorResponse | None:
"""Collect and aggregate batch results
with support for chunked processing.
......@@ -443,6 +431,7 @@ class OpenAIServingEmbedding(OpenAIServing):
minimize memory usage.
For regular requests, collects results normally.
"""
ctx = cast(EmbeddingServeContext, ctx)
try:
if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available")
......@@ -538,10 +527,12 @@ class OpenAIServingEmbedding(OpenAIServing):
except (ValueError, IndexError):
prompt_idx = result_idx # Fallback to result_idx
short_prompts_results[prompt_idx] = result
short_prompts_results[prompt_idx] = cast(
PoolingRequestOutput, result
)
# Finalize aggregated results
final_res_batch: list[PoolingRequestOutput] = []
final_res_batch: list[PoolingRequestOutput | EmbeddingRequestOutput] = []
num_prompts = len(ctx.engine_prompts)
for prompt_idx in range(num_prompts):
......@@ -589,19 +580,49 @@ class OpenAIServingEmbedding(OpenAIServing):
f"Failed to aggregate chunks for prompt {prompt_idx}"
)
elif prompt_idx in short_prompts_results:
final_res_batch.append(short_prompts_results[prompt_idx])
final_res_batch.append(
cast(PoolingRequestOutput, short_prompts_results[prompt_idx])
)
else:
return self.create_error_response(
f"Result not found for prompt {prompt_idx}"
)
ctx.final_res_batch = final_res_batch
ctx.final_res_batch = cast(
list[RequestOutput | PoolingRequestOutput], final_res_batch
)
return None
except Exception as e:
return self.create_error_response(str(e))
class OpenAIServingEmbedding(EmbeddingMixin):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_embedding(
self,
request: EmbeddingRequest,
......@@ -624,13 +645,16 @@ class OpenAIServingEmbedding(OpenAIServing):
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
)
return await self.handle(ctx) # type: ignore[return-value]
return await super().handle(ctx) # type: ignore
@override
def _create_pooling_params(
self,
ctx: EmbeddingServeContext,
ctx: ServeContext[EmbeddingRequest],
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
......@@ -642,3 +666,17 @@ class OpenAIServingEmbedding(OpenAIServing):
return self.create_error_response(str(e))
return pooling_params
async def _preprocess(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
return await super()._preprocess(ctx)
\ No newline at end of file
......@@ -17,10 +17,8 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import EmbedsPrompt, TokensPrompt
from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING:
......@@ -34,15 +32,11 @@ if TYPE_CHECKING:
StreamOptions,
)
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
else:
ChatCompletionRequest = object
CompletionRequest = object
StreamOptions = object
LoRAModulePath = object
ResponsesRequest = object
logger = init_logger(__name__)
......@@ -217,26 +211,11 @@ def _validate_truncation_size(
def get_max_tokens(
max_model_len: int,
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
prompt: TokensPrompt | EmbedsPrompt,
request: "ChatCompletionRequest | CompletionRequest",
input_length: int,
default_sampling_params: dict,
) -> int:
# NOTE: Avoid isinstance() for better efficiency
max_tokens: int | None = None
if max_tokens is None:
# ChatCompletionRequest
max_tokens = getattr(request, "max_completion_tokens", None)
if max_tokens is None:
# ResponsesRequest
max_tokens = getattr(request, "max_output_tokens", None)
if max_tokens is None:
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens = getattr(request, "max_tokens", None)
input_length = length_from_prompt_token_ids_or_embeds(
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
prompt.get("prompt_embeds"), # type: ignore[arg-type]
)
max_tokens = getattr(request, "max_completion_tokens", None) or request.max_tokens
default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment