Commit 37c3139d authored by thefazzer's avatar thefazzer
Browse files

Merge remote-tracking branch 'origin/master' into fazz/refactor-task-coqa

parents 79c9b68a 7ad6bf45
import os
import numpy as np
from best_download import download_file
from lm_eval.base import MultipleChoiceTask, rf
from lm_eval.metrics import mean
import xml.etree.ElementTree as ET
import random
class QA4MRE(MultipleChoiceTask):
YEAR = None
def download(self):
year = self.YEAR
lang = "EN"
base_path = (
"http://nlp.uned.es/clef-qa/repository/js/scripts/downloadFile.php?"
"file=/var/www/html/nlp/clef-qa/repository/resources/QA4MRE/"
)
# TODO: add side tasks?
variable_year_path = {
2011: '2011/Training_Data/Goldstandard/',
2012: '2012/Main_Task/Training_Data/Goldstandard/Used_in_Evaluation/',
2013: '2013/Main_Task/Training_Data/Goldstandard/'
}
sha256sums = {
2011 : "6d2524952a3a015f2a82df785b85b5578681e3602ec276b4e72c01f4ebc50034",
2012 : "f9edaf408f8ac93f89a643a0d0b19263a1bb5ce64f19b2af10df279a656dfb24",
2013 : "c60e5aa4ec77e0493ef0b11d46bd1d74d58a499a3a2f871b8cf3af9536f0f094",
}
vpath = variable_year_path[year]
url_path = f"{base_path}{vpath}QA4MRE-{year}-{lang}_GS.xml"
if not os.path.exists("data/qa4mre"):
os.mkdir("data/qa4mre")
if not os.path.isfile(f"data/qa4mre/QA4MRE-{year}-{lang}"):
download_file(
url_path,
f"data/qa4mre/QA4MRE-{year}-{lang}_GS.xml",
checksum=sha256sums[year],
)
def has_training_docs(self):
return False
def has_validation_docs(self):
return False
def has_test_docs(self):
return True
def fewshot_examples(self, k):
# Since only test docs sample from test docs
if self._training_docs is None:
self._training_docs = list(self.test_docs())
return random.sample(self._training_docs, k)
def _convert_standard(self, question):
choices = [i.text for i in question.iter('answer')]
out_doc = {
"query" : question.find('q_str').text,
"choices": choices,
"gold" : int(question.find("./answer[@correct='Yes']").attrib["a_id"]) - 1,
}
return out_doc
def load_docs(self, textfilename, tfds=False):
tree = ET.parse(textfilename)
root = tree.getroot()
# TODO: context is much larger than the context sometimes
# at the moment, it just gets left-truncated by LM automatically, and maybe that's good enough?
for reading_test in root.iter('reading-test'):
src = reading_test[0].text
src = src.strip().replace("\'", "'")
for qid, question in enumerate(reading_test.iter('q')):
out_doc = self._convert_standard(question)
out_doc['source'] = src
yield out_doc
def fewshot_description(self):
return ""
def test_docs(self):
return self.load_docs(f"data/qa4mre/QA4MRE-{self.YEAR}-EN_GS.xml")
def doc_to_text(self, doc):
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"])
class QA4MRE_2011(QA4MRE):
YEAR = 2011
class QA4MRE_2012(QA4MRE):
YEAR = 2012
class QA4MRE_2013(QA4MRE):
YEAR = 2013
import collections
import datasets
import numpy as np
from lm_eval.base import rf, mean
from lm_eval.base import rf
from ..metrics import mean
from . common import HFTask
import os
......@@ -82,10 +83,13 @@ class RACE(HFTask):
def doc_to_text(self, doc):
text = 'Article: ' + doc['article'] + '\n\n'
for problem in doc['problems'][:-1]:
question = 'Q: ' + problem['question'] + '\n\n'
answer = 'A: ' + self.get_answer_option(problem) + '\n\n'
text += question + answer
text += 'Q: ' + self.last_problem(doc)['question'] + '\n\n' + 'A:'
if problem['question'][-6:] == ' _ .':
text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
else:
question = 'Question: ' + problem['question'] + '\n'
answer = 'Answer: ' + self.get_answer_option(problem) + '\n'
text += question + answer
text += self.last_problem(doc)['question']
return text
def doc_to_target(self, doc):
......
import json
import random
import os
from lm_eval.base import MultipleChoiceTask, rf, mean
from lm_eval.base import MultipleChoiceTask, rf
from ..metrics import mean
from tqdm import auto as tqdm_lib
from . common import simple_accuracy_metric
import numpy as np
......
import os
import json
from ..utils import sh
from lm_eval.base import MultipleChoiceTask, rf, mean
from lm_eval.base import MultipleChoiceTask, rf
from ..metrics import mean
import zipfile
from best_download import download_file
class SciQ(MultipleChoiceTask):
......@@ -10,9 +12,11 @@ class SciQ(MultipleChoiceTask):
def download(self):
if not os.path.exists('data/sciq'):
os.mkdir('data/sciq')
sh((
"wget https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip -O data/sciq/SciQ.zip"
))
download_file(
'https://ai2-public-datasets.s3.amazonaws.com/sciq/SciQ.zip',
'data/sciq/SciQ.zip',
'7f3312f6ac6b09970b32942d106a8c44ec0dad46a0369f17d635aff8e348a87c',
)
with zipfile.ZipFile("data/sciq/SciQ.zip", "r") as zf:
zf.extractall("data/sciq/")
......@@ -48,8 +52,6 @@ class SciQ(MultipleChoiceTask):
yield self._convert_standard(record)
def fewshot_description(self):
# Average ctx length in labelled dataset is 238.9
# 2 few-shot exmamples pushes it beyond context window
return ""
def training_docs(self):
......@@ -62,4 +64,4 @@ class SciQ(MultipleChoiceTask):
return self.load_docs("data/sciq/SciQ dataset-2 3/test.json")
def doc_to_text(self, doc):
return "{}\n{}".format(doc["source"], doc["query"])
\ No newline at end of file
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"]).strip()
\ No newline at end of file
......@@ -30,7 +30,7 @@ class SQuAD(HFTask):
return "Title: The_Title_of_It\n\nBackground: A text passage as background to answer the question with.\n\nQ: Question about the passage.\n\nA: Answer."
def doc_to_text(self, doc):
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Q: ' + doc['question'] + '\n\n' + 'A: '
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:'
def doc_to_target(self, doc):
answer_list = doc['answers']['text']
......
......@@ -5,9 +5,11 @@ To-do:
"""
import numpy as np
from . common import HFTask, yesno
from lm_eval.base import rf, mean, acc_all, metric_max_over_ground_truths
from lm_eval.base import rf
from ..metrics import mean, acc_all, metric_max_over_ground_truths
import sklearn
import transformers.data.metrics.squad_metrics as squad_metrics
from ..utils import general_detokenize
class BoolQ(HFTask):
......@@ -28,7 +30,7 @@ class BoolQ(HFTask):
return "Read the following passages and answer each question with a yes or a no."
def doc_to_text(self, doc):
return f"{doc['passage']}\nquestion: {doc['question']}\nanswer:"
return f"{doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def doc_to_target(self, doc):
return " " + yesno(doc['label'])
......@@ -80,7 +82,7 @@ class CommitmentBank(HFTask):
"to the truth of the hypothesis. The three possible labels are true, false or neither."
def doc_to_text(self, doc):
return "{}\nquestion: {} true, false or neither?\nanswer:".format(
return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format(
doc["premise"],
doc["hypothesis"],
)
......@@ -89,12 +91,12 @@ class CommitmentBank(HFTask):
# True = entailment
# False = contradiction
# Neither = neutral
return " {}".format({0: "true", 1: "neither", 2: "false"}[doc["label"]])
return " {}".format({0: "True", 1: "Neither", 2: "False"}[doc["label"]])
def construct_requests(self, doc, ctx):
ll_true, _ = rf.loglikelihood(ctx, ' true')
ll_neither, _ = rf.loglikelihood(ctx, ' neither')
ll_false, _ = rf.loglikelihood(ctx, ' false')
ll_true, _ = rf.loglikelihood(ctx, ' True')
ll_neither, _ = rf.loglikelihood(ctx, ' Neither')
ll_false, _ = rf.loglikelihood(ctx, ' False')
return ll_true, ll_neither, ll_false
......@@ -214,15 +216,15 @@ class MultiRC(HFTask):
return "READING COMPREHENSION ANSWER KEY"
def doc_to_text(self, doc):
return f"{doc['paragraph']}\n\n{doc['question']}\n"
return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:"
def doc_to_target(self, doc):
return self.format_answer(answer=doc["answer"], label=doc["label"])
return " " + self.format_answer(answer=doc["answer"], label=doc["label"])
@staticmethod
def format_answer(answer, label):
label_str = "True" if label else "False"
return f"[{label_str}] {answer}"
label_str = "yes" if label else "no"
return f"{label_str}, {answer}"
def construct_requests(self, doc, ctx):
true_choice = self.format_answer(answer=doc["answer"], label=True)
......@@ -364,8 +366,8 @@ class WordsInContext(HFTask):
return ""
def doc_to_text(self, doc):
return "{}\n{}\nQuestion: Is the word '{}' used in the same way in the" \
" two sentences above?\nanswer:".format(
return "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \
" two sentences above?\nAnswer:".format(
doc["sentence1"],
doc["sentence2"],
doc["sentence1"][doc["start1"]:doc["end1"]],
......@@ -438,7 +440,7 @@ class SGWinogradSchemaChallenge(HFTask):
# NOTE: HuggingFace span indices are word-based not character-based.
pre = " ".join(raw_passage.split()[:doc["span2_index"]])
post = raw_passage[len(pre) + len(doc["span2_text"]) + 1:]
passage = pre + " *{}*".format(doc['span2_text']) + post
passage = general_detokenize(pre + " *{}*".format(doc['span2_text']) + post)
noun = doc["span1_text"]
pronoun = doc["span2_text"]
text = (
......
import abc
import json
import random
import os
from pprint import pprint
import pycountry
from sacrebleu import sacrebleu
import logging
from lm_eval import metrics
from lm_eval.base import Task, rf
"""
This file implements translation tasks using datasets from WMT conferences, provided by sacrebleu.
Traditionally they are evaluated with BLEU scores. TER and CHRF are other options.
See sacrebleu.DATASETS for all available datasets. There are a lot!
"""
sacrebleu_datasets = sacrebleu.DATASETS
########################################
# Benchmarks one might want to run
########################################
# 6 total
gpt3_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian
}
# 14 total
selected_benchmarks = {
**gpt3_benchmarks,
"wmt20": ['fr-de', 'de-fr', 'en-ru', 'ru-en', 'en-iu', 'iu-en'], # French, German, Russian, Inuit
"iwslt17": ['en-ar', 'ar-en'] # Arabic
}
# 319 total
all_benchmarks = {
ts: sacrebleu.get_langpairs_for_testset(ts)
for ts in sacrebleu.get_available_testsets()
}
available_tests = {
"gpt3_tests": gpt3_benchmarks,
"selected_tests": selected_benchmarks,
"all_tests": all_benchmarks
}
def create_tasks_from_benchmarks(benchmark_dict):
"""Creates a dictionary of tasks from a dict
:param benchmark_dict: { dataset: [lang_pair, ...] }
:return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
"""
return {
f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair)
for dataset, language_pairs in benchmark_dict.items()
for language_pair in language_pairs
}
########################################
# Tasks
########################################
def create_translation_task(dataset, language_pair):
class TranslationTask(GeneralTranslationTask):
def __init__(self):
super().__init__(dataset, language_pair)
return TranslationTask
class GeneralTranslationTask(Task):
# e.g. ("wmt14", "fr-en")
def __init__(self, sacrebleu_dataset, sacrebleu_language_pair=None):
self.sacrebleu_dataset = sacrebleu_dataset
self.sacrebleu_language_pair = sacrebleu_language_pair
self.src_file = self.ref_file = self.src_data = self.ref_data = None
super().__init__()
def download(self):
# This caches in the users home dir automatically
self.src_file, self.ref_file = \
sacrebleu.download_test_set(self.sacrebleu_dataset, self.sacrebleu_language_pair)
self.src_data, self.ref_data = [
[line.rstrip() for line in sacrebleu.smart_open(file)]
for file in (self.src_file, self.ref_file)
]
def has_training_docs(self):
"""Whether the task has a training set"""
# TODO In the future we could be more discerning. Some more recent tests have train and dev sets
return False
def has_validation_docs(self):
"""Whether the task has a validation set"""
return False
def has_test_docs(self):
"""Whether the task has a test set"""
return True
def test_docs(self):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return [{
"src": src,
"ref": ref
} for src, ref in zip(self.src_data, self.ref_data)]
def doc_to_text(self, doc):
return doc["src"]
def doc_to_target(self, doc):
# TODO Note that some exotic tests have multiple ref lines.
# How does sacrebleu handle opening these files?
return doc["ref"]
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
return rf.greedy_until(ctx, ["\n"])
def process_results(self, doc, results):
# These metrics are corpus-level not sentence level, so we'll hide the
# results in this dict and compute the corpus score in the aggregate method
ref_pred = (doc["ref"], results)
return {
"bleu": ref_pred,
"chrf": ref_pred,
"ter": ref_pred,
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"bleu": metrics.bleu,
"chrf": metrics.chrf,
"ter": metrics.ter,
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"bleu": True,
"chrf": True,
"ter": False,
}
def fewshot_description(self):
language_codes = self.sacrebleu_language_pair.split("-")
src_lang = code_to_language(language_codes[0])
tar_lang = code_to_language(language_codes[1])
return f"Translate these {src_lang} phrases to {tar_lang}."
# TODO This should be something like
# French: {src_line}
# English: {ref_line}
def fewshot_context(self, doc, num_fewshot, provide_description):
return ""
def __str__(self):
language_codes = self.sacrebleu_language_pair.split("-")
src_lang = code_to_language(language_codes[0])
tar_lang = code_to_language(language_codes[1])
return f"{self.sacrebleu_dataset.upper()} {src_lang} to {tar_lang} Task"
########################################
# Util
########################################
def code_to_language(code):
# key is alpha_2 or alpha_3 depending on the code length
language_tuple = pycountry.languages.get(**{f"alpha_{len(code)}": code})
return language_tuple.name
def print_available_tests():
pprint({ts: sacrebleu.get_langpairs_for_testset(ts) for ts in sacrebleu.get_available_testsets()})
def main():
# print(sacrebleu.download_test_set("wmt14", "en-fr"))
# print_available_tests()
# sacrebleu.print_test_set("wmt14", "fr-en", "src")
# # Print number of benchmarks
# print(sum([
# len(sacrebleu.get_langpairs_for_testset(ts))
# for ts in sacrebleu.get_available_testsets()
# ]))
# Test task dictionary
# for task, task_class in create_tasks_from_benchmarks(selected_benchmarks).items():
# print(task, task_class())
pass
if __name__ == "__main__":
main()
########################################
# Don't mind me...!
########################################
# Available tests as of 2020/02/11
"""
{'iwslt17': ['en-fr',
'fr-en',
'en-de',
'de-en',
'en-zh',
'zh-en',
'en-ar',
'ar-en',
'en-ja',
'ja-en',
'en-ko',
'ko-en'],
'iwslt17/dev2010': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2010': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2011': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2012': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2013': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2014': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2015': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2016': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'mtnt1.1/test': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'mtnt1.1/train': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'mtnt1.1/valid': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'mtnt2019': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'multi30k/2016': ['en-fr', 'en-de', 'en-cs'],
'multi30k/2017': ['en-fr', 'en-de'],
'multi30k/2018': ['en-fr', 'en-de'],
'wmt08': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr',
'hu-en',
'en-hu'],
'wmt08/europarl': ['de-en', 'en-de', 'es-en', 'en-es', 'fr-en', 'en-fr'],
'wmt08/nc': ['cs-en', 'en-cs'],
'wmt09': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr',
'hu-en',
'en-hu',
'it-en',
'en-it'],
'wmt10': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr'],
'wmt11': ['cs-en',
'en-cs',
'de-en',
'en-de',
'fr-en',
'en-fr',
'es-en',
'en-es'],
'wmt12': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr'],
'wmt13': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr',
'ru-en',
'en-ru'],
'wmt14': ['cs-en',
'en-cs',
'de-en',
'en-de',
'en-fr',
'fr-en',
'en-hi',
'hi-en',
'en-ru',
'ru-en'],
'wmt14/full': ['cs-en',
'en-cs',
'de-en',
'en-de',
'en-fr',
'fr-en',
'en-hi',
'hi-en',
'en-ru',
'ru-en'],
'wmt15': ['en-fr',
'fr-en',
'cs-en',
'de-en',
'en-cs',
'en-de',
'en-fi',
'en-ru',
'fi-en',
'ru-en'],
'wmt16': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-fi',
'en-ro',
'en-ru',
'en-tr',
'fi-en',
'ro-en',
'ru-en',
'tr-en'],
'wmt16/B': ['en-fi'],
'wmt16/dev': ['en-ro', 'en-tr', 'ro-en', 'tr-en'],
'wmt16/tworefs': ['en-fi'],
'wmt17': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-fi',
'en-lv',
'en-ru',
'en-tr',
'en-zh',
'fi-en',
'lv-en',
'ru-en',
'tr-en',
'zh-en'],
'wmt17/B': ['en-fi'],
'wmt17/dev': ['en-lv', 'en-zh', 'lv-en', 'zh-en'],
'wmt17/improved': ['en-zh', 'zh-en'],
'wmt17/ms': ['zh-en'],
'wmt17/tworefs': ['en-fi'],
'wmt18': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-et',
'en-fi',
'en-ru',
'et-en',
'fi-en',
'ru-en',
'en-tr',
'tr-en',
'en-zh',
'zh-en'],
'wmt18/dev': ['et-en', 'en-et'],
'wmt18/test-ts': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-et',
'en-fi',
'en-ru',
'et-en',
'fi-en',
'ru-en',
'en-tr',
'tr-en',
'en-zh',
'zh-en'],
'wmt19': ['cs-de',
'de-cs',
'de-en',
'de-fr',
'en-cs',
'en-de',
'en-fi',
'en-gu',
'en-kk',
'en-lt',
'en-ru',
'en-zh',
'fi-en',
'fr-de',
'gu-en',
'kk-en',
'lt-en',
'ru-en',
'zh-en'],
'wmt19/dev': ['lt-en', 'en-lt', 'gu-en', 'en-gu', 'kk-en', 'en-kk'],
'wmt19/google/ar': ['en-de'],
'wmt19/google/arp': ['en-de'],
'wmt19/google/hqall': ['en-de'],
'wmt19/google/hqp': ['en-de'],
'wmt19/google/hqr': ['en-de'],
'wmt19/google/wmtp': ['en-de'],
'wmt20': ['cs-en',
'de-en',
'de-fr',
'en-cs',
'en-de',
'en-iu',
'en-ja',
'en-km',
'en-pl',
'en-ps',
'en-ru',
'en-ta',
'en-zh',
'fr-de',
'iu-en',
'ja-en',
'km-en',
'pl-en',
'ps-en',
'ru-en',
'ta-en',
'zh-en'],
'wmt20/dev': ['iu-en',
'en-iu',
'ja-en',
'en-ja',
'pl-en',
'en-pl',
'ta-en',
'en-ta'],
'wmt20/robust/set1': ['en-ja', 'en-de'],
'wmt20/robust/set2': ['en-ja', 'ja-en'],
'wmt20/robust/set3': ['de-en'],
'wmt20/tworefs': ['de-en', 'en-de', 'en-zh', 'ru-en', 'zh-en']}
"""
\ No newline at end of file
import os
import json
import random
from lm_eval.base import Task, mean, rf
from lm_eval.base import Task, rf
from ..metrics import mean
from ..utils import sh
class TriviaQA(Task):
......
from . common import HFTask
from lm_eval.base import mean, rf
from lm_eval.base import rf
from ..metrics import mean
class WebQs(HFTask):
DATASET_PATH = "web_questions"
......@@ -19,7 +21,7 @@ class WebQs(HFTask):
return ""
def doc_to_text(self, doc):
return "Q: " + doc['question'] + '\nA:'
return "Question: " + doc['question'] + '\nAnswer:'
def doc_to_target(self, doc):
# this picks one answer to be the "correct" one, despite sometimes
......
import numpy as np
from . common import HFTask
from lm_eval.base import rf, mean
from lm_eval.base import rf
from ..metrics import mean
"""
This evaluation of Winogrande uses partial evaluation as described by
......@@ -13,6 +14,8 @@ class Winogrande(HFTask):
DATASET_PATH = "winogrande"
DATASET_NAME = "winogrande_xl"
answer_to_num = {'1': 0, '2': 1}
def has_training_docs(self):
return True
......@@ -20,54 +23,59 @@ class Winogrande(HFTask):
return True
def has_test_docs(self):
return True
return False
def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]])
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in."
@classmethod
def partial_context(cls, doc):
# Substitute the pronoun in the sentence with each candidate choice
def partial_context(cls, doc, option):
# Substitute the pronoun in the sentence with the specified option
# and ignore everything after.
pronoun_loc = doc["sentence"].index("_")
context1 = doc["sentence"][:pronoun_loc] + doc["option1"]
context2 = doc["sentence"][:pronoun_loc] + doc["option2"]
return context1, context2
return doc["sentence"][:pronoun_loc] + option
def doc_to_target(self, doc):
return self.partial_target(doc)
@classmethod
def partial_target(cls, doc):
# The target is everything after the document specified pronoun.
pronoun_loc = doc["sentence"].index("_") + 1
return doc["sentence"][pronoun_loc:].strip()
def doc_to_text(self, doc):
context1, context2 = self.partial_context(doc)
return context1 + '\n' + context2 + '\n'
def doc_to_target(self, doc):
return self.partial_target(doc)
return " " + doc["sentence"][pronoun_loc:].strip()
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
target = self.partial_target(doc)
context1, context2 = self.partial_context(doc)
ll_context1, _ = rf.loglikelihood(context1, " " + target)
ll_context2, _ = rf.loglikelihood(context2, " " + target)
return ll_context1, ll_context2
lls = []
for option in [doc["option1"], doc["option2"]]:
partial_ctx = self.partial_context(doc, option)
full_ctx = self.append_context(ctx, partial_ctx)
lls.append(rf.loglikelihood(full_ctx, target)[0])
return lls
@classmethod
def append_context(cls, ctx, partial_ctx):
ctx = ctx.split("\n\n") # Each fewshot context is on its own new line.
ctx.pop() # Remove the correct context put in by `doc_to_text`.
return "\n\n".join([*ctx, partial_ctx]) if ctx else partial_ctx
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
......@@ -75,15 +83,14 @@ class Winogrande(HFTask):
:param results:
The results of the requests created in construct_requests.
"""
answer = int(doc["answer"]) - 1 # `- 1` b/c doc["answer"] ∈ {'1', '2'}
return {
"acc": np.argmax(results) == answer
"acc": np.argmax(results) == self.answer_to_num[doc["answer"]]
}
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
......@@ -93,7 +100,7 @@ class Winogrande(HFTask):
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
......
import numpy as np
import random
from lm_eval.base import rf, mean
from lm_eval.base import rf
from ..metrics import mean
from . common import HFTask
"""
......@@ -26,14 +27,14 @@ class WinogradSchemaChallenge273(HFTask):
data = []
for doc in self.data["test"]:
doc["text"] = doc["text"].replace(" ", " ")
doc["options"][0] = self.__normalize_option(doc["options"][0], doc)
doc["options"][1] = self.__normalize_option(doc["options"][1], doc)
doc["options"][0] = self.__normalize_option(doc, doc["options"][0])
doc["options"][1] = self.__normalize_option(doc, doc["options"][1])
data.append(doc)
return {"test": data}
def __normalize_option(self, option, doc):
def __normalize_option(self, doc, option):
# Append `'s` to possessive determiner based options.
if doc["pronoun"].lower() in ["my", "his", "her", "our", "their"]:
if doc["pronoun"].lower() in ["my", "his", "her", "our", "their"]:
option += "'s"
# Appropriately lowercase the pronoun in the option.
pronoun = option.split()[0]
......@@ -51,56 +52,61 @@ class WinogradSchemaChallenge273(HFTask):
def has_test_docs(self):
return True
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
def fewshot_examples(self, k):
# NOTE: `super().fewshot_examples` samples from training docs which are
# not available for this test-set-only dataset.
return random.sample(list(self.test_docs()), k)
def fewshot_description(self):
# TODO: redo description
return "Winograd schema sentence with correct continuation. True. Winograd schema sentence with incorrect continuation. False."
def doc_to_text(self, doc):
return self.partial_context(doc, doc["options"][doc["label"]])
@classmethod
def partial_context(cls, doc):
# Substitute the pronoun in the original text with each candidate
# choice and ignore everything after.
context1 = doc["text"][:doc["pronoun_loc"]] + doc["options"][0]
context2 = doc["text"][:doc["pronoun_loc"]] + doc["options"][1]
return context1, context2
def partial_context(cls, doc, option):
# Substitute the pronoun in the original text with the specified
# option and ignore everything after.
return doc["text"][:doc["pronoun_loc"]] + option
def doc_to_target(self, doc):
return self.partial_target(doc)
@classmethod
def partial_target(cls, doc):
# The target is everything after the document specified pronoun.
start_index = doc["pronoun_loc"] + len(doc["pronoun"])
return doc["text"][start_index:].strip()
def doc_to_text(self, doc):
context1, context2 = self.partial_context(doc)
return context1 + '\n' + context2 + '\n'
def doc_to_target(self, doc):
return self.partial_target(doc)
return " " + doc["text"][start_index:].strip()
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
target = self.partial_target(doc)
context1, context2 = self.partial_context(doc)
ll_context1, _ = rf.loglikelihood(context1, " " + target)
ll_context2, _ = rf.loglikelihood(context2, " " + target)
return ll_context1, ll_context2
lls = []
for option in doc["options"]:
partial_ctx = self.partial_context(doc, option)
full_ctx = self.append_context(ctx, partial_ctx)
lls.append(rf.loglikelihood(full_ctx, target)[0])
return lls
@classmethod
def append_context(cls, ctx, partial_ctx):
ctx = ctx.split("\n\n") # Each fewshot context is on its own new line.
ctx.pop() # Remove the correct context put in by `doc_to_text`.
return "\n\n".join([*ctx, partial_ctx]) if ctx else partial_ctx
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
......@@ -115,7 +121,7 @@ class WinogradSchemaChallenge273(HFTask):
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
......@@ -125,7 +131,7 @@ class WinogradSchemaChallenge273(HFTask):
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
......
import os
import re
class ExitCodeError(Exception):
......@@ -39,4 +40,13 @@ def chunks(iter, n):
yield arr
arr = []
if arr: yield arr
\ No newline at end of file
if arr: yield arr
def general_detokenize(string):
string = string.replace(" n't", "n't")
string = string.replace(" )", ")")
string = string.replace("( ", "(")
string = string.replace("\" ", "\"")
string = string.replace(" \"", "\"")
string = re.sub(r" (['.,])", r"\1", string)
return string
\ No newline at end of file
......@@ -16,11 +16,11 @@ def parse_args():
parser.add_argument('--model_args', default="")
parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=1)
parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--cache', action="store_true")
parser.add_argument('--no_cache', action="store_true")
return parser.parse_args()
def main():
......@@ -31,7 +31,7 @@ def main():
lm = models.get_model(args.model).create_from_arg_string(args.model_args)
if args.cache:
if not args.no_cache:
lm = base.CachingLM(lm, 'lm_cache/' + args.model + '_' + args.model_args.replace('=', '-').replace(',', '_') + '.db')
if args.tasks == "all_tasks":
task_names = tasks.ALL_TASKS
......@@ -47,6 +47,21 @@ def main():
with open(args.output_path, "w") as f:
f.write(dumped)
# MAKE TABLE
from pytablewriter import MarkdownTableWriter
writer = MarkdownTableWriter()
writer.headers = ["Task", "Metric", "Value"]
values = []
for k, dic in results.items():
for m, v in dic.items():
values.append([k, m, '%.4f' % v])
k = ""
writer.value_matrix = values
print(writer.dumps())
if __name__ == "__main__":
main()
......@@ -5,4 +5,7 @@ click>=7.1
scikit-learn>=0.24.1
torch>=1.7
transformers>=4.1
sqlitedict==1.6.0
\ No newline at end of file
sqlitedict==1.6.0
pytablewriter==0.58.0
sacrebleu==1.5.0
pycountry==20.7.3
\ No newline at end of file
......@@ -11,3 +11,11 @@ def test_gpt2():
assert ll_dog > ll_cat
assert not ig_cat
# test empty context
gpt2.loglikelihood([('', 'test')])
gen, = gpt2.greedy_until([
('The quick brown fox jumps over the lazy', ['.', '\n'])
])
assert gen == ', lazy fox and they both fall to the ground'
\ No newline at end of file
......@@ -22,6 +22,33 @@ def test_basic_interface(taskname, Task):
for v in task.higher_is_better().values(): assert v in [True, False]
# test deterministic docs
# (don't test train because it's slow)
task2 = Task()
if task.has_validation_docs():
arr = list(islice(task.validation_docs(), 100))
arr2 = list(islice(task2.validation_docs(), 100))
assert arr == arr2
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2
if task.has_test_docs():
arr = list(islice(task.test_docs(), 100))
arr2 = list(islice(task2.test_docs(), 100))
assert arr == arr2
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2
@pytest.mark.parametrize("taskname,Task", tasks.TASK_REGISTRY.items())
def test_documents_and_requests(taskname, Task):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment