Commit a810671a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori

parents 86b5aefe 6a09612b
......@@ -1654,13 +1654,23 @@ class ResponsesResponse(OpenAIBaseModel):
usage: ResponseUsage | None = None
user: str | None = None
# --8<-- [start:responses-extra-params]
# --8<-- [start:responses-response-extra-params]
# These are populated when enable_response_messages is set to True
# NOTE: custom serialization is needed
# see serialize_input_messages and serialize_output_messages
input_messages: ResponseInputOutputMessage | None = None
output_messages: ResponseInputOutputMessage | None = None
# --8<-- [end:responses-extra-params]
input_messages: ResponseInputOutputMessage | None = Field(
default=None,
description=(
"If enable_response_messages, we can show raw token input to model."
),
)
output_messages: ResponseInputOutputMessage | None = Field(
default=None,
description=(
"If enable_response_messages, we can show raw token output of model."
),
)
# --8<-- [end:responses-response-extra-params]
# NOTE: openAI harmony doesn't serialize TextContent properly,
# TODO: this fixes for TextContent, but need to verify for tools etc
......@@ -2054,6 +2064,9 @@ class TranscriptionRequest(OpenAIBaseModel):
presence_penalty: float | None = 0.0
"""The presence penalty to use for sampling."""
max_completion_tokens: int | None = None
"""The maximum number of tokens to generate."""
# --8<-- [end:transcription-sampling-params]
# Default sampling parameters for transcription requests.
......@@ -2300,6 +2313,9 @@ class TranslationRequest(OpenAIBaseModel):
# Flattened stream option to simplify form data.
stream_include_usage: bool | None = False
stream_continuous_usage_stats: bool | None = False
max_completion_tokens: int | None = None
"""The maximum number of tokens to generate."""
# --8<-- [end:translation-extra-params]
# Default sampling parameters for translation requests.
......
......@@ -162,6 +162,55 @@ class OpenAIServingChat(OpenAIServing):
self.supports_code_interpreter = False
self.python_tool = None
async def warmup(self) -> None:
"""
Warm up the chat template processing to avoid first-request latency.
This method triggers Jinja2 template compilation and content format
detection that would otherwise happen on the first real request,
causing increased latency on the first request.
"""
logger.info("Warming up chat template processing...")
start_time = time.perf_counter()
try:
# Get the tokenizer from the engine
tokenizer = await self.engine_client.get_tokenizer()
# Create a minimal dummy request
dummy_request = ChatCompletionRequest(
messages=[{"role": "user", "content": "warmup"}],
model=None,
max_completion_tokens=1,
)
# Call _preprocess_chat to trigger template compilation
# This forces:
# 1. Chat template content format detection
# 2. Jinja2 template compilation
# 3. Tokenizer initialization for chat
await self._preprocess_chat(
dummy_request,
tokenizer,
dummy_request.messages,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=True,
continue_final_message=False,
tool_dicts=None,
documents=None,
chat_template_kwargs=None,
tool_parser=None,
add_special_tokens=False,
)
elapsed = (time.perf_counter() - start_time) * 1000
logger.info("Chat template warmup completed in %.1fms", elapsed)
except Exception:
# Log but don't fail server startup if warmup fails
logger.exception("Chat template warmup failed")
async def create_chat_completion(
self,
request: ChatCompletionRequest,
......@@ -250,7 +299,10 @@ class OpenAIServingChat(OpenAIServing):
)
else:
# For GPT-OSS.
conversation, engine_prompts = self._make_request_with_harmony(request)
should_include_tools = tool_dicts is not None
conversation, engine_prompts = self._make_request_with_harmony(
request, should_include_tools
)
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(f"{e} {e.__cause__}")
......@@ -332,6 +384,7 @@ class OpenAIServingChat(OpenAIServing):
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)
generator = self.engine_client.generate(
......@@ -1783,6 +1836,7 @@ class OpenAIServingChat(OpenAIServing):
def _make_request_with_harmony(
self,
request: ChatCompletionRequest,
should_include_tools: bool = True,
):
messages: list[OpenAIMessage] = []
......@@ -1800,12 +1854,14 @@ class OpenAIServingChat(OpenAIServing):
reasoning_effort=request.reasoning_effort,
browser_description=None,
python_description=None,
with_custom_tools=request.tools is not None,
with_custom_tools=should_include_tools,
)
messages.append(sys_msg)
# Add developer message.
dev_msg = get_developer_message(tools=request.tools)
dev_msg = get_developer_message(
tools=request.tools if should_include_tools else None
)
messages.append(dev_msg)
# Add user message.
......
......@@ -230,6 +230,7 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)
generator = self.engine_client.generate(
......
......@@ -1231,6 +1231,7 @@ class OpenAIServing:
lora_request: LoRARequest | None,
trace_headers: Mapping[str, str] | None,
priority: int,
data_parallel_rank: int | None = None,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""Use the Processor to process inputs for AsyncLLM."""
tokenization_kwargs: dict[str, Any] = {}
......@@ -1246,6 +1247,7 @@ class OpenAIServing:
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
return engine_request, tokenization_kwargs
......
......@@ -104,7 +104,6 @@ from vllm.entrypoints.responses_utils import (
construct_input_messages,
construct_tool_dicts,
extract_tool_types,
make_response_output_items_from_parsable_context,
)
from vllm.entrypoints.tool_server import ToolServer
from vllm.inputs.data import TokensPrompt
......@@ -658,24 +657,19 @@ class OpenAIServingResponses(OpenAIServing):
else:
status = "incomplete"
elif isinstance(context, ParsableContext):
response_messages = context.parser.response_messages[
context.parser.num_init_messages :
]
output = make_response_output_items_from_parsable_context(response_messages)
output = context.parser.make_response_output_items_from_parsable_context()
# TODO: context for non-gptoss models doesn't use messages
# so we can't get them out yet
if request.enable_response_messages:
raise NotImplementedError(
"enable_response_messages is currently only supported for gpt-oss"
)
input_messages = context.input_messages
output_messages = context.output_messages
# TODO: Calculate usage.
# assert final_res.prompt_token_ids is not None
num_tool_output_tokens = 0
else:
assert isinstance(context, SimpleContext)
final_res = context.last_output
# Use final_output which has accumulated text/token_ids/logprobs
final_res = context.final_output
assert final_res is not None
assert len(final_res.outputs) == 1
final_output = final_res.outputs[0]
......
......@@ -35,7 +35,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing, SpeechToTextRe
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.models import SupportsTranscription
from vllm.model_executor.models import SupportsTranscription, supports_transcription
from vllm.outputs import RequestOutput
from vllm.tokenizers import get_tokenizer
from vllm.utils.import_utils import PlaceholderModule
......@@ -112,6 +112,131 @@ class OpenAISpeechToText(OpenAIServing):
self.default_sampling_params,
)
# Warm up audio preprocessing to avoid first-request latency
self._warmup_audio_preprocessing()
# Warm up input processor with dummy audio
self._warmup_input_processor()
def _warmup_audio_preprocessing(self) -> None:
"""Warm up audio processing libraries to avoid first-request latency.
The first call to librosa functions (load, get_duration, mel-spectrogram)
triggers JIT compilation and library initialization which can take ~7s.
This method warms up these operations during server initialization.
"""
# Skip warmup if librosa is not installed (optional dependency)
if isinstance(librosa, PlaceholderModule):
return
# Skip warmup if model doesn't support transcription
if not supports_transcription(self.model_cls):
return
try:
warmup_start = time.perf_counter()
logger.info("Warming up audio preprocessing libraries...")
# Create a minimal dummy audio (1 second of silence at target sample rate)
dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)
# Warm up librosa.load by using librosa functions on the dummy data
# This initializes FFTW, numba JIT, and other audio processing libraries
_ = librosa.get_duration(y=dummy_audio, sr=self.asr_config.sample_rate)
# Warm up mel-spectrogram computation with model-specific parameters
from vllm.transformers_utils.processor import (
cached_processor_from_config,
)
processor = cached_processor_from_config(self.model_config)
feature_extractor = None
if hasattr(processor, "feature_extractor"):
feature_extractor = processor.feature_extractor
elif hasattr(processor, "audio_processor"):
# For models like GraniteSpeech that use audio_processor
audio_proc = processor.audio_processor
if hasattr(audio_proc, "feature_extractor"):
feature_extractor = audio_proc.feature_extractor
# If audio_processor doesn't have feature_extractor,
# skip mel-spectrogram warmup for these models
if feature_extractor is not None:
_ = librosa.feature.melspectrogram(
y=dummy_audio,
sr=self.asr_config.sample_rate,
n_mels=getattr(feature_extractor, "n_mels", 128),
n_fft=getattr(feature_extractor, "n_fft", 400),
hop_length=getattr(feature_extractor, "hop_length", 160),
)
warmup_elapsed = time.perf_counter() - warmup_start
logger.info("Audio preprocessing warmup completed in %.2fs", warmup_elapsed)
except Exception:
# Don't fail initialization if warmup fails - log exception and continue
logger.exception(
"Audio preprocessing warmup failed (non-fatal): %s. "
"First request may experience higher latency.",
)
def _warmup_input_processor(self) -> None:
"""Warm up input processor with dummy audio to avoid first-request latency.
The first call to input_processor.process_inputs() with multimodal audio
triggers multimodal processing initialization which can take ~2.5s.
This method processes a dummy audio request to warm up the pipeline.
"""
# Skip warmup if model doesn't support transcription
if not supports_transcription(self.model_cls):
return
# Only warm up if model supports transcription methods
if not hasattr(self.model_cls, "get_generation_prompt"):
return
try:
from vllm.sampling_params import SamplingParams
warmup_start = time.perf_counter()
logger.info("Warming up multimodal input processor...")
# Create minimal dummy audio (1 second of silence)
dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)
# Use the same method that _preprocess_speech_to_text uses
# to create the prompt
dummy_prompt = self.model_cls.get_generation_prompt(
audio=dummy_audio,
stt_config=self.asr_config,
model_config=self.model_config,
language="en",
task_type=self.task_type,
request_prompt="",
to_language=None,
)
# Create minimal sampling params
dummy_params = SamplingParams(
max_tokens=1,
temperature=0.0,
)
# Process the dummy input through the input processor
# This will trigger all the multimodal processing initialization
_ = self.input_processor.process_inputs(
request_id="warmup",
prompt=dummy_prompt,
params=dummy_params,
)
warmup_elapsed = time.perf_counter() - warmup_start
logger.info("Input processor warmup completed in %.2fs", warmup_elapsed)
except Exception:
# Don't fail initialization if warmup fails - log warning and continue
logger.exception(
"Input processor warmup failed (non-fatal): %s. "
"First request may experience higher latency."
)
@cached_property
def model_cls(self) -> type[SupportsTranscription]:
from vllm.model_executor.model_loader import get_model_cls
......@@ -293,8 +418,14 @@ class OpenAISpeechToText(OpenAIServing):
try:
# Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a
# fixed-size log-mel-spectogram.
default_max_tokens = self.model_config.max_model_len
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
# generated by respecting the extra completion tokens arg.
if request.max_completion_tokens is None:
default_max_tokens = self.model_config.max_model_len
else:
default_max_tokens = min(
self.model_config.max_model_len, request.max_completion_tokens
)
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params
)
......
......@@ -16,7 +16,6 @@ from openai.types.responses.response import ToolChoice
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai.types.responses.response_output_item import McpCall
from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from openai.types.responses.tool import Tool
......@@ -27,38 +26,6 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionMessageParam,
ResponseInputOutputItem,
)
from vllm.utils import random_uuid
def make_response_output_items_from_parsable_context(
response_messages: list[ResponseInputOutputItem],
) -> list[ResponseOutputItem]:
"""Given a list of sentences, construct ResponseOutput Items."""
output_messages: list[ResponseOutputItem] = []
for message in response_messages:
if not isinstance(message, ResponseFunctionToolCallOutputItem):
output_messages.append(message)
else:
if len(output_messages) == 0:
raise ValueError(
"Cannot have a FunctionToolCallOutput before FunctionToolCall."
)
if isinstance(output_messages[-1], ResponseFunctionToolCall):
mcp_message = McpCall(
id=f"{MCP_PREFIX}{random_uuid()}",
arguments=output_messages[-1].arguments,
name=output_messages[-1].name,
server_label=output_messages[
-1
].name, # TODO: store the server label
type=f"{MCP_PREFIX}call",
status="completed",
output=message.output,
# TODO: support error output
)
output_messages[-1] = mcp_message
return output_messages
def construct_input_messages(
......
......@@ -4,8 +4,19 @@
from fastapi import FastAPI
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
def register_vllm_serve_api_routers(app: FastAPI):
if envs.VLLM_SERVER_DEV_MODE:
logger.warning(
"SECURITY WARNING: Development endpoints are enabled! "
"This should NOT be used in production!"
)
from vllm.entrypoints.serve.lora.api_router import (
attach_router as attach_lora_router,
)
......@@ -29,6 +40,18 @@ def register_vllm_serve_api_routers(app: FastAPI):
attach_sleep_router(app)
from vllm.entrypoints.serve.rpc.api_router import (
attach_router as attach_rpc_router,
)
attach_rpc_router(app)
from vllm.entrypoints.serve.cache.api_router import (
attach_router as attach_cache_router,
)
attach_cache_router(app)
from vllm.entrypoints.serve.tokenize.api_router import (
attach_router as attach_tokenize_router,
)
......@@ -58,3 +81,9 @@ def register_vllm_serve_api_routers(app: FastAPI):
)
attach_health_router(app)
from vllm.entrypoints.serve.instrumentator.server_info import (
attach_router as attach_server_info_router,
)
attach_server_info_router(app)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi import APIRouter, FastAPI, Query, Request
from fastapi.responses import Response
import vllm.envs as envs
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
logger = init_logger(__name__)
router = APIRouter()
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.post("/reset_prefix_cache")
async def reset_prefix_cache(
raw_request: Request,
reset_running_requests: bool = Query(default=False),
reset_external: bool = Query(default=False),
):
"""
Reset the local prefix cache.
Optionally, if the query parameter `reset_external=true`
also resets the external (connector-managed) prefix cache.
Note that we currently do not check if the prefix cache
is successfully reset in the API server.
Example:
POST /reset_prefix_cache?reset_external=true
"""
logger.info("Resetting prefix cache...")
await engine_client(raw_request).reset_prefix_cache(
reset_running_requests, reset_external
)
return Response(status_code=200)
@router.post("/reset_mm_cache")
async def reset_mm_cache(raw_request: Request):
"""
Reset the multi-modal cache. Note that we currently do not check if the
multi-modal cache is successfully reset in the API server.
"""
logger.info("Resetting multi-modal cache...")
await engine_client(raw_request).reset_mm_cache()
return Response(status_code=200)
def attach_router(app: FastAPI):
if not envs.VLLM_SERVER_DEV_MODE:
return
app.include_router(router)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Annotated, Literal
import pydantic
from fastapi import APIRouter, FastAPI, Query, Request
from fastapi.responses import JSONResponse
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
router = APIRouter()
PydanticVllmConfig = pydantic.TypeAdapter(VllmConfig)
@router.get("/server_info")
async def show_server_info(
raw_request: Request,
config_format: Annotated[Literal["text", "json"], Query()] = "text",
):
vllm_config: VllmConfig = raw_request.app.state.vllm_config
server_info = {
"vllm_config": str(vllm_config)
if config_format == "text"
else PydanticVllmConfig.dump_python(vllm_config, mode="json", fallback=str)
# fallback=str is needed to handle e.g. torch.dtype
}
return JSONResponse(content=server_info)
def attach_router(app: FastAPI):
if not envs.VLLM_SERVER_DEV_MODE:
return
app.include_router(router)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from http import HTTPStatus
from typing import Any
from fastapi import APIRouter, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response
import vllm.envs as envs
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
logger = init_logger(__name__)
router = APIRouter()
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.post("/collective_rpc")
async def collective_rpc(raw_request: Request):
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"JSON decode error: {e}",
) from e
method = body.get("method")
if method is None:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail="Missing 'method' in request body",
)
# For security reason, only serialized string args/kwargs are passed.
# User-defined `method` is responsible for deserialization if needed.
args: list[str] = body.get("args", [])
kwargs: dict[str, str] = body.get("kwargs", {})
timeout: float | None = body.get("timeout")
results = await engine_client(raw_request).collective_rpc(
method=method, timeout=timeout, args=tuple(args), kwargs=kwargs
)
if results is None:
return Response(status_code=200)
response: list[Any] = []
for result in results:
if result is None or isinstance(result, dict | list):
response.append(result)
else:
response.append(str(result))
return JSONResponse(content={"results": response})
def attach_router(app: FastAPI):
if not envs.VLLM_SERVER_DEV_MODE:
return
app.include_router(router)
......@@ -52,9 +52,5 @@ async def is_sleeping(raw_request: Request):
def attach_router(app: FastAPI):
if not envs.VLLM_SERVER_DEV_MODE:
return
logger.warning(
"SECURITY WARNING: Development endpoints are enabled! "
"This should NOT be used in production!"
)
app.include_router(router)
......@@ -207,7 +207,7 @@ if TYPE_CHECKING:
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
VLLM_LOOPBACK_IP: str = ""
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = True
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_ATTENTION: str | None = None
VLLM_NVFP4_GEMM_BACKEND: str | None = None
......@@ -244,6 +244,7 @@ if TYPE_CHECKING:
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
VLLM_USE_V2_MODEL_RUNNER: bool = False
VLLM_DEBUG_MFU_METRICS: bool = False
VLLM_USE_FLASH_MLA: bool = False
......@@ -1263,7 +1264,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MOONCAKE_BOOTSTRAP_PORT": lambda: int(
os.getenv("VLLM_MOONCAKE_BOOTSTRAP_PORT", "8998")
),
# all2all backend for vllm's expert parallel communication
# [DEPRECATED - will be removed in v0.15.0] all2all backend for vllm's
# expert parallel communication. Use --all2all-backend CLI argument instead.
# Available options:
# - "naive": naive all2all implementation using broadcasts
# - "allgather_reducescatter": all2all implementation based on allgather and
......@@ -1274,7 +1276,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl
"VLLM_ALL2ALL_BACKEND": env_with_choices(
"VLLM_ALL2ALL_BACKEND",
"allgather_reducescatter",
None,
[
"naive",
"pplx",
......@@ -1431,7 +1433,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# kv-cache memory usage and enable longer contexts)
# TODO(lucas): Remove this flag once latency regression is resolved.
"VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE": lambda: bool(
int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "0"))
int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "1"))
),
# Enables support for the "store" option in the OpenAI Responses API.
# When set to 1, vLLM's OpenAI server will retain the input and output
......@@ -1566,6 +1568,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_V2_MODEL_RUNNER": lambda: bool(
int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0"))
),
# Debug logging for --enable-mfu-metrics
"VLLM_DEBUG_MFU_METRICS": lambda: bool(
int(os.getenv("VLLM_DEBUG_MFU_METRICS", "0"))
),
# If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "0"))),
......@@ -1658,6 +1664,7 @@ def compile_factors() -> dict[str, object]:
"VLLM_CI_USE_S3",
"VLLM_MODEL_REDIRECT_PATH",
"VLLM_HOST_IP",
"VLLM_FORCE_AOT_LOAD",
"S3_ACCESS_KEY_ID",
"S3_SECRET_ACCESS_KEY",
"S3_ENDPOINT_URL",
......
......@@ -45,16 +45,17 @@ def parse_raw_prompts(
# case 4: array of token arrays
if is_list_of(prompt, list):
first = prompt[0]
if not isinstance(first, list):
raise ValueError("prompt expected to be a list of lists")
if len(first) == 0:
raise ValueError("Please provide at least one prompt")
# strict validation: every nested list must be list[int]
if not all(is_list_of(elem, int) for elem in prompt):
raise TypeError("Nested lists must contain only integers")
if len(prompt) == 1 and isinstance(prompt[0], list) and len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
for elem in prompt:
if not isinstance(elem, list):
raise TypeError(
"prompt must be a list of lists, but found a non-list element."
)
if not is_list_of(elem, int):
raise TypeError(
"Nested lists of tokens must contain only integers."
)
prompt = cast(list[list[int]], prompt)
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
......
......@@ -156,16 +156,22 @@ def _fused_moe_lora_kernel(
+ offs_bn[None, :] * stride_bn
)
if USE_GDC and IS_PRIMARY:
# GDC launch dependents hints the runtime system to launch dependent kernels.
tl.extra.cuda.gdc_launch_dependents()
# accumulator
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# GDC wait waits for ALL programs in the prior kernel to complete
# before continuing.
if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait()
for k in range(0, grid_k):
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
# pre-fetch lora weight
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
# GDC wait waits for ALL programs in the prior kernel to complete
# before continuing.
if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait()
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
......@@ -179,9 +185,6 @@ def _fused_moe_lora_kernel(
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
if USE_GDC and IS_PRIMARY:
# GDC launch dependents hints the runtime system to launch dependent kernels.
tl.extra.cuda.gdc_launch_dependents()
accumulator = accumulator.to(c_ptr.dtype.element_ty)
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
......@@ -290,6 +293,7 @@ def _fused_moe_lora_shrink(
def _fused_moe_lora_expand(
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank)
b_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, output_dim_size)
lora_b_stacked: list[
torch.Tensor
], # [(max_loras, num_experts, max_lora_rank, K,),...]
......@@ -331,11 +335,6 @@ def _fused_moe_lora_expand(
-1, a_intermediate_cache1.shape[3]
)
b_intermediate_cache1 = torch.zeros(
(num_slices, M, top_k_num, w1_output_dim_size),
dtype=output.dtype,
device=device,
)
use_gdc = supports_pdl(a_intermediate_cache1.device)
expand_config = {
"BLOCK_SIZE_M": block_size_m,
......@@ -460,6 +459,12 @@ def _fused_moe_lora(
device=device,
)
b_intermediate_cache1 = torch.zeros(
(num_slices, M, top_k_num, w1_output_dim_size),
dtype=output.dtype,
device=device,
)
_fused_moe_lora_shrink(
a_intermediate_cache1,
qcurr_hidden_states,
......@@ -506,6 +511,7 @@ def _fused_moe_lora(
_fused_moe_lora_expand(
output,
a_intermediate_cache1,
b_intermediate_cache1,
lora_b_stacked,
topk_weights,
sorted_token_ids,
......
......@@ -14,11 +14,6 @@ class LoRARequest(
"""
Request for a LoRA adapter.
Note that this class should be used internally. For online
serving, it is recommended to not allow users to use this class but
instead provide another layer of abstraction to prevent users from
accessing unauthorized LoRA adapters.
lora_int_id must be globally unique for a given adapter.
This is currently not enforced in vLLM.
"""
......
......@@ -933,30 +933,26 @@ def enable_batch_invariant_mode():
_batch_invariant_MODE = True
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
# Batch invariant matmuls are no longer needed after cublas overrides
if not is_torch_equal_or_newer("2.10.0.dev"):
if (
current_platform.is_device_capability_family(100)
or current_platform.is_device_capability(80)
or current_platform.is_device_capability(89)
):
# For PyTorch 2.9, B200 uses GEMV for bs=1
# Requires https://github.com/pytorch/pytorch/pull/166735
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
else:
# Only source of batch invariance for Hopper is split-k, can disable through
# cuBLAS workspace config
_original_cublas_workspace_cfg = os.environ.get(
"CUBLAS_WORKSPACE_CONFIG", None
)
_original_cublaslt_workspace_size = os.environ.get(
"CUBLASLT_WORKSPACE_SIZE", None
)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
if (
current_platform.is_device_capability_family(100)
or current_platform.is_device_capability(80)
or current_platform.is_device_capability(89)
):
# For PyTorch 2.9, B200 uses GEMV for bs=1
# Requires https://github.com/pytorch/pytorch/pull/166735
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
else:
# Only source of batch invariance for Hopper is split-k, can disable through
# cuBLAS workspace config
_original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
_original_cublaslt_workspace_size = os.environ.get(
"CUBLASLT_WORKSPACE_SIZE", None
)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
_batch_invariant_LIB.impl(
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
......
......@@ -251,6 +251,6 @@ class Conv3dLayer(ConvLayerBase):
# See: https://github.com/vllm-project/vllm/issues/27406
# and https://github.com/pytorch/pytorch/issues/166122
# By default, we use CUDNN's convolution ops with optimization.
if self.enable_linear and is_torch_equal("2.9.0"):
if self.enable_linear and (is_torch_equal("2.9.0") or is_torch_equal("2.9.1")):
return self._forward_mulmat(x)
return self._forward_conv(x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment