"cacheflow/vscode:/vscode.git/clone" did not exist on "afdbe5d3736f156e2a2c0afd13891f47a416baf5"
Commit fcfc474d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-dev

parents bb94d2e5 296c6572
...@@ -133,8 +133,9 @@ class RPCSleepRequest(Enum): ...@@ -133,8 +133,9 @@ class RPCSleepRequest(Enum):
SLEEP_LEVEL_2 = 2 SLEEP_LEVEL_2 = 2
class RPCWakeUpRequest(Enum): @dataclass
WAKE_UP = 1 class RPCWakeUpRequest:
tags: Optional[list[str]] = None
@dataclass @dataclass
......
...@@ -697,10 +697,10 @@ class MQLLMEngineClient(EngineClient): ...@@ -697,10 +697,10 @@ class MQLLMEngineClient(EngineClient):
return await self._send_one_way_rpc_request( return await self._send_one_way_rpc_request(
request=RPCSleepRequest(level), socket=self.input_socket) request=RPCSleepRequest(level), socket=self.input_socket)
async def wake_up(self) -> None: async def wake_up(self, tags: Optional[list[str]] = None) -> None:
"""Wake up the engine""" """Wake up the engine"""
return await self._send_one_way_rpc_request( return await self._send_one_way_rpc_request(
request=RPCWakeUpRequest.WAKE_UP, socket=self.input_socket) request=RPCWakeUpRequest(tags), socket=self.input_socket)
async def is_sleeping(self) -> bool: async def is_sleeping(self) -> bool:
"""Check whether the engine is sleeping""" """Check whether the engine is sleeping"""
......
...@@ -274,7 +274,7 @@ class MQLLMEngine: ...@@ -274,7 +274,7 @@ class MQLLMEngine:
elif isinstance(request, RPCSleepRequest): elif isinstance(request, RPCSleepRequest):
self.sleep(request.value) self.sleep(request.value)
elif isinstance(request, RPCWakeUpRequest): elif isinstance(request, RPCWakeUpRequest):
self.wake_up() self.wake_up(request.tags)
elif isinstance(request, RPCIsSleepingRequest): elif isinstance(request, RPCIsSleepingRequest):
self._handle_is_sleeping_request(request) self._handle_is_sleeping_request(request)
else: else:
...@@ -415,8 +415,8 @@ class MQLLMEngine: ...@@ -415,8 +415,8 @@ class MQLLMEngine:
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
self.engine.sleep(level) self.engine.sleep(level)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine.wake_up() self.engine.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.engine.is_sleeping() return self.engine.is_sleeping()
......
...@@ -282,7 +282,7 @@ class EngineClient(ABC): ...@@ -282,7 +282,7 @@ class EngineClient(ABC):
... ...
@abstractmethod @abstractmethod
async def wake_up(self) -> None: async def wake_up(self, tags: Optional[list[str]] = None) -> None:
"""Wake up the engine""" """Wake up the engine"""
... ...
......
...@@ -306,7 +306,24 @@ def _detect_content_format( ...@@ -306,7 +306,24 @@ def _detect_content_format(
return "openai" return "openai"
def _resolve_hf_chat_template( def resolve_mistral_chat_template(
chat_template: Optional[str],
**kwargs: Any,
) -> Optional[str]:
if chat_template is not None:
logger.warning_once(
"'chat_template' cannot be overridden for mistral tokenizer.")
if "add_generation_prompt" in kwargs:
logger.warning_once(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored.")
if "continue_final_message" in kwargs:
logger.warning_once(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.")
return None
def resolve_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: Optional[str], chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]], tools: Optional[list[dict[str, Any]]],
...@@ -352,7 +369,7 @@ def _resolve_chat_template_content_format( ...@@ -352,7 +369,7 @@ def _resolve_chat_template_content_format(
trust_remote_code: bool, trust_remote_code: bool,
) -> _ChatTemplateContentFormat: ) -> _ChatTemplateContentFormat:
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
hf_chat_template = _resolve_hf_chat_template( hf_chat_template = resolve_hf_chat_template(
tokenizer, tokenizer,
chat_template=chat_template, chat_template=chat_template,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
...@@ -470,7 +487,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -470,7 +487,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "<|endoftext10|>" # 200010 (see vocab.json in hf model) return "<|endoftext10|>" # 200010 (see vocab.json in hf model)
if model_type in ("minicpmo", "minicpmv"): if model_type in ("minicpmo", "minicpmv"):
return "(<image>./</image>)" return "(<image>./</image>)"
if model_type in ("blip-2", "fuyu", "paligemma", "pixtral"): if model_type in ("blip-2", "fuyu", "paligemma", "pixtral",
"mistral3"):
# These models do not use image tokens in the prompt # These models do not use image tokens in the prompt
return None return None
if model_type == "qwen": if model_type == "qwen":
...@@ -478,10 +496,11 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -478,10 +496,11 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type.startswith("llava"): if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer, return self._cached_token_str(self._tokenizer,
hf_config.image_token_index) hf_config.image_token_index)
if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat", if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
"NVLM_D", "h2ovl_chat"): "internvl_chat", "skywork_chat", "NVLM_D",
"h2ovl_chat"):
return "<image>" return "<image>"
if model_type == "mllama": if model_type in ("mllama", "llama4"):
return "<|image|>" return "<|image|>"
if model_type in ("qwen2_vl", "qwen2_5_vl"): if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|image_pad|><|vision_end|>" return "<|vision_start|><|image_pad|><|vision_end|>"
...@@ -1140,7 +1159,7 @@ def apply_hf_chat_template( ...@@ -1140,7 +1159,7 @@ def apply_hf_chat_template(
tokenize: bool = False, # Different from HF's default tokenize: bool = False, # Different from HF's default
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
hf_chat_template = _resolve_hf_chat_template( hf_chat_template = resolve_hf_chat_template(
tokenizer, tokenizer,
chat_template=chat_template, chat_template=chat_template,
tools=tools, tools=tools,
...@@ -1169,17 +1188,12 @@ def apply_mistral_chat_template( ...@@ -1169,17 +1188,12 @@ def apply_mistral_chat_template(
tools: Optional[list[dict[str, Any]]], tools: Optional[list[dict[str, Any]]],
**kwargs: Any, **kwargs: Any,
) -> list[int]: ) -> list[int]:
if chat_template is not None: # The return value of resolve_mistral_chat_template is always None,
logger.warning_once( # and we won't use it.
"'chat_template' cannot be overridden for mistral tokenizer.") resolve_mistral_chat_template(
if "add_generation_prompt" in kwargs: chat_template=chat_template,
logger.warning_once( **kwargs,
"'add_generation_prompt' is not supported for mistral tokenizer, " )
"so it will be ignored.")
if "continue_final_message" in kwargs:
logger.warning_once(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.")
return tokenizer.apply_chat_template( return tokenizer.apply_chat_template(
messages=messages, messages=messages,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# The CLI entrypoint to vLLM. # The CLI entrypoint to vLLM.
import os
import signal import signal
import sys import sys
...@@ -9,11 +8,9 @@ import vllm.entrypoints.cli.benchmark.main ...@@ -9,11 +8,9 @@ import vllm.entrypoints.cli.benchmark.main
import vllm.entrypoints.cli.openai import vllm.entrypoints.cli.openai
import vllm.entrypoints.cli.serve import vllm.entrypoints.cli.serve
import vllm.version import vllm.version
from vllm.logger import init_logger from vllm.entrypoints.utils import cli_env_setup
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
logger = init_logger(__name__)
CMD_MODULES = [ CMD_MODULES = [
vllm.entrypoints.cli.openai, vllm.entrypoints.cli.openai,
vllm.entrypoints.cli.serve, vllm.entrypoints.cli.serve,
...@@ -30,29 +27,8 @@ def register_signal_handlers(): ...@@ -30,29 +27,8 @@ def register_signal_handlers():
signal.signal(signal.SIGTSTP, signal_handler) signal.signal(signal.SIGTSTP, signal_handler)
def env_setup():
# The safest multiprocessing method is `spawn`, as the default `fork` method
# is not compatible with some accelerators. The default method will be
# changing in future versions of Python, so we should use it explicitly when
# possible.
#
# We only set it here in the CLI entrypoint, because changing to `spawn`
# could break some existing code using vLLM as a library. `spawn` will cause
# unexpected behavior if the code is not protected by
# `if __name__ == "__main__":`.
#
# References:
# - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
# - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing
# - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors
# - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders
if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ:
logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def main(): def main():
env_setup() cli_env_setup()
parser = FlexibleArgumentParser(description="vLLM CLI") parser = FlexibleArgumentParser(description="vLLM CLI")
parser.add_argument('-v', parser.add_argument('-v',
......
...@@ -4,7 +4,6 @@ import argparse ...@@ -4,7 +4,6 @@ import argparse
import uvloop import uvloop
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.api_server import run_server from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import (make_arg_parser, from vllm.entrypoints.openai.cli_args import (make_arg_parser,
...@@ -21,14 +20,9 @@ class ServeSubcommand(CLISubcommand): ...@@ -21,14 +20,9 @@ class ServeSubcommand(CLISubcommand):
@staticmethod @staticmethod
def cmd(args: argparse.Namespace) -> None: def cmd(args: argparse.Namespace) -> None:
# The default value of `--model` # If model is specified in CLI (as positional arg), it takes precedence
if args.model != EngineArgs.model: if hasattr(args, 'model_tag') and args.model_tag is not None:
raise ValueError( args.model = args.model_tag
"With `vllm serve`, you should provide the model as a "
"positional argument instead of via the `--model` option.")
# EngineArgs expects the model name to be passed as --model.
args.model = args.model_tag
uvloop.run(run_server(args)) uvloop.run(run_server(args))
...@@ -41,10 +35,12 @@ class ServeSubcommand(CLISubcommand): ...@@ -41,10 +35,12 @@ class ServeSubcommand(CLISubcommand):
serve_parser = subparsers.add_parser( serve_parser = subparsers.add_parser(
"serve", "serve",
help="Start the vLLM OpenAI Compatible API server", help="Start the vLLM OpenAI Compatible API server",
usage="vllm serve <model_tag> [options]") usage="vllm serve [model_tag] [options]")
serve_parser.add_argument("model_tag", serve_parser.add_argument("model_tag",
type=str, type=str,
help="The model tag to serve") nargs='?',
help="The model tag to serve "
"(optional if specified in config)")
serve_parser.add_argument( serve_parser.add_argument(
"--config", "--config",
type=str, type=str,
......
...@@ -492,8 +492,8 @@ class LLM: ...@@ -492,8 +492,8 @@ class LLM:
It is recommended to use this API to only pass control messages, It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data. and set up data-plane communication to pass data.
""" """
executor = self.llm_engine.model_executor
return executor.collective_rpc(method, timeout, args, kwargs) return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
""" """
...@@ -1200,26 +1200,35 @@ class LLM: ...@@ -1200,26 +1200,35 @@ class LLM:
The caller should guarantee that no requests are being processed The caller should guarantee that no requests are being processed
during the sleep period, before `wake_up` is called. during the sleep period, before `wake_up` is called.
:param level: The sleep level. Level 1 sleep will offload the model Args:
weights and discard the kv cache. The content of kv cache is level: The sleep level. Level 1 sleep will offload the model
forgotten. Level 1 sleep is good for sleeping and waking up the weights and discard the kv cache. The content of kv cache
engine to run the same model again. The model weights are backed is forgotten. Level 1 sleep is good for sleeping and waking
up in CPU memory. Please make sure there's enough CPU memory to up the engine to run the same model again. The model weights
store the model weights. Level 2 sleep will discard both the model are backed up in CPU memory. Please make sure there's enough
weights and the kv cache. The content of both the model weights CPU memory to store the model weights. Level 2 sleep will
and kv cache is forgotten. Level 2 sleep is good for sleeping and discard both the model weights and the kv cache. The content
waking up the engine to run a different model or update the model, of both the model weights and kv cache is forgotten. Level 2
where previous model weights are not needed. It reduces CPU memory sleep is good for sleeping and waking up the engine to run a
pressure. different model or update the model, where previous model
weights are not needed. It reduces CPU memory pressure.
""" """
self.reset_prefix_cache() self.reset_prefix_cache()
self.llm_engine.sleep(level=level) self.llm_engine.sleep(level=level)
def wake_up(self): def wake_up(self, tags: Optional[list[str]] = None):
""" """
Wake up the engine from sleep mode. See the :meth:`sleep` method Wake up the engine from sleep mode. See the :meth:`sleep` method
for more details.""" for more details.
self.llm_engine.wake_up()
Args:
tags: An optional list of tags to reallocate the engine memory
for specific memory allocations. Values must be in
("weights", "kv_cache",). If None, all memory is reallocated.
wake_up should be called with all tags (or None) before the
engine is used again.
"""
self.llm_engine.wake_up(tags)
# LEGACY # LEGACY
def _convert_v1_inputs( def _convert_v1_inputs(
......
...@@ -24,6 +24,7 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request ...@@ -24,6 +24,7 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import State from starlette.datastructures import State
from starlette.routing import Mount from starlette.routing import Mount
from typing_extensions import assert_never from typing_extensions import assert_never
...@@ -35,7 +36,9 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore ...@@ -35,7 +36,9 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.chat_utils import (load_chat_template,
resolve_hf_chat_template,
resolve_mistral_chat_template)
from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import (make_arg_parser, from vllm.entrypoints.openai.cli_args import (make_arg_parser,
...@@ -65,7 +68,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -65,7 +68,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TranscriptionRequest, TranscriptionRequest,
TranscriptionResponse, TranscriptionResponse,
UnloadLoRAAdapterRequest) UnloadLoRAAdapterRequest)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
...@@ -80,10 +82,13 @@ from vllm.entrypoints.openai.serving_tokenization import ( ...@@ -80,10 +82,13 @@ from vllm.entrypoints.openai.serving_tokenization import (
from vllm.entrypoints.openai.serving_transcription import ( from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription) OpenAIServingTranscription)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import load_aware_call, with_cancellation from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
with_cancellation)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value) maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path, from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address, set_ulimit) is_valid_ipv6_address, set_ulimit)
...@@ -307,6 +312,7 @@ def mount_metrics(app: FastAPI): ...@@ -307,6 +312,7 @@ def mount_metrics(app: FastAPI):
# See https://prometheus.github.io/client_python/multiprocess/ # See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (CollectorRegistry, make_asgi_app, from prometheus_client import (CollectorRegistry, make_asgi_app,
multiprocess) multiprocess)
from prometheus_fastapi_instrumentator import Instrumentator
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
if prometheus_multiproc_dir_path is not None: if prometheus_multiproc_dir_path is not None:
...@@ -314,6 +320,16 @@ def mount_metrics(app: FastAPI): ...@@ -314,6 +320,16 @@ def mount_metrics(app: FastAPI):
prometheus_multiproc_dir_path) prometheus_multiproc_dir_path)
registry = CollectorRegistry() registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry) multiprocess.MultiProcessCollector(registry)
Instrumentator(
excluded_handlers=[
"/metrics",
"/health",
"/load",
"/ping",
"/version",
],
registry=registry,
).add().instrument(app).expose(app)
# Add prometheus asgi middleware to route /metrics requests # Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
...@@ -689,7 +705,6 @@ if envs.VLLM_SERVER_DEV_MODE: ...@@ -689,7 +705,6 @@ if envs.VLLM_SERVER_DEV_MODE:
async def sleep(raw_request: Request): async def sleep(raw_request: Request):
# get POST params # get POST params
level = raw_request.query_params.get("level", "1") level = raw_request.query_params.get("level", "1")
logger.info("sleep the engine with level %s", level)
await engine_client(raw_request).sleep(int(level)) await engine_client(raw_request).sleep(int(level))
# FIXME: in v0 with frontend multiprocessing, the sleep command # FIXME: in v0 with frontend multiprocessing, the sleep command
# is sent but does not finish yet when we return a response. # is sent but does not finish yet when we return a response.
...@@ -697,8 +712,12 @@ if envs.VLLM_SERVER_DEV_MODE: ...@@ -697,8 +712,12 @@ if envs.VLLM_SERVER_DEV_MODE:
@router.post("/wake_up") @router.post("/wake_up")
async def wake_up(raw_request: Request): async def wake_up(raw_request: Request):
logger.info("wake up the engine") tags = raw_request.query_params.getlist("tags")
await engine_client(raw_request).wake_up() if tags == []:
# set to None to wake up all tags if no tags are provided
tags = None
logger.info("wake up the engine with tags: %s", tags)
await engine_client(raw_request).wake_up(tags)
# FIXME: in v0 with frontend multiprocessing, the wake-up command # FIXME: in v0 with frontend multiprocessing, the wake-up command
# is sent but does not finish yet when we return a response. # is sent but does not finish yet when we return a response.
return Response(status_code=200) return Response(status_code=200)
...@@ -814,7 +833,8 @@ def build_app(args: Namespace) -> FastAPI: ...@@ -814,7 +833,8 @@ def build_app(args: Namespace) -> FastAPI:
return JSONResponse(err.model_dump(), return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST) status_code=HTTPStatus.BAD_REQUEST)
if token := envs.VLLM_API_KEY or args.api_key: # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
if token := args.api_key or envs.VLLM_API_KEY:
@app.middleware("http") @app.middleware("http")
async def authentication(request: Request, call_next): async def authentication(request: Request, call_next):
...@@ -843,6 +863,21 @@ def build_app(args: Namespace) -> FastAPI: ...@@ -843,6 +863,21 @@ def build_app(args: Namespace) -> FastAPI:
response.headers["X-Request-Id"] = request_id response.headers["X-Request-Id"] = request_id
return response return response
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
logger.warning("CAUTION: Enabling log response in the API Server. "
"This can include sensitive information and should be "
"avoided in production.")
@app.middleware("http")
async def log_response(request: Request, call_next):
response = await call_next(request)
response_body = [
section async for section in response.body_iterator
]
response.body_iterator = iterate_in_threadpool(iter(response_body))
logger.info("response_body={%s}", response_body[0].decode())
return response
for middleware in args.middleware: for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1) module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name) imported = getattr(importlib.import_module(module_path), object_name)
...@@ -883,8 +918,26 @@ async def init_app_state( ...@@ -883,8 +918,26 @@ async def init_app_state(
resolved_chat_template = load_chat_template(args.chat_template) resolved_chat_template = load_chat_template(args.chat_template)
if resolved_chat_template is not None: if resolved_chat_template is not None:
logger.info("Using supplied chat template:\n%s", # Get the tokenizer to check official template
resolved_chat_template) tokenizer = await engine_client.get_tokenizer()
if isinstance(tokenizer, MistralTokenizer):
# The warning is logged in resolve_mistral_chat_template.
resolved_chat_template = resolve_mistral_chat_template(
chat_template=resolved_chat_template)
else:
hf_chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=None,
trust_remote_code=model_config.trust_remote_code)
if hf_chat_template != resolved_chat_template:
logger.warning(
"Using supplied chat template: %s\n"
"It is different from official chat template '%s'. "
"This discrepancy may lead to performance degradation.",
resolved_chat_template, args.model)
state.openai_serving_models = OpenAIServingModels( state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client, engine_client=engine_client,
...@@ -1048,15 +1101,17 @@ async def run_server(args, **uvicorn_kwargs) -> None: ...@@ -1048,15 +1101,17 @@ async def run_server(args, **uvicorn_kwargs) -> None:
) )
# NB: Await server shutdown only after the backend context is exited # NB: Await server shutdown only after the backend context is exited
await shutdown_task try:
await shutdown_task
sock.close() finally:
sock.close()
if __name__ == "__main__": if __name__ == "__main__":
# NOTE(simon): # NOTE(simon):
# This section should be in sync with vllm/entrypoints/cli/main.py for CLI # This section should be in sync with vllm/entrypoints/cli/main.py for CLI
# entrypoints. # entrypoints.
cli_env_setup()
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.") description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser) parser = make_arg_parser(parser)
......
...@@ -247,7 +247,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ...@@ -247,7 +247,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=None, default=None,
help='Max number of prompt characters or prompt ' help='Max number of prompt characters or prompt '
'ID numbers being printed in log.' 'ID numbers being printed in log.'
'\n\nDefault: Unlimited') ' The default of None means unlimited.')
parser.add_argument( parser.add_argument(
"--disable-fastapi-docs", "--disable-fastapi-docs",
......
...@@ -61,7 +61,7 @@ class OpenAIBaseModel(BaseModel): ...@@ -61,7 +61,7 @@ class OpenAIBaseModel(BaseModel):
field_names = set() field_names = set()
for field_name, field in cls.model_fields.items(): for field_name, field in cls.model_fields.items():
field_names.add(field_name) field_names.add(field_name)
if alias := getattr(field, 'alias', None): if alias := getattr(field, "alias", None):
field_names.add(alias) field_names.add(alias)
cls.field_names = field_names cls.field_names = field_names
...@@ -70,7 +70,8 @@ class OpenAIBaseModel(BaseModel): ...@@ -70,7 +70,8 @@ class OpenAIBaseModel(BaseModel):
logger.warning( logger.warning(
"The following fields were present in the request " "The following fields were present in the request "
"but ignored: %s", "but ignored: %s",
data.keys() - field_names) data.keys() - field_names,
)
return result return result
...@@ -234,8 +235,12 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -234,8 +235,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
temperature: Optional[float] = None temperature: Optional[float] = None
top_p: Optional[float] = None top_p: Optional[float] = None
tools: Optional[list[ChatCompletionToolsParam]] = None tools: Optional[list[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"], Literal["auto"], tool_choice: Optional[Union[
ChatCompletionNamedToolChoiceParam]] = "none" Literal["none"],
Literal["auto"],
Literal["required"],
ChatCompletionNamedToolChoiceParam,
]] = "none"
# NOTE this will be ignored by vLLM -- the model determines the behavior # NOTE this will be ignored by vLLM -- the model determines the behavior
parallel_tool_calls: Optional[bool] = False parallel_tool_calls: Optional[bool] = False
...@@ -340,24 +345,28 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -340,24 +345,28 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=( description=(
"If specified, will override the default guided decoding backend " "If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either " "of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'")) "'outlines' / 'lm-format-enforcer'"),
)
guided_whitespace_pattern: Optional[str] = Field( guided_whitespace_pattern: Optional[str] = Field(
default=None, default=None,
description=( description=(
"If specified, will override the default whitespace pattern " "If specified, will override the default whitespace pattern "
"for guided json decoding.")) "for guided json decoding."),
)
priority: int = Field( priority: int = Field(
default=0, default=0,
description=( description=(
"The priority of the request (lower means earlier handling; " "The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error " "default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling.")) "if the served model does not use priority scheduling."),
)
request_id: str = Field( request_id: str = Field(
default_factory=lambda: f"{random_uuid()}", default_factory=lambda: f"{random_uuid()}",
description=( description=(
"The request_id related to this request. If the caller does " "The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used " "not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response.")) "through out the inference process and return in response."),
)
logits_processors: Optional[LogitsProcessors] = Field( logits_processors: Optional[LogitsProcessors] = Field(
default=None, default=None,
description=( description=(
...@@ -415,13 +424,15 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -415,13 +424,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
temperature=temperature, temperature=temperature,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output) include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params( def to_sampling_params(
self, self,
default_max_tokens: int, default_max_tokens: int,
logits_processor_pattern: Optional[str], logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams: default_sampling_params: Optional[dict] = None,
) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API # TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens max_tokens = self.max_completion_tokens or self.max_tokens
...@@ -475,7 +486,8 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -475,7 +486,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
grammar=self.guided_grammar, grammar=self.guided_grammar,
json_object=guided_json_object, json_object=guided_json_object,
backend=self.guided_decoding_backend, backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern) whitespace_pattern=self.guided_whitespace_pattern,
)
return SamplingParams.from_optional( return SamplingParams.from_optional(
n=self.n, n=self.n,
...@@ -522,6 +534,41 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -522,6 +534,41 @@ class ChatCompletionRequest(OpenAIBaseModel):
tool = tools[tool_name] tool = tools[tool_name]
return tool.parameters return tool.parameters
if self.tool_choice == "required":
# Pydantic schema generation cannot be used since the JSON schema
# has to be constructed for a specific instantiation of a tool list
# so that parameters of a function are correctly generated
# based on the chosen function name
def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
return {
"properties": {
"name": {
"type": "string",
"enum": [tool.function.name]
},
# parameters are always generated as '{}' in the final
# output if they are missing from the request
# (i.e. are None or '{}') so the schema is
# updated to produce an empty object in that case
"parameters": tool.function.parameters
if tool.function.parameters else {
"type": "object",
"properties": {}
}
},
"required": ["name", "parameters"]
}
json_schema = {
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": [get_tool_schema(tool) for tool in self.tools]
}
}
return json_schema
return None return None
@model_validator(mode="before") @model_validator(mode="before")
...@@ -572,8 +619,11 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -572,8 +619,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
"You can only use one kind of guided decoding " "You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').") "('guided_json', 'guided_regex' or 'guided_choice').")
# you can only either use guided decoding or tools, not both # you can only either use guided decoding or tools, not both
if guide_count > 1 and data.get("tool_choice", if guide_count > 1 and data.get("tool_choice", "none") not in (
"none") not in ("none", "auto"): "none",
"auto",
"required",
):
raise ValueError( raise ValueError(
"You can only either use guided decoding or tools, not both.") "You can only either use guided decoding or tools, not both.")
return data return data
...@@ -602,12 +652,15 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -602,12 +652,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
"When using `tool_choice`, `tools` must be set.") "When using `tool_choice`, `tools` must be set.")
# make sure that tool choice is either a named tool # make sure that tool choice is either a named tool
# OR that it's set to "auto" # OR that it's set to "auto" or "required"
if data["tool_choice"] != "auto" and not isinstance( if data["tool_choice"] not in [
data["tool_choice"], dict): "auto", "required"
raise ValueError( ] and not isinstance(data["tool_choice"], dict):
"`tool_choice` must either be a named tool, \"auto\", " raise NotImplementedError(
"or \"none\".") f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\
'Only named tools, "none", "auto" or "required" '\
'are supported.'
)
# ensure that if "tool_choice" is specified as an object, # ensure that if "tool_choice" is specified as an object,
# it matches a valid tool # it matches a valid tool
...@@ -722,18 +775,21 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -722,18 +775,21 @@ class CompletionRequest(OpenAIBaseModel):
description=( description=(
"If specified, will override the default guided decoding backend " "If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of " "of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'")) "'outlines' / 'lm-format-enforcer'"),
)
guided_whitespace_pattern: Optional[str] = Field( guided_whitespace_pattern: Optional[str] = Field(
default=None, default=None,
description=( description=(
"If specified, will override the default whitespace pattern " "If specified, will override the default whitespace pattern "
"for guided json decoding.")) "for guided json decoding."),
)
priority: int = Field( priority: int = Field(
default=0, default=0,
description=( description=(
"The priority of the request (lower means earlier handling; " "The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error " "default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling.")) "if the served model does not use priority scheduling."),
)
logits_processors: Optional[LogitsProcessors] = Field( logits_processors: Optional[LogitsProcessors] = Field(
default=None, default=None,
description=( description=(
...@@ -745,6 +801,7 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -745,6 +801,7 @@ class CompletionRequest(OpenAIBaseModel):
"arguments. For example: {'qualname': " "arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}.")) "{'param': 'value'}}."))
return_tokens_as_token_ids: Optional[bool] = Field( return_tokens_as_token_ids: Optional[bool] = Field(
default=None, default=None,
description=( description=(
...@@ -789,13 +846,15 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -789,13 +846,15 @@ class CompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
temperature=temperature, temperature=temperature,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output) include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params( def to_sampling_params(
self, self,
default_max_tokens: int, default_max_tokens: int,
logits_processor_pattern: Optional[str], logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams: default_sampling_params: Optional[dict] = None,
) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if default_sampling_params is None: if default_sampling_params is None:
...@@ -844,7 +903,8 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -844,7 +903,8 @@ class CompletionRequest(OpenAIBaseModel):
grammar=self.guided_grammar, grammar=self.guided_grammar,
json_object=guided_json_object, json_object=guided_json_object,
backend=self.guided_decoding_backend, backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern) whitespace_pattern=self.guided_whitespace_pattern,
)
return SamplingParams.from_optional( return SamplingParams.from_optional(
n=self.n, n=self.n,
...@@ -942,7 +1002,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): ...@@ -942,7 +1002,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
description=( description=(
"The priority of the request (lower means earlier handling; " "The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error " "default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling.")) "if the served model does not use priority scheduling."),
)
# doc: end-embedding-extra-params # doc: end-embedding-extra-params
...@@ -995,7 +1056,8 @@ class EmbeddingChatRequest(OpenAIBaseModel): ...@@ -995,7 +1056,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
description=( description=(
"The priority of the request (lower means earlier handling; " "The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error " "default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling.")) "if the served model does not use priority scheduling."),
)
# doc: end-chat-embedding-extra-params # doc: end-chat-embedding-extra-params
@model_validator(mode="before") @model_validator(mode="before")
...@@ -1034,7 +1096,8 @@ class ScoreRequest(OpenAIBaseModel): ...@@ -1034,7 +1096,8 @@ class ScoreRequest(OpenAIBaseModel):
description=( description=(
"The priority of the request (lower means earlier handling; " "The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error " "default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling.")) "if the served model does not use priority scheduling."),
)
# doc: end-score-extra-params # doc: end-score-extra-params
...@@ -1059,7 +1122,8 @@ class RerankRequest(OpenAIBaseModel): ...@@ -1059,7 +1122,8 @@ class RerankRequest(OpenAIBaseModel):
description=( description=(
"The priority of the request (lower means earlier handling; " "The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error " "default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling.")) "if the served model does not use priority scheduling."),
)
# doc: end-rerank-extra-params # doc: end-rerank-extra-params
...@@ -1238,6 +1302,9 @@ class ChatCompletionLogProb(OpenAIBaseModel): ...@@ -1238,6 +1302,9 @@ class ChatCompletionLogProb(OpenAIBaseModel):
class ChatCompletionLogProbsContent(ChatCompletionLogProb): class ChatCompletionLogProbsContent(ChatCompletionLogProb):
# Workaround: redefine fields name cache so that it's not
# shared with the super class.
field_names: ClassVar[Optional[set[str]]] = None
top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list) top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list)
......
...@@ -2,13 +2,16 @@ ...@@ -2,13 +2,16 @@
import asyncio import asyncio
import json import json
import re
import time import time
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence from collections.abc import Sequence as GenericSequence
from typing import Callable, Final, Optional, Union from typing import Callable, Final, Optional, Union
import jinja2 import jinja2
import partial_json_parser
from fastapi import Request from fastapi import Request
from pydantic import TypeAdapter
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
...@@ -21,10 +24,8 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -21,10 +24,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo, DeltaToolCall, ErrorResponse, FunctionCall, FunctionDefinition,
RequestResponseMetadata, ToolCall, UsageInfo) PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing, from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs) clamp_prompt_logprobs)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
...@@ -33,6 +34,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( ...@@ -33,6 +34,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall) MistralToolCall)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
...@@ -151,12 +153,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -151,12 +153,6 @@ class OpenAIServingChat(OpenAIServing):
tool_parser = self.tool_parser tool_parser = self.tool_parser
# validation for OpenAI tools
# tool_choice = "required" is not supported
if request.tool_choice == "required":
return self.create_error_response(
"tool_choice = \"required\" is not supported!")
if isinstance(tokenizer, MistralTokenizer): if isinstance(tokenizer, MistralTokenizer):
# because of issues with pydantic we need to potentially # because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request # re-serialize the tool_calls field of the request
...@@ -197,16 +193,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -197,16 +193,8 @@ class OpenAIServingChat(OpenAIServing):
truncate_prompt_tokens=request.truncate_prompt_tokens, truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
except ValueError as e: except (ValueError, TypeError, RuntimeError,
logger.exception("Error in preprocessing prompt inputs") jinja2.TemplateError) as e:
return self.create_error_response(str(e))
except TypeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except RuntimeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except jinja2.TemplateError as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
...@@ -286,6 +274,122 @@ class OpenAIServingChat(OpenAIServing): ...@@ -286,6 +274,122 @@ class OpenAIServingChat(OpenAIServing):
return self.response_role return self.response_role
return request.messages[-1]["role"] return request.messages[-1]["role"]
@staticmethod
def _bracket_level(s: str, opening='{', closing='}') -> int:
"""
Calculate the current level of nested brackets in a given string.
"""
level = 0
for char in s:
if char == opening:
level += 1
elif char == closing:
level -= 1
return level
@staticmethod
def _filter_delta_text(delta_text: str,
previous_text: str) -> tuple[str, bool]:
# remove last '},' of the tool definition stemming from the
# "name"/"parameters" outer object or closing ']' of the tool list
# count occurrences of opening and closing curly braces and
# once level 0 is reached stop outputting text
# if 0 is reached while parsing the delta_text we know the current
# tool will finish in this current iteration
bracket_level = OpenAIServingChat._bracket_level(previous_text)
updated_delta, passed_zero = "", False
for c in delta_text:
if c == '{':
bracket_level += 1
passed_zero = bracket_level == 0
elif c == '}':
bracket_level -= 1
passed_zero = bracket_level == 0
if bracket_level != 0:
updated_delta += c
else:
# if a comma is reached at level 0 we can stop
if c == ',':
break
return updated_delta, passed_zero
def extract_tool_call_required_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
function_name_returned: bool,
) -> tuple[Optional[DeltaMessage], bool]:
try:
obj = partial_json_parser.loads(current_text)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
obj = None
# check if the current text is a valid array
# containing a partial tool calling object
# if not repeat
if obj is None or not isinstance(obj, list) or not len(obj) > 0:
function_name_returned = False
delta_message = None
else:
_, finishes_previous_tool = OpenAIServingChat._filter_delta_text(
delta_text, previous_text)
# take the last tool call from the generated list
current_tool_call = obj[-1]
# once parameters have been generated the name is complete as well
if not finishes_previous_tool and ("name" not in current_tool_call
or "parameters"
not in current_tool_call):
function_name_returned = False
delta_message = None
else:
if not function_name_returned:
# get partly generated arguments from the latest tool call
param_match = re.search(r'.*"parameters":\s*(.*)',
current_text)
arguments = param_match.group(1) if param_match else ""
arguments, _ = OpenAIServingChat._filter_delta_text(
arguments, previous_text)
# if this iteration finishes a previous tool call but a
# new incomplete tool is already generated, take the
# previous from the list
if (finishes_previous_tool
and "parameters" not in current_tool_call):
current_tool_call = obj[-2]
function_name_returned = True
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=current_tool_call["name"],
arguments=arguments),
index=len(obj) - 1,
type="function")
])
else:
delta_text, _ = OpenAIServingChat._filter_delta_text(
delta_text, previous_text)
if delta_text != "":
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(
function=DeltaFunctionCall(
# OpenAI API returns None
# instead of name every time
name=None,
arguments=delta_text),
index=len(obj) - 1,
type="function")
])
else:
delta_message = None
return delta_message, function_name_returned
async def chat_completion_stream_generator( async def chat_completion_stream_generator(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
...@@ -321,6 +425,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -321,6 +425,7 @@ class OpenAIServingChat(OpenAIServing):
self._should_stream_with_reasoning_parsing(request)) self._should_stream_with_reasoning_parsing(request))
all_previous_token_ids: Optional[list[list[int]]] all_previous_token_ids: Optional[list[list[int]]]
function_name_returned: Optional[list[bool]] = None
# Only one of these will be used, thus previous_texts and # Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration. # all_previous_token_ids will not be used twice in the same iteration.
...@@ -331,6 +436,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -331,6 +436,10 @@ class OpenAIServingChat(OpenAIServing):
# For reasoning parser and tool call all enabled # For reasoning parser and tool call all enabled
added_content_delta_arr = [False] * num_choices added_content_delta_arr = [False] * num_choices
reasoning_end_arr = [False] * num_choices reasoning_end_arr = [False] * num_choices
elif request.tool_choice == "required":
previous_texts = [""] * num_choices
function_name_returned = [False] * num_choices
all_previous_token_ids = None
else: else:
previous_texts, all_previous_token_ids = None, None previous_texts, all_previous_token_ids = None, None
...@@ -530,6 +639,23 @@ class OpenAIServingChat(OpenAIServing): ...@@ -530,6 +639,23 @@ class OpenAIServingChat(OpenAIServing):
index=i) index=i)
]) ])
elif request.tool_choice == "required":
assert previous_texts is not None
assert function_name_returned is not None
previous_text = previous_texts[i]
current_text = previous_text + delta_text
fn_name_returned = function_name_returned[i]
delta_message, function_name_returned[i] = (
self.extract_tool_call_required_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=fn_name_returned))
# update the previous values for the next iteration
previous_texts[i] = current_text
# handle streaming deltas for tools with "auto" tool choice # handle streaming deltas for tools with "auto" tool choice
# and reasoning parser # and reasoning parser
elif tool_choice_auto and self.enable_reasoning: elif tool_choice_auto and self.enable_reasoning:
...@@ -830,10 +956,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -830,10 +956,10 @@ class OpenAIServingChat(OpenAIServing):
# if auto tools are not enabled, and a named tool choice using # if auto tools are not enabled, and a named tool choice using
# outlines is not being used # outlines is not being used
if (not self.enable_auto_tools if (not self.enable_auto_tools or not self.tool_parser) and \
or not self.tool_parser) and not isinstance( (not isinstance(request.tool_choice,
request.tool_choice, ChatCompletionNamedToolChoiceParam
ChatCompletionNamedToolChoiceParam): ) and request.tool_choice != "required"):
message = ChatMessage(role=role, message = ChatMessage(role=role,
reasoning_content=reasoning_content, reasoning_content=reasoning_content,
content=content) content=content)
...@@ -854,6 +980,24 @@ class OpenAIServingChat(OpenAIServing): ...@@ -854,6 +980,24 @@ class OpenAIServingChat(OpenAIServing):
arguments=content)) arguments=content))
]) ])
elif request.tool_choice and request.tool_choice == "required":
tool_call_class = MistralToolCall if isinstance(
tokenizer, MistralTokenizer) else ToolCall
# the fields of FunctionDefinition are a superset of the
# tool call outputs and can be used for parsing
tool_calls = TypeAdapter(
list[FunctionDefinition]).validate_json(output.text)
message = ChatMessage(
role=role,
content="",
tool_calls=[
tool_call_class(function=FunctionCall(
name=tool_call.name,
arguments=json.dumps(tool_call.parameters)))
for tool_call in tool_calls
])
# if the request doesn't use tool choice # if the request doesn't use tool choice
# OR specifies to not use a tool # OR specifies to not use a tool
elif not request.tool_choice or request.tool_choice == "none": elif not request.tool_choice or request.tool_choice == "none":
......
...@@ -139,10 +139,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -139,10 +139,7 @@ class OpenAIServingEmbedding(OpenAIServing):
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
except ValueError as e: except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except TypeError as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
......
...@@ -537,7 +537,7 @@ class OpenAIServing: ...@@ -537,7 +537,7 @@ class OpenAIServing:
lora_request: Optional[LoRARequest] = None) -> str: lora_request: Optional[LoRARequest] = None) -> str:
if lora_request: if lora_request:
return lora_request.lora_name return lora_request.lora_name
if model_name is None: if not model_name:
return self.models.base_model_paths[0].name return self.models.base_model_paths[0].name
return model_name return model_name
......
...@@ -162,7 +162,7 @@ class OpenAIServingModels: ...@@ -162,7 +162,7 @@ class OpenAIServingModels:
except BaseException as e: except BaseException as e:
error_type = "BadRequestError" error_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST status_code = HTTPStatus.BAD_REQUEST
if isinstance(e, ValueError) and "No adapter found" in str(e): if "No adapter found" in str(e):
error_type = "NotFoundError" error_type = "NotFoundError"
status_code = HTTPStatus.NOT_FOUND status_code = HTTPStatus.NOT_FOUND
......
...@@ -136,13 +136,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -136,13 +136,7 @@ class OpenAIServingPooling(OpenAIServing):
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
except ValueError as e: except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except TypeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except jinja2.TemplateError as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
......
...@@ -89,13 +89,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -89,13 +89,7 @@ class OpenAIServingTokenization(OpenAIServing):
request.prompt, request.prompt,
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
except ValueError as e: except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except TypeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except jinja2.TemplateError as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
......
...@@ -8,11 +8,12 @@ from .internlm2_tool_parser import Internlm2ToolParser ...@@ -8,11 +8,12 @@ from .internlm2_tool_parser import Internlm2ToolParser
from .jamba_tool_parser import JambaToolParser from .jamba_tool_parser import JambaToolParser
from .llama_tool_parser import Llama3JsonToolParser from .llama_tool_parser import Llama3JsonToolParser
from .mistral_tool_parser import MistralToolParser from .mistral_tool_parser import MistralToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser from .pythonic_tool_parser import PythonicToolParser
__all__ = [ __all__ = [
"ToolParser", "ToolParserManager", "Granite20bFCToolParser", "ToolParser", "ToolParserManager", "Granite20bFCToolParser",
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser", "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser", "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
"PythonicToolParser" "PythonicToolParser", "Phi4MiniJsonToolParser"
] ]
# SPDX-License-Identifier: Apache-2.0
import json
import re
from collections.abc import Sequence
from typing import Any, Optional
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module("phi4_mini_json")
class Phi4MiniJsonToolParser(ToolParser):
"""
Tool call parser for phi-4-mini models intended for use with the
examples/tool_chat_template_llama.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser phi4_mini_json
are all set
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
super().__init__(tokenizer)
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: list[dict[str, Any]] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token: str = "functools"
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
logger.debug("Model output: %s", model_output)
pattern = r'functools\[(.*?)\]'
matches = re.search(pattern, model_output, re.DOTALL)
if not matches:
logger.debug("No function calls found")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
function_call_arr: list[dict[str, Any]] = []
try:
json_content = '[' + matches.group(1) + ']'
function_call_arr = json.loads(json_content)
logger.debug("Successfully extracted %d function calls",
len(function_call_arr))
except json.JSONDecodeError as e:
logger.error(
"Failed to parse function calls from model output: %s. "
"Error: %s", model_output, str(e))
tool_calls: list[ToolCall] = [
ToolCall(
id=f"chatcmpl-tool-{random_uuid()}",
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(
raw_function_call["arguments"] if "arguments" in
raw_function_call else
raw_function_call["parameters"])))
for raw_function_call in function_call_arr
]
# get any content before the tool call
ret = ExtractedToolCallInformation(tools_called=True,
tool_calls=tool_calls,
content=None)
return ret
except Exception:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Optional[DeltaMessage]:
return None
...@@ -2,11 +2,16 @@ ...@@ -2,11 +2,16 @@
import asyncio import asyncio
import functools import functools
import os
from fastapi import Request from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from starlette.background import BackgroundTask, BackgroundTasks from starlette.background import BackgroundTask, BackgroundTasks
from vllm.logger import init_logger
logger = init_logger(__name__)
async def listen_for_disconnect(request: Request) -> None: async def listen_for_disconnect(request: Request) -> None:
"""Returns if a disconnect message is received""" """Returns if a disconnect message is received"""
...@@ -68,13 +73,20 @@ def decrement_server_load(request: Request): ...@@ -68,13 +73,20 @@ def decrement_server_load(request: Request):
def load_aware_call(func): def load_aware_call(func):
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args, raw_request: Request, **kwargs): async def wrapper(*args, **kwargs):
raw_request = kwargs.get("raw_request",
args[1] if len(args) > 1 else None)
if raw_request is None:
raise ValueError(
"raw_request required when server load tracking is enabled")
if not raw_request.app.state.enable_server_load_tracking: if not raw_request.app.state.enable_server_load_tracking:
return await func(*args, raw_request=raw_request, **kwargs) return await func(*args, **kwargs)
raw_request.app.state.server_load_metrics += 1 raw_request.app.state.server_load_metrics += 1
try: try:
response = await func(*args, raw_request=raw_request, **kwargs) response = await func(*args, **kwargs)
except Exception: except Exception:
raw_request.app.state.server_load_metrics -= 1 raw_request.app.state.server_load_metrics -= 1
raise raise
...@@ -101,3 +113,24 @@ def load_aware_call(func): ...@@ -101,3 +113,24 @@ def load_aware_call(func):
return response return response
return wrapper return wrapper
def cli_env_setup():
# The safest multiprocessing method is `spawn`, as the default `fork` method
# is not compatible with some accelerators. The default method will be
# changing in future versions of Python, so we should use it explicitly when
# possible.
#
# We only set it here in the CLI entrypoint, because changing to `spawn`
# could break some existing code using vLLM as a library. `spawn` will cause
# unexpected behavior if the code is not protected by
# `if __name__ == "__main__":`.
#
# References:
# - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
# - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing
# - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors
# - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders
if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ:
logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment