Unverified Commit 0eec4cb6 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

HiCache, add bench long context plus minor fixs (#9086)


Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent ff1f6825
import json
import queue
import time
import requests
from bench_multiturn import (
ReadyQueue,
WorkloadGenerator,
gen_payload,
log_to_jsonl_file,
parse_args,
)
from tqdm.asyncio import tqdm
from sglang.bench_serving import get_tokenizer
class ContextWorkloadGenerator(WorkloadGenerator):
def __init__(self, args):
# Construct the base URL for requests
self.baseurl = f"http://{args.host}:{args.port}/"
self.url = self.baseurl + "generate"
self.tokenizer = get_tokenizer(args.model_path)
self.distribution = args.distribution
self.request_rate = args.request_rate
self.start_time = None
self.finished_time = None
self.sent_requests = 0
self.completed_requests = 0
self.dataset = json.load(open(args.dataset_path))
init_requests = []
for i in range(min(args.num_clients, len(self.dataset["queries"]))):
context_id = self.dataset["queries"][i]["context"]
init_requests.append(
(
i,
gen_payload(
self.dataset["contexts"][context_id]
+ self.dataset["queries"][i]["question"],
len(
self.tokenizer(
self.dataset["queries"][i]["reference_answer"]
)["input_ids"]
),
),
)
)
self.ready_queue = ReadyQueue(init_requests=init_requests)
self.response_queue = queue.Queue()
self.pbar = tqdm(total=args.num_clients * args.num_rounds)
self.performance_metrics = {
"ttft": [],
"latency": [],
"itl": [],
"prompt_len": [],
"cached_tokens": [],
}
self.max_parallel = args.max_parallel
self.logfile = args.log_file
def response_handler(self):
while True:
try:
client_id, response = self.response_queue.get(
timeout=10
) # Block until response is available
if not response.success:
raise ValueError(f"Request failed with error: {response.error}")
self.performance_metrics["ttft"].append(response.ttft)
self.performance_metrics["itl"].extend(response.itl)
self.performance_metrics["latency"].append(response.latency)
self.completed_requests += 1
except queue.Empty:
if self.pbar.n == self.pbar.total:
break
if __name__ == "__main__":
args = parse_args()
args.num_rounds = 1
args.max_parallel = 128
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
for request_rate in [24, 16, 12, 8, 4, 2, 1]:
args.request_rate = request_rate
requests.post(flush_cache_url)
time.sleep(1)
performance_data = ContextWorkloadGenerator(args).run()
log_to_jsonl_file(performance_data, args.log_file, args.tag)
......@@ -322,6 +322,9 @@ class WorkloadGenerator:
"prompt_len": [],
"cached_tokens": [],
}
self.num_rounds = args.num_rounds
self.max_parallel = args.max_parallel
self.output_length = args.output_length
async def handle_request(self, item):
try:
......@@ -336,7 +339,7 @@ class WorkloadGenerator:
def request_sender(self):
async def request_loop():
while True:
if self.sent_requests - self.completed_requests < args.max_parallel:
if self.sent_requests - self.completed_requests < self.max_parallel:
new_request = self.ready_queue.pop()
if new_request:
asyncio.create_task(self.handle_request(new_request))
......@@ -382,7 +385,7 @@ class WorkloadGenerator:
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
self.completed_requests += 1
if self.client_records[client_id]["round"] < args.num_rounds:
if self.client_records[client_id]["round"] < self.num_rounds:
# append new request to client's history
self.client_records[client_id][
"history"
......@@ -392,7 +395,7 @@ class WorkloadGenerator:
client_id,
gen_payload(
self.client_records[client_id]["history"],
args.output_length,
self.output_length,
),
)
)
......@@ -461,7 +464,7 @@ class WorkloadGenerator:
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
)
print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}")
log_to_jsonl_file(performance_data, args.log_file, tag=args.tag)
return performance_data
if __name__ == "__main__":
......@@ -482,4 +485,5 @@ if __name__ == "__main__":
args.request_rate = rate
requests.post(flush_cache_url)
time.sleep(1)
WorkloadGenerator(args).run()
performance_data = WorkloadGenerator(args).run()
log_to_jsonl_file(performance_data, args.log_file, tag=args.tag)
......@@ -71,8 +71,10 @@ class HiRadixCache(RadixCache):
self.tp_group = tp_cache_group
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
self.enable_storage = hicache_storage_backend is not None
# todo: customizable storage prefetch threshold
# todo: customizable storage prefetch threshold and timeout
self.prefetch_threshold = 256
self.prefetch_timeout = 3 # seconds
self.prefetch_stop_policy = hicache_storage_prefetch_policy
self.load_cache_event = threading.Event()
self.cache_controller = HiCacheController(
......@@ -87,13 +89,6 @@ class HiRadixCache(RadixCache):
prefetch_threshold=self.prefetch_threshold,
)
self.prefetch_stop_policy = hicache_storage_prefetch_policy
# todo: customizable storage prefetch timeout
self.prefetch_timeout = 3 # seconds
logger.info(
f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
)
# record the nodes with ongoing write through
self.ongoing_write_through = {}
# record the node segments with ongoing load back
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment