Unverified Commit d88a566c authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #612 from EleutherAI/benchmark-scripts

[Refactor] Benchmark scripts
parents 4168c05f 29f12dd9
env
*.pyc
output/
data/
lm_cache
.idea
......
import os
import evaluate
from lm_eval.api.model import LM
from lm_eval.logger import eval_logger
MODEL_REGISTRY = {}
......@@ -130,7 +131,7 @@ searching in HF Evaluate library..."
metric_object = evaluate.load(name)
return metric_object.compute
except Exception:
raise Warning(
eval_logger.error(
"{} not found in the evaluate library!".format(name),
"Please check https://huggingface.co/evaluate-metric",
)
......@@ -153,7 +154,7 @@ def get_aggregation(name):
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
raise Warning(
eval_logger.warning(
"{} not a registered aggregation metric!".format(name),
)
......@@ -162,7 +163,9 @@ def get_default_aggregation(metric_name):
try:
return DEFAULT_AGGREGATION_REGISTRY[metric_name]
except KeyError:
raise Warning(f"No default aggregation metric for metric '{metric_name}'!")
eval_logger.warning(
f"No default aggregation metric for metric '{metric_name}'!"
)
def is_higher_better(metric_name):
......@@ -170,3 +173,6 @@ def is_higher_better(metric_name):
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
raise Warning(f"higher_is_better not specified for metric '{metric_name}'!")
eval_logger.warning(
f"higher_is_better not specified for metric '{metric_name}'!"
)
......@@ -70,6 +70,7 @@ class TaskConfig(dict):
doc_to_target: Union[Callable, str] = None
doc_to_choice: Union[Callable, str, dict, list] = None
gold_alias: Union[Callable, str] = None
process_results: Union[Callable, str] = None
use_prompt: str = None
description: str = ""
target_delimiter: str = " "
......@@ -545,8 +546,18 @@ class ConfigurableTask(Task):
for key in metric_config
if key not in ["metric", "aggregation", "higher_is_better"]
}
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._metric_fn_kwargs[metric_name] = kwargs
if self._config.process_results is not None:
self._metric_fn_list[metric_name] = None
self._metric_fn_kwargs[metric_name] = {}
elif callable(metric_name):
metric_fn = metric_name.__call__
metric_name = metric_name.__name__
self._metric_fn_list[metric_name] = metric_fn
self._metric_fn_kwargs[metric_name] = kwargs
else:
self._metric_fn_list[metric_name] = get_metric(metric_name)
self._metric_fn_kwargs[metric_name] = kwargs
if "aggregation" in metric_config:
agg_name = metric_config["aggregation"]
......@@ -885,8 +896,8 @@ class ConfigurableTask(Task):
def process_results(self, doc, results):
# if callable(self._config.process_results):
# return self._config.process_results(doc, results)
if callable(self._config.process_results):
return self._config.process_results(doc, results)
result_dict = {}
use_metric = list(self._metric_fn_list.keys())
......@@ -975,7 +986,11 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "greedy_until":
gold = self.doc_to_target(doc)
if type(gold) == int:
choices = self.doc_to_choice(doc)
gold = choices[gold]
print(self._metric_fn_list)
for key, result in zip(self._metric_fn_list.keys(), results):
if self.multiple_target:
# in the case where we have multiple targets,
......
......@@ -192,14 +192,35 @@ def evaluate(
# decontaminate = decontamination_ngrams_path is not None
# stores the final result for each task, for each metric/filter pair.
results = collections.defaultdict(dict)
# Tracks each task's version.
versions = collections.defaultdict(dict)
# Tracks the YAML configs of all chosen tasks.
configs = collections.defaultdict(dict)
# logs info about each document evaluated.
samples = collections.defaultdict(list)
# tracks all Instances/requests a model must generate output on.
requests = collections.defaultdict(list)
# Stores task scores based on task grouping.
aggregate = collections.defaultdict(dict)
# tracks if a task was chosen via user selecting a group containing it
task_groups = collections.defaultdict(dict)
# stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal
padding_requests = collections.defaultdict(int)
# Stores group related keys and values for group-aggregation
aggregate = collections.defaultdict(dict)
task_groups = collections.defaultdict(dict)
# get lists of each type of request
for task_name, task in task_dict.items():
if type(task) == tuple:
group, task = task
task_groups[task_name] = group
versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config())
......@@ -243,6 +264,7 @@ def evaluate(
# compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
numpad = max(gathered_item) - gathered_item[lm.rank]
padding_requests[task.OUTPUT_TYPE] += numpad
### Run LM on inputs, get all outputs ###
# execute each type of request
......@@ -253,8 +275,8 @@ def evaluate(
for req in reqs:
cloned_reqs.extend([req] * req.repeats)
if (lm.world_size > 1) and (numpad > 0):
for _ in range(numpad):
if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
for _ in range(padding_requests[reqtype]):
cloned_reqs.extend([req] * req.repeats)
# run requests through model
......@@ -264,12 +286,14 @@ def evaluate(
for x, req in zip(resps, cloned_reqs):
req.resps.append(x)
if lm.world_size > 1:
lm.accelerator.wait_for_everyone()
if lm.world_size > 1:
lm.accelerator.wait_for_everyone()
### Postprocess outputs ###
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
for task_name, task in task_dict.items():
if type(task) == tuple:
group, task = task
task.apply_filters()
### Collect values of metrics on all datapoints ###
......@@ -277,6 +301,8 @@ def evaluate(
# unpack results and sort back in order and return control to Task
for task_name, task in task_dict.items():
if type(task) == tuple:
group, task = task
# TODO: make it possible to use a different metric per filter
# iterate over different filters used
for key in task.instances[0].filtered_resps.keys():
......@@ -362,7 +388,23 @@ def evaluate(
# aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items():
task = task_dict[task_name]
results[task_name][metric + "," + key] = task.aggregation()[metric](items)
if type(task) == tuple:
group, task = task
task_score = task.aggregation()[metric](items)
results[task_name][metric + "," + key] = task_score
# Need to put back in results
# pythia | acc
# | perplexity
# | word_perplexity
# | byte_perplexity
# | bits_per_byte
if bool(task_groups):
group_name = task_groups[task_name]
if metric not in aggregate[group_name]:
aggregate[group_name][metric] = [task_score]
else:
aggregate[group_name][metric].append(task_score)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
......@@ -377,10 +419,21 @@ def evaluate(
if stderr is not None:
results[task_name][metric + "_stderr" + "," + key] = stderr(items)
if bool(aggregate):
for group in aggregate.keys():
for metric in aggregate[group].keys():
aggregate[group][metric] = np.average(aggregate[group][metric])
versions[group] = "N/A"
results_dict = {
"results": dict(results),
"configs": dict(configs),
"versions": dict(versions),
"results": dict(sorted(results.items())),
**(
{"aggregate": dict(sorted(aggregate.items()))}
if bool(aggregate)
else {}
),
"configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())),
}
if log_samples:
results_dict["samples"] = dict(samples)
......
from lm_eval import utils
from lm_eval.logger import eval_logger
# Prompt library.
......@@ -51,3 +52,17 @@ def get_prompt(prompt_id: str, dataset_name=None, subset_name=None):
f"expected only a single `:` as separator between \
prompt category and name, but got `{prompt_id}` instead"
)
def load_prompt_list(use_prompt: str, dataset_name=None, subset_name=None, **kwargs):
from promptsource.templates import DatasetTemplates
if subset_name is None:
prompts = DatasetTemplates(dataset_name=dataset_name)
else:
prompts = DatasetTemplates(dataset_name=dataset_name, subset_name=subset_name)
category_name, prompt_name = use_prompt.split(":")
prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
return [":".join([category_name, prompt]) for prompt in prompt_list]
import os
import yaml
from typing import List, Union
from lm_eval import utils
from lm_eval import prompts
from lm_eval.logger import eval_logger
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
from lm_eval.api.registry import (
......@@ -13,6 +15,58 @@ from lm_eval.api.registry import (
)
def register_configurable_task(config):
SubClass = type(
config["task"] + "ConfigurableTask",
(ConfigurableTask,),
{"CONFIG": TaskConfig(**config)},
)
if "task" in config:
task_name = "{}".format(config["task"])
register_task(task_name)(SubClass)
if "group" in config:
if type(config["group"]) == str:
group_name = [config["group"]]
else:
group_name = config["group"]
for group in group_name:
register_group(group)(SubClass)
return 0
def check_prompt_config(config):
all_configs = []
if "use_prompt" in config:
prompt_list = prompts.load_prompt_list(
use_prompt=config["use_prompt"],
dataset_name=config["dataset_path"],
subset_name=config["dataset_name"],
)
for idx, prompt_variation in enumerate(prompt_list):
all_configs.append(
{
**config,
**{"use_prompt": prompt_variation},
**{
"task": "_".join(
[
get_task_name_from_config(config),
prompt_variation,
]
)
},
**{"output_type": "greedy_until"},
}
)
else:
all_configs.append(config)
return all_configs
def get_task_name_from_config(task_config):
if "dataset_name" in task_config:
return "{dataset_path}_{dataset_name}".format(**task_config)
......@@ -31,23 +85,10 @@ def include_task_folder(task_dir):
yaml_path = os.path.join(root, f)
try:
config = utils.load_yaml_config(yaml_path)
all_configs = check_prompt_config(config)
for config in all_configs:
register_configurable_task(config)
SubClass = type(
config["task"] + "ConfigurableTask",
(ConfigurableTask,),
{"CONFIG": TaskConfig(**config)},
)
if "task" in config:
# task_name = "{}:{}".format(
# get_task_name_from_config(config), config["task"]
# )
task_name = "{}".format(config["task"])
register_task(task_name)(SubClass)
if "group" in config:
for group in config["group"]:
register_group(group)(SubClass)
except Exception as error:
eval_logger.warning(
"Failed to load config in\n"
......@@ -57,8 +98,58 @@ def include_task_folder(task_dir):
)
def include_benchmarks(task_dir, benchmark_dir="benchmarks"):
for root, subdirs, file_list in os.walk(os.path.join(task_dir, benchmark_dir)):
if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
for f in file_list:
if f.endswith(".yaml"):
try:
benchmark_path = os.path.join(root, f)
with open(benchmark_path, "rb") as file:
yaml_config = yaml.full_load(file)
assert "group" in yaml_config
group = yaml_config["group"]
all_task_list = yaml_config["task"]
config_list = [
task for task in all_task_list if type(task) != str
]
task_list = [
task for task in all_task_list if type(task) == str
]
for task_config in config_list:
var_configs = check_prompt_config(
{
**task_config,
**{"group": group},
}
)
for config in var_configs:
register_configurable_task(config)
task_names = utils.pattern_match(task_list, ALL_TASKS)
for task in task_names:
if task in TASK_REGISTRY:
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(task)
else:
GROUP_REGISTRY[group] = [task]
ALL_TASKS.add(group)
except Exception as error:
eval_logger.warning(
"Failed to load benchmark in\n"
f" {benchmark_path}\n"
" Benchmark will not be added to registry\n"
f" Error: {error}"
)
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_task_folder(task_dir)
include_benchmarks(task_dir)
def get_task(task_name, config):
......@@ -97,11 +188,15 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
if isinstance(task_element, str):
if task_element in GROUP_REGISTRY:
group_name = task_element
for task_name in GROUP_REGISTRY[task_element]:
if task_name not in task_name_from_registry_dict:
task_name_from_registry_dict = {
**task_name_from_registry_dict,
task_name: get_task(task_name=task_name, config=config),
task_name: (
group_name,
get_task(task_name=task_name, config=config),
),
}
else:
task_name = task_element
......
group: pythia
task:
- lambada_openai
- wikitext
- piqa
- sciq
- wsc
- winogrande
- arc_*
# - logiqa
# - blimp_*
# - hendrycksTest*
group: t0_eval
task:
# # Coreference Resolution
# - dataset_path: super_glue
# dataset_name: wsc.fixed
# use_prompt: promptsource:*
# training_split: train
# validation_split: validation
# metric_list:
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
# # Coreference Resolution
# - dataset_path: winogrande
# dataset_name: winogrande_xl
# use_prompt: promptsource:*
# training_split: train
# validation_split: validation
# metric_list:
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
# Natural Language Inference
- dataset_path: super_glue
dataset_name: cb
use_prompt: promptsource:*
training_split: train
validation_split: validation
output_type: greedy_until
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
# Natural Language Inference
# - dataset_path: super_glue
# dataset_name: rte
# use_prompt: promptsource:*
# training_split: train
# validation_split: validation
# metric_list:
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
# # Natural Language Inference
# # - dataset_path: anli
# # use_prompt: promptsource:*
# # training_split: train_r1
# # validation_split: dev_r1
# # Sentence Completion
# - dataset_path: super_glue
# dataset_name: copa
# use_prompt: promptsource:*
# training_split: train
# validation_split: validation
# metric_list:
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
# # Natural Language Inference
# - dataset_path: hellaswag
# use_prompt: promptsource:*
# training_split: train
# validation_split: validation
# metric_list:
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
# # Word Sense Disambiguation
# - dataset_path: super_glue
# dataset_name: wic
# use_prompt: promptsource:*
# training_split: train
# validation_split: validation
# metric_list:
# - metric: exact_match
# aggregation: mean
# higher_is_better: true
# ignore_case: true
# ignore_punctuation: true
group:
- super-glue-lm-eval-v1
task: "boolq"
task: boolq
dataset_path: super_glue
dataset_name: boolq
output_type: multiple_choice
......
......@@ -5,11 +5,15 @@ dataset_path: super_glue
dataset_name: cb
training_split: train
validation_split: validation
output_type: greedy_until
doc_to_text: "cb hypothesis: {{hypothesis}} premise {{premise}}"
doc_to_target: "{% set answer_choices = ['entailment', 'contradiction', 'neutral'] %}{{answer_choices[label]}}"
doc_to_target: label
doc_to_choice: ['entailment', 'contradiction', 'neutral']
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
- metric: f1
aggregation: !function "aggregate.cb_multi_fi"
......@@ -5,8 +5,10 @@ dataset_path: super_glue
dataset_name: copa
training_split: train
validation_split: validation
output_type: greedy_until
doc_to_text: "copa choice1: {{choice1}} choice2: {{choice2}} question: {{question}}"
doc_to_target: "{% set answer_choices = ['False', 'True'] %}{{answer_choices[label]}}"
doc_to_target: label
doc_to_choice: ['False', 'True']
metric_list:
- metric: exact_match
aggregation: mean
......
group:
- super-glue-promptsource
task: "I was going to say…"
dataset_path: super_glue
dataset_name: multirc
training_split: train
validation_split: validation
use_prompt: "promptsource:I was going to say…"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "Would it be good to answer…"
use_prompt: "promptsource:Would it be good to answer…"
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "confirm"
use_prompt: "promptsource:confirm"
# group:
# - super-glue-lm-eval-v1
group:
- super-glue-lm-eval-v1
task: record
dataset_path: super_glue
dataset_name: record
......@@ -9,6 +9,10 @@ validation_split: validation
doc_to_text: !function util.doc_to_text
doc_to_target: "{{answers}}"
doc_to_choice: "{{entities}}"
process_results: !function util.process_results
metric_list:
- metric: f1
aggregation: mean
- metric: em
higher_is_better: True
aggregation: mean
group:
- super-glue-promptsource
task: "Add sentence after (continuation choices)"
dataset_path: super_glue
dataset_name: record
training_split: train
validation_split: validation
use_prompt: "promptsource:Add sentence after (continuation choices)"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
......@@ -5,6 +5,7 @@ dataset_path: super_glue
dataset_name: record
training_split: train
validation_split: validation
output_type: greedy_until
doc_to_text: "record query: {{query}} entities: {{entities}} passage: {{passage}}"
doc_to_target: "{{answers}}"
metric_list:
......
import numpy as np
import transformers.data.metrics.squad_metrics as squad_metrics
from lm_eval.api.metrics import metric_max_over_ground_truths
def doc_to_text(doc):
initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n")
text = initial_text + "\n\n"
......@@ -13,3 +19,25 @@ def format_answer(query, entity):
def doc_to_target(doc):
# We only output the first correct entity in a doc
return format_answer(query=doc["query"], entity=doc["answers"][0])
def process_results(doc, results):
# ReCoRD's evaluation is actually deceptively simple:
# - Pick the maximum likelihood prediction entity
# - Evaluate the accuracy and token F1 PER EXAMPLE
# - Average over all examples
max_idx = np.argmax(np.array([result[0] for result in results]))
prediction = doc["entities"][max_idx]
gold_label_set = doc["answers"]
f1 = metric_max_over_ground_truths(
squad_metrics.compute_f1, prediction, gold_label_set
)
em = metric_max_over_ground_truths(
squad_metrics.compute_exact, prediction, gold_label_set
)
return {
"f1": f1,
"em": em,
}
group:
- super-glue-promptsource
task: "rte"
- super-glue-lm-eval-v1
task: rte
dataset_path: super_glue
dataset_name: rte
output_type: multiple_choice
training_split: train
validation_split: validation
use_prompt: "promptsource:GPT-3 style"
generation_kwargs:
until:
- "\n"
- "\n\n"
doc_to_text: "{{premise}}\nQuestion: {{hypothesis}} True or False?\nAnswer:"
doc_to_target: label
doc_to_choice: ['True', 'False']
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
- metric: acc
include: promptsource-00.yaml
group:
- super-glue-promptsource
task: "MNLI crowdsource"
use_prompt: "promptsource:MNLI crowdsource"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment