Commit cc7f22a8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.1' into v0.9.1-ori

parents b9ea0c09 b6553be1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import signal
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import warnings
......@@ -45,8 +46,7 @@ from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of)
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of
if TYPE_CHECKING:
from vllm.v1.metrics.reader import Metric
......@@ -143,12 +143,6 @@ class LLM:
DEPRECATE_LEGACY: ClassVar[bool] = True
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
"""
A flag to toggle whether to deprecate positional arguments in
[LLM.__init__][].
"""
@classmethod
@contextmanager
def deprecate_legacy_api(cls):
......@@ -158,16 +152,11 @@ class LLM:
cls.DEPRECATE_LEGACY = False
@deprecate_args(
start_index=2, # Ignore self and model
is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
additional_message=(
"All positional arguments other than `model` will be "
"replaced with keyword arguments in an upcoming version."),
)
def __init__(
self,
model: str,
*,
task: TaskOption = "auto",
tokenizer: Optional[str] = None,
tokenizer_mode: TokenizerMode = "auto",
skip_tokenizer_init: bool = False,
......@@ -189,8 +178,6 @@ class LLM:
hf_token: Optional[Union[bool, str]] = None,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
override_pooler_config: Optional[PoolerConfig] = None,
compilation_config: Optional[Union[int, dict[str, Any]]] = None,
**kwargs,
......@@ -207,6 +194,9 @@ class LLM:
if isinstance(worker_cls, type):
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
if hf_overrides is None:
hf_overrides = {}
if compilation_config is not None:
if isinstance(compilation_config, int):
compilation_config_instance = CompilationConfig(
......@@ -218,7 +208,7 @@ class LLM:
else:
compilation_config_instance = compilation_config
else:
compilation_config_instance = None
compilation_config_instance = CompilationConfig()
engine_args = EngineArgs(
model=model,
......@@ -291,7 +281,7 @@ class LLM:
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
*,
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
......@@ -307,7 +297,7 @@ class LLM:
sampling_params: Optional[Union[SamplingParams,
list[SamplingParams]]] = None,
prompt_token_ids: Optional[list[int]] = None,
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
......@@ -323,7 +313,7 @@ class LLM:
sampling_params: Optional[Union[SamplingParams,
list[SamplingParams]]] = None,
prompt_token_ids: Optional[list[list[int]]] = None,
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
......@@ -340,7 +330,7 @@ class LLM:
list[SamplingParams]]] = None,
*,
prompt_token_ids: list[int],
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
......@@ -357,7 +347,7 @@ class LLM:
list[SamplingParams]]] = None,
*,
prompt_token_ids: list[list[int]],
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
......@@ -372,7 +362,7 @@ class LLM:
prompts: None,
sampling_params: None,
prompt_token_ids: Union[list[int], list[list[int]]],
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
......@@ -392,7 +382,7 @@ class LLM:
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
......@@ -414,7 +404,10 @@ class LLM:
When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
use_tqdm: Whether to use tqdm to display the progress bar.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
......@@ -519,10 +512,27 @@ class LLM:
executor = self.llm_engine.model_executor
return executor.apply_model(func)
def _get_beam_search_lora_requests(
self,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
prompts: list[Union[TokensPrompt, TextPrompt]],
) -> list[Optional[LoRARequest]]:
"""Get the optional lora request corresponding to each prompt."""
if isinstance(lora_request,
Sequence) and len(lora_request) != len(prompts):
raise ValueError(
"Lora request list should be the same length as the prompts")
if lora_request is None or isinstance(lora_request, LoRARequest):
return [lora_request] * len(prompts)
raise TypeError(f"Invalid lora_request type {type(lora_request)}")
def beam_search(
self,
prompts: list[Union[TokensPrompt, TextPrompt]],
params: BeamSearchParams,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
) -> list[BeamSearchOutput]:
"""
Generate sequences using beam search.
......@@ -531,6 +541,7 @@ class LLM:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
params: The beam search parameters.
lora_request: LoRA request to use for generation, if any.
"""
# TODO: how does beam search work together with length penalty,
# frequency, penalty, and stopping criteria, etc.?
......@@ -540,6 +551,9 @@ class LLM:
ignore_eos = params.ignore_eos
length_penalty = params.length_penalty
lora_requests = self._get_beam_search_lora_requests(
lora_request, prompts)
def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
......@@ -567,7 +581,7 @@ class LLM:
temperature=temperature)
instances: list[BeamSearchInstance] = []
for prompt in prompts:
for lora_req, prompt in zip(lora_requests, prompts):
# Add multimodal processor kwargs & data
mm_kwargs = {}
if "multi_modal_data" in prompt:
......@@ -583,7 +597,12 @@ class LLM:
prompt_tokens = tokenizer.encode(prompt["prompt"])
instances.append(
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
BeamSearchInstance(
prompt_tokens,
lora_request=lora_req,
logprobs=None,
**mm_kwargs,
), )
for _ in range(max_tokens):
all_beams: list[BeamSearchSequence] = list(
......@@ -597,15 +616,17 @@ class LLM:
if len(all_beams) == 0:
break
prompts_batch = [
create_tokens_prompt_from_beam(beam) for beam in all_beams
]
# create the corresponding batch entries for prompt & optional lora
prompts_batch, lora_req_batch = zip(
*[(create_tokens_prompt_from_beam(beam), beam.lora_request)
for beam in all_beams])
# only runs for one step
# we don't need to use tqdm here
output = self.generate(prompts_batch,
sampling_params=beam_search_params,
use_tqdm=False)
use_tqdm=False,
lora_request=lora_req_batch)
for (start, end), instance in zip(instance_start_and_end,
instances):
......@@ -623,6 +644,7 @@ class LLM:
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
multi_modal_data=current_beam.multi_modal_data,
......@@ -659,7 +681,7 @@ class LLM:
list[list[ChatCompletionMessageParam]]],
sampling_params: Optional[Union[SamplingParams,
list[SamplingParams]]] = None,
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
......@@ -690,7 +712,10 @@ class LLM:
is a single value, it is applied to every prompt. When it
is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
use_tqdm: Whether to use tqdm to display the progress bar.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
chat_template: The template to use for structuring the chat.
If not provided, the model's default chat template will be used.
......@@ -804,7 +829,7 @@ class LLM:
Sequence[PoolingParams]]] = None,
*,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[PoolingRequestOutput]:
......@@ -819,7 +844,7 @@ class LLM:
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[list[int]] = None,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[PoolingRequestOutput]:
......@@ -834,7 +859,7 @@ class LLM:
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[list[list[int]]] = None,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[PoolingRequestOutput]:
......@@ -850,7 +875,7 @@ class LLM:
*,
prompt_token_ids: list[int],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[PoolingRequestOutput]:
......@@ -866,7 +891,7 @@ class LLM:
*,
prompt_token_ids: list[list[int]],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[PoolingRequestOutput]:
......@@ -880,7 +905,7 @@ class LLM:
pooling_params: None,
prompt_token_ids: Union[list[int], list[list[int]]],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[PoolingRequestOutput]:
......@@ -899,7 +924,7 @@ class LLM:
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[PoolingRequestOutput]:
......@@ -916,7 +941,10 @@ class LLM:
for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
......@@ -986,7 +1014,7 @@ class LLM:
/,
*,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
......@@ -1005,7 +1033,10 @@ class LLM:
for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
......@@ -1032,7 +1063,7 @@ class LLM:
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ClassificationRequestOutput]:
......@@ -1047,7 +1078,10 @@ class LLM:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompts.
use_tqdm: Whether to use tqdm to display the progress bar.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
......@@ -1073,7 +1107,7 @@ class LLM:
text_1: list[Union[str, TextPrompt, TokensPrompt]],
text_2: list[Union[str, TextPrompt, TokensPrompt]],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]:
......@@ -1107,7 +1141,7 @@ class LLM:
text_1: list[str],
text_2: list[str],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]:
......@@ -1159,7 +1193,7 @@ class LLM:
/,
*,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]:
......@@ -1179,7 +1213,10 @@ class LLM:
text_2: The texts to pair with the query to form the input
to the LLM. See [PromptType][vllm.inputs.PromptType] for
more details about the format of each prompts.
use_tqdm: Whether to use tqdm to display the progress bar.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
......@@ -1360,7 +1397,7 @@ class LLM:
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
*,
use_tqdm: bool,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
......@@ -1398,7 +1435,8 @@ class LLM:
# Add requests to the engine.
it = prompts
if use_tqdm:
it = tqdm(it, desc="Adding requests")
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests")
for i, prompt in enumerate(it):
self._add_request(
......@@ -1455,12 +1493,15 @@ class LLM:
return params
def _run_engine(
self, *, use_tqdm: bool
self,
*,
use_tqdm: Union[bool, Callable[..., tqdm]] = True
) -> list[Union[RequestOutput, PoolingRequestOutput]]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
pbar = tqdm_func(
total=num_requests,
desc="Processed prompts",
dynamic_ncols=True,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import atexit
import gc
import importlib
import inspect
import json
import multiprocessing
import os
import signal
......@@ -16,8 +18,7 @@ from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from json import JSONDecodeError
from typing import Annotated, Optional, Union
from typing import Annotated, Any, Optional
import prometheus_client
import regex as re
......@@ -26,6 +27,8 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import State
from starlette.routing import Mount
......@@ -59,9 +62,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse,
EmbeddingResponse, ErrorResponse,
LoadLoRAAdapterRequest,
PoolingChatRequest,
PoolingCompletionRequest,
......@@ -99,10 +100,9 @@ from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address, set_ulimit)
from vllm.v1.metrics.prometheus import get_prometheus_registry
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
prometheus_multiproc_dir: tempfile.TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
......@@ -144,14 +144,17 @@ async def lifespan(app: FastAPI):
@asynccontextmanager
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[EngineClient]:
args: Namespace,
client_config: Optional[dict[str, Any]] = None,
) -> AsyncIterator[EngineClient]:
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args)
async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine:
engine_args, args.disable_frontend_multiprocessing,
client_config) as engine:
yield engine
......@@ -159,6 +162,7 @@ async def build_async_engine_client(
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
client_config: Optional[dict[str, Any]] = None,
) -> AsyncIterator[EngineClient]:
"""
Create EngineClient, either:
......@@ -181,12 +185,16 @@ async def build_async_engine_client_from_engine_args(
from vllm.v1.engine.async_llm import AsyncLLM
async_llm: Optional[AsyncLLM] = None
client_index = client_config.pop(
"client_index") if client_config else 0
try:
async_llm = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
disable_log_requests=engine_args.disable_log_requests,
disable_log_stats=engine_args.disable_log_stats)
disable_log_stats=engine_args.disable_log_stats,
client_addresses=client_config,
client_index=client_index)
# Don't keep the dummy data in memory
await async_llm.reset_mm_cache()
......@@ -320,22 +328,9 @@ class PrometheusResponse(Response):
def mount_metrics(app: FastAPI):
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (REGISTRY, CollectorRegistry, make_asgi_app,
multiprocess)
from prometheus_fastapi_instrumentator import Instrumentator
registry = REGISTRY
"""Mount prometheus metrics to a FastAPI app."""
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
if prometheus_multiproc_dir_path is not None:
logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
prometheus_multiproc_dir_path)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
registry = get_prometheus_registry()
# `response_class=PrometheusResponse` is needed to return an HTTP response
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
......@@ -627,36 +622,9 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request)
if handler is None:
fallback_handler = pooling(raw_request)
if fallback_handler is None:
return base(raw_request).create_error_response(
message="The model does not support Embeddings API")
logger.warning(
"Embeddings API will become exclusive to embedding models "
"in a future release. To return the hidden states directly, "
"use the Pooling API (`/pooling`) instead.")
res = await fallback_handler.create_pooling(request, raw_request)
generator: Union[ErrorResponse, EmbeddingResponse]
if isinstance(res, PoolingResponse):
generator = EmbeddingResponse(
id=res.id,
object=res.object,
created=res.created,
model=res.model,
data=[
EmbeddingResponseData(
index=d.index,
embedding=d.data, # type: ignore
) for d in res.data
],
usage=res.usage,
)
else:
generator = res
else:
generator = await handler.create_embedding(request, raw_request)
if isinstance(generator, ErrorResponse):
......@@ -961,7 +929,7 @@ async def invocations(raw_request: Request):
"""
try:
body = await raw_request.json()
except JSONDecodeError as e:
except json.JSONDecodeError as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"JSON decode error: {e}") from e
......@@ -1034,6 +1002,18 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
return Response(status_code=200, content=response)
def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
if not log_config_file:
return None
try:
with open(log_config_file) as f:
return json.load(f)
except Exception as e:
logger.warning("Failed to load log config from file %s: error %s",
log_config_file, e)
return None
def build_app(args: Namespace) -> FastAPI:
if args.disable_fastapi_docs:
app = FastAPI(openapi_url=None,
......@@ -1285,13 +1265,7 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket:
return sock
async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
log_non_default_args(args)
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
def validate_api_server_args(args):
valid_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \
and args.tool_call_parser not in valid_tool_parses:
......@@ -1305,6 +1279,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
f"invalid reasoning parser: {args.reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
def setup_server(args):
"""Validate API server args, set up signal handler, create socket
ready to serve."""
logger.info("vLLM API server version %s", VLLM_VERSION)
log_non_default_args(args)
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
validate_api_server_args(args)
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204
......@@ -1321,22 +1308,46 @@ async def run_server(args, **uvicorn_kwargs) -> None:
signal.signal(signal.SIGTERM, signal_handler)
async with build_async_engine_client(args) as engine_client:
addr, port = sock_addr
is_ssl = args.ssl_keyfile and args.ssl_certfile
host_part = f"[{addr}]" if is_valid_ipv6_address(
addr) else addr or "0.0.0.0"
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
return listen_address, sock
async def run_server(args, **uvicorn_kwargs) -> None:
"""Run a single-worker API server."""
listen_address, sock = setup_server(args)
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
async def run_server_worker(listen_address,
sock,
args,
client_config=None,
**uvicorn_kwargs) -> None:
"""Run a single API server worker."""
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
server_index = client_config.get("client_index", 0) if client_config else 0
# Load logging config for uvicorn if specified
log_config = load_log_config(args.log_config_file)
if log_config is not None:
uvicorn_kwargs['log_config'] = log_config
async with build_async_engine_client(args, client_config) as engine_client:
app = build_app(args)
vllm_config = await engine_client.get_vllm_config()
await init_app_state(engine_client, vllm_config, app.state, args)
def _listen_addr(a: str) -> str:
if is_valid_ipv6_address(a):
return '[' + a + ']'
return a or "0.0.0.0"
is_ssl = args.ssl_keyfile and args.ssl_certfile
logger.info("Starting vLLM API server on http%s://%s:%d",
"s" if is_ssl else "", _listen_addr(sock_addr[0]),
sock_addr[1])
logger.info("Starting vLLM API server %d on %s", server_index,
listen_address)
shutdown_task = await serve_http(
app,
sock=sock,
......@@ -1347,7 +1358,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
# NOTE: When the 'disable_uvicorn_access_log' value is True,
# no access log will be output.
access_log=not args.disable_uvicorn_access_log,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file contains the command line arguments for the vLLM's
OpenAI-compatible server. It is kept in a separate file for documentation
......@@ -11,6 +12,7 @@ import ssl
from collections.abc import Sequence
from typing import Optional, Union, get_args
import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
......@@ -243,6 +245,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
" into OpenAI API format, the name register in this plugin can be used "
"in ``--tool-call-parser``.")
parser.add_argument(
"--log-config-file",
type=str,
default=envs.VLLM_LOGGING_CONFIG_PATH,
help="Path to logging config JSON file for both vllm and uvicorn",
)
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument('--max-log-len',
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from functools import lru_cache, partial
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
......@@ -175,11 +176,15 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
type: Literal["function"] = "function"
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
class LogitsProcessorConstructor(BaseModel):
qualname: str
args: Optional[list[Any]] = None
kwargs: Optional[dict[str, Any]] = None
model_config = ConfigDict(extra="forbid")
LogitsProcessors = list[Union[str, LogitsProcessorConstructor]]
......@@ -234,7 +239,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
presence_penalty: Optional[float] = 0.0
response_format: Optional[AnyResponseFormat] = None
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: Optional[Union[str, list[str]]] = Field(default_factory=list)
stop: Optional[Union[str, list[str]]] = []
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = None
......@@ -258,7 +263,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
min_p: Optional[float] = None
repetition_penalty: Optional[float] = None
length_penalty: float = 1.0
stop_token_ids: Optional[list[int]] = Field(default_factory=list)
stop_token_ids: Optional[list[int]] = []
include_stop_str_in_output: bool = False
ignore_eos: bool = False
min_tokens: int = 0
......@@ -266,6 +271,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
prompt_logprobs: Optional[int] = None
allowed_token_ids: Optional[list[int]] = None
# --8<-- [end:chat-completion-sampling-params]
# --8<-- [start:chat-completion-extra-params]
......@@ -544,6 +550,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
else RequestOutputKind.FINAL_ONLY,
guided_decoding=guided_decoding,
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids,
extra_args=({"kv_transfer_params": self.kv_transfer_params}
if self.kv_transfer_params else None))
......@@ -756,7 +763,7 @@ class CompletionRequest(OpenAIBaseModel):
n: int = 1
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: Optional[Union[str, list[str]]] = Field(default_factory=list)
stop: Optional[Union[str, list[str]]] = []
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
suffix: Optional[str] = None
......@@ -770,7 +777,7 @@ class CompletionRequest(OpenAIBaseModel):
min_p: Optional[float] = None
repetition_penalty: Optional[float] = None
length_penalty: float = 1.0
stop_token_ids: Optional[list[int]] = Field(default_factory=list)
stop_token_ids: Optional[list[int]] = []
include_stop_str_in_output: bool = False
ignore_eos: bool = False
min_tokens: int = 0
......@@ -1477,6 +1484,10 @@ class TranscriptionStreamResponse(OpenAIBaseModel):
usage: Optional[UsageInfo] = Field(default=None)
BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest,
ScoreRequest, RerankRequest]
class BatchRequestInput(OpenAIBaseModel):
"""
The per-line object of the batch input file.
......@@ -1497,21 +1508,22 @@ class BatchRequestInput(OpenAIBaseModel):
url: str
# The parameters of the request.
body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest]
body: BatchRequestInputBody
@field_validator('body', mode='plain')
@classmethod
def check_type_for_url(cls, value: Any, info: ValidationInfo):
# Use url to disambiguate models
url = info.data['url']
url: str = info.data["url"]
if url == "/v1/chat/completions":
return ChatCompletionRequest.model_validate(value)
if url == "/v1/embeddings":
return TypeAdapter(EmbeddingRequest).validate_python(value)
if url == "/v1/score":
if url.endswith("/score"):
return ScoreRequest.model_validate(value)
return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest,
ScoreRequest]).validate_python(value)
if url.endswith("/rerank"):
return RerankRequest.model_validate(value)
return TypeAdapter(BatchRequestInputBody).validate_python(value)
class BatchResponseData(OpenAIBaseModel):
......@@ -1523,7 +1535,7 @@ class BatchResponseData(OpenAIBaseModel):
# The body of the response.
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse,
ScoreResponse]] = None
ScoreResponse, RerankResponse]] = None
class BatchRequestOutput(OpenAIBaseModel):
......@@ -1554,6 +1566,11 @@ class TokenizeCompletionRequest(OpenAIBaseModel):
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."),
)
return_token_strs: Optional[bool] = Field(
default=False,
description=("If true, also return the token strings "
"corresponding to the token ids."),
)
class TokenizeChatRequest(OpenAIBaseModel):
......@@ -1567,6 +1584,11 @@ class TokenizeChatRequest(OpenAIBaseModel):
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
return_token_strs: Optional[bool] = Field(
default=False,
description=("If true, also return the token strings "
"corresponding to the token ids."),
)
continue_final_message: bool = Field(
default=False,
description=
......@@ -1624,6 +1646,7 @@ class TokenizeResponse(OpenAIBaseModel):
count: int
max_model_len: int
tokens: list[int]
token_strs: Optional[list[str]] = None
class DetokenizeRequest(OpenAIBaseModel):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import tempfile
......@@ -21,7 +22,7 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchResponseData,
ChatCompletionResponse,
EmbeddingResponse, ErrorResponse,
ScoreResponse)
RerankResponse, ScoreResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
......@@ -33,9 +34,7 @@ from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION
def parse_args():
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible batch runner.")
def make_arg_parser(parser: FlexibleArgumentParser):
parser.add_argument(
"-i",
"--input-file",
......@@ -98,7 +97,13 @@ def parse_args():
default=False,
help="If set to True, enable prompt_tokens_details in usage.")
return parser.parse_args()
return parser
def parse_args():
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible batch runner.")
return make_arg_parser(parser).parse_args()
# explicitly use pure text format, with a newline at the end
......@@ -270,8 +275,11 @@ async def run_request(serving_engine_func: Callable,
tracker: BatchProgressTracker) -> BatchRequestOutput:
response = await serving_engine_func(request.body)
if isinstance(response,
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse)):
if isinstance(
response,
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse,
RerankResponse),
):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
......@@ -393,7 +401,7 @@ async def main(args):
response_futures.append(
run_request(embed_handler_fn, request, tracker))
tracker.submitted()
elif request.url == "/v1/score":
elif request.url.endswith("/score"):
score_handler_fn = openai_serving_scores.create_score if \
openai_serving_scores is not None else None
if score_handler_fn is None:
......@@ -407,13 +415,29 @@ async def main(args):
response_futures.append(
run_request(score_handler_fn, request, tracker))
tracker.submitted()
elif request.url.endswith("/rerank"):
rerank_handler_fn = openai_serving_scores.do_rerank if \
openai_serving_scores is not None else None
if rerank_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Rerank API",
))
continue
response_futures.append(
run_request(rerank_handler_fn, request, tracker))
tracker.submitted()
else:
response_futures.append(
make_async_error_request_output(
request,
error_msg=
"Only /v1/chat/completions, /v1/embeddings, and /v1/score "
"are supported in the batch endpoint.",
error_msg=f"URL {request.url} was used. "
"Supported endpoints: /v1/chat/completions, /v1/embeddings,"
" /score, /rerank ."
"See vllm/entrypoints/openai/api_server.py for supported "
"score/rerank versions.",
))
with tracker.pbar():
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import json
......@@ -236,6 +237,7 @@ class OpenAIServingChat(OpenAIServing):
prompt=engine_prompt,
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
)
else:
generator = self.engine_client.generate(
......@@ -318,10 +320,13 @@ class OpenAIServingChat(OpenAIServing):
def extract_tool_call_required_streaming(
self,
previous_text: str,
current_text: str,
current_text: Optional[str],
delta_text: str,
function_name_returned: bool,
) -> tuple[Optional[DeltaMessage], bool]:
if current_text is None or current_text == "":
# if the current text is empty, we cannot parse it
return None, function_name_returned
try:
obj = partial_json_parser.loads(current_text)
except partial_json_parser.core.exceptions.MalformedJSON:
......@@ -648,10 +653,18 @@ class OpenAIServingChat(OpenAIServing):
current_text = previous_text + delta_text
fn_name_returned = function_name_returned[i]
if self.reasoning_parser:
_, content = \
reasoning_parser.extract_reasoning_content(
current_text,
request
)
else:
content = current_text
delta_message, function_name_returned[i] = (
self.extract_tool_call_required_streaming(
previous_text=previous_text,
current_text=current_text,
current_text=content,
delta_text=delta_text,
function_name_returned=fn_name_returned))
......@@ -676,7 +689,21 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids,
output.token_ids,
))
# When encountering think end id in prompt_token_ids
# i.e {"enable_thinking": False},
# set reasoning status to end.
# Remove the text and token ids related
# to 'reasoning_content'.
if res.prompt_token_ids and \
reasoning_parser.is_reasoning_end(
list(res.prompt_token_ids)):
reasoning_end_arr[i] = True
current_token_ids = list(output.token_ids)
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
# When encountering think end id in delta_token_ids,
# set reasoning status to end.
# Remove the text and token ids related
......@@ -979,15 +1006,17 @@ class OpenAIServingChat(OpenAIServing):
# the fields of FunctionDefinition are a superset of the
# tool call outputs and can be used for parsing
assert content is not None
tool_calls = TypeAdapter(
list[FunctionDefinition]).validate_json(output.text)
list[FunctionDefinition]).validate_json(content)
message = ChatMessage(
role=role,
content="",
tool_calls=[
tool_call_class(function=FunctionCall(
name=tool_call.name,
arguments=json.dumps(tool_call.parameters)))
arguments=json.dumps(tool_call.parameters,
ensure_ascii=False)))
for tool_call in tool_calls
])
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from typing import Optional, Union, cast
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
......@@ -186,6 +187,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt=engine_prompt,
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
)
else:
generator = self.engine_client.generate(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
from typing import Final, Literal, Optional, Union, cast
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import io
import json
......@@ -134,11 +135,9 @@ class RequestProcessingMixin(BaseModel):
Mixin for request processing,
handling prompt preparation and engine input.
"""
request_prompts: Optional[Sequence[RequestPrompt]] = \
Field(default_factory=list)
request_prompts: Optional[Sequence[RequestPrompt]] = []
engine_prompts: Optional[Union[list[EngineTokensPrompt],
list[EngineEmbedsPrompt]]] = Field(
default_factory=list)
list[EngineEmbedsPrompt]]] = []
model_config = ConfigDict(arbitrary_types_allowed=True)
......@@ -528,12 +527,14 @@ class OpenAIServing:
if isinstance(request,
(EmbeddingChatRequest, EmbeddingCompletionRequest,
ScoreRequest, RerankRequest, ClassificationRequest)):
operation = {
ScoreRequest: "score",
ClassificationRequest: "classification"
}.get(type(request), "embedding generation")
if token_num > self.max_model_len:
operations: dict[type[AnyRequest], str] = {
ScoreRequest: "score",
ClassificationRequest: "classification"
}
operation = operations.get(type(request),
"embedding generation")
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pathlib
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import base64
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator, Mapping
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Final, Optional, Union
......@@ -110,7 +111,12 @@ class OpenAIServingTokenization(OpenAIServing):
dict) and "prompt_token_ids" in engine_prompt:
input_ids.extend(engine_prompt["prompt_token_ids"])
token_strs = None
if request.return_token_strs:
token_strs = tokenizer.convert_ids_to_tokens(input_ids)
return TokenizeResponse(tokens=input_ids,
token_strs=token_strs,
count=len(input_ids),
max_model_len=self.max_model_len)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
import time
......@@ -278,7 +279,9 @@ class OpenAIServingTranscription(OpenAIServing):
result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None
try:
# TODO(rob): subtract len of tokenized prompt.
# Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a
# fixed-size log-mel-spectogram.
default_max_tokens = self.model_config.max_model_len
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .abstract_tool_parser import ToolParser, ToolParserManager
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment