Commit 2e046ce3 authored by Julen Etxaniz's avatar Julen Etxaniz
Browse files

add --write_detailed_eval_info to dump JSON with prompts and completions

parent 8fc04fe5
......@@ -32,7 +32,7 @@ repos:
rev: 22.3.0
hooks:
- id: black
language_version: python3.8
language_version: python3.9
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
hooks:
......
......@@ -23,8 +23,9 @@ def simple_evaluate(
description_dict=None,
check_integrity=False,
decontamination_ngrams_path=None,
write_detailed_eval_info=False,
detailed_eval_info_path=None,
):
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
......@@ -50,6 +51,10 @@ def simple_evaluate(
Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
:param write_detailed_eval_info: bool
If True, write details about prompts and logits to json for all tasks
:param detailed_eval_info_path: str, optional
Directory to which detailed eval info will be written. Defaults to present working dir.
:return
Dictionary of results
"""
......@@ -91,6 +96,8 @@ def simple_evaluate(
bootstrap_iters=bootstrap_iters,
description_dict=description_dict,
decontamination_ngrams_path=decontamination_ngrams_path,
write_detailed_eval_info=write_detailed_eval_info,
detailed_eval_info_path=detailed_eval_info_path,
)
# add info about the model and few shot config
......@@ -122,6 +129,8 @@ def evaluate(
bootstrap_iters=100000,
description_dict=None,
decontamination_ngrams_path=None,
write_detailed_eval_info=False,
detailed_eval_info_path=None,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -139,6 +148,10 @@ def evaluate(
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
:param write_detailed_eval_info: bool
If True, write all prompts, logits and metrics to json for offline analysis
:param detailed_eval_info_path: str, optional
Directory to which detailed eval info will be written. Defaults to present working dir
:return
Dictionary of results
"""
......@@ -175,6 +188,7 @@ def evaluate(
# TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
docs = {}
detailed_eval_info = {}
docs_for_decontamination = collections.defaultdict(list)
......@@ -197,6 +211,10 @@ def evaluate(
rnd = random.Random()
rnd.seed(42)
rnd.shuffle(task_docs)
print(f"Task: {task_name}; number of docs: {len(task_docs)}")
if write_detailed_eval_info:
prompt_details = []
description = (
description_dict[task_name]
......@@ -205,7 +223,6 @@ def evaluate(
)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
if decontaminate and task.should_decontaminate():
docs_for_decontamination[(task_name, task_set)].append(
task.doc_to_decontamination_query(doc)
......@@ -216,6 +233,17 @@ def evaluate(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
reqs = task.construct_requests(doc, ctx)
if write_detailed_eval_info:
prompt_details.append({"doc_id": doc_id})
# print the prompt for the first few documents
if doc_id < 1:
print(
f"Task: {task_name}; document {doc_id}; context prompt (starting on next line):\n{ctx}\n(end of prompt on previous line)"
)
print("Requests:", reqs)
if not isinstance(reqs, (list, tuple)):
reqs = [reqs]
for i, req in enumerate(reqs):
......@@ -224,6 +252,14 @@ def evaluate(
# doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append((i, task_name, doc, doc_id))
if write_detailed_eval_info:
prompt_details[-1][f"prompt_{i}"] = "".join(
(map(lambda x: "".join(x), req.args))
)
if write_detailed_eval_info:
detailed_eval_info[task_name] = prompt_details
# Compare all tasks/sets at once to ensure a single training set scan
if decontaminate:
from lm_eval.decontamination.decontaminate import get_train_overlap
......@@ -252,6 +288,20 @@ def evaluate(
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp))
if write_detailed_eval_info:
detailed_eval_info[task_name][doc_id][f"logit_{i}"] = resp
task = task_dict[task_name]
if isinstance(task, lm_eval.base.MultipleChoiceTask):
detailed_eval_info[task_name][doc_id]["truth"] = doc["gold"]
elif isinstance(task, lm_eval.tasks.winogrande.Winogrande):
detailed_eval_info[task_name][doc_id]["truth"] = task.answer_to_num[
doc["answer"]
]
else:
detailed_eval_info[task_name][doc_id]["truth"] = task.doc_to_target(
doc
)
vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task
......@@ -266,6 +316,9 @@ def evaluate(
for metric, value in metrics.items():
vals[(task_name, metric)].append(value)
if write_detailed_eval_info:
detailed_eval_info[task_name][doc_id][metric] = str(value)
# Re-use the evaluation for the decontaminated set by just ignoring the overlaps
if decontaminate and task_name in overlaps:
if doc_id not in overlaps[task_name]:
......@@ -294,6 +347,32 @@ def evaluate(
if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items)
if write_detailed_eval_info:
import json
import pathlib
detailed_eval_info_path = (
pathlib.Path(detailed_eval_info_path)
if detailed_eval_info_path is not None
else pathlib.Path(".")
)
try:
detailed_eval_info_path.mkdir(parents=True, exist_ok=False)
except FileExistsError:
pass
for task_name, _ in task_dict_items:
with open(
detailed_eval_info_path.joinpath(
f"{task_name}_detailed_eval_info.json"
),
"w",
encoding="utf8",
) as fp:
json.dump(
detailed_eval_info[task_name], fp, indent=4, ensure_ascii=False
)
return {"results": dict(results), "versions": dict(versions)}
......
......@@ -40,6 +40,10 @@ def parse_args():
parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
parser.add_argument(
"--write_detailed_eval_info", action="store_true", default=False
)
parser.add_argument("--detailed_eval_info_path", type=str, default=None)
return parser.parse_args()
......@@ -88,6 +92,8 @@ def main():
description_dict=description_dict,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
write_detailed_eval_info=args.write_detailed_eval_info,
detailed_eval_info_path=args.detailed_eval_info_path,
)
dumped = json.dumps(results, indent=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment