Commit d76fc11e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.15.0rc1' into v0.15.0rc1-dev

parents 38166ec4 58996f35
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus from http import HTTPStatus
from typing import cast from typing import Final, cast
import jinja2 import jinja2
import numpy as np import numpy as np
...@@ -11,18 +11,8 @@ from fastapi import Request ...@@ -11,18 +11,8 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
ChatCompletionRequest, from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
)
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
ClassificationServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.classify.protocol import ( from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest, ClassificationChatRequest,
...@@ -39,60 +29,68 @@ from vllm.pooling_params import PoolingParams ...@@ -39,60 +29,68 @@ from vllm.pooling_params import PoolingParams
logger = init_logger(__name__) logger = init_logger(__name__)
class ClassificationMixin(OpenAIServing): ClassificationServeContext = ServeContext[ClassificationRequest]
chat_template: str | None
chat_template_content_format: ChatTemplateContentFormatOption
trust_request_chat_template: bool class ServingClassification(OpenAIServing):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def _preprocess( async def _preprocess(
self, self,
ctx: ServeContext, ctx: ClassificationServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
""" """
Process classification inputs: tokenize text, resolve adapters, Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs. and prepare model-specific inputs.
""" """
ctx = cast(ClassificationServeContext, ctx)
try: try:
request_obj = ctx.request ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(request_obj, ClassificationChatRequest): if isinstance(ctx.request, ClassificationChatRequest):
chat_request = request_obj error_check_ret = self._validate_chat_template(
messages = chat_request.messages request_chat_template=ctx.request.chat_template,
trust_request_chat_template = getattr( chat_template_kwargs=ctx.request.chat_template_kwargs,
self, trust_request_chat_template=self.trust_request_chat_template,
"trust_request_chat_template",
False,
)
ret = self._validate_chat_template(
request_chat_template=chat_request.chat_template,
chat_template_kwargs=chat_request.chat_template_kwargs,
trust_request_chat_template=trust_request_chat_template,
) )
if ret: if error_check_ret:
return ret return error_check_ret
_, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self._preprocess_chat(
cast(ChatCompletionRequest, chat_request), ctx.request,
self.renderer, self.renderer,
messages, ctx.request.messages,
chat_template=( chat_template=ctx.request.chat_template or self.chat_template,
chat_request.chat_template chat_template_content_format=self.chat_template_content_format,
or getattr(self, "chat_template", None) add_generation_prompt=ctx.request.add_generation_prompt,
), continue_final_message=ctx.request.continue_final_message,
chat_template_content_format=cast( add_special_tokens=ctx.request.add_special_tokens,
ChatTemplateContentFormatOption,
getattr(self, "chat_template_content_format", "auto"),
),
add_generation_prompt=chat_request.add_generation_prompt,
continue_final_message=chat_request.continue_final_message,
add_special_tokens=chat_request.add_special_tokens,
) )
ctx.engine_prompts = engine_prompts ctx.engine_prompts = engine_prompts
elif isinstance(request_obj, ClassificationCompletionRequest): elif isinstance(ctx.request, ClassificationCompletionRequest):
completion_request = request_obj input_data = ctx.request.input
input_data = completion_request.input
if input_data in (None, ""): if input_data in (None, ""):
return self.create_error_response( return self.create_error_response(
"Input or messages must be provided", "Input or messages must be provided",
...@@ -106,13 +104,10 @@ class ClassificationMixin(OpenAIServing): ...@@ -106,13 +104,10 @@ class ClassificationMixin(OpenAIServing):
prompt_input = cast(str | list[str], input_data) prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt( ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input, prompt_or_prompts=prompt_input,
config=self._build_render_config(completion_request), config=self._build_render_config(ctx.request),
) )
else: else:
return self.create_error_response( return self.create_error_response("Invalid classification request type")
"Invalid classification request type",
status_code=HTTPStatus.BAD_REQUEST,
)
return None return None
...@@ -122,13 +117,14 @@ class ClassificationMixin(OpenAIServing): ...@@ -122,13 +117,14 @@ class ClassificationMixin(OpenAIServing):
def _build_response( def _build_response(
self, self,
ctx: ServeContext, ctx: ClassificationServeContext,
) -> ClassificationResponse | ErrorResponse: ) -> ClassificationResponse | ErrorResponse:
""" """
Convert model outputs to a formatted classification response Convert model outputs to a formatted classification response
with probabilities and labels. with probabilities and labels.
""" """
ctx = cast(ClassificationServeContext, ctx) id2label = getattr(self.model_config.hf_config, "id2label", {})
items: list[ClassificationData] = [] items: list[ClassificationData] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
...@@ -139,9 +135,7 @@ class ClassificationMixin(OpenAIServing): ...@@ -139,9 +135,7 @@ class ClassificationMixin(OpenAIServing):
probs = classify_res.probs probs = classify_res.probs
predicted_index = int(np.argmax(probs)) predicted_index = int(np.argmax(probs))
label = getattr(self.model_config.hf_config, "id2label", {}).get( label = id2label.get(predicted_index)
predicted_index
)
item = ClassificationData( item = ClassificationData(
index=idx, index=idx,
...@@ -174,32 +168,6 @@ class ClassificationMixin(OpenAIServing): ...@@ -174,32 +168,6 @@ class ClassificationMixin(OpenAIServing):
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
class ServingClassification(ClassificationMixin):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_classify( async def create_classify(
self, self,
request: ClassificationRequest, request: ClassificationRequest,
...@@ -215,11 +183,11 @@ class ServingClassification(ClassificationMixin): ...@@ -215,11 +183,11 @@ class ServingClassification(ClassificationMixin):
request_id=request_id, request_id=request_id,
) )
return await super().handle(ctx) # type: ignore return await self.handle(ctx) # type: ignore[return-value]
def _create_pooling_params( def _create_pooling_params(
self, self,
ctx: ServeContext[ClassificationRequest], ctx: ClassificationServeContext,
) -> PoolingParams | ErrorResponse: ) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx) pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse): if isinstance(pooling_params, ErrorResponse):
......
...@@ -6,21 +6,13 @@ from typing import Any, Final, cast ...@@ -6,21 +6,13 @@ from typing import Any, Final, cast
import torch import torch
from fastapi import Request from fastapi import Request
from fastapi.responses import Response from typing_extensions import assert_never
from typing_extensions import assert_never, override
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
ErrorResponse, from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
EmbeddingServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.embed.protocol import ( from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse, EmbeddingBytesResponse,
...@@ -33,19 +25,11 @@ from vllm.entrypoints.pooling.embed.protocol import ( ...@@ -33,19 +25,11 @@ from vllm.entrypoints.pooling.embed.protocol import (
from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import ( from vllm.outputs import PoolingOutput, PoolingRequestOutput
EmbeddingRequestOutput,
PoolingOutput,
PoolingRequestOutput,
RequestOutput,
)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import ( from vllm.utils.serial_utils import (
EmbedDType,
EncodingFormat,
Endianness,
encode_pooling_bytes, encode_pooling_bytes,
encode_pooling_output, encode_pooling_output,
) )
...@@ -53,9 +37,33 @@ from vllm.utils.serial_utils import ( ...@@ -53,9 +37,33 @@ from vllm.utils.serial_utils import (
logger = init_logger(__name__) logger = init_logger(__name__)
class EmbeddingMixin(OpenAIServing): EmbeddingServeContext = ServeContext[EmbeddingRequest]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class OpenAIServingEmbedding(OpenAIServing):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
pooler_config = self.model_config.pooler_config pooler_config = self.model_config.pooler_config
...@@ -69,32 +77,41 @@ class EmbeddingMixin(OpenAIServing): ...@@ -69,32 +77,41 @@ class EmbeddingMixin(OpenAIServing):
else None else None
) )
@override
async def _preprocess( async def _preprocess(
self, self,
ctx: ServeContext, ctx: EmbeddingServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
ctx = cast(EmbeddingServeContext, ctx)
try: try:
ctx.lora_request = self._maybe_get_adapters(ctx.request) ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(ctx.request, EmbeddingChatRequest): if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
_, ctx.engine_prompts = await self._preprocess_chat( _, ctx.engine_prompts = await self._preprocess_chat(
ctx.request, ctx.request,
self.renderer, self.renderer,
ctx.request.messages, ctx.request.messages,
chat_template=ctx.request.chat_template or ctx.chat_template, chat_template=ctx.request.chat_template or self.chat_template,
chat_template_content_format=ctx.chat_template_content_format, chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt, add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message, continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens, add_special_tokens=ctx.request.add_special_tokens,
) )
else: elif isinstance(ctx.request, EmbeddingCompletionRequest):
renderer = self._get_completion_renderer() renderer = self._get_completion_renderer()
ctx.engine_prompts = await renderer.render_prompt( ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input, prompt_or_prompts=ctx.request.input,
config=self._build_render_config(ctx.request), config=self._build_render_config(ctx.request),
) )
else:
return self.create_error_response("Invalid classification request type")
return None return None
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
...@@ -113,16 +130,15 @@ class EmbeddingMixin(OpenAIServing): ...@@ -113,16 +130,15 @@ class EmbeddingMixin(OpenAIServing):
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
@override
def _build_response( def _build_response(
self, self,
ctx: ServeContext, ctx: EmbeddingServeContext,
) -> EmbeddingResponse | Response | ErrorResponse: ) -> EmbeddingResponse | EmbeddingBytesResponse | ErrorResponse:
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch) final_res_batch_checked = ctx.final_res_batch
encoding_format: EncodingFormat = ctx.request.encoding_format encoding_format = ctx.request.encoding_format
embed_dtype: EmbedDType = ctx.request.embed_dtype embed_dtype = ctx.request.embed_dtype
endianness: Endianness = ctx.request.endianness endianness = ctx.request.endianness
def encode_float_base64(): def encode_float_base64():
items: list[EmbeddingResponseData] = [] items: list[EmbeddingResponseData] = []
...@@ -203,8 +219,8 @@ class EmbeddingMixin(OpenAIServing): ...@@ -203,8 +219,8 @@ class EmbeddingMixin(OpenAIServing):
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
token_ids: list[int], token_ids: list[int],
pooling_params, pooling_params: PoolingParams,
trace_headers, trace_headers: Mapping[str, str] | None,
prompt_idx: int, prompt_idx: int,
) -> list[AsyncGenerator[PoolingRequestOutput, None]]: ) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
"""Process a single prompt using chunked processing.""" """Process a single prompt using chunked processing."""
...@@ -246,7 +262,7 @@ class EmbeddingMixin(OpenAIServing): ...@@ -246,7 +262,7 @@ class EmbeddingMixin(OpenAIServing):
def _validate_input( def _validate_input(
self, self,
request, request: object,
input_ids: list[int], input_ids: list[int],
input_text: str, input_text: str,
) -> TokensPrompt: ) -> TokensPrompt:
...@@ -326,7 +342,7 @@ class EmbeddingMixin(OpenAIServing): ...@@ -326,7 +342,7 @@ class EmbeddingMixin(OpenAIServing):
pooling_params: PoolingParams, pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None, trace_headers: Mapping[str, str] | None,
prompt_index: int, prompt_index: int,
) -> AsyncGenerator[RequestOutput | PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Create a generator for a single prompt using standard processing.""" """Create a generator for a single prompt using standard processing."""
request_id_item = f"{ctx.request_id}-{prompt_index}" request_id_item = f"{ctx.request_id}-{prompt_index}"
...@@ -347,7 +363,6 @@ class EmbeddingMixin(OpenAIServing): ...@@ -347,7 +363,6 @@ class EmbeddingMixin(OpenAIServing):
priority=getattr(ctx.request, "priority", 0), priority=getattr(ctx.request, "priority", 0),
) )
@override
async def _prepare_generators( async def _prepare_generators(
self, self,
ctx: ServeContext, ctx: ServeContext,
...@@ -363,9 +378,7 @@ class EmbeddingMixin(OpenAIServing): ...@@ -363,9 +378,7 @@ class EmbeddingMixin(OpenAIServing):
return await super()._prepare_generators(ctx) return await super()._prepare_generators(ctx)
# Custom logic for chunked processing # Custom logic for chunked processing
generators: list[ generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
] = []
try: try:
trace_headers = ( trace_headers = (
...@@ -419,10 +432,9 @@ class EmbeddingMixin(OpenAIServing): ...@@ -419,10 +432,9 @@ class EmbeddingMixin(OpenAIServing):
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
@override
async def _collect_batch( async def _collect_batch(
self, self,
ctx: ServeContext, ctx: EmbeddingServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Collect and aggregate batch results """Collect and aggregate batch results
with support for chunked processing. with support for chunked processing.
...@@ -431,7 +443,6 @@ class EmbeddingMixin(OpenAIServing): ...@@ -431,7 +443,6 @@ class EmbeddingMixin(OpenAIServing):
minimize memory usage. minimize memory usage.
For regular requests, collects results normally. For regular requests, collects results normally.
""" """
ctx = cast(EmbeddingServeContext, ctx)
try: try:
if ctx.engine_prompts is None: if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available") return self.create_error_response("Engine prompts not available")
...@@ -527,12 +538,10 @@ class EmbeddingMixin(OpenAIServing): ...@@ -527,12 +538,10 @@ class EmbeddingMixin(OpenAIServing):
except (ValueError, IndexError): except (ValueError, IndexError):
prompt_idx = result_idx # Fallback to result_idx prompt_idx = result_idx # Fallback to result_idx
short_prompts_results[prompt_idx] = cast( short_prompts_results[prompt_idx] = result
PoolingRequestOutput, result
)
# Finalize aggregated results # Finalize aggregated results
final_res_batch: list[PoolingRequestOutput | EmbeddingRequestOutput] = [] final_res_batch: list[PoolingRequestOutput] = []
num_prompts = len(ctx.engine_prompts) num_prompts = len(ctx.engine_prompts)
for prompt_idx in range(num_prompts): for prompt_idx in range(num_prompts):
...@@ -580,49 +589,19 @@ class EmbeddingMixin(OpenAIServing): ...@@ -580,49 +589,19 @@ class EmbeddingMixin(OpenAIServing):
f"Failed to aggregate chunks for prompt {prompt_idx}" f"Failed to aggregate chunks for prompt {prompt_idx}"
) )
elif prompt_idx in short_prompts_results: elif prompt_idx in short_prompts_results:
final_res_batch.append( final_res_batch.append(short_prompts_results[prompt_idx])
cast(PoolingRequestOutput, short_prompts_results[prompt_idx])
)
else: else:
return self.create_error_response( return self.create_error_response(
f"Result not found for prompt {prompt_idx}" f"Result not found for prompt {prompt_idx}"
) )
ctx.final_res_batch = cast( ctx.final_res_batch = final_res_batch
list[RequestOutput | PoolingRequestOutput], final_res_batch
)
return None return None
except Exception as e: except Exception as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
class OpenAIServingEmbedding(EmbeddingMixin):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_embedding( async def create_embedding(
self, self,
request: EmbeddingRequest, request: EmbeddingRequest,
...@@ -645,16 +624,13 @@ class OpenAIServingEmbedding(EmbeddingMixin): ...@@ -645,16 +624,13 @@ class OpenAIServingEmbedding(EmbeddingMixin):
raw_request=raw_request, raw_request=raw_request,
model_name=model_name, model_name=model_name,
request_id=request_id, request_id=request_id,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
) )
return await super().handle(ctx) # type: ignore return await self.handle(ctx) # type: ignore[return-value]
@override
def _create_pooling_params( def _create_pooling_params(
self, self,
ctx: ServeContext[EmbeddingRequest], ctx: EmbeddingServeContext,
) -> PoolingParams | ErrorResponse: ) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx) pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse): if isinstance(pooling_params, ErrorResponse):
...@@ -666,17 +642,3 @@ class OpenAIServingEmbedding(EmbeddingMixin): ...@@ -666,17 +642,3 @@ class OpenAIServingEmbedding(EmbeddingMixin):
return self.create_error_response(str(e)) return self.create_error_response(str(e))
return pooling_params return pooling_params
async def _preprocess(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
return await super()._preprocess(ctx)
...@@ -17,8 +17,10 @@ from starlette.background import BackgroundTask, BackgroundTasks ...@@ -17,8 +17,10 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs from vllm import envs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import EmbedsPrompt, TokensPrompt
from vllm.logger import current_formatter_type, init_logger from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -32,11 +34,15 @@ if TYPE_CHECKING: ...@@ -32,11 +34,15 @@ if TYPE_CHECKING:
StreamOptions, StreamOptions,
) )
from vllm.entrypoints.openai.models.protocol import LoRAModulePath from vllm.entrypoints.openai.models.protocol import LoRAModulePath
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
else: else:
ChatCompletionRequest = object ChatCompletionRequest = object
CompletionRequest = object CompletionRequest = object
StreamOptions = object StreamOptions = object
LoRAModulePath = object LoRAModulePath = object
ResponsesRequest = object
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -211,11 +217,26 @@ def _validate_truncation_size( ...@@ -211,11 +217,26 @@ def _validate_truncation_size(
def get_max_tokens( def get_max_tokens(
max_model_len: int, max_model_len: int,
request: "ChatCompletionRequest | CompletionRequest", request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
input_length: int, prompt: TokensPrompt | EmbedsPrompt,
default_sampling_params: dict, default_sampling_params: dict,
) -> int: ) -> int:
max_tokens = getattr(request, "max_completion_tokens", None) or request.max_tokens # NOTE: Avoid isinstance() for better efficiency
max_tokens: int | None = None
if max_tokens is None:
# ChatCompletionRequest
max_tokens = getattr(request, "max_completion_tokens", None)
if max_tokens is None:
# ResponsesRequest
max_tokens = getattr(request, "max_output_tokens", None)
if max_tokens is None:
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens = getattr(request, "max_tokens", None)
input_length = length_from_prompt_token_ids_or_embeds(
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
prompt.get("prompt_embeds"), # type: ignore[arg-type]
)
default_max_tokens = max_model_len - input_length default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length) max_output_tokens = current_platform.get_max_output_tokens(input_length)
......
...@@ -87,6 +87,7 @@ if TYPE_CHECKING: ...@@ -87,6 +87,7 @@ if TYPE_CHECKING:
VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds
VLLM_PLUGINS: list[str] | None = None VLLM_PLUGINS: list[str] | None = None
VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None
VLLM_LORA_RESOLVER_HF_REPO_LIST: str | None = None
# Deprecated env variables for profiling, kept for backward compatibility # Deprecated env variables for profiling, kept for backward compatibility
# See also vllm/config/profiler.py and `--profiler-config` argument # See also vllm/config/profiler.py and `--profiler-config` argument
VLLM_TORCH_CUDA_PROFILE: str | None = None VLLM_TORCH_CUDA_PROFILE: str | None = None
...@@ -325,16 +326,11 @@ def use_aot_compile() -> bool: ...@@ -325,16 +326,11 @@ def use_aot_compile() -> bool:
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
default_value = ( default_value = (
"1" "1"
if is_torch_equal_or_newer("2.10.0.dev") if is_torch_equal_or_newer("2.10.0.dev") and not disable_compile_cache()
and not disable_compile_cache()
# Disabling AOT_COMPILE for CPU
# See: https://github.com/vllm-project/vllm/issues/32033
and not current_platform.is_cpu()
else "0" else "0"
) )
...@@ -823,6 +819,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -823,6 +819,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
), ),
# Backend for Video IO # Backend for Video IO
# - "opencv": Default backend that uses OpenCV stream buffered backend. # - "opencv": Default backend that uses OpenCV stream buffered backend.
# - "identity": Returns raw video bytes for model processor to handle.
# #
# Custom backend implementations can be registered # Custom backend implementations can be registered
# via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and # via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and
...@@ -914,6 +911,13 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -914,6 +911,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv( "VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv(
"VLLM_LORA_RESOLVER_CACHE_DIR", None "VLLM_LORA_RESOLVER_CACHE_DIR", None
), ),
# A remote HF repo(s) containing one or more LoRA adapters, which
# may be downloaded and leveraged as needed. Only works if plugins
# are enabled and VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled.
# Values should be comma separated.
"VLLM_LORA_RESOLVER_HF_REPO_LIST": lambda: os.getenv(
"VLLM_LORA_RESOLVER_HF_REPO_LIST", None
),
# Enables torch CUDA profiling if set to 1. # Enables torch CUDA profiling if set to 1.
# Deprecated, see profiler_config. # Deprecated, see profiler_config.
"VLLM_TORCH_CUDA_PROFILE": lambda: os.getenv("VLLM_TORCH_CUDA_PROFILE"), "VLLM_TORCH_CUDA_PROFILE": lambda: os.getenv("VLLM_TORCH_CUDA_PROFILE"),
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.logging_utils.access_log_filter import (
UvicornAccessLogFilter,
create_uvicorn_log_config,
)
from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter
from vllm.logging_utils.lazy import lazy from vllm.logging_utils.lazy import lazy
from vllm.logging_utils.log_time import logtime from vllm.logging_utils.log_time import logtime
...@@ -8,6 +12,8 @@ from vllm.logging_utils.log_time import logtime ...@@ -8,6 +12,8 @@ from vllm.logging_utils.log_time import logtime
__all__ = [ __all__ = [
"NewLineFormatter", "NewLineFormatter",
"ColoredFormatter", "ColoredFormatter",
"UvicornAccessLogFilter",
"create_uvicorn_log_config",
"lazy", "lazy",
"logtime", "logtime",
] ]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Access log filter for uvicorn to exclude specific endpoints from logging.
This module provides a logging filter that can be used to suppress access logs
for specific endpoints (e.g., /health, /metrics) to reduce log noise in
production environments.
"""
import logging
from urllib.parse import urlparse
class UvicornAccessLogFilter(logging.Filter):
"""
A logging filter that excludes access logs for specified endpoint paths.
This filter is designed to work with uvicorn's access logger. It checks
the log record's arguments for the request path and filters out records
matching the excluded paths.
Uvicorn access log format:
'%s - "%s %s HTTP/%s" %d'
(client_addr, method, path, http_version, status_code)
Example:
127.0.0.1:12345 - "GET /health HTTP/1.1" 200
Args:
excluded_paths: A list of URL paths to exclude from logging.
Paths are matched exactly.
Example: ["/health", "/metrics"]
"""
def __init__(self, excluded_paths: list[str] | None = None):
super().__init__()
self.excluded_paths = set(excluded_paths or [])
def filter(self, record: logging.LogRecord) -> bool:
"""
Determine if the log record should be logged.
Args:
record: The log record to evaluate.
Returns:
True if the record should be logged, False otherwise.
"""
if not self.excluded_paths:
return True
# This filter is specific to uvicorn's access logs.
if record.name != "uvicorn.access":
return True
# The path is the 3rd argument in the log record's args tuple.
# See uvicorn's access logging implementation for details.
log_args = record.args
if isinstance(log_args, tuple) and len(log_args) >= 3:
path_with_query = log_args[2]
# Get path component without query string.
if isinstance(path_with_query, str):
path = urlparse(path_with_query).path
if path in self.excluded_paths:
return False
return True
def create_uvicorn_log_config(
excluded_paths: list[str] | None = None,
log_level: str = "info",
) -> dict:
"""
Create a uvicorn logging configuration with access log filtering.
This function generates a logging configuration dictionary that can be
passed to uvicorn's `log_config` parameter. It sets up the access log
filter to exclude specified paths.
Args:
excluded_paths: List of URL paths to exclude from access logs.
log_level: The log level for uvicorn loggers.
Returns:
A dictionary containing the logging configuration.
Example:
>>> config = create_uvicorn_log_config(["/health", "/metrics"])
>>> uvicorn.run(app, log_config=config)
"""
config = {
"version": 1,
"disable_existing_loggers": False,
"filters": {
"access_log_filter": {
"()": UvicornAccessLogFilter,
"excluded_paths": excluded_paths or [],
},
},
"formatters": {
"default": {
"()": "uvicorn.logging.DefaultFormatter",
"fmt": "%(levelprefix)s %(message)s",
"use_colors": None,
},
"access": {
"()": "uvicorn.logging.AccessFormatter",
"fmt": '%(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s', # noqa: E501
},
},
"handlers": {
"default": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"access": {
"formatter": "access",
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
"filters": ["access_log_filter"],
},
},
"loggers": {
"uvicorn": {
"handlers": ["default"],
"level": log_level.upper(),
"propagate": False,
},
"uvicorn.error": {
"level": log_level.upper(),
"handlers": ["default"],
"propagate": False,
},
"uvicorn.access": {
"handlers": ["access"],
"level": log_level.upper(),
"propagate": False,
},
},
}
return config
...@@ -62,6 +62,7 @@ def _fused_moe_lora_kernel( ...@@ -62,6 +62,7 @@ def _fused_moe_lora_kernel(
num_experts, num_experts,
lora_ids, lora_ids,
adapter_enabled, adapter_enabled,
max_loras, # <<< PR2: rename, used for masks when grid axis-2 != max_loras
# The stride variables represent how much to increase the ptr by when # The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is # moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down # how much to increase `a_ptr` by to get the element one row down
...@@ -83,6 +84,7 @@ def _fused_moe_lora_kernel( ...@@ -83,6 +84,7 @@ def _fused_moe_lora_kernel(
num_slice_c: tl.constexpr, num_slice_c: tl.constexpr,
top_k: tl.constexpr, top_k: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr,
USE_B_L2_CACHE: tl.constexpr, # new, enable .ca load for B
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
...@@ -104,10 +106,13 @@ def _fused_moe_lora_kernel( ...@@ -104,10 +106,13 @@ def _fused_moe_lora_kernel(
if moe_enabled == 0: if moe_enabled == 0:
# Early exit for the no moe lora case. # Early exit for the no moe lora case.
return return
# The grid size on axis 2 is (max_loras + 1) to handle the no-lora case # The grid's axis-2 dimension is max_loras + 1 to accommodate the -1 sentinel.
# (lora_id == -1), but sorted_token_ids and expert_ids are allocated with # This guard ensures we don't access sorted_token_ids / expert_ids /
# shape (max_loras, ...). Use (num_programs - 1) for correct bounds checking. # num_tokens_post_padded beyond their allocated bounds if an invalid
max_loras = tl.num_programs(axis=2) - 1 # lora_id somehow appears. Although the caller should pass correct
# max_loras, defensive programming prevents accidental out-of-bounds.
if lora_id >= max_loras:
return
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
# calculate pid_m,pid_n # calculate pid_m,pid_n
...@@ -136,10 +141,11 @@ def _fused_moe_lora_kernel( ...@@ -136,10 +141,11 @@ def _fused_moe_lora_kernel(
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty)) cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N # remove modulo wrap-around
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32)
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int32)
token_ind = stride_tl * lora_id + offs_token_id token_ind = stride_tl * lora_id + offs_token_id
offs_token = tl.load( offs_token = tl.load(
sorted_token_ids_ptr + token_ind, sorted_token_ids_ptr + token_ind,
...@@ -176,7 +182,13 @@ def _fused_moe_lora_kernel( ...@@ -176,7 +182,13 @@ def _fused_moe_lora_kernel(
# GDC wait waits for ALL programs in the prior kernel to complete # GDC wait waits for ALL programs in the prior kernel to complete
# before continuing. # before continuing.
# pre-fetch lora weight # pre-fetch lora weight
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) # add (offs_bn < N) mask; optional .ca for B
b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N)
if USE_B_L2_CACHE:
b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca")
else:
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
if USE_GDC and not IS_PRIMARY: if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait() tl.extra.cuda.gdc_wait()
a = tl.load( a = tl.load(
...@@ -276,6 +288,7 @@ def _fused_moe_lora_shrink( ...@@ -276,6 +288,7 @@ def _fused_moe_lora_shrink(
num_experts, num_experts,
lora_ids, lora_ids,
adapter_enabled, adapter_enabled,
lora_a_stacked[0].shape[0],
qcurr_hidden_states.stride(0), qcurr_hidden_states.stride(0),
qcurr_hidden_states.stride(1), qcurr_hidden_states.stride(1),
w1_lora_a_stacked.stride(0), w1_lora_a_stacked.stride(0),
...@@ -292,6 +305,7 @@ def _fused_moe_lora_shrink( ...@@ -292,6 +305,7 @@ def _fused_moe_lora_shrink(
num_slice_c=num_slices, num_slice_c=num_slices,
top_k=1 if mul_routed_weight else top_k_num, top_k=1 if mul_routed_weight else top_k_num,
MUL_ROUTED_WEIGHT=False, MUL_ROUTED_WEIGHT=False,
USE_B_L2_CACHE=True, # new
IS_PRIMARY=True, IS_PRIMARY=True,
**shrink_config, **shrink_config,
) )
...@@ -377,6 +391,7 @@ def _fused_moe_lora_expand( ...@@ -377,6 +391,7 @@ def _fused_moe_lora_expand(
num_experts, num_experts,
lora_ids, lora_ids,
adapter_enabled, adapter_enabled,
lora_b_stacked[0].shape[0],
a_intermediate_cache1.stride(0), a_intermediate_cache1.stride(0),
a_intermediate_cache1.stride(1), a_intermediate_cache1.stride(1),
w1_lora_b_stacked.stride(0), w1_lora_b_stacked.stride(0),
...@@ -393,6 +408,7 @@ def _fused_moe_lora_expand( ...@@ -393,6 +408,7 @@ def _fused_moe_lora_expand(
num_slice_c=num_slices, num_slice_c=num_slices,
top_k=1, top_k=1,
MUL_ROUTED_WEIGHT=mul_routed_weight, MUL_ROUTED_WEIGHT=mul_routed_weight,
USE_B_L2_CACHE=True, # new
IS_PRIMARY=False, IS_PRIMARY=False,
**expand_config, **expand_config,
) )
......
...@@ -7,17 +7,27 @@ import torch ...@@ -7,17 +7,27 @@ import torch
from vllm.distributed import ( from vllm.distributed import (
get_ep_group, get_ep_group,
) )
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEParallelConfig, FusedMoEParallelConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import (
FlashInferA2APrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
) )
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNaiveEP,
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
logger = init_logger(__name__)
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
if has_pplx(): if has_pplx():
from .pplx_prepare_finalize import ( from .pplx_prepare_finalize import (
...@@ -70,20 +80,46 @@ def maybe_make_prepare_finalize( ...@@ -70,20 +80,46 @@ def maybe_make_prepare_finalize(
moe: FusedMoEConfig, moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig | None, quant_config: FusedMoEQuantConfig | None,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
allow_new_interface: bool = False,
) -> FusedMoEPrepareAndFinalize | None: ) -> FusedMoEPrepareAndFinalize | None:
# NOTE(rob): we are migrating each quant_method to hold the MK
# in all cases. The allow_new_interface=False flag allow us to fall
# back to the old method for methods that have not yet been migrated.
#
# In old method:
# * maybe_init_modular_kernel() calls this function. If we are
# using no Dp/Ep or naive all2all, we return None this function
# returns None and no ModularKernelMethod is created. If non-naive
# all2all is used, this returns a PrepareAndFinalize object and
# a ModularKernelMethod is created.
# In new method:
# * maybe_make_prepare_finalize() is called from the oracle. We
# always return a PrepareAndFinalize object and the quant method
# holds the ModularKernel.
if not moe.moe_parallel_config.use_all2all_kernels: if not moe.moe_parallel_config.use_all2all_kernels:
if not allow_new_interface:
return None return None
# For DP/TP case, fall back to naive P/F.
if moe.moe_parallel_config.dp_size > 1:
logger.info_once(
"Detected DP deployment with no --enable-expert-parallel. "
"Falling back to AllGather+ReduceScatter dispatch/combine."
)
return MoEPrepareAndFinalizeNaiveEP(
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
num_dispatchers=(
get_ep_group().device_communicator.all2all_manager.world_size
),
)
else:
return MoEPrepareAndFinalizeNoEP()
all2all_manager = get_ep_group().device_communicator.all2all_manager all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None assert all2all_manager is not None
prepare_finalize: FusedMoEPrepareAndFinalize | None = None prepare_finalize: FusedMoEPrepareAndFinalize | None = None
# TODO(rob): update this as part of the MoE refactor.
assert not moe.use_flashinfer_cutlass_kernels, (
"Must be created in modelopt.py or fp8.py"
)
if moe.use_pplx_kernels: if moe.use_pplx_kernels:
assert quant_config is not None assert quant_config is not None
...@@ -203,4 +239,16 @@ def maybe_make_prepare_finalize( ...@@ -203,4 +239,16 @@ def maybe_make_prepare_finalize(
use_fp8_dispatch=use_fp8_dispatch, use_fp8_dispatch=use_fp8_dispatch,
) )
elif moe.use_fi_all2allv_kernels:
assert quant_config is not None
prepare_finalize = FlashInferA2APrepareAndFinalize(
num_dispatchers=all2all_manager.world_size,
)
elif moe.use_naive_all2all_kernels and allow_new_interface:
prepare_finalize = MoEPrepareAndFinalizeNaiveEP(
is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel),
num_dispatchers=all2all_manager.world_size,
)
return prepare_finalize return prepare_finalize
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment