Unverified Commit 25482edb authored by Yueyang Pan's avatar Yueyang Pan Committed by GitHub
Browse files

Online serving benchmarks of real datasets for hierarchical KV caching (#3211)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 62b362b1
......@@ -22,4 +22,70 @@ python bench_multiturn.py --model-path Qwen/Qwen2.5-14B-Instruct
Note: The performance gain of hierarchical caching depends on the ratio of reusable tokens to GPU memory capacity. The more tokens to be reused, the larger the model, and the more constrained the GPU memory size, the greater the benefit one can expect from hierarchical caching.
## More benchmarks to be added
# Benchmark with more datasets
## Download Dataset
```bash
./download.sh {sharegpt|ultragpt|loogle|nextqa|all}
```
This script will automatically download the required dataset to the current working directory
## Multiturn Benchmark
### Supported Datasets
- sharegpt
- ultrachat
- loogle
### Example Usage:
```bash
python3 bench_serving.py --model mistralai/Mistral-7B-Instruct-v0.3 --backend sglang \
--dataset-path longdep_qa.json --dataset-name loogle --request-rate 10 --num-prompts 10 \
--port 8001 --enable-multiturn --disable-shuffle
```
This uses `mistralai/Mistral-7B-Instruct-v0.3` model with `sglang` as backend. The dataset
is `longdep_qa.json`. We send `10 conversations` with `10 req/s` to port 8001. We enable
multiturn chat without shuffling the order of conversations (i.e. following the original
order in the dataset file).
### Note:
The requests of multiple conversations are sent in a round robin fashion.
For example, if we have 3 conversations A, B, C whose rounds are `[2, 3, 4]` correspondingly,
multiturn chat will send the requests to the backend in the following order: `[A1, B1, C1, A2, B2, C2, B3, C3, C4]`
This has implications on the cache reuse patterns: the cache reuse distance is the largest
under this request pattern (which means a prefix-aware local scheduler in the backend can
yield the most benefit compared to a FIFO scheduler)
## Shared Prefix Benchmark
### Supported Datasets
- loogle
### Example Usage:
```bash
python3 bench_serving.py --model mistralai/Mistral-7B-Instruct-v0.3 --backend sglang \
--dataset-path longdep_qa.json --dataset-name loogle --request-rate 10 --num-prompts 10 \
--port 8001 --enable-shared-prefix --disable-shuffle
```
### Note:
Shared Prefix benchmark sends the questions for the same prompt together. For example,
if we have 3 shared prefix A, B, C, which have [2, 3, 4] questions correspondingly,
the shared prefix benchmark will send the requests to the
backend in the following order: `[A+Q1, A+Q2, B+Q1, B+Q2, B+Q3, C+Q1, C+Q2, C+Q3]`.
## Multi Modality Benchmark (WIP)
### Supported Datasets:
- nextqa
### Example Usage:
```bash
Server:
python3 -m sglang.launch_server --model-path lmms-lab/LLaVA-NeXT-Video-7B --tp 2 --dp 1 --port 8001 \
--host 0.0.0.0 --mem-fraction-static 0.9 --tokenizer-path llava-hf/llava-1.5-7b-hf \
--json-model-override-args "{\"architectures\": [\"LlavaVidForCausalLM\"], \"model_type\":\"llava\", \"mm_spatial_pool_stride\":2}"
Client:
python3 bench_serving.py --model lmms-lab/LLaVA-NeXT-Video-7B --backend sglang --dataset-path \
NExTVideo --dataset-name nextqa --request-rate 10 --num-prompts 1 --disable-shuffle --port 8001 \ --enable-multiturn --max-frames 16 --tokenizer llava-hf/llava-1.5-7b-hf --fixed-output-len 2048
```
Note: for the server args, `tokenizer-path`, overriding architecture are necessary.
## Supported Backend
- sglang (oai)
- vllm (oai)
- lmdeploy (oai)
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
"""
Benchmark online serving with dynamic requests.
Usage:
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi
"""
import argparse
import asyncio
import json
import os
import random
import sys
import time
import traceback
import warnings
from argparse import ArgumentParser
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
import aiohttp
import numpy as np
import requests
from data_processing import MsgContent, SampleOutput, get_dataset
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
from sglang.bench_serving import get_tokenizer, remove_prefix, set_ulimit
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
global args
@dataclass
class RequestFuncInput:
prompts: List[Tuple[MsgContent, int, int]]
api_url: str
model: str
lora_name: str
extra_request_body: Dict[str, Any]
# For multiturn chat, store the context
prev_messages: List = field(default_factory=list)
finished_prompts: int = 0
@dataclass
class RequestFuncOutput:
generated_text: List[str] = field(default_factory=list)
prompt_len: List[int] = field(default_factory=list)
output_len: List[int] = field(default_factory=list)
latency: List[float] = field(default_factory=list)
ttft: List[float] = field(default_factory=list)
itl: List[float] = field(default_factory=list) # List of inter-token latencies
success: bool = False
error: str = ""
# set ignore_eos True by default
async def async_request_openai_completions(
request_func_input: RequestFuncInput,
queue: asyncio.Queue,
tokenizer: PreTrainedTokenizerBase,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
"completions"
), "OpenAI Completions API URL must end with 'completions'."
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"model": request_func_input.model,
"temperature": 0.0,
"best_of": 1,
"stream": not args.disable_stream,
"stream_options": {"include_usage": True},
"ignore_eos": not args.disable_ignore_eos,
**request_func_input.extra_request_body,
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
output = RequestFuncOutput()
prompt_idx = request_func_input.finished_prompts
messages = request_func_input.prev_messages
prompt, input_len, max_tokens = request_func_input.prompts[prompt_idx]
prompt_len = sum(
prompt[1] + prompt[2] # input_len + output_len
for prompt in request_func_input.prompts[:prompt_idx]
)
prompt_len += input_len
# Messages
messages.append(
{
"role": "user",
"content": prompt,
}
)
payload["messages"] = messages
payload["max_tokens"] = max_tokens
# output.prompt_len = request_func_input.prompt_len
# print(payload)
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200:
actual_prompt_len = prompt_len - 1
actual_output_len = 0
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
data = json.loads(chunk)
timestamp = time.perf_counter()
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if data["usage"] is not None and len(data["usage"]) > 0:
actual_prompt_len = data["usage"]["prompt_tokens"]
actual_output_len = data["usage"]["completion_tokens"]
continue
delta = data["choices"][0]["delta"]
if delta.get("content", None):
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft.append(ttft)
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
generated_text += delta["content"]
most_recent_timestamp = timestamp
output.prompt_len.append(actual_prompt_len) # truncate <s>
output.output_len.append(actual_output_len)
output.generated_text.append(generated_text)
output.success = True
output.latency.append(latency)
# Prepare for the new request
request_func_input.prompts[prompt_idx] = (
prompt,
input_len,
actual_output_len, # changes from max_tokens to output_len
)
prompt_idx += 1
messages.append(
{
"role": "assistant",
"content": generated_text,
}
)
# Move the new request to the end of the queue
if prompt_idx < len(request_func_input.prompts):
request_func_input.finished_prompts = prompt_idx
request_func_input.prev_messages = messages
await queue.put(request_func_input)
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
async def async_request_profile(api_url: str) -> RequestFuncOutput:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
output = RequestFuncOutput()
try:
async with session.post(url=api_url) as response:
if response.status == 200:
output.success = True
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
return output
ASYNC_REQUEST_FUNCS = {
"sglang": async_request_openai_completions,
"vllm": async_request_openai_completions,
"lmdeploy": async_request_openai_completions,
}
@dataclass
class BenchmarkMetrics:
completed: int
total_input: int
total_output: int
total_output_retokenized: int
request_throughput: float
input_throughput: float
output_throughput: float
output_throughput_retokenized: float
total_throughput: float
total_throughput_retokenized: float
mean_ttft_ms: float
median_ttft_ms: float
std_ttft_ms: float
p90_ttft_ms: float
p99_ttft_ms: float
mean_tpot_ms: float
median_tpot_ms: float
std_tpot_ms: float
p90_tpot_ms: float
p99_tpot_ms: float
mean_itl_ms: float
median_itl_ms: float
std_itl_ms: float
p90_itl_ms: float
p99_itl_ms: float
mean_e2e_latency_ms: float
median_e2e_latency_ms: float
std_e2e_latency_ms: float
p99_e2e_latency_ms: float
concurrency: float
async def get_requests(
input_requests_queue: asyncio.Queue,
request_rate: float,
num_actual_requests: int,
) -> AsyncGenerator[RequestFuncInput, None]:
for _ in range(num_actual_requests):
try:
request = await asyncio.wait_for(
input_requests_queue.get(), timeout=300
) # Wait for 5 minites then abort
except Exception as e:
print(f"exception: {e}")
break
yield request
if request_rate == float("inf"):
continue
interval = np.random.exponential(1.0 / request_rate)
await asyncio.sleep(interval)
def calculate_metrics(
outputs: List[RequestFuncOutput],
dur_s: float,
tokenizer: PreTrainedTokenizerBase,
backend: str,
) -> Tuple[BenchmarkMetrics, List[int]]:
output_lens: List[int] = []
retokenized_output_lens: List[int] = []
total_input = 0
completed = 0
itls: List[float] = []
tpots: List[float] = []
ttfts: List[float] = []
e2e_latencies: List[float] = []
output_success = 0
for i in range(len(outputs)):
if outputs[i].success:
output_success += 1
assert len(outputs[i].generated_text) == len(outputs[i].latency)
assert len(outputs[i].generated_text) == len(outputs[i].ttft)
for j in range(len(outputs[i].generated_text)):
output_len = outputs[i].output_len[j]
output_lens.append(output_len)
retokenized_output_len = len(
tokenizer.encode(
outputs[i].generated_text[j], add_special_tokens=False
)
)
retokenized_output_lens.append(retokenized_output_len)
total_input += outputs[i].prompt_len[j]
if output_len > 1:
tpots.append(
(outputs[i].latency[j] - outputs[i].ttft[j]) / (output_len - 1)
)
completed += 1
itls += outputs[i].itl
ttfts += outputs[i].ttft
e2e_latencies += outputs[i].latency
else:
output_lens.append(0)
retokenized_output_lens.append(0)
if completed == 0:
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2,
)
metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
total_output=sum(output_lens),
total_output_retokenized=sum(retokenized_output_lens),
request_throughput=completed / dur_s,
input_throughput=total_input / dur_s,
output_throughput=sum(output_lens) / dur_s,
output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
total_throughput=(total_input + sum(output_lens)) / dur_s,
total_throughput_retokenized=(total_input + sum(retokenized_output_lens))
/ dur_s,
mean_ttft_ms=np.mean(ttfts or 0)
* 1000, # ttfts is empty if streaming is not supported by backend
median_ttft_ms=np.median(ttfts or 0) * 1000,
std_ttft_ms=np.std(ttfts or 0) * 1000,
p90_ttft_ms=np.percentile(ttfts or 0, 90) * 1000,
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
mean_tpot_ms=np.mean(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000,
p90_tpot_ms=np.percentile(tpots or 0, 90) * 1000,
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
mean_itl_ms=np.mean(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000,
p90_itl_ms=np.percentile(itls or 0, 90) * 1000,
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000,
concurrency=np.sum(e2e_latencies) / dur_s,
)
return metrics, output_lens
async def benchmark(
backend: str,
api_url: str,
base_url: str,
model_id: str,
tokenizer: PreTrainedTokenizerBase,
input_requests: SampleOutput,
request_rate: float,
max_concurrency: Optional[int],
disable_tqdm: bool,
lora_name: str,
extra_request_body: Dict[str, Any],
profile: bool,
enable_shared_prefix: bool,
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
else:
raise ValueError(f"Unknown backend: {backend}")
# Limit concurrency
# From https://github.com/vllm-project/vllm/pull/9390
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
async def limited_request_func(request_func_input, queue, tokenizer, pbar):
if semaphore is None:
return await request_func(
request_func_input=request_func_input,
queue=queue,
tokenizer=tokenizer,
pbar=pbar,
)
async with semaphore:
return await request_func(
request_func_input=request_func_input,
queue=queue,
tokenizer=tokenizer,
pbar=pbar,
)
num_actual_requests = sum(len(r) for r in input_requests)
print(f"Num of shared prefixes or conversations: {len(input_requests)}")
print(f"Num of total requests: {num_actual_requests}")
# flatten the requests for shared prefix
if enable_shared_prefix:
input_requests = [[r] for requests in input_requests for r in requests]
inputs_requests_queue = asyncio.Queue(maxsize=len(input_requests))
print("Starting initial single prompt test run...")
# NOTE: Just use the first request of the first conversation for warmup
test_input = RequestFuncInput(
model=model_id,
prompts=input_requests[0][:1],
api_url=api_url,
lora_name=lora_name,
extra_request_body=extra_request_body,
)
test_output = await request_func(
request_func_input=test_input, queue=inputs_requests_queue, tokenizer=tokenizer
)
if not test_output.success:
raise ValueError(
"Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}"
)
else:
print("Initial test run completed. Starting main benchmark run...")
# Check the states
assert inputs_requests_queue.empty()
# Flush cache
if "sglang" in backend:
requests.post(base_url + "/flush_cache")
time.sleep(1.0)
# Start profiler
if profile:
print("Starting profiler...")
profile_output = await async_request_profile(
api_url=base_url + "/start_profile"
)
if profile_output.success:
print("Profiler started")
for request in input_requests:
request_func_input = RequestFuncInput(
model=model_id,
prompts=request,
api_url=api_url,
lora_name=lora_name,
extra_request_body=extra_request_body,
)
inputs_requests_queue.put_nowait(request_func_input)
if (
not args.enable_multiturn
and not args.enable_shared_prefix
and not args.dataset_name == "generated-shared-prefix"
):
assert len(input_requests) == num_actual_requests
pbar = None if disable_tqdm else tqdm(total=num_actual_requests)
benchmark_start_time = time.perf_counter()
tasks: List[asyncio.Task] = []
async for request in get_requests(
inputs_requests_queue, request_rate, num_actual_requests
):
tasks.append(
asyncio.create_task(
limited_request_func(
request_func_input=request,
queue=inputs_requests_queue,
tokenizer=tokenizer,
pbar=pbar,
)
)
)
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
# Stop profiler
if profile:
print("Stopping profiler...")
profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
if profile_output.success:
print("Profiler stopped")
if pbar is not None:
pbar.close()
# Compute metrics and print results
benchmark_duration = time.perf_counter() - benchmark_start_time
metrics, output_lens = calculate_metrics(
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
backend=backend,
)
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Backend:", backend))
print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
print(
"{:<40} {:<10}".format(
"Max reqeuest concurrency:",
max_concurrency if max_concurrency else "not set",
)
)
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
print(
"{:<40} {:<10}".format(
"Total generated tokens (retokenized):", metrics.total_output_retokenized
)
)
print(
"{:<40} {:<10.2f}".format(
"Request throughput (req/s):", metrics.request_throughput
)
)
print(
"{:<40} {:<10.2f}".format(
"Input token throughput (tok/s):", metrics.input_throughput
)
)
print(
"{:<40} {:<10.2f}".format(
"Output token throughput (tok/s):", metrics.output_throughput
)
)
print(
"{:<40} {:<10.2f}".format(
"Total token throughput (tok/s):", metrics.total_throughput
)
)
print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
print(
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
)
print(
"{:<40} {:<10.2f}".format(
"Median E2E Latency (ms):", metrics.median_e2e_latency_ms
)
)
print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
print("{:<40} {:<10.2f}".format("P90 TTFT (ms):", metrics.p90_ttft_ms))
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
print(
"{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
)
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
print("{:<40} {:<10.2f}".format("P90 TPOT (ms):", metrics.p90_tpot_ms))
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
print("{:<40} {:<10.2f}".format("P90 ITL (ms):", metrics.p90_itl_ms))
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
print("=" * 50)
if (
metrics.median_ttft_ms is not None
and metrics.mean_itl_ms is not None
and metrics.output_throughput is not None
):
result = {
# Arguments
"backend": args.backend,
"dataset_name": args.dataset_name,
"request_rate": request_rate,
"max_concurrency": max_concurrency,
"fixed_output_len": args.fixed_output_len,
"random_input_len": args.random_input_len,
"random_output_len": args.random_output_len,
"random_range_ratio": args.random_range_ratio,
# Results
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"total_output_tokens_retokenized": metrics.total_output_retokenized,
"request_throughput": metrics.request_throughput,
"input_throughput": metrics.input_throughput,
"output_throughput": metrics.output_throughput,
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
"std_e2e_latency_ms": metrics.std_e2e_latency_ms,
"p99_e2e_latency_ms": metrics.p99_e2e_latency_ms,
"mean_ttft_ms": metrics.mean_ttft_ms,
"median_ttft_ms": metrics.median_ttft_ms,
"std_ttft_ms": metrics.std_ttft_ms,
"p99_ttft_ms": metrics.p99_ttft_ms,
"mean_tpot_ms": metrics.mean_tpot_ms,
"median_tpot_ms": metrics.median_tpot_ms,
"std_tpot_ms": metrics.std_tpot_ms,
"p99_tpot_ms": metrics.p99_tpot_ms,
"mean_itl_ms": metrics.mean_itl_ms,
"median_itl_ms": metrics.median_itl_ms,
"std_itl_ms": metrics.std_itl_ms,
"p99_itl_ms": metrics.p99_itl_ms,
"concurrency": metrics.concurrency,
"input_throughput": metrics.input_throughput,
"output_throughput": metrics.output_throughput,
"fixed_output_len": args.fixed_output_len,
"random_input_len": args.random_input_len,
"random_output_len": args.random_output_len,
"random_range_ratio": args.random_range_ratio,
"duration": benchmark_duration,
"completed": metrics.completed,
}
else:
print(f"Error running benchmark for request rate: {request_rate}")
print("-" * 30)
# Determine output file name
if args.output_file:
output_file_name = args.output_file
else:
now = datetime.now().strftime("%m%d")
if args.dataset_name == "random":
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
else:
output_file_name = (
f"{args.backend}_{now}_{args.num_prompts}_{args.dataset_name}.jsonl"
)
# Append results to a JSONL file
with open(output_file_name, "a") as file:
file.write(json.dumps(result) + "\n")
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"total_output_tokens_retokenized": metrics.total_output_retokenized,
"request_throughput": metrics.request_throughput,
"input_throughput": metrics.input_throughput,
"output_throughput": metrics.output_throughput,
"mean_ttft_ms": metrics.mean_ttft_ms,
"median_ttft_ms": metrics.median_ttft_ms,
"std_ttft_ms": metrics.std_ttft_ms,
"p90_ttft_ms": metrics.p90_ttft_ms,
"p99_ttft_ms": metrics.p99_ttft_ms,
"mean_tpot_ms": metrics.mean_tpot_ms,
"median_tpot_ms": metrics.median_tpot_ms,
"std_tpot_ms": metrics.std_tpot_ms,
"p90_tpot_ms": metrics.p90_tpot_ms,
"p99_tpot_ms": metrics.p99_tpot_ms,
"mean_itl_ms": metrics.mean_itl_ms,
"median_itl_ms": metrics.median_itl_ms,
"std_itl_ms": metrics.std_itl_ms,
"p90_itl_ms": metrics.p90_itl_ms,
"p99_itl_ms": metrics.p99_itl_ms,
"input_lens": [output.prompt_len for output in outputs],
"output_lens": output_lens,
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs],
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
}
return result
def run_benchmark(args_: argparse.Namespace):
global args
args = args_
# Set default value for max_concurrency if not present
if not hasattr(args, "max_concurrency"):
args.max_concurrency = None
# Set global environments
set_ulimit()
random.seed(args.seed)
np.random.seed(args.seed)
extra_request_body = {}
if args.extra_request_body:
extra_request_body = json.loads(args.extra_request_body)
# Set url
if args.port is None:
args.port = {
"sglang": 30000,
"lmdeploy": 23333,
"vllm": 8000,
}.get(args.backend, 30000)
model_url = (
f"{args.base_url}/v1/models"
if args.base_url
else f"http://{args.host}:{args.port}/v1/models"
)
if args.backend in ["sglang", "vllm", "lmdeploy"]:
api_url = (
f"{args.base_url}/v1/chat/completions"
if args.base_url
else f"http://{args.host}:{args.port}/v1/chat/completions"
)
base_url = (
f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
)
# Get model name
if args.model is None:
if args.backend == "truss":
print(
"Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct"
)
sys.exit(1)
try:
response = requests.get(model_url)
model_list = response.json().get("data", [])
args.model = model_list[0]["id"] if model_list else None
except Exception as e:
print(f"Failed to fetch model from {model_url}. Error: {e}")
print(
"Please specify the correct host and port using `--host` and `--port`."
)
sys.exit(1)
if args.model is None:
print("No model specified or found. Please provide a model using `--model`.")
sys.exit(1)
# Dataset compatibility check
if args.enable_multiturn:
# TODO: Support multiturn for random
if args.dataset_name not in ["sharegpt", "ultrachat", "loogle", "nextqa"]:
print(
"Multiturn conversation is only supported for sharegpt, ultrachat, loogle, and nextqa datasets."
)
sys.exit(1)
if args.enable_shared_prefix:
if args.dataset_name not in ["loogle", "nextqa"]:
print("Shared prefix is only supported for loogle and nextqa datasets.")
sys.exit(1)
print(f"{args}\n")
# Read dataset
backend = args.backend
model_id = args.model
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer = get_tokenizer(tokenizer_id)
input_requests = get_dataset(args, tokenizer)
return asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm,
lora_name=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
enable_shared_prefix=args.enable_shared_prefix,
)
)
if __name__ == "__main__":
parser = ArgumentParser(description="Benchmark the online serving throughput.")
parser.add_argument(
"--backend",
type=str,
choices=list(ASYNC_REQUEST_FUNCS.keys()),
default="sglang",
help="Must specify a backend, depending on the LLM Inference Engine.",
)
parser.add_argument(
"--base-url",
type=str,
default=None,
help="Server or API base url if not using http host and port.",
)
parser.add_argument(
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
)
parser.add_argument(
"--port",
type=int,
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
)
parser.add_argument(
"--dataset-name",
type=str,
default="sharegpt",
choices=[
"sharegpt",
"random",
"generated-shared-prefix",
"ultrachat",
"loogle",
"nextqa",
],
help="Name of the dataset to benchmark on.",
)
parser.add_argument(
"--dataset-path", type=str, default="", help="Path to the dataset."
)
parser.add_argument(
"--model",
type=str,
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
)
parser.add_argument(
"--tokenizer",
type=str,
help="Name or path of the tokenizer. If not set, using the model conf.",
)
parser.add_argument(
"--chat-template",
type=str,
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
)
parser.add_argument(
"--num-prompts",
type=int,
default=1000,
help="Number of prompts to process. Default is 1000.",
)
parser.add_argument(
"--fixed-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output length from the dataset.",
)
parser.add_argument(
"--sharegpt-context-len",
type=int,
default=None,
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
)
parser.add_argument(
"--random-input-len",
type=int,
default=1024,
help="Number of input tokens per request, used only for random dataset.",
)
parser.add_argument(
"--random-output-len",
default=1024,
type=int,
help="Number of output tokens per request, used only for random dataset.",
)
parser.add_argument(
"--random-range-ratio",
type=float,
default=0.0,
help="Range of sampled ratio of input/output length, "
"used only for random dataset.",
)
parser.add_argument(
"--request-rate",
type=float,
default=float("inf"),
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
)
parser.add_argument(
"--max-concurrency",
type=int,
default=None,
help="Maximum number of concurrent requests. This can be used "
"to help simulate an environment where a higher level component "
"is enforcing a maximum number of concurrent requests. While the "
"--request-rate argument controls the rate at which requests are "
"initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.",
)
parser.add_argument(
"--multi",
action="store_true",
help="Use request rate range rather than single value.",
)
parser.add_argument(
"--request-rate-range",
type=str,
default="2,34,2",
help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
)
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
parser.add_argument(
"--enable-multiturn",
action="store_true",
help="Enable multiturn chat for online serving benchmarking. "
"This option is effective on the following datasets: "
"sharegpt, ultrachat, loogle, nextqa",
)
parser.add_argument(
"--enable-shared-prefix",
action="store_true",
help="Enable shared prefix for online serving benchmarking. "
"This option is effective on the following datasets: "
"loogle, nextqa",
)
parser.add_argument(
"--disable-shuffle",
action="store_true",
help="Disable shuffling datasets. This is useful to generate stable output "
"in benchmarking",
)
parser.add_argument(
"--disable-tqdm",
action="store_true",
help="Specify to disable tqdm progress bar.",
)
parser.add_argument(
"--disable-stream",
action="store_true",
help="Disable streaming mode.",
)
parser.add_argument(
"--return-logprob",
action="store_true",
help="Return logprob.",
)
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--disable-ignore-eos",
action="store_true",
help="Disable ignoring EOS.",
)
parser.add_argument(
"--extra-request-body",
metavar='{"key1": "value1", "key2": "value2"}',
type=str,
help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.",
)
parser.add_argument(
"--apply-chat-template",
action="store_true",
help="Apply chat template",
)
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--lora-name",
type=str,
default=None,
help="The name of LoRA adapter",
)
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
group.add_argument(
"--gsp-num-groups",
type=int,
default=64,
help="Number of system prompt groups for generated-shared-prefix dataset",
)
group.add_argument(
"--gsp-prompts-per-group",
type=int,
default=16,
help="Number of prompts per system prompt group for generated-shared-prefix dataset",
)
group.add_argument(
"--gsp-system-prompt-len",
type=int,
default=2048,
help="Target length in tokens for system prompts in generated-shared-prefix dataset",
)
group.add_argument(
"--gsp-question-len",
type=int,
default=128,
help="Target length in tokens for questions in generated-shared-prefix dataset",
)
group.add_argument(
"--gsp-output-len",
type=int,
default=256,
help="Target length in tokens for outputs in generated-shared-prefix dataset",
)
# videos specific
parser.add_argument(
"--max-frames",
type=int,
default=sys.maxsize,
help="The maximum number of frames to extract from each video. "
"This option is specific to the nextqa dataset (video benchmark). ",
)
args = parser.parse_args()
if args.enable_multiturn and args.enable_shared_prefix:
parser.error(
"--enable-multiturn and --enable-shared-prefix cannot be set at the same time."
)
run_benchmark(args)
import json
import os
import pickle
import random
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
from nextqa import NExTQALoader
# from nextqa.video import , VideoPrompt
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
from sglang.bench_serving import (
download_and_cache_file,
gen_prompt,
get_gen_prefix_cache_path,
)
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
from sglang.utils import encode_video_base64
# type of content fields, can be only prompts or with images/videos
MsgContent = Union[str, List[ChatCompletionMessageContentPart]]
# A list of all the conversations. Each conversation is a list of
# tuples. If multiturn is not enabled, the length of list is 1,
# containing only the first Q&A pair.
# For the shared prefix workload (synthetic, loogle, nextqa), it
# is a list of conversations sharing the same prefix (synthetic,
# doc, video)
SampleOutput = List[List[Tuple[MsgContent, int, int]]]
def common_filter_chat(
num_requests: int,
new_dataset: List,
tokenizer: PreTrainedTokenizerBase,
min_prompt_len: Optional[int],
min_output_len: Optional[int],
max_prompt_len: Optional[int],
max_output_len: Optional[int],
fixed_output_len: Optional[int],
) -> SampleOutput:
# Filter out sequences that are too long or too short
filtered_dataset: SampleOutput = []
l = 0
input_tokens = 0
output_tokens = 0
while l < num_requests:
for i in range(len(new_dataset)):
if l == num_requests:
break
processed = []
for j in new_dataset[i]:
# Tokenize the prompts and completions.
prompt = j[0]
prompt_token_ids = tokenizer.encode(prompt)
prompt_len = len(prompt_token_ids)
completion = j[1]
completion_token_ids = tokenizer.encode(completion)
output_len = (
len(completion_token_ids)
if fixed_output_len is None
else fixed_output_len
)
if (
min_prompt_len is not None
and prompt_len < min_prompt_len
or min_output_len is not None
and output_len < min_output_len
or max_prompt_len is not None
and prompt_len > max_prompt_len
or max_output_len is not None
and output_len > max_output_len
):
# Prune too short sequences.
continue
input_tokens += prompt_len
output_tokens += output_len
processed.append((prompt, prompt_len, output_len))
filtered_dataset.append(processed)
l += 1
print(f"#Input tokens: {input_tokens}")
print(f"#Output tokens: {output_tokens}")
return filtered_dataset
def sample_sharegpt_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
disable_shuffle: bool = False,
enable_multiturn: bool = True,
fixed_output_len: Optional[int] = None,
) -> SampleOutput:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Download sharegpt if necessary
if not os.path.isfile(dataset_path):
dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Keep one conversation in one list
new_dataset = []
for data in dataset:
if len(data["conversations"]) % 2 != 0:
continue
if data["conversations"][0]["from"] != "human":
continue
chat = []
total_len = 2
if enable_multiturn:
total_len = len(data["conversations"])
for i in range(0, total_len, 2):
# One user One Assistant
chat.append(
(
data["conversations"][i]["value"],
data["conversations"][i + 1]["value"],
)
)
new_dataset.append(chat)
if not disable_shuffle:
# Shuffle the dataset.
random.shuffle(new_dataset)
# Filter out sequences that are too long or too short
filtered_dataset: SampleOutput = common_filter_chat(
num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len
)
return filtered_dataset
def sample_ultrachat_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
disable_shuffle: bool = False,
enable_multiturn: bool = True,
fixed_output_len: Optional[int] = None,
) -> SampleOutput:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset
dataset = []
with open(dataset_path) as f:
while True:
line = f.readline()
if not line:
break
dataset.append(json.loads(line))
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["data"]) >= 2]
# Keep one conversation in one list
new_dataset = []
for data in dataset:
if len(data["data"]) % 2 != 0:
continue
chat = []
total_len = 2
if enable_multiturn:
total_len = len(data["data"])
for i in range(0, total_len, 2):
# One user One Assistant
chat.append((data["data"][i], data["data"][i + 1]))
new_dataset.append(chat)
# Shuffle the dataset.
if not disable_shuffle:
random.shuffle(new_dataset)
# Filter out sequences that are too long or too short
filtered_dataset: SampleOutput = common_filter_chat(
num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len
)
return filtered_dataset
def sample_loogle_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
disable_shuffle: bool = False,
enable_multiturn: bool = True,
enable_shared_prefix: bool = False,
fixed_output_len: Optional[int] = None,
) -> SampleOutput:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset
dataset = []
with open(dataset_path) as f:
while True:
line = f.readline()
if not line:
break
dataset.append(json.loads(line))
# Keep one conversation in one list
new_dataset = []
# TODO: Add shared prefix support for loogle
# NOTE: Now we preprocess it only for chat
for data in dataset:
chat = []
if (
"qa_pairs" not in data
or data["qa_pairs"] == "none"
or len(data["qa_pairs"]) == 0
):
# If Q is none (for summarization),
# We add a question for summarization
# And keep the summary up to 1024 words
chat.append(
(
"Input: "
+ data["input"]
+ " Question: "
+ "Please summarize the input",
data["input"][:1024],
)
)
new_dataset.append(chat)
else:
qa_pairs = eval(data["qa_pairs"])
for i, qa in enumerate(qa_pairs):
if i == 0 or enable_shared_prefix:
# Combine input with the first Q
chat.append(
("Input: " + data["input"] + " Question: " + qa["Q"], qa["A"])
)
elif enable_multiturn:
chat.append((qa["Q"], qa["A"]))
new_dataset.append(chat)
# Shuffle the dataset.
if not disable_shuffle:
random.shuffle(new_dataset)
# Filter out sequences that are too long or too short
filtered_dataset: SampleOutput = common_filter_chat(
num_requests, new_dataset, tokenizer, 4, None, None, None, fixed_output_len
)
return filtered_dataset
def sample_nextqa_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
max_frames: int, # Specific for video
model_path: str,
disable_shuffle: bool = False,
enable_multiturn: bool = True, # No multiturn support for now
backend: str = "sglang-oai",
chat_template_name: Optional[str] = None,
fixed_output_len: Optional[int] = None,
) -> SampleOutput:
"""
Example of messages:
message = {
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": base64_data}},
{"type": "text", "text": video.prompt},
],
}
"""
if fixed_output_len is None:
fixed_output_len = 4096
# TODO: Check for multiturn
dataset = NExTQALoader(video_dir=dataset_path, max_frames=max_frames)
new_dataset = []
for v in dataset:
new_dataset.append(v)
if not disable_shuffle:
random.shuffle(new_dataset)
# TODO: prompt len can get from server side
filtered_dataset = []
l = 0
while l < num_requests:
for i in range(len(new_dataset)):
if l == num_requests:
break
video = new_dataset[i]
# text prompt
prompt = video.prompt
# NOTE: Chat Template is a must for video benchmark because we have to
# add special image token for later expansion
if backend == "sglang" or backend == "sglang-native":
if "chat_template" in tokenizer.init_kwargs:
chat_template = get_chat_template(tokenizer.get_chat_template())
elif chat_template_name is not None:
chat_template = get_chat_template(chat_template_name)
else:
chat_template = get_chat_template_by_model_path(model_path)
prompt = chat_template.image_token + prompt
prompt_token_ids = tokenizer(prompt).input_ids
prompt_len = len(prompt_token_ids)
output_len = fixed_output_len # max output len, not real output len
# video input
base64_data = encode_video_base64(video.path, video.num_frames)
# NOTE: This will be replaced by the expanded length from the server
prompt_len += video.num_frames
# add to content
content = [
{"type": "image_url", "image_url": {"url": base64_data}},
{"type": "text", "text": prompt},
]
filtered_dataset.append([(content, prompt_len, output_len)])
l += 1
return filtered_dataset
def sample_random_requests(
input_len: int,
output_len: int,
num_prompts: int,
range_ratio: float,
tokenizer: PreTrainedTokenizerBase,
dataset_path: str,
disable_shuffle: bool = False,
) -> SampleOutput:
input_lens = np.random.randint(
max(int(input_len * range_ratio), 1),
input_len + 1,
size=num_prompts,
)
output_lens = np.random.randint(
int(output_len * range_ratio),
output_len + 1,
size=num_prompts,
)
if True:
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
# Download sharegpt if necessary
if not os.path.isfile(dataset_path):
dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
if not disable_shuffle:
# Shuffle the dataset.
random.shuffle(dataset)
# Filter out sequences that are too long or too short
input_requests: SampleOutput = []
for data in dataset:
i = len(input_requests)
if i == num_prompts:
break
# Tokenize the prompts and completions.
prompt = data[0]
prompt_token_ids = tokenizer.encode(prompt)
prompt_len = len(prompt_token_ids)
# Skip empty prompt
if prompt_len == 0:
continue
if prompt_len > input_lens[i]:
input_ids = prompt_token_ids[: input_lens[i]]
else:
ratio = (input_lens[i] + prompt_len - 1) // prompt_len
input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
prompt = tokenizer.decode(input_ids)
input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))])
else:
# Sample token ids from random integers. This can cause some NaN issues.
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
input_requests = []
for i in range(num_prompts):
prompt = tokenizer.decode(
[
(offsets[i] + i + j) % tokenizer.vocab_size
for j in range(input_lens[i])
]
)
input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))])
print(f"#Input tokens: {np.sum(input_lens)}")
print(f"#Output tokens: {np.sum(output_lens)}")
return input_requests
def gen_prompt(tokenizer, token_num):
"""Generate a random prompt of specified token length using tokenizer vocabulary."""
all_available_tokens = list(tokenizer.get_vocab().values())
selected_tokens = random.choices(all_available_tokens, k=token_num)
return tokenizer.decode(selected_tokens)
def get_gen_prefix_cache_path(args, tokenizer):
"""Create cache directory under ~/.cache/sglang/benchmark"""
cache_dir = Path.home() / ".cache" / "sglang" / "benchmark"
# Create a unique cache filename based on the generation parameters
cache_key = (
f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_"
f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_"
f"{tokenizer.__class__.__name__}.pkl"
)
return cache_dir / cache_key
def sample_generated_shared_prefix_requests(
num_groups: int,
prompts_per_group: int,
system_prompt_len: int,
question_len: int,
output_len: int,
tokenizer: PreTrainedTokenizerBase,
args,
disable_shuffle: bool = False,
) -> SampleOutput:
"""Generate benchmark requests with shared system prompts using random tokens and caching."""
cache_path = get_gen_prefix_cache_path(args, tokenizer)
# Try to load from cache first
if cache_path.exists():
print(f"\nLoading cached generated input data from {cache_path}")
with open(cache_path, "rb") as f:
return pickle.load(f)
print("\nGenerating new input data...")
# Generate system prompts for each group
system_prompts = []
for _ in range(num_groups):
system_prompt = gen_prompt(tokenizer, system_prompt_len)
system_prompts.append(system_prompt)
# Generate questions
questions = []
for _ in range(num_groups * prompts_per_group):
question = gen_prompt(tokenizer, question_len)
questions.append(question)
# Combine system prompts with questions
input_requests = []
total_input_tokens = 0
total_output_tokens = 0
for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
system_prompt = system_prompts[group_idx]
input_requests.append([])
for prompt_idx in tqdm(
range(prompts_per_group), desc="Generating questions", leave=False
):
question = questions[group_idx * prompts_per_group + prompt_idx]
full_prompt = f"{system_prompt}\n\n{question}"
prompt_len = len(tokenizer.encode(full_prompt))
input_requests[-1].append((full_prompt, prompt_len, output_len))
total_input_tokens += prompt_len
total_output_tokens += output_len
if not disable_shuffle:
# Shuffle questions
random.shuffle(input_requests)
# Print statistics
print(f"\nGenerated shared prefix dataset statistics:")
print(f"Number of groups: {num_groups}")
print(f"Prompts per group: {prompts_per_group}")
print(f"Total prompts: {len(input_requests) * prompts_per_group}")
print(f"Total input tokens: {total_input_tokens}")
print(f"Total output tokens: {total_output_tokens}")
print(
f"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens"
)
print(
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
)
# Save to cache
cache_path.parent.mkdir(parents=True, exist_ok=True)
print(f"Caching generated input data to {cache_path}")
with open(cache_path, "wb") as f:
pickle.dump(input_requests, f)
return input_requests
def get_dataset(args, tokenizer):
if args.dataset_name == "sharegpt":
input_requests = sample_sharegpt_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
disable_shuffle=args.disable_shuffle,
enable_multiturn=args.enable_multiturn,
fixed_output_len=args.fixed_output_len,
)
elif args.dataset_name == "ultrachat":
input_requests = sample_ultrachat_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
disable_shuffle=args.disable_shuffle,
enable_multiturn=args.enable_multiturn,
fixed_output_len=args.fixed_output_len,
)
elif args.dataset_name == "loogle":
input_requests = sample_loogle_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
disable_shuffle=args.disable_shuffle,
enable_multiturn=args.enable_multiturn,
enable_shared_prefix=args.enable_shared_prefix,
fixed_output_len=args.fixed_output_len,
)
elif args.dataset_name == "nextqa":
input_requests = sample_nextqa_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
max_frames=args.max_frames,
model_path=args.model,
disable_shuffle=args.disable_shuffle,
enable_multiturn=args.enable_multiturn,
backend=args.backend,
chat_template_name=args.chat_template,
fixed_output_len=args.fixed_output_len,
)
elif args.dataset_name == "random":
input_requests = sample_random_requests(
input_len=args.random_input_len,
output_len=args.random_output_len,
num_prompts=args.num_prompts,
range_ratio=args.random_range_ratio,
tokenizer=tokenizer,
dataset_path=args.dataset_path,
)
elif args.dataset_name == "generated-shared-prefix":
input_requests = sample_generated_shared_prefix_requests(
num_groups=args.gen_num_groups,
prompts_per_group=args.gen_prompts_per_group,
system_prompt_len=args.gen_system_prompt_len,
question_len=args.gen_question_len,
output_len=args.gen_output_len,
args=args,
tokenizer=tokenizer,
)
else:
raise ValueError(f"Unknown dataset: {args.dataset_name}")
return input_requests
#!/usr/bin/bash
# The usage function
usage() {
echo "Usage: $0 {sharegpt|ultragpt|loogle|nextqa|all}"
exit 1
}
# The download function
download() {
case "$1" in
sharegpt)
echo $1
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
;;
ultragpt)
echo $1
# Questions about the world
wget https://cloud.tsinghua.edu.cn/seafhttp/files/be1d7b87-22ca-449e-a6a7-c61d1ea7e010/ultrachat_release_230407.json
# Writing and Creation
wget https://cloud.tsinghua.edu.cn/seafhttp/files/61742d2a-25e2-4d08-b2b9-15f47ae50ace/ultrachat_material_release_230417.json
wget https://cloud.tsinghua.edu.cn/seafhttp/files/f71f6aa6-d346-4b16-85b7-8502efa3d608/ultrachat_material_release_230412.json
# External materials
wget https://cloud.tsinghua.edu.cn/seafhttp/files/42d22e28-e899-4975-a70f-5eda163e265d/ultrachat_existent_material_release_230420.json.gz
gunzip ultrachat_existent_material_release_230420.json.gz
;;
loogle)
echo $1
git lfs install
git clone git@hf.co:datasets/bigainlco/LooGLE
unzip LooGLE/data.zip
;;
nextqa)
echo $1
git lfs install
git clone https://huggingface.co/datasets/lmms-lab/NExTQA
unzip NExTQA/videos.zip
;;
*)
usage
exit 1
;;
esac
}
# Arg check
if [ "$#" -ne 1 ]; then
usage
fi
# Invoke
case "$1" in
sharegpt|ultragpt|loogle|nextqa)
download "$1"
;;
all)
download sharegpt
download ultragpt
download loogle
download nextqa
;;
*)
usage
;;
esac
import os
import sys
from typing import List
import av
from datasets import load_dataset
def find_video_files(video_dir) -> List[str]:
if os.path.isfile(video_dir):
return [video_dir]
video_files = []
for root, dirs, files in os.walk(video_dir):
for file in files:
if file.endswith((".mp4", ".avi", ".mov")):
video_files.append(os.path.join(root, file))
# if file is dir
elif os.path.isdir(file):
video_files.extend(find_video_files(file))
return video_files
def video_frames(video_path, max_frames) -> int:
container = av.open(video_path)
total_frames = container.streams.video[0].frames
return min(total_frames, max_frames)
class Video:
def __init__(self, video_path, num_frames):
self.path = video_path
self.num_frames = num_frames
def __str__(self):
return f"Video({self.path}, {self.num_frames})"
def __iter__(self):
return iter((self.path, self.num_frames))
class VideoPrompt(Video):
def __init__(self, video_path, num_frames, prompt):
super().__init__(video_path, num_frames)
self.prompt = prompt
def __str__(self):
return f"VideoPrompt({self.path}, {self.num_frames}, {self.prompt})"
def __iter__(self):
return iter((self.path, self.num_frames, self.prompt))
class VideoLoader:
pass
class VideoFileLoader(VideoLoader):
"""
Load all the videos in a directory
"""
def __init__(self, video_dir, batch_size=1, max_frames=sys.maxsize):
super().__init__()
self.video_dir = video_dir
self.video_files = find_video_files(video_dir)
self.batch_size = batch_size
self.max_frames = max_frames
print(f"batch_size: {batch_size}, max_frames: {max_frames}")
def __iter__(self): # (file, number of frames)
if self.batch_size == 1:
for video_file in self.video_files:
yield Video(video_file, video_frames(video_file, self.max_frames))
else:
batch = []
for video_file in self.video_files:
video = Video(video_file, video_frames(video_file, self.max_frames))
batch.append(video)
if len(batch) == self.batch_size:
yield batch
batch = []
class NExTQALoader(VideoLoader):
"""
Load vdideos and prompts from NExT dataset
set: train, test or validation
"""
def __init__(
self, video_dir, batch_size=1, max_frames=sys.maxsize, dset="test", task="OE"
):
"""
task: 'MV' or 'OE'
"""
super().__init__()
self.task = task
print(f"Loading the {dset} data of {task} from lmms-lab/NExTQA")
self.ds = load_dataset("lmms-lab/NExTQA", task)
self.ds = self.ds[dset]
# self.n = ds.num_rows
self.video_dir = video_dir
self.video_files = find_video_files(video_dir)
self.video_to_path = dict()
for video_file in self.video_files:
video_id = video_file.split("/")[-1].split(".")[0]
self.video_to_path[video_id] = video_file
self.batch_size = batch_size
self.max_frames = max_frames
def get_video_prompt(self, entry, max_frames) -> VideoPrompt:
# Get video
video_id = entry["video"]
video_path = self.video_to_path[video_id]
assert os.path.exists(video_path), f"Video not found: {video_path}"
num_frames = min(entry["frame_count"], max_frames)
video = Video(video_path, num_frames)
prompt = entry["question"] + "?"
if self.task == "MC": # add choices
prompt += f' a0: {entry["a0"]}, a1: {entry["a1"]}, a2: {entry["a2"]}, a3: {entry["a3"]}'
return VideoPrompt(video_path, num_frames, prompt)
def __iter__(self):
if self.batch_size == 1:
for entry in self.ds:
yield self.get_video_prompt(entry, self.max_frames)
else:
batch = []
for entry in self.ds:
video = self.get_video_prompt(entry, self.max_frames)
batch.append(video)
if len(batch) == self.batch_size:
yield batch
batch = []
# main
if __name__ == "__main__":
video_dir = "./videos"
# video_loader = VideoFileLoader(video_dir, batch_size=16)
# for batch in video_loader:
# print(f"Number of videos in batch: {len(batch)}")
# for video_file, num_frames in batch:
# print(f"Video: {video_file} number of frames: {num_frames}")
video_loader = NExTQALoader(video_dir, batch_size=16, dset="test", task="OE")
for batch in video_loader:
print(f"Number of videos in batch: {len(batch)}")
for video_file, num_frames, prompt in batch:
print(
f"Video: {video_file} number of frames: {num_frames}, prompt: {prompt}"
)
# break
# for video_file, prompt in batch:
# print(f"Video: {video_file} prompt: {prompt}")
# break
......@@ -24,10 +24,14 @@ import requests
from IPython.display import HTML, display
from tqdm import tqdm
from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
from sglang.srt.utils import kill_process_tree
logger = logging.getLogger(__name__)
# type of content fields, can be only prompts or with images/videos
MsgContent = Union[str, List[ChatCompletionMessageContentPart]]
def get_exception_traceback():
etype, value, tb = sys.exc_info()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment