Commit fcfc474d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-dev

parents bb94d2e5 296c6572
......@@ -133,8 +133,9 @@ class RPCSleepRequest(Enum):
SLEEP_LEVEL_2 = 2
class RPCWakeUpRequest(Enum):
WAKE_UP = 1
@dataclass
class RPCWakeUpRequest:
tags: Optional[list[str]] = None
@dataclass
......
......@@ -697,10 +697,10 @@ class MQLLMEngineClient(EngineClient):
return await self._send_one_way_rpc_request(
request=RPCSleepRequest(level), socket=self.input_socket)
async def wake_up(self) -> None:
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
"""Wake up the engine"""
return await self._send_one_way_rpc_request(
request=RPCWakeUpRequest.WAKE_UP, socket=self.input_socket)
request=RPCWakeUpRequest(tags), socket=self.input_socket)
async def is_sleeping(self) -> bool:
"""Check whether the engine is sleeping"""
......
......@@ -274,7 +274,7 @@ class MQLLMEngine:
elif isinstance(request, RPCSleepRequest):
self.sleep(request.value)
elif isinstance(request, RPCWakeUpRequest):
self.wake_up()
self.wake_up(request.tags)
elif isinstance(request, RPCIsSleepingRequest):
self._handle_is_sleeping_request(request)
else:
......@@ -415,8 +415,8 @@ class MQLLMEngine:
def sleep(self, level: int = 1) -> None:
self.engine.sleep(level)
def wake_up(self) -> None:
self.engine.wake_up()
def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine.wake_up(tags)
def is_sleeping(self) -> bool:
return self.engine.is_sleeping()
......
......@@ -282,7 +282,7 @@ class EngineClient(ABC):
...
@abstractmethod
async def wake_up(self) -> None:
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
"""Wake up the engine"""
...
......
......@@ -306,7 +306,24 @@ def _detect_content_format(
return "openai"
def _resolve_hf_chat_template(
def resolve_mistral_chat_template(
chat_template: Optional[str],
**kwargs: Any,
) -> Optional[str]:
if chat_template is not None:
logger.warning_once(
"'chat_template' cannot be overridden for mistral tokenizer.")
if "add_generation_prompt" in kwargs:
logger.warning_once(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored.")
if "continue_final_message" in kwargs:
logger.warning_once(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.")
return None
def resolve_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
......@@ -352,7 +369,7 @@ def _resolve_chat_template_content_format(
trust_remote_code: bool,
) -> _ChatTemplateContentFormat:
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
hf_chat_template = _resolve_hf_chat_template(
hf_chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=chat_template,
trust_remote_code=trust_remote_code,
......@@ -470,7 +487,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "<|endoftext10|>" # 200010 (see vocab.json in hf model)
if model_type in ("minicpmo", "minicpmv"):
return "(<image>./</image>)"
if model_type in ("blip-2", "fuyu", "paligemma", "pixtral"):
if model_type in ("blip-2", "fuyu", "paligemma", "pixtral",
"mistral3"):
# These models do not use image tokens in the prompt
return None
if model_type == "qwen":
......@@ -478,10 +496,11 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
"NVLM_D", "h2ovl_chat"):
if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
"internvl_chat", "skywork_chat", "NVLM_D",
"h2ovl_chat"):
return "<image>"
if model_type == "mllama":
if model_type in ("mllama", "llama4"):
return "<|image|>"
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|image_pad|><|vision_end|>"
......@@ -1140,7 +1159,7 @@ def apply_hf_chat_template(
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
hf_chat_template = _resolve_hf_chat_template(
hf_chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=chat_template,
tools=tools,
......@@ -1169,17 +1188,12 @@ def apply_mistral_chat_template(
tools: Optional[list[dict[str, Any]]],
**kwargs: Any,
) -> list[int]:
if chat_template is not None:
logger.warning_once(
"'chat_template' cannot be overridden for mistral tokenizer.")
if "add_generation_prompt" in kwargs:
logger.warning_once(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored.")
if "continue_final_message" in kwargs:
logger.warning_once(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.")
# The return value of resolve_mistral_chat_template is always None,
# and we won't use it.
resolve_mistral_chat_template(
chat_template=chat_template,
**kwargs,
)
return tokenizer.apply_chat_template(
messages=messages,
......
# SPDX-License-Identifier: Apache-2.0
# The CLI entrypoint to vLLM.
import os
import signal
import sys
......@@ -9,11 +8,9 @@ import vllm.entrypoints.cli.benchmark.main
import vllm.entrypoints.cli.openai
import vllm.entrypoints.cli.serve
import vllm.version
from vllm.logger import init_logger
from vllm.entrypoints.utils import cli_env_setup
from vllm.utils import FlexibleArgumentParser
logger = init_logger(__name__)
CMD_MODULES = [
vllm.entrypoints.cli.openai,
vllm.entrypoints.cli.serve,
......@@ -30,29 +27,8 @@ def register_signal_handlers():
signal.signal(signal.SIGTSTP, signal_handler)
def env_setup():
# The safest multiprocessing method is `spawn`, as the default `fork` method
# is not compatible with some accelerators. The default method will be
# changing in future versions of Python, so we should use it explicitly when
# possible.
#
# We only set it here in the CLI entrypoint, because changing to `spawn`
# could break some existing code using vLLM as a library. `spawn` will cause
# unexpected behavior if the code is not protected by
# `if __name__ == "__main__":`.
#
# References:
# - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
# - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing
# - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors
# - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders
if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ:
logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def main():
env_setup()
cli_env_setup()
parser = FlexibleArgumentParser(description="vLLM CLI")
parser.add_argument('-v',
......
......@@ -4,7 +4,6 @@ import argparse
import uvloop
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
......@@ -21,14 +20,9 @@ class ServeSubcommand(CLISubcommand):
@staticmethod
def cmd(args: argparse.Namespace) -> None:
# The default value of `--model`
if args.model != EngineArgs.model:
raise ValueError(
"With `vllm serve`, you should provide the model as a "
"positional argument instead of via the `--model` option.")
# EngineArgs expects the model name to be passed as --model.
args.model = args.model_tag
# If model is specified in CLI (as positional arg), it takes precedence
if hasattr(args, 'model_tag') and args.model_tag is not None:
args.model = args.model_tag
uvloop.run(run_server(args))
......@@ -41,10 +35,12 @@ class ServeSubcommand(CLISubcommand):
serve_parser = subparsers.add_parser(
"serve",
help="Start the vLLM OpenAI Compatible API server",
usage="vllm serve <model_tag> [options]")
usage="vllm serve [model_tag] [options]")
serve_parser.add_argument("model_tag",
type=str,
help="The model tag to serve")
nargs='?',
help="The model tag to serve "
"(optional if specified in config)")
serve_parser.add_argument(
"--config",
type=str,
......
......@@ -492,8 +492,8 @@ class LLM:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
executor = self.llm_engine.model_executor
return executor.collective_rpc(method, timeout, args, kwargs)
return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
"""
......@@ -1200,26 +1200,35 @@ class LLM:
The caller should guarantee that no requests are being processed
during the sleep period, before `wake_up` is called.
:param level: The sleep level. Level 1 sleep will offload the model
weights and discard the kv cache. The content of kv cache is
forgotten. Level 1 sleep is good for sleeping and waking up the
engine to run the same model again. The model weights are backed
up in CPU memory. Please make sure there's enough CPU memory to
store the model weights. Level 2 sleep will discard both the model
weights and the kv cache. The content of both the model weights
and kv cache is forgotten. Level 2 sleep is good for sleeping and
waking up the engine to run a different model or update the model,
where previous model weights are not needed. It reduces CPU memory
pressure.
Args:
level: The sleep level. Level 1 sleep will offload the model
weights and discard the kv cache. The content of kv cache
is forgotten. Level 1 sleep is good for sleeping and waking
up the engine to run the same model again. The model weights
are backed up in CPU memory. Please make sure there's enough
CPU memory to store the model weights. Level 2 sleep will
discard both the model weights and the kv cache. The content
of both the model weights and kv cache is forgotten. Level 2
sleep is good for sleeping and waking up the engine to run a
different model or update the model, where previous model
weights are not needed. It reduces CPU memory pressure.
"""
self.reset_prefix_cache()
self.llm_engine.sleep(level=level)
def wake_up(self):
def wake_up(self, tags: Optional[list[str]] = None):
"""
Wake up the engine from sleep mode. See the :meth:`sleep` method
for more details."""
self.llm_engine.wake_up()
for more details.
Args:
tags: An optional list of tags to reallocate the engine memory
for specific memory allocations. Values must be in
("weights", "kv_cache",). If None, all memory is reallocated.
wake_up should be called with all tags (or None) before the
engine is used again.
"""
self.llm_engine.wake_up(tags)
# LEGACY
def _convert_v1_inputs(
......
......@@ -24,6 +24,7 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import State
from starlette.routing import Mount
from typing_extensions import assert_never
......@@ -35,7 +36,9 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.chat_utils import (load_chat_template,
resolve_hf_chat_template,
resolve_mistral_chat_template)
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
......@@ -65,7 +68,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TranscriptionRequest,
TranscriptionResponse,
UnloadLoRAAdapterRequest)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
......@@ -80,10 +82,13 @@ from vllm.entrypoints.openai.serving_tokenization import (
from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import load_aware_call, with_cancellation
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
with_cancellation)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address, set_ulimit)
......@@ -307,6 +312,7 @@ def mount_metrics(app: FastAPI):
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (CollectorRegistry, make_asgi_app,
multiprocess)
from prometheus_fastapi_instrumentator import Instrumentator
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
if prometheus_multiproc_dir_path is not None:
......@@ -314,6 +320,16 @@ def mount_metrics(app: FastAPI):
prometheus_multiproc_dir_path)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
Instrumentator(
excluded_handlers=[
"/metrics",
"/health",
"/load",
"/ping",
"/version",
],
registry=registry,
).add().instrument(app).expose(app)
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
......@@ -689,7 +705,6 @@ if envs.VLLM_SERVER_DEV_MODE:
async def sleep(raw_request: Request):
# get POST params
level = raw_request.query_params.get("level", "1")
logger.info("sleep the engine with level %s", level)
await engine_client(raw_request).sleep(int(level))
# FIXME: in v0 with frontend multiprocessing, the sleep command
# is sent but does not finish yet when we return a response.
......@@ -697,8 +712,12 @@ if envs.VLLM_SERVER_DEV_MODE:
@router.post("/wake_up")
async def wake_up(raw_request: Request):
logger.info("wake up the engine")
await engine_client(raw_request).wake_up()
tags = raw_request.query_params.getlist("tags")
if tags == []:
# set to None to wake up all tags if no tags are provided
tags = None
logger.info("wake up the engine with tags: %s", tags)
await engine_client(raw_request).wake_up(tags)
# FIXME: in v0 with frontend multiprocessing, the wake-up command
# is sent but does not finish yet when we return a response.
return Response(status_code=200)
......@@ -814,7 +833,8 @@ def build_app(args: Namespace) -> FastAPI:
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)
if token := envs.VLLM_API_KEY or args.api_key:
# Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
if token := args.api_key or envs.VLLM_API_KEY:
@app.middleware("http")
async def authentication(request: Request, call_next):
......@@ -843,6 +863,21 @@ def build_app(args: Namespace) -> FastAPI:
response.headers["X-Request-Id"] = request_id
return response
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
logger.warning("CAUTION: Enabling log response in the API Server. "
"This can include sensitive information and should be "
"avoided in production.")
@app.middleware("http")
async def log_response(request: Request, call_next):
response = await call_next(request)
response_body = [
section async for section in response.body_iterator
]
response.body_iterator = iterate_in_threadpool(iter(response_body))
logger.info("response_body={%s}", response_body[0].decode())
return response
for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
......@@ -883,8 +918,26 @@ async def init_app_state(
resolved_chat_template = load_chat_template(args.chat_template)
if resolved_chat_template is not None:
logger.info("Using supplied chat template:\n%s",
resolved_chat_template)
# Get the tokenizer to check official template
tokenizer = await engine_client.get_tokenizer()
if isinstance(tokenizer, MistralTokenizer):
# The warning is logged in resolve_mistral_chat_template.
resolved_chat_template = resolve_mistral_chat_template(
chat_template=resolved_chat_template)
else:
hf_chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=None,
trust_remote_code=model_config.trust_remote_code)
if hf_chat_template != resolved_chat_template:
logger.warning(
"Using supplied chat template: %s\n"
"It is different from official chat template '%s'. "
"This discrepancy may lead to performance degradation.",
resolved_chat_template, args.model)
state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
......@@ -1048,15 +1101,17 @@ async def run_server(args, **uvicorn_kwargs) -> None:
)
# NB: Await server shutdown only after the backend context is exited
await shutdown_task
sock.close()
try:
await shutdown_task
finally:
sock.close()
if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/entrypoints/cli/main.py for CLI
# entrypoints.
cli_env_setup()
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
......
......@@ -247,7 +247,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
' The default of None means unlimited.')
parser.add_argument(
"--disable-fastapi-docs",
......
......@@ -61,7 +61,7 @@ class OpenAIBaseModel(BaseModel):
field_names = set()
for field_name, field in cls.model_fields.items():
field_names.add(field_name)
if alias := getattr(field, 'alias', None):
if alias := getattr(field, "alias", None):
field_names.add(alias)
cls.field_names = field_names
......@@ -70,7 +70,8 @@ class OpenAIBaseModel(BaseModel):
logger.warning(
"The following fields were present in the request "
"but ignored: %s",
data.keys() - field_names)
data.keys() - field_names,
)
return result
......@@ -234,8 +235,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
temperature: Optional[float] = None
top_p: Optional[float] = None
tools: Optional[list[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
ChatCompletionNamedToolChoiceParam]] = "none"
tool_choice: Optional[Union[
Literal["none"],
Literal["auto"],
Literal["required"],
ChatCompletionNamedToolChoiceParam,
]] = "none"
# NOTE this will be ignored by vLLM -- the model determines the behavior
parallel_tool_calls: Optional[bool] = False
......@@ -340,24 +345,28 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"))
"'outlines' / 'lm-format-enforcer'"),
)
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
"for guided json decoding."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."))
"through out the inference process and return in response."),
)
logits_processors: Optional[LogitsProcessors] = Field(
default=None,
description=(
......@@ -415,13 +424,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output)
include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params(
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None,
) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
......@@ -475,7 +486,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
whitespace_pattern=self.guided_whitespace_pattern,
)
return SamplingParams.from_optional(
n=self.n,
......@@ -522,6 +534,41 @@ class ChatCompletionRequest(OpenAIBaseModel):
tool = tools[tool_name]
return tool.parameters
if self.tool_choice == "required":
# Pydantic schema generation cannot be used since the JSON schema
# has to be constructed for a specific instantiation of a tool list
# so that parameters of a function are correctly generated
# based on the chosen function name
def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
return {
"properties": {
"name": {
"type": "string",
"enum": [tool.function.name]
},
# parameters are always generated as '{}' in the final
# output if they are missing from the request
# (i.e. are None or '{}') so the schema is
# updated to produce an empty object in that case
"parameters": tool.function.parameters
if tool.function.parameters else {
"type": "object",
"properties": {}
}
},
"required": ["name", "parameters"]
}
json_schema = {
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": [get_tool_schema(tool) for tool in self.tools]
}
}
return json_schema
return None
@model_validator(mode="before")
......@@ -572,8 +619,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
# you can only either use guided decoding or tools, not both
if guide_count > 1 and data.get("tool_choice",
"none") not in ("none", "auto"):
if guide_count > 1 and data.get("tool_choice", "none") not in (
"none",
"auto",
"required",
):
raise ValueError(
"You can only either use guided decoding or tools, not both.")
return data
......@@ -602,12 +652,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
"When using `tool_choice`, `tools` must be set.")
# make sure that tool choice is either a named tool
# OR that it's set to "auto"
if data["tool_choice"] != "auto" and not isinstance(
data["tool_choice"], dict):
raise ValueError(
"`tool_choice` must either be a named tool, \"auto\", "
"or \"none\".")
# OR that it's set to "auto" or "required"
if data["tool_choice"] not in [
"auto", "required"
] and not isinstance(data["tool_choice"], dict):
raise NotImplementedError(
f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\
'Only named tools, "none", "auto" or "required" '\
'are supported.'
)
# ensure that if "tool_choice" is specified as an object,
# it matches a valid tool
......@@ -722,18 +775,21 @@ class CompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"))
"'outlines' / 'lm-format-enforcer'"),
)
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
"for guided json decoding."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
logits_processors: Optional[LogitsProcessors] = Field(
default=None,
description=(
......@@ -745,6 +801,7 @@ class CompletionRequest(OpenAIBaseModel):
"arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}."))
return_tokens_as_token_ids: Optional[bool] = Field(
default=None,
description=(
......@@ -789,13 +846,15 @@ class CompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output)
include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params(
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None,
) -> SamplingParams:
max_tokens = self.max_tokens
if default_sampling_params is None:
......@@ -844,7 +903,8 @@ class CompletionRequest(OpenAIBaseModel):
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
whitespace_pattern=self.guided_whitespace_pattern,
)
return SamplingParams.from_optional(
n=self.n,
......@@ -942,7 +1002,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
# doc: end-embedding-extra-params
......@@ -995,7 +1056,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
# doc: end-chat-embedding-extra-params
@model_validator(mode="before")
......@@ -1034,7 +1096,8 @@ class ScoreRequest(OpenAIBaseModel):
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
# doc: end-score-extra-params
......@@ -1059,7 +1122,8 @@ class RerankRequest(OpenAIBaseModel):
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
# doc: end-rerank-extra-params
......@@ -1238,6 +1302,9 @@ class ChatCompletionLogProb(OpenAIBaseModel):
class ChatCompletionLogProbsContent(ChatCompletionLogProb):
# Workaround: redefine fields name cache so that it's not
# shared with the super class.
field_names: ClassVar[Optional[set[str]]] = None
top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list)
......
......@@ -2,13 +2,16 @@
import asyncio
import json
import re
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import Callable, Final, Optional, Union
import jinja2
import partial_json_parser
from fastapi import Request
from pydantic import TypeAdapter
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
......@@ -21,10 +24,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
DeltaToolCall, ErrorResponse, FunctionCall, FunctionDefinition,
PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
......@@ -33,6 +34,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall)
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
......@@ -151,12 +153,6 @@ class OpenAIServingChat(OpenAIServing):
tool_parser = self.tool_parser
# validation for OpenAI tools
# tool_choice = "required" is not supported
if request.tool_choice == "required":
return self.create_error_response(
"tool_choice = \"required\" is not supported!")
if isinstance(tokenizer, MistralTokenizer):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
......@@ -197,16 +193,8 @@ class OpenAIServingChat(OpenAIServing):
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except TypeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except RuntimeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except jinja2.TemplateError as e:
except (ValueError, TypeError, RuntimeError,
jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
......@@ -286,6 +274,122 @@ class OpenAIServingChat(OpenAIServing):
return self.response_role
return request.messages[-1]["role"]
@staticmethod
def _bracket_level(s: str, opening='{', closing='}') -> int:
"""
Calculate the current level of nested brackets in a given string.
"""
level = 0
for char in s:
if char == opening:
level += 1
elif char == closing:
level -= 1
return level
@staticmethod
def _filter_delta_text(delta_text: str,
previous_text: str) -> tuple[str, bool]:
# remove last '},' of the tool definition stemming from the
# "name"/"parameters" outer object or closing ']' of the tool list
# count occurrences of opening and closing curly braces and
# once level 0 is reached stop outputting text
# if 0 is reached while parsing the delta_text we know the current
# tool will finish in this current iteration
bracket_level = OpenAIServingChat._bracket_level(previous_text)
updated_delta, passed_zero = "", False
for c in delta_text:
if c == '{':
bracket_level += 1
passed_zero = bracket_level == 0
elif c == '}':
bracket_level -= 1
passed_zero = bracket_level == 0
if bracket_level != 0:
updated_delta += c
else:
# if a comma is reached at level 0 we can stop
if c == ',':
break
return updated_delta, passed_zero
def extract_tool_call_required_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
function_name_returned: bool,
) -> tuple[Optional[DeltaMessage], bool]:
try:
obj = partial_json_parser.loads(current_text)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
obj = None
# check if the current text is a valid array
# containing a partial tool calling object
# if not repeat
if obj is None or not isinstance(obj, list) or not len(obj) > 0:
function_name_returned = False
delta_message = None
else:
_, finishes_previous_tool = OpenAIServingChat._filter_delta_text(
delta_text, previous_text)
# take the last tool call from the generated list
current_tool_call = obj[-1]
# once parameters have been generated the name is complete as well
if not finishes_previous_tool and ("name" not in current_tool_call
or "parameters"
not in current_tool_call):
function_name_returned = False
delta_message = None
else:
if not function_name_returned:
# get partly generated arguments from the latest tool call
param_match = re.search(r'.*"parameters":\s*(.*)',
current_text)
arguments = param_match.group(1) if param_match else ""
arguments, _ = OpenAIServingChat._filter_delta_text(
arguments, previous_text)
# if this iteration finishes a previous tool call but a
# new incomplete tool is already generated, take the
# previous from the list
if (finishes_previous_tool
and "parameters" not in current_tool_call):
current_tool_call = obj[-2]
function_name_returned = True
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=current_tool_call["name"],
arguments=arguments),
index=len(obj) - 1,
type="function")
])
else:
delta_text, _ = OpenAIServingChat._filter_delta_text(
delta_text, previous_text)
if delta_text != "":
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(
function=DeltaFunctionCall(
# OpenAI API returns None
# instead of name every time
name=None,
arguments=delta_text),
index=len(obj) - 1,
type="function")
])
else:
delta_message = None
return delta_message, function_name_returned
async def chat_completion_stream_generator(
self,
request: ChatCompletionRequest,
......@@ -321,6 +425,7 @@ class OpenAIServingChat(OpenAIServing):
self._should_stream_with_reasoning_parsing(request))
all_previous_token_ids: Optional[list[list[int]]]
function_name_returned: Optional[list[bool]] = None
# Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration.
......@@ -331,6 +436,10 @@ class OpenAIServingChat(OpenAIServing):
# For reasoning parser and tool call all enabled
added_content_delta_arr = [False] * num_choices
reasoning_end_arr = [False] * num_choices
elif request.tool_choice == "required":
previous_texts = [""] * num_choices
function_name_returned = [False] * num_choices
all_previous_token_ids = None
else:
previous_texts, all_previous_token_ids = None, None
......@@ -530,6 +639,23 @@ class OpenAIServingChat(OpenAIServing):
index=i)
])
elif request.tool_choice == "required":
assert previous_texts is not None
assert function_name_returned is not None
previous_text = previous_texts[i]
current_text = previous_text + delta_text
fn_name_returned = function_name_returned[i]
delta_message, function_name_returned[i] = (
self.extract_tool_call_required_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=fn_name_returned))
# update the previous values for the next iteration
previous_texts[i] = current_text
# handle streaming deltas for tools with "auto" tool choice
# and reasoning parser
elif tool_choice_auto and self.enable_reasoning:
......@@ -830,10 +956,10 @@ class OpenAIServingChat(OpenAIServing):
# if auto tools are not enabled, and a named tool choice using
# outlines is not being used
if (not self.enable_auto_tools
or not self.tool_parser) and not isinstance(
request.tool_choice,
ChatCompletionNamedToolChoiceParam):
if (not self.enable_auto_tools or not self.tool_parser) and \
(not isinstance(request.tool_choice,
ChatCompletionNamedToolChoiceParam
) and request.tool_choice != "required"):
message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=content)
......@@ -854,6 +980,24 @@ class OpenAIServingChat(OpenAIServing):
arguments=content))
])
elif request.tool_choice and request.tool_choice == "required":
tool_call_class = MistralToolCall if isinstance(
tokenizer, MistralTokenizer) else ToolCall
# the fields of FunctionDefinition are a superset of the
# tool call outputs and can be used for parsing
tool_calls = TypeAdapter(
list[FunctionDefinition]).validate_json(output.text)
message = ChatMessage(
role=role,
content="",
tool_calls=[
tool_call_class(function=FunctionCall(
name=tool_call.name,
arguments=json.dumps(tool_call.parameters)))
for tool_call in tool_calls
])
# if the request doesn't use tool choice
# OR specifies to not use a tool
elif not request.tool_choice or request.tool_choice == "none":
......
......@@ -139,10 +139,7 @@ class OpenAIServingEmbedding(OpenAIServing):
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except TypeError as e:
except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
......
......@@ -537,7 +537,7 @@ class OpenAIServing:
lora_request: Optional[LoRARequest] = None) -> str:
if lora_request:
return lora_request.lora_name
if model_name is None:
if not model_name:
return self.models.base_model_paths[0].name
return model_name
......
......@@ -162,7 +162,7 @@ class OpenAIServingModels:
except BaseException as e:
error_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
if isinstance(e, ValueError) and "No adapter found" in str(e):
if "No adapter found" in str(e):
error_type = "NotFoundError"
status_code = HTTPStatus.NOT_FOUND
......
......@@ -136,13 +136,7 @@ class OpenAIServingPooling(OpenAIServing):
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except TypeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except jinja2.TemplateError as e:
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
......
......@@ -89,13 +89,7 @@ class OpenAIServingTokenization(OpenAIServing):
request.prompt,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except TypeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except jinja2.TemplateError as e:
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
......
......@@ -8,11 +8,12 @@ from .internlm2_tool_parser import Internlm2ToolParser
from .jamba_tool_parser import JambaToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .mistral_tool_parser import MistralToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser
__all__ = [
"ToolParser", "ToolParserManager", "Granite20bFCToolParser",
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
"PythonicToolParser"
"PythonicToolParser", "Phi4MiniJsonToolParser"
]
# SPDX-License-Identifier: Apache-2.0
import json
import re
from collections.abc import Sequence
from typing import Any, Optional
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module("phi4_mini_json")
class Phi4MiniJsonToolParser(ToolParser):
"""
Tool call parser for phi-4-mini models intended for use with the
examples/tool_chat_template_llama.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser phi4_mini_json
are all set
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
super().__init__(tokenizer)
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: list[dict[str, Any]] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token: str = "functools"
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
logger.debug("Model output: %s", model_output)
pattern = r'functools\[(.*?)\]'
matches = re.search(pattern, model_output, re.DOTALL)
if not matches:
logger.debug("No function calls found")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
function_call_arr: list[dict[str, Any]] = []
try:
json_content = '[' + matches.group(1) + ']'
function_call_arr = json.loads(json_content)
logger.debug("Successfully extracted %d function calls",
len(function_call_arr))
except json.JSONDecodeError as e:
logger.error(
"Failed to parse function calls from model output: %s. "
"Error: %s", model_output, str(e))
tool_calls: list[ToolCall] = [
ToolCall(
id=f"chatcmpl-tool-{random_uuid()}",
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(
raw_function_call["arguments"] if "arguments" in
raw_function_call else
raw_function_call["parameters"])))
for raw_function_call in function_call_arr
]
# get any content before the tool call
ret = ExtractedToolCallInformation(tools_called=True,
tool_calls=tool_calls,
content=None)
return ret
except Exception:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Optional[DeltaMessage]:
return None
......@@ -2,11 +2,16 @@
import asyncio
import functools
import os
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.background import BackgroundTask, BackgroundTasks
from vllm.logger import init_logger
logger = init_logger(__name__)
async def listen_for_disconnect(request: Request) -> None:
"""Returns if a disconnect message is received"""
......@@ -68,13 +73,20 @@ def decrement_server_load(request: Request):
def load_aware_call(func):
@functools.wraps(func)
async def wrapper(*args, raw_request: Request, **kwargs):
async def wrapper(*args, **kwargs):
raw_request = kwargs.get("raw_request",
args[1] if len(args) > 1 else None)
if raw_request is None:
raise ValueError(
"raw_request required when server load tracking is enabled")
if not raw_request.app.state.enable_server_load_tracking:
return await func(*args, raw_request=raw_request, **kwargs)
return await func(*args, **kwargs)
raw_request.app.state.server_load_metrics += 1
try:
response = await func(*args, raw_request=raw_request, **kwargs)
response = await func(*args, **kwargs)
except Exception:
raw_request.app.state.server_load_metrics -= 1
raise
......@@ -101,3 +113,24 @@ def load_aware_call(func):
return response
return wrapper
def cli_env_setup():
# The safest multiprocessing method is `spawn`, as the default `fork` method
# is not compatible with some accelerators. The default method will be
# changing in future versions of Python, so we should use it explicitly when
# possible.
#
# We only set it here in the CLI entrypoint, because changing to `spawn`
# could break some existing code using vLLM as a library. `spawn` will cause
# unexpected behavior if the code is not protected by
# `if __name__ == "__main__":`.
#
# References:
# - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
# - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing
# - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors
# - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders
if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ:
logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment