Unverified Commit 02973cd9 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny refactor bench_serving to improve extensibility (#6134)

parent 6d95a35a
...@@ -17,11 +17,12 @@ import logging ...@@ -17,11 +17,12 @@ import logging
import os import os
import random import random
import time import time
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional
import numpy as np import numpy as np
from sglang.bench_serving import ( from sglang.bench_serving import (
DatasetRow,
get_dataset, get_dataset,
get_tokenizer, get_tokenizer,
sample_random_requests, sample_random_requests,
...@@ -194,7 +195,7 @@ class BenchArgs: ...@@ -194,7 +195,7 @@ class BenchArgs:
def throughput_test_once( def throughput_test_once(
backend_name: str, backend_name: str,
backend, backend,
reqs: List[Tuple[str, int, int]], reqs: List[DatasetRow],
ignore_eos: bool, ignore_eos: bool,
extra_request_body: Dict, extra_request_body: Dict,
profile: bool, profile: bool,
...@@ -203,7 +204,7 @@ def throughput_test_once( ...@@ -203,7 +204,7 @@ def throughput_test_once(
"backend": backend_name, "backend": backend_name,
"successful_requests": len(reqs), "successful_requests": len(reqs),
"total_latency": -1, "total_latency": -1,
"total_input_tokens": sum(r[1] for r in reqs), "total_input_tokens": sum(r.prompt_len for r in reqs),
"total_output_tokens": -1, "total_output_tokens": -1,
"request_throughput": -1, "request_throughput": -1,
"input_throughput": -1, "input_throughput": -1,
...@@ -211,11 +212,11 @@ def throughput_test_once( ...@@ -211,11 +212,11 @@ def throughput_test_once(
"total_throughput": -1, "total_throughput": -1,
} }
prompt = [r[0] for r in reqs] prompt = [r.prompt for r in reqs]
sampling_params = [ sampling_params = [
{ {
"temperature": 0, "temperature": 0,
"max_new_tokens": r[2], "max_new_tokens": r.output_len,
"ignore_eos": ignore_eos, "ignore_eos": ignore_eos,
**extra_request_body, **extra_request_body,
} }
...@@ -267,7 +268,6 @@ def throughput_test_once( ...@@ -267,7 +268,6 @@ def throughput_test_once(
def monitor_trace_file(directory, interval=1): def monitor_trace_file(directory, interval=1):
print(f"Monitoring {directory} for new trace files...") print(f"Monitoring {directory} for new trace files...")
known_files = set(os.listdir(directory)) known_files = set(os.listdir(directory))
......
...@@ -610,12 +610,19 @@ def download_and_cache_file(url: str, filename: Optional[str] = None): ...@@ -610,12 +610,19 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
return filename return filename
@dataclass
class DatasetRow:
prompt: str
prompt_len: int
output_len: int
def sample_mmmu_requests( def sample_mmmu_requests(
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None, fixed_output_len: Optional[int] = None,
random_sample: bool = True, random_sample: bool = True,
) -> List[Tuple[str, int, int]]: ) -> List[DatasetRow]:
""" """
Sample requests from the MMMU dataset using HuggingFace datasets. Sample requests from the MMMU dataset using HuggingFace datasets.
...@@ -716,7 +723,11 @@ def sample_mmmu_requests( ...@@ -716,7 +723,11 @@ def sample_mmmu_requests(
output_len = fixed_output_len if fixed_output_len is not None else 256 output_len = fixed_output_len if fixed_output_len is not None else 256
filtered_dataset.append((prompt, prompt_len, output_len)) filtered_dataset.append(
DatasetRow(
prompt=prompt, prompt_len=prompt_len, output_len=output_len
)
)
except Exception as e: except Exception as e:
print(f"Error processing example {i}: {e}") print(f"Error processing example {i}: {e}")
...@@ -733,7 +744,7 @@ def sample_sharegpt_requests( ...@@ -733,7 +744,7 @@ def sample_sharegpt_requests(
context_len: Optional[int] = None, context_len: Optional[int] = None,
prompt_suffix: Optional[str] = "", prompt_suffix: Optional[str] = "",
apply_chat_template=False, apply_chat_template=False,
) -> List[Tuple[str, int, int]]: ) -> List[DatasetRow]:
if fixed_output_len is not None and fixed_output_len < 4: if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small") raise ValueError("output_len too small")
...@@ -764,7 +775,7 @@ def sample_sharegpt_requests( ...@@ -764,7 +775,7 @@ def sample_sharegpt_requests(
random.shuffle(dataset) random.shuffle(dataset)
# Filter out sequences that are too long or too short # Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = [] filtered_dataset: List[DatasetRow] = []
for i in range(len(dataset)): for i in range(len(dataset)):
if len(filtered_dataset) == num_requests: if len(filtered_dataset) == num_requests:
break break
...@@ -802,10 +813,12 @@ def sample_sharegpt_requests( ...@@ -802,10 +813,12 @@ def sample_sharegpt_requests(
# Prune too long sequences. # Prune too long sequences.
continue continue
filtered_dataset.append((prompt, prompt_len, output_len)) filtered_dataset.append(
DatasetRow(prompt=prompt, prompt_len=prompt_len, output_len=output_len)
)
print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}") print(f"#Input tokens: {np.sum([x.prompt_len for x in filtered_dataset])}")
print(f"#Output tokens: {np.sum([x[2] for x in filtered_dataset])}") print(f"#Output tokens: {np.sum([x.output_len for x in filtered_dataset])}")
return filtered_dataset return filtered_dataset
...@@ -817,7 +830,7 @@ def sample_random_requests( ...@@ -817,7 +830,7 @@ def sample_random_requests(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
dataset_path: str, dataset_path: str,
random_sample: bool = True, random_sample: bool = True,
) -> List[Tuple[str, int, int]]: ) -> List[DatasetRow]:
input_lens = np.random.randint( input_lens = np.random.randint(
max(int(input_len * range_ratio), 1), max(int(input_len * range_ratio), 1),
input_len + 1, input_len + 1,
...@@ -857,7 +870,7 @@ def sample_random_requests( ...@@ -857,7 +870,7 @@ def sample_random_requests(
random.shuffle(dataset) random.shuffle(dataset)
# Filter out sequences that are too long or too short # Filter out sequences that are too long or too short
input_requests: List[Tuple[str, int, int]] = [] input_requests: List[DatasetRow] = []
for data in dataset: for data in dataset:
i = len(input_requests) i = len(input_requests)
if i == num_prompts: if i == num_prompts:
...@@ -878,7 +891,13 @@ def sample_random_requests( ...@@ -878,7 +891,13 @@ def sample_random_requests(
ratio = (input_lens[i] + prompt_len - 1) // prompt_len ratio = (input_lens[i] + prompt_len - 1) // prompt_len
input_ids = (prompt_token_ids * ratio)[: input_lens[i]] input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
prompt = tokenizer.decode(input_ids) prompt = tokenizer.decode(input_ids)
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) input_requests.append(
DatasetRow(
prompt=prompt,
prompt_len=int(input_lens[i]),
output_len=int(output_lens[i]),
)
)
else: else:
# Sample token ids from random integers. This can cause some NaN issues. # Sample token ids from random integers. This can cause some NaN issues.
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
...@@ -890,7 +909,13 @@ def sample_random_requests( ...@@ -890,7 +909,13 @@ def sample_random_requests(
for j in range(input_lens[i]) for j in range(input_lens[i])
] ]
) )
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) input_requests.append(
DatasetRow(
prompt=prompt,
prompt_len=int(input_lens[i]),
output_len=int(output_lens[i]),
)
)
print(f"#Input tokens: {np.sum(input_lens)}") print(f"#Input tokens: {np.sum(input_lens)}")
print(f"#Output tokens: {np.sum(output_lens)}") print(f"#Output tokens: {np.sum(output_lens)}")
...@@ -925,7 +950,7 @@ def sample_generated_shared_prefix_requests( ...@@ -925,7 +950,7 @@ def sample_generated_shared_prefix_requests(
output_len: int, output_len: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace, args: argparse.Namespace,
) -> List[Tuple[str, int, int]]: ) -> List[DatasetRow]:
"""Generate benchmark requests with shared system prompts using random tokens and caching.""" """Generate benchmark requests with shared system prompts using random tokens and caching."""
cache_path = get_gen_prefix_cache_path(args, tokenizer) cache_path = get_gen_prefix_cache_path(args, tokenizer)
...@@ -963,7 +988,11 @@ def sample_generated_shared_prefix_requests( ...@@ -963,7 +988,11 @@ def sample_generated_shared_prefix_requests(
full_prompt = f"{system_prompt}\n\n{question}" full_prompt = f"{system_prompt}\n\n{question}"
prompt_len = len(tokenizer.encode(full_prompt)) prompt_len = len(tokenizer.encode(full_prompt))
input_requests.append((full_prompt, prompt_len, output_len)) input_requests.append(
DatasetRow(
prompt=full_prompt, prompt_len=prompt_len, output_len=output_len
)
)
total_input_tokens += prompt_len total_input_tokens += prompt_len
total_output_tokens += output_len total_output_tokens += output_len
...@@ -994,9 +1023,9 @@ def sample_generated_shared_prefix_requests( ...@@ -994,9 +1023,9 @@ def sample_generated_shared_prefix_requests(
async def get_request( async def get_request(
input_requests: List[Tuple[str, int, int]], input_requests: List[DatasetRow],
request_rate: float, request_rate: float,
) -> AsyncGenerator[Tuple[str, int, int], None]: ) -> AsyncGenerator[DatasetRow, None]:
input_requests = iter(input_requests) input_requests = iter(input_requests)
for request in input_requests: for request in input_requests:
yield request yield request
...@@ -1012,7 +1041,7 @@ async def get_request( ...@@ -1012,7 +1041,7 @@ async def get_request(
def calculate_metrics( def calculate_metrics(
input_requests: List[Tuple[str, int, int]], input_requests: List[DatasetRow],
outputs: List[RequestFuncOutput], outputs: List[RequestFuncOutput],
dur_s: float, dur_s: float,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
...@@ -1034,7 +1063,7 @@ def calculate_metrics( ...@@ -1034,7 +1063,7 @@ def calculate_metrics(
tokenizer.encode(outputs[i].generated_text, add_special_tokens=False) tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
) )
retokenized_output_lens.append(retokenized_output_len) retokenized_output_lens.append(retokenized_output_len)
total_input += input_requests[i][1] total_input += input_requests[i].prompt_len
if output_len > 1: if output_len > 1:
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1)) tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
itls += outputs[i].itl itls += outputs[i].itl
...@@ -1096,7 +1125,7 @@ async def benchmark( ...@@ -1096,7 +1125,7 @@ async def benchmark(
base_url: str, base_url: str,
model_id: str, model_id: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]], input_requests: List[DatasetRow],
request_rate: float, request_rate: float,
max_concurrency: Optional[int], max_concurrency: Optional[int],
disable_tqdm: bool, disable_tqdm: bool,
...@@ -1126,7 +1155,12 @@ async def benchmark( ...@@ -1126,7 +1155,12 @@ async def benchmark(
print(f"Starting warmup with {warmup_requests} sequences...") print(f"Starting warmup with {warmup_requests} sequences...")
# Use the first request for all warmup iterations # Use the first request for all warmup iterations
test_prompt, test_prompt_len, test_output_len = input_requests[0] test_request = input_requests[0]
test_prompt, test_prompt_len, test_output_len = (
test_request.prompt,
test_request.prompt_len,
test_request.output_len,
)
if lora_names is not None and len(lora_names) != 0: if lora_names is not None and len(lora_names) != 0:
lora_name = lora_names[0] lora_name = lora_names[0]
else: else:
...@@ -1194,7 +1228,11 @@ async def benchmark( ...@@ -1194,7 +1228,11 @@ async def benchmark(
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
tasks: List[asyncio.Task] = [] tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate): async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request prompt, prompt_len, output_len = (
request.prompt,
request.prompt_len,
request.output_len,
)
if lora_names is not None and len(lora_names) != 0: if lora_names is not None and len(lora_names) != 0:
idx = random.randint(0, len(lora_names) - 1) idx = random.randint(0, len(lora_names) - 1)
lora_name = lora_names[idx] lora_name = lora_names[idx]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment