Commit 1ed76cfa authored by lintangsutawika's avatar lintangsutawika
Browse files

moved benchmarks back to tasks/

parent 10d8ed64
import os
import yaml
from lm_eval import utils
from lm_eval.tasks import register_configurable_task, check_prompt_config
from lm_eval.logger import eval_logger
from lm_eval.api.registry import (
TASK_REGISTRY,
GROUP_REGISTRY,
ALL_TASKS,
)
def include_benchmarks(task_dir: str) -> None:
for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
for f in file_list:
if f.endswith(".yaml"):
try:
benchmark_path = os.path.join(root, f)
with open(benchmark_path, "rb") as file:
yaml_config = yaml.full_load(file)
assert "group" in yaml_config
group = yaml_config["group"]
all_task_list = yaml_config["task"]
config_list = [
task for task in all_task_list if type(task) != str
]
task_list = [
task for task in all_task_list if type(task) == str
]
for task_config in config_list:
var_configs = check_prompt_config(
{
**task_config,
**{"group": group},
}
)
for config in var_configs:
register_configurable_task(config)
task_names = utils.pattern_match(task_list, ALL_TASKS)
for task in task_names:
if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(task)
else:
GROUP_REGISTRY[group] = [task]
ALL_TASKS.add(group)
except Exception as error:
eval_logger.warning(
"Failed to load benchmark in\n"
f" {benchmark_path}\n"
" Benchmark will not be added to registry\n"
f" Error: {error}"
)
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_benchmarks(task_dir)
......@@ -11,7 +11,6 @@ import numpy as np
import lm_eval.api
import lm_eval.tasks
import lm_eval.benchmarks
import lm_eval.models
import lm_eval.api.metrics
import lm_eval.api.registry
......
......@@ -37,6 +37,37 @@ def register_configurable_task(config: Dict[str, str]) -> int:
return 0
def register_configurable_group(config: Dict[str, str]) -> int:
group = config["group"]
all_task_list = config["task"]
config_list = [
task for task in all_task_list if type(task) != str
]
task_list = [
task for task in all_task_list if type(task) == str
]
for task_config in config_list:
var_configs = check_prompt_config(
{
**task_config,
**{"group": group},
}
)
for config in var_configs:
register_configurable_task(config)
task_names = utils.pattern_match(task_list, ALL_TASKS)
for task in task_names:
if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(task)
else:
GROUP_REGISTRY[group] = [task]
ALL_TASKS.add(group)
return 0
def check_prompt_config(config: Dict[str, str]) -> List[Dict[str, str]]:
all_configs = []
......@@ -87,9 +118,15 @@ def include_task_folder(task_dir: str) -> None:
yaml_path = os.path.join(root, f)
try:
config = utils.load_yaml_config(yaml_path)
all_configs = check_prompt_config(config)
for config in all_configs:
register_configurable_task(config)
# If a `task` in config is a list,
# that means it's a benchmark
if type(config["task"]) == list:
register_configurable_group(config)
else:
all_configs = check_prompt_config(config)
for config in all_configs:
register_configurable_task(config)
except Exception as error:
eval_logger.warning(
......
......@@ -4,44 +4,22 @@ output_type: greedy_until
training_split: train
validation_split: validation
doc_to_text: "Title: {{title}}\n\nBackground: {{context}}\n\nQuestion: {{question}}\n\n Answer:"
doc_to_target: "{% if answers.text| length > 0 %}{{answers.text}}{% else %}{{['unanswerable']}}{% endif %}"
doc_to_target: "{% if answers.text| length > 0 %}{{answers.text}}{% else %}{{['']}}{% endif %}"
target_delimiter: ""
should_decontaminate: true
doc_to_decontamination_query: context
process_results: !function utils.process_results
generation_kwargs:
until:
- "\n\n"
- "\n"
do_sample: false
temperature: 0.0
filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
- function: take_first
# filter_list:
# - name: remove_whitespace
# filter:
# - function: remove_whitespace
# - function: take_first
metric_list:
- metric: exact
aggregation: !function utils.exact
- metric: !function utils.exact
aggregation: mean
higher_is_better: true
- metric: !function utils.f1
aggregation: mean
higher_is_better: true
# - metric: f1
# aggregation: mean
# higher_is_better: true
# - metric: HasAns_exact
# aggregation: mean
# higher_is_better: true
# - metric: HasAns_f1
# aggregation: mean
# higher_is_better: true
# - metric: NoAns_exact
# aggregation: mean
# higher_is_better: true
# - metric: NoAns_f1
# aggregation: mean
# higher_is_better: true
# - metric: best_exact
# aggregation: mean
# higher_is_better: true
# - metric: best_f1
# aggregation: mean
# higher_is_better: true
import evaluate
from math import exp
from functools import partial
def process_results(doc, results):
continuation = results[0]
no_answer_probability = 0 # exp(logprob_unanswerable)
predictions = {
"id": doc["id"],
"prediction_text": continuation,
"no_answer_probability": no_answer_probability,
}
references = {
"id": doc["id"],
"answers": doc["answers"],
}
return {
"predictions": predictions,
"reference": references
}
# return _squad_metric([predictions], [references])
# return {key: value if key in metrics for key, value in score.items()}
def _squad_metric(predictions, references):
squad_metric = evaluate.load("squad_v2")
return squad_metric.compute(predictions=predictions, references=references)
import re
import string
import collections
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
return re.sub(regex, ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def get_tokens(s):
if not s: return []
return normalize_answer(s).split()
# Exact match (the normalized answer exactly match the gold answer)
def exact(items):
print(items)
import sys; sys.exit()
predictions, references = zip(*items)
return _squad_metric(predictions=predictions, references=references)["exact"]
def exact(predictions, references):
return int(normalize_answer(references[0]) == normalize_answer(predictions[0]))
# The F-score of predicted tokens versus the gold answer
def f1(predictions, references):
return _squad_metric(predictions=predictions, references=references)["f1"]
# Exact match (the normalized answer exactly match the gold answer)
def HasAns_exact(predictions, references):
return _squad_metric(predictions=predictions, references=references)["HasAns_exact"]
# The F-score of predicted tokens versus the gold answer
def HasAns_f1(predictions, references):
return _squad_metric(predictions=predictions, references=references)["HasAns_f1"]
# Exact match (the normalized answer exactly match the gold answer)
def NoAns_exact(predictions, references):
return _squad_metric(predictions=predictions, references=references)["NoAns_exact"]
# The F-score of predicted tokens versus the gold answer
def NoAns_f1(predictions, references):
return _squad_metric(predictions=predictions, references=references)["NoAns_f1"]
# Best exact match (with varying threshold)
def best_exact(predictions, references):
return _squad_metric(predictions=predictions, references=references)["best_exact"]
# Best F1 (with varying threshold)
def best_f1(predictions, references):
return _squad_metric(predictions=predictions, references=references)["best_f1"]
gold_toks = get_tokens(references[0])
pred_toks = get_tokens(predictions[0])
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if len(gold_toks) == 0 or len(pred_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
......@@ -11,7 +11,7 @@ from lm_eval import evaluator, utils
from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger, SPACING
from lm_eval.tasks import include_task_folder
from lm_eval.benchmarks import include_benchmarks
# from lm_eval.benchmarks import include_benchmarks
os.environ["TOKENIZERS_PARALLELISM"] = "false"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment