"tests/vscode:/vscode.git/clone" did not exist on "8d6487f3cbe89bb6e32f82fc9f04df6ce001ef24"
Unverified Commit 13387e6b authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Multi-turn benchmark for hierarchical caching (#2942)

parent 120c3634
import argparse
import asyncio
import json
import queue
import random
import threading
import time
from typing import Optional
import aiohttp
import requests
from tqdm.asyncio import tqdm
from sglang.bench_serving import (
RequestFuncOutput,
get_tokenizer,
remove_prefix,
sample_random_requests,
)
def parse_args():
parser = argparse.ArgumentParser(
description="Script to benchmark concurrent requests to a server."
)
parser.add_argument(
"--num-clients",
type=int,
default=200,
help="Number of concurrent clients",
)
parser.add_argument(
"--request-length",
type=int,
default=512,
help="Length of each new request",
)
parser.add_argument(
"--output-length",
type=int,
default=64,
help="Length of each output",
)
parser.add_argument(
"--num-rounds",
type=int,
default=5,
help="Number of rounds per client",
)
parser.add_argument(
"--distribution",
type=str,
default="poisson",
choices=["poisson", "uniform"],
help="Distribution type for request intervals (poisson or uniform)",
)
parser.add_argument(
"--request-rate",
type=float,
default=1.0,
help="Average number of requests per second",
)
parser.add_argument(
"--host",
type=str,
default="localhost",
help="Server hostname or IP (default: localhost)",
)
parser.add_argument(
"--port",
type=int,
default=30000,
help="Server port (default: 30000)",
)
parser.add_argument(
"--model",
type=str,
default="meta-llama/Llama-3.1-8B-Instruct",
help="model path compatible with Hugging Face Transformers",
)
return parser.parse_args()
async def async_request_sglang_generate(
payload,
url,
pbar: Optional[tqdm] = None,
):
"""
Sends a streaming request to the server. Gathers text token-by-token.
"""
async with aiohttp.ClientSession() as session:
headers = {}
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
output = RequestFuncOutput()
try:
async with session.post(url=url, json=payload, headers=headers) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
data = json.loads(chunk)
if data["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text = data["text"]
output.generated_text = generated_text
output.success = True
output.latency = latency
else:
output.error = response.reason or ""
output.success = False
except Exception as e:
output.success = False
output.error = str(e)
print(f"Request failed: {e}")
if pbar:
pbar.update(1)
return output
def gen_payload(prompt, output_len):
payload = {
"text": prompt,
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": output_len,
"ignore_eos": True,
},
"stream": True,
"lora_path": "",
"return_logprob": False,
"logprob_start_len": -1,
}
return payload
class ReadyQueue:
"""
Thread-safe queue that can pop requests in different orders based on given policy.
"""
def __init__(self, init_requests=None, policy="random"):
self.lock = threading.Lock()
self.requests = init_requests or []
self.policy = policy
def append(self, item):
with self.lock:
self.requests.append(item)
def pop(self):
with self.lock:
if not self.requests:
return None
if self.policy == "random":
index = random.randrange(len(self.requests))
return self.requests.pop(index)
elif self.policy == "fifo":
return self.requests.pop(0)
else:
# todo, varying thinking time of clients
raise ValueError(f"{self.policy} not implemented")
class WorkloadGenerator:
def __init__(self, args):
# Construct the base URL for requests
self.url = f"http://{args.host}:{args.port}/generate"
self.tokenizer = get_tokenizer(args.model)
self.distribution = args.distribution
self.request_rate = args.request_rate
self.start_time = None
self.finished_time = None
self.candidate_inputs = sample_random_requests(
input_len=args.request_length,
output_len=args.output_length,
num_prompts=args.num_clients * args.num_rounds,
range_ratio=1.0,
tokenizer=self.tokenizer,
dataset_path="",
)
self.candidate_inputs = [i[0] for i in self.candidate_inputs]
init_requests = [
(i, gen_payload(self.candidate_inputs[i], args.output_length))
for i in range(args.num_clients)
]
self.client_records = {
i: {"round": 0, "history": init_requests[i][1]["text"]}
for i in range(args.num_clients)
}
self.ready_queue = ReadyQueue(init_requests=init_requests)
self.candidate_inputs = self.candidate_inputs[args.num_clients :]
self.response_queue = queue.Queue()
self.pbar = tqdm(total=args.num_clients * args.num_rounds)
self.performance_metrics = {"ttft": [], "latency": []}
async def handle_request(self, item):
try:
client_id, payload = item
response = await async_request_sglang_generate(payload, self.url, self.pbar)
if self.pbar.n == self.pbar.total:
self.finished_time = time.time()
self.response_queue.put((client_id, response))
except Exception as e:
print(f"Request failed: {e}")
def request_sender(self):
async def request_loop():
while True:
# Calculate Poisson-distributed wait time
if self.distribution == "poisson":
sleep_time = random.expovariate(self.request_rate)
elif self.distribution == "uniform":
avg_interval = (
1.0 / self.request_rate if self.request_rate > 0 else 1.0
)
sleep_time = random.uniform(0, 2 * avg_interval)
else:
raise ValueError("Invalid distribution type")
await asyncio.sleep(sleep_time) # Wait before sending the next request
new_request = self.ready_queue.pop()
# Submit async request
if new_request:
asyncio.create_task(self.handle_request(new_request))
else:
if self.pbar.n == self.pbar.total:
break
# Create and run the event loop for asynchronous requests
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(request_loop())
loop.close()
def response_handler(self):
while True:
try:
client_id, response = self.response_queue.get(
timeout=10
) # Block until response is available
if not response.success:
raise ValueError(f"Request failed with error: {response.error}")
self.client_records[client_id]["history"] += response.generated_text
self.client_records[client_id]["round"] += 1
self.performance_metrics["ttft"].append(response.ttft)
self.performance_metrics["latency"].append(response.latency)
if self.client_records[client_id]["round"] < args.num_rounds:
self.client_records[client_id][
"history"
] += self.candidate_inputs.pop()
self.ready_queue.append(
(
client_id,
gen_payload(
self.client_records[client_id]["history"],
args.output_length,
),
)
)
except queue.Empty:
if self.pbar.n == self.pbar.total:
break
def run(self):
request_thread = threading.Thread(target=self.request_sender, daemon=True)
response_thread = threading.Thread(target=self.response_handler, daemon=True)
self.start_time = time.time()
request_thread.start()
response_thread.start()
request_thread.join()
response_thread.join()
self.pbar.close()
print("All requests completed.")
print("Performance metrics summary:")
print(
f" Total requests: {len(self.performance_metrics['ttft'])} at {self.request_rate} requests per second"
)
print(
f" Average TTFT: {sum(self.performance_metrics['ttft']) / len(self.performance_metrics['ttft']):.2f}"
)
print(
f" Median TTFT: {sorted(self.performance_metrics['ttft'])[len(self.performance_metrics['ttft']) // 2]:.2f}"
)
print(
f" Average latency: {sum(self.performance_metrics['latency']) / len(self.performance_metrics['latency']):.2f}"
)
print(
f" Median latency: {sorted(self.performance_metrics['latency'])[len(self.performance_metrics['latency']) // 2]:.2f}"
)
throughput = self.pbar.total / (self.finished_time - self.start_time)
print(f"Throughput: {throughput:.2f} requests per second")
if __name__ == "__main__":
args = parse_args()
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
for request_rate in range(1, 41, 2):
args.request_rate = request_rate
requests.post(flush_cache_url)
WorkloadGenerator(args).run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment