Commit d76fc11e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.15.0rc1' into v0.15.0rc1-dev

parents 38166ec4 58996f35
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from typing import cast
from typing import Final, cast
import jinja2
import numpy as np
......@@ -11,18 +11,8 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
ClassificationServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
......@@ -39,60 +29,68 @@ from vllm.pooling_params import PoolingParams
logger = init_logger(__name__)
class ClassificationMixin(OpenAIServing):
chat_template: str | None
chat_template_content_format: ChatTemplateContentFormatOption
trust_request_chat_template: bool
ClassificationServeContext = ServeContext[ClassificationRequest]
class ServingClassification(OpenAIServing):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def _preprocess(
self,
ctx: ServeContext,
ctx: ClassificationServeContext,
) -> ErrorResponse | None:
"""
Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs.
"""
ctx = cast(ClassificationServeContext, ctx)
try:
request_obj = ctx.request
if isinstance(request_obj, ClassificationChatRequest):
chat_request = request_obj
messages = chat_request.messages
trust_request_chat_template = getattr(
self,
"trust_request_chat_template",
False,
)
ret = self._validate_chat_template(
request_chat_template=chat_request.chat_template,
chat_template_kwargs=chat_request.chat_template_kwargs,
trust_request_chat_template=trust_request_chat_template,
ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(ctx.request, ClassificationChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if ret:
return ret
if error_check_ret:
return error_check_ret
_, engine_prompts = await self._preprocess_chat(
cast(ChatCompletionRequest, chat_request),
ctx.request,
self.renderer,
messages,
chat_template=(
chat_request.chat_template
or getattr(self, "chat_template", None)
),
chat_template_content_format=cast(
ChatTemplateContentFormatOption,
getattr(self, "chat_template_content_format", "auto"),
),
add_generation_prompt=chat_request.add_generation_prompt,
continue_final_message=chat_request.continue_final_message,
add_special_tokens=chat_request.add_special_tokens,
ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
)
ctx.engine_prompts = engine_prompts
elif isinstance(request_obj, ClassificationCompletionRequest):
completion_request = request_obj
input_data = completion_request.input
elif isinstance(ctx.request, ClassificationCompletionRequest):
input_data = ctx.request.input
if input_data in (None, ""):
return self.create_error_response(
"Input or messages must be provided",
......@@ -106,13 +104,10 @@ class ClassificationMixin(OpenAIServing):
prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input,
config=self._build_render_config(completion_request),
config=self._build_render_config(ctx.request),
)
else:
return self.create_error_response(
"Invalid classification request type",
status_code=HTTPStatus.BAD_REQUEST,
)
return self.create_error_response("Invalid classification request type")
return None
......@@ -122,13 +117,14 @@ class ClassificationMixin(OpenAIServing):
def _build_response(
self,
ctx: ServeContext,
ctx: ClassificationServeContext,
) -> ClassificationResponse | ErrorResponse:
"""
Convert model outputs to a formatted classification response
with probabilities and labels.
"""
ctx = cast(ClassificationServeContext, ctx)
id2label = getattr(self.model_config.hf_config, "id2label", {})
items: list[ClassificationData] = []
num_prompt_tokens = 0
......@@ -139,9 +135,7 @@ class ClassificationMixin(OpenAIServing):
probs = classify_res.probs
predicted_index = int(np.argmax(probs))
label = getattr(self.model_config.hf_config, "id2label", {}).get(
predicted_index
)
label = id2label.get(predicted_index)
item = ClassificationData(
index=idx,
......@@ -174,32 +168,6 @@ class ClassificationMixin(OpenAIServing):
add_special_tokens=request.add_special_tokens,
)
class ServingClassification(ClassificationMixin):
request_id_prefix = "classify"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_classify(
self,
request: ClassificationRequest,
......@@ -215,11 +183,11 @@ class ServingClassification(ClassificationMixin):
request_id=request_id,
)
return await super().handle(ctx) # type: ignore
return await self.handle(ctx) # type: ignore[return-value]
def _create_pooling_params(
self,
ctx: ServeContext[ClassificationRequest],
ctx: ClassificationServeContext,
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
......
......@@ -6,21 +6,13 @@ from typing import Any, Final, cast
import torch
from fastapi import Request
from fastapi.responses import Response
from typing_extensions import assert_never, override
from typing_extensions import assert_never
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.serving import (
EmbeddingServeContext,
OpenAIServing,
ServeContext,
)
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
......@@ -33,19 +25,11 @@ from vllm.entrypoints.pooling.embed.protocol import (
from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import (
EmbeddingRequestOutput,
PoolingOutput,
PoolingRequestOutput,
RequestOutput,
)
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import (
EmbedDType,
EncodingFormat,
Endianness,
encode_pooling_bytes,
encode_pooling_output,
)
......@@ -53,9 +37,33 @@ from vllm.utils.serial_utils import (
logger = init_logger(__name__)
class EmbeddingMixin(OpenAIServing):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
EmbeddingServeContext = ServeContext[EmbeddingRequest]
class OpenAIServingEmbedding(OpenAIServing):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
pooler_config = self.model_config.pooler_config
......@@ -69,32 +77,41 @@ class EmbeddingMixin(OpenAIServing):
else None
)
@override
async def _preprocess(
self,
ctx: ServeContext,
ctx: EmbeddingServeContext,
) -> ErrorResponse | None:
ctx = cast(EmbeddingServeContext, ctx)
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
_, ctx.engine_prompts = await self._preprocess_chat(
ctx.request,
self.renderer,
ctx.request.messages,
chat_template=ctx.request.chat_template or ctx.chat_template,
chat_template_content_format=ctx.chat_template_content_format,
chat_template=ctx.request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
)
else:
elif isinstance(ctx.request, EmbeddingCompletionRequest):
renderer = self._get_completion_renderer()
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input,
config=self._build_render_config(ctx.request),
)
else:
return self.create_error_response("Invalid classification request type")
return None
except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs")
......@@ -113,16 +130,15 @@ class EmbeddingMixin(OpenAIServing):
add_special_tokens=request.add_special_tokens,
)
@override
def _build_response(
self,
ctx: ServeContext,
) -> EmbeddingResponse | Response | ErrorResponse:
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
ctx: EmbeddingServeContext,
) -> EmbeddingResponse | EmbeddingBytesResponse | ErrorResponse:
final_res_batch_checked = ctx.final_res_batch
encoding_format: EncodingFormat = ctx.request.encoding_format
embed_dtype: EmbedDType = ctx.request.embed_dtype
endianness: Endianness = ctx.request.endianness
encoding_format = ctx.request.encoding_format
embed_dtype = ctx.request.embed_dtype
endianness = ctx.request.endianness
def encode_float_base64():
items: list[EmbeddingResponseData] = []
......@@ -203,8 +219,8 @@ class EmbeddingMixin(OpenAIServing):
self,
ctx: EmbeddingServeContext,
token_ids: list[int],
pooling_params,
trace_headers,
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
prompt_idx: int,
) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
"""Process a single prompt using chunked processing."""
......@@ -246,7 +262,7 @@ class EmbeddingMixin(OpenAIServing):
def _validate_input(
self,
request,
request: object,
input_ids: list[int],
input_text: str,
) -> TokensPrompt:
......@@ -326,7 +342,7 @@ class EmbeddingMixin(OpenAIServing):
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
prompt_index: int,
) -> AsyncGenerator[RequestOutput | PoolingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Create a generator for a single prompt using standard processing."""
request_id_item = f"{ctx.request_id}-{prompt_index}"
......@@ -347,7 +363,6 @@ class EmbeddingMixin(OpenAIServing):
priority=getattr(ctx.request, "priority", 0),
)
@override
async def _prepare_generators(
self,
ctx: ServeContext,
......@@ -363,9 +378,7 @@ class EmbeddingMixin(OpenAIServing):
return await super()._prepare_generators(ctx)
# Custom logic for chunked processing
generators: list[
AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
] = []
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
trace_headers = (
......@@ -419,10 +432,9 @@ class EmbeddingMixin(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@override
async def _collect_batch(
self,
ctx: ServeContext,
ctx: EmbeddingServeContext,
) -> ErrorResponse | None:
"""Collect and aggregate batch results
with support for chunked processing.
......@@ -431,7 +443,6 @@ class EmbeddingMixin(OpenAIServing):
minimize memory usage.
For regular requests, collects results normally.
"""
ctx = cast(EmbeddingServeContext, ctx)
try:
if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available")
......@@ -527,12 +538,10 @@ class EmbeddingMixin(OpenAIServing):
except (ValueError, IndexError):
prompt_idx = result_idx # Fallback to result_idx
short_prompts_results[prompt_idx] = cast(
PoolingRequestOutput, result
)
short_prompts_results[prompt_idx] = result
# Finalize aggregated results
final_res_batch: list[PoolingRequestOutput | EmbeddingRequestOutput] = []
final_res_batch: list[PoolingRequestOutput] = []
num_prompts = len(ctx.engine_prompts)
for prompt_idx in range(num_prompts):
......@@ -580,49 +589,19 @@ class EmbeddingMixin(OpenAIServing):
f"Failed to aggregate chunks for prompt {prompt_idx}"
)
elif prompt_idx in short_prompts_results:
final_res_batch.append(
cast(PoolingRequestOutput, short_prompts_results[prompt_idx])
)
final_res_batch.append(short_prompts_results[prompt_idx])
else:
return self.create_error_response(
f"Result not found for prompt {prompt_idx}"
)
ctx.final_res_batch = cast(
list[RequestOutput | PoolingRequestOutput], final_res_batch
)
ctx.final_res_batch = final_res_batch
return None
except Exception as e:
return self.create_error_response(str(e))
class OpenAIServingEmbedding(EmbeddingMixin):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
async def create_embedding(
self,
request: EmbeddingRequest,
......@@ -645,16 +624,13 @@ class OpenAIServingEmbedding(EmbeddingMixin):
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
)
return await super().handle(ctx) # type: ignore
return await self.handle(ctx) # type: ignore[return-value]
@override
def _create_pooling_params(
self,
ctx: ServeContext[EmbeddingRequest],
ctx: EmbeddingServeContext,
) -> PoolingParams | ErrorResponse:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
......@@ -666,17 +642,3 @@ class OpenAIServingEmbedding(EmbeddingMixin):
return self.create_error_response(str(e))
return pooling_params
async def _preprocess(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
if isinstance(ctx.request, EmbeddingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=ctx.request.chat_template,
chat_template_kwargs=ctx.request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
return await super()._preprocess(ctx)
......@@ -17,8 +17,10 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import EmbedsPrompt, TokensPrompt
from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING:
......@@ -32,11 +34,15 @@ if TYPE_CHECKING:
StreamOptions,
)
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
else:
ChatCompletionRequest = object
CompletionRequest = object
StreamOptions = object
LoRAModulePath = object
ResponsesRequest = object
logger = init_logger(__name__)
......@@ -211,11 +217,26 @@ def _validate_truncation_size(
def get_max_tokens(
max_model_len: int,
request: "ChatCompletionRequest | CompletionRequest",
input_length: int,
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
prompt: TokensPrompt | EmbedsPrompt,
default_sampling_params: dict,
) -> int:
max_tokens = getattr(request, "max_completion_tokens", None) or request.max_tokens
# NOTE: Avoid isinstance() for better efficiency
max_tokens: int | None = None
if max_tokens is None:
# ChatCompletionRequest
max_tokens = getattr(request, "max_completion_tokens", None)
if max_tokens is None:
# ResponsesRequest
max_tokens = getattr(request, "max_output_tokens", None)
if max_tokens is None:
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens = getattr(request, "max_tokens", None)
input_length = length_from_prompt_token_ids_or_embeds(
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
prompt.get("prompt_embeds"), # type: ignore[arg-type]
)
default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length)
......
......@@ -87,6 +87,7 @@ if TYPE_CHECKING:
VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds
VLLM_PLUGINS: list[str] | None = None
VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None
VLLM_LORA_RESOLVER_HF_REPO_LIST: str | None = None
# Deprecated env variables for profiling, kept for backward compatibility
# See also vllm/config/profiler.py and `--profiler-config` argument
VLLM_TORCH_CUDA_PROFILE: str | None = None
......@@ -325,16 +326,11 @@ def use_aot_compile() -> bool:
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer
default_value = (
"1"
if is_torch_equal_or_newer("2.10.0.dev")
and not disable_compile_cache()
# Disabling AOT_COMPILE for CPU
# See: https://github.com/vllm-project/vllm/issues/32033
and not current_platform.is_cpu()
if is_torch_equal_or_newer("2.10.0.dev") and not disable_compile_cache()
else "0"
)
......@@ -823,6 +819,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
),
# Backend for Video IO
# - "opencv": Default backend that uses OpenCV stream buffered backend.
# - "identity": Returns raw video bytes for model processor to handle.
#
# Custom backend implementations can be registered
# via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and
......@@ -914,6 +911,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv(
"VLLM_LORA_RESOLVER_CACHE_DIR", None
),
# A remote HF repo(s) containing one or more LoRA adapters, which
# may be downloaded and leveraged as needed. Only works if plugins
# are enabled and VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled.
# Values should be comma separated.
"VLLM_LORA_RESOLVER_HF_REPO_LIST": lambda: os.getenv(
"VLLM_LORA_RESOLVER_HF_REPO_LIST", None
),
# Enables torch CUDA profiling if set to 1.
# Deprecated, see profiler_config.
"VLLM_TORCH_CUDA_PROFILE": lambda: os.getenv("VLLM_TORCH_CUDA_PROFILE"),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.logging_utils.access_log_filter import (
UvicornAccessLogFilter,
create_uvicorn_log_config,
)
from vllm.logging_utils.formatter import ColoredFormatter, NewLineFormatter
from vllm.logging_utils.lazy import lazy
from vllm.logging_utils.log_time import logtime
......@@ -8,6 +12,8 @@ from vllm.logging_utils.log_time import logtime
__all__ = [
"NewLineFormatter",
"ColoredFormatter",
"UvicornAccessLogFilter",
"create_uvicorn_log_config",
"lazy",
"logtime",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Access log filter for uvicorn to exclude specific endpoints from logging.
This module provides a logging filter that can be used to suppress access logs
for specific endpoints (e.g., /health, /metrics) to reduce log noise in
production environments.
"""
import logging
from urllib.parse import urlparse
class UvicornAccessLogFilter(logging.Filter):
"""
A logging filter that excludes access logs for specified endpoint paths.
This filter is designed to work with uvicorn's access logger. It checks
the log record's arguments for the request path and filters out records
matching the excluded paths.
Uvicorn access log format:
'%s - "%s %s HTTP/%s" %d'
(client_addr, method, path, http_version, status_code)
Example:
127.0.0.1:12345 - "GET /health HTTP/1.1" 200
Args:
excluded_paths: A list of URL paths to exclude from logging.
Paths are matched exactly.
Example: ["/health", "/metrics"]
"""
def __init__(self, excluded_paths: list[str] | None = None):
super().__init__()
self.excluded_paths = set(excluded_paths or [])
def filter(self, record: logging.LogRecord) -> bool:
"""
Determine if the log record should be logged.
Args:
record: The log record to evaluate.
Returns:
True if the record should be logged, False otherwise.
"""
if not self.excluded_paths:
return True
# This filter is specific to uvicorn's access logs.
if record.name != "uvicorn.access":
return True
# The path is the 3rd argument in the log record's args tuple.
# See uvicorn's access logging implementation for details.
log_args = record.args
if isinstance(log_args, tuple) and len(log_args) >= 3:
path_with_query = log_args[2]
# Get path component without query string.
if isinstance(path_with_query, str):
path = urlparse(path_with_query).path
if path in self.excluded_paths:
return False
return True
def create_uvicorn_log_config(
excluded_paths: list[str] | None = None,
log_level: str = "info",
) -> dict:
"""
Create a uvicorn logging configuration with access log filtering.
This function generates a logging configuration dictionary that can be
passed to uvicorn's `log_config` parameter. It sets up the access log
filter to exclude specified paths.
Args:
excluded_paths: List of URL paths to exclude from access logs.
log_level: The log level for uvicorn loggers.
Returns:
A dictionary containing the logging configuration.
Example:
>>> config = create_uvicorn_log_config(["/health", "/metrics"])
>>> uvicorn.run(app, log_config=config)
"""
config = {
"version": 1,
"disable_existing_loggers": False,
"filters": {
"access_log_filter": {
"()": UvicornAccessLogFilter,
"excluded_paths": excluded_paths or [],
},
},
"formatters": {
"default": {
"()": "uvicorn.logging.DefaultFormatter",
"fmt": "%(levelprefix)s %(message)s",
"use_colors": None,
},
"access": {
"()": "uvicorn.logging.AccessFormatter",
"fmt": '%(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s', # noqa: E501
},
},
"handlers": {
"default": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"access": {
"formatter": "access",
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
"filters": ["access_log_filter"],
},
},
"loggers": {
"uvicorn": {
"handlers": ["default"],
"level": log_level.upper(),
"propagate": False,
},
"uvicorn.error": {
"level": log_level.upper(),
"handlers": ["default"],
"propagate": False,
},
"uvicorn.access": {
"handlers": ["access"],
"level": log_level.upper(),
"propagate": False,
},
},
}
return config
......@@ -103,7 +103,14 @@ def run_cutlass_moe_fp8(
or a2_scale.size(0) == a1q.shape[0]
), "Intermediate scale shape mismatch"
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
if expert_map is not None:
# NOTE(rob): the expert_map is used for the STANDARD case and
# the batched format is used by the BATCHED case.
# TODO(rob): update the MK interface to only pass the expert_map
# during the STANDARD case to make this clearer across all kernels.
if use_batched_format:
assert expert_num_tokens is not None
else:
assert expert_num_tokens is None
# We have two modes: batched experts and non-batched experts.
......@@ -379,7 +386,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
# needed for STANDARD activation format kernels in DP/EP mode.
# Note that the BATCHED activation format does not use
# the expert map for identifying experts.
return not moe_parallel_config.use_all2all_kernels
return not (
moe_parallel_config.use_fi_all2allv_kernels
or moe_parallel_config.use_deepep_ht_kernels
)
def supports_chunking(self) -> bool:
return True
......@@ -641,10 +651,8 @@ def run_cutlass_moe_fp4(
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def expects_unquantized_inputs(
moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
) -> bool:
@property
def expects_unquantized_inputs(self) -> bool:
return True
@staticmethod
......
......@@ -148,7 +148,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
# NOTE(rob): discovered an IMA with this combination. Needs investigation.
return not moe_parallel_config.use_fi_all2allv_kernels
def supports_chunking(self) -> bool:
return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment