Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from http import HTTPStatus
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.api_server import validate_json_request
from vllm.entrypoints.openai.protocol import (
ErrorResponse,
)
from vllm.entrypoints.serve.elastic_ep.middleware import (
get_scaling_elastic_ep,
set_scaling_elastic_ep,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
router = APIRouter()
@router.post(
"/scale_elastic_ep",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"model": dict},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
async def scale_elastic_ep(raw_request: Request):
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
new_data_parallel_size = body.get("new_data_parallel_size")
drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes
if new_data_parallel_size is None:
raise HTTPException(
status_code=400, detail="new_data_parallel_size is required"
)
if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0:
raise HTTPException(
status_code=400,
detail="new_data_parallel_size must be a positive integer",
)
if not isinstance(drain_timeout, int) or drain_timeout <= 0:
raise HTTPException(
status_code=400, detail="drain_timeout must be a positive integer"
)
# Set scaling flag to prevent new requests
set_scaling_elastic_ep(True)
client = engine_client(raw_request)
try:
await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
return JSONResponse(
{
"message": f"Scaled to {new_data_parallel_size} data parallel engines",
}
)
except TimeoutError as e:
raise HTTPException(
status_code=408,
detail="Scale failed due to request drain timeout "
f"after {drain_timeout} seconds",
) from e
except Exception as e:
logger.error("Scale failed: %s", e)
raise HTTPException(status_code=500, detail="Scale failed") from e
finally:
set_scaling_elastic_ep(False)
@router.post("/is_scaling_elastic_ep")
async def is_scaling_elastic_ep(raw_request: Request):
return JSONResponse({"is_scaling_elastic_ep": get_scaling_elastic_ep()})
def attach_router(app: FastAPI):
app.include_router(router)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Awaitable
from fastapi.responses import JSONResponse
from starlette.types import ASGIApp, Receive, Scope, Send
# Global variable to track scaling state
_scaling_elastic_ep = False
def get_scaling_elastic_ep():
return _scaling_elastic_ep
def set_scaling_elastic_ep(value):
global _scaling_elastic_ep
_scaling_elastic_ep = value
class ScalingMiddleware:
"""
Middleware that checks if the model is currently scaling and
returns a 503 Service Unavailable response if it is.
This middleware applies to all HTTP requests and prevents
processing when the model is in a scaling state.
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
if scope["type"] != "http":
return self.app(scope, receive, send)
# Check global scaling state
if get_scaling_elastic_ep():
# Return 503 Service Unavailable response
response = JSONResponse(
content={
"error": "The model is currently scaling. Please try again later."
},
status_code=503,
)
return response(scope, receive, send)
return self.app(scope, receive, send)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi import APIRouter, Request
from fastapi.responses import Response
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
from vllm.v1.engine.exceptions import EngineDeadError
logger = init_logger(__name__)
router = APIRouter()
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.get("/health", response_class=Response)
async def health(raw_request: Request) -> Response:
"""Health check."""
try:
await engine_client(raw_request).check_health()
return Response(status_code=200)
except EngineDeadError:
return Response(status_code=503)
def attach_router(app):
app.include_router(router)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import prometheus_client
import regex as re
from fastapi import FastAPI, Response
from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator
from starlette.routing import Mount
from vllm.v1.metrics.prometheus import get_prometheus_registry
class PrometheusResponse(Response):
media_type = prometheus_client.CONTENT_TYPE_LATEST
def attach_router(app: FastAPI):
"""Mount prometheus metrics to a FastAPI app."""
registry = get_prometheus_registry()
# `response_class=PrometheusResponse` is needed to return an HTTP response
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
# instead of the default "application/json" which is incorrect.
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
Instrumentator(
excluded_handlers=[
"/metrics",
"/health",
"/load",
"/ping",
"/version",
"/server_info",
],
registry=registry,
).add().instrument(app).expose(app, response_class=PrometheusResponse)
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import model_hosting_container_standards.sagemaker as sagemaker_standards import model_hosting_container_standards.sagemaker as sagemaker_standards
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, FastAPI, Request
from fastapi.responses import JSONResponse, Response from fastapi.responses import JSONResponse, Response
from vllm import envs
from vllm.entrypoints.openai.api_server import models, validate_json_request from vllm.entrypoints.openai.api_server import models, validate_json_request
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ErrorResponse, ErrorResponse,
...@@ -14,9 +17,18 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels ...@@ -14,9 +17,18 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
router = APIRouter()
def register_dynamic_lora_routes(router: APIRouter): def attach_router(app: FastAPI):
if not envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
"""If LoRA dynamic loading & unloading is not enabled, do nothing."""
return
logger.warning(
"LoRA dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!"
)
@sagemaker_standards.register_load_adapter_handler( @sagemaker_standards.register_load_adapter_handler(
request_shape={ request_shape={
"lora_name": "body.name", "lora_name": "body.name",
...@@ -54,4 +66,5 @@ def register_dynamic_lora_routes(router: APIRouter): ...@@ -54,4 +66,5 @@ def register_dynamic_lora_routes(router: APIRouter):
return Response(status_code=200, content=response) return Response(status_code=200, content=response)
return router # register the router
app.include_router(router)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi import APIRouter, FastAPI, Request
from fastapi.responses import Response
from vllm.config import ProfilerConfig
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
logger = init_logger(__name__)
router = APIRouter()
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.post("/start_profile")
async def start_profile(raw_request: Request):
logger.info("Starting profiler...")
await engine_client(raw_request).start_profile()
logger.info("Profiler started.")
return Response(status_code=200)
@router.post("/stop_profile")
async def stop_profile(raw_request: Request):
logger.info("Stopping profiler...")
await engine_client(raw_request).stop_profile()
logger.info("Profiler stopped.")
return Response(status_code=200)
def attach_router(app: FastAPI):
profiler_config = getattr(app.state.args, "profiler_config", None)
assert profiler_config is None or isinstance(profiler_config, ProfilerConfig)
if profiler_config is not None and profiler_config.profiler is not None:
logger.warning_once(
"Profiler with mode '%s' is enabled in the "
"API server. This should ONLY be used for local development!",
profiler_config.profiler,
)
app.include_router(router)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, FastAPI, Query, Request
from fastapi.responses import JSONResponse
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
logger = init_logger(__name__)
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
router = APIRouter()
@router.post("/pause")
async def pause_generation(
raw_request: Request,
wait_for_inflight_requests: bool = Query(False),
clear_cache: bool = Query(True),
) -> JSONResponse:
"""Pause generation requests to allow weight updates.
Args:
wait_for_inflight_requests: When ``True`` waits for in-flight
requests to finish before pausing. When ``False`` (default),
aborts any in-flight requests immediately.
clear_cache: Whether to clear KV/prefix caches after draining.
"""
engine = engine_client(raw_request)
try:
await engine.pause_generation(
wait_for_inflight_requests=wait_for_inflight_requests,
clear_cache=clear_cache,
)
return JSONResponse(
content={"status": "paused"},
status_code=HTTPStatus.OK.value,
)
except ValueError as err:
return JSONResponse(
content={"error": str(err)},
status_code=HTTPStatus.BAD_REQUEST.value,
)
except Exception as err: # pragma: no cover - defensive
logger.exception("Failed to pause generation")
return JSONResponse(
content={"error": f"Failed to pause generation: {err}"},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
)
@router.post("/resume")
async def resume_generation(raw_request: Request) -> JSONResponse:
"""Resume generation after a pause."""
engine = engine_client(raw_request)
try:
await engine.resume_generation()
return JSONResponse(
content={"status": "resumed"},
status_code=HTTPStatus.OK.value,
)
except Exception as err: # pragma: no cover - defensive
logger.exception("Failed to resume generation")
return JSONResponse(
content={"error": f"Failed to resume generation: {err}"},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
)
@router.get("/is_paused")
async def is_paused(raw_request: Request) -> JSONResponse:
"""Return the current pause status."""
engine = engine_client(raw_request)
try:
paused = await engine.is_paused()
except Exception as err: # pragma: no cover - defensive
logger.exception("Failed to fetch pause status")
return JSONResponse(
content={"error": f"Failed to fetch pause status: {err}"},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
)
return JSONResponse(content={"is_paused": paused})
def attach_router(app: FastAPI):
app.include_router(router)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi import APIRouter, FastAPI, Request
from fastapi.responses import JSONResponse, Response
import vllm.envs as envs
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
logger = init_logger(__name__)
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
router = APIRouter()
@router.post("/sleep")
async def sleep(raw_request: Request):
# get POST params
level = raw_request.query_params.get("level", "1")
await engine_client(raw_request).sleep(int(level))
# FIXME: in v0 with frontend multiprocessing, the sleep command
# is sent but does not finish yet when we return a response.
return Response(status_code=200)
@router.post("/wake_up")
async def wake_up(raw_request: Request):
tags = raw_request.query_params.getlist("tags")
if tags == []:
# set to None to wake up all tags if no tags are provided
tags = None
logger.info("wake up the engine with tags: %s", tags)
await engine_client(raw_request).wake_up(tags)
# FIXME: in v0 with frontend multiprocessing, the wake-up command
# is sent but does not finish yet when we return a response.
return Response(status_code=200)
@router.get("/is_sleeping")
async def is_sleeping(raw_request: Request):
logger.info("check whether the engine is sleeping")
is_sleeping = await engine_client(raw_request).is_sleeping()
return JSONResponse(content={"is_sleeping": is_sleeping})
def attach_router(app: FastAPI):
if not envs.VLLM_SERVER_DEV_MODE:
return
logger.warning(
"SECURITY WARNING: Development endpoints are enabled! "
"This should NOT be used in production!"
)
app.include_router(router)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from typing_extensions import assert_never
from vllm.entrypoints.openai.api_server import validate_json_request
from vllm.entrypoints.openai.protocol import (
DetokenizeRequest,
DetokenizeResponse,
ErrorResponse,
TokenizeRequest,
TokenizeResponse,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.entrypoints.utils import (
with_cancellation,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
router = APIRouter()
@router.post(
"/tokenize",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
},
)
@with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
try:
generator = await handler.create_tokenize(request, raw_request)
except NotImplementedError as e:
raise HTTPException(
status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e)
) from e
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, TokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post(
"/detokenize",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
async def detokenize(request: DetokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
try:
generator = await handler.create_detokenize(request, raw_request)
except OverflowError as e:
raise RequestValidationError(errors=[str(e)]) from e
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, DetokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
def attach_router(app: FastAPI):
if getattr(app.state.args, "enable_tokenizer_info_endpoint", False):
"""Conditionally register the tokenizer info endpoint if enabled."""
@router.get("/tokenizer_info")
async def get_tokenizer_info(raw_request: Request):
"""Get comprehensive tokenizer information."""
result = await tokenization(raw_request).get_tokenizer_info()
return JSONResponse(
content=result.model_dump(),
status_code=result.error.code
if isinstance(result, ErrorResponse)
else 200,
)
app.include_router(router)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai_harmony import Author, Message, Role, TextContent from openai_harmony import Author, Message, Role, TextContent
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import random_uuid
if TYPE_CHECKING: if TYPE_CHECKING:
# Avoid circular import. # Avoid circular import.
...@@ -46,6 +51,10 @@ class Tool(ABC): ...@@ -46,6 +51,10 @@ class Tool(ABC):
async def get_result(self, context: "ConversationContext") -> Any: async def get_result(self, context: "ConversationContext") -> Any:
pass pass
@abstractmethod
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
pass
class HarmonyBrowserTool(Tool): class HarmonyBrowserTool(Tool):
def __init__(self): def __init__(self):
...@@ -81,6 +90,9 @@ class HarmonyBrowserTool(Tool): ...@@ -81,6 +90,9 @@ class HarmonyBrowserTool(Tool):
tool_output_msgs.append(msg) tool_output_msgs.append(msg)
return tool_output_msgs return tool_output_msgs
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
raise NotImplementedError("Not implemented yet")
@property @property
def tool_config(self) -> Any: def tool_config(self) -> Any:
return self.browser_tool.tool_config return self.browser_tool.tool_config
...@@ -138,6 +150,38 @@ class HarmonyPythonTool(Tool): ...@@ -138,6 +150,38 @@ class HarmonyPythonTool(Tool):
tool_output_msgs.append(msg) tool_output_msgs.append(msg)
return tool_output_msgs return tool_output_msgs
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
"""
This function converts parsable context types to harmony and
back so we can use GPTOSS demo python tool
"""
from vllm.entrypoints.context import ParsableContext
assert isinstance(context, ParsableContext)
last_msg = context.parser.response_messages[-1]
args = json.loads(last_msg.arguments)
last_msg_harmony = Message(
author=Author(role="assistant", name=None),
content=[TextContent(text=args["code"])],
channel="analysis",
recipient="python",
content_type="code",
)
tool_output_msgs = []
async for msg in self.python_tool.process(last_msg_harmony):
processed = ResponseFunctionToolCallOutputItem(
id=f"fco_{random_uuid()}",
type="function_call_output",
call_id=f"call_{random_uuid()}",
output=msg.content[0].text,
status="completed",
)
tool_output_msgs.append(processed)
return tool_output_msgs
@property @property
def tool_config(self) -> Any: def tool_config(self) -> Any:
return self.python_tool.tool_config return self.python_tool.tool_config
...@@ -37,7 +37,7 @@ if TYPE_CHECKING: ...@@ -37,7 +37,7 @@ if TYPE_CHECKING:
VLLM_DISABLE_FLASHINFER_PREFILL: bool = False VLLM_DISABLE_FLASHINFER_PREFILL: bool = False
VLLM_DO_NOT_TRACK: bool = False VLLM_DO_NOT_TRACK: bool = False
VLLM_USAGE_SOURCE: str = "" VLLM_USAGE_SOURCE: str = ""
VLLM_CONFIGURE_LOGGING: int = 1 VLLM_CONFIGURE_LOGGING: bool = True
VLLM_LOGGING_LEVEL: str = "INFO" VLLM_LOGGING_LEVEL: str = "INFO"
VLLM_LOGGING_PREFIX: str = "" VLLM_LOGGING_PREFIX: str = ""
VLLM_LOGGING_STREAM: str = "ext://sys.stdout" VLLM_LOGGING_STREAM: str = "ext://sys.stdout"
...@@ -75,11 +75,12 @@ if TYPE_CHECKING: ...@@ -75,11 +75,12 @@ if TYPE_CHECKING:
VLLM_MM_INPUT_CACHE_GIB: int = 4 VLLM_MM_INPUT_CACHE_GIB: int = 4
VLLM_TARGET_DEVICE: str = "cuda" VLLM_TARGET_DEVICE: str = "cuda"
VLLM_MAIN_CUDA_VERSION: str = "12.9" VLLM_MAIN_CUDA_VERSION: str = "12.9"
VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest"
MAX_JOBS: str | None = None MAX_JOBS: str | None = None
NVCC_THREADS: str | None = None NVCC_THREADS: str | None = None
VLLM_USE_PRECOMPILED: bool = False VLLM_USE_PRECOMPILED: bool = False
VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX: bool = False
VLLM_DOCKER_BUILD_CONTEXT: bool = False VLLM_DOCKER_BUILD_CONTEXT: bool = False
VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: bool = False
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False
CMAKE_BUILD_TYPE: Literal["Debug", "Release", "RelWithDebInfo"] | None = None CMAKE_BUILD_TYPE: Literal["Debug", "Release", "RelWithDebInfo"] | None = None
VERBOSE: bool = False VERBOSE: bool = False
...@@ -88,20 +89,23 @@ if TYPE_CHECKING: ...@@ -88,20 +89,23 @@ if TYPE_CHECKING:
VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds
VLLM_PLUGINS: list[str] | None = None VLLM_PLUGINS: list[str] | None = None
VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None
VLLM_TORCH_CUDA_PROFILE: bool = False # Deprecated env variables for profiling, kept for backward compatibility
# See also vllm/config/profiler.py and `--profiler-config` argument
VLLM_TORCH_CUDA_PROFILE: str | None = None
VLLM_TORCH_PROFILER_DIR: str | None = None VLLM_TORCH_PROFILER_DIR: str | None = None
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False VLLM_TORCH_PROFILER_RECORD_SHAPES: str | None = None
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: str | None = None
VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: bool = False VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM: str | None = None
VLLM_TORCH_PROFILER_WITH_STACK: str | None = None
VLLM_TORCH_PROFILER_WITH_FLOPS: str | None = None
VLLM_TORCH_PROFILER_USE_GZIP: str | None = None
VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: str | None = None
VLLM_PROFILER_DELAY_ITERS: str | None = None
VLLM_PROFILER_MAX_ITERS: str | None = None
# End of deprecated env variables for profiling
VLLM_USE_AOT_COMPILE: bool = False VLLM_USE_AOT_COMPILE: bool = False
VLLM_USE_BYTECODE_HOOK: bool = False VLLM_USE_BYTECODE_HOOK: bool = False
VLLM_FORCE_AOT_LOAD: bool = False VLLM_FORCE_AOT_LOAD: bool = False
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
VLLM_PROFILER_DELAY_ITERS: int = 0
VLLM_PROFILER_MAX_ITERS: int = 0
VLLM_TORCH_PROFILER_USE_GZIP: bool = True
VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: bool = True
VLLM_USE_TRITON_AWQ: bool = False VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False VLLM_SKIP_P2P_CHECK: bool = False
...@@ -144,6 +148,7 @@ if TYPE_CHECKING: ...@@ -144,6 +148,7 @@ if TYPE_CHECKING:
VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_IP: str = ""
VLLM_DP_MASTER_PORT: int = 0 VLLM_DP_MASTER_PORT: int = 0
VLLM_MOE_DP_CHUNK_SIZE: int = 256 VLLM_MOE_DP_CHUNK_SIZE: int = 256
VLLM_ENABLE_MOE_DP_CHUNK: bool = True
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict" VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict"
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
...@@ -175,6 +180,7 @@ if TYPE_CHECKING: ...@@ -175,6 +180,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5600 VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5600
VLLM_MOONCAKE_BOOTSTRAP_PORT: int = 8998
VLLM_ALL2ALL_BACKEND: Literal[ VLLM_ALL2ALL_BACKEND: Literal[
"naive", "naive",
"pplx", "pplx",
...@@ -197,6 +203,7 @@ if TYPE_CHECKING: ...@@ -197,6 +203,7 @@ if TYPE_CHECKING:
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480
VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT: int = 480
VLLM_USE_CUDNN_PREFILL: bool = False VLLM_USE_CUDNN_PREFILL: bool = False
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False
...@@ -214,6 +221,7 @@ if TYPE_CHECKING: ...@@ -214,6 +221,7 @@ if TYPE_CHECKING:
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
VLLM_TUNED_CONFIG_FOLDER: str | None = None VLLM_TUNED_CONFIG_FOLDER: str | None = None
VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS: set[str] = set() VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS: set[str] = set()
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT: bool = False
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY: bool = False VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
...@@ -449,6 +457,14 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -449,6 +457,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Main CUDA version of vLLM. This follows PyTorch but can be overridden. # Main CUDA version of vLLM. This follows PyTorch but can be overridden.
"VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower() "VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower()
or "12.9", or "12.9",
# Controls PyTorch float32 matmul precision mode within vLLM workers.
# Valid options mirror torch.set_float32_matmul_precision
"VLLM_FLOAT32_MATMUL_PRECISION": env_with_choices(
"VLLM_FLOAT32_MATMUL_PRECISION",
"highest",
["highest", "high", "medium"],
case_sensitive=False,
),
# Maximum number of compilation jobs to run in parallel. # Maximum number of compilation jobs to run in parallel.
# By default this is the number of CPUs # By default this is the number of CPUs
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None), "MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
...@@ -462,17 +478,16 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -462,17 +478,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
.lower() .lower()
in ("1", "true") in ("1", "true")
or bool(os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")), or bool(os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")),
# If set, skip adding +precompiled suffix to version string
"VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX": lambda: bool(
int(os.environ.get("VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX", "0"))
),
# Used to mark that setup.py is running in a Docker build context, # Used to mark that setup.py is running in a Docker build context,
# in order to force the use of precompiled binaries. # in order to force the use of precompiled binaries.
"VLLM_DOCKER_BUILD_CONTEXT": lambda: os.environ.get("VLLM_DOCKER_BUILD_CONTEXT", "") "VLLM_DOCKER_BUILD_CONTEXT": lambda: os.environ.get("VLLM_DOCKER_BUILD_CONTEXT", "")
.strip() .strip()
.lower() .lower()
in ("1", "true"), in ("1", "true"),
# Whether to force using nightly wheel in python build.
# This is used for testing the nightly wheel in python build.
"VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL": lambda: bool(
int(os.getenv("VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL", "0"))
),
# CMake build type # CMake build type
# If not set, defaults to "Debug" or "RelWithDebInfo" # If not set, defaults to "Debug" or "RelWithDebInfo"
# Available options: "Debug", "Release", "RelWithDebInfo" # Available options: "Debug", "Release", "RelWithDebInfo"
...@@ -618,7 +633,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -618,7 +633,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set to 0, vllm will not configure logging # If set to 0, vllm will not configure logging
# If set to 1, vllm will configure logging using the default configuration # If set to 1, vllm will configure logging using the default configuration
# or the configuration file specified by VLLM_LOGGING_CONFIG_PATH # or the configuration file specified by VLLM_LOGGING_CONFIG_PATH
"VLLM_CONFIGURE_LOGGING": lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")), "VLLM_CONFIGURE_LOGGING": lambda: bool(
int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
),
"VLLM_LOGGING_CONFIG_PATH": lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"), "VLLM_LOGGING_CONFIG_PATH": lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"),
# this is used for configuring the default logging level # this is used for configuring the default logging level
"VLLM_LOGGING_LEVEL": lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper(), "VLLM_LOGGING_LEVEL": lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper(),
...@@ -837,71 +854,52 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -837,71 +854,52 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv( "VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv(
"VLLM_LORA_RESOLVER_CACHE_DIR", None "VLLM_LORA_RESOLVER_CACHE_DIR", None
), ),
# Enables torch CUDA profiling if set. # Enables torch CUDA profiling if set to 1.
# On NVIDIA GPUs, this will start/stop cudaProfilerApi when triggered. # Deprecated, see profiler_config.
"VLLM_TORCH_CUDA_PROFILE": lambda: bool( "VLLM_TORCH_CUDA_PROFILE": lambda: os.getenv("VLLM_TORCH_CUDA_PROFILE"),
os.getenv("VLLM_TORCH_CUDA_PROFILE", "0") != "0"
),
# Enables torch profiler if set. # Enables torch profiler if set.
# Both AsyncLLM's CPU traces as well as workers' # Deprecated, see profiler_config.
# traces (CPU & GPU) will be saved under this directory. "VLLM_TORCH_PROFILER_DIR": lambda: os.getenv("VLLM_TORCH_PROFILER_DIR"),
# Note that it must be an absolute path. # Enable torch profiler to record shapes if set to 1.
"VLLM_TORCH_PROFILER_DIR": lambda: ( # Deprecated, see profiler_config.
None "VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: (
if (val := os.getenv("VLLM_TORCH_PROFILER_DIR")) is None os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES")
else ( ),
val # Enable torch profiler to profile memory if set to 1.
if val.startswith("gs://") and val[5:] and val[5] != "/" # Deprecated, see profiler_config.
else os.path.abspath(os.path.expanduser(val)) "VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: (
) os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY")
), ),
# Enable torch profiler to record shapes if set # Enable torch profiler to profile stack if set to 1.
# VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will # Deprecated, see profiler_config.
# not record shapes. "VLLM_TORCH_PROFILER_WITH_STACK": lambda: (
"VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: bool( os.getenv("VLLM_TORCH_PROFILER_WITH_STACK")
os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES", "0") != "0" ),
), # Enable torch profiler to profile flops if set to 1.
# Enable torch profiler to profile memory if set # Deprecated, see profiler_config.
# VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1. If not set, torch profiler "VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: (
# will not profile memory. os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS")
"VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: bool( ),
os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", "0") != "0" # Disable torch profiling of the AsyncLLMEngine process if set to 1.
), # Deprecated, see profiler_config.
# Enable torch profiler to profile stack if set "VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: (
# VLLM_TORCH_PROFILER_WITH_STACK=1. If not set, torch profiler WILL os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM")
# profile stack by default.
"VLLM_TORCH_PROFILER_WITH_STACK": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_WITH_STACK", "1") != "0"
),
# Enable torch profiler to profile flops if set
# VLLM_TORCH_PROFILER_WITH_FLOPS=1. If not set, torch profiler will
# not profile flops.
"VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0"
),
# Disable torch profiling of the AsyncLLMEngine process.
# If set to 1, will not profile the engine process.
"VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM": lambda: bool(
os.getenv("VLLM_TORCH_PROFILER_DISABLE_ASYNC_LLM", "0") != "0"
), ),
# Delay number of iterations before starting profiling when using # Delay number of iterations before starting profiling when using
# the torch/torch CUDA profiler. If set to 0, will start profiling immediately. # the torch/torch CUDA profiler. If set to 0, will start profiling immediately.
"VLLM_PROFILER_DELAY_ITERS": lambda: int( # Deprecated, see profiler_config.
os.getenv("VLLM_PROFILER_DELAY_ITERS", "0") "VLLM_PROFILER_DELAY_ITERS": lambda: (os.getenv("VLLM_PROFILER_DELAY_ITERS")),
),
# Maximum number of iterations to profile when using the torch/torch CUDA profiler. # Maximum number of iterations to profile when using the torch/torch CUDA profiler.
# If set to 0, will not limit the number of iterations. # If set to 0, will not limit the number of iterations.
"VLLM_PROFILER_MAX_ITERS": lambda: int(os.getenv("VLLM_PROFILER_MAX_ITERS", "0")), "VLLM_PROFILER_MAX_ITERS": lambda: os.getenv("VLLM_PROFILER_MAX_ITERS"),
# Control whether torch profiler gzip-compresses profiling files. # Control whether torch profiler gzip-compresses profiling files.
# Set VLLM_TORCH_PROFILER_USE_GZIP=0 to disable gzip (enabled by default). # Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_USE_GZIP": lambda: bool( "VLLM_TORCH_PROFILER_USE_GZIP": lambda: os.getenv("VLLM_TORCH_PROFILER_USE_GZIP"),
os.getenv("VLLM_TORCH_PROFILER_USE_GZIP", "1") != "0"
),
# Control whether torch profiler dumps the self_cuda_time_total table. # Control whether torch profiler dumps the self_cuda_time_total table.
# Set VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0 to disable dumping # Set to 0 to disable dumping the table.
# (enabled by default). # Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: bool( "VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: (
os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL", "1") != "0" os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL")
), ),
# If set, vLLM will use Triton implementations of AWQ. # If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), "VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
...@@ -1098,6 +1096,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1098,6 +1096,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# rank. All DP ranks process the activations in VLLM_MOE_DP_CHUNK_SIZE # rank. All DP ranks process the activations in VLLM_MOE_DP_CHUNK_SIZE
# units. # units.
"VLLM_MOE_DP_CHUNK_SIZE": lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")), "VLLM_MOE_DP_CHUNK_SIZE": lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")),
"VLLM_ENABLE_MOE_DP_CHUNK": lambda: bool(
int(os.getenv("VLLM_ENABLE_MOE_DP_CHUNK", "1"))
),
# Randomize inputs during dummy runs when using Data Parallel # Randomize inputs during dummy runs when using Data Parallel
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS": lambda: os.environ.get( "VLLM_RANDOMIZE_DP_DUMMY_INPUTS": lambda: os.environ.get(
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0" "VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0"
...@@ -1259,6 +1260,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1259,6 +1260,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_NIXL_SIDE_CHANNEL_PORT": lambda: int( "VLLM_NIXL_SIDE_CHANNEL_PORT": lambda: int(
os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5600") os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5600")
), ),
# Port used for Mooncake handshake between remote agents.
"VLLM_MOONCAKE_BOOTSTRAP_PORT": lambda: int(
os.getenv("VLLM_MOONCAKE_BOOTSTRAP_PORT", "8998")
),
# all2all backend for vllm's expert parallel communication # all2all backend for vllm's expert parallel communication
# Available options: # Available options:
# - "naive": naive all2all implementation using broadcasts # - "naive": naive all2all implementation using broadcasts
...@@ -1368,6 +1373,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1368,6 +1373,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int( "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int(
os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480") os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480")
), ),
# Timeout (in seconds) for MooncakeConnector in PD disaggregated setup.
"VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT": lambda: int(
os.getenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "480")
),
# Controls whether or not to use cudnn prefill # Controls whether or not to use cudnn prefill
"VLLM_USE_CUDNN_PREFILL": lambda: bool( "VLLM_USE_CUDNN_PREFILL": lambda: bool(
int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0")) int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))
...@@ -1445,6 +1454,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1445,6 +1454,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool( "VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool(
int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")) int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))
), ),
# Experimental: use this to enable MCP tool calling for non harmony models
"VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT": lambda: bool(
int(os.getenv("VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", "0"))
),
# Allows vllm to find tuned config under customized folder # Allows vllm to find tuned config under customized folder
"VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), "VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
# Valid values are container,code_interpreter,web_search_preview # Valid values are container,code_interpreter,web_search_preview
......
...@@ -292,7 +292,7 @@ def set_forward_context( ...@@ -292,7 +292,7 @@ def set_forward_context(
if num_tokens_across_dp is None: if num_tokens_across_dp is None:
assert ubatch_slices is None assert ubatch_slices is None
assert num_tokens is not None assert num_tokens is not None
_, num_tokens_across_dp = coordinate_batch_across_dp( _, num_tokens_across_dp, _ = coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_unpadded=num_tokens,
parallel_config=vllm_config.parallel_config, parallel_config=vllm_config.parallel_config,
allow_microbatching=False, allow_microbatching=False,
......
...@@ -198,7 +198,7 @@ class InputPreprocessor: ...@@ -198,7 +198,7 @@ class InputPreprocessor:
) -> dict[str, Any]: ) -> dict[str, Any]:
kwargs = dict[str, Any]() kwargs = dict[str, Any]()
if self.model_config.hf_config.model_type == "whisper": if self.model_config.is_encoder_decoder:
# For Whisper, special tokens should be provided by the user based # For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid # on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation. # appending an EOS token to the prompt which disrupts generation.
...@@ -573,7 +573,6 @@ class InputPreprocessor: ...@@ -573,7 +573,6 @@ class InputPreprocessor:
""" """
encoder_inputs: SingletonInputs encoder_inputs: SingletonInputs
decoder_inputs: SingletonInputs | None decoder_inputs: SingletonInputs | None
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
# `cast` is needed for mypy, but not pyright # `cast` is needed for mypy, but not pyright
prompt_ = cast(ExplicitEncoderDecoderPrompt, prompt) prompt_ = cast(ExplicitEncoderDecoderPrompt, prompt)
...@@ -585,7 +584,9 @@ class InputPreprocessor: ...@@ -585,7 +584,9 @@ class InputPreprocessor:
if (decoder_input := prompt_["decoder_prompt"]) is None: if (decoder_input := prompt_["decoder_prompt"]) is None:
decoder_inputs = None decoder_inputs = None
else: else:
decoder_inputs = self._prompt_to_llm_inputs(decoder_input) decoder_inputs = self._prompt_to_llm_inputs(
decoder_input, tokenization_kwargs=tokenization_kwargs
)
# For multimodal model, override decoder prompt from processor # For multimodal model, override decoder prompt from processor
# with explicit decoder prompt. # with explicit decoder prompt.
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment