Commit 1591c68f authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.2

parents 09bcf00b c7f2cf2b
...@@ -100,6 +100,7 @@ if __name__ == "__main__": ...@@ -100,6 +100,7 @@ if __name__ == "__main__":
type=str, type=str,
default=None, default=None,
help="FastAPI root_path when app is behind a path based routing proxy") help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument("--log-level", type=str, default="debug")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
...@@ -110,7 +111,7 @@ if __name__ == "__main__": ...@@ -110,7 +111,7 @@ if __name__ == "__main__":
uvicorn.run(app, uvicorn.run(app,
host=args.host, host=args.host,
port=args.port, port=args.port,
log_level="debug", log_level=args.log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE, timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile, ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile, ssl_certfile=args.ssl_certfile,
......
...@@ -69,6 +69,9 @@ class LLM: ...@@ -69,6 +69,9 @@ class LLM:
disable CUDA graph and always execute the model in eager mode. disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid. If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back When a sequence has context length larger than this, we fall back
to eager mode. to eager mode.
disable_custom_all_reduce: See ParallelConfig disable_custom_all_reduce: See ParallelConfig
...@@ -90,7 +93,8 @@ class LLM: ...@@ -90,7 +93,8 @@ class LLM:
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
swap_space: int = 4, swap_space: int = 4,
enforce_eager: bool = False, enforce_eager: bool = False,
max_context_len_to_capture: int = 8192, max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
...@@ -112,6 +116,7 @@ class LLM: ...@@ -112,6 +116,7 @@ class LLM:
swap_space=swap_space, swap_space=swap_space,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture, max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce, disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs, **kwargs,
) )
......
import asyncio import asyncio
import importlib import importlib
import inspect import inspect
import os import re
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Set
import fastapi import fastapi
import uvicorn import uvicorn
...@@ -12,8 +13,10 @@ from fastapi.exceptions import RequestValidationError ...@@ -12,8 +13,10 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app from prometheus_client import make_asgi_app
from starlette.routing import Mount
import vllm import vllm
import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
...@@ -31,6 +34,8 @@ openai_serving_chat: OpenAIServingChat ...@@ -31,6 +34,8 @@ openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion openai_serving_completion: OpenAIServingCompletion
logger = init_logger(__name__) logger = init_logger(__name__)
_running_tasks: Set[asyncio.Task[Any]] = set()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: fastapi.FastAPI): async def lifespan(app: fastapi.FastAPI):
...@@ -41,7 +46,9 @@ async def lifespan(app: fastapi.FastAPI): ...@@ -41,7 +46,9 @@ async def lifespan(app: fastapi.FastAPI):
await engine.do_log_stats() await engine.do_log_stats()
if not engine_args.disable_log_stats: if not engine_args.disable_log_stats:
asyncio.create_task(_force_log()) task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)
yield yield
...@@ -55,8 +62,10 @@ def parse_args(): ...@@ -55,8 +62,10 @@ def parse_args():
# Add prometheus asgi middleware to route /metrics requests # Add prometheus asgi middleware to route /metrics requests
metrics_app = make_asgi_app() route = Mount("/metrics", make_asgi_app())
app.mount("/metrics", metrics_app) # Workaround for 307 Redirect for /metrics
route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(route)
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
...@@ -125,7 +134,7 @@ if __name__ == "__main__": ...@@ -125,7 +134,7 @@ if __name__ == "__main__":
allow_headers=args.allowed_headers, allow_headers=args.allowed_headers,
) )
if token := os.environ.get("VLLM_API_KEY") or args.api_key: if token := envs.VLLM_API_KEY or args.api_key:
@app.middleware("http") @app.middleware("http")
async def authentication(request: Request, call_next): async def authentication(request: Request, call_next):
...@@ -148,8 +157,8 @@ if __name__ == "__main__": ...@@ -148,8 +157,8 @@ if __name__ == "__main__":
raise ValueError(f"Invalid middleware {middleware}. " raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.") f"Must be a function or a class.")
logger.info(f"vLLM API server version {vllm.__version__}") logger.info("vLLM API server version %s", vllm.__version__)
logger.info(f"args: {args}") logger.info("args: %s", args)
if args.served_model_name is not None: if args.served_model_name is not None:
served_model_names = args.served_model_name served_model_names = args.served_model_name
......
...@@ -8,8 +8,8 @@ import argparse ...@@ -8,8 +8,8 @@ import argparse
import json import json
import ssl import ssl
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import LoRA from vllm.entrypoints.openai.serving_engine import LoRAModulePath
class LoRAParserAction(argparse.Action): class LoRAParserAction(argparse.Action):
...@@ -18,14 +18,17 @@ class LoRAParserAction(argparse.Action): ...@@ -18,14 +18,17 @@ class LoRAParserAction(argparse.Action):
lora_list = [] lora_list = []
for item in values: for item in values:
name, path = item.split('=') name, path = item.split('=')
lora_list.append(LoRA(name, path)) lora_list.append(LoRAModulePath(name, path))
setattr(namespace, self.dest, lora_list) setattr(namespace, self.dest, lora_list)
def make_arg_parser(): def make_arg_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.") description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", type=str, default=None, help="host name") parser.add_argument("--host",
type=nullable_str,
default=None,
help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number") parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument( parser.add_argument(
"--uvicorn-log-level", "--uvicorn-log-level",
...@@ -49,49 +52,39 @@ def make_arg_parser(): ...@@ -49,49 +52,39 @@ def make_arg_parser():
default=["*"], default=["*"],
help="allowed headers") help="allowed headers")
parser.add_argument("--api-key", parser.add_argument("--api-key",
type=str, type=nullable_str,
default=None, default=None,
help="If provided, the server will require this key " help="If provided, the server will require this key "
"to be presented in the header.") "to be presented in the header.")
parser.add_argument("--served-model-name",
nargs="+",
type=str,
default=None,
help="The model name(s) used in the API. If multiple "
"names are provided, the server will respond to any "
"of the provided names. The model name in the model "
"field of a response will be the first name in this "
"list. If not specified, the model name will be the "
"same as the `--model` argument.")
parser.add_argument( parser.add_argument(
"--lora-modules", "--lora-modules",
type=str, type=nullable_str,
default=None, default=None,
nargs='+', nargs='+',
action=LoRAParserAction, action=LoRAParserAction,
help="LoRA module configurations in the format name=path. " help="LoRA module configurations in the format name=path. "
"Multiple modules can be specified.") "Multiple modules can be specified.")
parser.add_argument("--chat-template", parser.add_argument("--chat-template",
type=str, type=nullable_str,
default=None, default=None,
help="The file path to the chat template, " help="The file path to the chat template, "
"or the template in single-line form " "or the template in single-line form "
"for the specified model") "for the specified model")
parser.add_argument("--response-role", parser.add_argument("--response-role",
type=str, type=nullable_str,
default="assistant", default="assistant",
help="The role name to return if " help="The role name to return if "
"`request.add_generation_prompt=true`.") "`request.add_generation_prompt=true`.")
parser.add_argument("--ssl-keyfile", parser.add_argument("--ssl-keyfile",
type=str, type=nullable_str,
default=None, default=None,
help="The file path to the SSL key file") help="The file path to the SSL key file")
parser.add_argument("--ssl-certfile", parser.add_argument("--ssl-certfile",
type=str, type=nullable_str,
default=None, default=None,
help="The file path to the SSL cert file") help="The file path to the SSL cert file")
parser.add_argument("--ssl-ca-certs", parser.add_argument("--ssl-ca-certs",
type=str, type=nullable_str,
default=None, default=None,
help="The CA certificates file") help="The CA certificates file")
parser.add_argument( parser.add_argument(
...@@ -102,12 +95,12 @@ def make_arg_parser(): ...@@ -102,12 +95,12 @@ def make_arg_parser():
) )
parser.add_argument( parser.add_argument(
"--root-path", "--root-path",
type=str, type=nullable_str,
default=None, default=None,
help="FastAPI root_path when app is behind a path based routing proxy") help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument( parser.add_argument(
"--middleware", "--middleware",
type=str, type=nullable_str,
action="append", action="append",
default=[], default=[],
help="Additional ASGI middleware to apply to the app. " help="Additional ASGI middleware to apply to the app. "
......
...@@ -4,14 +4,20 @@ import time ...@@ -4,14 +4,20 @@ import time
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union
import torch import torch
from pydantic import BaseModel, Field, model_validator from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
class ErrorResponse(BaseModel): class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
class ErrorResponse(OpenAIBaseModel):
object: str = "error" object: str = "error"
message: str message: str
type: str type: str
...@@ -19,7 +25,7 @@ class ErrorResponse(BaseModel): ...@@ -19,7 +25,7 @@ class ErrorResponse(BaseModel):
code: int code: int
class ModelPermission(BaseModel): class ModelPermission(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission" object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
...@@ -34,7 +40,7 @@ class ModelPermission(BaseModel): ...@@ -34,7 +40,7 @@ class ModelPermission(BaseModel):
is_blocking: bool = False is_blocking: bool = False
class ModelCard(BaseModel): class ModelCard(OpenAIBaseModel):
id: str id: str
object: str = "model" object: str = "model"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
...@@ -44,26 +50,26 @@ class ModelCard(BaseModel): ...@@ -44,26 +50,26 @@ class ModelCard(BaseModel):
permission: List[ModelPermission] = Field(default_factory=list) permission: List[ModelPermission] = Field(default_factory=list)
class ModelList(BaseModel): class ModelList(OpenAIBaseModel):
object: str = "list" object: str = "list"
data: List[ModelCard] = Field(default_factory=list) data: List[ModelCard] = Field(default_factory=list)
class UsageInfo(BaseModel): class UsageInfo(OpenAIBaseModel):
prompt_tokens: int = 0 prompt_tokens: int = 0
total_tokens: int = 0 total_tokens: int = 0
completion_tokens: Optional[int] = 0 completion_tokens: Optional[int] = 0
class ResponseFormat(BaseModel): class ResponseFormat(OpenAIBaseModel):
# type must be "json_object" or "text" # type must be "json_object" or "text"
type: Literal["text", "json_object"] type: Literal["text", "json_object"]
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create # https://platform.openai.com/docs/api-reference/chat/create
messages: List[Dict[str, str]] messages: List[ChatCompletionMessageParam]
model: str model: str
frequency_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None logit_bias: Optional[Dict[str, float]] = None
...@@ -73,7 +79,9 @@ class ChatCompletionRequest(BaseModel): ...@@ -73,7 +79,9 @@ class ChatCompletionRequest(BaseModel):
n: Optional[int] = 1 n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None response_format: Optional[ResponseFormat] = None
seed: Optional[int] = None seed: Optional[int] = Field(None,
ge=torch.iinfo(torch.long).min,
le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
temperature: Optional[float] = 0.7 temperature: Optional[float] = 0.7
...@@ -140,6 +148,11 @@ class ChatCompletionRequest(BaseModel): ...@@ -140,6 +148,11 @@ class ChatCompletionRequest(BaseModel):
"If specified, will override the default guided decoding backend " "If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either " "of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'")) "'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
...@@ -204,7 +217,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -204,7 +217,7 @@ class ChatCompletionRequest(BaseModel):
return data return data
class CompletionRequest(BaseModel): class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create # https://platform.openai.com/docs/api-reference/completions/create
model: str model: str
...@@ -217,7 +230,9 @@ class CompletionRequest(BaseModel): ...@@ -217,7 +230,9 @@ class CompletionRequest(BaseModel):
max_tokens: Optional[int] = 16 max_tokens: Optional[int] = 16
n: int = 1 n: int = 1
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
seed: Optional[int] = None seed: Optional[int] = Field(None,
ge=torch.iinfo(torch.long).min,
le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
suffix: Optional[str] = None suffix: Optional[str] = None
...@@ -279,6 +294,11 @@ class CompletionRequest(BaseModel): ...@@ -279,6 +294,11 @@ class CompletionRequest(BaseModel):
"If specified, will override the default guided decoding backend " "If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of " "of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'")) "'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
# doc: end-completion-extra-params # doc: end-completion-extra-params
...@@ -343,19 +363,19 @@ class CompletionRequest(BaseModel): ...@@ -343,19 +363,19 @@ class CompletionRequest(BaseModel):
return data return data
class LogProbs(BaseModel): class LogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list) text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list)
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
class CompletionResponseChoice(BaseModel): class CompletionResponseChoice(OpenAIBaseModel):
index: int index: int
text: str text: str
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[str] = None
stop_reason: Union[None, int, str] = Field( stop_reason: Optional[Union[int, str]] = Field(
default=None, default=None,
description=( description=(
"The stop string or token id that caused the completion " "The stop string or token id that caused the completion "
...@@ -364,7 +384,7 @@ class CompletionResponseChoice(BaseModel): ...@@ -364,7 +384,7 @@ class CompletionResponseChoice(BaseModel):
) )
class CompletionResponse(BaseModel): class CompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion" object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
...@@ -373,12 +393,12 @@ class CompletionResponse(BaseModel): ...@@ -373,12 +393,12 @@ class CompletionResponse(BaseModel):
usage: UsageInfo usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel): class CompletionResponseStreamChoice(OpenAIBaseModel):
index: int index: int
text: str text: str
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[str] = None
stop_reason: Union[None, int, str] = Field( stop_reason: Optional[Union[int, str]] = Field(
default=None, default=None,
description=( description=(
"The stop string or token id that caused the completion " "The stop string or token id that caused the completion "
...@@ -387,7 +407,7 @@ class CompletionResponseStreamChoice(BaseModel): ...@@ -387,7 +407,7 @@ class CompletionResponseStreamChoice(BaseModel):
) )
class CompletionStreamResponse(BaseModel): class CompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion" object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
...@@ -396,20 +416,20 @@ class CompletionStreamResponse(BaseModel): ...@@ -396,20 +416,20 @@ class CompletionStreamResponse(BaseModel):
usage: Optional[UsageInfo] = Field(default=None) usage: Optional[UsageInfo] = Field(default=None)
class ChatMessage(BaseModel): class ChatMessage(OpenAIBaseModel):
role: str role: str
content: str content: str
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int index: int
message: ChatMessage message: ChatMessage
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[str] = None
stop_reason: Union[None, int, str] = None stop_reason: Optional[Union[int, str]] = None
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion" object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
...@@ -418,20 +438,20 @@ class ChatCompletionResponse(BaseModel): ...@@ -418,20 +438,20 @@ class ChatCompletionResponse(BaseModel):
usage: UsageInfo usage: UsageInfo
class DeltaMessage(BaseModel): class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None role: Optional[str] = None
content: Optional[str] = None content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
index: int index: int
delta: DeltaMessage delta: DeltaMessage
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[str] = None
stop_reason: Union[None, int, str] = None stop_reason: Optional[Union[int, str]] = None
class ChatCompletionStreamResponse(BaseModel): class ChatCompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk" object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
......
import asyncio
import codecs import codecs
import time import time
from typing import AsyncGenerator, AsyncIterator, List, Optional, Union from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
Optional, Tuple, TypedDict, Union, final)
from fastapi import Request from fastapi import Request
from openai.types.chat import (ChatCompletionContentPartParam,
ChatCompletionRole)
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
...@@ -10,7 +14,8 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -10,7 +14,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo) UsageInfo)
from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor) get_guided_decoding_logits_processor)
...@@ -20,19 +25,49 @@ from vllm.utils import random_uuid ...@@ -20,19 +25,49 @@ from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
@final # So that it should be compatible with Dict[str, str]
class ConversationMessage(TypedDict):
role: str
content: str
class OpenAIServingChat(OpenAIServing): class OpenAIServingChat(OpenAIServing):
def __init__(self, def __init__(self,
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
served_model_names: List[str], served_model_names: List[str],
response_role: str, response_role: str,
lora_modules: Optional[List[LoRA]] = None, lora_modules: Optional[List[LoRAModulePath]] = None,
chat_template=None): chat_template: Optional[str] = None):
super().__init__(engine=engine, super().__init__(engine=engine,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=lora_modules) lora_modules=lora_modules,
await_post_init=self._load_chat_template(
chat_template=chat_template))
self.response_role = response_role self.response_role = response_role
self._load_chat_template(chat_template)
def _parse_chat_message_content(
self,
role: ChatCompletionRole,
content: Optional[Union[str,
Iterable[ChatCompletionContentPartParam]]],
) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]:
if content is None:
return [], []
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)], []
texts: List[str] = []
for _, part in enumerate(content):
if part["type"] == "text":
text = part["text"]
texts.append(text)
else:
raise NotImplementedError(f"Unknown part type: {part['type']}")
return [ConversationMessage(role=role, content="\n".join(texts))], []
async def create_chat_completion( async def create_chat_completion(
self, request: ChatCompletionRequest, raw_request: Request self, request: ChatCompletionRequest, raw_request: Request
...@@ -52,13 +87,21 @@ class OpenAIServingChat(OpenAIServing): ...@@ -52,13 +87,21 @@ class OpenAIServingChat(OpenAIServing):
return error_check_ret return error_check_ret
try: try:
conversation: List[ConversationMessage] = []
for m in request.messages:
messages, _ = self._parse_chat_message_content(
m["role"], m["content"])
conversation.extend(messages)
prompt = self.tokenizer.apply_chat_template( prompt = self.tokenizer.apply_chat_template(
conversation=request.messages, conversation=conversation,
tokenize=False, tokenize=False,
add_generation_prompt=request.add_generation_prompt) add_generation_prompt=request.add_generation_prompt,
)
except Exception as e: except Exception as e:
logger.error( logger.error("Error in applying chat template from request: %s", e)
f"Error in applying chat template from request: {str(e)}")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
...@@ -68,7 +111,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -68,7 +111,7 @@ class OpenAIServingChat(OpenAIServing):
request, prompt=prompt) request, prompt=prompt)
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request) lora_request = self._maybe_get_lora(request)
decoding_config = self.engine.engine.decoding_config decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \ guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend or decoding_config.guided_decoding_backend
guided_decode_logits_processor = ( guided_decode_logits_processor = (
...@@ -89,11 +132,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -89,11 +132,12 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response # Streaming response
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
request, result_generator, request_id) request, result_generator, request_id, conversation)
else: else:
try: try:
return await self.chat_completion_full_generator( return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id) request, raw_request, result_generator, request_id,
conversation)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
...@@ -106,9 +150,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -106,9 +150,9 @@ class OpenAIServingChat(OpenAIServing):
async def chat_completion_stream_generator( async def chat_completion_stream_generator(
self, request: ChatCompletionRequest, self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], request_id: str result_generator: AsyncIterator[RequestOutput], request_id: str,
) -> Union[ErrorResponse, AsyncGenerator[str, None]]: conversation: List[ConversationMessage]
) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0] model_name = self.served_model_names[0]
created_time = int(time.time()) created_time = int(time.time())
chunk_object_type = "chat.completion.chunk" chunk_object_type = "chat.completion.chunk"
...@@ -147,12 +191,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -147,12 +191,10 @@ class OpenAIServingChat(OpenAIServing):
# last message # last message
if request.echo: if request.echo:
last_msg_content = "" last_msg_content = ""
if request.messages and isinstance( if conversation and conversation[-1].get(
request.messages, "content") and conversation[-1].get(
list) and request.messages[-1].get( "role") == role:
"content") and request.messages[-1].get( last_msg_content = conversation[-1]["content"]
"role") == role:
last_msg_content = request.messages[-1]["content"]
if last_msg_content: if last_msg_content:
for i in range(request.n): for i in range(request.n):
...@@ -247,13 +289,14 @@ class OpenAIServingChat(OpenAIServing): ...@@ -247,13 +289,14 @@ class OpenAIServingChat(OpenAIServing):
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
async def chat_completion_full_generator( async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Request, self, request: ChatCompletionRequest, raw_request: Request,
result_generator: AsyncIterator[RequestOutput], result_generator: AsyncIterator[RequestOutput], request_id: str,
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]: conversation: List[ConversationMessage]
) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.served_model_names[0] model_name = self.served_model_names[0]
created_time = int(time.time()) created_time = int(time.time())
final_res: RequestOutput = None final_res: Optional[RequestOutput] = None
async for res in result_generator: async for res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
...@@ -290,11 +333,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -290,11 +333,9 @@ class OpenAIServingChat(OpenAIServing):
if request.echo: if request.echo:
last_msg_content = "" last_msg_content = ""
if request.messages and isinstance( if conversation and conversation[-1].get(
request.messages, list) and request.messages[-1].get( "content") and conversation[-1].get("role") == role:
"content") and request.messages[-1].get( last_msg_content = conversation[-1]["content"]
"role") == role:
last_msg_content = request.messages[-1]["content"]
for choice in choices: for choice in choices:
full_message = last_msg_content + choice.message.content full_message = last_msg_content + choice.message.content
...@@ -318,7 +359,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -318,7 +359,10 @@ class OpenAIServingChat(OpenAIServing):
return response return response
def _load_chat_template(self, chat_template): async def _load_chat_template(self, chat_template: Optional[str]):
while self.tokenizer is None:
# Give the parent class time to load the tokenizer
await asyncio.sleep(0.1)
tokenizer = self.tokenizer tokenizer = self.tokenizer
if chat_template is not None: if chat_template is not None:
...@@ -338,11 +382,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -338,11 +382,11 @@ class OpenAIServingChat(OpenAIServing):
tokenizer.chat_template = codecs.decode( tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape") chat_template, "unicode_escape")
logger.info( logger.info("Using supplied chat template:\n%s",
f"Using supplied chat template:\n{tokenizer.chat_template}") tokenizer.chat_template)
elif tokenizer.chat_template is not None: elif tokenizer.chat_template is not None:
logger.info( logger.info("Using default chat template:\n%s",
f"Using default chat template:\n{tokenizer.chat_template}") tokenizer.chat_template)
else: else:
logger.warning( logger.warning(
"No chat template provided. Chat API will not work.") "No chat template provided. Chat API will not work.")
...@@ -11,7 +11,8 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest, ...@@ -11,7 +11,8 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponseStreamChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, CompletionStreamResponse,
LogProbs, UsageInfo) LogProbs, UsageInfo)
from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor) get_guided_decoding_logits_processor)
...@@ -54,7 +55,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -54,7 +55,7 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__(self, def __init__(self,
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
served_model_names: List[str], served_model_names: List[str],
lora_modules: Optional[List[LoRA]] = None): lora_modules: Optional[List[LoRAModulePath]] = None):
super().__init__(engine=engine, super().__init__(engine=engine,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=lora_modules) lora_modules=lora_modules)
...@@ -84,11 +85,11 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -84,11 +85,11 @@ class OpenAIServingCompletion(OpenAIServing):
created_time = int(time.time()) created_time = int(time.time())
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators = [] generators: List[AsyncIterator[RequestOutput]] = []
try: try:
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request) lora_request = self._maybe_get_lora(request)
decoding_config = self.engine.engine.decoding_config decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \ guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend or decoding_config.guided_decoding_backend
guided_decode_logit_processor = ( guided_decode_logit_processor = (
...@@ -148,7 +149,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -148,7 +149,7 @@ class OpenAIServingCompletion(OpenAIServing):
num_prompts=len(prompts)) num_prompts=len(prompts))
# Non-streaming response # Non-streaming response
final_res_batch: RequestOutput = [None] * len(prompts) final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
try: try:
async for i, res in result_generator: async for i, res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
......
...@@ -2,7 +2,7 @@ import asyncio ...@@ -2,7 +2,7 @@ import asyncio
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Tuple, Union from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
from pydantic import Field from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
...@@ -22,7 +22,7 @@ logger = init_logger(__name__) ...@@ -22,7 +22,7 @@ logger = init_logger(__name__)
@dataclass @dataclass
class LoRA: class LoRAModulePath:
name: str name: str
local_path: str local_path: str
...@@ -32,7 +32,8 @@ class OpenAIServing: ...@@ -32,7 +32,8 @@ class OpenAIServing:
def __init__(self, def __init__(self,
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
served_model_names: List[str], served_model_names: List[str],
lora_modules=Optional[List[LoRA]]): lora_modules: Optional[List[LoRAModulePath]],
await_post_init: Optional[Awaitable[Any]] = None):
self.engine = engine self.engine = engine
self.served_model_names = served_model_names self.served_model_names = served_model_names
if lora_modules is None: if lora_modules is None:
...@@ -58,12 +59,12 @@ class OpenAIServing: ...@@ -58,12 +59,12 @@ class OpenAIServing:
if event_loop is not None and event_loop.is_running(): if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve, # If the current is instanced by Ray Serve,
# there is already a running event loop # there is already a running event loop
event_loop.create_task(self._post_init()) event_loop.create_task(self._post_init(await_post_init))
else: else:
# When using single vLLM without engine_use_ray # When using single vLLM without engine_use_ray
asyncio.run(self._post_init()) asyncio.run(self._post_init(await_post_init))
async def _post_init(self): async def _post_init(self, await_post_init):
engine_model_config = await self.engine.get_model_config() engine_model_config = await self.engine.get_model_config()
self.max_model_len = engine_model_config.max_model_len self.max_model_len = engine_model_config.max_model_len
...@@ -75,6 +76,9 @@ class OpenAIServing: ...@@ -75,6 +76,9 @@ class OpenAIServing:
trust_remote_code=engine_model_config.trust_remote_code, trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left") truncation_side="left")
if await_post_init is not None:
await await_post_init
async def show_available_models(self) -> ModelList: async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
model_cards = [ model_cards = [
...@@ -158,7 +162,9 @@ class OpenAIServing: ...@@ -158,7 +162,9 @@ class OpenAIServing:
}) })
return json_str return json_str
async def _check_model(self, request) -> Optional[ErrorResponse]: async def _check_model(
self, request: Union[CompletionRequest, ChatCompletionRequest]
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names: if request.model in self.served_model_names:
return None return None
if request.model in [lora.lora_name for lora in self.lora_requests]: if request.model in [lora.lora_name for lora in self.lora_requests]:
...@@ -168,14 +174,16 @@ class OpenAIServing: ...@@ -168,14 +174,16 @@ class OpenAIServing:
err_type="NotFoundError", err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND) status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora(self, request) -> Optional[LoRARequest]: def _maybe_get_lora(
self, request: Union[CompletionRequest, ChatCompletionRequest]
) -> Optional[LoRARequest]:
if request.model in self.served_model_names: if request.model in self.served_model_names:
return None return None
for lora in self.lora_requests: for lora in self.lora_requests:
if request.model == lora.lora_name: if request.model == lora.lora_name:
return lora return lora
# if _check_model has been called earlier, this will be unreachable # if _check_model has been called earlier, this will be unreachable
raise ValueError("The model `{request.model}` does not exist.") raise ValueError(f"The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize( def _validate_prompt_and_tokenize(
self, self,
......
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
if TYPE_CHECKING:
VLLM_HOST_IP: str = ""
VLLM_USE_MODELSCOPE: bool = False
VLLM_INSTANCE_ID: Optional[str] = None
VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
VLLM_API_KEY: Optional[str] = None
S3_ACCESS_KEY_ID: Optional[str] = None
S3_SECRET_ACCESS_KEY: Optional[str] = None
S3_ENDPOINT_URL: Optional[str] = None
VLLM_CONFIG_ROOT: str = ""
VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai"
VLLM_NO_USAGE_STATS: bool = False
VLLM_DO_NOT_TRACK: bool = False
VLLM_USAGE_SOURCE: str = ""
VLLM_CONFIGURE_LOGGING: int = 1
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "spawn"
VLLM_TARGET_DEVICE: str = "cuda"
MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None
VLLM_BUILD_WITH_NEURON: bool = False
VLLM_USE_PRECOMPILED: bool = False
VLLM_INSTALL_PUNICA_KERNELS: bool = False
CMAKE_BUILD_TYPE: Optional[str] = None
VERBOSE: bool = False
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
# begin-env-vars-definition
environment_variables: Dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ==================
# Target device of vLLM, supporting [cuda (by default), rocm, neuron, cpu]
"VLLM_TARGET_DEVICE":
lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"),
# Maximum number of compilation jobs to run in parallel.
# By default this is the number of CPUs
"MAX_JOBS":
lambda: os.getenv("MAX_JOBS", None),
# Number of threads to use for nvcc
# By default this is 1.
# If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU.
"NVCC_THREADS":
lambda: os.getenv("NVCC_THREADS", None),
# If set, vllm will build with Neuron support
"VLLM_BUILD_WITH_NEURON":
lambda: bool(os.environ.get("VLLM_BUILD_WITH_NEURON", False)),
# If set, vllm will use precompiled binaries (*.so)
"VLLM_USE_PRECOMPILED":
lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")),
# If set, vllm will install Punica kernels
"VLLM_INSTALL_PUNICA_KERNELS":
lambda: bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))),
# CMake build type
# If not set, defaults to "Debug" or "RelWithDebInfo"
# Available options: "Debug", "Release", "RelWithDebInfo"
"CMAKE_BUILD_TYPE":
lambda: os.getenv("CMAKE_BUILD_TYPE"),
# If set, vllm will print verbose logs during installation
"VERBOSE":
lambda: bool(int(os.getenv('VERBOSE', '0'))),
# Root directory for VLLM configuration files
# Note that this not only affects how vllm finds its configuration files
# during runtime, but also affects how vllm installs its configuration
# files during **installation**.
"VLLM_CONFIG_ROOT":
lambda: os.environ.get("VLLM_CONFIG_ROOT", None) or os.getenv(
"XDG_CONFIG_HOME", None) or os.path.expanduser("~/.config"),
# ================== Runtime Env Vars ==================
# used in distributed environment to determine the master address
'VLLM_HOST_IP':
lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""),
# If true, will load models from ModelScope instead of Hugging Face Hub.
# note that the value is true or false, not numbers
"VLLM_USE_MODELSCOPE":
lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true",
# Instance id represents an instance of the VLLM. All processes in the same
# instance should have the same instance id.
"VLLM_INSTANCE_ID":
lambda: os.environ.get("VLLM_INSTANCE_ID", None),
# path to cudatoolkit home directory, under which should be bin, include,
# and lib directories.
"CUDA_HOME":
lambda: os.environ.get("CUDA_HOME", None),
# Path to the NCCL library file. It is needed because nccl>=2.19 brought
# by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234
"VLLM_NCCL_SO_PATH":
lambda: os.environ.get("VLLM_NCCL_SO_PATH", None),
# when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl
# library file in the locations specified by `LD_LIBRARY_PATH`
"LD_LIBRARY_PATH":
lambda: os.environ.get("LD_LIBRARY_PATH", None),
# flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK":
lambda: int(os.environ.get("LOCAL_RANK", "0")),
# used to control the visible devices in the distributed setting
"CUDA_VISIBLE_DEVICES":
lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None),
# timeout for each iteration in the engine
"VLLM_ENGINE_ITERATION_TIMEOUT_S":
lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")),
# API key for VLLM API server
"VLLM_API_KEY":
lambda: os.environ.get("VLLM_API_KEY", None),
# S3 access information, used for tensorizer to load model from S3
"S3_ACCESS_KEY_ID":
lambda: os.environ.get("S3_ACCESS_KEY", None),
"S3_SECRET_ACCESS_KEY":
lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None),
"S3_ENDPOINT_URL":
lambda: os.environ.get("S3_ENDPOINT_URL", None),
# Usage stats collection
"VLLM_USAGE_STATS_SERVER":
lambda: os.environ.get("VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai"),
"VLLM_NO_USAGE_STATS":
lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1",
"VLLM_DO_NOT_TRACK":
lambda: (os.environ.get("VLLM_DO_NOT_TRACK", None) or os.environ.get(
"DO_NOT_TRACK", None) or "0") == "1",
"VLLM_USAGE_SOURCE":
lambda: os.environ.get("VLLM_USAGE_SOURCE", "production"),
# Logging configuration
# If set to 0, vllm will not configure logging
# If set to 1, vllm will configure logging using the default configuration
# or the configuration file specified by VLLM_LOGGING_CONFIG_PATH
"VLLM_CONFIGURE_LOGGING":
lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")),
"VLLM_LOGGING_CONFIG_PATH":
lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"),
# Trace function calls
# If set to 1, vllm will trace function calls
# Useful for debugging
"VLLM_TRACE_FUNCTION":
lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")),
# Backend for attention computation
# Available options:
# - "TORCH_SDPA": use torch.nn.MultiheadAttention
# - "FLASH_ATTN": use FlashAttention
# - "XFORMERS": use XFormers
# - "ROCM_FLASH": use ROCmFlashAttention
"VLLM_ATTENTION_BACKEND":
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
# CPU key-value cache space
# default is 4GB
"VLLM_CPU_KVCACHE_SPACE":
lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")),
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
"VLLM_USE_RAY_COMPILED_DAG":
lambda: bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)),
# Use dedicated multiprocess context for workers.
# Both spawn and fork work
"VLLM_WORKER_MULTIPROC_METHOD":
lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"),
}
# end-env-vars-definition
def __getattr__(name):
# lazy evaluation of environment variables
if name in environment_variables:
return environment_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
return list(environment_variables.keys())
import os from typing import List, Set, Tuple
from typing import Dict, List, Set, Tuple
import torch import torch
import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async) make_async)
...@@ -69,21 +69,13 @@ class CPUExecutor(ExecutorBase): ...@@ -69,21 +69,13 @@ class CPUExecutor(ExecutorBase):
# NOTE: `cpu block` for CPU backend is located on CPU memory but is # NOTE: `cpu block` for CPU backend is located on CPU memory but is
# referred as `gpu block`. Because we want to reuse the existing block # referred as `gpu block`. Because we want to reuse the existing block
# management procedure. # management procedure.
logger.info(f"# CPU blocks: {num_gpu_blocks}") logger.info("# CPU blocks: %d", num_gpu_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(self, def execute_model(
seq_group_metadata_list: List[SequenceGroupMetadata], self,
blocks_to_swap_in: Dict[int, int], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_out: Dict[int, int], output = self.driver_worker.execute_model(execute_model_req)
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
return output return output
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
...@@ -104,17 +96,10 @@ class CPUExecutor(ExecutorBase): ...@@ -104,17 +96,10 @@ class CPUExecutor(ExecutorBase):
class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase): class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
async def execute_model_async( async def execute_model_async(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_in: Dict[int, int], output = await make_async(self.driver_worker.execute_model
blocks_to_swap_out: Dict[int, int], )(execute_model_req=execute_model_req, )
blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy)
return output return output
async def check_health_async(self) -> None: async def check_health_async(self) -> None:
...@@ -150,8 +135,7 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: ...@@ -150,8 +135,7 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
logger.warning("Prefix caching is not supported on CPU, disable it.") logger.warning("Prefix caching is not supported on CPU, disable it.")
config.enable_prefix_caching = False config.enable_prefix_caching = False
kv_cache_space_str = os.getenv("VLLM_CPU_KVCACHE_SPACE", "0") kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
kv_cache_space = int(kv_cache_space_str)
if kv_cache_space >= 0: if kv_cache_space >= 0:
if kv_cache_space == 0: if kv_cache_space == 0:
......
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput
logger = init_logger(__name__)
class DistributedGPUExecutor(GPUExecutor):
"""Abstract superclass of multi-GPU executor implementations."""
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
Returns:
- tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers("determine_num_available_blocks", )
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache in all workers.
"""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self._run_workers("initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
all_outputs = self._run_workers("execute_model",
driver_args=args,
driver_kwargs=kwargs)
# Only the driver worker returns the sampling results.
return all_outputs[0]
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"add_lora",
lora_request=lora_request,
)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"remove_lora",
lora_id=lora_id,
)
def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")
@abstractmethod
def _run_workers(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
raise NotImplementedError
class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
@abstractmethod
async def _run_workers_async(
self,
method: str,
*args,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
raise NotImplementedError
async def execute_model_async(self, *args,
**kwargs) -> List[SamplerOutput]:
all_outputs = await self._run_workers_async("execute_model",
driver_args=args,
driver_kwargs=kwargs)
# Only the driver worker returns the sampling results.
return all_outputs[0]
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, VisionLanguageConfig) SpeculativeConfig, VisionLanguageConfig)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
class ExecutorBase(ABC): class ExecutorBase(ABC):
...@@ -68,12 +68,9 @@ class ExecutorBase(ABC): ...@@ -68,12 +68,9 @@ class ExecutorBase(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def execute_model(self, def execute_model(
seq_group_metadata_list: List[SequenceGroupMetadata], self,
blocks_to_swap_in: Dict[int, int], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int) -> List[SamplerOutput]:
"""Executes at least one model step on the given sequences.""" """Executes at least one model step on the given sequences."""
raise NotImplementedError raise NotImplementedError
...@@ -95,17 +92,20 @@ class ExecutorBase(ABC): ...@@ -95,17 +92,20 @@ class ExecutorBase(ABC):
exception.""" exception."""
raise NotImplementedError raise NotImplementedError
def shutdown(self) -> None:
"""Shutdown the executor."""
return
def __del__(self):
self.shutdown()
class ExecutorAsyncBase(ExecutorBase): class ExecutorAsyncBase(ExecutorBase):
@abstractmethod @abstractmethod
async def execute_model_async( async def execute_model_async(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput:
"""Executes one model step on the given sequences.""" """Executes one model step on the given sequences."""
raise NotImplementedError raise NotImplementedError
......
from typing import Dict, List, Set, Tuple from typing import Any, Dict, List, Optional, Set, Tuple
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async) make_async)
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -23,30 +24,47 @@ class GPUExecutor(ExecutorBase): ...@@ -23,30 +24,47 @@ class GPUExecutor(ExecutorBase):
else: else:
self._init_spec_worker() self._init_spec_worker()
def _init_non_spec_worker(self): def _get_worker_kwargs(
# Lazy import the Worker to avoid importing torch.cuda/xformers self,
# before CUDA_VISIBLE_DEVICES is set in the Worker local_rank: int = 0,
from vllm.worker.worker import Worker rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
assert self.parallel_config.world_size == 1, ( """Return worker init args for a given rank."""
"GPUExecutor only supports single GPU.") if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
distributed_init_method = get_distributed_init_method( get_ip(), get_open_port())
get_ip(), get_open_port()) return dict(
self.driver_worker = Worker(
model_config=self.model_config, model_config=self.model_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
device_config=self.device_config, device_config=self.device_config,
cache_config=self.cache_config, cache_config=self.cache_config,
load_config=self.load_config, load_config=self.load_config,
local_rank=0, local_rank=local_rank,
rank=0, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config, vision_language_config=self.vision_language_config,
is_driver_worker=True, is_driver_worker=rank == 0,
) )
def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
def _init_non_spec_worker(self):
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = self._create_worker()
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
...@@ -55,46 +73,23 @@ class GPUExecutor(ExecutorBase): ...@@ -55,46 +73,23 @@ class GPUExecutor(ExecutorBase):
""" """
assert self.speculative_config is not None assert self.speculative_config is not None
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.worker.worker import Worker
distributed_init_method = get_distributed_init_method( target_worker = self._create_worker()
get_ip(), get_open_port())
target_worker = Worker( draft_worker_kwargs = self._get_worker_kwargs()
model_config=self.model_config, # Override draft-model specific worker args.
parallel_config=self.parallel_config, draft_worker_kwargs.update(
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
is_driver_worker=True,
)
draft_worker = MultiStepWorker(
model_config=self.speculative_config.draft_model_config, model_config=self.speculative_config.draft_model_config,
parallel_config=self.speculative_config.draft_parallel_config, parallel_config=self.speculative_config.draft_parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
# TODO allow draft-model specific load config. # TODO allow draft-model specific load config.
load_config=self.load_config, #load_config=self.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
is_driver_worker=True,
) )
spec_decode_worker = SpecDecodeWorker.from_workers( spec_decode_worker = SpecDecodeWorker.create_worker(
proposer_worker=draft_worker, scorer_worker=target_worker) scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
)
assert self.parallel_config.world_size == 1, ( assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.") "GPUExecutor only supports single GPU.")
...@@ -116,26 +111,15 @@ class GPUExecutor(ExecutorBase): ...@@ -116,26 +111,15 @@ class GPUExecutor(ExecutorBase):
# NOTE: This is logged in the executor because there can be >1 worker # NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work # with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations. # remains to abstract away the device for non-GPU configurations.
logger.info(f"# GPU blocks: {num_gpu_blocks}, " logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
f"# CPU blocks: {num_cpu_blocks}") num_cpu_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_in: Dict[int, int], output = self.driver_worker.execute_model(execute_model_req)
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int,
) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
num_lookahead_slots=num_lookahead_slots,
)
return output return output
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
...@@ -159,14 +143,8 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): ...@@ -159,14 +143,8 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
async def execute_model_async( async def execute_model_async(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Dict[int, int], ) -> List[SamplerOutput]:
blocks_to_swap_out: Dict[int, int], output = await make_async(self.driver_worker.execute_model
blocks_to_copy: Dict[int, List[int]], )(execute_model_req=execute_model_req, )
) -> SamplerOutput:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy)
return output return output
import asyncio
import multiprocessing
import os
import sys
import threading
import traceback
import uuid
from dataclasses import dataclass
from multiprocessing import Queue
from multiprocessing.connection import wait
from multiprocessing.process import BaseProcess
from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
TypeVar, Union)
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
T = TypeVar('T')
_TERMINATE = "TERMINATE" # sentinel
# ANSI color codes
CYAN = '\033[1;36m'
RESET = '\033[0;0m'
JOIN_TIMEOUT_S = 2
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
mp = multiprocessing.get_context(mp_method)
@dataclass
class Result(Generic[T]):
"""Result of task dispatched to worker"""
task_id: uuid.UUID
value: Optional[T] = None
exception: Optional[BaseException] = None
class ResultFuture(threading.Event, Generic[T]):
"""Synchronous future for non-async case"""
def __init__(self):
super().__init__()
self.result: Optional[Result[T]] = None
def set_result(self, result: Result[T]):
self.result = result
self.set()
def get(self) -> T:
self.wait()
assert self.result is not None
if self.result.exception is not None:
raise self.result.exception
return self.result.value # type: ignore[return-value]
def _set_future_result(future: Union[ResultFuture, asyncio.Future],
result: Result):
if isinstance(future, ResultFuture):
future.set_result(result)
return
loop = future.get_loop()
if result.exception is not None:
loop.call_soon_threadsafe(future.set_exception, result.exception)
else:
loop.call_soon_threadsafe(future.set_result, result.value)
class ResultHandler(threading.Thread):
"""Handle results from all workers (in background thread)"""
def __init__(self) -> None:
super().__init__(daemon=True)
self.result_queue = mp.Queue()
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
def run(self):
for result in iter(self.result_queue.get, _TERMINATE):
future = self.tasks.pop(result.task_id)
_set_future_result(future, result)
# Ensure that all waiters will receive an exception
for task_id, future in self.tasks.items():
_set_future_result(
future,
Result(task_id=task_id,
exception=ChildProcessError("worker died")))
def close(self):
self.result_queue.put(_TERMINATE)
class WorkerMonitor(threading.Thread):
"""Monitor worker status (in background thread)"""
def __init__(self, workers: List['ProcessWorkerWrapper'],
result_handler: ResultHandler):
super().__init__(daemon=True)
self.workers = workers
self.result_handler = result_handler
self._close = False
def run(self) -> None:
# Blocks until any worker exits
dead_sentinels = wait([w.process.sentinel for w in self.workers])
if not self._close:
self._close = True
# Kill / cleanup all workers
for worker in self.workers:
process = worker.process
if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0:
logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid, process.exitcode)
# Cleanup any remaining workers
logger.info("Killing local vLLM worker processes")
for worker in self.workers:
worker.kill_worker()
# Must be done after worker task queues are all closed
self.result_handler.close()
for worker in self.workers:
worker.process.join(JOIN_TIMEOUT_S)
def close(self):
if self._close:
return
self._close = True
logger.info("Terminating local vLLM worker processes")
for worker in self.workers:
worker.terminate_worker()
# Must be done after worker task queues are all closed
self.result_handler.close()
class ProcessWorkerWrapper:
"""Local process wrapper for vllm.worker.Worker,
for handling single-node multi-GPU tensor parallel."""
def __init__(self, result_handler: ResultHandler,
worker_factory: Callable[[], Any]) -> None:
self._task_queue = mp.Queue()
self.result_queue = result_handler.result_queue
self.tasks = result_handler.tasks
self.process: BaseProcess = mp.Process( # type: ignore[attr-defined]
target=_run_worker_process,
name="VllmWorkerProcess",
kwargs=dict(
worker_factory=worker_factory,
task_queue=self._task_queue,
result_queue=self.result_queue,
),
daemon=True)
self.process.start()
def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
method: str, args, kwargs):
task_id = uuid.uuid4()
self.tasks[task_id] = future
try:
self._task_queue.put((task_id, method, args, kwargs))
except BaseException as e:
del self.tasks[task_id]
raise ChildProcessError("worker died") from e
def execute_method(self, method: str, *args, **kwargs):
future: ResultFuture = ResultFuture()
self._enqueue_task(future, method, args, kwargs)
return future
async def execute_method_async(self, method: str, *args, **kwargs):
future = asyncio.get_running_loop().create_future()
self._enqueue_task(future, method, args, kwargs)
return await future
def terminate_worker(self):
try:
self._task_queue.put(_TERMINATE)
except ValueError:
self.process.kill()
self._task_queue.close()
def kill_worker(self):
self._task_queue.close()
self.process.kill()
def _run_worker_process(
worker_factory: Callable[[], Any],
task_queue: Queue,
result_queue: Queue,
) -> None:
"""Worker process event loop"""
# Add process-specific prefix to stdout and stderr
process_name = mp.current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
# Initialize worker
worker = worker_factory()
del worker_factory
# Accept tasks from the engine in task_queue
# and return task output in result_queue
logger.info("Worker ready; awaiting tasks")
try:
for items in iter(task_queue.get, _TERMINATE):
output = None
exception = None
task_id, method, args, kwargs = items
try:
executor = getattr(worker, method)
output = executor(*args, **kwargs)
except BaseException as e:
tb = traceback.format_exc()
logger.error(
"Exception in worker %s while processing method %s: %s, %s",
process_name, method, e, tb)
exception = e
result_queue.put(
Result(task_id=task_id, value=output, exception=exception))
except KeyboardInterrupt:
pass
except Exception:
logger.exception("Worker failed")
logger.info("Worker exiting")
def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
"""Prepend each output line with process-specific prefix"""
prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
file_write = file.write
def write_with_prefix(s: str):
if not s:
return
if file.start_new_line: # type: ignore[attr-defined]
file_write(prefix)
idx = 0
while (next_idx := s.find('\n', idx)) != -1:
next_idx += 1
file_write(s[idx:next_idx])
if next_idx == len(s):
file.start_new_line = True # type: ignore[attr-defined]
return
file_write(prefix)
idx = next_idx
file_write(s[idx:])
file.start_new_line = False # type: ignore[attr-defined]
file.start_new_line = True # type: ignore[attr-defined]
file.write = write_with_prefix # type: ignore[method-assign]
from typing import Dict, List, Set, Tuple from typing import List, Set, Tuple
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import make_async from vllm.utils import make_async
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -45,20 +45,18 @@ class NeuronExecutor(ExecutorBase): ...@@ -45,20 +45,18 @@ class NeuronExecutor(ExecutorBase):
""" """
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(self, def execute_model(
seq_group_metadata_list: List[SequenceGroupMetadata], self,
blocks_to_swap_in: Dict[int, int], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_out: Dict[int, int], assert (execute_model_req.blocks_to_swap_in == {}
blocks_to_copy: Dict[int, List[int]], and execute_model_req.blocks_to_swap_out == {}
num_lookahead_slots: int) -> List[SamplerOutput]: and execute_model_req.blocks_to_copy == {}), (
assert (blocks_to_swap_in == {} and blocks_to_swap_out == {}
and blocks_to_copy == {}), (
"Cache operations are not supported for Neuron backend.") "Cache operations are not supported for Neuron backend.")
assert num_lookahead_slots == 0, ( assert execute_model_req.num_lookahead_slots == 0, (
"lookahead not supported for Neuron backend.") "lookahead not supported for Neuron backend.")
output = self.driver_worker.execute_model( output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list) execute_model_req.seq_group_metadata_list)
return output return output
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
...@@ -80,13 +78,11 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase): ...@@ -80,13 +78,11 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
async def execute_model_async( async def execute_model_async(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Dict[int, int], ) -> List[SamplerOutput]:
blocks_to_swap_out: Dict[int, int], output = await make_async(
blocks_to_copy: Dict[int, List[int]], self.driver_worker.execute_model
) -> SamplerOutput: )(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, )
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list, )
return output return output
async def check_health_async(self) -> None: async def check_health_async(self) -> None:
......
...@@ -3,13 +3,14 @@ import os ...@@ -3,13 +3,14 @@ import os
import pickle import pickle
from collections import defaultdict from collections import defaultdict
from itertools import islice, repeat from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from vllm.engine.ray_utils import RayWorkerWrapper, ray import vllm.envs as envs
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async) get_vllm_instance_id, make_async)
...@@ -21,13 +22,10 @@ if TYPE_CHECKING: ...@@ -21,13 +22,10 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
# If the env var is set, it uses the Ray's compiled DAG API USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
class RayGPUExecutor(ExecutorBase): class RayGPUExecutor(DistributedGPUExecutor):
def _init_executor(self) -> None: def _init_executor(self) -> None:
assert (not self.speculative_config assert (not self.speculative_config
...@@ -74,7 +72,7 @@ class RayGPUExecutor(ExecutorBase): ...@@ -74,7 +72,7 @@ class RayGPUExecutor(ExecutorBase):
# The driver dummy worker does not actually use any resources. # The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker. # It holds the resource for the driver worker.
self.driver_dummy_worker: RayWorkerWrapper = None self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
# The remaining workers are the actual ray actors. # The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = [] self.workers: List[RayWorkerWrapper] = []
...@@ -145,7 +143,7 @@ class RayGPUExecutor(ExecutorBase): ...@@ -145,7 +143,7 @@ class RayGPUExecutor(ExecutorBase):
"VLLM_INSTANCE_ID": "VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID, VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION": "VLLM_TRACE_FUNCTION":
os.getenv("VLLM_TRACE_FUNCTION", "0"), str(envs.VLLM_TRACE_FUNCTION),
}, ) for (node_id, _) in worker_node_and_gpu_ids] }, ) for (node_id, _) in worker_node_and_gpu_ids]
self._run_workers("update_environment_variables", self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables) all_args=all_args_to_update_environment_variables)
...@@ -153,113 +151,31 @@ class RayGPUExecutor(ExecutorBase): ...@@ -153,113 +151,31 @@ class RayGPUExecutor(ExecutorBase):
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port()) driver_ip, get_open_port())
def collect_arg_helper_func(**kwargs):
# avoid writing `{"name": value}` manually
return kwargs
# Initialize the actual workers inside worker wrapper. # Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = [] init_worker_all_kwargs = [
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids): self._get_worker_kwargs(
local_rank = node_workers[node_id].index(rank) local_rank=node_workers[node_id].index(rank),
init_worker_all_kwargs.append( rank=rank,
collect_arg_helper_func( distributed_init_method=distributed_init_method,
model_config=self.model_config, ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
parallel_config=self.parallel_config, ]
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
is_driver_worker=rank == 0,
))
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("init_device") self._run_workers("init_device")
self._run_workers( self._run_workers("load_model",
"load_model", max_concurrent_workers=self.parallel_config.
max_concurrent_workers=self.parallel_config. max_parallel_loading_workers)
max_parallel_loading_workers,
)
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
Returns:
- Tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers("determine_num_available_blocks", )
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int, def execute_model(
num_cpu_blocks: int) -> None: self,
"""Initialize the KV cache in all workers. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}")
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self._run_workers("initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int = 0) -> SamplerOutput:
all_outputs = self._run_workers( all_outputs = self._run_workers(
"execute_model", "execute_model",
driver_kwargs={ driver_kwargs={"execute_model_req": execute_model_req},
"seq_group_metadata_list": seq_group_metadata_list,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
},
use_ray_compiled_dag=USE_RAY_COMPILED_DAG) use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
# Only the driver worker returns the sampling results. # Only the driver worker returns the sampling results.
output = all_outputs[0] return all_outputs[0]
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"add_lora",
lora_request=lora_request,
)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"remove_lora",
lora_id=lora_id,
)
def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")
def _run_workers( def _run_workers(
self, self,
...@@ -318,6 +234,7 @@ class RayGPUExecutor(ExecutorBase): ...@@ -318,6 +234,7 @@ class RayGPUExecutor(ExecutorBase):
driver_worker_output = self.driver_worker.execute_method( driver_worker_output = self.driver_worker.execute_method(
method, *driver_args, **driver_kwargs) method, *driver_args, **driver_kwargs)
else: else:
assert self.driver_dummy_worker is not None
driver_worker_output = ray.get( driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote( self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs)) method, *driver_args, **driver_kwargs))
...@@ -353,8 +270,9 @@ class RayGPUExecutor(ExecutorBase): ...@@ -353,8 +270,9 @@ class RayGPUExecutor(ExecutorBase):
# a dummy value for now. It will be fixed soon. # a dummy value for now. It will be fixed soon.
with InputNode() as input_data: with InputNode() as input_data:
forward_dag = MultiOutputNode([ forward_dag = MultiOutputNode([
worker.execute_model_compiled_dag_remote.bind(input_data) worker.execute_model_compiled_dag_remote.
for worker in self.workers bind( # type: ignore[attr-defined]
input_data) for worker in self.workers
]) ])
return forward_dag.experimental_compile() return forward_dag.experimental_compile()
...@@ -376,7 +294,7 @@ class RayGPUExecutor(ExecutorBase): ...@@ -376,7 +294,7 @@ class RayGPUExecutor(ExecutorBase):
f"Dead Workers: {dead_actors}. ") f"Dead Workers: {dead_actors}. ")
class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -407,23 +325,3 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): ...@@ -407,23 +325,3 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
all_outputs = await asyncio.gather(*coros) all_outputs = await asyncio.gather(*coros)
return all_outputs return all_outputs
async def execute_model_async(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput:
all_outputs = await self._run_workers_async(
"execute_model",
driver_kwargs={
"seq_group_metadata_list": seq_group_metadata_list,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
})
# Only the driver worker returns the sampling results.
output = all_outputs[0]
return output
...@@ -43,9 +43,9 @@ try: ...@@ -43,9 +43,9 @@ try:
return output return output
except ImportError as e: except ImportError as e:
logger.warning(f"Failed to import Ray with {e!r}. " logger.warning(
"For distributed inference, please install Ray with " "Failed to import Ray with %r. For distributed inference, "
"`pip install ray`.") "please install Ray with `pip install ray`.", e)
ray = None # type: ignore ray = None # type: ignore
RayWorkerWrapper = None # type: ignore RayWorkerWrapper = None # type: ignore
......
# Adapted from
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration for vLLM.""" """Logging configuration for vLLM."""
import datetime import datetime
import json
import logging import logging
import os import os
import sys import sys
from functools import partial from functools import partial
from typing import Optional from logging import Logger
from logging.config import dictConfig
from os import path
from typing import Dict, Optional
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) import vllm.envs as envs
VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING
VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S" _DATE_FORMAT = "%m-%d %H:%M:%S"
DEFAULT_LOGGING_CONFIG = {
"formatters": {
"vllm": {
"class": "vllm.logging.NewLineFormatter",
"datefmt": _DATE_FORMAT,
"format": _FORMAT,
},
},
"handlers": {
"vllm": {
"class": "logging.StreamHandler",
"formatter": "vllm",
"level": "INFO",
"stream": "ext://sys.stdout",
},
},
"loggers": {
"vllm": {
"handlers": ["vllm"],
"level": "DEBUG",
"propagate": False,
},
},
"version": 1,
}
def _configure_vllm_root_logger() -> None:
logging_config: Optional[Dict] = None
if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH:
raise RuntimeError(
"VLLM_CONFIGURE_LOGGING evaluated to false, but "
"VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH "
"implies VLLM_CONFIGURE_LOGGING. Please enable "
"VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.")
class NewLineFormatter(logging.Formatter): if VLLM_CONFIGURE_LOGGING:
"""Adds logging prefix to newlines to align multi-line messages.""" logging_config = DEFAULT_LOGGING_CONFIG
def __init__(self, fmt, datefmt=None): if VLLM_LOGGING_CONFIG_PATH:
logging.Formatter.__init__(self, fmt, datefmt) if not path.exists(VLLM_LOGGING_CONFIG_PATH):
raise RuntimeError(
"Could not load logging config. File does not exist: %s",
VLLM_LOGGING_CONFIG_PATH)
with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8",
mode="r") as file:
custom_config = json.loads(file.read())
def format(self, record): if not isinstance(custom_config, dict):
msg = logging.Formatter.format(self, record) raise ValueError("Invalid logging config. Expected Dict, got %s.",
if record.message != "": type(custom_config).__name__)
parts = msg.split(record.message) logging_config = custom_config
msg = msg.replace("\n", "\r\n" + parts[0])
return msg
if logging_config:
dictConfig(logging_config)
_root_logger = logging.getLogger("vllm")
_default_handler: Optional[logging.Handler] = None
def init_logger(name: str) -> Logger:
"""The main purpose of this function is to ensure that loggers are
retrieved in such a way that we can be sure the root vllm logger has
already been configured."""
def _setup_logger(): return logging.getLogger(name)
_root_logger.setLevel(logging.DEBUG)
global _default_handler
if _default_handler is None:
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.flush = sys.stdout.flush # type: ignore
_default_handler.setLevel(logging.INFO)
_root_logger.addHandler(_default_handler)
fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT)
_default_handler.setFormatter(fmt)
# Setting this will avoid the message
# being propagated to the parent logger.
_root_logger.propagate = False
# The logger is initialized when the module is imported. # The root logger is initialized when the module is imported.
# This is thread-safe as the module is only imported once, # This is thread-safe as the module is only imported once,
# guaranteed by the Python GIL. # guaranteed by the Python GIL.
if VLLM_CONFIGURE_LOGGING: _configure_vllm_root_logger()
_setup_logger()
def init_logger(name: str):
# Use the same settings as above for root logger
logger = logging.getLogger(name)
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))
if VLLM_CONFIGURE_LOGGING:
if _default_handler is None:
raise ValueError(
"_default_handler is not set up. This should never happen!"
" Please open an issue on Github.")
logger.addHandler(_default_handler)
logger.propagate = False
return logger
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -126,7 +146,7 @@ def enable_trace_function_call(log_file_path: str, ...@@ -126,7 +146,7 @@ def enable_trace_function_call(log_file_path: str,
"VLLM_TRACE_FUNCTION is enabled. It will record every" "VLLM_TRACE_FUNCTION is enabled. It will record every"
" function executed by Python. This will slow down the code. It " " function executed by Python. This will slow down the code. It "
"is suggested to be used for debugging hang or crashes only.") "is suggested to be used for debugging hang or crashes only.")
logger.info(f"Trace frame log is saved to {log_file_path}") logger.info("Trace frame log is saved to %s", log_file_path)
if root_dir is None: if root_dir is None:
# by default, this is the vllm root directory # by default, this is the vllm root directory
root_dir = os.path.dirname(os.path.dirname(__file__)) root_dir = os.path.dirname(os.path.dirname(__file__))
......
from vllm.logging.formatter import NewLineFormatter
__all__ = [
"NewLineFormatter",
]
import logging
class NewLineFormatter(logging.Formatter):
"""Adds logging prefix to newlines to align multi-line messages."""
def __init__(self, fmt, datefmt=None, style="%"):
logging.Formatter.__init__(self, fmt, datefmt, style)
def format(self, record):
msg = logging.Formatter.format(self, record)
if record.message != "":
parts = msg.split(record.message)
msg = msg.replace("\n", "\r\n" + parts[0])
return msg
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment