Unverified Commit 62222bd2 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Minor tool for comparison of benchmark results (#7974)

parent ed0fdbf3
...@@ -10,6 +10,7 @@ import numpy as np ...@@ -10,6 +10,7 @@ import numpy as np
from sglang.api import set_default_backend from sglang.api import set_default_backend
from sglang.test.test_utils import ( from sglang.test.test_utils import (
add_common_sglang_args_and_parse, add_common_sglang_args_and_parse,
dump_bench_raw_result,
select_sglang_backend, select_sglang_backend,
) )
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
...@@ -115,6 +116,12 @@ def main(args): ...@@ -115,6 +116,12 @@ def main(args):
# Dump results # Dump results
dump_state_text(f"tmp_output_{args.backend}.txt", states) dump_state_text(f"tmp_output_{args.backend}.txt", states)
dump_bench_raw_result(
path=args.raw_result_file,
states=states,
preds=preds,
labels=labels,
)
with open(args.result_file, "a") as fout: with open(args.result_file, "a") as fout:
value = { value = {
......
...@@ -9,6 +9,7 @@ import tiktoken ...@@ -9,6 +9,7 @@ import tiktoken
from sglang.test.test_utils import ( from sglang.test.test_utils import (
add_common_sglang_args_and_parse, add_common_sglang_args_and_parse,
dump_bench_raw_result,
select_sglang_backend, select_sglang_backend,
) )
...@@ -142,6 +143,13 @@ def main(args): ...@@ -142,6 +143,13 @@ def main(args):
assert pt == len(cors) assert pt == len(cors)
weighted_acc = np.mean(cors) weighted_acc = np.mean(cors)
dump_bench_raw_result(
path=args.raw_result_file,
states=states,
preds=preds,
labels=labels,
)
# Print results # Print results
print("Total latency: {:.3f}".format(latency)) print("Total latency: {:.3f}".format(latency))
print("Average accuracy: {:.3f}".format(weighted_acc)) print("Average accuracy: {:.3f}".format(weighted_acc))
......
import argparse
import json
from pathlib import Path
import polars as pl
_DESCRIPTION = """Compare and find differences to benchmark outputs.
Supported inputs:
* The samples jsonl from `lm_eval --log_samples --output_path FOLDER_NAME`
* The output from `gsm8k/bench_sglang.py --raw-result-file FILE_NAME` (or mmlu)
"""
def main(args):
df_input = _transform_df_input(_compute_df_raw(args))
assert all(
c in df_input.columns
for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"]
)
df_meta = _compute_df_meta(df_input)
df_correctness_per_trial = df_input.group_by(
"category", "trial_index", maintain_order=True
).agg(pl.col("correct").mean())
df_correctness_delta = (
df_meta.group_by("correctness_delta").len().sort("correctness_delta")
)
df_good_to_bad = df_meta.filter(pl.col("correctness_delta") < 0)
df_bad_to_good = df_meta.filter(pl.col("correctness_delta") > 0)
print(f"Dump output to {args.output_path}")
Path(args.output_path).write_text(
json.dumps(
dict(
df_meta=df_meta.to_dicts(),
df_good_to_bad=df_good_to_bad.to_dicts(),
df_bad_to_good=df_bad_to_good.to_dicts(),
)
)
)
if not args.disable_print_details:
with pl.Config(
fmt_str_lengths=10000,
tbl_cols=-1,
tbl_rows=-1,
tbl_width_chars=-1,
tbl_formatting="UTF8_FULL",
):
print("====== Correctness per trial ======")
print(df_correctness_per_trial)
print(
"====== Correctness Delta (-1.0 means all-right becomes all-wrong) ======"
)
print(df_correctness_delta)
for name, df in [
("Good->Bad", df_good_to_bad),
("Bad->Good", df_bad_to_good),
]:
print(f"====== Concrete Examples: {name} ======")
print(df)
def _compute_df_raw(args):
return pl.concat(
[
_read_df_raw(p, category=category, trial_index=i)
for category, paths in [
("baseline", args.baseline_path),
("target", args.target_path),
]
for i, p in enumerate(paths)
]
)
def _read_df_raw(path: str, category: str, trial_index: int):
return pl.read_ndjson(path).with_columns(
category=pl.lit(category), trial_index=trial_index
)
def _transform_df_input(df: pl.DataFrame):
if "doc_id" in df.columns:
print("Transform mode: lm_eval")
filter_names = df["filter"].unique(maintain_order=True).to_list()
if len(filter_names) > 1:
filter_name = filter_names[0]
print(f"Choose {filter_name=} among {filter_names}")
df = df.filter(pl.col("filter") == filter_name)
df = df.select(
pl.col("category"),
pl.col("trial_index"),
prompt_id=pl.col("doc_id"),
prompt=pl.col("arguments").struct.field("gen_args_0").struct.field("arg_0"),
output=pl.col("resps").list.get(0).list.get(0),
correct=pl.col("exact_match").cast(bool),
)
return df
elif "prompt_id" in df.columns:
print("Transform mode: SGLang bench")
return df
else:
raise Exception(f"Unknown data: {df.columns}")
def _compute_df_meta(df_input: pl.DataFrame):
df_input = df_input.sort("prompt_id", "category", "trial_index")
df_meta = pl.DataFrame(
[
_handle_one_prompt(df_one_prompt)
for df_one_prompt in df_input.partition_by("prompt_id", maintain_order=True)
]
)
df_meta = df_meta.with_columns(
correctness_delta=pl.col("correctness_target") - pl.col("correctness_baseline"),
)
df_meta = df_meta.sort("correctness_delta", "output_same_prefix_len")
return df_meta
def _handle_one_prompt(df_one_prompt: pl.DataFrame):
assert len(set(df_one_prompt["prompt"])) == 1
df_baseline = df_one_prompt.filter(pl.col("category") == "baseline")
df_target = df_one_prompt.filter(pl.col("category") == "target")
outputs_baseline = df_baseline["output"].to_list()
outputs_target = df_target["output"].to_list()
output_same_prefix_len = max(
_compute_str_prefix_len(output_baseline, output_target)
for output_baseline in outputs_baseline
for output_target in outputs_target
)
return dict(
prompt_id=df_one_prompt[0, "prompt_id"],
correctness_baseline=df_baseline["correct"].mean(),
correctness_target=df_target["correct"].mean(),
output_same_prefix_len=output_same_prefix_len,
prompt=df_one_prompt[0, "prompt"],
outputs_baseline=outputs_baseline,
outputs_target=outputs_target,
)
def _compute_str_prefix_len(a: str, b: str) -> int:
min_len = min(len(a), len(b))
for i in range(min_len):
if a[i] != b[i]:
return i
return min_len
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=_DESCRIPTION)
parser.add_argument("--baseline-path", type=str, nargs="+")
parser.add_argument("--target-path", type=str, nargs="+")
parser.add_argument(
"--output-path", type=str, default="/tmp/text_comparator_output.json"
)
parser.add_argument("--disable-print-details", action="store_true")
args = parser.parse_args()
main(args)
...@@ -15,6 +15,7 @@ import unittest ...@@ -15,6 +15,7 @@ import unittest
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from typing import Awaitable, Callable, List, Optional, Tuple from typing import Awaitable, Callable, List, Optional, Tuple
...@@ -27,6 +28,7 @@ from sglang.bench_serving import run_benchmark ...@@ -27,6 +28,7 @@ from sglang.bench_serving import run_benchmark
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.lang.backend.openai import OpenAI from sglang.lang.backend.openai import OpenAI
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.lang.interpreter import ProgramState
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
get_device, get_device,
...@@ -348,6 +350,7 @@ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser): ...@@ -348,6 +350,7 @@ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
help="Device type (auto/cuda/rocm/cpu). Auto will detect available platforms", help="Device type (auto/cuda/rocm/cpu). Auto will detect available platforms",
) )
parser.add_argument("--result-file", type=str, default="result.jsonl") parser.add_argument("--result-file", type=str, default="result.jsonl")
parser.add_argument("--raw-result-file", type=str)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -1309,3 +1312,35 @@ class CustomTestCase(unittest.TestCase): ...@@ -1309,3 +1312,35 @@ class CustomTestCase(unittest.TestCase):
lambda: super(CustomTestCase, self)._callTestMethod(method), lambda: super(CustomTestCase, self)._callTestMethod(method),
max_retry=max_retry, max_retry=max_retry,
) )
def dump_bench_raw_result(
path: str,
states,
preds,
labels,
):
if not path:
return
rows = []
for i in range(len(states)):
state = states[i]
output = state["answer"]
prompt = _ensure_remove_suffix(state.text(), output)
rows.append(
dict(
prompt_id=i,
prompt=prompt,
output=output,
correct=bool(preds[i] == labels[i]),
)
)
print(f"BenchRawResultDumper save results to {path}")
Path(path).write_text("\n".join(json.dumps(row) for row in rows))
def _ensure_remove_suffix(text: str, suffix: str):
assert text.endswith(suffix)
return text.removesuffix(suffix)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment