Commit 99b0a42d authored by Julen Etxaniz's avatar Julen Etxaniz
Browse files

update parameter names and add docs

parent 2e046ce3
......@@ -271,6 +271,19 @@ python main.py \
--num_fewshot K
```
### Checking the Model Outputs
The `--write_out.py` script mentioned previously can be used to verify that the prompts look as intended. If you also want to save model outputs, you can use the `--write_out` parameter in `main.py` to dump JSON with prompts and completions. The output path can be chosen with `--output_base_path`. It is helpful for debugging and for exploring model outputs.
```sh
python main.py \
--model gpt2 \
--model_args device=<device-name> \
--tasks <task-name> \
--num_fewshot K \
--write_out \
--output_base_path <path>
```
### Running Unit Tests
To run the entire test suite, use:
......
......@@ -23,8 +23,8 @@ def simple_evaluate(
description_dict=None,
check_integrity=False,
decontamination_ngrams_path=None,
write_detailed_eval_info=False,
detailed_eval_info_path=None,
write_out=False,
output_base_path=None,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -51,9 +51,9 @@ def simple_evaluate(
Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
:param write_detailed_eval_info: bool
:param write_out: bool
If True, write details about prompts and logits to json for all tasks
:param detailed_eval_info_path: str, optional
:param output_base_path: str, optional
Directory to which detailed eval info will be written. Defaults to present working dir.
:return
Dictionary of results
......@@ -96,8 +96,8 @@ def simple_evaluate(
bootstrap_iters=bootstrap_iters,
description_dict=description_dict,
decontamination_ngrams_path=decontamination_ngrams_path,
write_detailed_eval_info=write_detailed_eval_info,
detailed_eval_info_path=detailed_eval_info_path,
write_out=write_out,
output_base_path=output_base_path,
)
# add info about the model and few shot config
......@@ -129,8 +129,8 @@ def evaluate(
bootstrap_iters=100000,
description_dict=None,
decontamination_ngrams_path=None,
write_detailed_eval_info=False,
detailed_eval_info_path=None,
write_out=False,
output_base_path=None,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -148,9 +148,9 @@ def evaluate(
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
:param write_detailed_eval_info: bool
:param write_out: bool
If True, write all prompts, logits and metrics to json for offline analysis
:param detailed_eval_info_path: str, optional
:param output_base_path: str, optional
Directory to which detailed eval info will be written. Defaults to present working dir
:return
Dictionary of results
......@@ -213,7 +213,7 @@ def evaluate(
rnd.shuffle(task_docs)
print(f"Task: {task_name}; number of docs: {len(task_docs)}")
if write_detailed_eval_info:
if write_out:
prompt_details = []
description = (
......@@ -234,7 +234,7 @@ def evaluate(
)
reqs = task.construct_requests(doc, ctx)
if write_detailed_eval_info:
if write_out:
prompt_details.append({"doc_id": doc_id})
# print the prompt for the first few documents
......@@ -252,12 +252,12 @@ def evaluate(
# doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append((i, task_name, doc, doc_id))
if write_detailed_eval_info:
if write_out:
prompt_details[-1][f"prompt_{i}"] = "".join(
(map(lambda x: "".join(x), req.args))
)
if write_detailed_eval_info:
if write_out:
detailed_eval_info[task_name] = prompt_details
# Compare all tasks/sets at once to ensure a single training set scan
......@@ -288,7 +288,7 @@ def evaluate(
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
process_res_queue[(task_name, doc_id)].append((i, resp))
if write_detailed_eval_info:
if write_out:
detailed_eval_info[task_name][doc_id][f"logit_{i}"] = resp
task = task_dict[task_name]
if isinstance(task, lm_eval.base.MultipleChoiceTask):
......@@ -316,7 +316,7 @@ def evaluate(
for metric, value in metrics.items():
vals[(task_name, metric)].append(value)
if write_detailed_eval_info:
if write_out:
detailed_eval_info[task_name][doc_id][metric] = str(value)
# Re-use the evaluation for the decontaminated set by just ignoring the overlaps
......@@ -347,25 +347,23 @@ def evaluate(
if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items)
if write_detailed_eval_info:
if write_out:
import json
import pathlib
detailed_eval_info_path = (
pathlib.Path(detailed_eval_info_path)
if detailed_eval_info_path is not None
output_base_path = (
pathlib.Path(output_base_path)
if output_base_path is not None
else pathlib.Path(".")
)
try:
detailed_eval_info_path.mkdir(parents=True, exist_ok=False)
output_base_path.mkdir(parents=True, exist_ok=False)
except FileExistsError:
pass
for task_name, _ in task_dict_items:
with open(
detailed_eval_info_path.joinpath(
f"{task_name}_detailed_eval_info.json"
),
output_base_path.joinpath(f"{task_name}_detailed_eval_info.json"),
"w",
encoding="utf8",
) as fp:
......
......@@ -40,10 +40,8 @@ def parse_args():
parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
parser.add_argument(
"--write_detailed_eval_info", action="store_true", default=False
)
parser.add_argument("--detailed_eval_info_path", type=str, default=None)
parser.add_argument("--write_out", action="store_true", default=False)
parser.add_argument("--output_base_path", type=str, default=None)
return parser.parse_args()
......@@ -92,8 +90,8 @@ def main():
description_dict=description_dict,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
write_detailed_eval_info=args.write_detailed_eval_info,
detailed_eval_info_path=args.detailed_eval_info_path,
write_out=args.write_out,
output_base_path=args.output_base_path,
)
dumped = json.dumps(results, indent=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment