Unverified Commit 02973cd9 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny refactor bench_serving to improve extensibility (#6134)

parent 6d95a35a
......@@ -17,11 +17,12 @@ import logging
import os
import random
import time
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional
import numpy as np
from sglang.bench_serving import (
DatasetRow,
get_dataset,
get_tokenizer,
sample_random_requests,
......@@ -194,7 +195,7 @@ class BenchArgs:
def throughput_test_once(
backend_name: str,
backend,
reqs: List[Tuple[str, int, int]],
reqs: List[DatasetRow],
ignore_eos: bool,
extra_request_body: Dict,
profile: bool,
......@@ -203,7 +204,7 @@ def throughput_test_once(
"backend": backend_name,
"successful_requests": len(reqs),
"total_latency": -1,
"total_input_tokens": sum(r[1] for r in reqs),
"total_input_tokens": sum(r.prompt_len for r in reqs),
"total_output_tokens": -1,
"request_throughput": -1,
"input_throughput": -1,
......@@ -211,11 +212,11 @@ def throughput_test_once(
"total_throughput": -1,
}
prompt = [r[0] for r in reqs]
prompt = [r.prompt for r in reqs]
sampling_params = [
{
"temperature": 0,
"max_new_tokens": r[2],
"max_new_tokens": r.output_len,
"ignore_eos": ignore_eos,
**extra_request_body,
}
......@@ -267,7 +268,6 @@ def throughput_test_once(
def monitor_trace_file(directory, interval=1):
print(f"Monitoring {directory} for new trace files...")
known_files = set(os.listdir(directory))
......
......@@ -610,12 +610,19 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
return filename
@dataclass
class DatasetRow:
prompt: str
prompt_len: int
output_len: int
def sample_mmmu_requests(
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
random_sample: bool = True,
) -> List[Tuple[str, int, int]]:
) -> List[DatasetRow]:
"""
Sample requests from the MMMU dataset using HuggingFace datasets.
......@@ -716,7 +723,11 @@ def sample_mmmu_requests(
output_len = fixed_output_len if fixed_output_len is not None else 256
filtered_dataset.append((prompt, prompt_len, output_len))
filtered_dataset.append(
DatasetRow(
prompt=prompt, prompt_len=prompt_len, output_len=output_len
)
)
except Exception as e:
print(f"Error processing example {i}: {e}")
......@@ -733,7 +744,7 @@ def sample_sharegpt_requests(
context_len: Optional[int] = None,
prompt_suffix: Optional[str] = "",
apply_chat_template=False,
) -> List[Tuple[str, int, int]]:
) -> List[DatasetRow]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
......@@ -764,7 +775,7 @@ def sample_sharegpt_requests(
random.shuffle(dataset)
# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
filtered_dataset: List[DatasetRow] = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
......@@ -802,10 +813,12 @@ def sample_sharegpt_requests(
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
filtered_dataset.append(
DatasetRow(prompt=prompt, prompt_len=prompt_len, output_len=output_len)
)
print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
print(f"#Output tokens: {np.sum([x[2] for x in filtered_dataset])}")
print(f"#Input tokens: {np.sum([x.prompt_len for x in filtered_dataset])}")
print(f"#Output tokens: {np.sum([x.output_len for x in filtered_dataset])}")
return filtered_dataset
......@@ -817,7 +830,7 @@ def sample_random_requests(
tokenizer: PreTrainedTokenizerBase,
dataset_path: str,
random_sample: bool = True,
) -> List[Tuple[str, int, int]]:
) -> List[DatasetRow]:
input_lens = np.random.randint(
max(int(input_len * range_ratio), 1),
input_len + 1,
......@@ -857,7 +870,7 @@ def sample_random_requests(
random.shuffle(dataset)
# Filter out sequences that are too long or too short
input_requests: List[Tuple[str, int, int]] = []
input_requests: List[DatasetRow] = []
for data in dataset:
i = len(input_requests)
if i == num_prompts:
......@@ -878,7 +891,13 @@ def sample_random_requests(
ratio = (input_lens[i] + prompt_len - 1) // prompt_len
input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
prompt = tokenizer.decode(input_ids)
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
input_requests.append(
DatasetRow(
prompt=prompt,
prompt_len=int(input_lens[i]),
output_len=int(output_lens[i]),
)
)
else:
# Sample token ids from random integers. This can cause some NaN issues.
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
......@@ -890,7 +909,13 @@ def sample_random_requests(
for j in range(input_lens[i])
]
)
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
input_requests.append(
DatasetRow(
prompt=prompt,
prompt_len=int(input_lens[i]),
output_len=int(output_lens[i]),
)
)
print(f"#Input tokens: {np.sum(input_lens)}")
print(f"#Output tokens: {np.sum(output_lens)}")
......@@ -925,7 +950,7 @@ def sample_generated_shared_prefix_requests(
output_len: int,
tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace,
) -> List[Tuple[str, int, int]]:
) -> List[DatasetRow]:
"""Generate benchmark requests with shared system prompts using random tokens and caching."""
cache_path = get_gen_prefix_cache_path(args, tokenizer)
......@@ -963,7 +988,11 @@ def sample_generated_shared_prefix_requests(
full_prompt = f"{system_prompt}\n\n{question}"
prompt_len = len(tokenizer.encode(full_prompt))
input_requests.append((full_prompt, prompt_len, output_len))
input_requests.append(
DatasetRow(
prompt=full_prompt, prompt_len=prompt_len, output_len=output_len
)
)
total_input_tokens += prompt_len
total_output_tokens += output_len
......@@ -994,9 +1023,9 @@ def sample_generated_shared_prefix_requests(
async def get_request(
input_requests: List[Tuple[str, int, int]],
input_requests: List[DatasetRow],
request_rate: float,
) -> AsyncGenerator[Tuple[str, int, int], None]:
) -> AsyncGenerator[DatasetRow, None]:
input_requests = iter(input_requests)
for request in input_requests:
yield request
......@@ -1012,7 +1041,7 @@ async def get_request(
def calculate_metrics(
input_requests: List[Tuple[str, int, int]],
input_requests: List[DatasetRow],
outputs: List[RequestFuncOutput],
dur_s: float,
tokenizer: PreTrainedTokenizerBase,
......@@ -1034,7 +1063,7 @@ def calculate_metrics(
tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
)
retokenized_output_lens.append(retokenized_output_len)
total_input += input_requests[i][1]
total_input += input_requests[i].prompt_len
if output_len > 1:
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
itls += outputs[i].itl
......@@ -1096,7 +1125,7 @@ async def benchmark(
base_url: str,
model_id: str,
tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]],
input_requests: List[DatasetRow],
request_rate: float,
max_concurrency: Optional[int],
disable_tqdm: bool,
......@@ -1126,7 +1155,12 @@ async def benchmark(
print(f"Starting warmup with {warmup_requests} sequences...")
# Use the first request for all warmup iterations
test_prompt, test_prompt_len, test_output_len = input_requests[0]
test_request = input_requests[0]
test_prompt, test_prompt_len, test_output_len = (
test_request.prompt,
test_request.prompt_len,
test_request.output_len,
)
if lora_names is not None and len(lora_names) != 0:
lora_name = lora_names[0]
else:
......@@ -1194,7 +1228,11 @@ async def benchmark(
benchmark_start_time = time.perf_counter()
tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request
prompt, prompt_len, output_len = (
request.prompt,
request.prompt_len,
request.output_len,
)
if lora_names is not None and len(lora_names) != 0:
idx = random.randint(0, len(lora_names) - 1)
lora_name = lora_names[idx]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment