Commit 3263c572 authored by lintangsutawika's avatar lintangsutawika
Browse files

Merge branch 'big-refactor' of https://github.com/EleutherAI/lm-evaluation-harness into squadv2

parents a27e8ed1 33d52483
# Generated by utils.py
dataset_name: ja
doc_to_target: '{% if answer is not none %}{{answer[20+1]}}{% else %}{{answer_number|string}}{% endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nStep-by-Step Answer:"}}{% else %}{{"問題: "+question+"\nStep-by-Step Answer:"}}{% endif %}'
include: cot_yaml
task: mgsm_ja_direct
# Generated by utils.py
dataset_name: ru
doc_to_target: '{% if answer is not none %}{{answer[20+1]}}{% else %}{{answer_number|string}}{% endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nStep-by-Step Answer:"}}{% else %}{{"Задача: "+question+"\nStep-by-Step Answer:"}}{% endif %}'
include: cot_yaml
task: mgsm_ru_direct
# Generated by utils.py
dataset_name: sw
doc_to_target: '{% if answer is not none %}{{answer[20+1]}}{% else %}{{answer_number|string}}{% endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nStep-by-Step Answer:"}}{% else %}{{"Swali: "+question+"\nStep-by-Step Answer:"}}{% endif %}'
include: cot_yaml
task: mgsm_sw_direct
# Generated by utils.py
dataset_name: te
doc_to_target: '{% if answer is not none %}{{answer[20+1]}}{% else %}{{answer_number|string}}{% endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nStep-by-Step Answer:"}}{% else %}{{"ప్రశ్న: "+question+"\nStep-by-Step Answer:"}}{% endif %}'
include: cot_yaml
task: mgsm_te_direct
# Generated by utils.py
dataset_name: th
doc_to_target: '{% if answer is not none %}{{answer[20+1]}}{% else %}{{answer_number|string}}{% endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nStep-by-Step Answer:"}}{% else %}{{"โจทย์: "+question+"\nStep-by-Step Answer:"}}{% endif %}'
include: cot_yaml
task: mgsm_th_direct
# Generated by utils.py
dataset_name: zh
doc_to_target: '{% if answer is not none %}{{answer[20+1]}}{% else %}{{answer_number|string}}{% endif %}'
doc_to_text: '{% if answer is not none %}{{question+"\nStep-by-Step Answer:"}}{% else %}{{"问题: "+question+"\nStep-by-Step Answer:"}}{% endif %}'
include: cot_yaml
task: mgsm_zh_direct
import yaml
import argparse
LANGUAGES = {
"bn": { # Bengali
# "QUESTION": "প্রশ্ন:",
"QUESTION": "\u09aa\u09cd\u09b0\u09b6\u09cd\u09a8:",
# "ANSWER": "ধাপে ধাপে উত্তর:",
"ANSWER": "\u09a7\u09be\u09aa\u09c7 \u09a7\u09be\u09aa\u09c7 \u0989\u09a4\u09cd\u09a4\u09b0:",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
},
"de": { # German
"QUESTION": "Frage:",
# "ANSWER": "Schritt-für-Schritt-Antwort:",
"ANSWER": "Schritt-f\u00fcr-Schritt-Antwort:",
"DIRECT": "Antwort:",
"REGEX": "Die Antwort lautet (\\-?[0-9\\.\\,]+)",
},
"en": { # English
"QUESTION": "Question:",
"ANSWER": "Step-by-Step Answer:",
"DIRECT": "Answer:",
"REGEX": "The answer is (\\-?[0-9\\.\\,]+)",
},
"es": { # Spanish
"QUESTION": "Pregunta:",
"ANSWER": "Respuesta paso a paso:",
"DIRECT": "Respuesta:",
"REGEX": "La respuesta es (\\-?[0-9\\.\\,]+)",
},
"fr": { # French
"QUESTION": "Question :",
# "ANSWER": "Réponse étape par étape :"
"ANSWER": "R\u00e9ponse \u00e9tape par \u00e9tape :",
# "DIRECT": "Réponse :",
"DIRECT": "R\u00e9ponse :",
# "REGEX": "La réponse est (\\-?[0-9\\.\\,]+)",
"REGEX": "La r\u00e9ponse est (\\-?[0-9\\.\\,]+)",
},
"ru": { # Russian
# "QUESTION": "Задача:",
"QUESTION": "\u0417\u0430\u0434\u0430\u0447\u0430:",
# "ANSWER": "Пошаговоерешение:",
"ANSWER": "\u041f\u043e\u0448\u0430\u0433\u043e\u0432\u043e\u0435\u0440\u0435\u0448\u0435\u043d\u0438\u0435:",
"DIRECT": "Answer:",
# "REGEX": "Ответ — (\\-?[0-9\\.\\,]+)",
"REGEX": "\u041e\u0442\u0432\u0435\u0442 \u2014 (\\-?[0-9\\.\\,]+)",
},
"sw": { # Swahili
"QUESTION": "Swali:",
"ANSWER": "Jibu la Hatua kwa Hatua:",
"DIRECT": "Answer:",
"REGEX": "Jibu ni (\\-?[0-9\\.\\,]+)",
},
"te": { # Telugu
# "QUESTION": "ప్రశ్న:",
"QUESTION": "\u0c2a\u0c4d\u0c30\u0c36\u0c4d\u0c28:",
# "ANSWER": "దశలవారీగా సమాధానం:",
"ANSWER": "\u0c26\u0c36\u0c32\u0c35\u0c3e\u0c30\u0c40\u0c17\u0c3e \u0c38\u0c2e\u0c3e\u0c27\u0c3e\u0c28\u0c02:",
"DIRECT": "Answer:",
# "REGEX": "సమాధానం (\\-?[0-9\\.\\,]+)",
"REGEX": "\u0c38\u0c2e\u0c3e\u0c27\u0c3e\u0c28\u0c02 (\\-?[0-9\\.\\,]+)",
},
"th": { # Thai
# "QUESTION": "โจทย์:",
"QUESTION": "\u0e42\u0e08\u0e17\u0e22\u0e4c:",
# "ANSWER": "คำตอบทีละขั้นตอน:",
"ANSWER": "\u0e04\u0e33\u0e15\u0e2d\u0e1a\u0e17\u0e35\u0e25\u0e30\u0e02\u0e31\u0e49\u0e19\u0e15\u0e2d\u0e19:",
"DIRECT": "Answer:",
# "REGEX": "คำตอบคือ (\\-?[0-9\\.\\,]+)",
"REGEX": "\u0e04\u0e33\u0e15\u0e2d\u0e1a\u0e04\u0e37\u0e2d (\\-?[0-9\\.\\,]+)",
},
"ja": { # Japanese
# "QUESTION": "問題:",
"QUESTION": "\u554f\u984c:",
# "ANSWER": "ステップごとの答え:",
"ANSWER": "\u30b9\u30c6\u30c3\u30d7\u3054\u3068\u306e\u7b54\u3048:",
"DIRECT": "Answer:",
# "REGEX": "答えは(\\-?[0-9\\.\\,]+)です。",
"REGEX": "\u7b54\u3048\u306f(\\-?[0-9\\.\\,]+)\u3067\u3059\u3002",
},
"zh": { # Chinese
# "QUESTION": "问题:",
"QUESTION": "\u95ee\u9898:",
# "ANSWER": "逐步解答:",
"ANSWER": "\u9010\u6b65\u89e3\u7b54:",
"DIRECT": "Answer:",
# "REGEX": "答案是 (\\-?[0-9\\.\\,]+)。",
"REGEX": "\u7b54\u6848\u662f (\\-?[0-9\\.\\,]+)\u3002",
},
}
def add_regex_pattern(regex_pattern):
if regex_pattern is None:
return {}
return {
"filter_list": [
{
"name": "get-answer",
"filter": [
{
"function": "regex",
"regex_pattern": regex_pattern,
},
{
"function": "take_first",
},
],
},
],
}
def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None:
"""
Generate a yaml file for each language.
:param output_dir: The directory to output the files to.
:param overwrite: Whether to overwrite files if they already exist.
"""
err = []
for lang in LANGUAGES.keys():
try:
QUESTION = LANGUAGES[lang]["QUESTION"]
yaml_template = "cot_yaml"
filter_list = {}
if mode == "direct":
ANSWER = LANGUAGES[lang]["DIRECT"]
REGEX = None
task_name = f"mgsm_{lang}_direct"
yaml_template = "direct_yaml"
elif mode == "native-cot":
ANSWER = LANGUAGES[lang]["ANSWER"]
REGEX = LANGUAGES[lang]["REGEX"]
task_name = f"mgsm_{lang}_native-cot"
filter_list = add_regex_pattern(REGEX)
elif mode == "en-cot":
ANSWER = LANGUAGES["en"]["ANSWER"]
REGEX = LANGUAGES["en"]["REGEX"]
task_name = f"mgsm_{lang}_en-cot"
file_name = f"{task_name}.yaml"
with open(
f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8"
) as f:
f.write("# Generated by utils.py\n")
yaml.dump(
{
"include": yaml_template,
"dataset_name": lang,
"task": f"mgsm_{lang}_direct",
"doc_to_text": f"""{{% if answer is not none %}}"""
f"""{{{{question+"\\n{ANSWER}"}}}}"""
f"""{{% else %}}"""
f"""{{{{"{QUESTION} "+question+"\\n{ANSWER}"}}}}"""
f"""{{% endif %}}""",
"doc_to_target": f"""{{% if answer is not none %}}"""
f"""{{{{answer[{len(ANSWER)}+1]}}}}"""
f"""{{% else %}}"""
f"""{{{{answer_number|string}}}}"""
f"""{{% endif %}}""",
**filter_list,
},
f,
allow_unicode=True,
width=float("inf"),
)
except FileExistsError:
err.append(file_name)
if len(err) > 0:
raise FileExistsError(
"Files were not created because they already exist (use --overwrite flag):"
f" {', '.join(err)}"
)
def main() -> None:
"""Parse CLI args and generate language-specific yaml files."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--overwrite",
default=False,
action="store_true",
help="Overwrite files if they already exist",
)
parser.add_argument(
"--output-dir", default=".", help="Directory to write yaml files to"
)
parser.add_argument(
"--mode",
default="native-cot",
choices=["direct", "native-cot", "en-cot"],
help="Mode of chain-of-thought",
)
args = parser.parse_args()
gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite, mode=args.mode)
if __name__ == "__main__":
main()
task: nq_open
dataset_path: nq_open
output_type: greedy_until
training_split: train
validation_split: validation
description: "Answer these questions:\n"
doc_to_text: "Q: {{question}}?\nA:"
doc_to_target: "{{answer}}" # TODO: should be multi-target
fewshot_delimiter: "\n"
generation_kwargs:
until:
- "\n"
- "."
- ","
do_sample: false
temperature: 0.0
filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
- function: take_first
target_delimiter: " "
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- "\ban|a|the\b"
def doc_to_text(doc): def doc_to_text(doc) -> str:
ctxs = "\n".join(doc["context"]["contexts"]) ctxs = "\n".join(doc["context"]["contexts"])
return "Abstract: {}\nQuestion: {}\nAnswer:".format( return "Abstract: {}\nQuestion: {}\nAnswer:".format(
ctxs, doc["question"], doc["final_decision"] ctxs, doc["question"], doc["final_decision"]
) )
def doc_to_target(doc): def doc_to_target(doc) -> str:
return " {}".format(doc["final_decision"]) return " {}".format(doc["final_decision"])
......
import os import os
import sys
import json import json
import requests import requests
import numpy as np import numpy as np
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
try:
API_KEY = os.environ["PERSPECTIVE_API_KEY"]
except KeyError:
API_KEY = None
eval_logger.warning(
"PERSPECTIVE_API_KEY is not set. If you are running the `realtoxicityprompts` task, please set this environment variable."
)
url = f"https://commentanalyzer.googleapis.com/v1alpha1/comments:analyze?key={API_KEY}"
def toxicity_perspective_api(references, predictions, **kwargs): def toxicity_perspective_api(references, predictions, **kwargs):
"""Toxicity Perspective API is a metric which uses Perspective API to score the toxicity of a generated sentence. """Toxicity Perspective API is a metric which uses Perspective API to score the toxicity of a generated sentence.
...@@ -25,6 +14,16 @@ def toxicity_perspective_api(references, predictions, **kwargs): ...@@ -25,6 +14,16 @@ def toxicity_perspective_api(references, predictions, **kwargs):
Lower is better Lower is better
""" """
try:
API_KEY = os.environ["PERSPECTIVE_API_KEY"]
except KeyError:
eval_logger.warning(
"PERSPECTIVE_API_KEY is not set. If you are running the `realtoxicityprompts` task, please set this environment variable."
)
raise
url = f"https://commentanalyzer.googleapis.com/v1alpha1/comments:analyze?key={API_KEY}"
scores = [] scores = []
for pred in predictions: for pred in predictions:
data = { data = {
...@@ -35,7 +34,7 @@ def toxicity_perspective_api(references, predictions, **kwargs): ...@@ -35,7 +34,7 @@ def toxicity_perspective_api(references, predictions, **kwargs):
headers = { headers = {
"content-type": "application/json", "content-type": "application/json",
} }
req_response = requests.post(url, data=data, headers=headers) req_response = requests.post(url, json=data, headers=headers)
if req_response.ok: if req_response.ok:
response = json.loads(req_response.text) response = json.loads(req_response.text)
if ( if (
...@@ -54,6 +53,6 @@ def toxicity_perspective_api(references, predictions, **kwargs): ...@@ -54,6 +53,6 @@ def toxicity_perspective_api(references, predictions, **kwargs):
raise SystemExit(0) raise SystemExit(0)
else: else:
eval_logger.error("Unhandled Exception") eval_logger.error("Unhandled Exception")
raise SystemExit(0) req_response.raise_for_status()
return np.mean(scores) return np.mean(scores)
# SuperGLUE
### Paper
Title: `SuperGLUE: A Stickier Benchmark for General-Purpose Language Understanding Systems`
Abstract: `https://w4ngatang.github.io/static/papers/superglue.pdf`
SuperGLUE is a benchmark styled after GLUE with a new set of more difficult language
understanding tasks.
Homepage: https://super.gluebenchmark.com/
### Citation
```
@inproceedings{NEURIPS2019_4496bf24,
author = {Wang, Alex and Pruksachatkun, Yada and Nangia, Nikita and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel},
booktitle = {Advances in Neural Information Processing Systems},
editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett},
pages = {},
publisher = {Curran Associates, Inc.},
title = {SuperGLUE: A Stickier Benchmark for General-Purpose Language Understanding Systems},
url = {https://proceedings.neurips.cc/paper/2019/file/4496bf24afe7fab6f046bf4923da8de6-Paper.pdf},
volume = {32},
year = {2019}
}
```
### Groups and Tasks
#### Groups
* `super-glue-lm-eval-v1`: SuperGLUE eval adapted from LM Eval V1
* `super-glue-t5-prompt`: SuperGLUE prompt and evaluation that matches the T5 paper (if using accelerate, will error if record is included.)
#### Tasks
Comparison between validation split score on T5x and LM-Eval (T5x models converted to HF)
| T5V1.1 Base | SGLUE | BoolQ | CB | Copa | MultiRC | ReCoRD | RTE | WiC | WSC |
| ----------- | ------| ----- | --------- | ---- | ------- | ------ | --- | --- | --- |
| T5x | 69.47 | 78.47(acc) | 83.93(f1) 87.5(acc) | 50(acc) | 73.81(f1) 33.26(em) | 70.09(em) 71.34(f1) | 78.7(acc) | 63.64(acc) | 75(acc) |
| LM-Eval | 71.35 | 79.36(acc) | 83.63(f1) 87.5(acc) | 63(acc) | 73.45(f1) 33.26(em) | 69.85(em) 68.86(f1) | 78.34(acc) | 65.83(acc) | 75.96(acc) |
* `super-glue-lm-eval-v1`
- `boolq`
- `cb`
- `copa`
- `multirc`
- `record`
- `rte`
- `wic`
- `wsc`
* `super-glue-t5-prompt`
- `super_glue-boolq-t5-prompt`
- `super_glue-cb-t5-prompt`
- `super_glue-copa-t5-prompt`
- `super_glue-multirc-t5-prompt`
- `super_glue-record-t5-prompt`
- `super_glue-rte-t5-prompt`
- `super_glue-wic-t5-prompt`
- `super_glue-wsc-t5-prompt`
### Checklist
For adding novel benchmarks/datasets to the library:
* [ ] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
group:
- super-glue-t5-prompt
task: super_glue-boolq-t5-prompt
dataset_path: super_glue
dataset_name: boolq
training_split: train
validation_split: validation
output_type: greedy_until
doc_to_text: "boolq passage: {{passage}} question: {{question}}"
doc_to_target: label
doc_to_choice: ['False', 'True']
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
...@@ -6,7 +6,7 @@ dataset_name: cb ...@@ -6,7 +6,7 @@ dataset_name: cb
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: greedy_until output_type: greedy_until
doc_to_text: "cb hypothesis: {{hypothesis}} premise {{premise}}" doc_to_text: "cb hypothesis: {{hypothesis}} premise: {{premise}}"
doc_to_target: label doc_to_target: label
doc_to_choice: ['entailment', 'contradiction', 'neutral'] doc_to_choice: ['entailment', 'contradiction', 'neutral']
metric_list: metric_list:
......
import sklearn.metrics
def mean_3class_f1(predictions, references): # This is a passthrough function
string_label = ["entailment", "contradiction", "neutral"]
predictions = string_label.index(predictions[0])
references = string_label.index(references[0])
return (predictions, references)
def agg_mean_3class_f1(items):
predictions, references = zip(*items)
"""Computes the unweighted average of the F1 per class."""
metric_str = "fbeta_score"
metric_fn_kwargs = {
"beta": 1,
"labels": range(3),
"average": "macro",
}
def _fn(predictions, references):
metric_fn = getattr(sklearn.metrics, metric_str)
metric_val = metric_fn(references, predictions, **metric_fn_kwargs)
return metric_val
return _fn(predictions, references)
...@@ -6,9 +6,9 @@ dataset_name: copa ...@@ -6,9 +6,9 @@ dataset_name: copa
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: greedy_until output_type: greedy_until
doc_to_text: "copa choice1: {{choice1}} choice2: {{choice2}} question: {{question}}" doc_to_text: "copa choice1: {{choice1}} choice2: {{choice2}} premise: {{premise}} question: {{question}}"
doc_to_target: label doc_to_target: label
doc_to_choice: ['False', 'True'] doc_to_choice: ['choice1', 'choice2']
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
......
group:
- super-glue-t5-prompt
task: super_glue-multirc-t5-prompt
dataset_path: super_glue
dataset_name: multirc
training_split: train
validation_split: validation
output_type: greedy_until
doc_to_text: "multirc question: {{question}} answer: {{answer}} paragraph: {{paragraph}}"
doc_to_target: label
doc_to_choice: "{% set group_id = idx.question|string %}{{[group_id+'_False', group_id+'_True']}}"
generation_kwargs:
until:
- "</s>"
do_sample: false
temperature: 0.5
metric_list:
- metric: !function t5_utils.f1
aggregation: !function t5_utils.agg_f1
higher_is_better: true
- metric: !function t5_utils.em
aggregation: !function t5_utils.agg_em
higher_is_better: true
import collections
import numpy as np
import sklearn.metrics
def f1(predictions, references): # This is a passthrough function
_prediction = predictions[0]
_reference = references[0].split("_")[-1]
string_label = ["False", "True"]
reference = string_label.index(_reference)
prediction = (
string_label.index(_prediction)
if _prediction in string_label
else not bool(reference)
)
return (prediction, reference)
def agg_f1(items):
predictions, references = zip(*items)
references, predictions = np.asarray(references), np.asarray(predictions)
return sklearn.metrics.f1_score(references, predictions)
def em(predictions, references): # This is a passthrough function
_prediction = predictions[0]
_group, _reference = references[0].split("_")
string_label = ["False", "True"]
reference = string_label.index(_reference)
prediction = (
string_label.index(_prediction)
if _prediction in string_label
else not bool(reference)
)
return (_group, prediction, reference)
def agg_em(items):
grouped_values = collections.defaultdict(lambda: ([], []))
for group, prediction, reference in items:
grouped_values[group][0].append(reference)
grouped_values[group][1].append(prediction)
group_scores = []
for group, (targets, predictions) in grouped_values.items():
score = float(np.array_equal(targets, predictions))
group_scores.append(score)
return np.mean(group_scores)
...@@ -3,14 +3,15 @@ group: ...@@ -3,14 +3,15 @@ group:
task: super_glue-record-t5-prompt task: super_glue-record-t5-prompt
dataset_path: super_glue dataset_path: super_glue
dataset_name: record dataset_name: record
training_split: train
validation_split: validation validation_split: validation
output_type: greedy_until output_type: greedy_until
doc_to_text: "record query: {{query}} entities: {{entities}} passage: {{passage}}" process_docs: !function t5_utils.process_docs
doc_to_target: "{{answers}}" doc_to_text: !function t5_utils.doc_to_text
doc_to_target: "{{idx.passage|string}}+{{idx.query}}_{{answers}}"
metric_list: metric_list:
- metric: exact_match - metric: !function t5_utils.em
aggregation: mean aggregation: !function t5_utils.squad_em_agg
higher_is_better: true
- metric: !function t5_utils.f1
aggregation: !function t5_utils.squad_f1_agg
higher_is_better: true higher_is_better: true
ignore_case: true
ignore_punctuation: true
import re
import string
import collections
import numpy as np
from tqdm import tqdm
from datasets import Dataset, concatenate_datasets
from lm_eval.api.metrics import metric_max_over_ground_truths
def doc_to_text(doc):
passage = doc["passage"]
passage = re.sub(r"(\.|\?|\!|\"|\')\n@highlight\n", r"\1 ", passage)
passage = re.sub(r"\n@highlight\n", ". ", passage)
return " ".join(
[
"record query:",
doc["query"],
"entities:",
", ".join(doc["entities"]),
"passage:",
passage,
]
)
def process_docs(dataset):
def split_answers(doc):
split_doc = {
**{k: [] for k in doc.keys()},
}
answers = doc.pop("answers")
for idx, answer in enumerate(answers):
for key in split_doc.keys():
if key in doc:
split_doc[key].append(doc[key])
split_doc["answers"].append(answer)
return split_doc
dataset = dataset.map(split_answers)
new_dataset = {}
for key in dataset.features.keys():
new_dataset[key] = [x for row in dataset[key] for x in row]
return Dataset.from_dict(new_dataset)
def normalize_squad(answer):
"""Normalization used in official SQuAD evaluation script."""
def _normalize_answer(text, punc_chars, punc_repl):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(s):
return re.sub(r"\b(a|an|the)\b", " ", s)
def replace_punctuation(s):
to_replace = set(punc_chars)
return "".join(punc_repl if ch in to_replace else ch for ch in s)
def white_space_fix(s):
return " ".join(s.split())
text = text.lower()
text = replace_punctuation(text)
text = remove_articles(text)
text = white_space_fix(text)
return text
return _normalize_answer(answer, punc_chars=string.punctuation, punc_repl="")
def em(predictions, references): # This is a passthrough function
return (predictions[0], references[0])
def f1(predictions, references): # This is a passthrough function
return (predictions[0], references[0])
def squad_em_agg(items):
def _exact_match_score(prediction, target):
return target == prediction
grouped_values = collections.defaultdict(lambda: ([], []))
for prediction, reference in items:
group, reference = reference.split("_")
# if group not in grouped_values:
grouped_values[group][0].append(normalize_squad(prediction))
grouped_values[group][1].append(normalize_squad(reference))
em = []
for group in grouped_values.keys():
predictions, targets = grouped_values[group]
for p in predictions:
em.append(metric_max_over_ground_truths(_exact_match_score, p, targets))
return np.mean(em)
def squad_f1_agg(items):
def _f1_score(prediction, target):
"""Computes token f1 score for a single target and prediction."""
prediction_tokens = prediction.split()
target_tokens = target.split()
common = collections.Counter(prediction_tokens) & collections.Counter(
target_tokens
)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(target_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
grouped_values = collections.defaultdict(lambda: ([], []))
for prediction, reference in items:
group, reference = reference.split("_")
if group not in grouped_values:
grouped_values[group][0].append(normalize_squad(prediction))
grouped_values[group][1].append(normalize_squad(reference))
f1 = []
for group in grouped_values.keys():
p, t = grouped_values[group]
f1.append(metric_max_over_ground_truths(_f1_score, p[0], t))
return np.mean(f1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment