Unverified Commit bda68845 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #492 from juletx/eval-info

Add option to dump prompts and completions to a JSON file
parents 96a83d45 af913422
......@@ -271,6 +271,19 @@ python main.py \
--num_fewshot K
```
### Checking the Model Outputs
The `--write_out.py` script mentioned previously can be used to verify that the prompts look as intended. If you also want to save model outputs, you can use the `--write_out` parameter in `main.py` to dump JSON with prompts and completions. The output path can be chosen with `--output_base_path`. It is helpful for debugging and for exploring model outputs.
```sh
python main.py \
--model gpt2 \
--model_args device=<device-name> \
--tasks <task-name> \
--num_fewshot K \
--write_out \
--output_base_path <path>
```
### Running Unit Tests
To run the entire test suite, use:
......
......@@ -23,8 +23,9 @@ def simple_evaluate(
description_dict=None,
check_integrity=False,
decontamination_ngrams_path=None,
write_out=False,
output_base_path=None,
):
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
......@@ -50,6 +51,10 @@ def simple_evaluate(
Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
:param write_out: bool
If True, write details about prompts and logits to json for all tasks
:param output_base_path: str, optional
Directory to which detailed eval info will be written. Defaults to present working dir.
:return
Dictionary of results
"""
......@@ -91,6 +96,8 @@ def simple_evaluate(
bootstrap_iters=bootstrap_iters,
description_dict=description_dict,
decontamination_ngrams_path=decontamination_ngrams_path,
write_out=write_out,
output_base_path=output_base_path,
)
# add info about the model and few shot config
......@@ -122,6 +129,8 @@ def evaluate(
bootstrap_iters=100000,
description_dict=None,
decontamination_ngrams_path=None,
write_out=False,
output_base_path=None,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -139,6 +148,10 @@ def evaluate(
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
:param write_out: bool
If True, write all prompts, logits and metrics to json for offline analysis
:param output_base_path: str, optional
Directory to which detailed eval info will be written. Defaults to present working dir
:return
Dictionary of results
"""
......@@ -175,6 +188,7 @@ def evaluate(
# TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
docs = {}
write_out_info = {}
docs_for_decontamination = collections.defaultdict(list)
......@@ -197,6 +211,10 @@ def evaluate(
rnd = random.Random()
rnd.seed(42)
rnd.shuffle(task_docs)
print(f"Task: {task_name}; number of docs: {len(task_docs)}")
if write_out:
prompt_details = []
description = (
description_dict[task_name]
......@@ -207,7 +225,6 @@ def evaluate(
limit = int(len(task_docs) * limit) if limit < 1.0 else int(limit)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
if decontaminate and task.should_decontaminate():
docs_for_decontamination[(task_name, task_set)].append(
task.doc_to_decontamination_query(doc)
......@@ -218,6 +235,17 @@ def evaluate(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
reqs = task.construct_requests(doc, ctx)
if write_out:
prompt_details.append({"doc_id": doc_id})
# print the prompt for the first few documents
if doc_id < 1:
print(
f"Task: {task_name}; document {doc_id}; context prompt (starting on next line):\n{ctx}\n(end of prompt on previous line)"
)
print("Requests:", reqs)
if not isinstance(reqs, (list, tuple)):
reqs = [reqs]
for i, req in enumerate(reqs):
......@@ -226,6 +254,14 @@ def evaluate(
# doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append((i, task_name, doc, doc_id))
if write_out:
prompt_details[-1][f"prompt_{i}"] = "".join(
(map(lambda x: "".join(x), req.args))
)
if write_out:
write_out_info[task_name] = prompt_details
# Compare all tasks/sets at once to ensure a single training set scan
if decontaminate:
from lm_eval.decontamination.decontaminate import get_train_overlap
......@@ -254,6 +290,18 @@ def evaluate(
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp))
if write_out:
write_out_info[task_name][doc_id][f"logit_{i}"] = resp
task = task_dict[task_name]
if isinstance(task, lm_eval.base.MultipleChoiceTask):
write_out_info[task_name][doc_id]["truth"] = doc["gold"]
elif isinstance(task, lm_eval.tasks.winogrande.Winogrande):
write_out_info[task_name][doc_id]["truth"] = task.answer_to_num[
doc["answer"]
]
else:
write_out_info[task_name][doc_id]["truth"] = task.doc_to_target(doc)
vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task
......@@ -268,6 +316,9 @@ def evaluate(
for metric, value in metrics.items():
vals[(task_name, metric)].append(value)
if write_out:
write_out_info[task_name][doc_id][metric] = str(value)
# Re-use the evaluation for the decontaminated set by just ignoring the overlaps
if decontaminate and task_name in overlaps:
if doc_id not in overlaps[task_name]:
......@@ -296,6 +347,28 @@ def evaluate(
if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items)
if write_out:
import json
import pathlib
output_base_path = (
pathlib.Path(output_base_path)
if output_base_path is not None
else pathlib.Path(".")
)
try:
output_base_path.mkdir(parents=True, exist_ok=False)
except FileExistsError:
pass
for task_name, _ in task_dict_items:
with open(
output_base_path.joinpath(f"{task_name}_write_out_info.json"),
"w",
encoding="utf8",
) as fp:
json.dump(write_out_info[task_name], fp, indent=4, ensure_ascii=False)
return {"results": dict(results), "versions": dict(versions)}
......
......@@ -44,6 +44,8 @@ def parse_args():
parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
parser.add_argument("--write_out", action="store_true", default=False)
parser.add_argument("--output_base_path", type=str, default=None)
return parser.parse_args()
......@@ -92,6 +94,8 @@ def main():
description_dict=description_dict,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
write_out=args.write_out,
output_base_path=args.output_base_path,
)
dumped = json.dumps(results, indent=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment