"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "f3a8933c4730ec9127072ab1ff4f89a1976e465e"
Commit 4d147bdd authored by Jonathan Tow's avatar Jonathan Tow
Browse files

Merge branch 'master' of https://github.com/EleutherAI/lm-evaluation-harness into task-guide

parents 011cc891 dc937d4b
"""
QuAC: Question Answering in Context
https://arxiv.org/abs/1808.07036
@article{choi2018quac,
title={Quac: Question answering in context},
author={Choi, Eunsol and He, He and Iyyer, Mohit and Yatskar, Mark and Yih, Wen-tau and Choi, Yejin and Liang, Percy and Zettlemoyer, Luke},
journal={arXiv preprint arXiv:1808.07036},
year={2018}
}
"""
import json import json
import os import os
from lm_eval.base import Task from lm_eval.base import Task
from ..utils import sh from ..utils import sh
class QuAC(Task): class QuAC(Task):
VERSION = 0
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def download(self): def download(self):
if not os.path.exists('data/quac'): if not os.path.exists('data/quac'):
# TODO: convert to use best_download
sh(""" sh("""
mkdir -p data/quac mkdir -p data/quac
wget https://s3.amazonaws.com/my89public/quac/train_v0.2.json -O data/quac/train_v0.2.json wget https://s3.amazonaws.com/my89public/quac/train_v0.2.json -O data/quac/train_v0.2.json
......
...@@ -15,6 +15,7 @@ class each: ...@@ -15,6 +15,7 @@ class each:
class RACE(HFTask): class RACE(HFTask):
VERSION = 0
DATASET_PATH = "race" DATASET_PATH = "race"
DATASET_NAME = "high" DATASET_NAME = "high"
......
...@@ -3,6 +3,7 @@ from lm_eval.base import MultipleChoiceTask ...@@ -3,6 +3,7 @@ from lm_eval.base import MultipleChoiceTask
class SATAnalogies(MultipleChoiceTask): class SATAnalogies(MultipleChoiceTask):
VERSION = 0
NEEDS_MANUAL_DL = True NEEDS_MANUAL_DL = True
def __init__(self): def __init__(self):
......
...@@ -6,6 +6,7 @@ from best_download import download_file ...@@ -6,6 +6,7 @@ from best_download import download_file
class SciQ(MultipleChoiceTask): class SciQ(MultipleChoiceTask):
VERSION = 0
# Multiple languages and multiple years # Multiple languages and multiple years
def download(self): def download(self):
if not os.path.exists('data/sciq'): if not os.path.exists('data/sciq'):
......
import datasets
from math import exp
from lm_eval.base import rf
from lm_eval.metrics import f1_score, mean
from . common import HFTask from . common import HFTask
from functools import partial
from packaging import version
class SQuAD(HFTask): def _squad_metric(predictions, references):
squad_metric = datasets.load_metric("squad_v2")
return squad_metric.compute(predictions=predictions, references=references)
def _squad_agg(key, items):
predictions, references = zip(*items)
return _squad_metric(predictions=predictions, references=references)[key]
class SQuAD2(HFTask):
VERSION = 1
DATASET_PATH = "squad_v2" DATASET_PATH = "squad_v2"
DATASET_NAME = None DATASET_NAME = None
# HF changed squad on us so we have to make sure we aren't running the old one
assert version.parse(datasets.__version__) >= version.parse("1.11.0"), "datasets v1.11.0 or later required for SQuAD"
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -15,16 +36,14 @@ class SQuAD(HFTask): ...@@ -15,16 +36,14 @@ class SQuAD(HFTask):
return False return False
def training_docs(self): def training_docs(self):
if self.has_training_docs(): return self.data["train"]
return self.data["train"]
def validation_docs(self): def validation_docs(self):
if self.has_validation_docs(): return self.data["validation"]
return self.data["validation"]
def fewshot_description(self): def fewshot_description(self):
# TODO: redo description # TODO: figure out description
return "Title: The_Title_of_It\n\nBackground: A text passage as background to answer the question with.\n\nQ: Question about the passage.\n\nA: Answer." return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:' return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:'
...@@ -35,7 +54,7 @@ class SQuAD(HFTask): ...@@ -35,7 +54,7 @@ class SQuAD(HFTask):
answer = answer_list[0] answer = answer_list[0]
else: else:
answer = 'unanswerable' answer = 'unanswerable'
return answer return " " + answer
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -48,8 +67,9 @@ class SQuAD(HFTask): ...@@ -48,8 +67,9 @@ class SQuAD(HFTask):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. continuation = rf.greedy_until(ctx, ['\n'])
raise NotImplementedError('Evaluation not implemented') is_unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
return continuation, is_unanswerable
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -61,8 +81,31 @@ class SQuAD(HFTask): ...@@ -61,8 +81,31 @@ class SQuAD(HFTask):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. continuation, (logprob_unanswerable, _) = results
raise NotImplementedError('Evaluation not implemented')
no_answer_probability = exp(logprob_unanswerable)
predictions = {
'id': doc['id'],
'prediction_text': continuation,
'no_answer_probability': no_answer_probability,
}
references = {
'id': doc['id'],
'answers': doc['answers'],
}
return {
'exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'HasAns_exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'NoAns_exact': (predictions, references), # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': (predictions, references), # The F-score of predicted tokens versus the gold answer
'best_exact': (predictions, references), # Best exact match (with varying threshold)
'best_f1': (predictions, references), # Best F1 (with varying threshold)
}
def aggregation(self): def aggregation(self):
""" """
...@@ -70,8 +113,16 @@ class SQuAD(HFTask): ...@@ -70,8 +113,16 @@ class SQuAD(HFTask):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') 'exact': partial(_squad_agg, 'exact'), # Exact match (the normalized answer exactly match the gold answer)
'f1': partial(_squad_agg, 'f1'), # The F-score of predicted tokens versus the gold answer
'HasAns_exact': partial(_squad_agg, 'HasAns_exact'), # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': partial(_squad_agg, 'HasAns_f1'), # The F-score of predicted tokens versus the gold answer
'NoAns_exact': partial(_squad_agg, 'NoAns_exact'), # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': partial(_squad_agg, 'NoAns_f1'), # The F-score of predicted tokens versus the gold answer
'best_exact': partial(_squad_agg, 'best_exact'), # Best exact match (with varying threshold)
'best_f1': partial(_squad_agg, 'best_f1'), # Best F1 (with varying threshold)
}
def higher_is_better(self): def higher_is_better(self):
""" """
...@@ -79,5 +130,13 @@ class SQuAD(HFTask): ...@@ -79,5 +130,13 @@ class SQuAD(HFTask):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. return {
raise NotImplementedError('Evaluation not implemented') 'exact': True, # Exact match (the normalized answer exactly match the gold answer)
'f1': True, # The F-score of predicted tokens versus the gold answer
'HasAns_exact': True, # Exact match (the normalized answer exactly match the gold answer)
'HasAns_f1': True, # The F-score of predicted tokens versus the gold answer
'NoAns_exact': True, # Exact match (the normalized answer exactly match the gold answer)
'NoAns_f1': True, # The F-score of predicted tokens versus the gold answer
'best_exact': True, # Best exact match (with varying threshold)
'best_f1': True, # Best F1 (with varying threshold)
}
...@@ -3,6 +3,7 @@ from lm_eval.base import Task ...@@ -3,6 +3,7 @@ from lm_eval.base import Task
class StoryCloze(Task): class StoryCloze(Task):
VERSION = 0
NEEDS_MANUAL_DL = True NEEDS_MANUAL_DL = True
def download(self): def download(self):
......
...@@ -13,6 +13,7 @@ from ..utils import general_detokenize ...@@ -13,6 +13,7 @@ from ..utils import general_detokenize
class BoolQ(HFTask): class BoolQ(HFTask):
VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "boolq" DATASET_NAME = "boolq"
...@@ -64,6 +65,7 @@ class BoolQ(HFTask): ...@@ -64,6 +65,7 @@ class BoolQ(HFTask):
class CommitmentBank(HFTask): class CommitmentBank(HFTask):
VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "cb" DATASET_NAME = "cb"
...@@ -135,6 +137,7 @@ class CommitmentBank(HFTask): ...@@ -135,6 +137,7 @@ class CommitmentBank(HFTask):
class Copa(HFTask): class Copa(HFTask):
VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "copa" DATASET_NAME = "copa"
...@@ -199,6 +202,7 @@ class Copa(HFTask): ...@@ -199,6 +202,7 @@ class Copa(HFTask):
class MultiRC(HFTask): class MultiRC(HFTask):
VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "multirc" DATASET_NAME = "multirc"
...@@ -253,6 +257,7 @@ class MultiRC(HFTask): ...@@ -253,6 +257,7 @@ class MultiRC(HFTask):
class ReCoRD(HFTask): class ReCoRD(HFTask):
VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "record" DATASET_NAME = "record"
...@@ -345,6 +350,7 @@ class ReCoRD(HFTask): ...@@ -345,6 +350,7 @@ class ReCoRD(HFTask):
class WordsInContext(HFTask): class WordsInContext(HFTask):
VERSION = 0
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
DATASET_NAME = "wic" DATASET_NAME = "wic"
...@@ -400,6 +406,7 @@ class WordsInContext(HFTask): ...@@ -400,6 +406,7 @@ class WordsInContext(HFTask):
class SGWinogradSchemaChallenge(HFTask): class SGWinogradSchemaChallenge(HFTask):
VERSION = 0
# Note: This implementation differs from Fig G.32 because this is the SuperGLUE, # Note: This implementation differs from Fig G.32 because this is the SuperGLUE,
# binary version of the task. # binary version of the task.
DATASET_PATH = "super_glue" DATASET_PATH = "super_glue"
...@@ -412,7 +419,7 @@ class SGWinogradSchemaChallenge(HFTask): ...@@ -412,7 +419,7 @@ class SGWinogradSchemaChallenge(HFTask):
return True return True
def has_test_docs(self): def has_test_docs(self):
return True return False
def training_docs(self): def training_docs(self):
if self.has_training_docs(): if self.has_training_docs():
......
...@@ -3,6 +3,9 @@ from pprint import pprint ...@@ -3,6 +3,9 @@ from pprint import pprint
from sacrebleu import sacrebleu from sacrebleu import sacrebleu
from lm_eval import metrics from lm_eval import metrics
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from typing import List
""" """
This file implements translation tasks using datasets from WMT conferences, provided by sacrebleu. This file implements translation tasks using datasets from WMT conferences, provided by sacrebleu.
...@@ -19,23 +22,46 @@ def create_tasks_from_benchmarks(benchmark_dict): ...@@ -19,23 +22,46 @@ def create_tasks_from_benchmarks(benchmark_dict):
:return: {task_name: task} :return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task} e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
""" """
def version_of(dataset, language_pair):
if language_pair[-2:] in ["zh", "ja"]:
return 1 # changed to use jieba/nagisa
return 0
return { return {
f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair) f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair, version_of(dataset, language_pair))
for dataset, language_pairs in benchmark_dict.items() for dataset, language_pairs in benchmark_dict.items()
for language_pair in language_pairs for language_pair in language_pairs
} }
########################################
# Language Specifics
########################################
def zh_split(zh_text: List[str]) -> List[str]:
"""Chinese splitting"""
import jieba
return [" ".join(jieba.cut(txt.strip())) for txt in zh_text]
def ja_split(ja_text: List[str]) -> List[str]:
"""Japanese splitting"""
import nagisa
return [" ".join(nagisa.tagging(txt.strip()).words) for txt in ja_text]
NO_SPACE_LANG = {"zh": zh_split, "ja": ja_split}
######################################## ########################################
# Tasks # Tasks
######################################## ########################################
def create_translation_task(dataset, language_pair): def create_translation_task(dataset, language_pair, version=0):
class TranslationTask(GeneralTranslationTask): class TranslationTask(GeneralTranslationTask):
VERSION = version
def __init__(self): def __init__(self):
super().__init__(dataset, language_pair) super().__init__(dataset, language_pair)
return TranslationTask return TranslationTask
class GeneralTranslationTask(Task): class GeneralTranslationTask(Task):
VERSION = 0
# e.g. ("wmt14", "fr-en") # e.g. ("wmt14", "fr-en")
def __init__(self, sacrebleu_dataset, sacrebleu_language_pair=None): def __init__(self, sacrebleu_dataset, sacrebleu_language_pair=None):
...@@ -101,6 +127,12 @@ class GeneralTranslationTask(Task): ...@@ -101,6 +127,12 @@ class GeneralTranslationTask(Task):
return rf.greedy_until(ctx, ["\n"]) return rf.greedy_until(ctx, ["\n"])
def process_results(self, doc, results): def process_results(self, doc, results):
# Add spaces between words for BLEU score calculation of target languages like Chinese
tar_lang_code = self.sacrebleu_language_pair.split("-")[-1]
if tar_lang_code in NO_SPACE_LANG:
doc["ref"] = NO_SPACE_LANG[tar_lang_code]([doc["ref"]])[0]
results = NO_SPACE_LANG[tar_lang_code](results)
# These metrics are corpus-level not sentence level, so we'll hide the # These metrics are corpus-level not sentence level, so we'll hide the
# results in this dict and compute the corpus score in the aggregate method # results in this dict and compute the corpus score in the aggregate method
ref_pred = (doc["ref"], results) ref_pred = (doc["ref"], results)
...@@ -156,283 +188,3 @@ def code_to_language(code): ...@@ -156,283 +188,3 @@ def code_to_language(code):
# key is alpha_2 or alpha_3 depending on the code length # key is alpha_2 or alpha_3 depending on the code length
language_tuple = pycountry.languages.get(**{f"alpha_{len(code)}": code}) language_tuple = pycountry.languages.get(**{f"alpha_{len(code)}": code})
return language_tuple.name return language_tuple.name
def print_available_tests():
pprint({ts: sacrebleu.get_langpairs_for_testset(ts) for ts in sacrebleu.get_available_testsets()})
def print_available_pairs():
list_of_pairs = [sacrebleu.get_langpairs_for_testset(ts) for ts in sacrebleu.get_available_testsets()]
pairs = set([item for sublist in list_of_pairs for item in sublist])
pairs = sorted(["-".join(map(code_to_language, pair.split("-"))) for pair in pairs])
pprint(pairs)
print(len(pairs))
def main():
# print(sacrebleu.download_test_set("wmt14", "en-fr"))
# print_available_tests()
# sacrebleu.print_test_set("wmt14", "fr-en", "src")
# # Print number of benchmarks
# print(sum([
# len(sacrebleu.get_langpairs_for_testset(ts))
# for ts in sacrebleu.get_available_testsets()
# ]))
# Test task dictionary
# for task, task_class in create_tasks_from_benchmarks(selected_benchmarks).items():
# print(task, task_class())
print_available_pairs()
pass
if __name__ == "__main__":
main()
########################################
# Don't mind me...!
########################################
# Available tests as of 2020/02/11
"""
{'iwslt17': ['en-fr',
'fr-en',
'en-de',
'de-en',
'en-zh',
'zh-en',
'en-ar',
'ar-en',
'en-ja',
'ja-en',
'en-ko',
'ko-en'],
'iwslt17/dev2010': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2010': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2011': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2012': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2013': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2014': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2015': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'iwslt17/tst2016': ['en-fr', 'fr-en', 'en-de', 'de-en', 'en-zh', 'zh-en'],
'mtnt1.1/test': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'mtnt1.1/train': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'mtnt1.1/valid': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'mtnt2019': ['en-fr', 'fr-en', 'en-ja', 'ja-en'],
'multi30k/2016': ['en-fr', 'en-de', 'en-cs'],
'multi30k/2017': ['en-fr', 'en-de'],
'multi30k/2018': ['en-fr', 'en-de'],
'wmt08': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr',
'hu-en',
'en-hu'],
'wmt08/europarl': ['de-en', 'en-de', 'es-en', 'en-es', 'fr-en', 'en-fr'],
'wmt08/nc': ['cs-en', 'en-cs'],
'wmt09': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr',
'hu-en',
'en-hu',
'it-en',
'en-it'],
'wmt10': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr'],
'wmt11': ['cs-en',
'en-cs',
'de-en',
'en-de',
'fr-en',
'en-fr',
'es-en',
'en-es'],
'wmt12': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr'],
'wmt13': ['cs-en',
'en-cs',
'de-en',
'en-de',
'es-en',
'en-es',
'fr-en',
'en-fr',
'ru-en',
'en-ru'],
'wmt14': ['cs-en',
'en-cs',
'de-en',
'en-de',
'en-fr',
'fr-en',
'en-hi',
'hi-en',
'en-ru',
'ru-en'],
'wmt14/full': ['cs-en',
'en-cs',
'de-en',
'en-de',
'en-fr',
'fr-en',
'en-hi',
'hi-en',
'en-ru',
'ru-en'],
'wmt15': ['en-fr',
'fr-en',
'cs-en',
'de-en',
'en-cs',
'en-de',
'en-fi',
'en-ru',
'fi-en',
'ru-en'],
'wmt16': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-fi',
'en-ro',
'en-ru',
'en-tr',
'fi-en',
'ro-en',
'ru-en',
'tr-en'],
'wmt16/B': ['en-fi'],
'wmt16/dev': ['en-ro', 'en-tr', 'ro-en', 'tr-en'],
'wmt16/tworefs': ['en-fi'],
'wmt17': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-fi',
'en-lv',
'en-ru',
'en-tr',
'en-zh',
'fi-en',
'lv-en',
'ru-en',
'tr-en',
'zh-en'],
'wmt17/B': ['en-fi'],
'wmt17/dev': ['en-lv', 'en-zh', 'lv-en', 'zh-en'],
'wmt17/improved': ['en-zh', 'zh-en'],
'wmt17/ms': ['zh-en'],
'wmt17/tworefs': ['en-fi'],
'wmt18': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-et',
'en-fi',
'en-ru',
'et-en',
'fi-en',
'ru-en',
'en-tr',
'tr-en',
'en-zh',
'zh-en'],
'wmt18/dev': ['et-en', 'en-et'],
'wmt18/test-ts': ['cs-en',
'de-en',
'en-cs',
'en-de',
'en-et',
'en-fi',
'en-ru',
'et-en',
'fi-en',
'ru-en',
'en-tr',
'tr-en',
'en-zh',
'zh-en'],
'wmt19': ['cs-de',
'de-cs',
'de-en',
'de-fr',
'en-cs',
'en-de',
'en-fi',
'en-gu',
'en-kk',
'en-lt',
'en-ru',
'en-zh',
'fi-en',
'fr-de',
'gu-en',
'kk-en',
'lt-en',
'ru-en',
'zh-en'],
'wmt19/dev': ['lt-en', 'en-lt', 'gu-en', 'en-gu', 'kk-en', 'en-kk'],
'wmt19/google/ar': ['en-de'],
'wmt19/google/arp': ['en-de'],
'wmt19/google/hqall': ['en-de'],
'wmt19/google/hqp': ['en-de'],
'wmt19/google/hqr': ['en-de'],
'wmt19/google/wmtp': ['en-de'],
'wmt20': ['cs-en',
'de-en',
'de-fr',
'en-cs',
'en-de',
'en-iu',
'en-ja',
'en-km',
'en-pl',
'en-ps',
'en-ru',
'en-ta',
'en-zh',
'fr-de',
'iu-en',
'ja-en',
'km-en',
'pl-en',
'ps-en',
'ru-en',
'ta-en',
'zh-en'],
'wmt20/dev': ['iu-en',
'en-iu',
'ja-en',
'en-ja',
'pl-en',
'en-pl',
'ta-en',
'en-ta'],
'wmt20/robust/set1': ['en-ja', 'en-de'],
'wmt20/robust/set2': ['en-ja', 'ja-en'],
'wmt20/robust/set3': ['de-en'],
'wmt20/tworefs': ['de-en', 'en-de', 'en-zh', 'ru-en', 'zh-en']}
"""
\ No newline at end of file
import os import os
import json import json
import jsonlines
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from ..metrics import mean from ..metrics import mean
from ..utils import sh from ..utils import sh
from best_download import download_file
class TriviaQA(Task): class TriviaQA(Task):
VERSION = 0
def download(self): def download(self):
if not os.path.exists('data/triviaqa'): if not os.path.exists('data/triviaqa/unfiltered-web-train.jsonl'):
os.makedirs("data/triviaqa/", exist_ok=True)
download_file("http://eaidata.bmk.sh/data/triviaqa-unfiltered.tar.gz", "data/triviaqa/triviaqa-unfiltered.tar.gz", "adc19b42769062d241a8fbe834c56e58598d9322eb6c614e9f33a68a2cf5523e")
sh(""" sh("""
mkdir -p data/triviaqa cd data/triviaqa/
wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz -O data/triviaqa/trivia_qa-unfiltered.tar.gz tar -xf triviaqa-unfiltered.tar.gz
tar -xf data/triviaqa/trivia_qa-unfiltered.tar.gz
mv triviaqa-unfiltered/ data/triviaqa/
""") """)
def has_training_docs(self): def has_training_docs(self):
...@@ -25,20 +28,20 @@ class TriviaQA(Task): ...@@ -25,20 +28,20 @@ class TriviaQA(Task):
return False return False
def training_docs(self): def training_docs(self):
return json.load(open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-train.json'))['Data'] return jsonlines.open('data/triviaqa/unfiltered-web-train.jsonl')
def validation_docs(self): def validation_docs(self):
return json.load(open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-dev.json'))['Data'] return jsonlines.open('data/triviaqa/unfiltered-web-dev.jsonl')
def test_docs(self): def test_docs(self):
return json.load(open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-test.json'))['Data'] raise NotImplementedError()
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out fewshot description # TODO: figure out fewshot description
return "" return ""
def doc_to_text(self, doc): def doc_to_text(self, doc):
return ''.join(['Q:', doc['Question'], '\n\n','A:']) return f"Question: {doc['Question']}\nAnswer:"
def doc_to_target(self, doc): def doc_to_target(self, doc):
return " " + doc['Answer']['Value'] return " " + doc['Answer']['Value']
......
...@@ -14,6 +14,7 @@ def extract_gzip(gz, to): ...@@ -14,6 +14,7 @@ def extract_gzip(gz, to):
class WordUnscrambleTask(Task): class WordUnscrambleTask(Task):
VERSION = 0
BASE_PATH = Path("data/unscramble") BASE_PATH = Path("data/unscramble")
FILENAME = None FILENAME = None
CHECKSUM = None # SHA256 Checksum. CHECKSUM = None # SHA256 Checksum.
...@@ -23,7 +24,7 @@ class WordUnscrambleTask(Task): ...@@ -23,7 +24,7 @@ class WordUnscrambleTask(Task):
def download(self): def download(self):
if not self.BASE_PATH.exists(): if not self.BASE_PATH.exists():
Path.mkdir(self.BASE_PATH) Path.mkdir(self.BASE_PATH, parents=True)
file = self.BASE_PATH / self.FILENAME file = self.BASE_PATH / self.FILENAME
if not file.exists(): if not file.exists():
rawfile = file.parent / (file.name + ".gz") rawfile = file.parent / (file.name + ".gz")
......
...@@ -4,6 +4,7 @@ from ..metrics import mean ...@@ -4,6 +4,7 @@ from ..metrics import mean
class WebQs(HFTask): class WebQs(HFTask):
VERSION = 0
DATASET_PATH = "web_questions" DATASET_PATH = "web_questions"
DATASET_NAME = None DATASET_NAME = None
......
from . common import HFTask import os
import re
from lm_eval.base import rf, PerplexityTask
class WikiText103(HFTask): from lm_eval.utils import sh
NLP_PATH = "wikitext"
NLP_NAME = "wikitext-103-raw-v1" from best_download import download_file
def wikitext_detokenizer(string):
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string
class WikiText(PerplexityTask):
VERSION = 0
def download(self):
if not os.path.exists('data/wikitext/wikitext-2-raw/wiki.valid.raw'):
os.makedirs("data/wikitext/", exist_ok=True)
download_file("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip", "data/wikitext/wikitext-2-raw-v1.zip", "ef7edb566e3e2b2d31b29c1fdb0c89a4cc683597484c3dc2517919c615435a11")
sh("cd data/wikitext/ && unzip wikitext-2-raw-v1.zip")
def fewshot_description(self): def fewshot_description(self):
# TODO: figure out fewshot description # TODO: figure out fewshot description
return "" return ""
def doc_to_text(self, doc): def has_validation_docs(self):
# TODO: implement return True
pass
def doc_to_target(self, doc): def has_train_docs(self):
# TODO: implement return True
pass
def construct_requests(self, doc, ctx): def has_test_docs(self):
""" Uses RequestFactory to construct Requests and returns an iterable of return True
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def process_results(self, doc, results): def docs_for_split(self, split):
"""Take a single document and the LM results and evaluates, returning a ret = []
dict where keys are the names of submetrics and values are the values of for line in open(f"data/wikitext/wikitext-2-raw/wiki.{split}.raw").read().split('\n'):
the metric for that one document rline = line.replace("= = =", "===").replace("= =", "==").strip()
if rline.startswith('= ') and rline.strip().endswith(' ='):
:param doc: s = '\n'.join(ret)
The document as returned from training_docs, validation_docs, or test_docs. if s.strip(): yield s
:param results: ret = []
The results of the requests created in construct_requests. ret.append(line)
""" yield '\n'.join(ret)
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') def validation_docs(self):
return self.docs_for_split('valid')
def aggregation(self):
""" def train_docs(self):
:returns: {str: [float] -> float} return self.docs_for_split('train')
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics def test_docs(self):
""" return self.docs_for_split('test')
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
class WikiText2(HFTask):
NLP_PATH = "wikitext"
NLP_NAME = "wikitext-2-raw-v1"
def fewshot_description(self):
# TODO: figure out fewshot description
return ""
def doc_to_text(self, doc):
# TODO: implement
pass
def doc_to_target(self, doc): def doc_to_target(self, doc):
# TODO: implement return wikitext_detokenizer(doc)
pass
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def process_results(self, doc, results): def count_words(self, doc):
"""Take a single document and the LM results and evaluates, returning a # count number of words in *original doc before detokenization*
dict where keys are the names of submetrics and values are the values of return len(re.split(r"\s+", doc))
the metric for that one document \ No newline at end of file
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
...@@ -11,6 +11,7 @@ Reference: https://arxiv.org/abs/1806.02847 ...@@ -11,6 +11,7 @@ Reference: https://arxiv.org/abs/1806.02847
class Winogrande(HFTask): class Winogrande(HFTask):
VERSION = 0
DATASET_PATH = "winogrande" DATASET_PATH = "winogrande"
DATASET_NAME = "winogrande_xl" DATASET_NAME = "winogrande_xl"
......
...@@ -12,6 +12,7 @@ See: https://arxiv.org/abs/1806.02847 ...@@ -12,6 +12,7 @@ See: https://arxiv.org/abs/1806.02847
class WinogradSchemaChallenge273(HFTask): class WinogradSchemaChallenge273(HFTask):
VERSION = 0
DATASET_PATH = "winograd_wsc" DATASET_PATH = "winograd_wsc"
DATASET_NAME = "wsc273" DATASET_NAME = "wsc273"
......
...@@ -61,6 +61,56 @@ def general_detokenize(string): ...@@ -61,6 +61,56 @@ def general_detokenize(string):
return string return string
def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len):
"""
- context_len allows for a rolling window context, allowing each prediction window to potentially
condition on some context
:param token_list: list
List of tokens to be PREDICTED
:param max_seq_len: int
max_seq_len of model (or max_seq_len we want to use)
:param context_len: int
Amount of desired token context for prediction. Needs to be at least 1.
:param prefix_token: token
Dummy token like <eos> so the first token has something to condition on
:return: generator
Generator of tuples
(input_tokens, pred_tokens)
Note: Score only the last len(pred_tokens) logits of the LM
"""
assert 1 <= context_len <= max_seq_len
if not token_list:
return
# +1 offset, going from input->preds
pred_len = max_seq_len - context_len + 1
predicted = 0
# Special handling for first window: predict all tokens
first_seq_len = min(max_seq_len, len(token_list))
yield (
[prefix_token] + token_list[:first_seq_len - 1],
token_list[:first_seq_len]
)
predicted += first_seq_len
while predicted < len(token_list):
window_pred_len = min(len(token_list) - predicted, pred_len)
window_end = predicted + window_pred_len
yield (
token_list[window_end - max_seq_len - 1:window_end - 1],
token_list[window_end - window_pred_len:window_end],
)
predicted += window_pred_len
def make_disjoint_window(pair):
""" Takes output from get_rolling_token_windows and makes the context not overlap with the continuation """
a, b = pair
return a[:-(len(b) - 1)], b
class Reorderer: class Reorderer:
def __init__(self, arr, fn): def __init__(self, arr, fn):
self.size = len(arr) self.size = len(arr)
......
...@@ -15,6 +15,8 @@ def parse_args(): ...@@ -15,6 +15,8 @@ def parse_args():
parser.add_argument('--tasks', default="all_tasks") parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--provide_description', action="store_true") parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0) parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None)
parser.add_argument('--device', type=str, default=None)
parser.add_argument('--seed', type=int, default=1234) parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--output_path', default=None) parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None) parser.add_argument('--limit', type=int, default=None)
...@@ -27,7 +29,9 @@ def main(): ...@@ -27,7 +29,9 @@ def main():
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
lm = models.get_model(args.model).create_from_arg_string(args.model_args) lm = models.get_model(args.model).create_from_arg_string(args.model_args, {
'batch_size': args.batch_size, 'device': args.device
})
if args.limit: if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
...@@ -49,20 +53,36 @@ def main(): ...@@ -49,20 +53,36 @@ def main():
f.write(dumped) f.write(dumped)
# MAKE TABLE # MAKE TABLE
from pytablewriter import MarkdownTableWriter from pytablewriter import MarkdownTableWriter, LatexTableWriter
writer = MarkdownTableWriter() md_writer = MarkdownTableWriter()
writer.headers = ["Task", "Metric", "Value"] latex_writer = LatexTableWriter()
md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
values = [] values = []
for k, dic in results.items(): for k, dic in results["results"].items():
version = results["versions"][k]
for m, v in dic.items(): for m, v in dic.items():
values.append([k, m, '%.4f' % v]) if m.endswith("_stderr"): continue
if m + "_stderr" in dic:
se = dic[m + "_stderr"]
values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
else:
values.append([k, version, m, '%.4f' % v, '', ''])
k = "" k = ""
writer.value_matrix = values version = ""
md_writer.value_matrix = values
latex_writer.value_matrix = values
# todo: make latex table look good
# print(latex_writer.dumps())
print(writer.dumps()) print(f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}")
print(md_writer.dumps())
if __name__ == "__main__": if __name__ == "__main__":
main() main()
black==20.8b1 -e .
best_download>=0.0.5
datasets>=1.2.1
click>=7.1
scikit-learn>=0.24.1
torch>=1.7
transformers>=4.1
sqlitedict==1.6.0
pytablewriter==0.58.0
sacrebleu==1.5.0
pycountry==20.7.3
numexpr==2.7.2
\ No newline at end of file
import os
import zstandard
import json
import jsonlines
import io
import datetime
def json_serial(obj):
"""JSON serializer for objects not serializable by default json code"""
if isinstance(obj, (datetime.datetime,)):
return obj.isoformat()
raise TypeError ("Type %s not serializable" % type(obj))
# Modified version of lm_dataformat Archive for single file.
class Archive:
def __init__(self, file_path, compression_level=3):
self.file_path = file_path
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, 'wb')
self.cctx = zstandard.ZstdCompressor(level=compression_level)
self.compressor = self.cctx.stream_writer(self.fh)
def add_data(self, data, meta={}):
self.compressor.write(json.dumps({'text': data, 'meta': meta}, default=json_serial).encode('UTF-8') + b'\n')
def commit(self):
self.compressor.flush(zstandard.FLUSH_FRAME)
self.fh.flush()
self.fh.close()
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader:
def __init__(self):
pass
def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n'):
with open(file, 'rb') as fh:
self.fh = fh
cctx = zstandard.ZstdDecompressor()
reader = io.BufferedReader(cctx.stream_reader(fh))
rdr = jsonlines.Reader(reader)
for ob in rdr:
# naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
if isinstance(ob, str):
assert not get_meta
yield ob
continue
text = ob['text']
if autojoin_paragraphs and isinstance(text, list):
text = para_joiner.join(text)
if get_meta:
yield text, (ob['meta'] if 'meta' in ob else {})
else:
yield text
# Simple text reader and writer with same interface as above
class TextArchive:
def __init__(self, file_path, mode="ab"):
self.file_path = file_path
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, mode)
def add_data(self, data, meta={}):
self.fh.write(data.encode('UTF-8') + b'\n')
def commit(self):
self.fh.flush()
self.fh.close()
class TextReader:
def __init__(self, file_path):
self.file_path = file_path
def read(self):
with open(self.file_path, 'r', encoding="utf8") as fh:
self.fh = fh
while True:
line = self.fh.readline()
if line == -1 or line == "":
break
else :
yield line[:-1]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment