Commit 4d147bdd authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Merge branch 'master' of https://github.com/EleutherAI/lm-evaluation-harness into task-guide

parents 011cc891 dc937d4b
"""
QuAC: Question Answering in Context
https://arxiv.org/abs/1808.07036
@article{choi2018quac,
title={Quac: Question answering in context},
author={Choi, Eunsol and He, He and Iyyer, Mohit and Yatskar, Mark and Yih, Wen-tau and Choi, Yejin and Liang, Percy and Zettlemoyer, Luke},
journal={arXiv preprint arXiv:1808.07036},
year={2018}
}
"""
import json
import os
from lm_eval.base import Task
from ..utils import sh
class QuAC(Task):
class QuAC(Task):
VERSION = 0
def __init__(self):
super().__init__()
def download(self):
if not os.path.exists('data/quac'):
# TODO: convert to use best_download
sh("""
mkdir -p data/quac
wget https://s3.amazonaws.com/my89public/quac/train_v0.2.json -O data/quac/train_v0.2.json
......
......@@ -15,6 +15,7 @@ class each:
class RACE(HFTask):
VERSION = 0
DATASET_PATH = "race"
DATASET_NAME = "high"
......
......@@ -3,6 +3,7 @@ from lm_eval.base import MultipleChoiceTask
class SATAnalogies(MultipleChoiceTask):
VERSION = 0
NEEDS_MANUAL_DL = True
def __init__(self):
......
......@@ -6,6 +6,7 @@ from best_download import download_file
class SciQ(MultipleChoiceTask):
VERSION = 0
# Multiple languages and multiple years
def download(self):
if not os.path.exists('data/sciq'):
......
import datasets
from math import exp
from lm_eval.base import rf
from lm_eval.metrics import f1_score, mean
from . common import HFTask
from functools import partial
from packaging import version
class SQuAD(HFTask):
def _squad_metric(predictions, references):
squad_metric = datasets.load_metric("squad_v2")
return squad_metric.compute(predictions=predictions, references=references)
def _squad_agg(key, items):
predictions, references = zip(*items)
return _squad_metric(predictions=predictions, references=references)[key]
class SQuAD2(HFTask):
VERSION = 1
DATASET_PATH = "squad_v2"
DATASET_NAME = None
# HF changed squad on us so we have to make sure we aren't running the old one
assert version.parse(datasets.__version__) >= version.parse("1.11.0"), "datasets v1.11.0 or later required for SQuAD"
def has_training_docs(self):
return True
......@@ -15,16 +36,14 @@ class SQuAD(HFTask):
return False
def training_docs(self):
if self.has_training_docs():
return self.data["train"]
return self.data["train"]
def validation_docs(self):
if self.has_validation_docs():
return self.data["validation"]
return self.data["validation"]
def fewshot_description(self):
# TODO: redo description
return "Title: The_Title_of_It\n\nBackground: A text passage as background to answer the question with.\n\nQ: Question about the passage.\n\nA: Answer."
# TODO: figure out description
return ""
def doc_to_text(self, doc):
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:'
......@@ -35,7 +54,7 @@ class SQuAD(HFTask):
answer = answer_list[0]
else:
answer = 'unanswerable'
return answer
return " " + answer
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
......@@ -48,8 +67,9 @@ class SQuAD(HFTask):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
continuation = rf.greedy_until(ctx, ['\n'])
is_unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
return continuation, is_unanswerable
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......@@ -61,8 +81,31 @@ class SQuAD(HFTask):
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
continuation, (logprob_unanswerable, _) = results
no_answer_probability = exp(logprob_unanswerable)
predictions = {
'id': doc['id'],
'prediction_text': continuation,
'no_answer_probability': no_answer_probability,
}
references = {
'id': doc['id'],
'answers': doc['answers'],
}
return {
'exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'HasAns_exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'NoAns_exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'best_exact': (predictions, references), # Best exact match (with varying threshold)
'best_f1': (predictions, references), # Best F1 (with varying threshold)
}
def aggregation(self):
"""
......@@ -70,8 +113,16 @@ class SQuAD(HFTask):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
'exact': partial(_squad_agg, 'exact'), # Exact match (the normalized answer exactly match the gold answer)
'f1': partial(_squad_agg, 'f1'), # The F-score of predicted tokens versus the gold answer
'HasAns_exact': partial(_squad_agg, 'HasAns_exact'), # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': partial(_squad_agg, 'HasAns_f1'), # The F-score of predicted tokens versus the gold answer
'NoAns_exact': partial(_squad_agg, 'NoAns_exact'), # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': partial(_squad_agg, 'NoAns_f1'), # The F-score of predicted tokens versus the gold answer
'best_exact': partial(_squad_agg, 'best_exact'), # Best exact match (with varying threshold)
'best_f1': partial(_squad_agg, 'best_f1'), # Best F1 (with varying threshold)
}
def higher_is_better(self):
"""
......@@ -79,5 +130,13 @@ class SQuAD(HFTask):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
'exact': True, # Exact match (the normalized answer exactly match the gold answer)
'f1': True, # The F-score of predicted tokens versus the gold answer
'HasAns_exact': True, # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': True, # The F-score of predicted tokens versus the gold answer
'NoAns_exact': True, # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': True, # The F-score of predicted tokens versus the gold answer
'best_exact': True, # Best exact match (with varying threshold)
'best_f1': True, # Best F1 (with varying threshold)
}
......@@ -3,6 +3,7 @@ from lm_eval.base import Task
class StoryCloze(Task):
VERSION = 0
NEEDS_MANUAL_DL = True
def download(self):
......
......@@ -13,6 +13,7 @@ from ..utils import general_detokenize
class BoolQ(HFTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "boolq"
......@@ -64,6 +65,7 @@ class BoolQ(HFTask):
class CommitmentBank(HFTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "cb"
......@@ -135,6 +137,7 @@ class CommitmentBank(HFTask):
class Copa(HFTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "copa"
......@@ -199,6 +202,7 @@ class Copa(HFTask):
class MultiRC(HFTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "multirc"
......@@ -253,6 +257,7 @@ class MultiRC(HFTask):
class ReCoRD(HFTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "record"
......@@ -345,6 +350,7 @@ class ReCoRD(HFTask):
class WordsInContext(HFTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "wic"
......@@ -400,6 +406,7 @@ class WordsInContext(HFTask):
class SGWinogradSchemaChallenge(HFTask):
VERSION = 0
# Note: This implementation differs from Fig G.32 because this is the SuperGLUE,
# binary version of the task.
DATASET_PATH = "super_glue"
......@@ -412,7 +419,7 @@ class SGWinogradSchemaChallenge(HFTask):
return True
def has_test_docs(self):
return True
return False
def training_docs(self):
if self.has_training_docs():
......
......@@ -3,6 +3,9 @@ from pprint import pprint
from sacrebleu import sacrebleu
from lm_eval import metrics
from lm_eval.base import Task, rf
from typing import List
"""
This file implements translation tasks using datasets from WMT conferences, provided by sacrebleu.
......@@ -19,23 +22,46 @@ def create_tasks_from_benchmarks(benchmark_dict):
:return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
"""
def version_of(dataset, language_pair):
if language_pair[-2:] in ["zh", "ja"]:
return 1 # changed to use jieba/nagisa
return 0
return {
f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair)
f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair, version_of(dataset, language_pair))
for dataset, language_pairs in benchmark_dict.items()
for language_pair in language_pairs
}
########################################
# Language Specifics
########################################
def zh_split(zh_text: List[str]) -> List[str]:
"""Chinese splitting"""
import jieba
return [" ".join(jieba.cut(txt.strip())) for txt in zh_text]
def ja_split(ja_text: List[str]) -> List[str]:
"""Japanese splitting"""
import nagisa
return [" ".join(nagisa.tagging(txt.strip()).words) for txt in ja_text]
NO_SPACE_LANG = {"zh": zh_split, "ja": ja_split}
########################################
# Tasks
########################################
def create_translation_task(dataset, language_pair):
def create_translation_task(dataset, language_pair, version=0):
class TranslationTask(GeneralTranslationTask):
VERSION = version
def __init__(self):
super().__init__(dataset, language_pair)
return TranslationTask
class GeneralTranslationTask(Task):
VERSION = 0
# e.g. ("wmt14", "fr-en")
def __init__(self, sacrebleu_dataset, sacrebleu_language_pair=None):
......@@ -101,6 +127,12 @@ class GeneralTranslationTask(Task):
return rf.greedy_until(ctx, ["\n"])
def process_results(self, doc, results):
# Add spaces between words for BLEU score calculation of target languages like Chinese
tar_lang_code = self.sacrebleu_language_pair.split("-")[-1]
if tar_lang_code in NO_SPACE_LANG:
doc["ref"] = NO_SPACE_LANG[tar_lang_code]([doc["ref"]])[0]
results = NO_SPACE_LANG[tar_lang_code](results)
# These metrics are corpus-level not sentence level, so we'll hide the
# results in this dict and compute the corpus score in the aggregate method
ref_pred = (doc["ref"], results)
......@@ -156,283 +188,3 @@ def code_to_language(code):
# key is alpha_2 or alpha_3 depending on the code length
language_tuple = pycountry.languages.get(**{f"alpha_{len(code)}": code})
return language_tuple.name
def print_available_tests():
pprint({ts: sacrebleu.get_langpairs_for_testset(ts) for ts in sacrebleu.get_available_testsets()})
def print_available_pairs():
list_of_pairs = [sacrebleu.get_langpairs_for_testset(ts) for ts in sacrebleu.get_available_testsets()]
pairs = set([item for sublist in list_of_pairs for item in sublist])
pairs = sorted(["-".join(map(code_to_language, pair.split("-"))) for pair in pairs])
pprint(pairs)
print(len(pairs))
def main():
# print(sacrebleu.download_test_set("wmt14", "en-fr"))
# print_available_tests()
# sacrebleu.print_test_set("wmt14", "fr-en", "src")
# # Print number of benchmarks
# print(sum([
# len(sacrebleu.get_langpairs_for_testset(ts))
# for ts in sacrebleu.get_available_testsets()
# ]))
# Test task dictionary
# for task, task_class in create_tasks_from_benchmarks(selected_benchmarks).items():
# print(task, task_class())
print_available_pairs()
pass
if __name__ == "__main__":
main()
########################################
# Don't mind me...!
########################################
# Available tests as of 2020/02/11
"""
{'iwslt17': ['en-fr',
'fr-en',
'en-de',
'de-en',
'en-zh',
'zh-en',
'en-ar',
'ar-en',
'en-ja',
'ja-en',
'en-ko',
'ko-en'],
'iwslt17/dev2010': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2010': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2011': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2012': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2013': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2014': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2015': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2016': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'mtnt1.1/test': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'mtnt1.1/train': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'mtnt1.1/valid': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'mtnt2019': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'multi30k/2016': ['en-fr', 'en-de', 'en-cs'],
'multi30k/2017': ['en-fr', 'en-de'],
'multi30k/2018': ['en-fr', 'en-de'],
'wmt08': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr',
'hu-en',
'en-hu'],
'wmt08/europarl': ['de-en', 'en-de', 'es-en', 'en-es', 'fr-en', 'en-fr'],
'wmt08/nc': ['cs-en', 'en-cs'],
'wmt09': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr',
'hu-en',
'en-hu',
'it-en',
'en-it'],
'wmt10': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr'],
'wmt11': ['cs-en',
'en-cs',
'de-en',
'en-de',
'fr-en',
'en-fr',
'es-en',
'en-es'],
'wmt12': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr'],
'wmt13': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr',
'ru-en',
'en-ru'],
'wmt14': ['cs-en',
'en-cs',
'de-en',
'en-de',
'en-fr',
'fr-en',
'en-hi',
'hi-en',
'en-ru',
'ru-en'],
'wmt14/full': ['cs-en',
'en-cs',
'de-en',
'en-de',
'en-fr',
'fr-en',
'en-hi',
'hi-en',
'en-ru',
'ru-en'],
'wmt15': ['en-fr',
'fr-en',
'cs-en',
'de-en',
'en-cs',
'en-de',
'en-fi',
'en-ru',
'fi-en',
'ru-en'],
'wmt16': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-fi',
'en-ro',
'en-ru',
'en-tr',
'fi-en',
'ro-en',
'ru-en',
'tr-en'],
'wmt16/B': ['en-fi'],
'wmt16/dev': ['en-ro', 'en-tr', 'ro-en', 'tr-en'],
'wmt16/tworefs': ['en-fi'],
'wmt17': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-fi',
'en-lv',
'en-ru',
'en-tr',
'en-zh',
'fi-en',
'lv-en',
'ru-en',
'tr-en',
'zh-en'],
'wmt17/B': ['en-fi'],
'wmt17/dev': ['en-lv', 'en-zh', 'lv-en', 'zh-en'],
'wmt17/improved': ['en-zh', 'zh-en'],
'wmt17/ms': ['zh-en'],
'wmt17/tworefs': ['en-fi'],
'wmt18': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-et',
'en-fi',
'en-ru',
'et-en',
'fi-en',
'ru-en',
'en-tr',
'tr-en',
'en-zh',
'zh-en'],
'wmt18/dev': ['et-en', 'en-et'],
'wmt18/test-ts': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-et',
'en-fi',
'en-ru',
'et-en',
'fi-en',
'ru-en',
'en-tr',
'tr-en',
'en-zh',
'zh-en'],
'wmt19': ['cs-de',
'de-cs',
'de-en',
'de-fr',
'en-cs',
'en-de',
'en-fi',
'en-gu',
'en-kk',
'en-lt',
'en-ru',
'en-zh',
'fi-en',
'fr-de',
'gu-en',
'kk-en',
'lt-en',
'ru-en',
'zh-en'],
'wmt19/dev': ['lt-en', 'en-lt', 'gu-en', 'en-gu', 'kk-en', 'en-kk'],
'wmt19/google/ar': ['en-de'],
'wmt19/google/arp': ['en-de'],
'wmt19/google/hqall': ['en-de'],
'wmt19/google/hqp': ['en-de'],
'wmt19/google/hqr': ['en-de'],
'wmt19/google/wmtp': ['en-de'],
'wmt20': ['cs-en',
'de-en',
'de-fr',
'en-cs',
'en-de',
'en-iu',
'en-ja',
'en-km',
'en-pl',
'en-ps',
'en-ru',
'en-ta',
'en-zh',
'fr-de',
'iu-en',
'ja-en',
'km-en',
'pl-en',
'ps-en',
'ru-en',
'ta-en',
'zh-en'],
'wmt20/dev': ['iu-en',
'en-iu',
'ja-en',
'en-ja',
'pl-en',
'en-pl',
'ta-en',
'en-ta'],
'wmt20/robust/set1': ['en-ja', 'en-de'],
'wmt20/robust/set2': ['en-ja', 'ja-en'],
'wmt20/robust/set3': ['de-en'],
'wmt20/tworefs': ['de-en', 'en-de', 'en-zh', 'ru-en', 'zh-en']}
"""
\ No newline at end of file
import os
import json
import jsonlines
from lm_eval.base import Task, rf
from ..metrics import mean
from ..utils import sh
from best_download import download_file
class TriviaQA(Task):
VERSION = 0
def download(self):
if not os.path.exists('data/triviaqa'):
if not os.path.exists('data/triviaqa/unfiltered-web-train.jsonl'):
os.makedirs("data/triviaqa/", exist_ok=True)
download_file("http://eaidata.bmk.sh/data/triviaqa-unfiltered.tar.gz", "data/triviaqa/triviaqa-unfiltered.tar.gz", "adc19b42769062d241a8fbe834c56e58598d9322eb6c614e9f33a68a2cf5523e")
sh("""
mkdir -p data/triviaqa
wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz -O data/triviaqa/trivia_qa-unfiltered.tar.gz
tar -xf data/triviaqa/trivia_qa-unfiltered.tar.gz
mv triviaqa-unfiltered/ data/triviaqa/
cd data/triviaqa/
tar -xf triviaqa-unfiltered.tar.gz
""")
def has_training_docs(self):
......@@ -25,20 +28,20 @@ class TriviaQA(Task):
return False
def training_docs(self):
return json.load(open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-train.json'))['Data']
return jsonlines.open('data/triviaqa/unfiltered-web-train.jsonl')
def validation_docs(self):
return json.load(open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-dev.json'))['Data']
return jsonlines.open('data/triviaqa/unfiltered-web-dev.jsonl')
def test_docs(self):
return json.load(open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-test.json'))['Data']
raise NotImplementedError()
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc):
return ''.join(['Q:', doc['Question'], '\n\n','A:'])
return f"Question: {doc['Question']}\nAnswer:"
def doc_to_target(self, doc):
return " " + doc['Answer']['Value']
......
......@@ -14,6 +14,7 @@ def extract_gzip(gz, to):
class WordUnscrambleTask(Task):
VERSION = 0
BASE_PATH = Path("data/unscramble")
FILENAME = None
CHECKSUM = None # SHA256 Checksum.
......@@ -23,7 +24,7 @@ class WordUnscrambleTask(Task):
def download(self):
if not self.BASE_PATH.exists():
Path.mkdir(self.BASE_PATH)
Path.mkdir(self.BASE_PATH, parents=True)
file = self.BASE_PATH / self.FILENAME
if not file.exists():
rawfile = file.parent / (file.name + ".gz")
......
......@@ -4,6 +4,7 @@ from ..metrics import mean
class WebQs(HFTask):
VERSION = 0
DATASET_PATH = "web_questions"
DATASET_NAME = None
......
from . common import HFTask
class WikiText103(HFTask):
NLP_PATH = "wikitext"
NLP_NAME = "wikitext-103-raw-v1"
import os
import re
from lm_eval.base import rf, PerplexityTask
from lm_eval.utils import sh
from best_download import download_file
def wikitext_detokenizer(string):
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string
class WikiText(PerplexityTask):
VERSION = 0
def download(self):
if not os.path.exists('data/wikitext/wikitext-2-raw/wiki.valid.raw'):
os.makedirs("data/wikitext/", exist_ok=True)
download_file("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip", "data/wikitext/wikitext-2-raw-v1.zip", "ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11")
sh("cd data/wikitext/ && unzip wikitext-2-raw-v1.zip")
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc):
# TODO: implement
pass
def has_validation_docs(self):
return True
def doc_to_target(self, doc):
# TODO: implement
pass
def has_train_docs(self):
return True
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def has_test_docs(self):
return True
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
class WikiText2(HFTask):
NLP_PATH = "wikitext"
NLP_NAME = "wikitext-2-raw-v1"
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc):
# TODO: implement
pass
def docs_for_split(self, split):
ret = []
for line in open(f"data/wikitext/wikitext-2-raw/wiki.{split}.raw").read().split('\n'):
rline = line.replace("= = =", "===").replace("= =", "==").strip()
if rline.startswith('= ') and rline.strip().endswith(' ='):
s = '\n'.join(ret)
if s.strip(): yield s
ret = []
ret.append(line)
yield '\n'.join(ret)
def validation_docs(self):
return self.docs_for_split('valid')
def train_docs(self):
return self.docs_for_split('train')
def test_docs(self):
return self.docs_for_split('test')
def doc_to_target(self, doc):
# TODO: implement
pass
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return wikitext_detokenizer(doc)
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def count_words(self, doc):
# count number of words in *original doc before detokenization*
return len(re.split(r"\s+", doc))
\ No newline at end of file
......@@ -11,6 +11,7 @@ Reference: https://arxiv.org/abs/1806.02847
class Winogrande(HFTask):
VERSION = 0
DATASET_PATH = "winogrande"
DATASET_NAME = "winogrande_xl"
......
......@@ -12,6 +12,7 @@ See: https://arxiv.org/abs/1806.02847
class WinogradSchemaChallenge273(HFTask):
VERSION = 0
DATASET_PATH = "winograd_wsc"
DATASET_NAME = "wsc273"
......
......@@ -61,6 +61,56 @@ def general_detokenize(string):
return string
def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len):
"""
- context_len allows for a rolling window context, allowing each prediction window to potentially
condition on some context
:param token_list: list
List of tokens to be PREDICTED
:param max_seq_len: int
max_seq_len of model (or max_seq_len we want to use)
:param context_len: int
Amount of desired token context for prediction. Needs to be at least 1.
:param prefix_token: token
Dummy token like <eos> so the first token has something to condition on
:return: generator
Generator of tuples
(input_tokens, pred_tokens)
Note: Score only the last len(pred_tokens) logits of the LM
"""
assert 1 <= context_len <= max_seq_len
if not token_list:
return
# +1 offset, going from input->preds
pred_len = max_seq_len - context_len + 1
predicted = 0
# Special handling for first window: predict all tokens
first_seq_len = min(max_seq_len, len(token_list))
yield (
[prefix_token] + token_list[:first_seq_len - 1],
token_list[:first_seq_len]
)
predicted += first_seq_len
while predicted < len(token_list):
window_pred_len = min(len(token_list) - predicted, pred_len)
window_end = predicted + window_pred_len
yield (
token_list[window_end - max_seq_len - 1:window_end - 1],
token_list[window_end - window_pred_len:window_end],
)
predicted += window_pred_len
def make_disjoint_window(pair):
""" Takes output from get_rolling_token_windows and makes the context not overlap with the continuation """
a, b = pair
return a[:-(len(b) - 1)], b
class Reorderer:
def __init__(self, arr, fn):
self.size = len(arr)
......
......@@ -15,6 +15,8 @@ def parse_args():
parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None)
parser.add_argument('--device', type=str, default=None)
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None)
......@@ -27,7 +29,9 @@ def main():
random.seed(args.seed)
np.random.seed(args.seed)
lm = models.get_model(args.model).create_from_arg_string(args.model_args)
lm = models.get_model(args.model).create_from_arg_string(args.model_args, {
'batch_size': args.batch_size, 'device': args.device
})
if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
......@@ -49,20 +53,36 @@ def main():
f.write(dumped)
# MAKE TABLE
from pytablewriter import MarkdownTableWriter
from pytablewriter import MarkdownTableWriter, LatexTableWriter
writer = MarkdownTableWriter()
writer.headers = ["Task", "Metric", "Value"]
md_writer = MarkdownTableWriter()
latex_writer = LatexTableWriter()
md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
values = []
for k, dic in results.items():
for k, dic in results["results"].items():
version = results["versions"][k]
for m, v in dic.items():
values.append([k, m, '%.4f' % v])
if m.endswith("_stderr"): continue
if m + "_stderr" in dic:
se = dic[m + "_stderr"]
values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
else:
values.append([k, version, m, '%.4f' % v, '', ''])
k = ""
writer.value_matrix = values
version = ""
md_writer.value_matrix = values
latex_writer.value_matrix = values
# todo: make latex table look good
# print(latex_writer.dumps())
print(writer.dumps())
print(f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}")
print(md_writer.dumps())
if __name__ == "__main__":
main()
black==20.8b1
best_download>=0.0.5
datasets>=1.2.1
click>=7.1
scikit-learn>=0.24.1
torch>=1.7
transformers>=4.1
sqlitedict==1.6.0
pytablewriter==0.58.0
sacrebleu==1.5.0
pycountry==20.7.3
numexpr==2.7.2
\ No newline at end of file
-e .
import os
import zstandard
import json
import jsonlines
import io
import datetime
def json_serial(obj):
"""JSON serializer for objects not serializable by default json code"""
if isinstance(obj, (datetime.datetime,)):
return obj.isoformat()
raise TypeError ("Type %s not serializable" % type(obj))
# Modified version of lm_dataformat Archive for single file.
class Archive:
def __init__(self, file_path, compression_level=3):
self.file_path = file_path
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, 'wb')
self.cctx = zstandard.ZstdCompressor(level=compression_level)
self.compressor = self.cctx.stream_writer(self.fh)
def add_data(self, data, meta={}):
self.compressor.write(json.dumps({'text': data, 'meta': meta}, default=json_serial).encode('UTF-8') + b'\n')
def commit(self):
self.compressor.flush(zstandard.FLUSH_FRAME)
self.fh.flush()
self.fh.close()
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader:
def __init__(self):
pass
def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n'):
with open(file, 'rb') as fh:
self.fh = fh
cctx = zstandard.ZstdDecompressor()
reader = io.BufferedReader(cctx.stream_reader(fh))
rdr = jsonlines.Reader(reader)
for ob in rdr:
# naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
if isinstance(ob, str):
assert not get_meta
yield ob
continue
text = ob['text']
if autojoin_paragraphs and isinstance(text, list):
text = para_joiner.join(text)
if get_meta:
yield text, (ob['meta'] if 'meta' in ob else {})
else:
yield text
# Simple text reader and writer with same interface as above
class TextArchive:
def __init__(self, file_path, mode="ab"):
self.file_path = file_path
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, mode)
def add_data(self, data, meta={}):
self.fh.write(data.encode('UTF-8') + b'\n')
def commit(self):
self.fh.flush()
self.fh.close()
class TextReader:
def __init__(self, file_path):
self.file_path = file_path
def read(self):
with open(self.file_path, 'r', encoding="utf8") as fh:
self.fh = fh
while True:
line = self.fh.readline()
if line == -1 or line == "":
break
else :
yield line[:-1]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment