Unverified Commit d4db9f53 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Benchmark] Add `--async-engine` option to benchmark_throughput.py (#7964)

parent 2188a60c
...@@ -6,13 +6,16 @@ import time ...@@ -6,13 +6,16 @@ import time
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
import uvloop
from tqdm import tqdm from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer, from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase) PreTrainedTokenizerBase)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser, merge_async_iterators
def sample_requests( def sample_requests(
...@@ -135,6 +138,93 @@ def run_vllm( ...@@ -135,6 +138,93 @@ def run_vllm(
return end - start return end - start
async def run_vllm_async(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
use_v2_block_manager: bool = False,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
disable_frontend_multiprocessing: bool = False,
) -> float:
from vllm import SamplingParams
engine_args = AsyncEngineArgs(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc,
worker_use_ray=False,
engine_use_ray=False,
disable_log_requests=True,
)
async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm:
# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))
generators = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
pass
end = time.perf_counter()
return end - start
def run_hf( def run_hf(
requests: List[Tuple[str, int, int]], requests: List[Tuple[str, int, int]],
model: str, model: str,
...@@ -230,7 +320,7 @@ def main(args: argparse.Namespace): ...@@ -230,7 +320,7 @@ def main(args: argparse.Namespace):
args.output_len) args.output_len)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm( run_args = [
requests, args.model, args.tokenizer, args.quantization, requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len, args.trust_remote_code, args.dtype, args.max_model_len,
...@@ -240,7 +330,14 @@ def main(args: argparse.Namespace): ...@@ -240,7 +330,14 @@ def main(args: argparse.Namespace):
args.max_num_batched_tokens, args.distributed_executor_backend, args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps, args.gpu_memory_utilization, args.num_scheduler_steps,
args.use_v2_block_manager, args.download_dir, args.load_format, args.use_v2_block_manager, args.download_dir, args.load_format,
args.disable_async_output_proc) args.disable_async_output_proc
]
if args.async_engine:
run_args.append(args.disable_frontend_multiprocessing)
elapsed_time = uvloop.run(run_vllm_async(*run_args))
else:
elapsed_time = run_vllm(*run_args)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
...@@ -426,6 +523,14 @@ if __name__ == "__main__": ...@@ -426,6 +523,14 @@ if __name__ == "__main__":
action='store_true', action='store_true',
default=False, default=False,
help="Disable async output processor for vLLM backend.") help="Disable async output processor for vLLM backend.")
parser.add_argument("--async-engine",
action='store_true',
default=False,
help="Use vLLM async engine rather than LLM class.")
parser.add_argument("--disable-frontend-multiprocessing",
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
args = parser.parse_args() args = parser.parse_args()
if args.tokenizer is None: if args.tokenizer is None:
args.tokenizer = args.model args.tokenizer = args.model
......
...@@ -67,7 +67,7 @@ _running_tasks: Set[asyncio.Task] = set() ...@@ -67,7 +67,7 @@ _running_tasks: Set[asyncio.Task] = set()
def model_is_embedding(model_name: str, trust_remote_code: bool, def model_is_embedding(model_name: str, trust_remote_code: bool,
quantization: str) -> bool: quantization: Optional[str]) -> bool:
return ModelConfig(model=model_name, return ModelConfig(model=model_name,
tokenizer=model_name, tokenizer=model_name,
tokenizer_mode="auto", tokenizer_mode="auto",
...@@ -96,13 +96,6 @@ async def lifespan(app: FastAPI): ...@@ -96,13 +96,6 @@ async def lifespan(app: FastAPI):
@asynccontextmanager @asynccontextmanager
async def build_async_engine_client( async def build_async_engine_client(
args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]: args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
"""
Create AsyncEngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# Context manager to handle async_engine_client lifecycle # Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit # Ensures everything is shutdown and cleaned up on error/exit
...@@ -112,14 +105,37 @@ async def build_async_engine_client( ...@@ -112,14 +105,37 @@ async def build_async_engine_client(
# Backend itself still global for the silly lil' health handler # Backend itself still global for the silly lil' health handler
global async_engine_client global async_engine_client
async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine:
async_engine_client = engine # type: ignore[assignment]
yield engine
@asynccontextmanager
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
) -> AsyncIterator[Optional[AsyncEngineClient]]:
"""
Create AsyncEngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# If manually triggered or embedding model, use AsyncLLMEngine in process. # If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC. # TODO: support embedding model via RPC.
if (model_is_embedding(args.model, args.trust_remote_code, if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
args.quantization) engine_args.quantization)
or args.disable_frontend_multiprocessing): or disable_frontend_multiprocessing):
async_engine_client = AsyncLLMEngine.from_engine_args( engine_client = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER) engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
yield async_engine_client try:
yield engine_client
finally:
engine_client.shutdown_background_loop()
return return
# Otherwise, use the multiprocessing AsyncLLMEngine. # Otherwise, use the multiprocessing AsyncLLMEngine.
...@@ -148,7 +164,6 @@ async def build_async_engine_client( ...@@ -148,7 +164,6 @@ async def build_async_engine_client(
# NOTE: Actually, this is not true yet. We still need to support # NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above) # embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path) rpc_client = AsyncEngineRPCClient(rpc_path)
async_engine_client = rpc_client # type: ignore
# Start RPCServer in separate process (holds the AsyncLLMEngine). # Start RPCServer in separate process (holds the AsyncLLMEngine).
context = multiprocessing.get_context("spawn") context = multiprocessing.get_context("spawn")
...@@ -174,7 +189,7 @@ async def build_async_engine_client( ...@@ -174,7 +189,7 @@ async def build_async_engine_client(
yield None yield None
return return
yield async_engine_client yield rpc_client # type: ignore[misc]
finally: finally:
# Ensure rpc server process was terminated # Ensure rpc server process was terminated
rpc_server_process.terminate() rpc_server_process.terminate()
......
...@@ -7,6 +7,7 @@ from uuid import uuid4 ...@@ -7,6 +7,7 @@ from uuid import uuid4
import cloudpickle import cloudpickle
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket from zmq.asyncio import Socket
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
...@@ -214,6 +215,7 @@ class AsyncEngineRPCClient: ...@@ -214,6 +215,7 @@ class AsyncEngineRPCClient:
# Await the data from the Server. # Await the data from the Server.
frame = await socket.recv(copy=False) frame = await socket.recv(copy=False)
assert isinstance(frame, Frame)
data = pickle.loads(frame.buffer) data = pickle.loads(frame.buffer)
if isinstance(data, Exception): if isinstance(data, Exception):
...@@ -247,6 +249,7 @@ class AsyncEngineRPCClient: ...@@ -247,6 +249,7 @@ class AsyncEngineRPCClient:
f"{self._data_timeout} ms") f"{self._data_timeout} ms")
frame = await socket.recv(copy=False) frame = await socket.recv(copy=False)
assert isinstance(frame, Frame)
return pickle.loads(frame.buffer) return pickle.loads(frame.buffer)
# Make a new socket connection. # Make a new socket connection.
...@@ -395,6 +398,7 @@ class AsyncEngineRPCClient: ...@@ -395,6 +398,7 @@ class AsyncEngineRPCClient:
# Stream back the results from the RPC Server. # Stream back the results from the RPC Server.
while not finished: while not finished:
message = await socket.recv(copy=False) message = await socket.recv(copy=False)
assert isinstance(message, Frame)
request_output = pickle.loads(message.buffer) request_output = pickle.loads(message.buffer)
if isinstance(request_output, Exception): if isinstance(request_output, Exception):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment