Unverified Commit ff45ab7a authored by zhongwei's avatar zhongwei Committed by GitHub
Browse files

[Benchmark] add disable-auto-run param for hicache/bench_multiturn (#7822)


Co-authored-by: default avatarzhongwei.ren <zhongwei.ren@bytedance.com>
Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 0f8b5386
...@@ -9,6 +9,7 @@ from datetime import datetime ...@@ -9,6 +9,7 @@ from datetime import datetime
from typing import Optional from typing import Optional
import aiohttp import aiohttp
import numpy as np
import requests import requests
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
...@@ -97,6 +98,30 @@ def parse_args(): ...@@ -97,6 +98,30 @@ def parse_args():
default="performance_metrics.jsonl", default="performance_metrics.jsonl",
help="File to log performance metrics", help="File to log performance metrics",
) )
parser.add_argument(
"--disable-auto-run",
action="store_true",
help="If set, disable automatically testing with a range of request rates.",
)
parser.add_argument(
"--disable-random-sample",
action="store_true",
help="If set, disable random sampling of requests from the ShareGPT dataset.",
)
parser.add_argument(
"--sub-question-input-length",
type=int,
default=0,
help="Length of the sub question input for each request, if set 0 use request_length",
)
parser.add_argument(
"--ready-queue-policy",
type=str,
default="random",
help="Policy for popping requests from the ready queue (random or fifo)",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
return parser.parse_args() return parser.parse_args()
...@@ -234,13 +259,29 @@ class WorkloadGenerator: ...@@ -234,13 +259,29 @@ class WorkloadGenerator:
self.candidate_inputs = sample_random_requests( self.candidate_inputs = sample_random_requests(
input_len=args.request_length, input_len=args.request_length,
output_len=args.output_length, output_len=args.output_length,
num_prompts=args.num_clients * args.num_rounds, num_prompts=args.num_clients,
range_ratio=1.0, range_ratio=1.0,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
dataset_path=args.dataset_path, dataset_path=args.dataset_path,
random_sample=not args.disable_random_sample,
) )
self.candidate_inputs = [i.prompt for i in self.candidate_inputs] self.candidate_inputs = [i.prompt for i in self.candidate_inputs]
if args.sub_question_input_length != 0:
sub_question_input_length = args.sub_question_input_length
else:
sub_question_input_length = args.request_length
self.sub_question_inputs = sample_random_requests(
input_len=sub_question_input_length,
output_len=args.output_length,
num_prompts=args.num_clients * max(args.num_rounds - 1, 1),
range_ratio=1.0,
tokenizer=self.tokenizer,
dataset_path=args.dataset_path,
random_sample=not args.disable_random_sample,
)
init_requests = [ init_requests = [
(i, gen_payload(self.candidate_inputs[i], args.output_length)) (i, gen_payload(self.candidate_inputs[i], args.output_length))
for i in range(args.num_clients) for i in range(args.num_clients)
...@@ -249,7 +290,9 @@ class WorkloadGenerator: ...@@ -249,7 +290,9 @@ class WorkloadGenerator:
i: {"round": 0, "history": init_requests[i][1]["text"]} i: {"round": 0, "history": init_requests[i][1]["text"]}
for i in range(args.num_clients) for i in range(args.num_clients)
} }
self.ready_queue = ReadyQueue(init_requests=init_requests) self.ready_queue = ReadyQueue(
init_requests=init_requests, policy=args.ready_queue_policy
)
self.candidate_inputs = self.candidate_inputs[args.num_clients :] self.candidate_inputs = self.candidate_inputs[args.num_clients :]
self.response_queue = queue.Queue() self.response_queue = queue.Queue()
...@@ -314,9 +357,10 @@ class WorkloadGenerator: ...@@ -314,9 +357,10 @@ class WorkloadGenerator:
self.completed_requests += 1 self.completed_requests += 1
if self.client_records[client_id]["round"] < args.num_rounds: if self.client_records[client_id]["round"] < args.num_rounds:
# append new request to client's history
self.client_records[client_id][ self.client_records[client_id][
"history" "history"
] += self.candidate_inputs.pop() ] += self.sub_question_inputs.pop()
self.ready_queue.append( self.ready_queue.append(
( (
client_id, client_id,
...@@ -329,6 +373,9 @@ class WorkloadGenerator: ...@@ -329,6 +373,9 @@ class WorkloadGenerator:
except queue.Empty: except queue.Empty:
if self.pbar.n == self.pbar.total: if self.pbar.n == self.pbar.total:
break break
except ValueError as e:
print(f"Error processing response for client {client_id}: {e}")
continue
def run(self): def run(self):
request_thread = threading.Thread(target=self.request_sender, daemon=True) request_thread = threading.Thread(target=self.request_sender, daemon=True)
...@@ -388,8 +435,18 @@ if __name__ == "__main__": ...@@ -388,8 +435,18 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache" flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]: random.seed(args.seed)
args.request_rate = request_rate np.random.seed(args.seed)
if args.disable_auto_run:
print("Running with specified request rate...")
request_rates = [args.request_rate]
else:
print("Auto-running with different request rates...")
request_rates = [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
for rate in request_rates:
args.request_rate = rate
requests.post(flush_cache_url) requests.post(flush_cache_url)
time.sleep(1) time.sleep(1)
WorkloadGenerator(args).run() WorkloadGenerator(args).run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment