Unverified Commit 8e85ee88 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support simple evals in text comparator (#8867)

parent adf73175
import argparse
import hashlib
import json
from pathlib import Path
......@@ -13,7 +14,11 @@ Supported inputs:
def main(args):
df_input = _transform_df_input(_compute_df_raw(args))
if args.data_type == "simple_evals":
df_input = _compute_df_input_mode_simple_evals(args)
else:
df_input = _transform_df_input(_compute_df_raw(args))
assert all(
c in df_input.columns
for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"]
......@@ -37,8 +42,9 @@ def main(args):
df_meta=df_meta.to_dicts(),
df_good_to_bad=df_good_to_bad.to_dicts(),
df_bad_to_good=df_bad_to_good.to_dicts(),
)
)
),
indent=4,
),
)
if not args.disable_print_details:
......@@ -65,19 +71,70 @@ def main(args):
print(df)
def _compute_df_input_mode_simple_evals(args):
return pl.concat(
[
_compute_df_input_one_mode_simple_evals(**info)
for info in _get_file_infos(args=args)
]
)
def _compute_df_input_one_mode_simple_evals(path, category, trial_index):
data = json.loads(Path(path).read_text())
rows = []
for single_eval_result in data["metadata"]["single_eval_results"]:
prompt = single_eval_result["example_level_metadata"][
"actual_queried_prompt_messages"
]
score = single_eval_result["score"]
assert score in {0.0, 1.0}, f"{score=}"
row = dict(
category=category,
trial_index=trial_index,
prompt_id=_compute_id_from_object(prompt),
prompt=json.dumps(prompt),
output=single_eval_result["example_level_metadata"]["response_text"],
correct=score == 1.0,
)
rows.append(row)
return pl.DataFrame(rows)
def _compute_id_from_object(obj):
if isinstance(obj, pl.Series):
obj = obj.to_list()
json_str = json.dumps(obj, sort_keys=True, ensure_ascii=False)
return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
def _compute_df_raw(args):
return pl.concat(
[
_read_df_raw(p, category=category, trial_index=i)
for category, paths in [
("baseline", args.baseline_path),
("target", args.target_path),
]
for i, p in enumerate(paths)
_read_df_raw(
path=info["path"],
category=info["category"],
trial_index=info["trial_index"],
)
for info in _get_file_infos(args=args)
]
)
def _get_file_infos(args):
return [
dict(path=path, category=category, trial_index=trial_index)
for category, paths in [
("baseline", args.baseline_path),
("target", args.target_path),
]
for trial_index, path in enumerate(paths)
]
def _read_df_raw(path: str, category: str, trial_index: int):
return pl.read_ndjson(path).with_columns(
category=pl.lit(category), trial_index=trial_index
......@@ -108,7 +165,9 @@ def _transform_df_input(df: pl.DataFrame):
print("Transform mode: SGLang bench")
return df
else:
raise Exception(f"Unknown data: {df.columns}")
raise Exception(
f"Unknown data: {df.columns}. You may need to set `--data-type` if using e.g. simple_evals."
)
def _compute_df_meta(df_input: pl.DataFrame):
......@@ -127,7 +186,9 @@ def _compute_df_meta(df_input: pl.DataFrame):
def _handle_one_prompt(df_one_prompt: pl.DataFrame):
assert len(set(df_one_prompt["prompt"])) == 1
assert (
len(set(_compute_id_from_object(obj) for obj in df_one_prompt["prompt"])) == 1
)
df_baseline = df_one_prompt.filter(pl.col("category") == "baseline")
df_target = df_one_prompt.filter(pl.col("category") == "target")
......@@ -162,6 +223,7 @@ def _compute_str_prefix_len(a: str, b: str) -> int:
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=_DESCRIPTION)
parser.add_argument("--data-type", type=str, default="auto")
parser.add_argument("--baseline-path", type=str, nargs="+")
parser.add_argument("--target-path", type=str, nargs="+")
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment