Unverified Commit edd7dde3 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge branch 'big-refactor' into haileyschoelkopf-patch-2

parents fb436108 365fcda9
......@@ -2,7 +2,7 @@ group:
- super-glue-lm-eval-v1
task: wsc
dataset_path: super_glue
dataset_name: wsc
dataset_name: wsc.fixed
output_type: multiple_choice
training_split: train
validation_split: validation
......
import re
from lm_eval.utils import general_detokenize
def t5_prompt_doc_to_text(x):
def _mark_span(text, span_str, span_idx, mark):
pattern_tmpl = r"^((?:\S+\s){N})(W)"
pattern = re.sub("N", str(span_idx), pattern_tmpl)
pattern = re.sub("W", span_str, pattern)
return re.sub(pattern, r"\1{0} \2 {0}".format(mark), text)
text = x["text"]
text = _mark_span(text, x["span1_text"], x["span1_index"], "*")
# Compensate for 2 added "words" added in previous step.
span2_index = x["span2_index"] + 2 * (x["span1_index"] < x["span2_index"])
text = _mark_span(text, x["span2_text"], span2_index, "#")
return text
def default_doc_to_text(x):
raw_passage = x["text"]
# NOTE: HuggingFace span indices are word-based not character-based.
......
......@@ -2,16 +2,17 @@ group:
- super-glue-t5-prompt
task: super_glue-wsc-t5-prompt
dataset_path: super_glue
dataset_name: wsc
dataset_name: wsc.fixed
training_split: train
validation_split: validation
output_type: greedy_until
doc_to_text: !function "preprocess_wsc.t5_prompt_doc_to_text"
doc_to_text: !function "t5_utils.doc_to_text"
doc_to_target: label
doc_to_choice: ['False', 'True']
metric_list:
- metric: exact_match
- metric: accuracy
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
filter_list:
- name: "wsc_postprocessor"
filter:
- function: !function t5_utils.WSCPostprocess
import re
from lm_eval.api.filter import Filter
def doc_to_text(x):
text = re.sub(r" X ", " *" + x["span2_text"] + "* ", _wsc_inputs(x))
return "wsc: " + text
def _wsc_inputs(x):
words = x["text"].split(" ")
# We would need some special logic to handle the case where the pronoun is the
# first or last word in the text. None of the examples in WSC seem to have
# this, so we are ignoring these cases.
assert x["span2_index"] > 0
assert x["span2_index"] < len(words)
pronoun_index = x["span2_index"]
def create_input():
assert words[pronoun_index] == x["span2_text"]
return " ".join(
[
" ".join(words[:pronoun_index]),
"X",
" ".join(words[pronoun_index + 1 :]),
]
)
# Handle some special cases.
if (
x["text"]
== 'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. "Good for him," he said. '
):
return (
"The boy continued to whip the pony , and eventually the pony threw "
'him over. John laughed out quite loud. "Good for X ," he said.'
)
# Using the span2_index, we get 'use' instead of 'it'.
if (
x["text"]
== "When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use it , but really for now, what more could they wish for?"
):
return (
"When they had eventually calmed down a bit , and had gotten home, "
"Mr. Farley put the magic pebble in an iron safe . Some day they might "
"want to use X , but really for now, what more could they wish for?"
)
return create_input()
class WSCPostprocess(Filter):
def __init__(self, **kwargs):
self.determiners = {
"a",
"an",
"few",
"her",
"his",
"each",
"every",
"many",
"much",
"my",
"our",
"some",
"that",
"the",
"their",
"these",
"this",
"those",
"which",
"whose",
"your",
}
def clean(self, s):
"""Ignore capitalization and determiners."""
s = s.strip().lower()
return " ".join([w for w in s.split(" ") if w not in self.determiners])
def apply(self, resps, docs):
filtered_resps = []
for prediction, reference in zip(*(resps, docs["span1_text"])):
prediction = self.clean(prediction[0])
reference = self.clean(reference)
if ("'" in prediction) != ("'" in reference):
# referent is "Bob's hat" as predicting the referent.
predicted_referent = False
else:
prediction_words = set(prediction.split(" "))
referent_words = set(reference.split(" "))
# Handle cases where the prediction is "fuzzy bunny" and the referent is
# "bunny".
predicted_referent = prediction_words.issubset(
referent_words
) or referent_words.issubset(prediction_words)
filtered_resps.append(predicted_referent)
return filtered_resps
# Translation Tasks
### Paper
### Citation
```
```
### Groups and Tasks
#### Groups
* `gpt3_translation_tasks`
* `wmt14`
* `wmt16`
* `wmt20`
* `iwslt2017`
#### Tasks
*
### Checklist
For adding novel benchmarks/datasets to the library:
* [x] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [x] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
* [ ] Checked for equivalence with v0.3.0 LM Evaluation Harness
# Generated by utils.py
dataset_name: iwslt2017-en-ar
dataset_path: iwslt2017
doc_to_target: ' {{translation["en"]}}'
doc_to_text: 'Arabic phrase: {{translation["ar"]}}
English phrase:'
group:
- greedy_until
- translation
- iwslt2017
include: wmt_common_yaml
task: iwslt2017-ar-en
# Generated by utils.py
dataset_name: iwslt2017-en-ar
dataset_path: iwslt2017
doc_to_target: ' {{translation["ar"]}}'
doc_to_text: 'English phrase: {{translation["en"]}}
Arabic phrase:'
group:
- greedy_until
- translation
- iwslt2017
include: wmt_common_yaml
task: iwslt2017-en-ar
import argparse
from typing import Dict, List
import yaml
import sacrebleu
try:
import pycountry
except ModuleNotFoundError:
raise Exception(
"`pycountry` is required for generating translation task prompt templates. \
please install pycountry via pip install lm-eval[multilingua] or pip install -e .[multilingual]",
)
# Different translation benchmarks included in the library. Mostly WMT.
# These correspond to dataset names (subsets) on HuggingFace for each dataset.
# A yaml file is generated by this script for each language pair.
gpt3_translation_benchmarks = {
"wmt14": ["fr-en"], # ["en-fr", "fr-en"], # French
"wmt16": [
"ro-en",
"de-en",
], # ["en-ro", "ro-en", "de-en", "en-de"], # German, Romanian
}
# 28 total
LANGUAGES = {
**gpt3_translation_benchmarks,
# "wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt2017": ["en-ar"], # Arabic
}
def code_to_language(code):
# key is alpha_2 or alpha_3 depending on the code length
language_tuple = pycountry.languages.get(**{f"alpha_{len(code)}": code})
return language_tuple.name
def gen_lang_yamls(output_dir: str, overwrite: bool) -> None:
"""
Generate a yaml file for each language.
:param output_dir: The directory to output the files to.
:param overwrite: Whether to overwrite files if they already exist.
"""
err = []
for lang in LANGUAGES.keys():
for dataset_name in LANGUAGES[lang]:
src_lang, _, tgt_lang = dataset_name.partition("-")
for src, tgt in [[src_lang, tgt_lang], [tgt_lang, src_lang]]:
# both translation directions for each lang pair
lang_pair = src + "-" + tgt
file_name = f"{lang}_{lang_pair}.yaml"
try:
source, target = code_to_language(src), code_to_language(tgt)
groups = ["greedy_until", "translation", lang]
if lang in gpt3_translation_benchmarks.keys():
groups += ["gpt3_translation_benchmarks"]
with open(
f"{output_dir}/{file_name}",
"w" if overwrite else "x",
encoding="utf8",
) as f:
f.write("# Generated by utils.py\n")
yaml.dump(
{
"include": "wmt_common_yaml",
"group": groups,
"dataset_path": lang,
"dataset_name": dataset_name
if not (lang == "iwslt2017")
else "iwslt2017-" + dataset_name,
"task": f"{lang}-{lang_pair}",
"doc_to_text": f"{source} phrase: "
+ "{{translation["
+ f'"{src}"'
+ "]}}\n"
+ f"{target} phrase:",
"doc_to_target": " {{"
+ "translation["
+ f'"{tgt}"]'
+ "}}",
},
f,
)
except FileExistsError:
err.append(file_name)
if len(err) > 0:
raise FileExistsError(
"Files were not created because they already exist (use --overwrite flag):"
f" {', '.join(err)}"
)
def main() -> None:
"""Parse CLI args and generate language-specific yaml files."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--overwrite",
default=False,
action="store_true",
help="Overwrite files if they already exist",
)
parser.add_argument(
"--output-dir", default=".", help="Directory to write yaml files to"
)
args = parser.parse_args()
gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite)
if __name__ == "__main__":
main()
# Generated by utils.py
dataset_name: fr-en
dataset_path: wmt14
doc_to_target: ' {{translation["fr"]}}'
doc_to_text: 'English phrase: {{translation["en"]}}
French phrase:'
group:
- greedy_until
- translation
- wmt14
- gpt3_translation_benchmarks
include: wmt_common_yaml
task: wmt14-en-fr
# Generated by utils.py
dataset_name: fr-en
dataset_path: wmt14
doc_to_target: ' {{translation["en"]}}'
doc_to_text: 'French phrase: {{translation["fr"]}}
English phrase:'
group:
- greedy_until
- translation
- wmt14
- gpt3_translation_benchmarks
include: wmt_common_yaml
task: wmt14-fr-en
# Generated by utils.py
dataset_name: de-en
dataset_path: wmt16
doc_to_target: ' {{translation["en"]}}'
doc_to_text: 'German phrase: {{translation["de"]}}
English phrase:'
group:
- greedy_until
- translation
- wmt16
- gpt3_translation_benchmarks
include: wmt_common_yaml
task: wmt16-de-en
# Generated by utils.py
dataset_name: de-en
dataset_path: wmt16
doc_to_target: ' {{translation["de"]}}'
doc_to_text: 'English phrase: {{translation["en"]}}
German phrase:'
group:
- greedy_until
- translation
- wmt16
- gpt3_translation_benchmarks
include: wmt_common_yaml
task: wmt16-en-de
# Generated by utils.py
dataset_name: ro-en
dataset_path: wmt16
doc_to_target: ' {{translation["ro"]}}'
doc_to_text: 'English phrase: {{translation["en"]}}
Romanian phrase:'
group:
- greedy_until
- translation
- wmt16
- gpt3_translation_benchmarks
include: wmt_common_yaml
task: wmt16-en-ro
# Generated by utils.py
dataset_name: ro-en
dataset_path: wmt16
doc_to_target: ' {{translation["en"]}}'
doc_to_text: 'Romanian phrase: {{translation["ro"]}}
English phrase:'
group:
- greedy_until
- translation
- wmt16
- gpt3_translation_benchmarks
include: wmt_common_yaml
task: wmt16-ro-en
output_type: greedy_until
training_split: train
validation_split: validation
fewshot_split: validation
test_split: test
metric_list:
- metric: bleu
- metric: ter
- metric: chrf
generation_kwargs:
until:
- "\n"
do_sample: false
temperature: 0.0
repeats: 1
......@@ -34,3 +34,15 @@ def wikitext_detokenizer(doc):
string = string.replace(" 's", "'s")
return string
def process_results(doc, results):
(loglikelihood,) = results
# IMPORTANT: wikitext counts number of words in *original doc before detokenization*
_words = len(re.split(r"\s+", doc["page"]))
_bytes = len(doc["page"].encode("utf-8"))
return {
"word_perplexity": (loglikelihood, _words),
"byte_perplexity": (loglikelihood, _bytes),
"bits_per_byte": (loglikelihood, _bytes),
}
......@@ -7,6 +7,7 @@ validation_split: validation
test_split: test
doc_to_text: ""
doc_to_target: !function preprocess_wikitext.wikitext_detokenizer
process_results: !function preprocess_wikitext.process_results
should_decontaminate: true
doc_to_decontamination_query: "{{page}}"
metric_list:
......
# WMT16
### Paper
Title: `Findings of the 2016 Conference on Machine Translation`
Abstract: http://www.aclweb.org/anthology/W/W16/W16-2301
Homepage: https://huggingface.co/datasets/wmt16
### Citation
```
@InProceedings{bojar-EtAl:2016:WMT1,
author = {Bojar, Ond
{r}ej and Chatterjee, Rajen and Federmann, Christian and Graham, Yvette and Haddow, Barry and Huck, Matthias and Jimeno Yepes, Antonio and Koehn, Philipp and Logacheva, Varvara and Monz, Christof and Negri, Matteo and Neveol, Aurelie and Neves, Mariana and Popel, Martin and Post, Matt and Rubino, Raphael and Scarton, Carolina and Specia, Lucia and Turchi, Marco and Verspoor, Karin and Zampieri, Marcos},
title = {Findings of the 2016 Conference on Machine Translation},
booktitle = {Proceedings of the First Conference on Machine Translation},
month = {August},
year = {2016},
address = {Berlin, Germany},
publisher = {Association for Computational Linguistics},
pages = {131--198},
url = {http://www.aclweb.org/anthology/W/W16/W16-2301}
}
```
### Groups and Tasks
#### Groups
* `wmt-t5-prompt`: Group for all wmt tasks with prompt templates used for T5 (`Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer`)
#### Tasks
With specific prompt styles
* `wmt-ro-en-t5-prompt`: WMT16 with the prompt template used for T5
### Checklist
For adding novel benchmarks/datasets to the library:
* [ ] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
import evaluate
def bleu(predictions, references):
return (predictions[0], references[0])
def agg_bleu(items):
bleu_fn = evaluate.load("bleu")
predictions, references = zip(*items)
return bleu_fn.compute(predictions=predictions, references=references)["bleu"]
group:
- wmt-t5-prompt
task: wmt-ro-en-t5-prompt
dataset_path: wmt16
dataset_name: ro-en
training_split: train
validation_split: validation
output_type: greedy_until
doc_to_text: "translate English to Romanian: {{translation.en}}"
doc_to_target: "{{translation.ro}}"
metric_list:
- metric: wer
aggregation: mean
higher_is_better: false
- metric: !function metrics.bleu
aggregation: !function metrics.agg_bleu
higher_is_better: true
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment