Unverified Commit 60769be1 authored by Yunmeng's avatar Yunmeng Committed by GitHub
Browse files

Add concurrency option for benchmark (#2136)

parent a78d8f8d
...@@ -859,6 +859,7 @@ async def benchmark( ...@@ -859,6 +859,7 @@ async def benchmark(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]], input_requests: List[Tuple[str, int, int]],
request_rate: float, request_rate: float,
max_concurrency: Optional[int],
disable_tqdm: bool, disable_tqdm: bool,
extra_request_body: Dict[str, Any], extra_request_body: Dict[str, Any],
profile: bool, profile: bool,
...@@ -868,6 +869,15 @@ async def benchmark( ...@@ -868,6 +869,15 @@ async def benchmark(
else: else:
raise ValueError(f"Unknown backend: {backend}") raise ValueError(f"Unknown backend: {backend}")
# From https://github.com/vllm-project/vllm/pull/9390
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
async def limited_request_func(request_func_input, pbar):
if semaphore is None:
return await request_func(request_func_input=request_func_input, pbar=pbar)
async with semaphore:
return await request_func(request_func_input=request_func_input, pbar=pbar)
print("Starting initial single prompt test run...") print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len = input_requests[0] test_prompt, test_prompt_len, test_output_len = input_requests[0]
test_input = RequestFuncInput( test_input = RequestFuncInput(
...@@ -913,7 +923,7 @@ async def benchmark( ...@@ -913,7 +923,7 @@ async def benchmark(
) )
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
request_func(request_func_input=request_func_input, pbar=pbar) limited_request_func(request_func_input=request_func_input, pbar=pbar)
) )
) )
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
...@@ -940,6 +950,12 @@ async def benchmark( ...@@ -940,6 +950,12 @@ async def benchmark(
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Backend:", backend)) print("{:<40} {:<10}".format("Backend:", backend))
print("{:<40} {:<10}".format("Traffic request rate:", request_rate)) print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
print(
"{:<40} {:<10}".format(
"Max reqeuest concurrency:",
max_concurrency if max_concurrency else "not set",
)
)
print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
...@@ -1003,6 +1019,7 @@ async def benchmark( ...@@ -1003,6 +1019,7 @@ async def benchmark(
"backend": args.backend, "backend": args.backend,
"dataset_name": args.dataset_name, "dataset_name": args.dataset_name,
"request_rate": request_rate, "request_rate": request_rate,
"max_concurrency": max_concurrency,
"total_input_tokens": metrics.total_input, "total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output, "total_output_tokens": metrics.total_output,
"total_output_tokens_retokenized": metrics.total_output_retokenized, "total_output_tokens_retokenized": metrics.total_output_retokenized,
...@@ -1201,6 +1218,7 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1201,6 +1218,7 @@ def run_benchmark(args_: argparse.Namespace):
tokenizer=tokenizer, tokenizer=tokenizer,
input_requests=input_requests, input_requests=input_requests,
request_rate=args.request_rate, request_rate=args.request_rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm, disable_tqdm=args.disable_tqdm,
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
profile=args.profile, profile=args.profile,
...@@ -1220,6 +1238,7 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1220,6 +1238,7 @@ def run_benchmark(args_: argparse.Namespace):
tokenizer=tokenizer, tokenizer=tokenizer,
input_requests=input_requests, input_requests=input_requests,
request_rate=rate, request_rate=rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm, disable_tqdm=args.disable_tqdm,
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
profile=args.profile, profile=args.profile,
...@@ -1319,6 +1338,19 @@ if __name__ == "__main__": ...@@ -1319,6 +1338,19 @@ if __name__ == "__main__":
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.", "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
) )
parser.add_argument(
"--max-concurrency",
type=int,
default=None,
help="Maximum number of concurrent requests. This can be used "
"to help simulate an environment where a higher level component "
"is enforcing a maximum number of concurrent requests. While the "
"--request-rate argument controls the rate at which requests are "
"initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument( parser.add_argument(
"--multi", "--multi",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment