Unverified Commit 34ddcf9f authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] `run-batch` supports V1 (#21541)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent fe56180c
...@@ -167,7 +167,8 @@ async def run_vllm_async( ...@@ -167,7 +167,8 @@ async def run_vllm_async(
from vllm import SamplingParams from vllm import SamplingParams
async with build_async_engine_client_from_engine_args( async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing engine_args,
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
) as llm: ) as llm:
model_config = await llm.get_model_config() model_config = await llm.get_model_config()
assert all( assert all(
......
...@@ -295,8 +295,6 @@ async def test_metrics_exist(server: RemoteOpenAIServer, ...@@ -295,8 +295,6 @@ async def test_metrics_exist(server: RemoteOpenAIServer,
def test_metrics_exist_run_batch(use_v1: bool): def test_metrics_exist_run_batch(use_v1: bool):
if use_v1:
pytest.skip("Skipping test on vllm V1")
input_batch = """{"custom_id": "request-0", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}}""" # noqa: E501 input_batch = """{"custom_id": "request-0", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}}""" # noqa: E501
base_url = "0.0.0.0" base_url = "0.0.0.0"
...@@ -323,7 +321,8 @@ def test_metrics_exist_run_batch(use_v1: bool): ...@@ -323,7 +321,8 @@ def test_metrics_exist_run_batch(use_v1: bool):
base_url, base_url,
"--port", "--port",
port, port,
], ) ],
env={"VLLM_USE_V1": "1" if use_v1 else "0"})
def is_server_up(url): def is_server_up(url):
try: try:
......
...@@ -148,7 +148,9 @@ async def run_vllm_async( ...@@ -148,7 +148,9 @@ async def run_vllm_async(
from vllm import SamplingParams from vllm import SamplingParams
async with build_async_engine_client_from_engine_args( async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm: engine_args,
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
) as llm:
model_config = await llm.get_model_config() model_config = await llm.get_model_config()
assert all( assert all(
model_config.max_model_len >= (request.prompt_len + model_config.max_model_len >= (request.prompt_len +
......
...@@ -149,6 +149,9 @@ async def lifespan(app: FastAPI): ...@@ -149,6 +149,9 @@ async def lifespan(app: FastAPI):
@asynccontextmanager @asynccontextmanager
async def build_async_engine_client( async def build_async_engine_client(
args: Namespace, args: Namespace,
*,
usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
disable_frontend_multiprocessing: Optional[bool] = None,
client_config: Optional[dict[str, Any]] = None, client_config: Optional[dict[str, Any]] = None,
) -> AsyncIterator[EngineClient]: ) -> AsyncIterator[EngineClient]:
...@@ -156,15 +159,24 @@ async def build_async_engine_client( ...@@ -156,15 +159,24 @@ async def build_async_engine_client(
# Ensures everything is shutdown and cleaned up on error/exit # Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
if disable_frontend_multiprocessing is None:
disable_frontend_multiprocessing = bool(
args.disable_frontend_multiprocessing)
async with build_async_engine_client_from_engine_args( async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing, engine_args,
client_config) as engine: usage_context=usage_context,
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
client_config=client_config,
) as engine:
yield engine yield engine
@asynccontextmanager @asynccontextmanager
async def build_async_engine_client_from_engine_args( async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
*,
usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
disable_frontend_multiprocessing: bool = False, disable_frontend_multiprocessing: bool = False,
client_config: Optional[dict[str, Any]] = None, client_config: Optional[dict[str, Any]] = None,
) -> AsyncIterator[EngineClient]: ) -> AsyncIterator[EngineClient]:
...@@ -177,7 +189,6 @@ async def build_async_engine_client_from_engine_args( ...@@ -177,7 +189,6 @@ async def build_async_engine_client_from_engine_args(
""" """
# Create the EngineConfig (determines if we can use V1). # Create the EngineConfig (determines if we can use V1).
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context) vllm_config = engine_args.create_engine_config(usage_context=usage_context)
# V1 AsyncLLM. # V1 AsyncLLM.
...@@ -1811,7 +1822,10 @@ async def run_server_worker(listen_address, ...@@ -1811,7 +1822,10 @@ async def run_server_worker(listen_address,
if log_config is not None: if log_config is not None:
uvicorn_kwargs['log_config'] = log_config uvicorn_kwargs['log_config'] = log_config
async with build_async_engine_client(args, client_config) as engine_client: async with build_async_engine_client(
args,
client_config=client_config,
) as engine_client:
maybe_register_tokenizer_info_endpoint(args) maybe_register_tokenizer_info_endpoint(args)
app = build_app(args) app = build_app(args)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import asyncio import asyncio
import tempfile import tempfile
from argparse import Namespace
from collections.abc import Awaitable from collections.abc import Awaitable
from http import HTTPStatus from http import HTTPStatus
from io import StringIO from io import StringIO
...@@ -13,10 +14,12 @@ import torch ...@@ -13,10 +14,12 @@ import torch
from prometheus_client import start_http_server from prometheus_client import start_http_server
from tqdm import tqdm from tqdm import tqdm
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
# yapf: disable # yapf: disable
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.protocol import (BatchRequestInput, from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput, BatchRequestOutput,
BatchResponseData, BatchResponseData,
...@@ -310,36 +313,37 @@ async def run_request(serving_engine_func: Callable, ...@@ -310,36 +313,37 @@ async def run_request(serving_engine_func: Callable,
return batch_output return batch_output
async def main(args): async def run_batch(
engine_client: EngineClient,
vllm_config: VllmConfig,
args: Namespace,
) -> None:
if args.served_model_name is not None: if args.served_model_name is not None:
served_model_names = args.served_model_name served_model_names = args.served_model_name
else: else:
served_model_names = [args.model] served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args) if args.disable_log_requests:
engine = AsyncLLMEngine.from_engine_args( request_logger = None
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
model_config = await engine.get_model_config()
base_model_paths = [ base_model_paths = [
BaseModelPath(name=name, model_path=args.model) BaseModelPath(name=name, model_path=args.model)
for name in served_model_names for name in served_model_names
] ]
if args.disable_log_requests: model_config = vllm_config.model_config
request_logger = None
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
# Create the openai serving objects. # Create the openai serving objects.
openai_serving_models = OpenAIServingModels( openai_serving_models = OpenAIServingModels(
engine_client=engine, engine_client=engine_client,
model_config=model_config, model_config=model_config,
base_model_paths=base_model_paths, base_model_paths=base_model_paths,
lora_modules=None, lora_modules=None,
) )
openai_serving_chat = OpenAIServingChat( openai_serving_chat = OpenAIServingChat(
engine, engine_client,
model_config, model_config,
openai_serving_models, openai_serving_models,
args.response_role, args.response_role,
...@@ -349,7 +353,7 @@ async def main(args): ...@@ -349,7 +353,7 @@ async def main(args):
enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if "generate" in model_config.supported_tasks else None ) if "generate" in model_config.supported_tasks else None
openai_serving_embedding = OpenAIServingEmbedding( openai_serving_embedding = OpenAIServingEmbedding(
engine, engine_client,
model_config, model_config,
openai_serving_models, openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
...@@ -362,7 +366,7 @@ async def main(args): ...@@ -362,7 +366,7 @@ async def main(args):
"num_labels", 0) == 1) "num_labels", 0) == 1)
openai_serving_scores = ServingScores( openai_serving_scores = ServingScores(
engine, engine_client,
model_config, model_config,
openai_serving_models, openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
...@@ -457,6 +461,17 @@ async def main(args): ...@@ -457,6 +461,17 @@ async def main(args):
await write_file(args.output_file, responses, args.output_tmp_dir) await write_file(args.output_file, responses, args.output_tmp_dir)
async def main(args: Namespace):
async with build_async_engine_client(
args,
usage_context=UsageContext.OPENAI_BATCH_RUNNER,
disable_frontend_multiprocessing=False,
) as engine_client:
vllm_config = await engine_client.get_vllm_config()
await run_batch(engine_client, vllm_config, args)
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment