Unverified Commit 8c2ffaaf authored by hzh0425's avatar hzh0425 Committed by GitHub
Browse files

fix(hicahce-long-bench): adjust context workload generator to use full query set (#9847)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 20445327
...@@ -31,9 +31,10 @@ class ContextWorkloadGenerator(WorkloadGenerator): ...@@ -31,9 +31,10 @@ class ContextWorkloadGenerator(WorkloadGenerator):
self.completed_requests = 0 self.completed_requests = 0
self.dataset = json.load(open(args.dataset_path)) self.dataset = json.load(open(args.dataset_path))
num_requests = min(args.num_clients, len(self.dataset["queries"]))
init_requests = [] init_requests = []
for i in range(min(args.num_clients, len(self.dataset["queries"]))): for i in range(num_requests):
context_id = self.dataset["queries"][i]["context"] context_id = self.dataset["queries"][i]["context"]
init_requests.append( init_requests.append(
( (
...@@ -52,13 +53,14 @@ class ContextWorkloadGenerator(WorkloadGenerator): ...@@ -52,13 +53,14 @@ class ContextWorkloadGenerator(WorkloadGenerator):
self.ready_queue = ReadyQueue(init_requests=init_requests) self.ready_queue = ReadyQueue(init_requests=init_requests)
self.response_queue = queue.Queue() self.response_queue = queue.Queue()
self.pbar = tqdm(total=args.num_clients * args.num_rounds) self.pbar = tqdm(total=num_requests)
self.performance_metrics = { self.performance_metrics = {
"ttft": [], "ttft": [],
"latency": [], "latency": [],
"itl": [], "itl": [],
"prompt_len": [], "prompt_len": [],
"cached_tokens": [], "cached_tokens": [],
"generated_len": [],
} }
self.max_parallel = args.max_parallel self.max_parallel = args.max_parallel
...@@ -75,6 +77,9 @@ class ContextWorkloadGenerator(WorkloadGenerator): ...@@ -75,6 +77,9 @@ class ContextWorkloadGenerator(WorkloadGenerator):
self.performance_metrics["ttft"].append(response.ttft) self.performance_metrics["ttft"].append(response.ttft)
self.performance_metrics["itl"].extend(response.itl) self.performance_metrics["itl"].extend(response.itl)
self.performance_metrics["latency"].append(response.latency) self.performance_metrics["latency"].append(response.latency)
self.performance_metrics["prompt_len"].append(response.prompt_len)
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
self.performance_metrics["generated_len"].append(response.generated_len)
self.completed_requests += 1 self.completed_requests += 1
except queue.Empty: except queue.Empty:
...@@ -85,7 +90,7 @@ class ContextWorkloadGenerator(WorkloadGenerator): ...@@ -85,7 +90,7 @@ class ContextWorkloadGenerator(WorkloadGenerator):
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
args.num_rounds = 1 args.num_rounds = 1
args.max_parallel = 128 args.max_parallel = 24
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache" flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
for request_rate in [24, 16, 12, 8, 4, 2, 1]: for request_rate in [24, 16, 12, 8, 4, 2, 1]:
......
...@@ -191,6 +191,7 @@ async def async_request_sglang_generate( ...@@ -191,6 +191,7 @@ async def async_request_sglang_generate(
output.latency = latency output.latency = latency
output.prompt_len = prompt_tokens output.prompt_len = prompt_tokens
output.cached_tokens = cached_tokens output.cached_tokens = cached_tokens
output.generated_len = len(output.itl) + 1
else: else:
output.error = response.reason or "" output.error = response.reason or ""
output.success = False output.success = False
...@@ -321,6 +322,7 @@ class WorkloadGenerator: ...@@ -321,6 +322,7 @@ class WorkloadGenerator:
"latency": [], "latency": [],
"prompt_len": [], "prompt_len": [],
"cached_tokens": [], "cached_tokens": [],
"generated_len": [],
} }
self.num_rounds = args.num_rounds self.num_rounds = args.num_rounds
self.max_parallel = args.max_parallel self.max_parallel = args.max_parallel
...@@ -383,6 +385,7 @@ class WorkloadGenerator: ...@@ -383,6 +385,7 @@ class WorkloadGenerator:
self.performance_metrics["latency"].append(response.latency) self.performance_metrics["latency"].append(response.latency)
self.performance_metrics["prompt_len"].append(response.prompt_len) self.performance_metrics["prompt_len"].append(response.prompt_len)
self.performance_metrics["cached_tokens"].append(response.cached_tokens) self.performance_metrics["cached_tokens"].append(response.cached_tokens)
self.performance_metrics["generated_len"].append(response.generated_len)
self.completed_requests += 1 self.completed_requests += 1
if self.client_records[client_id]["round"] < self.num_rounds: if self.client_records[client_id]["round"] < self.num_rounds:
...@@ -418,6 +421,7 @@ class WorkloadGenerator: ...@@ -418,6 +421,7 @@ class WorkloadGenerator:
response_thread.join() response_thread.join()
self.pbar.close() self.pbar.close()
duration = self.finished_time - self.start_time
performance_data = { performance_data = {
"summary": { "summary": {
"total_requests": len(self.performance_metrics["ttft"]), "total_requests": len(self.performance_metrics["ttft"]),
...@@ -438,7 +442,13 @@ class WorkloadGenerator: ...@@ -438,7 +442,13 @@ class WorkloadGenerator:
"median_latency": sorted(self.performance_metrics["latency"])[ "median_latency": sorted(self.performance_metrics["latency"])[
len(self.performance_metrics["latency"]) // 2 len(self.performance_metrics["latency"]) // 2
], ],
"throughput": self.pbar.total / (self.finished_time - self.start_time), "input_token_throughput": sum(self.performance_metrics["prompt_len"])
/ duration,
"output_token_throughput": sum(
self.performance_metrics["generated_len"]
)
/ duration,
"throughput": self.pbar.total / duration,
"cache_hit_rate": ( "cache_hit_rate": (
0 0
if sum(self.performance_metrics["prompt_len"]) == 0 if sum(self.performance_metrics["prompt_len"]) == 0
...@@ -461,7 +471,13 @@ class WorkloadGenerator: ...@@ -461,7 +471,13 @@ class WorkloadGenerator:
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}") print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}") print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
print( print(
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second" f" Input token throughput: {performance_data['summary']['input_token_throughput']:.2f} tokens per second"
)
print(
f" Output token throughput: {performance_data['summary']['output_token_throughput']:.2f} tokens per second"
)
print(
f" Request Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
) )
print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}") print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}")
return performance_data return performance_data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment