Commit 82e40fb7 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.15.0rc1' into v0.15.0rc1-ori

parents 30a1922e 58996f35
...@@ -63,6 +63,7 @@ from vllm.engine.protocol import EngineClient ...@@ -63,6 +63,7 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatTemplateContentFormatOption, ChatTemplateContentFormatOption,
make_tool_call_id,
) )
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.mcp.tool_server import ToolServer from vllm.entrypoints.mcp.tool_server import ToolServer
...@@ -115,6 +116,7 @@ from vllm.entrypoints.openai.responses.utils import ( ...@@ -115,6 +116,7 @@ from vllm.entrypoints.openai.responses.utils import (
extract_tool_types, extract_tool_types,
should_continue_final_message, should_continue_final_message,
) )
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -250,6 +252,17 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -250,6 +252,17 @@ class OpenAIServingResponses(OpenAIServing):
self.default_sampling_params["stop_token_ids"].extend( self.default_sampling_params["stop_token_ids"].extend(
get_stop_tokens_for_assistant_actions() get_stop_tokens_for_assistant_actions()
) )
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides = getattr(self.model_config, "hf_overrides", None)
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
isinstance(hf_overrides, dict)
and hf_overrides.get("model_type") == "kimi_k2"
):
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
self.enable_auto_tools = enable_auto_tools self.enable_auto_tools = enable_auto_tools
# set up tool use # set up tool use
self.tool_parser = self._get_tool_parser( self.tool_parser = self._get_tool_parser(
...@@ -423,8 +436,11 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -423,8 +436,11 @@ class OpenAIServingResponses(OpenAIServing):
if maybe_error is not None: if maybe_error is not None:
return maybe_error return maybe_error
default_max_tokens = self.max_model_len - len( default_max_tokens = get_max_tokens(
engine_prompt["prompt_token_ids"] self.max_model_len,
request,
engine_prompt,
self.default_sampling_params,
) )
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
...@@ -954,8 +970,11 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -954,8 +970,11 @@ class OpenAIServingResponses(OpenAIServing):
enable_auto_tools=self.enable_auto_tools, enable_auto_tools=self.enable_auto_tools,
tool_parser_cls=self.tool_parser, tool_parser_cls=self.tool_parser,
) )
if content or (self.use_harmony and tool_calls):
res_text_part = None
if content: if content:
output_text = ResponseOutputText( res_text_part = ResponseOutputText(
text=content, text=content,
annotations=[], # TODO annotations=[], # TODO
type="output_text", type="output_text",
...@@ -972,7 +991,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -972,7 +991,7 @@ class OpenAIServingResponses(OpenAIServing):
) )
message_item = ResponseOutputMessage( message_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}", id=f"msg_{random_uuid()}",
content=[output_text], content=[res_text_part] if res_text_part else [],
role="assistant", role="assistant",
status="completed", status="completed",
type="message", type="message",
...@@ -984,17 +1003,28 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -984,17 +1003,28 @@ class OpenAIServingResponses(OpenAIServing):
if message_item: if message_item:
outputs.append(message_item) outputs.append(message_item)
if tool_calls: if tool_calls:
tool_call_items = [ # We use a simple counter for history_tool_call_count because
# we don't track the history of tool calls in the Responses API yet.
# This means that the tool call index will start from 0 for each
# request.
tool_call_items = []
for history_tool_call_cnt, tool_call in enumerate(tool_calls):
tool_call_items.append(
ResponseFunctionToolCall( ResponseFunctionToolCall(
id=f"fc_{random_uuid()}", id=f"fc_{random_uuid()}",
call_id=f"call_{random_uuid()}", call_id=tool_call.id
if tool_call.id
else make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_call.name,
idx=history_tool_call_cnt,
),
type="function_call", type="function_call",
status="completed", status="completed",
name=tool_call.name, name=tool_call.name,
arguments=tool_call.arguments, arguments=tool_call.arguments,
) )
for tool_call in tool_calls )
]
outputs.extend(tool_call_items) outputs.extend(tool_call_items)
return outputs return outputs
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus from http import HTTPStatus
from typing import cast from typing import Final, cast
import jinja2 import jinja2
import numpy as np import numpy as np
...@@ -11,18 +11,8 @@ from fastapi import Request ...@@ -11,18 +11,8 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
ChatCompletionRequest, from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
)
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
ClassificationServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.classify.protocol import ( from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest, ClassificationChatRequest,
...@@ -39,60 +29,68 @@ from vllm.pooling_params import PoolingParams ...@@ -39,60 +29,68 @@ from vllm.pooling_params import PoolingParams
logger = init_logger(__name__) logger = init_logger(__name__)
class ClassificationMixin(OpenAIServing): ClassificationServeContext = ServeContext[ClassificationRequest]
chat_template: str | None
chat_template_content_format: ChatTemplateContentFormatOption
trust_request_chat_template: bool class ServingClassification(OpenAIServing):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def _preprocess( async def _preprocess(
self, self,
ctx: ServeContext, ctx: ClassificationServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
""" """
Process classification inputs: tokenize text, resolve adapters, Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs. and prepare model-specific inputs.
""" """
ctx = cast(ClassificationServeContext, ctx)
try: try:
request_obj = ctx.request ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(request_obj, ClassificationChatRequest): if isinstance(ctx.request, ClassificationChatRequest):
chat_request = request_obj error_check_ret = self._validate_chat_template(
messages = chat_request.messages request_chat_template=ctx.request.chat_template,
trust_request_chat_template = getattr( chat_template_kwargs=ctx.request.chat_template_kwargs,
self, trust_request_chat_template=self.trust_request_chat_template,
"trust_request_chat_template",
False,
)
ret = self._validate_chat_template(
request_chat_template=chat_request.chat_template,
chat_template_kwargs=chat_request.chat_template_kwargs,
trust_request_chat_template=trust_request_chat_template,
) )
if ret: if error_check_ret:
return ret return error_check_ret
_, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self._preprocess_chat(
cast(ChatCompletionRequest, chat_request), ctx.request,
self.renderer, self.renderer,
messages, ctx.request.messages,
chat_template=( chat_template=ctx.request.chat_template or self.chat_template,
chat_request.chat_template chat_template_content_format=self.chat_template_content_format,
or getattr(self, "chat_template", None) add_generation_prompt=ctx.request.add_generation_prompt,
), continue_final_message=ctx.request.continue_final_message,
chat_template_content_format=cast( add_special_tokens=ctx.request.add_special_tokens,
ChatTemplateContentFormatOption,
getattr(self, "chat_template_content_format", "auto"),
),
add_generation_prompt=chat_request.add_generation_prompt,
continue_final_message=chat_request.continue_final_message,
add_special_tokens=chat_request.add_special_tokens,
) )
ctx.engine_prompts = engine_prompts ctx.engine_prompts = engine_prompts
elif isinstance(request_obj, ClassificationCompletionRequest): elif isinstance(ctx.request, ClassificationCompletionRequest):
completion_request = request_obj input_data = ctx.request.input
input_data = completion_request.input
if input_data in (None, ""): if input_data in (None, ""):
return self.create_error_response( return self.create_error_response(
"Input or messages must be provided", "Input or messages must be provided",
...@@ -106,13 +104,10 @@ class ClassificationMixin(OpenAIServing): ...@@ -106,13 +104,10 @@ class ClassificationMixin(OpenAIServing):
prompt_input = cast(str | list[str], input_data) prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt( ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input, prompt_or_prompts=prompt_input,
config=self._build_render_config(completion_request), config=self._build_render_config(ctx.request),
) )
else: else:
return self.create_error_response( return self.create_error_response("Invalid classification request type")
"Invalid classification request type",
status_code=HTTPStatus.BAD_REQUEST,
)
return None return None
...@@ -122,13 +117,14 @@ class ClassificationMixin(OpenAIServing): ...@@ -122,13 +117,14 @@ class ClassificationMixin(OpenAIServing):
def _build_response( def _build_response(
self, self,
ctx: ServeContext, ctx: ClassificationServeContext,
) -> ClassificationResponse | ErrorResponse: ) -> ClassificationResponse | ErrorResponse:
""" """
Convert model outputs to a formatted classification response Convert model outputs to a formatted classification response
with probabilities and labels. with probabilities and labels.
""" """
ctx = cast(ClassificationServeContext, ctx) id2label = getattr(self.model_config.hf_config, "id2label", {})
items: list[ClassificationData] = [] items: list[ClassificationData] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
...@@ -139,9 +135,7 @@ class ClassificationMixin(OpenAIServing): ...@@ -139,9 +135,7 @@ class ClassificationMixin(OpenAIServing):
probs = classify_res.probs probs = classify_res.probs
predicted_index = int(np.argmax(probs)) predicted_index = int(np.argmax(probs))
label = getattr(self.model_config.hf_config, "id2label", {}).get( label = id2label.get(predicted_index)
predicted_index
)
item = ClassificationData( item = ClassificationData(
index=idx, index=idx,
...@@ -174,32 +168,6 @@ class ClassificationMixin(OpenAIServing): ...@@ -174,32 +168,6 @@ class ClassificationMixin(OpenAIServing):
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
class ServingClassification(ClassificationMixin):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_classify( async def create_classify(
self, self,
request: ClassificationRequest, request: ClassificationRequest,
...@@ -215,11 +183,11 @@ class ServingClassification(ClassificationMixin): ...@@ -215,11 +183,11 @@ class ServingClassification(ClassificationMixin):
request_id=request_id, request_id=request_id,
) )
return await super().handle(ctx) # type: ignore return await self.handle(ctx) # type: ignore[return-value]
def _create_pooling_params( def _create_pooling_params(
self, self,
ctx: ServeContext[ClassificationRequest], ctx: ClassificationServeContext,
) -> PoolingParams | ErrorResponse: ) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx) pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse): if isinstance(pooling_params, ErrorResponse):
......
...@@ -6,21 +6,13 @@ from typing import Any, Final, cast ...@@ -6,21 +6,13 @@ from typing import Any, Final, cast
import torch import torch
from fastapi import Request from fastapi import Request
from fastapi.responses import Response from typing_extensions import assert_never
from typing_extensions import assert_never, override
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
ErrorResponse, from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
EmbeddingServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.embed.protocol import ( from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse, EmbeddingBytesResponse,
...@@ -33,19 +25,11 @@ from vllm.entrypoints.pooling.embed.protocol import ( ...@@ -33,19 +25,11 @@ from vllm.entrypoints.pooling.embed.protocol import (
from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import ( from vllm.outputs import PoolingOutput, PoolingRequestOutput
EmbeddingRequestOutput,
PoolingOutput,
PoolingRequestOutput,
RequestOutput,
)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import ( from vllm.utils.serial_utils import (
EmbedDType,
EncodingFormat,
Endianness,
encode_pooling_bytes, encode_pooling_bytes,
encode_pooling_output, encode_pooling_output,
) )
...@@ -53,9 +37,33 @@ from vllm.utils.serial_utils import ( ...@@ -53,9 +37,33 @@ from vllm.utils.serial_utils import (
logger = init_logger(__name__) logger = init_logger(__name__)
class EmbeddingMixin(OpenAIServing): EmbeddingServeContext = ServeContext[EmbeddingRequest]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class OpenAIServingEmbedding(OpenAIServing):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
pooler_config = self.model_config.pooler_config pooler_config = self.model_config.pooler_config
...@@ -69,32 +77,41 @@ class EmbeddingMixin(OpenAIServing): ...@@ -69,32 +77,41 @@ class EmbeddingMixin(OpenAIServing):
else None else None
) )
@override
async def _preprocess( async def _preprocess(
self, self,
ctx: ServeContext, ctx: EmbeddingServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
ctx = cast(EmbeddingServeContext, ctx)
try: try:
ctx.lora_request = self._maybe_get_adapters(ctx.request) ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(ctx.request, EmbeddingChatRequest): if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
_, ctx.engine_prompts = await self._preprocess_chat( _, ctx.engine_prompts = await self._preprocess_chat(
ctx.request, ctx.request,
self.renderer, self.renderer,
ctx.request.messages, ctx.request.messages,
chat_template=ctx.request.chat_template or ctx.chat_template, chat_template=ctx.request.chat_template or self.chat_template,
chat_template_content_format=ctx.chat_template_content_format, chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt, add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message, continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens, add_special_tokens=ctx.request.add_special_tokens,
) )
else: elif isinstance(ctx.request, EmbeddingCompletionRequest):
renderer = self._get_completion_renderer() renderer = self._get_completion_renderer()
ctx.engine_prompts = await renderer.render_prompt( ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input, prompt_or_prompts=ctx.request.input,
config=self._build_render_config(ctx.request), config=self._build_render_config(ctx.request),
) )
else:
return self.create_error_response("Invalid classification request type")
return None return None
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
...@@ -113,16 +130,15 @@ class EmbeddingMixin(OpenAIServing): ...@@ -113,16 +130,15 @@ class EmbeddingMixin(OpenAIServing):
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
@override
def _build_response( def _build_response(
self, self,
ctx: ServeContext, ctx: EmbeddingServeContext,
) -> EmbeddingResponse | Response | ErrorResponse: ) -> EmbeddingResponse | EmbeddingBytesResponse | ErrorResponse:
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch) final_res_batch_checked = ctx.final_res_batch
encoding_format: EncodingFormat = ctx.request.encoding_format encoding_format = ctx.request.encoding_format
embed_dtype: EmbedDType = ctx.request.embed_dtype embed_dtype = ctx.request.embed_dtype
endianness: Endianness = ctx.request.endianness endianness = ctx.request.endianness
def encode_float_base64(): def encode_float_base64():
items: list[EmbeddingResponseData] = [] items: list[EmbeddingResponseData] = []
...@@ -203,8 +219,8 @@ class EmbeddingMixin(OpenAIServing): ...@@ -203,8 +219,8 @@ class EmbeddingMixin(OpenAIServing):
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
token_ids: list[int], token_ids: list[int],
pooling_params, pooling_params: PoolingParams,
trace_headers, trace_headers: Mapping[str, str] | None,
prompt_idx: int, prompt_idx: int,
) -> list[AsyncGenerator[PoolingRequestOutput, None]]: ) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
"""Process a single prompt using chunked processing.""" """Process a single prompt using chunked processing."""
...@@ -246,7 +262,7 @@ class EmbeddingMixin(OpenAIServing): ...@@ -246,7 +262,7 @@ class EmbeddingMixin(OpenAIServing):
def _validate_input( def _validate_input(
self, self,
request, request: object,
input_ids: list[int], input_ids: list[int],
input_text: str, input_text: str,
) -> TokensPrompt: ) -> TokensPrompt:
...@@ -326,7 +342,7 @@ class EmbeddingMixin(OpenAIServing): ...@@ -326,7 +342,7 @@ class EmbeddingMixin(OpenAIServing):
pooling_params: PoolingParams, pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None, trace_headers: Mapping[str, str] | None,
prompt_index: int, prompt_index: int,
) -> AsyncGenerator[RequestOutput | PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Create a generator for a single prompt using standard processing.""" """Create a generator for a single prompt using standard processing."""
request_id_item = f"{ctx.request_id}-{prompt_index}" request_id_item = f"{ctx.request_id}-{prompt_index}"
...@@ -347,7 +363,6 @@ class EmbeddingMixin(OpenAIServing): ...@@ -347,7 +363,6 @@ class EmbeddingMixin(OpenAIServing):
priority=getattr(ctx.request, "priority", 0), priority=getattr(ctx.request, "priority", 0),
) )
@override
async def _prepare_generators( async def _prepare_generators(
self, self,
ctx: ServeContext, ctx: ServeContext,
...@@ -363,9 +378,7 @@ class EmbeddingMixin(OpenAIServing): ...@@ -363,9 +378,7 @@ class EmbeddingMixin(OpenAIServing):
return await super()._prepare_generators(ctx) return await super()._prepare_generators(ctx)
# Custom logic for chunked processing # Custom logic for chunked processing
generators: list[ generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
] = []
try: try:
trace_headers = ( trace_headers = (
...@@ -419,10 +432,9 @@ class EmbeddingMixin(OpenAIServing): ...@@ -419,10 +432,9 @@ class EmbeddingMixin(OpenAIServing):
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
@override
async def _collect_batch( async def _collect_batch(
self, self,
ctx: ServeContext, ctx: EmbeddingServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Collect and aggregate batch results """Collect and aggregate batch results
with support for chunked processing. with support for chunked processing.
...@@ -431,7 +443,6 @@ class EmbeddingMixin(OpenAIServing): ...@@ -431,7 +443,6 @@ class EmbeddingMixin(OpenAIServing):
minimize memory usage. minimize memory usage.
For regular requests, collects results normally. For regular requests, collects results normally.
""" """
ctx = cast(EmbeddingServeContext, ctx)
try: try:
if ctx.engine_prompts is None: if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available") return self.create_error_response("Engine prompts not available")
...@@ -527,12 +538,10 @@ class EmbeddingMixin(OpenAIServing): ...@@ -527,12 +538,10 @@ class EmbeddingMixin(OpenAIServing):
except (ValueError, IndexError): except (ValueError, IndexError):
prompt_idx = result_idx # Fallback to result_idx prompt_idx = result_idx # Fallback to result_idx
short_prompts_results[prompt_idx] = cast( short_prompts_results[prompt_idx] = result
PoolingRequestOutput, result
)
# Finalize aggregated results # Finalize aggregated results
final_res_batch: list[PoolingRequestOutput | EmbeddingRequestOutput] = [] final_res_batch: list[PoolingRequestOutput] = []
num_prompts = len(ctx.engine_prompts) num_prompts = len(ctx.engine_prompts)
for prompt_idx in range(num_prompts): for prompt_idx in range(num_prompts):
...@@ -580,49 +589,19 @@ class EmbeddingMixin(OpenAIServing): ...@@ -580,49 +589,19 @@ class EmbeddingMixin(OpenAIServing):
f"Failed to aggregate chunks for prompt {prompt_idx}" f"Failed to aggregate chunks for prompt {prompt_idx}"
) )
elif prompt_idx in short_prompts_results: elif prompt_idx in short_prompts_results:
final_res_batch.append( final_res_batch.append(short_prompts_results[prompt_idx])
cast(PoolingRequestOutput, short_prompts_results[prompt_idx])
)
else: else:
return self.create_error_response( return self.create_error_response(
f"Result not found for prompt {prompt_idx}" f"Result not found for prompt {prompt_idx}"
) )
ctx.final_res_batch = cast( ctx.final_res_batch = final_res_batch
list[RequestOutput | PoolingRequestOutput], final_res_batch
)
return None return None
except Exception as e: except Exception as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
class OpenAIServingEmbedding(EmbeddingMixin):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_embedding( async def create_embedding(
self, self,
request: EmbeddingRequest, request: EmbeddingRequest,
...@@ -645,16 +624,13 @@ class OpenAIServingEmbedding(EmbeddingMixin): ...@@ -645,16 +624,13 @@ class OpenAIServingEmbedding(EmbeddingMixin):
raw_request=raw_request, raw_request=raw_request,
model_name=model_name, model_name=model_name,
request_id=request_id, request_id=request_id,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
) )
return await super().handle(ctx) # type: ignore return await self.handle(ctx) # type: ignore[return-value]
@override
def _create_pooling_params( def _create_pooling_params(
self, self,
ctx: ServeContext[EmbeddingRequest], ctx: EmbeddingServeContext,
) -> PoolingParams | ErrorResponse: ) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx) pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse): if isinstance(pooling_params, ErrorResponse):
...@@ -666,17 +642,3 @@ class OpenAIServingEmbedding(EmbeddingMixin): ...@@ -666,17 +642,3 @@ class OpenAIServingEmbedding(EmbeddingMixin):
return self.create_error_response(str(e)) return self.create_error_response(str(e))
return pooling_params return pooling_params
async def _preprocess(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
return await super()._preprocess(ctx)
...@@ -17,8 +17,10 @@ from starlette.background import BackgroundTask, BackgroundTasks ...@@ -17,8 +17,10 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs from vllm import envs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import EmbedsPrompt, TokensPrompt
from vllm.logger import current_formatter_type, init_logger from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -32,11 +34,15 @@ if TYPE_CHECKING: ...@@ -32,11 +34,15 @@ if TYPE_CHECKING:
StreamOptions, StreamOptions,
) )
from vllm.entrypoints.openai.models.protocol import LoRAModulePath from vllm.entrypoints.openai.models.protocol import LoRAModulePath
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
else: else:
ChatCompletionRequest = object ChatCompletionRequest = object
CompletionRequest = object CompletionRequest = object
StreamOptions = object StreamOptions = object
LoRAModulePath = object LoRAModulePath = object
ResponsesRequest = object
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -211,11 +217,26 @@ def _validate_truncation_size( ...@@ -211,11 +217,26 @@ def _validate_truncation_size(
def get_max_tokens( def get_max_tokens(
max_model_len: int, max_model_len: int,
request: "ChatCompletionRequest | CompletionRequest", request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
input_length: int, prompt: TokensPrompt | EmbedsPrompt,
default_sampling_params: dict, default_sampling_params: dict,
) -> int: ) -> int:
max_tokens = getattr(request, "max_completion_tokens", None) or request.max_tokens # NOTE: Avoid isinstance() for better efficiency
max_tokens: int | None = None
if max_tokens is None:
# ChatCompletionRequest
max_tokens = getattr(request, "max_completion_tokens", None)
if max_tokens is None:
# ResponsesRequest
max_tokens = getattr(request, "max_output_tokens", None)
if max_tokens is None:
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens = getattr(request, "max_tokens", None)
input_length = length_from_prompt_token_ids_or_embeds(
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
prompt.get("prompt_embeds"), # type: ignore[arg-type]
)
default_max_tokens = max_model_len - input_length default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length) max_output_tokens = current_platform.get_max_output_tokens(input_length)
......
...@@ -87,6 +87,7 @@ if TYPE_CHECKING: ...@@ -87,6 +87,7 @@ if TYPE_CHECKING:
VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds
VLLM_PLUGINS: list[str] | None = None VLLM_PLUGINS: list[str] | None = None
VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None
VLLM_LORA_RESOLVER_HF_REPO_LIST: str | None = None
# Deprecated env variables for profiling, kept for backward compatibility # Deprecated env variables for profiling, kept for backward compatibility
# See also vllm/config/profiler.py and `--profiler-config` argument # See also vllm/config/profiler.py and `--profiler-config` argument
VLLM_TORCH_CUDA_PROFILE: str | None = None VLLM_TORCH_CUDA_PROFILE: str | None = None
...@@ -288,16 +289,11 @@ def use_aot_compile() -> bool: ...@@ -288,16 +289,11 @@ def use_aot_compile() -> bool:
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
default_value = ( default_value = (
"1" "1"
if is_torch_equal_or_newer("2.10.0.dev") if is_torch_equal_or_newer("2.10.0.dev") and not disable_compile_cache()
and not disable_compile_cache()
# Disabling AOT_COMPILE for CPU
# See: https://github.com/vllm-project/vllm/issues/32033
and not current_platform.is_cpu()
else "0" else "0"
) )
...@@ -782,6 +778,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -782,6 +778,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
), ),
# Backend for Video IO # Backend for Video IO
# - "opencv": Default backend that uses OpenCV stream buffered backend. # - "opencv": Default backend that uses OpenCV stream buffered backend.
# - "identity": Returns raw video bytes for model processor to handle.
# #
# Custom backend implementations can be registered # Custom backend implementations can be registered
# via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and # via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and
...@@ -873,6 +870,13 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -873,6 +870,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv( "VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv(
"VLLM_LORA_RESOLVER_CACHE_DIR", None "VLLM_LORA_RESOLVER_CACHE_DIR", None
), ),
# A remote HF repo(s) containing one or more LoRA adapters, which
# may be downloaded and leveraged as needed. Only works if plugins
# are enabled and VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled.
# Values should be comma separated.
"VLLM_LORA_RESOLVER_HF_REPO_LIST": lambda: os.getenv(
"VLLM_LORA_RESOLVER_HF_REPO_LIST", None
),
# Enables torch CUDA profiling if set to 1. # Enables torch CUDA profiling if set to 1.
# Deprecated, see profiler_config. # Deprecated, see profiler_config.
"VLLM_TORCH_CUDA_PROFILE": lambda: os.getenv("VLLM_TORCH_CUDA_PROFILE"), "VLLM_TORCH_CUDA_PROFILE": lambda: os.getenv("VLLM_TORCH_CUDA_PROFILE"),
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.logging_utils.access_log_filter import (
UvicornAccessLogFilter,
create_uvicorn_log_config,
)
from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter
from vllm.logging_utils.lazy import lazy from vllm.logging_utils.lazy import lazy
from vllm.logging_utils.log_time import logtime from vllm.logging_utils.log_time import logtime
...@@ -8,6 +12,8 @@ from vllm.logging_utils.log_time import logtime ...@@ -8,6 +12,8 @@ from vllm.logging_utils.log_time import logtime
__all__ = [ __all__ = [
"NewLineFormatter", "NewLineFormatter",
"ColoredFormatter", "ColoredFormatter",
"UvicornAccessLogFilter",
"create_uvicorn_log_config",
"lazy", "lazy",
"logtime", "logtime",
] ]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Access log filter for uvicorn to exclude specific endpoints from logging.
This module provides a logging filter that can be used to suppress access logs
for specific endpoints (e.g., /health, /metrics) to reduce log noise in
production environments.
"""
import logging
from urllib.parse import urlparse
class UvicornAccessLogFilter(logging.Filter):
"""
A logging filter that excludes access logs for specified endpoint paths.
This filter is designed to work with uvicorn's access logger. It checks
the log record's arguments for the request path and filters out records
matching the excluded paths.
Uvicorn access log format:
'%s - "%s %s HTTP/%s" %d'
(client_addr, method, path, http_version, status_code)
Example:
127.0.0.1:12345 - "GET /health HTTP/1.1" 200
Args:
excluded_paths: A list of URL paths to exclude from logging.
Paths are matched exactly.
Example: ["/health", "/metrics"]
"""
def __init__(self, excluded_paths: list[str] | None = None):
super().__init__()
self.excluded_paths = set(excluded_paths or [])
def filter(self, record: logging.LogRecord) -> bool:
"""
Determine if the log record should be logged.
Args:
record: The log record to evaluate.
Returns:
True if the record should be logged, False otherwise.
"""
if not self.excluded_paths:
return True
# This filter is specific to uvicorn's access logs.
if record.name != "uvicorn.access":
return True
# The path is the 3rd argument in the log record's args tuple.
# See uvicorn's access logging implementation for details.
log_args = record.args
if isinstance(log_args, tuple) and len(log_args) >= 3:
path_with_query = log_args[2]
# Get path component without query string.
if isinstance(path_with_query, str):
path = urlparse(path_with_query).path
if path in self.excluded_paths:
return False
return True
def create_uvicorn_log_config(
excluded_paths: list[str] | None = None,
log_level: str = "info",
) -> dict:
"""
Create a uvicorn logging configuration with access log filtering.
This function generates a logging configuration dictionary that can be
passed to uvicorn's `log_config` parameter. It sets up the access log
filter to exclude specified paths.
Args:
excluded_paths: List of URL paths to exclude from access logs.
log_level: The log level for uvicorn loggers.
Returns:
A dictionary containing the logging configuration.
Example:
>>> config = create_uvicorn_log_config(["/health", "/metrics"])
>>> uvicorn.run(app, log_config=config)
"""
config = {
"version": 1,
"disable_existing_loggers": False,
"filters": {
"access_log_filter": {
"()": UvicornAccessLogFilter,
"excluded_paths": excluded_paths or [],
},
},
"formatters": {
"default": {
"()": "uvicorn.logging.DefaultFormatter",
"fmt": "%(levelprefix)s %(message)s",
"use_colors": None,
},
"access": {
"()": "uvicorn.logging.AccessFormatter",
"fmt": '%(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s', # noqa: E501
},
},
"handlers": {
"default": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"access": {
"formatter": "access",
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
"filters": ["access_log_filter"],
},
},
"loggers": {
"uvicorn": {
"handlers": ["default"],
"level": log_level.upper(),
"propagate": False,
},
"uvicorn.error": {
"level": log_level.upper(),
"handlers": ["default"],
"propagate": False,
},
"uvicorn.access": {
"handlers": ["access"],
"level": log_level.upper(),
"propagate": False,
},
},
}
return config
...@@ -148,7 +148,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -148,7 +148,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True # NOTE(rob): discovered an IMA with this combination. Needs investigation.
return not moe_parallel_config.use_fi_all2allv_kernels
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return True return True
......
...@@ -533,7 +533,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -533,7 +533,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
assert a1.dim() == 2 assert a1.dim() == 2
assert topk_ids.dim() == 2 assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0) assert topk_ids.size(0) == a1.size(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment