"vscode:/vscode.git/clone" did not exist on "30de083dd061a8769cf5a85cbc6805e1bc0a557f"
Commit 118f1fc7 authored by maxiao1's avatar maxiao1
Browse files

sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

parents
## Run benchmark
### Benchmark sglang
```
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-instruct-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 10 --parallel 1
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model codellama/CodeLlama-7b-instruct-hf --disable-log-requests --port 21000 --gpu 0.97
```
```
python3 bench_other.py --backend vllm --num-questions 64
```
### Benchmark guidance
```
python3 bench_other.py --backend guidance --num-questions 32 --parallel 1 --n-ctx 11000 --model-path path/to/code-llama/gguf
```
### Build dataset
```
pip install PyPDF2
python3 build_dataset.py
```
```python
import PyPDF2
with open('llama2.pdf', 'rb') as file:
reader = PyPDF2.PdfReader(file)
text = ''
for page_num in range(len(reader.pages)):
text += reader.pages[page_num].extract_text()
with open('output.txt', 'w') as text_file:
text_file.write(text)
```
import argparse
import json
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
USER_PREFIX = "[INST] "
USER_SUFFIX = " [/INST]"
ASSISTANT_PREFIX = ""
ASSISTANT_SUFFIX = " </s><s>"
def multi_document_qa(docs, question, generate):
s = USER_PREFIX
s += "Please answer a question according to given documents.\n"
s += "Question:" + question + "Documents begin.\n"
s += "".join(docs)
s += "\nDocuments end."
s += (
"\n\nBased on the above documents, please answer this question:\n"
+ question
+ "\nAnswer in three words or fewer."
)
s += USER_SUFFIX
s += ASSISTANT_PREFIX
answer = generate(s, max_tokens=16, stop=None)
return answer
def main(args):
lines = read_jsonl(args.data_path)
l = lines[0]
arguments = []
labels = []
num_docs = 10
if args.backend == "guidance":
num_docs = 7 # due to OOM
for i in range(len(l["questions"][: args.num_questions])):
arguments.append(
{
"docs": l["documents"][:num_docs],
"question": l["questions"][i],
}
)
labels.append(l["answers"][i])
states = [None] * len(arguments)
# Select backend
call_generate = partial(get_call_generate(args), temperature=0)
# Run requests
def get_one_answer(i):
states[i] = multi_document_qa(generate=call_generate, **arguments[i])
tic = time.perf_counter()
if args.parallel == 1:
for i in tqdm(range(len(labels))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(labels)))),
total=len(labels),
)
)
latency = time.perf_counter() - tic
# Compute accuracy
print(states)
correct = 0
for s, label in zip(states, labels):
answer = s.lower()
if all(x in answer for x in label.lower().split(" ")):
correct += 1
accuracy = correct / len(labels)
print(f"Accuracy: {accuracy:.3f}")
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "multi_document_qa",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"accuracy": accuracy,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="questions.jsonl")
parser.add_argument("--num-questions", type=int, default=100)
args = add_common_other_args_and_parse(parser)
main(args)
import argparse
import json
import time
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
@sgl.function
def multi_document_qa(s, docs, question):
s += sgl.user_begin()
s += "Please answer a question according to given documents.\n"
s += "Question:" + question + "Documents begin.\n"
forks = s.fork(len(docs))
forks += lambda i: docs[i]
forks.join("concate_and_append")
s += "\nDocuments end."
s += (
"\n\nBased on the above documents, please answer this question:\n"
+ question
+ "\nAnswer in three words or fewer."
)
s += sgl.user_end()
s += sgl.assistant(sgl.gen("answer", max_tokens=16))
def main(args):
lines = read_jsonl(args.data_path)
l = lines[0]
arguments = []
labels = []
for i in range(len(l["questions"][: args.num_questions])):
arguments.append(
{
"docs": l["documents"][:10],
"question": l["questions"][i],
}
)
labels.append(l["answers"][i])
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
# Run requests
tic = time.perf_counter()
states = multi_document_qa.run_batch(
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
)
latency = time.perf_counter() - tic
# Compute accuracy
print([s["answer"] for s in states])
correct = 0
for s, label in zip(states, labels):
answer = s["answer"].lower()
if all(x in answer for x in label.lower().split(" ")):
correct += 1
accuracy = correct / len(labels)
print(f"Accuracy: {accuracy:.3f}")
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "multi_document_qa",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"accuracy": accuracy,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="questions.jsonl")
parser.add_argument("--num-questions", type=int, default=100)
args = add_common_sglang_args_and_parse(parser)
main(args)
import json
import transformers
content = "\n".join(
open("llama2.txt", "r", encoding="utf-8", errors="ignore").readlines()
)
content = content.replace("\n\n", "\n")
# Count token
name = "meta-llama/Llama-2-7b-chat-hf"
t = transformers.AutoTokenizer.from_pretrained(name)
print(f"num tokens: {len(t.encode(content))}")
# Segment
SEP = "\n\n"
parts = content.split(SEP)
print(f"num segments: {len(parts)}")
segment_len = 1100
segments = []
tmp = []
tmp_len = 0
for i in range(len(parts)):
tmp.append(parts[i])
tmp_len += len(t.encode(parts[i]))
if tmp_len > segment_len:
segments.append(SEP.join(tmp))
tmp = []
tmp_len = 0
for i, s in enumerate(segments):
print(i, len(t.encode(segments[i])))
# Dump
with open("questions.jsonl", "w") as fout:
fout.write(
json.dumps(
{
"documents": segments[:30],
"questions": [
"What is the name of the fine-tuned LLMs?",
"Which figure shows the helpfulness human evaluation results for Llama 2-Chat?",
"What is the number of parameters in the largest Llama 2 model?",
"What is the batch size of fine-tuning?",
"Where can we find the details of potential data contamination?",
"What is the full name of MPT?",
"What is the power consumption of RSC in Watt?",
"How many tokens of data do they train on?",
"Which model's release is delayed due to a lack of time to sufficiently red team?",
"Which activation function is used in Llama?",
],
"answers": [
"Llama 2 Chat",
"1",
"70 B",
"64",
"A 6",
"MosaicML",
"400",
"2 trillion",
"34 B",
"SwiGLU",
],
}
)
+ "\n"
)
### Benchmark sglang
Run Llama-7B
```
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
Run Mixtral-8x7B
(When there is a CUDA out-of-memory error, try to reduce the `--mem-fraction-static`)
```
python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
```
Benchmark(short output)
```
python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf
```
Benchmark(long output)
```
python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf --long
```
### Benchmark vLLM
Run Llama-7B
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
Run Mixtral-8x7B
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model mistralai/Mixtral-8x7B-Instruct-v0.1 --disable-log-requests --port 21000 --tensor-parallel-size 8
```
Benchmark(short output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm
```
Benchmark(long output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm --long
```
### Benchmark guidance
Benchmark Llama-7B (short output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
Benchmark Llama-7B (long output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf --long
```
import json
import time
from argparse import ArgumentParser
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from data_gen import gen_arguments
from tqdm import tqdm
from vllm.transformers_utils.tokenizer import get_tokenizer
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text
def multi_turns(generate, qas):
s = ""
for qa in qas:
s += qa["prompt"]
s += generate(s, max_tokens=qa["new_tokens"])
return s
def main(args):
print(args)
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
multi_qas = gen_arguments(args, tokenizer)
states = [None] * args.num_qa
call_generate = partial(get_call_generate(args), temperature=0)
def get_one_answer(i):
states[i] = multi_turns(generate=call_generate, **multi_qas[i])
tic = time.perf_counter()
if args.parallel == 1:
for i in tqdm(range(len(multi_qas))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
rets = list(
tqdm(
executor.map(get_one_answer, list(range(len(multi_qas)))),
total=len(multi_qas),
)
)
for _ in rets:
pass
latency = time.perf_counter() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "multi_turn_chat",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_qa,
"num_turns": args.turns,
"other": {
"parallel": args.parallel,
"output_mode": "long" if args.long else "short",
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--turns", type=int, default=4)
parser.add_argument("--num-qa", type=int, default=20)
parser.add_argument("--min-len-q", type=int, default=256)
parser.add_argument("--max-len-q", type=int, default=512)
parser.add_argument("--min-len-a", type=int, default=4)
parser.add_argument("--max-len-a", type=int, default=8)
parser.add_argument("--tokenizer", type=str, required=True)
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--long", action="store_true")
args = add_common_other_args_and_parse(parser)
if args.long:
args.min_len_a = 256
args.max_len_a = 512
args.num_qa = 20
main(args)
import json
import time
from argparse import ArgumentParser
from data_gen import gen_arguments
from vllm.transformers_utils.tokenizer import get_tokenizer
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text
@sgl.function
def multi_turns(s, qas):
for qa in qas:
s += qa["prompt"]
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
def main(args):
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
multi_qas = gen_arguments(args, tokenizer)
backend = select_sglang_backend(args)
tic = time.perf_counter()
states = multi_turns.run_batch(
multi_qas,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.perf_counter() - tic
print(f"Latency: {latency:.3f}")
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "multi_turn_chat",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_qa,
"num_turns": args.turns,
"other": {
"parallel": args.parallel,
"output_mode": "long" if args.long else "short",
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--turns", type=int, default=4)
parser.add_argument("--num-qa", type=int, default=20)
parser.add_argument("--min-len-q", type=int, default=256)
parser.add_argument("--max-len-q", type=int, default=512)
parser.add_argument("--min-len-a", type=int, default=4)
parser.add_argument("--max-len-a", type=int, default=8)
parser.add_argument("--tokenizer", type=str, required=True)
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--long", action="store_true")
args = add_common_sglang_args_and_parse(parser)
if args.long:
args.min_len_a = 256
args.max_len_a = 512
args.num_qa = 20
print(args)
main(args)
import random
import string
random.seed(42)
def gen_prompt(tokenizer, token_num):
cha_set = string.ascii_letters + string.digits
ret = "".join(random.choices(cha_set, k=token_num))
while len(tokenizer(ret).input_ids) < token_num:
ret += random.choice(cha_set)
return ret
def gen_arguments(args, tokenizer):
multi_qas = [{"qas": []} for _ in range(args.num_qa)]
for i in range(args.num_qa):
qas = multi_qas[i]["qas"]
for _ in range(args.turns):
prompt_len = random.randint(args.min_len_q, args.max_len_q)
new_tokens = random.randint(args.min_len_a, args.max_len_a)
qas.append(
{
"prompt": gen_prompt(tokenizer, prompt_len),
"new_tokens": new_tokens,
}
)
return multi_qas
import json
import random
import time
from argparse import ArgumentParser
from pathlib import Path
from tqdm import tqdm
import sglang as sgl
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text
def gen_prompt(tokenizer, token_num):
all_available_tokens = list(tokenizer.get_vocab().values())
selected_tokens = random.choices(all_available_tokens, k=token_num)
ret = tokenizer.decode(selected_tokens)
return ret
def get_cache_path(args):
# Create cache directory under ~/.cache/sglang
cache_dir = Path.home() / ".cache" / "sglang"
# Create a unique cache filename based on the arguments that affect generation
cache_key = f"qa_{args.num_qa}_{args.turns}_{args.system_prompt_len}_{args.len_q}_{args.len_a}_{args.tokenizer.replace('/', '_')}.json"
return cache_dir / cache_key
def gen_arguments(args, tokenizer):
cache_path = get_cache_path(args)
# Try to load from cache first
if cache_path.exists():
print(f"Loading cached arguments from {cache_path}")
with open(cache_path, "r") as f:
return json.load(f)
print("Generating new arguments...")
# First progress bar for system prompts
multi_qas = []
for _ in tqdm(range(args.num_qa), desc="Generating system prompts"):
multi_qas.append(
{"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []}
)
# Nested progress bars for QA pairs
for i in tqdm(range(args.num_qa), desc="Generating QA pairs"):
qas = multi_qas[i]["qas"]
for j in range(args.turns):
qas.append(
{
"prompt": gen_prompt(tokenizer, args.len_q),
"new_tokens": args.len_a,
}
)
# Save to cache
cache_path.parent.mkdir(parents=True, exist_ok=True)
with open(cache_path, "w") as f:
json.dump(multi_qas, f)
print(f"Cached arguments saved to {cache_path}")
return multi_qas
@sgl.function
def multi_turns(s, system_prompt, qas):
s += system_prompt
for i, qa in enumerate(qas):
s += qa["prompt"]
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
def main(args):
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
multi_qas = gen_arguments(args, tokenizer)
backend = select_sglang_backend(args)
tic = time.perf_counter()
states = multi_turns.run_batch(
multi_qas,
temperature=0,
backend=backend,
num_threads="auto",
progress_bar=True,
)
latency = time.perf_counter() - tic
print(f"Latency: {latency:.3f}")
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "multi_turn_system_prompt_chat",
"backend": args.backend,
"latency": round(latency, 3),
"num_requests": args.num_qa,
"num_turns": args.turns,
"other": {
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--turns", type=int, default=8)
parser.add_argument("--num-qa", type=int, default=128)
parser.add_argument("--system-prompt-len", type=int, default=2048)
parser.add_argument("--len-q", type=int, default=32)
parser.add_argument("--len-a", type=int, default=128)
parser.add_argument(
"--tokenizer", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct"
)
parser.add_argument("--trust-remote-code", action="store_true")
args = add_common_sglang_args_and_parse(parser)
print(args)
main(args)
"""
SGLang Embeddings Benchmark Script
This script benchmarks SGLang's /v1/embeddings API performance using HTTP requests.
Features:
- HTTP-only implementation
- Uses /v1/embeddings API endpoint directly
- Configurable RPS, duration, and batch sizes
- Progress tracking and detailed metrics
- Poisson and constant request distributions
Usage:
- Update configuration variables at the top of the file
- Ensure SGLang server is running on the configured HTTP_URL
- Run: python bench_embeddings.py
"""
import asyncio
import logging
from transformers import AutoTokenizer
from util import (
BenchmarkConfig,
generate_text_with_token_count,
run_benchmark_main,
run_generic_benchmark,
)
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
###############################################################################
# CONFIG
###############################################################################
# Create benchmark configuration
config = BenchmarkConfig()
config.rps_values = [500]
config.duration_secs_values = [60]
config.num_unique_requests = 100
config.distribution = "POISSON"
config.profile = False
config.freeze_gc = True # Enable GC freeze functionality
# Profiler output directory - by default uses present working directory (pwd)
# Uncomment and customize the line below to override the default location:
# config.profiler_dir = "/sglang-oss-trace"
# HTTP Configuration
HTTP_URL = "http://localhost:30000/v1/embeddings"
# Embeddings API Config
EMBEDDINGS_MODEL_PATH = "/Qwen/Qwen3-Embedding-0.6B"
BATCH_SIZE = [1] # Number of items per request (batch size)
# Configurable input token length
EMBEDDINGS_INPUT_TOKENS = 500 # Default token length
# Load tokenizer once for embeddings text generation
print("Loading tokenizer for embeddings input generation...")
embeddings_tokenizer = AutoTokenizer.from_pretrained(EMBEDDINGS_MODEL_PATH)
# Generate input text with the specified token length using pre-loaded tokenizer
EMBEDDINGS_INPUT_TEXT = generate_text_with_token_count(
EMBEDDINGS_MODEL_PATH,
EMBEDDINGS_INPUT_TOKENS,
config.special_replicated_token,
tokenizer=embeddings_tokenizer,
)
###############################################################################
# REQUEST GENERATION (in parallel)
###############################################################################
def build_embeddings_request(index: int, item_count: int) -> tuple:
"""Build a single embeddings request."""
try:
# For embeddings, input can be a string or list of strings
if item_count == 1:
input_data = EMBEDDINGS_INPUT_TEXT
else:
input_data = [EMBEDDINGS_INPUT_TEXT for _ in range(item_count)]
req = {
"input": input_data,
"model": EMBEDDINGS_MODEL_PATH,
}
return (index, req)
except Exception as e:
logger.error(f"Error building request {index}: {e}")
return (index, None)
def validate_embeddings_response(response_data: dict) -> bool:
"""Validate embeddings API response."""
return "data" in response_data
def build_warmup_embeddings_request() -> dict:
"""Build a warmup request for the embeddings API."""
return {
"input": EMBEDDINGS_INPUT_TEXT,
"model": EMBEDDINGS_MODEL_PATH,
}
###############################################################################
# MAIN
###############################################################################
async def run_benchmark(rps, duration_secs, item_count):
"""Run a single embeddings benchmark with the given RPS value."""
return await run_generic_benchmark(
rps=rps,
duration_secs=duration_secs,
item_count=item_count,
config=config,
http_url=HTTP_URL,
build_request_func=build_embeddings_request,
response_validator=validate_embeddings_response,
api_name="EMBEDDINGS",
request_description="embeddings requests",
)
async def main():
additional_info = {
"Input text length": f"{EMBEDDINGS_INPUT_TOKENS} tokens",
"Input text preview": (
EMBEDDINGS_INPUT_TEXT[:100] + "..."
if len(EMBEDDINGS_INPUT_TEXT) > 100
else EMBEDDINGS_INPUT_TEXT
),
}
await run_benchmark_main(
config,
run_benchmark,
"EMBEDDINGS",
HTTP_URL,
BATCH_SIZE,
additional_info,
build_warmup_embeddings_request,
)
if __name__ == "__main__":
asyncio.run(main())
"""
SGLang Scoring Benchmark Script
This script benchmarks SGLang's scoring API performance using HTTP requests.
Current Features:
- HTTP-only implementation (open source compatible)
- Uses /v1/score API endpoint directly
- Single item scoring with batching support
- Configurable RPS, duration, and batch sizes
- Progress tracking and detailed metrics
- Poisson and constant request distributions
Usage:
- Update configuration variables at the top of the file
- Ensure SGLang server is running on the configured HTTP_URL
- Run: python bench_score.py
- Each request will contain ITEM_COUNT_VALUES items for batch scoring
"""
import asyncio
from transformers import AutoTokenizer
from util import (
BenchmarkConfig,
generate_text_with_token_count,
run_benchmark_main,
run_generic_benchmark,
)
###############################################################################
# CONFIG
###############################################################################
# Create benchmark configuration
config = BenchmarkConfig()
config.rps_values = [160]
config.duration_secs_values = [60]
config.num_unique_requests = 100
config.distribution = "POISSON"
config.profile = False
config.freeze_gc = True # Enable GC freeze functionality
# Profiler output directory - by default uses present working directory (pwd)
# Uncomment and customize the line below to override the default location:
# config.profiler_dir = "/sglang-oss-trace"
# HTTP Configuration
HTTP_URL = "http://localhost:30000/v1/score" # Use score API directly
# Score API Config
# ITEM_COUNT_VALUES determines number of items per score request (batch size)
SCORE_QUERY_TOKENS = 120
SCORE_ITEM_TOKENS = 180
SCORE_MODEL_PATH = "Qwen/Qwen3-0.6B"
SCORE_LABEL_TOKEN_IDS = [9454, 2753] # Yes/No token IDs
ITEM_COUNT_VALUES = [10] # Number of items per request
# Special token to replicate for precise token counting
SPECIAL_REPLICATED_TOKEN = "<|im_start|>"
###############################################################################
# REQUEST GENERATION (in parallel)
###############################################################################
def create_score_request_builder():
"""Create a score request builder function with shared tokenizer."""
# Load tokenizer once here to verify special token and get precise counts
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(SCORE_MODEL_PATH)
# Verify that our special token produces exactly 1 token
special_token_count = len(
tokenizer.encode(config.special_replicated_token, add_special_tokens=False)
)
print(
f"Special token '{config.special_replicated_token}' produces "
f"{special_token_count} token(s)"
)
def generate_text_with_token_count_local(num_toks):
"""Generate text with precise token count using replicated token."""
return generate_text_with_token_count(
SCORE_MODEL_PATH,
num_toks,
config.special_replicated_token,
tokenizer=tokenizer,
)
def build_score_request(index: int, item_count: int) -> tuple:
"""Build a single score request."""
try:
# Generate query and items for score API
query = generate_text_with_token_count_local(SCORE_QUERY_TOKENS)
items = [
generate_text_with_token_count_local(SCORE_ITEM_TOKENS)
for _ in range(item_count)
]
# Return as dict for score API format
score_data = {
"query": query,
"items": items,
"label_token_ids": SCORE_LABEL_TOKEN_IDS,
"model": SCORE_MODEL_PATH,
}
return (index, score_data)
except Exception as e:
print(f"Error building request {index}: {e}")
return (index, None)
return build_score_request
def validate_score_response(response_data: dict) -> bool:
"""Validate score API response."""
return "scores" in response_data or "logprobs" in response_data
def build_warmup_score_request() -> dict:
"""Build a warmup request for the score API."""
# Load tokenizer once for warmup generation
tokenizer = AutoTokenizer.from_pretrained(SCORE_MODEL_PATH)
warmup_query = generate_text_with_token_count(
SCORE_MODEL_PATH,
SCORE_QUERY_TOKENS,
config.special_replicated_token,
tokenizer=tokenizer,
)
warmup_items = [
generate_text_with_token_count(
SCORE_MODEL_PATH,
SCORE_ITEM_TOKENS,
config.special_replicated_token,
tokenizer=tokenizer,
)
for _ in range(3)
]
return {
"query": warmup_query,
"items": warmup_items,
"label_token_ids": SCORE_LABEL_TOKEN_IDS,
"model": SCORE_MODEL_PATH,
# Add missing parameters for consistency with the original warmup
"apply_softmax": True,
"item_first": False,
}
###############################################################################
# MAIN
###############################################################################
async def run_benchmark(rps, duration_secs, item_count):
"""Run a single benchmark with the given RPS value."""
# Create the request builder function with shared tokenizer
build_request_func = create_score_request_builder()
return await run_generic_benchmark(
rps=rps,
duration_secs=duration_secs,
item_count=item_count,
config=config,
http_url=HTTP_URL,
build_request_func=build_request_func,
response_validator=validate_score_response,
api_name="SINGLE_ITEM_SCORING",
request_description="score requests",
)
async def main():
"""Main function that runs benchmarks for all RPS values."""
additional_info = {
"Query tokens per request": SCORE_QUERY_TOKENS,
"Item tokens per item": SCORE_ITEM_TOKENS,
}
await run_benchmark_main(
config,
run_benchmark,
"SINGLE_ITEM_SCORING",
HTTP_URL,
ITEM_COUNT_VALUES,
additional_info,
build_warmup_score_request,
)
if __name__ == "__main__":
asyncio.run(main())
"""
Common utilities for SGLang benchmark scripts.
This module contains shared code for benchmarking different SGLang APIs
including scoring, embeddings, and other endpoints.
"""
import asyncio
import concurrent.futures
import json
import os
import random
from statistics import mean
from typing import Any, Callable, Dict, List, Optional, Tuple
import aiohttp
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer
class BenchmarkConfig:
"""Configuration for benchmark parameters."""
def __init__(self):
# Common benchmark settings
self.server_type = "HTTP"
self.rps_values = [70]
self.duration_secs_values = [60]
self.num_unique_requests = 100
self.distribution = "POISSON" # Options: "CONSTANT", "POISSON"
self.profile = False
# Garbage Collection Control
self.freeze_gc = True # Enable/disable garbage collection freezing
# Profiler configuration
self.profiler_dir = (
os.getcwd()
) # Default profiler output directory (current working directory)
# Special token for text generation
self.special_replicated_token = "<|im_start|>"
def generate_text_with_token_count(
model_path: str,
num_tokens: int,
special_token: str = "<|im_start|>",
tokenizer: Optional[Any] = None,
) -> str:
"""
Generate text with precise token count using a replicated token.
Args:
model_path: Path to the model for tokenizer
num_tokens: Target number of tokens
special_token: Token to replicate
tokenizer: Optional pre-loaded tokenizer to avoid repeated loading
Returns:
Generated text with approximately the target token count
"""
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Verify token count
special_token_count = len(tokenizer.encode(special_token, add_special_tokens=False))
if special_token_count == 1:
# Simple case: token maps to exactly 1 token
return special_token * num_tokens
else:
print(f"Special token '{special_token}' produces {special_token_count} tokens")
# Handle case where special token produces multiple tokens
repetitions = (num_tokens + special_token_count - 1) // special_token_count
text = special_token * repetitions
# Verify we got the expected token count
actual_tokens = len(tokenizer.encode(text, add_special_tokens=False))
if actual_tokens < num_tokens:
print(f"Warning: Generated {actual_tokens} tokens, expected {num_tokens}")
return text
def setup_profiler(config: BenchmarkConfig, benchmark_name: str) -> None:
"""
Set up profiler environment if profiling is enabled.
Args:
config: Benchmark configuration
benchmark_name: Name of the benchmark (used in directory path)
"""
if config.profile:
# Create benchmark-specific subdirectory
profiler_path = os.path.join(
config.profiler_dir, benchmark_name.lower().replace("_", "-")
)
os.environ["SGLANG_TORCH_PROFILER_DIR"] = profiler_path
print(f"Profiler enabled. Output directory: {profiler_path}")
else:
print("Profiler disabled")
def prepare_all_requests_parallel(
num_requests: int,
item_count: int,
build_request_func: Callable[[int, int], Tuple[int, Any]],
config: BenchmarkConfig,
description: str = "requests",
) -> List[Any]:
"""
Generic function to generate unique requests in parallel, then reuse them.
Args:
num_requests: Total number of requests needed
item_count: Number of items per request (batch size)
build_request_func: Function that takes (index, item_count) and returns (index, request_data)
config: Benchmark configuration
description: Description for progress bars
Returns:
List of request data objects
"""
def build_request_wrapper(index):
"""Wrapper to call the provided build_request_func."""
try:
return build_request_func(index, item_count)
except Exception as e:
print(f"Error building request {index}: {e}")
return (index, None)
# Generate only the unique requests
unique_requests = [None] * config.num_unique_requests
max_workers = min(8, os.cpu_count() or 1) # Limit to 8 threads max
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for i in tqdm(
range(config.num_unique_requests),
desc=f"Submitting {description} generation tasks",
):
future = executor.submit(build_request_wrapper, i)
futures.append(future)
# Collect results as they complete
for f in tqdm(
concurrent.futures.as_completed(futures),
desc=f"Building unique {description}",
total=config.num_unique_requests,
):
try:
index, req_data = f.result()
if req_data is not None:
unique_requests[index] = req_data
else:
print(f"Failed to build request {index}")
except Exception as e:
print(f"Error processing request result: {e}")
# Check if we have any valid requests
valid_requests = [req for req in unique_requests if req is not None]
if not valid_requests:
raise RuntimeError("Failed to generate any valid requests")
print(
f"Successfully generated {len(valid_requests)} out of "
f"{config.num_unique_requests} unique {description}"
)
# Create the full request list by cycling through unique requests
print(
f"Reusing {len(valid_requests)} unique {description} to create "
f"{num_requests} total requests..."
)
all_requests = []
for i in tqdm(range(num_requests), desc=f"Reusing {description}"):
unique_index = i % len(valid_requests)
all_requests.append(valid_requests[unique_index])
print(f"All {description} prepared.\n")
return all_requests
async def sleep_with_distribution(distribution: str, rps: float) -> None:
"""
Sleep according to the specified distribution pattern.
Args:
distribution: "CONSTANT" or "POISSON"
rps: Requests per second rate
"""
if distribution == "CONSTANT":
interval = 1 / rps
await asyncio.sleep(interval)
elif distribution == "POISSON":
# For Poisson process, inter-arrival times follow exponential distribution
interval = random.expovariate(rps)
await asyncio.sleep(interval)
else:
raise ValueError(
f"Unknown distribution: {distribution}. Use 'CONSTANT' or 'POISSON'."
)
def build_http_request_json(request_data: Any) -> str:
"""
Generic function to build HTTP request JSON.
Args:
request_data: The data to serialize to JSON
Returns:
JSON string representation of the request data
"""
return json.dumps(request_data)
async def make_http_call(
session: aiohttp.ClientSession,
request_data: Any,
request_id: int,
results_queue: asyncio.Queue,
http_url: str,
response_validator: Callable[[Dict[str, Any]], bool],
api_name: str = "API",
) -> None:
"""
Generic HTTP call function for API requests.
Args:
session: aiohttp client session
request_data: Data to send in the request
request_id: Unique identifier for this request
results_queue: Queue to put results
http_url: URL to send the request to
response_validator: Function to validate the response JSON
api_name: Name of the API for error messages
"""
try:
start_time = asyncio.get_event_loop().time()
request_json = build_http_request_json(request_data)
headers = {"Content-Type": "application/json"}
async with session.post(http_url, data=request_json, headers=headers) as resp:
resp_text = await resp.text()
if resp.status != 200:
print(
f"[HTTP] {api_name} Request {request_id} failed with status "
f"{resp.status}: {resp_text}"
)
completion_time = asyncio.get_event_loop().time()
await results_queue.put((request_id, 0, False, completion_time))
return
# Parse and validate response
try:
response_data = json.loads(resp_text)
success = response_validator(response_data)
if not success:
print(
f"[HTTP] {api_name} Request {request_id} failed response validation"
)
except json.JSONDecodeError:
print(
f"[HTTP] {api_name} Request {request_id} failed to parse JSON response"
)
success = False
completion_time = asyncio.get_event_loop().time()
elapsed_time = (completion_time - start_time) * 1000
await results_queue.put((request_id, elapsed_time, success, completion_time))
except Exception as e:
print(f"[HTTP] {api_name} Error for request {request_id}: {e}")
completion_time = asyncio.get_event_loop().time()
await results_queue.put((request_id, 0, False, completion_time))
async def send_profile_request(
profile_text: str, http_url: str, session: Optional[aiohttp.ClientSession] = None
) -> None:
"""
Send a profile request (START_PROFILE or STOP_PROFILE) and wait for completion.
Args:
profile_text: "START_PROFILE" or "STOP_PROFILE"
http_url: Base HTTP URL (will derive profile endpoints from this)
session: Optional aiohttp session to use
"""
try:
if session:
print(f"Sending {profile_text} request via HTTP...")
# Determine the correct endpoint
if "/v1/" in http_url:
base_url = http_url.rsplit("/v1/", 1)[0] # Remove /v1/xxx
else:
base_url = http_url.rsplit("/", 1)[0] # Remove last path component
if profile_text == "START_PROFILE":
endpoint_url = f"{base_url}/start_profile"
elif profile_text == "STOP_PROFILE":
endpoint_url = f"{base_url}/stop_profile"
else:
print(f"Unknown profile request: {profile_text}")
return
headers = {"Content-Type": "application/json"}
async with session.post(endpoint_url, headers=headers) as resp:
resp_text = await resp.text()
if resp.status == 200:
print(f"{profile_text} request completed")
else:
print(
f"{profile_text} request failed with status "
f"{resp.status}: {resp_text}"
)
else:
print(f"Cannot send {profile_text} request - missing session")
except Exception as e:
print(f"Error sending {profile_text} request: {e}")
async def call_freeze_gc_http(session: aiohttp.ClientSession, http_url: str) -> None:
"""
Call the /freeze_gc HTTP endpoint.
Args:
session: aiohttp client session
http_url: Base HTTP URL to derive the freeze_gc endpoint from
"""
try:
# Derive freeze_gc endpoint from the API URL
if "/v1/" in http_url:
freeze_gc_url = http_url.rsplit("/v1/", 1)[0] + "/freeze_gc"
else:
freeze_gc_url = http_url.rsplit("/", 1)[0] + "/freeze_gc"
print(f"Calling freeze_gc endpoint: {freeze_gc_url}")
async with session.post(freeze_gc_url) as resp:
if resp.status == 200:
print("freeze_gc called successfully")
else:
resp_text = await resp.text()
print(f"freeze_gc failed with status {resp.status}: {resp_text}")
except Exception as e:
print(f"Failed to call freeze_gc: {e}")
async def send_warmup_requests(
session: aiohttp.ClientSession,
http_url: str,
build_warmup_request_func: Callable[[], Any],
num_warmup: int = 3,
) -> None:
"""
Send warmup requests to HTTP server.
Args:
session: aiohttp client session
http_url: URL to send warmup requests to
build_warmup_request_func: Function that returns a warmup request object
num_warmup: Number of warmup requests to send
"""
print(f"Sending {num_warmup} HTTP warmup requests...")
for i in range(num_warmup):
try:
warmup_data = build_warmup_request_func()
request_json = build_http_request_json(warmup_data)
headers = {"Content-Type": "application/json"}
async with session.post(
http_url, data=request_json, headers=headers
) as resp:
if resp.status == 200:
print(f"Warmup request {i+1}/{num_warmup} completed successfully")
else:
print(
f"Warmup request {i+1}/{num_warmup} failed with status {resp.status}"
)
except Exception as e:
print(f"Warmup request {i+1}/{num_warmup} failed with error: {e}")
print("HTTP warmup requests completed")
async def perform_global_warmup_and_freeze(
config: BenchmarkConfig,
http_url: str,
build_warmup_request_func: Callable[[], Any],
) -> None:
"""
Perform warmup and optionally GC freeze operations once before all benchmark runs.
Args:
config: Benchmark configuration
http_url: URL for API requests
build_warmup_request_func: Function that returns a warmup request object
"""
print("=" * 80)
print(f"PERFORMING GLOBAL WARMUP{' AND GC FREEZE' if config.freeze_gc else ''}")
print("=" * 80)
print(f"Performing HTTP warmup{' and GC freeze' if config.freeze_gc else ''}...")
async with aiohttp.ClientSession() as session:
await send_warmup_requests(session, http_url, build_warmup_request_func)
if config.freeze_gc:
await call_freeze_gc_http(session, http_url)
print(
f"HTTP warmup{' and GC freeze' if config.freeze_gc else ''} completed successfully."
)
print(
f"Global warmup{' and GC freeze' if config.freeze_gc else ''} operations completed."
)
print("=" * 80)
async def process_results(
results_queue: asyncio.Queue,
num_requests: int,
send_duration: float,
total_duration: float,
rps: int,
duration_secs: int,
item_count: int,
test_start_time: float,
config: BenchmarkConfig,
http_mode: str = "UNKNOWN",
) -> List[Dict[str, Any]]:
"""
Process benchmark results and group them by minute intervals.
Args:
results_queue: Queue containing result tuples
num_requests: Total number of requests sent
send_duration: Time taken to send all requests
total_duration: Total time for all requests to complete
rps: Target requests per second
duration_secs: Test duration in seconds
item_count: Number of items per request
test_start_time: Start time of the test
config: Benchmark configuration
http_mode: Description of the HTTP mode/API being tested
Returns:
List of dictionaries containing minute-by-minute results
"""
all_results = []
# Collect all results
for _ in range(num_requests):
result = await results_queue.get()
request_id, elapsed_time, success, completion_time = result
all_results.append(
{
"request_id": request_id,
"elapsed_time": elapsed_time,
"success": success,
"completion_time": completion_time,
}
)
# Group results by minute intervals
minute_results = []
num_minutes = int(duration_secs // 60) + (1 if duration_secs % 60 > 0 else 0)
for minute in range(num_minutes):
minute_start = test_start_time + (minute * 60)
minute_end = test_start_time + ((minute + 1) * 60)
# Filter results that completed in this minute
minute_data = [
r for r in all_results if minute_start <= r["completion_time"] < minute_end
]
response_times = [r["elapsed_time"] for r in minute_data if r["success"]]
successful_requests = len([r for r in minute_data if r["success"]])
failed_requests = len([r for r in minute_data if not r["success"]])
avg_response_time = mean(response_times) if response_times else 0
# Calculate percentiles using numpy
if response_times:
p50 = np.percentile(response_times, 50)
p90 = np.percentile(response_times, 90)
p99 = np.percentile(response_times, 99)
else:
p50 = p90 = p99 = 0
minute_result = {
"test_duration_secs": duration_secs,
"minute_interval": minute + 1,
"target_rps": rps,
"item_count": item_count,
"server_type": config.server_type,
"distribution": config.distribution,
"unique_requests": config.num_unique_requests,
"total_requests": len(minute_data),
"successful_requests": successful_requests,
"failed_requests": failed_requests,
"send_duration_secs": send_duration,
"total_duration_secs": total_duration,
"avg_response_time_ms": avg_response_time,
"p50_response_time_ms": p50,
"p90_response_time_ms": p90,
"p99_response_time_ms": p99,
}
minute_results.append(minute_result)
print(
f"\nMinute {minute + 1} Summary for RPS {rps}, "
f"Duration {duration_secs}s, Item Count {item_count}:"
)
print(f" Requests completed in minute: {len(minute_data)}")
print(f" Successful requests: {successful_requests}")
print(f" Failed requests: {failed_requests}")
print(f" Average response time: {avg_response_time:.2f} ms")
print(f" P50 response time: {p50:.2f} ms")
print(f" P90 response time: {p90:.2f} ms")
print(f" P99 response time: {p99:.2f} ms")
# Print overall summary
all_response_times = [r["elapsed_time"] for r in all_results if r["success"]]
total_successful = len([r for r in all_results if r["success"]])
total_failed = len([r for r in all_results if not r["success"]])
overall_avg = mean(all_response_times) if all_response_times else 0
if all_response_times:
overall_p50 = np.percentile(all_response_times, 50)
overall_p90 = np.percentile(all_response_times, 90)
overall_p99 = np.percentile(all_response_times, 99)
else:
overall_p50 = overall_p90 = overall_p99 = 0
print(
f"\nOverall Summary for RPS {rps}, Duration {duration_secs}s, "
f"Item Count {item_count}:"
)
print(f" Test duration: {duration_secs} seconds")
print(f" Server type: {config.server_type}")
print(f" HTTP mode: {http_mode}")
print(f" Target RPS: {rps}")
print(f" Item count: {item_count}")
print(f" Distribution: {config.distribution}")
print(f" Unique requests generated: {config.num_unique_requests}")
print(f" Total requests sent: {num_requests}")
print(f" Successful requests: {total_successful}")
print(f" Failed requests: {total_failed}")
print(f" Time to send all requests: {send_duration:.2f} seconds")
print(f" Time for all requests to complete: {total_duration:.2f} seconds")
print(f" Average response time: {overall_avg:.2f} ms")
print(f" P50 response time: {overall_p50:.2f} ms")
print(f" P90 response time: {overall_p90:.2f} ms")
print(f" P99 response time: {overall_p99:.2f} ms\n")
return minute_results
def print_csv_results(all_results: List[Dict[str, Any]]) -> None:
"""
Print benchmark results in CSV format.
Args:
all_results: List of result dictionaries from process_results
"""
print("\n" + "=" * 80)
print("FINAL CSV RESULTS:")
print("=" * 80)
# CSV Header
headers = [
"test_duration_secs",
"minute_interval",
"target_rps",
"item_count",
"server_type",
"distribution",
"unique_requests",
"total_requests",
"successful_requests",
"failed_requests",
"send_duration_secs",
"total_duration_secs",
"avg_response_time_ms",
"p50_response_time_ms",
"p90_response_time_ms",
"p99_response_time_ms",
]
print(",".join(headers))
# CSV Data
for result in all_results:
row = [
result["test_duration_secs"],
result["minute_interval"],
result["target_rps"],
result["item_count"],
result["server_type"],
result["distribution"],
result["unique_requests"],
result["total_requests"],
result["successful_requests"],
result["failed_requests"],
f"{result['send_duration_secs']:.2f}",
f"{result['total_duration_secs']:.2f}",
f"{result['avg_response_time_ms']:.2f}",
f"{result['p50_response_time_ms']:.2f}",
f"{result['p90_response_time_ms']:.2f}",
f"{result['p99_response_time_ms']:.2f}",
]
print(",".join(map(str, row)))
async def run_benchmark_main(
config: BenchmarkConfig,
run_single_benchmark_func,
benchmark_name: str,
http_url: str,
item_count_values: List[int],
additional_info: Optional[Dict[str, Any]] = None,
build_warmup_request_func: Optional[Callable[[], Any]] = None,
) -> None:
"""
Main benchmark orchestration function.
Args:
config: Benchmark configuration
run_single_benchmark_func: Async function to run a single benchmark
benchmark_name: Name of the benchmark (e.g., "SCORING", "EMBEDDINGS")
http_url: URL of the API endpoint
item_count_values: List of item counts to test
additional_info: Additional information to print in the header
build_warmup_request_func: Optional function to build warmup requests
"""
total_combinations = (
len(config.duration_secs_values)
* len(config.rps_values)
* len(item_count_values)
)
print(
f"Running benchmarks for {len(config.duration_secs_values)} duration "
f"values, {len(config.rps_values)} RPS values, and "
f"{len(item_count_values)} item count values = "
f"{total_combinations} total combinations"
)
print(f"Server Type: {config.server_type}")
print(f"HTTP Mode: {benchmark_name}")
print(f"API URL: {http_url}")
if additional_info:
for key, value in additional_info.items():
print(f"{key}: {value}")
print(f"Items per request (batch size): {item_count_values}")
print(f"Profiling Enabled: {config.profile}")
print(f"Duration values: {config.duration_secs_values}")
print(f"RPS values: {config.rps_values}")
print(f"Item count values: {item_count_values}")
print("=" * 80)
# Set up profiler environment
setup_profiler(config, benchmark_name)
# Perform global warmup and GC freeze operations if warmup function is provided
if build_warmup_request_func is not None:
await perform_global_warmup_and_freeze(
config, http_url, build_warmup_request_func
)
all_results = []
for duration_secs in config.duration_secs_values:
for rps in config.rps_values:
for item_count in item_count_values:
result = await run_single_benchmark_func(rps, duration_secs, item_count)
all_results.extend(result) # Extend with minute results
print_csv_results(all_results)
async def run_generic_benchmark(
rps: int,
duration_secs: int,
item_count: int,
config: BenchmarkConfig,
http_url: str,
build_request_func: Callable[[int, int], Tuple[int, Any]],
response_validator: Callable[[Dict[str, Any]], bool],
api_name: str,
request_description: str = "requests",
) -> List[Dict[str, Any]]:
"""
Generic benchmark runner that can be used for different APIs.
Args:
rps: Requests per second
duration_secs: Duration of the test in seconds
item_count: Number of items per request (batch size)
config: Benchmark configuration
http_url: URL of the API endpoint
build_request_func: Function to build individual requests
response_validator: Function to validate API responses
api_name: Name of the API for logging
request_description: Description for progress bars
Returns:
List of dictionaries containing minute-by-minute results
"""
num_requests = int(rps * duration_secs)
print(
f"Starting benchmark with RPS={rps}, Duration={duration_secs}s, "
f"Item Count={item_count}, num_requests={num_requests}"
)
print(f"Server Type: {config.server_type}")
print(f"HTTP Mode: {api_name}")
print(f"Profiling Enabled: {config.profile}")
# Build requests in parallel (unmeasured)
all_requests = prepare_all_requests_parallel(
num_requests, item_count, build_request_func, config, request_description
)
results_queue = asyncio.Queue()
tasks = []
# Track timing for sending requests
send_start_time = asyncio.get_event_loop().time()
# HTTP implementation
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=300)
) as session:
# Send START_PROFILE if profiling is enabled
if config.profile:
await send_profile_request("START_PROFILE", http_url, session=session)
# Add progress bar for sending requests
with tqdm(
total=len(all_requests),
desc=f"Sending HTTP {request_description} at {rps} RPS",
unit="req",
) as pbar:
for i, request_data in enumerate(all_requests):
request_id = i + 1
tasks.append(
asyncio.create_task(
make_http_call(
session,
request_data,
request_id,
results_queue,
http_url,
response_validator,
api_name,
)
)
)
# Update progress bar
pbar.update(1)
# Throttle based on distribution
if i < len(all_requests) - 1:
await sleep_with_distribution(config.distribution, rps)
send_end_time = asyncio.get_event_loop().time()
send_duration = send_end_time - send_start_time
# Wait for all requests to complete with progress tracking
print(f"Waiting for {len(tasks)} HTTP {request_description} to complete...")
with tqdm(
total=len(tasks), desc=f"Completing HTTP {request_description}", unit="req"
) as completion_pbar:
completed_tasks = []
for task in asyncio.as_completed(tasks):
await task
completed_tasks.append(task)
completion_pbar.update(1)
# Send STOP_PROFILE if profiling is enabled
if config.profile:
await send_profile_request("STOP_PROFILE", http_url, session=session)
completion_end_time = asyncio.get_event_loop().time()
total_duration = completion_end_time - send_start_time
return await process_results(
results_queue,
num_requests,
send_duration,
total_duration,
rps,
duration_secs,
item_count,
send_start_time,
config,
api_name,
)
## Run benchmark
NOTE: This is an implementation for replaying a given trace for throughput/latency benchmark purposes. It is not an actual ReAct agent implementation.
### Benchmark sglang
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
```
python3 bench_sglang.py --num-questions 100
```
### Benchmark vllm
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
```
python3 bench_other.py --num-questions 100 --backend vllm
```
### Benchmark guidance
```
python3 bench_other.py --num-questions 100 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
### Benchmark lmql
```
python3 bench_other.py --num-questions 100 --backend lmql --parallel 1
```
import argparse
import json
import time
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
def get_prompt(question):
prompt = (
"""Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types:
(1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search.
(2) Lookup[keyword], which returns the next sentence containing keyword in the current passage.
(3) Finish[answer], which returns the answer and finishes the task.
Here are some examples.
Question: What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into?
Thought 1: I need to search Colorado orogeny, find the area that the eastern sector of the Colorado orogeny extends into, then find the elevation range of the area.
Action 1: Search[Colorado orogeny]
Observation 1: The Colorado orogeny was an episode of mountain building (an orogeny) in Colorado and surrounding areas.
Thought 2: It does not mention the eastern sector. So I need to look up eastern sector.
Action 2: Lookup[eastern sector]
Observation 2: (Result 1 / 1) The eastern sector extends into the High Plains and is called the Central Plains orogeny.
Thought 3: The eastern sector of Colorado orogeny extends into the High Plains. So I need to search High Plains and find its elevation range.
Action 3: Search[High Plains]
Observation 3: High Plains refers to one of two distinct land regions:
Thought 4: I need to instead search High Plains (United States).
Action 4: Search[High Plains (United States)]
Observation 4: The High Plains are a subregion of the Great Plains. From east to west, the High Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130 m).[3]
Thought 5: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer is 1,800 to 7,000 ft.
Action 5: Finish[1,800 to 7,000 ft]
Question: Musician and satirist Allie Goertz wrote a song about the "The Simpsons" character Milhouse, who Matt Groening named after who?
Thought 1: The question simplifies to "The Simpsons" character Milhouse is named after who. I only need to search Milhouse and find who it is named after.
Action 1: Search[Milhouse]
Observation 1: Milhouse Mussolini Van Houten is a recurring character in the Fox animated television series The Simpsons voiced by Pamela Hayden and created by Matt Groening.
Thought 2: The paragraph does not tell who Milhouse is named after, maybe I can look up "named after".
Action 2: Lookup[named after]
Observation 2: (Result 1 / 1) Milhouse was named after U.S. president Richard Nixon, whose middle name was Milhous.
Thought 3: Milhouse was named after U.S. president Richard Nixon, so the answer is Richard Nixon.
Action 3: Finish[Richard Nixon]
Question: Which documentary is about Finnish rock groups, Adam Clayton Powell or The Saimaa Gesture?
Thought 1: I need to search Adam Clayton Powell and The Saimaa Gesture, and find which documentary is about Finnish rock groups.
Action 1: Search[Adam Clayton Powell]
Observation 1: Could not find [Adam Clayton Powell]. Similar: ['Adam Clayton Powell III', 'Seventh Avenue (Manhattan)', 'Adam Clayton Powell Jr. State Office Building', 'Isabel Washington Powell', 'Adam Powell', 'Adam Clayton Powell (film)', 'Giancarlo Esposito'].
Thought 2: To find the documentary, I can search Adam Clayton Powell (film).
Action 2: Search[Adam Clayton Powell (film)]
Observation 2: Adam Clayton Powell is a 1989 American documentary film directed by Richard Kilberg.
The film is about the rise and fall of influential African-American politician Adam Clayton Powell Jr.[3][4] It was later aired as part of the PBS series The American Experience.
Thought 3: Adam Clayton Powell (film) is a documentary about an African-American politician, not Finnish rock groups. So the documentary about Finnish rock groups must instead be The Saimaa Gesture.
Action 3: Finish[The Saimaa Gesture]
Question: What profession does Nicholas Ray and Elia Kazan have in common?
Thought 1: I need to search Nicholas Ray and Elia Kazan, find their professions, then find the profession they have in common.
Action 1: Search[Nicholas Ray]
Observation 1: Nicholas Ray (born Raymond Nicholas Kienzle Jr., August 7, 1911 – June 16, 1979) was an American film director, screenwriter, and actor best known for the 1955 film Rebel Without a Cause.
Thought 2: Professions of Nicholas Ray are director, screenwriter, and actor. I need to search Elia Kazan next and find his professions.
Action 2: Search[Elia Kazan]
Observation 2: Elia Kazan was an American film and theatre director, producer, screenwriter and actor.
Thought 3: Professions of Elia Kazan are director, producer, screenwriter, and actor. So profession Nicholas Ray and Elia Kazan have in common is director, screenwriter, and actor.
Action 3: Finish[director, screenwriter, actor]
Question: Which magazine was started first Arthur's Magazine or First for Women?
Thought 1: I need to search Arthur's Magazine and First for Women, and find which was started first.
Action 1: Search[Arthur's Magazine]
Observation 1: Arthur's Magazine (1844-1846) was an American literary periodical published in Philadelphia in the 19th century.
Thought 2: Arthur's Magazine was started in 1844. I need to search First for Women next.
Action 2: Search[First for Women]
Observation 2: First for Women is a woman's magazine published by Bauer Media Group in the USA.[1] The magazine was started in 1989.
Thought 3: First for Women was started in 1989. 1844 (Arthur's Magazine) < 1989 (First for Women), so Arthur's Magazine was started first.
Action 3: Finish[Arthur's Magazine]
Question: Were Pavel Urysohn and Leonid Levin known for the same type of work?
Thought 1: I need to search Pavel Urysohn and Leonid Levin, find their types of work, then find if they are the same.
Action 1: Search[Pavel Urysohn]
Observation 1: Pavel Samuilovich Urysohn (February 3, 1898 â August 17, 1924) was a Soviet mathematician who is best known for his contributions in dimension theory.
Thought 2: Pavel Urysohn is a mathematician. I need to search Leonid Levin next and find its type of work.
Action 2: Search[Leonid Levin]
Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist.
Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work.
Action 3: Finish[yes]
"""
+ question
)
return prompt
def main(args):
lines = read_jsonl(args.data_path)[: args.num_questions]
arguments = [{"question": k, "triplets": v} for l in lines for k, v in l.items()]
states = []
# Select backend
call_generate = get_call_generate(args)
def run_single_agent(argument):
question = argument["question"]
triplets = argument["triplets"]
prompt = get_prompt(question)
for i in range(1, len(triplets) + 2):
prompt += "Thought " + str(i) + ":"
states.append(prompt)
answer = call_generate(
prompt, max_tokens=200, temperature=0, stop="Observation"
)
if i > len(triplets):
break
prompt += (
triplets[i - 1]["thought"]
+ "\nAction "
+ str(i)
+ ":"
+ triplets[i - 1]["action"]
+ "\nObservation "
+ str(i)
+ ":"
+ triplets[i - 1]["observation"]
+ "\n"
)
states.append(answer)
async def run_single_agent_async(argument):
question = argument["question"]
triplets = argument["triplets"]
prompt = get_prompt(question)
for i in range(1, len(triplets) + 2):
prompt += "Thought " + str(i) + ":"
states.append(prompt)
answer = await call_generate(
prompt, max_tokens=200, temperature=0, stop="Observation", max_len=4096
)
if i > len(triplets):
break
prompt += (
triplets[i - 1]["thought"]
+ "\nAction "
+ str(i)
+ ":"
+ triplets[i - 1]["action"]
+ "\nObservation "
+ str(i)
+ ":"
+ triplets[i - 1]["observation"]
+ "\n"
)
states.append(answer)
tic = time.perf_counter()
if args.backend != "lmql":
if args.parallel == 1:
for arg in tqdm(arguments):
run_single_agent(arg)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(run_single_agent, arguments), total=len(arguments)
)
)
else:
import asyncio
loop = asyncio.get_event_loop()
batches = [
[] for _ in range((len(arguments) + args.parallel - 1) // args.parallel)
]
for i, arg in enumerate(arguments):
batches[i // args.parallel].append(arg)
for bt in tqdm(batches):
tasks = [run_single_agent_async(arg) for arg in bt]
loop.run_until_complete(asyncio.gather(*tasks))
latency = time.perf_counter() - tic
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "ReAct Agents",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": len(arguments),
"other": {
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="hotpotqa_100.jsonl")
parser.add_argument("--num-questions", type=int, default=10)
args = add_common_other_args_and_parse(parser)
main(args)
import argparse
import json
import time
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
@sgl.function
def webthink(s, question, triplets):
s += (
"""Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types:
(1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search.
(2) Lookup[keyword], which returns the next sentence containing keyword in the current passage.
(3) Finish[answer], which returns the answer and finishes the task.
Here are some examples.
Question: What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into?
Thought 1: I need to search Colorado orogeny, find the area that the eastern sector of the Colorado orogeny extends into, then find the elevation range of the area.
Action 1: Search[Colorado orogeny]
Observation 1: The Colorado orogeny was an episode of mountain building (an orogeny) in Colorado and surrounding areas.
Thought 2: It does not mention the eastern sector. So I need to look up eastern sector.
Action 2: Lookup[eastern sector]
Observation 2: (Result 1 / 1) The eastern sector extends into the High Plains and is called the Central Plains orogeny.
Thought 3: The eastern sector of Colorado orogeny extends into the High Plains. So I need to search High Plains and find its elevation range.
Action 3: Search[High Plains]
Observation 3: High Plains refers to one of two distinct land regions:
Thought 4: I need to instead search High Plains (United States).
Action 4: Search[High Plains (United States)]
Observation 4: The High Plains are a subregion of the Great Plains. From east to west, the High Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130 m).[3]
Thought 5: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer is 1,800 to 7,000 ft.
Action 5: Finish[1,800 to 7,000 ft]
Question: Musician and satirist Allie Goertz wrote a song about the "The Simpsons" character Milhouse, who Matt Groening named after who?
Thought 1: The question simplifies to "The Simpsons" character Milhouse is named after who. I only need to search Milhouse and find who it is named after.
Action 1: Search[Milhouse]
Observation 1: Milhouse Mussolini Van Houten is a recurring character in the Fox animated television series The Simpsons voiced by Pamela Hayden and created by Matt Groening.
Thought 2: The paragraph does not tell who Milhouse is named after, maybe I can look up "named after".
Action 2: Lookup[named after]
Observation 2: (Result 1 / 1) Milhouse was named after U.S. president Richard Nixon, whose middle name was Milhous.
Thought 3: Milhouse was named after U.S. president Richard Nixon, so the answer is Richard Nixon.
Action 3: Finish[Richard Nixon]
Question: Which documentary is about Finnish rock groups, Adam Clayton Powell or The Saimaa Gesture?
Thought 1: I need to search Adam Clayton Powell and The Saimaa Gesture, and find which documentary is about Finnish rock groups.
Action 1: Search[Adam Clayton Powell]
Observation 1: Could not find [Adam Clayton Powell]. Similar: ['Adam Clayton Powell III', 'Seventh Avenue (Manhattan)', 'Adam Clayton Powell Jr. State Office Building', 'Isabel Washington Powell', 'Adam Powell', 'Adam Clayton Powell (film)', 'Giancarlo Esposito'].
Thought 2: To find the documentary, I can search Adam Clayton Powell (film).
Action 2: Search[Adam Clayton Powell (film)]
Observation 2: Adam Clayton Powell is a 1989 American documentary film directed by Richard Kilberg.
The film is about the rise and fall of influential African-American politician Adam Clayton Powell Jr.[3][4] It was later aired as part of the PBS series The American Experience.
Thought 3: Adam Clayton Powell (film) is a documentary about an African-American politician, not Finnish rock groups. So the documentary about Finnish rock groups must instead be The Saimaa Gesture.
Action 3: Finish[The Saimaa Gesture]
Question: What profession does Nicholas Ray and Elia Kazan have in common?
Thought 1: I need to search Nicholas Ray and Elia Kazan, find their professions, then find the profession they have in common.
Action 1: Search[Nicholas Ray]
Observation 1: Nicholas Ray (born Raymond Nicholas Kienzle Jr., August 7, 1911 – June 16, 1979) was an American film director, screenwriter, and actor best known for the 1955 film Rebel Without a Cause.
Thought 2: Professions of Nicholas Ray are director, screenwriter, and actor. I need to search Elia Kazan next and find his professions.
Action 2: Search[Elia Kazan]
Observation 2: Elia Kazan was an American film and theatre director, producer, screenwriter and actor.
Thought 3: Professions of Elia Kazan are director, producer, screenwriter, and actor. So profession Nicholas Ray and Elia Kazan have in common is director, screenwriter, and actor.
Action 3: Finish[director, screenwriter, actor]
Question: Which magazine was started first Arthur's Magazine or First for Women?
Thought 1: I need to search Arthur's Magazine and First for Women, and find which was started first.
Action 1: Search[Arthur's Magazine]
Observation 1: Arthur's Magazine (1844-1846) was an American literary periodical published in Philadelphia in the 19th century.
Thought 2: Arthur's Magazine was started in 1844. I need to search First for Women next.
Action 2: Search[First for Women]
Observation 2: First for Women is a woman's magazine published by Bauer Media Group in the USA.[1] The magazine was started in 1989.
Thought 3: First for Women was started in 1989. 1844 (Arthur's Magazine) < 1989 (First for Women), so Arthur's Magazine was started first.
Action 3: Finish[Arthur's Magazine]
Question: Were Pavel Urysohn and Leonid Levin known for the same type of work?
Thought 1: I need to search Pavel Urysohn and Leonid Levin, find their types of work, then find if they are the same.
Action 1: Search[Pavel Urysohn]
Observation 1: Pavel Samuilovich Urysohn (February 3, 1898 â August 17, 1924) was a Soviet mathematician who is best known for his contributions in dimension theory.
Thought 2: Pavel Urysohn is a mathematician. I need to search Leonid Levin next and find its type of work.
Action 2: Search[Leonid Levin]
Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist.
Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work.
Action 3: Finish[yes]
"""
+ question
)
for i in range(1, len(triplets) + 2):
s += "Thought " + str(i) + ":"
# NOTE: This is an implementation for replaying a given trace for benchmark purposes. It is not an actual ReAct agent implementation.
ss = s.fork(1)
ss[0] += sgl.gen(name="thought_action", max_tokens=200, stop="Observation")
ss.join()
# to verify the correctness of output, this should be collected
# print(ss[0]["thought_action"])
if i > len(triplets):
break
s += (
triplets[i - 1]["thought"]
+ "\nAction "
+ str(i)
+ ":"
+ triplets[i - 1]["action"]
+ "\nObservation "
+ str(i)
+ ":"
+ triplets[i - 1]["observation"]
+ "\n"
)
def main(args):
lines = read_jsonl(args.data_path)[: args.num_questions]
arguments = [{"question": k, "triplets": v} for l in lines for k, v in l.items()]
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
states = []
tic = time.perf_counter()
states = webthink.run_batch(
arguments,
temperature=0,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.perf_counter() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "ReAct Agents",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": len(arguments),
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="hotpotqa_100.jsonl")
parser.add_argument("--num-questions", type=int, default=10)
args = add_common_sglang_args_and_parse(parser)
main(args)
# Run benchmark
This benchmark is primarily intended to be used with reasoning models like `DeepSeek-R1` and its distilled models like `DeepSeek-R1-Distill-Qwen-1.5B`. Please use
```bash
pip install antlr4-python3-runtime
```
for `parse_latex` which we use for symbolic equality check.
## Benchmark sglang
1. Launch the Server
```bash
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port 30000
```
Note that depending on the GPU this benchmark will take quiet some time. To employ data parallelism please use:
```bash
python3 -m sglang_router.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port 30000 --dp-size 4
```
2. Benchmarking
We use [suggested](https://github.com/deepseek-ai/DeepSeek-R1) parameters of `temperature=0.6`, `top_p=.95`, `max_new_tokens=32768`. The command line argument `num-tries` can be used to evaluate the model multiple times on the same question. We use the suggested `64` from the repo for AIME 2024. For LIMO, we use `8` as the number of tries due to the size of the dataset.
By default evaluate on LIMO dataset.
```bash
python3 bench_sglang.py --parallel 256 --num-tries 64 --port 30000
```
Evaluate on AIME 2024 dataset.
```bash
python3 bench_sglang.py --parallel 256 --port 30000 --data-path Maxwell-Jia/AIME_2024 --question-key Problem --answer-key Answer --num-tries 64
```
Evaluate on [AIME 2025 I dataset](https://huggingface.co/datasets/opencompass/AIME2025). For benchmark result see [here](https://matharena.ai/).
```bash
python3 bench_sglang.py --parallel 256 --port 30000 --data-path opencompass/AIME2025 --question-key question --answer-key answer --num-tries 64
```
## Results
### Evaluation Results
| Dataset | Num Tries | Accuracy | Reference | Standard Error |
|------------|-----------|----------|-----------|-----------|
| LIMO | 8 | 47.7% | ? | ? |
| AIME 2024 | 64 | 33.2% | 28.9% | 3.4% |
| AIME 2025 I| 64 | 29.9% | 25.0% | ? |
### Statistic Analysis Results
Set up SGLang engine for statistic analysis, for high efficiency we use `--dp-size 8` for data parallelism:
```bash
python3 -m sglang_router.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port 30000 --dp-size 8
```
**Experiment 1**:
We fixed the number of attempts (num_tries) and conducted multiple runs to assess the consistency of the model's performance. The results show that all recorded accuracies lie within ± one standard error deviation from the mean. This suggests that **our metric serves as an effective upper bound for the deviation of reported accuracy**.
To collect the accuracy, run the following command 30 times:
```bash
python3 bench_sglang.py --parallel 64 --port 30000 --data-path Maxwell-Jia/AIME_2024 --question-key Problem --answer-key Answer --num-tries 64
```
![acc_hist](figure/Acc_histplot.png)
**Experiment 2**: We explored the relationship between the number of attempts (num_tries) and the standard error (SE) by varying num_tries across a range (e.g., 8, 16, 32, ..., 256) and performing a single run for each value. The results demonstrate that as the number of attempts increases, the standard error decreases, leading to **greater stability in answer accuracy**.
To reveal the relationship, run the command 6 times and adjust the parameter `--num-tries` for each run:
```bash
python3 bench_sglang.py --parallel 64 --port 30000 --data-path Maxwell-Jia/AIME_2024 --question-key Problem --answer-key Answer --num-tries <num_tries>
```
![SE_num_tries](figure/SE_numtries.png)
# Adapted from https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/data_processing/answer_extraction.py
import re
import regex
def _fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if len(substr) > 0 and substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
if "sqrt" not in a:
a = int(a)
if "sqrt" not in b:
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except:
return string
def _fix_sqrt(string):
_string = re.sub(r"\\sqrt(-?[0-9.a-zA-Z]+)", r"\\sqrt{\1}", string)
_string = re.sub(r"\\sqrt\s+(\w+)$", r"\\sqrt{\1}", _string)
return _string
def _fix_tan(string):
_string = re.sub(r"\\tan(-?[0-9.a-zA-Z]+)", r"\\tan{\1}", string)
_string = re.sub(r"\\tan\s+(\w+)$", r"\\tan{\1}", _string)
return _string
def strip_string(string):
string = str(string).strip()
# linebreaks
string = string.replace("\n", "")
# right "."
string = string.rstrip(".")
# remove inverse spaces
string = string.replace("\\!", "")
# string = string.replace("\\ ", "")
# replace \\ with \
# string = string.replace("\\\\", "\\")
# string = string.replace("\\\\", "\\")
if string.startswith("\\text{") and string.endswith("}"):
string = string.split("{", 1)[1][:-1]
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
string = string.replace("cfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove unit: miles, dollars if after is not none
_string = re.sub(r"\\text{.*?}$", "", string).strip()
if _string != "" and _string != string:
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
string = _string
# Remove circ (degrees)
string = string.replace("^{\\circ}", "").strip()
string = string.replace("^\\circ", "").strip()
string = regex.sub(r"\{(c|m)?m\}(\^(2|3))?", "", string).strip()
string = regex.sub(r"p\.m\.$", "", string).strip()
string = regex.sub(r"(\d)\s*t$", r"\1", string).strip()
# remove dollar signs
string = string.replace("\\$", "")
string = string.replace("$", "")
# string = string.replace("\\text", "")
string = string.replace("x\\in", "")
# remove percentage
string = string.replace("\\%", "%")
string = string.replace("\%", "%")
# string = string.replace("%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# cdot
string = string.replace("\\cdot", "")
# inf
string = string.replace("infinity", "\\infty")
if "\\infty" not in string:
string = string.replace("inf", "\\infty")
string = string.replace("+\\inity", "\\infty")
# and
# string = string.replace("and", "")
string = string.replace("\\mathbf", "")
string = string.replace("\\mathrm", "")
# use regex to remove \mbox{...}
string = re.sub(r"\\mbox{.*?}", "", string)
# quote
string.replace("'", "")
string.replace('"', "")
# i, j
if "j" in string and "i" not in string:
string = string.replace("j", "i")
# replace a.000b where b is not number or b is end, with ab, use regex
string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
string = re.sub(r"(\d+)\.0+$", r"\1", string)
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
# if len(string.split("=")) == 2:
# if len(string.split("=")[0]) <= 2:
# string = string.split("=")[1]
string = _fix_sqrt(string)
string = _fix_tan(string)
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
string = regex.sub(r"(\\|,|\.)+$", "", string)
return string
def extract_boxed_answers(text):
answers = []
for piece in text.split("boxed{")[1:]:
n = 0
for i in range(len(piece)):
if piece[i] == "{":
n += 1
elif piece[i] == "}":
n -= 1
if n < 0:
if i + 1 < len(piece) and piece[i + 1] == "%":
answers.append(piece[: i + 1])
else:
answers.append(piece[:i])
break
return answers
def extract_program_output(pred_str):
"""
extract output between the last ```output\n...\n```
"""
if "```output" not in pred_str:
return ""
if "```output" in pred_str:
pred_str = pred_str.split("```output")[-1]
if "```" in pred_str:
pred_str = pred_str.split("```")[0]
output = pred_str.strip()
return output
def extract_answer(pred_str, exhaust=False):
pred = []
if "final answer is $" in pred_str and "$. I hope" in pred_str:
tmp = pred_str.split("final answer is $", 1)[1]
pred = [tmp.split("$. I hope", 1)[0].strip()]
elif "boxed" in pred_str:
pred = extract_boxed_answers(pred_str)
elif "he answer is" in pred_str:
pred = [pred_str.split("he answer is")[-1].strip()]
else:
program_output = extract_program_output(pred_str)
if program_output != "":
# fall back to program
pred.append(program_output)
else: # use the last number
pattern = "-?\d*\.?\d+"
ans = re.findall(pattern, pred_str.replace(",", ""))
if len(ans) >= 1:
ans = ans[-1]
else:
ans = ""
if ans:
pred.append(ans)
# multiple line
_pred = []
for ans in pred:
ans = ans.strip().split("\n")[0]
ans = ans.lstrip(":")
ans = ans.rstrip(".")
ans = ans.rstrip("/")
ans = strip_string(ans)
_pred.append(ans)
if exhaust:
return _pred
else:
return _pred[-1] if _pred else ""
def extract_math_answer(question, reasoning, task):
answer = []
for ans in extract_answer(reasoning, exhaust=True):
if "separated by commas" in question and all(ch not in ans for ch in "()[]"):
answer.extend([a.strip() for a in ans.split(",")])
elif regex.search(r"\\text\{\s*and\s*\}", ans):
answer.extend(
[
a.strip()
for a in regex.sub(r"\\text\{\s*and\s*\}", "[SEP]", ans).split(
"[SEP]"
)
]
)
else:
answer.append(ans.strip())
return answer
import argparse
import json
import time
import answer_extraction
import eval_utils
import numpy as np
from datasets import load_dataset
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text
@sgl.function
def reasoning_gen(s, question: str):
s += sgl.user(
question
+ "\nPlease reason step by step, and put your final answer within \boxed{}."
)
s += sgl.assistant(
sgl.gen(
"answer",
)
)
def convert_dataset(path: str, question_key: str, answer_key: str, num_tries: int):
raw_dataset = load_dataset(path)
questions = []
answers = []
for data in raw_dataset["train"]:
question = data[question_key]
answer = data[answer_key]
for _ in range(num_tries):
questions.append({"question": question})
answers.append({"answer": answer})
return questions, answers
def main(args):
# Select backend
sgl.set_default_backend(select_sglang_backend(args))
# Get dataset
questions, answers = convert_dataset(
args.data_path, args.question_key, args.answer_key, args.num_tries
)
# Run requests
tic = time.perf_counter()
states = reasoning_gen.run_batch(
questions,
num_threads=args.parallel,
progress_bar=True,
temperature=0.6,
max_new_tokens=32768,
top_p=0.95,
)
latency = time.perf_counter() - tic
# Extract results and record outcomes in a list.
outcomes = []
for i, state in enumerate(states):
try:
pred_answer = answer_extraction.extract_math_answer(
questions[i]["question"], state["answer"], "limo"
)
gt_answer = str(answers[i]["answer"])
pred_answer = (
pred_answer[-1] if isinstance(pred_answer, list) else pred_answer
)
is_correct = 1 if eval_utils.math_equal(pred_answer, gt_answer) else 0
except Exception as e:
print(f"Error extracting answer: {e}")
is_correct = 0
outcomes.append(is_correct)
# Calculate overall accuracy using numpy
overall_accuracy = np.mean(outcomes)
print(f"Overall Accuracy: {overall_accuracy}")
# Calculate mean standard error over questions if num_tries >= 2
if args.num_tries > 1:
outcomes_np = np.array(outcomes).reshape(-1, args.num_tries)
# Using sample standard deviation with ddof=1
std_per_question = np.std(outcomes_np, axis=1, ddof=1)
# Compute the standard error for each question: std / sqrt(num_tries)
se_per_question = std_per_question / np.sqrt(args.num_tries)
mean_se = se_per_question.mean()
print(f"Mean Standard Error of Accuracy across questions: {mean_se}")
else:
mean_se = None
print("Not enough samples per question to compute standard error.")
# Calculate output throughput
num_output_tokens = sum(
s.get_meta_info("answer")["completion_tokens"] for s in states
)
output_throughput = num_output_tokens / latency
print(f"Output throughput: {output_throughput} token/s")
# Dump results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
# Write results
with open(args.result_file, "a") as fout:
value = {
"task": "limo",
"backend": args.backend,
"latency": round(latency, 3),
"overall_accuracy": round(overall_accuracy, 3),
"mean_se_accuracy": round(mean_se, 3) if mean_se is not None else None,
"num_requests": len(questions),
"other": {
"num_questions": len(questions),
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="GAIR/LIMO")
parser.add_argument("--question-key", type=str, default="question")
parser.add_argument("--answer-key", type=str, default="answer")
parser.add_argument("--num-tries", type=int, default=1)
add_common_sglang_args_and_parse(parser)
args = parser.parse_args()
main(args)
# Adapted from https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py
from math import isclose
import regex
from sympy import N, simplify
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
def parse_digits(num):
# format: 234.23 || 23%
num = regex.sub(",", "", str(num))
try:
return float(num)
except:
if num.endswith("%"):
num = num[:-1]
if num.endswith("\\"):
num = num[:-1]
try:
return float(num) / 100
except:
pass
return None
def is_digit(num):
# paired with parse_digits
return parse_digits(num) is not None
def symbolic_equal(a, b):
def _parse(s):
for f in [parse_latex, parse_expr]:
try:
return f(s)
except:
pass
return s
a = _parse(a)
b = _parse(b)
try:
if simplify(a - b) == 0:
return True
except:
pass
try:
if isclose(N(a), N(b), abs_tol=1e-3):
return True
except:
pass
return False
def math_equal(prediction, reference, include_percentage=True, is_close=True):
"""
Exact match of math if and only if:
1. numerical equal: both can convert to float and are equal
2. symbolic equal: both can convert to sympy expression and are equal
"""
if str(prediction) == str(reference):
return True
try: # 1. numerical equal
if is_digit(prediction) and is_digit(reference):
prediction = parse_digits(prediction)
reference = parse_digits(reference)
# number questions
if include_percentage:
gt_result = [reference / 100, reference, reference * 100]
else:
gt_result = [reference]
for item in gt_result:
try:
if is_close:
if isclose(item, prediction, abs_tol=1e-3):
return True
else:
if item == prediction:
return True
except Exception:
continue
return False
except:
pass
if not prediction and prediction not in [0, False]:
return False
# 2. symbolic equal
reference = str(reference).strip()
prediction = str(prediction).strip()
if (
regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
):
pred_parts = prediction[1:-1].split(",")
ref_parts = reference[1:-1].split(",")
if len(pred_parts) == len(ref_parts):
if all(
[
math_equal(
pred_parts[i], ref_parts[i], include_percentage, is_close
)
for i in range(len(pred_parts))
]
):
return True
# Add back matrix comparison
if (
(
prediction.startswith("\\begin{pmatrix}")
or prediction.startswith("\\begin{bmatrix}")
)
and (
prediction.endswith("\\end{pmatrix}")
or prediction.endswith("\\end{bmatrix}")
)
and (
reference.startswith("\\begin{pmatrix}")
or reference.startswith("\\begin{bmatrix}")
)
and (
reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
)
):
pred_lines = [
line.strip()
for line in prediction[
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
].split("\\\\")
if line.strip()
]
ref_lines = [
line.strip()
for line in reference[
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
].split("\\\\")
if line.strip()
]
matched = True
if len(pred_lines) == len(ref_lines):
for pred_line, ref_line in zip(pred_lines, ref_lines):
pred_parts = pred_line.split("&")
ref_parts = ref_line.split("&")
if len(pred_parts) == len(ref_parts):
if not all(
[
math_equal(
pred_parts[i],
ref_parts[i],
include_percentage,
is_close,
)
for i in range(len(pred_parts))
]
):
matched = False
break
else:
matched = False
if not matched:
break
else:
matched = False
if matched:
return True
# Add back equation comparison
if prediction.count("=") == 1 and reference.count("=") == 1:
pred = prediction.split("=")
pred = f"{pred[0].strip()} - ({pred[1].strip()})"
ref = reference.split("=")
ref = f"{ref[0].strip()} - ({ref[1].strip()})"
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
return True
elif (
prediction.count("=") == 1
and len(prediction.split("=")[0].strip()) <= 2
and "=" not in reference
):
if math_equal(
prediction.split("=")[1], reference, include_percentage, is_close
):
return True
elif (
reference.count("=") == 1
and len(reference.split("=")[0].strip()) <= 2
and "=" not in prediction
):
if math_equal(
prediction, reference.split("=")[1], include_percentage, is_close
):
return True
# symbolic equal with sympy
if symbolic_equal(prediction, reference):
return True
return False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment