Unverified Commit 62222bd2 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Minor tool for comparison of benchmark results (#7974)

parent ed0fdbf3
......@@ -10,6 +10,7 @@ import numpy as np
from sglang.api import set_default_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
dump_bench_raw_result,
select_sglang_backend,
)
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
......@@ -115,6 +116,12 @@ def main(args):
# Dump results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
dump_bench_raw_result(
path=args.raw_result_file,
states=states,
preds=preds,
labels=labels,
)
with open(args.result_file, "a") as fout:
value = {
......
......@@ -9,6 +9,7 @@ import tiktoken
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
dump_bench_raw_result,
select_sglang_backend,
)
......@@ -142,6 +143,13 @@ def main(args):
assert pt == len(cors)
weighted_acc = np.mean(cors)
dump_bench_raw_result(
path=args.raw_result_file,
states=states,
preds=preds,
labels=labels,
)
# Print results
print("Total latency: {:.3f}".format(latency))
print("Average accuracy: {:.3f}".format(weighted_acc))
......
import argparse
import json
from pathlib import Path
import polars as pl
_DESCRIPTION = """Compare and find differences to benchmark outputs.
Supported inputs:
* The samples jsonl from `lm_eval --log_samples --output_path FOLDER_NAME`
* The output from `gsm8k/bench_sglang.py --raw-result-file FILE_NAME` (or mmlu)
"""
def main(args):
df_input = _transform_df_input(_compute_df_raw(args))
assert all(
c in df_input.columns
for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"]
)
df_meta = _compute_df_meta(df_input)
df_correctness_per_trial = df_input.group_by(
"category", "trial_index", maintain_order=True
).agg(pl.col("correct").mean())
df_correctness_delta = (
df_meta.group_by("correctness_delta").len().sort("correctness_delta")
)
df_good_to_bad = df_meta.filter(pl.col("correctness_delta") < 0)
df_bad_to_good = df_meta.filter(pl.col("correctness_delta") > 0)
print(f"Dump output to {args.output_path}")
Path(args.output_path).write_text(
json.dumps(
dict(
df_meta=df_meta.to_dicts(),
df_good_to_bad=df_good_to_bad.to_dicts(),
df_bad_to_good=df_bad_to_good.to_dicts(),
)
)
)
if not args.disable_print_details:
with pl.Config(
fmt_str_lengths=10000,
tbl_cols=-1,
tbl_rows=-1,
tbl_width_chars=-1,
tbl_formatting="UTF8_FULL",
):
print("====== Correctness per trial ======")
print(df_correctness_per_trial)
print(
"====== Correctness Delta (-1.0 means all-right becomes all-wrong) ======"
)
print(df_correctness_delta)
for name, df in [
("Good->Bad", df_good_to_bad),
("Bad->Good", df_bad_to_good),
]:
print(f"====== Concrete Examples: {name} ======")
print(df)
def _compute_df_raw(args):
return pl.concat(
[
_read_df_raw(p, category=category, trial_index=i)
for category, paths in [
("baseline", args.baseline_path),
("target", args.target_path),
]
for i, p in enumerate(paths)
]
)
def _read_df_raw(path: str, category: str, trial_index: int):
return pl.read_ndjson(path).with_columns(
category=pl.lit(category), trial_index=trial_index
)
def _transform_df_input(df: pl.DataFrame):
if "doc_id" in df.columns:
print("Transform mode: lm_eval")
filter_names = df["filter"].unique(maintain_order=True).to_list()
if len(filter_names) > 1:
filter_name = filter_names[0]
print(f"Choose {filter_name=} among {filter_names}")
df = df.filter(pl.col("filter") == filter_name)
df = df.select(
pl.col("category"),
pl.col("trial_index"),
prompt_id=pl.col("doc_id"),
prompt=pl.col("arguments").struct.field("gen_args_0").struct.field("arg_0"),
output=pl.col("resps").list.get(0).list.get(0),
correct=pl.col("exact_match").cast(bool),
)
return df
elif "prompt_id" in df.columns:
print("Transform mode: SGLang bench")
return df
else:
raise Exception(f"Unknown data: {df.columns}")
def _compute_df_meta(df_input: pl.DataFrame):
df_input = df_input.sort("prompt_id", "category", "trial_index")
df_meta = pl.DataFrame(
[
_handle_one_prompt(df_one_prompt)
for df_one_prompt in df_input.partition_by("prompt_id", maintain_order=True)
]
)
df_meta = df_meta.with_columns(
correctness_delta=pl.col("correctness_target") - pl.col("correctness_baseline"),
)
df_meta = df_meta.sort("correctness_delta", "output_same_prefix_len")
return df_meta
def _handle_one_prompt(df_one_prompt: pl.DataFrame):
assert len(set(df_one_prompt["prompt"])) == 1
df_baseline = df_one_prompt.filter(pl.col("category") == "baseline")
df_target = df_one_prompt.filter(pl.col("category") == "target")
outputs_baseline = df_baseline["output"].to_list()
outputs_target = df_target["output"].to_list()
output_same_prefix_len = max(
_compute_str_prefix_len(output_baseline, output_target)
for output_baseline in outputs_baseline
for output_target in outputs_target
)
return dict(
prompt_id=df_one_prompt[0, "prompt_id"],
correctness_baseline=df_baseline["correct"].mean(),
correctness_target=df_target["correct"].mean(),
output_same_prefix_len=output_same_prefix_len,
prompt=df_one_prompt[0, "prompt"],
outputs_baseline=outputs_baseline,
outputs_target=outputs_target,
)
def _compute_str_prefix_len(a: str, b: str) -> int:
min_len = min(len(a), len(b))
for i in range(min_len):
if a[i] != b[i]:
return i
return min_len
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=_DESCRIPTION)
parser.add_argument("--baseline-path", type=str, nargs="+")
parser.add_argument("--target-path", type=str, nargs="+")
parser.add_argument(
"--output-path", type=str, default="/tmp/text_comparator_output.json"
)
parser.add_argument("--disable-print-details", action="store_true")
args = parser.parse_args()
main(args)
......@@ -15,6 +15,7 @@ import unittest
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from types import SimpleNamespace
from typing import Awaitable, Callable, List, Optional, Tuple
......@@ -27,6 +28,7 @@ from sglang.bench_serving import run_benchmark
from sglang.global_config import global_config
from sglang.lang.backend.openai import OpenAI
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.lang.interpreter import ProgramState
from sglang.srt.utils import (
get_bool_env_var,
get_device,
......@@ -348,6 +350,7 @@ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
help="Device type (auto/cuda/rocm/cpu). Auto will detect available platforms",
)
parser.add_argument("--result-file", type=str, default="result.jsonl")
parser.add_argument("--raw-result-file", type=str)
args = parser.parse_args()
return args
......@@ -1309,3 +1312,35 @@ class CustomTestCase(unittest.TestCase):
lambda: super(CustomTestCase, self)._callTestMethod(method),
max_retry=max_retry,
)
def dump_bench_raw_result(
path: str,
states,
preds,
labels,
):
if not path:
return
rows = []
for i in range(len(states)):
state = states[i]
output = state["answer"]
prompt = _ensure_remove_suffix(state.text(), output)
rows.append(
dict(
prompt_id=i,
prompt=prompt,
output=output,
correct=bool(preds[i] == labels[i]),
)
)
print(f"BenchRawResultDumper save results to {path}")
Path(path).write_text("\n".join(json.dumps(row) for row in rows))
def _ensure_remove_suffix(text: str, suffix: str):
assert text.endswith(suffix)
return text.removesuffix(suffix)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment