"vllm/attention/ops/__init__.py" did not exist on "2daf23ab0cf00da157b1255faddcf0a269283d36"
Commit ec5e299c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.3' into v0.7.3-dev

parents 47bd229c ed6e9075
......@@ -1187,6 +1187,12 @@ class AsyncLLMEngine(EngineClient):
async def reset_prefix_cache(self) -> None:
self.engine.reset_prefix_cache()
async def sleep(self, level: int = 1) -> None:
self.engine.sleep(level)
async def wake_up(self) -> None:
self.engine.wake_up()
async def add_lora(self, lora_request: LoRARequest) -> None:
self.engine.add_lora(lora_request)
......
......@@ -20,8 +20,7 @@ import vllm.envs as envs
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig,
VllmConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase, Stats
from vllm.engine.output_processor.interfaces import (
......@@ -59,7 +58,8 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
from vllm.utils import (Counter, Device, deprecate_kwargs,
resolve_obj_by_qualname, weak_bind)
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
......@@ -347,6 +347,11 @@ class LLMEngine:
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
Scheduler = resolve_obj_by_qualname(
self.vllm_config.scheduler_config.scheduler_cls)
else:
Scheduler = self.vllm_config.scheduler_config.scheduler_cls
self.scheduler = [
Scheduler(
self.scheduler_config, self.cache_config, self.lora_config,
......@@ -437,6 +442,7 @@ class LLMEngine:
@classmethod
def _get_executor_cls(cls,
engine_config: VllmConfig) -> Type[ExecutorBase]:
# distributed_executor_backend must be set in VllmConfig.__post_init__
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class.
......@@ -446,30 +452,29 @@ class LLMEngine:
"distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {distributed_executor_backend}.")
executor_class = distributed_executor_backend
elif engine_config.parallel_config.world_size > 1:
if distributed_executor_backend == "ray":
from vllm.executor.ray_distributed_executor import (
RayDistributedExecutor)
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
from vllm.executor.mp_distributed_executor import (
MultiprocessingDistributedExecutor)
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
executor_class = MultiprocessingDistributedExecutor
elif distributed_executor_backend == "uni":
# JAX-style, single-process, multi-device executor.
from vllm.executor.uniproc_executor import UniProcExecutor
executor_class = UniProcExecutor
elif distributed_executor_backend == "external_launcher":
# executor with external launcher
from vllm.executor.uniproc_executor import ( # noqa
ExecutorWithExternalLauncher)
executor_class = ExecutorWithExternalLauncher
else:
elif distributed_executor_backend == "ray":
from vllm.executor.ray_distributed_executor import (
RayDistributedExecutor)
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
from vllm.executor.mp_distributed_executor import (
MultiprocessingDistributedExecutor)
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
executor_class = MultiprocessingDistributedExecutor
elif distributed_executor_backend == "uni":
# JAX-style, single-process, multi-device executor.
from vllm.executor.uniproc_executor import UniProcExecutor
executor_class = UniProcExecutor
elif distributed_executor_backend == "external_launcher":
# executor with external launcher
from vllm.executor.uniproc_executor import ( # noqa
ExecutorWithExternalLauncher)
executor_class = ExecutorWithExternalLauncher
else:
raise ValueError("unrecognized distributed_executor_backend: "
f"{distributed_executor_backend}")
return executor_class
@classmethod
......
......@@ -237,7 +237,7 @@ class Metrics:
documentation="Count of successfully processed requests.",
labelnames=labelnames + [Metrics.labelname_finish_reason])
# Speculatie decoding stats
# Speculative decoding stats
self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
name="vllm:spec_decode_draft_acceptance_rate",
documentation="Speulative token acceptance rate.",
......
......@@ -127,6 +127,15 @@ class RPCResetPrefixCacheRequest(Enum):
RESET_PREFIX_CACHE = 1
class RPCSleepRequest(Enum):
SLEEP_LEVEL_1 = 1
SLEEP_LEVEL_2 = 2
class RPCWakeUpRequest(Enum):
WAKE_UP = 1
@dataclass
class RPCLoadAdapterRequest:
lora_request: LoRARequest
......@@ -141,7 +150,8 @@ class RPCAdapterLoadedResponse:
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest, RPCLoadAdapterRequest,
RPCResetPrefixCacheRequest]
RPCResetPrefixCacheRequest, RPCSleepRequest,
RPCWakeUpRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
RPCError]
......
......@@ -31,8 +31,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
RPCSleepRequest, RPCStartupRequest,
RPCStartupResponse,
RPCUProfileRequest, RPCWakeUpRequest)
from vllm.engine.protocol import EngineClient
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
......@@ -685,6 +686,16 @@ class MQLLMEngineClient(EngineClient):
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
socket=self.input_socket)
async def sleep(self, level: int = 1) -> None:
"""Sleep the engine for a given level"""
return await self._send_one_way_rpc_request(
request=RPCSleepRequest(level), socket=self.input_socket)
async def wake_up(self) -> None:
"""Wake up the engine"""
return await self._send_one_way_rpc_request(
request=RPCWakeUpRequest.WAKE_UP, socket=self.input_socket)
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests
......
......@@ -20,8 +20,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
RPCSleepRequest, RPCStartupRequest,
RPCStartupResponse,
RPCUProfileRequest, RPCWakeUpRequest)
# yapf: enable
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
......@@ -242,6 +243,10 @@ class MQLLMEngine:
self._handle_load_adapter_request(request)
elif isinstance(request, RPCResetPrefixCacheRequest):
self.reset_prefix_cache()
elif isinstance(request, RPCSleepRequest):
self.sleep(request.value)
elif isinstance(request, RPCWakeUpRequest):
self.wake_up()
else:
raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")
......@@ -369,6 +374,12 @@ class MQLLMEngine:
def reset_prefix_cache(self) -> bool:
return self.engine.reset_prefix_cache()
def sleep(self, level: int = 1) -> None:
self.engine.sleep(level)
def wake_up(self) -> None:
self.engine.wake_up()
def signal_handler(*_) -> None:
raise KeyboardInterrupt("MQLLMEngine terminated")
......
......@@ -113,7 +113,7 @@ class StopChecker:
stop_string_len = len(stop_str)
# Avoid searching already-searched text.
stop_index = output_text.find(stop_str,
-new_char_count - stop_string_len)
1 - new_char_count - stop_string_len)
if stop_index == -1:
continue
......
......@@ -278,6 +278,16 @@ class EngineClient(ABC):
"""Reset the prefix cache"""
...
@abstractmethod
async def sleep(self, level: int = 1) -> None:
"""Sleep the engine"""
...
@abstractmethod
async def wake_up(self) -> None:
"""Wake up the engine"""
...
@abstractmethod
async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
......
......@@ -127,6 +127,7 @@ async def run_server(args: Namespace,
shutdown_task = await serve_http(
app,
sock=None,
host=args.host,
port=args.port,
log_level=args.log_level,
......@@ -144,7 +145,7 @@ async def run_server(args: Namespace,
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--port", type=int, default=8000, ge=1024, le=65535)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument("--ssl-ca-certs",
......
# SPDX-License-Identifier: Apache-2.0
# The CLI entrypoint to vLLM.
import os
import signal
import sys
import vllm.entrypoints.cli.openai
import vllm.entrypoints.cli.serve
import vllm.version
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser
logger = init_logger(__name__)
CMD_MODULES = [
vllm.entrypoints.cli.openai,
vllm.entrypoints.cli.serve,
]
def register_signal_handlers():
def signal_handler(sig, frame):
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTSTP, signal_handler)
def env_setup():
# The safest multiprocessing method is `spawn`, as the default `fork` method
# is not compatible with some accelerators. The default method will be
# changing in future versions of Python, so we should use it explicitly when
# possible.
#
# We only set it here in the CLI entrypoint, because changing to `spawn`
# could break some existing code using vLLM as a library. `spawn` will cause
# unexpected behavior if the code is not protected by
# `if __name__ == "__main__":`.
#
# References:
# - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
# - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing
# - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors
# - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders
if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ:
logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def main():
env_setup()
parser = FlexibleArgumentParser(description="vLLM CLI")
parser.add_argument('-v',
'--version',
action='version',
version=vllm.version.__version__)
subparsers = parser.add_subparsers(required=False, dest="subparser")
cmds = {}
for cmd_module in CMD_MODULES:
new_cmds = cmd_module.cmd_init()
for cmd in new_cmds:
cmd.subparser_init(subparsers).set_defaults(
dispatch_function=cmd.cmd)
cmds[cmd.name] = cmd
args = parser.parse_args()
if args.subparser in cmds:
cmds[args.subparser].validate(args)
if hasattr(args, "dispatch_function"):
args.dispatch_function(args)
else:
parser.print_help()
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# Commands that act as an interactive OpenAI API client
import argparse
import os
import signal
import sys
from typing import List, Optional, Tuple
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.utils import FlexibleArgumentParser
def _register_signal_handlers():
def signal_handler(sig, frame):
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTSTP, signal_handler)
def _interactive_cli(args: argparse.Namespace) -> Tuple[str, OpenAI]:
_register_signal_handlers()
base_url = args.url
api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY")
openai_client = OpenAI(api_key=api_key, base_url=base_url)
if args.model_name:
model_name = args.model_name
else:
available_models = openai_client.models.list()
model_name = available_models.data[0].id
print(f"Using model: {model_name}")
return model_name, openai_client
def chat(system_prompt: Optional[str], model_name: str,
client: OpenAI) -> None:
conversation: List[ChatCompletionMessageParam] = []
if system_prompt is not None:
conversation.append({"role": "system", "content": system_prompt})
print("Please enter a message for the chat model:")
while True:
try:
input_message = input("> ")
except EOFError:
return
conversation.append({"role": "user", "content": input_message})
chat_completion = client.chat.completions.create(model=model_name,
messages=conversation)
response_message = chat_completion.choices[0].message
output = response_message.content
conversation.append(response_message) # type: ignore
print(output)
def _add_query_options(
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
"--url",
type=str,
default="http://localhost:8000/v1",
help="url of the running OpenAI-Compatible RESTful API server")
parser.add_argument(
"--model-name",
type=str,
default=None,
help=("The model name used in prompt completion, default to "
"the first model in list models API call."))
parser.add_argument(
"--api-key",
type=str,
default=None,
help=(
"API key for OpenAI services. If provided, this api key "
"will overwrite the api key obtained through environment variables."
))
return parser
class ChatCommand(CLISubcommand):
"""The `chat` subcommand for the vLLM CLI. """
def __init__(self):
self.name = "chat"
super().__init__()
@staticmethod
def cmd(args: argparse.Namespace) -> None:
model_name, client = _interactive_cli(args)
system_prompt = args.system_prompt
conversation: List[ChatCompletionMessageParam] = []
if system_prompt is not None:
conversation.append({"role": "system", "content": system_prompt})
print("Please enter a message for the chat model:")
while True:
try:
input_message = input("> ")
except EOFError:
return
conversation.append({"role": "user", "content": input_message})
chat_completion = client.chat.completions.create(
model=model_name, messages=conversation)
response_message = chat_completion.choices[0].message
output = response_message.content
conversation.append(response_message) # type: ignore
print(output)
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
chat_parser = subparsers.add_parser(
"chat",
help="Generate chat completions via the running API server",
usage="vllm chat [options]")
_add_query_options(chat_parser)
chat_parser.add_argument(
"--system-prompt",
type=str,
default=None,
help=("The system prompt to be added to the chat template, "
"used for models that support system prompts."))
return chat_parser
class CompleteCommand(CLISubcommand):
"""The `complete` subcommand for the vLLM CLI. """
def __init__(self):
self.name = "complete"
super().__init__()
@staticmethod
def cmd(args: argparse.Namespace) -> None:
model_name, client = _interactive_cli(args)
print("Please enter prompt to complete:")
while True:
input_prompt = input("> ")
completion = client.completions.create(model=model_name,
prompt=input_prompt)
output = completion.choices[0].text
print(output)
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
complete_parser = subparsers.add_parser(
"complete",
help=("Generate text completions based on the given prompt "
"via the running API server"),
usage="vllm complete [options]")
_add_query_options(complete_parser)
return complete_parser
def cmd_init() -> List[CLISubcommand]:
return [ChatCommand(), CompleteCommand()]
# SPDX-License-Identifier: Apache-2.0
import argparse
from typing import List
import uvloop
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.utils import FlexibleArgumentParser
class ServeSubcommand(CLISubcommand):
"""The `serve` subcommand for the vLLM CLI. """
def __init__(self):
self.name = "serve"
super().__init__()
@staticmethod
def cmd(args: argparse.Namespace) -> None:
# The default value of `--model`
if args.model != EngineArgs.model:
raise ValueError(
"With `vllm serve`, you should provide the model as a "
"positional argument instead of via the `--model` option.")
# EngineArgs expects the model name to be passed as --model.
args.model = args.model_tag
uvloop.run(run_server(args))
def validate(self, args: argparse.Namespace) -> None:
validate_parsed_serve_args(args)
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
serve_parser = subparsers.add_parser(
"serve",
help="Start the vLLM OpenAI Compatible API server",
usage="vllm serve <model_tag> [options]")
serve_parser.add_argument("model_tag",
type=str,
help="The model tag to serve")
serve_parser.add_argument(
"--config",
type=str,
default='',
required=False,
help="Read CLI options from a config file."
"Must be a YAML with the following options:"
"https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#cli-reference"
)
return make_arg_parser(serve_parser)
def cmd_init() -> List[CLISubcommand]:
return [ServeSubcommand()]
# SPDX-License-Identifier: Apache-2.0
import argparse
from vllm.utils import FlexibleArgumentParser
class CLISubcommand:
"""Base class for CLI argument handlers."""
name: str
@staticmethod
def cmd(args: argparse.Namespace) -> None:
raise NotImplementedError("Subclasses should implement this method")
def validate(self, args: argparse.Namespace) -> None:
# No validation by default
pass
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
raise NotImplementedError("Subclasses should implement this method")
......@@ -2,8 +2,9 @@
import asyncio
import signal
import socket
from http import HTTPStatus
from typing import Any
from typing import Any, Optional
import uvicorn
from fastapi import FastAPI, Request, Response
......@@ -17,7 +18,8 @@ from vllm.utils import find_process_using_port
logger = init_logger(__name__)
async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
async def serve_http(app: FastAPI, sock: Optional[socket.socket],
**uvicorn_kwargs: Any):
logger.info("Available routes are:")
for route in app.routes:
methods = getattr(route, "methods", None)
......@@ -34,7 +36,8 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
loop = asyncio.get_running_loop()
server_task = loop.create_task(server.serve())
server_task = loop.create_task(
server.serve(sockets=[sock] if sock else None))
def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
......
......@@ -421,7 +421,7 @@ class LLM:
instead pass them via the ``inputs`` parameter.
"""
runner_type = self.llm_engine.model_config.runner_type
if runner_type != "generate":
if runner_type not in ["generate", "transcription"]:
messages = [
"LLM.generate() is only supported for (conditional) generation "
"models (XForCausalLM, XForConditionalGeneration).",
......@@ -1051,9 +1051,9 @@ class LLM:
def _cross_encoding_score(
self,
tokenizer: Union[AnyTokenizer],
text_1: List[Union[str, TextPrompt, TokensPrompt]],
text_2: List[Union[str, TextPrompt, TokensPrompt]],
tokenizer: AnyTokenizer,
text_1: List[str],
text_2: List[str],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
......@@ -1176,29 +1176,36 @@ class LLM:
if isinstance(text_1, (str, dict)):
# Convert a single prompt to a list.
text_1 = [text_1]
text_1 = [ensure_str(t) for t in text_1]
input_text_1: List[str] = [ensure_str(t) for t in text_1]
if isinstance(text_2, (str, dict)):
# Convert a single prompt to a list.
text_2 = [text_2]
text_2 = [ensure_str(t) for t in text_2]
input_text_2: List[str] = [ensure_str(t) for t in text_2]
if len(text_1) > 1 and len(text_1) != len(text_2):
if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2):
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len(text_1) == 0:
if len(input_text_1) == 0:
raise ValueError("At least one text element must be given")
if len(text_2) == 0:
if len(input_text_2) == 0:
raise ValueError("At least one text_pair element must be given")
if self.llm_engine.model_config.is_cross_encoder:
return self._cross_encoding_score(tokenizer, text_1, text_2,
return self._cross_encoding_score(tokenizer, input_text_1,
input_text_2,
truncate_prompt_tokens, use_tqdm,
lora_request,
prompt_adapter_request)
else:
return self._embedding_score(tokenizer, text_1, text_2,
truncate_prompt_tokens, use_tqdm,
lora_request, prompt_adapter_request)
return self._embedding_score(
tokenizer,
input_text_1, # type: ignore[arg-type]
input_text_2, # type: ignore[arg-type]
truncate_prompt_tokens,
use_tqdm,
lora_request,
prompt_adapter_request)
def start_profile(self) -> None:
self.llm_engine.start_profile()
......
......@@ -10,17 +10,16 @@ import os
import re
import signal
import socket
import sys
import tempfile
import uuid
from argparse import Namespace
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union
from typing import Annotated, AsyncIterator, Dict, Optional, Set, Tuple, Union
import uvloop
from fastapi import APIRouter, FastAPI, HTTPException, Request
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
......@@ -62,6 +61,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ScoreRequest, ScoreResponse,
TokenizeRequest,
TokenizeResponse,
TranscriptionRequest,
TranscriptionResponse,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
# yapf: enable
......@@ -76,6 +77,8 @@ from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger
......@@ -253,6 +256,16 @@ async def build_async_engine_client_from_engine_args(
multiprocess.mark_process_dead(engine_process.pid)
async def validate_json_request(raw_request: Request):
content_type = raw_request.headers.get("content-type", "").lower()
media_type = content_type.split(";", maxsplit=1)[0]
if media_type != "application/json":
raise HTTPException(
status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
detail="Unsupported Media Type: Only 'application/json' is allowed"
)
router = APIRouter()
......@@ -319,6 +332,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def transcription(request: Request) -> OpenAIServingTranscription:
return request.app.state.openai_serving_transcription
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
......@@ -336,7 +353,7 @@ async def ping(raw_request: Request) -> Response:
return await health(raw_request)
@router.post("/tokenize")
@router.post("/tokenize", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
......@@ -351,7 +368,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
assert_never(generator)
@router.post("/detokenize")
@router.post("/detokenize", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def detokenize(request: DetokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
......@@ -380,7 +397,8 @@ async def show_version():
return JSONResponse(content=ver)
@router.post("/v1/chat/completions")
@router.post("/v1/chat/completions",
dependencies=[Depends(validate_json_request)])
@with_cancellation
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
......@@ -401,7 +419,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/completions")
@router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def create_completion(request: CompletionRequest, raw_request: Request):
handler = completion(raw_request)
......@@ -419,7 +437,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/embeddings")
@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request)
......@@ -465,7 +483,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
assert_never(generator)
@router.post("/pooling")
@router.post("/pooling", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def create_pooling(request: PoolingRequest, raw_request: Request):
handler = pooling(raw_request)
......@@ -483,7 +501,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
assert_never(generator)
@router.post("/score")
@router.post("/score", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def create_score(request: ScoreRequest, raw_request: Request):
handler = score(raw_request)
......@@ -501,7 +519,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
assert_never(generator)
@router.post("/v1/score")
@router.post("/v1/score", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def create_score_v1(request: ScoreRequest, raw_request: Request):
logger.warning(
......@@ -511,7 +529,32 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request)
@router.post("/rerank")
@router.post("/v1/audio/transcriptions")
@with_cancellation
async def create_transcriptions(request: Annotated[TranscriptionRequest,
Form()],
raw_request: Request):
handler = transcription(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Transcriptions API")
audio_data = await request.file.read()
generator = await handler.create_transcription(audio_data, request,
raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, TranscriptionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def do_rerank(request: RerankRequest, raw_request: Request):
handler = rerank(raw_request)
......@@ -528,7 +571,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
assert_never(generator)
@router.post("/v1/rerank")
@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
logger.warning_once(
......@@ -539,7 +582,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request):
return await do_rerank(request, raw_request)
@router.post("/v2/rerank")
@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
return await do_rerank(request, raw_request)
......@@ -582,8 +625,26 @@ if envs.VLLM_SERVER_DEV_MODE:
await engine_client(raw_request).reset_prefix_cache()
return Response(status_code=200)
@router.post("/sleep")
async def sleep(raw_request: Request):
# get POST params
level = raw_request.query_params.get("level", "1")
logger.info("sleep the engine with level %s", level)
await engine_client(raw_request).sleep(int(level))
# FIXME: in v0 with frontend multiprocessing, the sleep command
# is sent but does not finish yet when we return a response.
return Response(status_code=200)
@router.post("/wake_up")
async def wake_up(raw_request: Request):
logger.info("wake up the engine")
await engine_client(raw_request).wake_up()
# FIXME: in v0 with frontend multiprocessing, the wake-up command
# is sent but does not finish yet when we return a response.
return Response(status_code=200)
@router.post("/invocations")
@router.post("/invocations", dependencies=[Depends(validate_json_request)])
async def invocations(raw_request: Request):
"""
For SageMaker, routes requests to other handlers based on model `task`.
......@@ -633,7 +694,8 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
"Lora dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!")
@router.post("/v1/load_lora_adapter")
@router.post("/v1/load_lora_adapter",
dependencies=[Depends(validate_json_request)])
async def load_lora_adapter(request: LoadLoraAdapterRequest,
raw_request: Request):
handler = models(raw_request)
......@@ -644,7 +706,8 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
return Response(status_code=200, content=response)
@router.post("/v1/unload_lora_adapter")
@router.post("/v1/unload_lora_adapter",
dependencies=[Depends(validate_json_request)])
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
raw_request: Request):
handler = models(raw_request)
......@@ -753,7 +816,9 @@ async def init_app_state(
state.log_stats = not args.disable_log_stats
resolved_chat_template = load_chat_template(args.chat_template)
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
if resolved_chat_template is not None:
logger.info("Using supplied chat template:\n%s",
resolved_chat_template)
state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
......@@ -821,6 +886,12 @@ async def init_app_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
)
state.openai_serving_transcription = OpenAIServingTranscription(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if model_config.runner_type == "transcription" else None
state.task = model_config.task
......@@ -831,6 +902,7 @@ def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.bind(addr)
return sock
......@@ -878,8 +950,17 @@ async def run_server(args, **uvicorn_kwargs) -> None:
model_config = await engine_client.get_model_config()
await init_app_state(engine_client, model_config, app.state, args)
def _listen_addr(a: str) -> str:
if is_valid_ipv6_address(a):
return '[' + a + ']'
return a or "0.0.0.0"
logger.info("Starting vLLM API server on http://%s:%d",
_listen_addr(sock_addr[0]), sock_addr[1])
shutdown_task = await serve_http(
app,
sock=sock,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
......@@ -888,8 +969,6 @@ async def run_server(args, **uvicorn_kwargs) -> None:
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
# Workaround to work on macOS
fd=sock.fileno() if sys.platform.startswith("darwin") else None,
**uvicorn_kwargs,
)
......@@ -901,7 +980,8 @@ async def run_server(args, **uvicorn_kwargs) -> None:
if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
# This section should be in sync with vllm/entrypoints/cli/main.py for CLI
# entrypoints.
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
......
......@@ -8,9 +8,10 @@ from argparse import Namespace
from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
import torch
from fastapi import UploadFile
from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
ValidationInfo, field_validator, model_validator)
from typing_extensions import Annotated
from typing_extensions import Annotated, TypeAlias
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.logger import init_logger
......@@ -311,6 +312,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=("Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."),
)
mm_processor_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
......@@ -1426,3 +1431,163 @@ class LoadLoraAdapterRequest(BaseModel):
class UnloadLoraAdapterRequest(BaseModel):
lora_name: str
lora_int_id: Optional[int] = Field(default=None)
## Protocols for Audio
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json",
"vtt"]
class TranscriptionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
#https://platform.openai.com/docs/api-reference/audio/createTranscription
file: UploadFile
"""
The audio file object (not file name) to transcribe, in one of these
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
"""
model: str
"""ID of the model to use.
"""
language: Optional[str] = None
"""The language of the input audio.
Supplying the input language in
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
will improve accuracy and latency.
"""
prompt: str = Field(default="")
"""An optional text to guide the model's style or continue a previous audio
segment.
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
should match the audio language.
"""
response_format: AudioResponseFormat = Field(default="json")
"""
The format of the output, in one of these options: `json`, `text`, `srt`,
`verbose_json`, or `vtt`.
"""
## TODO (varun) : Support if set to 0, certain thresholds are met !!
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
Higher values like 0.8 will make the output more random, while lower values
like 0.2 will make it more focused / deterministic. If set to 0, the model
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
to automatically increase the temperature until certain thresholds are hit.
"""
timestamp_granularities: List[Literal["word", "segment"]] = Field(
alias="timestamp_granularities[]", default=[])
"""The timestamp granularities to populate for this transcription.
`response_format` must be set `verbose_json` to use timestamp granularities.
Either or both of these options are supported: `word`, or `segment`. Note:
There is no additional latency for segment timestamps, but generating word
timestamps incurs additional latency.
"""
# Default sampling parameters for transcription requests.
_DEFAULT_SAMPLING_PARAMS: dict = {
"temperature": 0,
}
def to_sampling_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
return SamplingParams.from_optional(temperature=temperature,
max_tokens=max_tokens)
# Transcription response objects
class TranscriptionResponse(OpenAIBaseModel):
text: str
"""The transcribed text."""
class TranscriptionWord(OpenAIBaseModel):
end: float
"""End time of the word in seconds."""
start: float
"""Start time of the word in seconds."""
word: str
"""The text content of the word."""
class TranscriptionSegment(OpenAIBaseModel):
id: int
"""Unique identifier of the segment."""
avg_logprob: float
"""Average logprob of the segment.
If the value is lower than -1, consider the logprobs failed.
"""
compression_ratio: float
"""Compression ratio of the segment.
If the value is greater than 2.4, consider the compression failed.
"""
end: float
"""End time of the segment in seconds."""
no_speech_prob: float
"""Probability of no speech in the segment.
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
this segment silent.
"""
seek: int
"""Seek offset of the segment."""
start: float
"""Start time of the segment in seconds."""
temperature: float
"""Temperature parameter used for generating the segment."""
text: str
"""Text content of the segment."""
tokens: List[int]
"""Array of token IDs for the text content."""
class TranscriptionResponseVerbose(OpenAIBaseModel):
duration: str
"""The duration of the input audio."""
language: str
"""The language of the input audio."""
text: str
"""The transcribed text."""
segments: Optional[List[TranscriptionSegment]] = None
"""Segments of the transcribed text and their corresponding details."""
words: Optional[List[TranscriptionWord]] = None
"""Extracted words and their corresponding timestamps."""
......@@ -67,6 +67,8 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
]):
return None
# Check if <think> is present in previous or delta.
# Keep compatibility with models that don't generate <think> tokens.
if self.think_start_token_id in previous_token_ids:
if self.think_end_token_id in delta_token_ids:
# <think> in previous, </think> in delta,
......@@ -85,7 +87,6 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
elif self.think_start_token_id in delta_token_ids:
logger.info(delta_text)
if self.think_end_token_id in delta_token_ids:
# <think> in delta, </think> in delta, extract reasoning content
start_index = delta_text.find(self.think_start_token)
......@@ -101,35 +102,46 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
else:
# No <think> in previous or delta, reasoning content continues.
return DeltaMessage(content=delta_text)
# No <think> in previous or delta, also need to check for </think>.
# Because the model may have generated </think> without <think>
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token_id in delta_token_ids:
# </think> in delta with more tokens,
# extract reasoning content and content
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
elif self.think_end_token_id in previous_token_ids:
# </think> in previous, thinking content ends
return DeltaMessage(content=delta_text)
else:
# no </think> in previous or delta, reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> Tuple[Optional[str], Optional[str]]:
# Check if the model output contains the <think> tokens.
if (self.think_start_token not in model_output
or self.think_end_token not in model_output):
return None, model_output
# DeepSeek R1 doesn't generate <think> now.
# Thus we assume the reasoning content is always at the start.
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token not in model_output:
return model_output, None
else:
# Add a start token if it's missing to keep compatibility.
if self.think_start_token not in model_output:
model_output = f"{self.think_start_token}{model_output}"
# Use a regex to find the reasoning content
reasoning_content = self.reasoning_regex.findall(model_output)[0]
# Remove the reasoning content from the model output
# Although deepseek's <think> token is always at the
# beginning of the line, we cannot guarantee that the
# other models will follow this convention.
# Therefore, we need to add :start_index.
start_index = model_output.find(self.think_start_token)
if start_index != -1:
end_index = start_index + len(
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
)
model_output = model_output[:start_index] + \
model_output[end_index:]
if len(model_output) == 0:
return reasoning_content, None
return reasoning_content, model_output
end_index = len(
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
)
final_output = model_output[end_index:]
if len(final_output) == 0:
return reasoning_content, None
return reasoning_content, final_output
# SPDX-License-Identifier: Apache-2.0
import asyncio
import tempfile
from http import HTTPStatus
from io import StringIO
from typing import Awaitable, Callable, List, Optional
......@@ -51,6 +52,13 @@ def parse_args():
help="The path or url to a single output file. Currently supports "
"local file paths, or web (http or https) urls. If a URL is specified,"
" the file should be available via HTTP PUT.")
parser.add_argument(
"--output-tmp-dir",
type=str,
default=None,
help="The directory to store the output file before uploading it "
"to the output URL.",
)
parser.add_argument("--response-role",
type=nullable_str,
default="assistant",
......@@ -134,17 +142,107 @@ async def read_file(path_or_url: str) -> str:
return f.read()
async def write_file(path_or_url: str, data: str) -> None:
async def write_local_file(output_path: str,
batch_outputs: List[BatchRequestOutput]) -> None:
"""
Write the responses to a local file.
output_path: The path to write the responses to.
batch_outputs: The list of batch outputs to write.
"""
# We should make this async, but as long as run_batch runs as a
# standalone program, blocking the event loop won't effect performance.
with open(output_path, "w", encoding="utf-8") as f:
for o in batch_outputs:
print(o.model_dump_json(), file=f)
async def upload_data(output_url: str, data_or_file: str,
from_file: bool) -> None:
"""
Upload a local file to a URL.
output_url: The URL to upload the file to.
data_or_file: Either the data to upload or the path to the file to upload.
from_file: If True, data_or_file is the path to the file to upload.
"""
# Timeout is a common issue when uploading large files.
# We retry max_retries times before giving up.
max_retries = 5
# Number of seconds to wait before retrying.
delay = 5
for attempt in range(1, max_retries + 1):
try:
# We increase the timeout to 1000 seconds to allow
# for large files (default is 300).
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(
total=1000)) as session:
if from_file:
with open(data_or_file, "rb") as file:
async with session.put(output_url,
data=file) as response:
if response.status != 200:
raise Exception(f"Failed to upload file.\n"
f"Status: {response.status}\n"
f"Response: {response.text()}")
else:
async with session.put(output_url,
data=data_or_file) as response:
if response.status != 200:
raise Exception(f"Failed to upload data.\n"
f"Status: {response.status}\n"
f"Response: {response.text()}")
except Exception as e:
if attempt < max_retries:
logger.error(
f"Failed to upload data (attempt {attempt}). "
f"Error message: {str(e)}.\nRetrying in {delay} seconds..."
)
await asyncio.sleep(delay)
else:
raise Exception(f"Failed to upload data (attempt {attempt}). "
f"Error message: {str(e)}.") from e
async def write_file(path_or_url: str, batch_outputs: List[BatchRequestOutput],
output_tmp_dir: str) -> None:
"""
Write batch_outputs to a file or upload to a URL.
path_or_url: The path or URL to write batch_outputs to.
batch_outputs: The list of batch outputs to write.
output_tmp_dir: The directory to store the output file before uploading it
to the output URL.
"""
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, \
session.put(path_or_url, data=data.encode("utf-8")):
pass
if output_tmp_dir is None:
logger.info("Writing outputs to memory buffer")
output_buffer = StringIO()
for o in batch_outputs:
print(o.model_dump_json(), file=output_buffer)
output_buffer.seek(0)
logger.info("Uploading outputs to %s", path_or_url)
await upload_data(
path_or_url,
output_buffer.read().strip().encode("utf-8"),
from_file=False,
)
else:
# Write responses to a temporary file and then upload it to the URL.
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
dir=output_tmp_dir,
prefix="tmp_batch_output_",
suffix=".jsonl",
) as f:
logger.info("Writing outputs to temporary local file %s",
f.name)
await write_local_file(f.name, batch_outputs)
logger.info("Uploading outputs to %s", path_or_url)
await upload_data(path_or_url, f.name, from_file=True)
else:
# We should make this async, but as long as this is always run as a
# standalone program, blocking the event loop won't effect performance
# in this particular case.
with open(path_or_url, "w", encoding="utf-8") as f:
f.write(data)
logger.info("Writing outputs to local file %s", path_or_url)
await write_local_file(path_or_url, batch_outputs)
def make_error_request_output(request: BatchRequestInput,
......@@ -317,12 +415,7 @@ async def main(args):
with tracker.pbar():
responses = await asyncio.gather(*response_futures)
output_buffer = StringIO()
for response in responses:
print(response.model_dump_json(), file=output_buffer)
output_buffer.seek(0)
await write_file(args.output_file, output_buffer.read().strip())
await write_file(args.output_file, responses, args.output_tmp_dir)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment