Commit 2e046ce3 authored by Julen Etxaniz's avatar Julen Etxaniz
Browse files

add --write_detailed_eval_info to dump JSON with prompts and completions

parent 8fc04fe5
...@@ -32,7 +32,7 @@ repos: ...@@ -32,7 +32,7 @@ repos:
rev: 22.3.0 rev: 22.3.0
hooks: hooks:
- id: black - id: black
language_version: python3.8 language_version: python3.9
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.1.0 rev: v2.1.0
hooks: hooks:
......
...@@ -23,8 +23,9 @@ def simple_evaluate( ...@@ -23,8 +23,9 @@ def simple_evaluate(
description_dict=None, description_dict=None,
check_integrity=False, check_integrity=False,
decontamination_ngrams_path=None, decontamination_ngrams_path=None,
write_detailed_eval_info=False,
detailed_eval_info_path=None,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM] :param model: Union[str, LM]
...@@ -50,6 +51,10 @@ def simple_evaluate( ...@@ -50,6 +51,10 @@ def simple_evaluate(
Dictionary of custom task descriptions of the form: `task_name: description` Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool :param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks Whether to run the relevant part of the test suite for the tasks
:param write_detailed_eval_info: bool
If True, write details about prompts and logits to json for all tasks
:param detailed_eval_info_path: str, optional
Directory to which detailed eval info will be written. Defaults to present working dir.
:return :return
Dictionary of results Dictionary of results
""" """
...@@ -91,6 +96,8 @@ def simple_evaluate( ...@@ -91,6 +96,8 @@ def simple_evaluate(
bootstrap_iters=bootstrap_iters, bootstrap_iters=bootstrap_iters,
description_dict=description_dict, description_dict=description_dict,
decontamination_ngrams_path=decontamination_ngrams_path, decontamination_ngrams_path=decontamination_ngrams_path,
write_detailed_eval_info=write_detailed_eval_info,
detailed_eval_info_path=detailed_eval_info_path,
) )
# add info about the model and few shot config # add info about the model and few shot config
...@@ -122,6 +129,8 @@ def evaluate( ...@@ -122,6 +129,8 @@ def evaluate(
bootstrap_iters=100000, bootstrap_iters=100000,
description_dict=None, description_dict=None,
decontamination_ngrams_path=None, decontamination_ngrams_path=None,
write_detailed_eval_info=False,
detailed_eval_info_path=None,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -139,6 +148,10 @@ def evaluate( ...@@ -139,6 +148,10 @@ def evaluate(
Number of iterations for bootstrap statistics Number of iterations for bootstrap statistics
:param description_dict: dict[str, str] :param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description` Dictionary of custom task descriptions of the form: `task_name: description`
:param write_detailed_eval_info: bool
If True, write all prompts, logits and metrics to json for offline analysis
:param detailed_eval_info_path: str, optional
Directory to which detailed eval info will be written. Defaults to present working dir
:return :return
Dictionary of results Dictionary of results
""" """
...@@ -175,6 +188,7 @@ def evaluate( ...@@ -175,6 +188,7 @@ def evaluate(
# TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
docs = {} docs = {}
detailed_eval_info = {}
docs_for_decontamination = collections.defaultdict(list) docs_for_decontamination = collections.defaultdict(list)
...@@ -197,6 +211,10 @@ def evaluate( ...@@ -197,6 +211,10 @@ def evaluate(
rnd = random.Random() rnd = random.Random()
rnd.seed(42) rnd.seed(42)
rnd.shuffle(task_docs) rnd.shuffle(task_docs)
print(f"Task: {task_name}; number of docs: {len(task_docs)}")
if write_detailed_eval_info:
prompt_details = []
description = ( description = (
description_dict[task_name] description_dict[task_name]
...@@ -205,7 +223,6 @@ def evaluate( ...@@ -205,7 +223,6 @@ def evaluate(
) )
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
if decontaminate and task.should_decontaminate(): if decontaminate and task.should_decontaminate():
docs_for_decontamination[(task_name, task_set)].append( docs_for_decontamination[(task_name, task_set)].append(
task.doc_to_decontamination_query(doc) task.doc_to_decontamination_query(doc)
...@@ -216,6 +233,17 @@ def evaluate( ...@@ -216,6 +233,17 @@ def evaluate(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
) )
reqs = task.construct_requests(doc, ctx) reqs = task.construct_requests(doc, ctx)
if write_detailed_eval_info:
prompt_details.append({"doc_id": doc_id})
# print the prompt for the first few documents
if doc_id < 1:
print(
f"Task: {task_name}; document {doc_id}; context prompt (starting on next line):\n{ctx}\n(end of prompt on previous line)"
)
print("Requests:", reqs)
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
reqs = [reqs] reqs = [reqs]
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
...@@ -224,6 +252,14 @@ def evaluate( ...@@ -224,6 +252,14 @@ def evaluate(
# doc_id: unique id that we can get back to a doc using `docs` # doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append((i, task_name, doc, doc_id)) requests_origin[req.request_type].append((i, task_name, doc, doc_id))
if write_detailed_eval_info:
prompt_details[-1][f"prompt_{i}"] = "".join(
(map(lambda x: "".join(x), req.args))
)
if write_detailed_eval_info:
detailed_eval_info[task_name] = prompt_details
# Compare all tasks/sets at once to ensure a single training set scan # Compare all tasks/sets at once to ensure a single training set scan
if decontaminate: if decontaminate:
from lm_eval.decontamination.decontaminate import get_train_overlap from lm_eval.decontamination.decontaminate import get_train_overlap
...@@ -252,6 +288,20 @@ def evaluate( ...@@ -252,6 +288,20 @@ def evaluate(
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]): for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp)) process_res_queue[(task_name, doc_id)].append((i, resp))
if write_detailed_eval_info:
detailed_eval_info[task_name][doc_id][f"logit_{i}"] = resp
task = task_dict[task_name]
if isinstance(task, lm_eval.base.MultipleChoiceTask):
detailed_eval_info[task_name][doc_id]["truth"] = doc["gold"]
elif isinstance(task, lm_eval.tasks.winogrande.Winogrande):
detailed_eval_info[task_name][doc_id]["truth"] = task.answer_to_num[
doc["answer"]
]
else:
detailed_eval_info[task_name][doc_id]["truth"] = task.doc_to_target(
doc
)
vals = collections.defaultdict(list) vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task # unpack results and sort back in order and return control to Task
...@@ -266,6 +316,9 @@ def evaluate( ...@@ -266,6 +316,9 @@ def evaluate(
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, metric)].append(value) vals[(task_name, metric)].append(value)
if write_detailed_eval_info:
detailed_eval_info[task_name][doc_id][metric] = str(value)
# Re-use the evaluation for the decontaminated set by just ignoring the overlaps # Re-use the evaluation for the decontaminated set by just ignoring the overlaps
if decontaminate and task_name in overlaps: if decontaminate and task_name in overlaps:
if doc_id not in overlaps[task_name]: if doc_id not in overlaps[task_name]:
...@@ -294,6 +347,32 @@ def evaluate( ...@@ -294,6 +347,32 @@ def evaluate(
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items) results[task_name][metric + "_stderr"] = stderr(items)
if write_detailed_eval_info:
import json
import pathlib
detailed_eval_info_path = (
pathlib.Path(detailed_eval_info_path)
if detailed_eval_info_path is not None
else pathlib.Path(".")
)
try:
detailed_eval_info_path.mkdir(parents=True, exist_ok=False)
except FileExistsError:
pass
for task_name, _ in task_dict_items:
with open(
detailed_eval_info_path.joinpath(
f"{task_name}_detailed_eval_info.json"
),
"w",
encoding="utf8",
) as fp:
json.dump(
detailed_eval_info[task_name], fp, indent=4, ensure_ascii=False
)
return {"results": dict(results), "versions": dict(versions)} return {"results": dict(results), "versions": dict(versions)}
......
...@@ -40,6 +40,10 @@ def parse_args(): ...@@ -40,6 +40,10 @@ def parse_args():
parser.add_argument("--decontamination_ngrams_path", default=None) parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--description_dict_path", default=None) parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true") parser.add_argument("--check_integrity", action="store_true")
parser.add_argument(
"--write_detailed_eval_info", action="store_true", default=False
)
parser.add_argument("--detailed_eval_info_path", type=str, default=None)
return parser.parse_args() return parser.parse_args()
...@@ -88,6 +92,8 @@ def main(): ...@@ -88,6 +92,8 @@ def main():
description_dict=description_dict, description_dict=description_dict,
decontamination_ngrams_path=args.decontamination_ngrams_path, decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity, check_integrity=args.check_integrity,
write_detailed_eval_info=args.write_detailed_eval_info,
detailed_eval_info_path=args.detailed_eval_info_path,
) )
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment