Unverified Commit d88a566c authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #612 from EleutherAI/benchmark-scripts

[Refactor] Benchmark scripts
parents 4168c05f 29f12dd9
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "based on the previous passage"
use_prompt: "promptsource:based on the previous passage"
group:
- super-glue-promptsource
task: "GPT-3 Style"
dataset_path: super_glue
dataset_name: wsc.fixed
training_split: train
validation_split: validation
use_prompt: "promptsource:GPT-3 Style"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "I think they mean"
use_prompt: "promptsource:I think they mean"
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "Who or what is/are"
use_prompt: "promptsource:Who or what is/are"
...@@ -18,14 +18,14 @@ def t5_prompt_doc_to_text(x): ...@@ -18,14 +18,14 @@ def t5_prompt_doc_to_text(x):
return text return text
def default_doc_to_text(doc): def default_doc_to_text(x):
raw_passage = doc["text"] raw_passage = x["text"]
# NOTE: HuggingFace span indices are word-based not character-based. # NOTE: HuggingFace span indices are word-based not character-based.
pre = " ".join(raw_passage.split()[: doc["span2_index"]]) pre = " ".join(raw_passage.split()[: x["span2_index"]])
post = raw_passage[len(pre) + len(doc["span2_text"]) + 1 :] post = raw_passage[len(pre) + len(x["span2_text"]) + 1 :]
passage = general_detokenize(pre + " *{}*".format(doc["span2_text"]) + post) passage = general_detokenize(pre + " *{}*".format(x["span2_text"]) + post)
noun = doc["span1_text"] noun = x["span1_text"]
pronoun = doc["span2_text"] pronoun = x["span2_text"]
text = ( text = (
f"Passage: {passage}\n" f"Passage: {passage}\n"
+ f'Question: In the passage above, does the pronoun "*{pronoun}*" refer to "*{noun}*"?\n' + f'Question: In the passage above, does the pronoun "*{pronoun}*" refer to "*{noun}*"?\n'
......
...@@ -5,6 +5,7 @@ dataset_path: super_glue ...@@ -5,6 +5,7 @@ dataset_path: super_glue
dataset_name: wsc dataset_name: wsc
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: greedy_until
doc_to_text: !function "preprocess_wsc.t5_prompt_doc_to_text" doc_to_text: !function "preprocess_wsc.t5_prompt_doc_to_text"
doc_to_target: label doc_to_target: label
doc_to_choice: ['False', 'True'] doc_to_choice: ['False', 'True']
......
...@@ -7,6 +7,8 @@ validation_split: validation ...@@ -7,6 +7,8 @@ validation_split: validation
doc_to_text: !function preprocess_winogrande.doc_to_text doc_to_text: !function preprocess_winogrande.doc_to_text
doc_to_target: !function preprocess_winogrande.doc_to_target doc_to_target: !function preprocess_winogrande.doc_to_target
doc_to_choice: !function preprocess_winogrande.doc_to_choice doc_to_choice: !function preprocess_winogrande.doc_to_choice
should_decontaminate: true
doc_to_decontamination_query: sentence
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
......
...@@ -108,6 +108,10 @@ class MultiChoice: ...@@ -108,6 +108,10 @@ class MultiChoice:
# Returns a list containing all values of the source_list that # Returns a list containing all values of the source_list that
# match at least one of the patterns # match at least one of the patterns
def pattern_match(patterns, source_list): def pattern_match(patterns, source_list):
if type(patterns) == str:
patterns = [patterns]
task_names = set() task_names = set()
for pattern in patterns: for pattern in patterns:
for matching in fnmatch.filter(source_list, pattern): for matching in fnmatch.filter(source_list, pattern):
...@@ -259,16 +263,20 @@ class Grouper: ...@@ -259,16 +263,20 @@ class Grouper:
return res return res
def make_table(result_dict): def make_table(result_dict, column="results"):
"""Generate table of results.""" """Generate table of results."""
from pytablewriter import MarkdownTableWriter, LatexTableWriter from pytablewriter import MarkdownTableWriter, LatexTableWriter
if column == "results":
column_name = "Task"
elif column == "aggregate":
column_name = "Benchmark"
md_writer = MarkdownTableWriter() md_writer = MarkdownTableWriter()
latex_writer = LatexTableWriter() latex_writer = LatexTableWriter()
md_writer.headers = [ md_writer.headers = [
"Task", column_name,
"Version", "Version",
"Fewshot",
"Filter", "Filter",
"Metric", "Metric",
"Value", "Value",
...@@ -276,7 +284,7 @@ def make_table(result_dict): ...@@ -276,7 +284,7 @@ def make_table(result_dict):
"Stderr", "Stderr",
] ]
latex_writer.headers = [ latex_writer.headers = [
"Task", column_name,
"Version", "Version",
"Fewshot", "Fewshot",
"Filter", "Filter",
...@@ -288,7 +296,7 @@ def make_table(result_dict): ...@@ -288,7 +296,7 @@ def make_table(result_dict):
values = [] values = []
for k, dic in result_dict["results"].items(): for k, dic in result_dict[column].items():
version = result_dict["versions"][k] version = result_dict["versions"][k]
n = str(result_dict["configs"][k]["num_fewshot"]) n = str(result_dict["configs"][k]["num_fewshot"])
for (mf), v in dic.items(): for (mf), v in dic.items():
......
...@@ -10,6 +10,7 @@ from pathlib import Path ...@@ -10,6 +10,7 @@ from pathlib import Path
from lm_eval import evaluator, utils from lm_eval import evaluator, utils
from lm_eval.api.registry import ALL_TASKS from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
from lm_eval.tasks import include_task_folder
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
...@@ -23,7 +24,7 @@ def parse_args(): ...@@ -23,7 +24,7 @@ def parse_args():
help="String arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`", help="String arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
) )
parser.add_argument( parser.add_argument(
"--tasks", default=None, choices=utils.MultiChoice(sorted(ALL_TASKS)) "--tasks", default=None # , choices=utils.MultiChoice(sorted(ALL_TASKS))
) )
parser.add_argument( parser.add_argument(
"--num_fewshot", "--num_fewshot",
...@@ -82,6 +83,18 @@ def parse_args(): ...@@ -82,6 +83,18 @@ def parse_args():
default=False, default=False,
help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis", help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis",
) )
parser.add_argument(
"--show_config",
action="store_true",
default=False,
help="If True, shows the the full config of all tasks at the end of the evaluation.",
)
parser.add_argument(
"--include_path",
type=str,
default=None,
help="Additional path to include if there are external tasks to include.",
)
return parser.parse_args() return parser.parse_args()
...@@ -94,6 +107,10 @@ def main(): ...@@ -94,6 +107,10 @@ def main():
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
) )
if args.include_path is not None:
eval_logger.info(f"Including path: {args.include_path}")
include_task_folder(args.include_path)
if args.tasks is None: if args.tasks is None:
task_names = ALL_TASKS task_names = ALL_TASKS
else: else:
...@@ -120,6 +137,7 @@ def main(): ...@@ -120,6 +137,7 @@ def main():
eval_logger.warning( eval_logger.warning(
f"File already exists at {path}. Results will be overwritten." f"File already exists at {path}. Results will be overwritten."
) )
output_path_file = path.joinpath("results.json")
assert not path.is_file(), "File already exists" assert not path.is_file(), "File already exists"
# if path json then get parent dir # if path json then get parent dir
elif path.suffix in (".json", ".jsonl"): elif path.suffix in (".json", ".jsonl"):
...@@ -154,7 +172,8 @@ def main(): ...@@ -154,7 +172,8 @@ def main():
if args.log_samples: if args.log_samples:
samples = results.pop("samples") samples = results.pop("samples")
dumped = json.dumps(results, indent=2, default=lambda o: str(o)) dumped = json.dumps(results, indent=2, default=lambda o: str(o))
print(dumped) if args.show_config:
print(dumped)
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
...@@ -164,7 +183,7 @@ def main(): ...@@ -164,7 +183,7 @@ def main():
if args.log_samples: if args.log_samples:
for task_name, config in results["configs"].items(): for task_name, config in results["configs"].items():
output_name = "{}_{}".format( output_name = "{}_{}".format(
re.sub("/", "__", args.model_args), task_name re.sub("/|=", "__", args.model_args), task_name
) )
filename = path.joinpath(f"{output_name}.jsonl") filename = path.joinpath(f"{output_name}.jsonl")
...@@ -176,6 +195,8 @@ def main(): ...@@ -176,6 +195,8 @@ def main():
f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}" f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
) )
print(evaluator.make_table(results)) print(evaluator.make_table(results))
if "aggregate" in results:
print(evaluator.make_table(results, "aggregate"))
if __name__ == "__main__": if __name__ == "__main__":
......
import os
import yaml
import argparse
from tqdm import tqdm
from promptsource.templates import DatasetTemplates
from lm_eval import utils
# from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger
# from lm_eval.tasks import include_task_folder
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--benchmark_name", required=True)
parser.add_argument("--benchmark_path", required=True)
parser.add_argument("--task_save_path", default="lm_eval/tasks/")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
with open(args.benchmark_path) as file:
TASK_LIST = yaml.full_load(file)
for task in tqdm(TASK_LIST):
eval_logger.info(f"Processing {task}")
dataset_name = task["dataset_path"]
if "dataset_name" in task:
subset_name = task["dataset_name"]
file_subdir = f"{dataset_name}/{subset_name}"
else:
subset_name = None
file_subdir = f"{dataset_name}"
file_path = os.path.join(args.task_save_path, file_subdir, "promptsource/")
os.makedirs(file_path, exist_ok=True)
if subset_name is None:
prompts = DatasetTemplates(dataset_name=dataset_name)
else:
prompts = DatasetTemplates(
dataset_name=dataset_name, subset_name=subset_name
)
for idx, prompt_name in enumerate(prompts.all_template_names):
full_file_name = f"promptsource_{idx}.yaml"
config_dict = {
"group": args.benchmark_name,
"include": "promptsource_template.yaml",
"use_prompts": f"promptsource:{prompt_name}",
}
file_save_path = os.path.join(file_path, full_file_name)
eval_logger.info(f"Save to {file_save_path}")
with open(file_save_path, "w") as yaml_file:
yaml.dump(config_dict, yaml_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment