Commit 105fa974 authored by Leo Gao's avatar Leo Gao
Browse files

Add task versioning

parent f76e6367
......@@ -5,6 +5,7 @@ from . common import HFTask
class PiQA(HFTask, MultipleChoiceTask):
VERSION = 0
DATASET_PATH = "piqa"
DATASET_NAME = None
......
......@@ -5,6 +5,7 @@ from ..metrics import mean
class Pubmed_QA(HFTask):
VERSION = 0
DATASET_PATH = "pubmed_qa"
DATASET_NAME = "pqa_labeled"
......
......@@ -5,6 +5,7 @@ from lm_eval.base import MultipleChoiceTask
class QA4MRE(MultipleChoiceTask):
VERSION = 0
YEAR = None
def download(self):
year = self.YEAR
......
......@@ -15,6 +15,7 @@ class each:
class RACE(HFTask):
VERSION = 0
DATASET_PATH = "race"
DATASET_NAME = "high"
......
......@@ -3,6 +3,7 @@ from lm_eval.base import MultipleChoiceTask
class SATAnalogies(MultipleChoiceTask):
VERSION = 0
NEEDS_MANUAL_DL = True
def __init__(self):
......
......@@ -6,6 +6,7 @@ from best_download import download_file
class SciQ(MultipleChoiceTask):
VERSION = 0
# Multiple languages and multiple years
def download(self):
if not os.path.exists('data/sciq'):
......
......@@ -18,6 +18,7 @@ def _squad_agg(key, items):
class SQuAD2(HFTask):
VERSION = 0
DATASET_PATH = "squad_v2"
DATASET_NAME = None
......
......@@ -3,6 +3,7 @@ from lm_eval.base import Task
class StoryCloze(Task):
VERSION = 0
NEEDS_MANUAL_DL = True
def download(self):
......
......@@ -13,6 +13,7 @@ from ..utils import general_detokenize
class BoolQ(HFTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "boolq"
......@@ -64,6 +65,7 @@ class BoolQ(HFTask):
class CommitmentBank(HFTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "cb"
......@@ -135,6 +137,7 @@ class CommitmentBank(HFTask):
class Copa(HFTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "copa"
......@@ -199,6 +202,7 @@ class Copa(HFTask):
class MultiRC(HFTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "multirc"
......@@ -253,6 +257,7 @@ class MultiRC(HFTask):
class ReCoRD(HFTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "record"
......@@ -345,6 +350,7 @@ class ReCoRD(HFTask):
class WordsInContext(HFTask):
VERSION = 0
DATASET_PATH = "super_glue"
DATASET_NAME = "wic"
......@@ -400,6 +406,7 @@ class WordsInContext(HFTask):
class SGWinogradSchemaChallenge(HFTask):
VERSION = 0
# Note: This implementation differs from Fig G.32 because this is the SuperGLUE,
# binary version of the task.
DATASET_PATH = "super_glue"
......
......@@ -36,6 +36,7 @@ def create_translation_task(dataset, language_pair):
return TranslationTask
class GeneralTranslationTask(Task):
VERSION = 0
# e.g. ("wmt14", "fr-en")
def __init__(self, sacrebleu_dataset, sacrebleu_language_pair=None):
......
......@@ -6,6 +6,7 @@ from ..utils import sh
class TriviaQA(Task):
VERSION = 0
def download(self):
if not os.path.exists('data/triviaqa'):
sh("""
......
......@@ -14,6 +14,7 @@ def extract_gzip(gz, to):
class WordUnscrambleTask(Task):
VERSION = 0
BASE_PATH = Path("data/unscramble")
FILENAME = None
CHECKSUM = None # SHA256 Checksum.
......
......@@ -4,6 +4,7 @@ from ..metrics import mean
class WebQs(HFTask):
VERSION = 0
DATASET_PATH = "web_questions"
DATASET_NAME = None
......
......@@ -2,6 +2,7 @@ from . common import HFTask
class WikiText103(HFTask):
VERSION = 0
NLP_PATH = "wikitext"
NLP_NAME = "wikitext-103-raw-v1"
......@@ -64,6 +65,7 @@ class WikiText103(HFTask):
class WikiText2(HFTask):
VERSION = 0
NLP_PATH = "wikitext"
NLP_NAME = "wikitext-2-raw-v1"
......
......@@ -11,6 +11,7 @@ Reference: https://arxiv.org/abs/1806.02847
class Winogrande(HFTask):
VERSION = 0
DATASET_PATH = "winogrande"
DATASET_NAME = "winogrande_xl"
......
......@@ -12,6 +12,7 @@ See: https://arxiv.org/abs/1806.02847
class WinogradSchemaChallenge273(HFTask):
VERSION = 0
DATASET_PATH = "winograd_wsc"
DATASET_NAME = "wsc273"
......
......@@ -53,20 +53,35 @@ def main():
f.write(dumped)
# MAKE TABLE
from pytablewriter import MarkdownTableWriter
from pytablewriter import MarkdownTableWriter, LatexTableWriter
writer = MarkdownTableWriter()
writer.headers = ["Task", "Metric", "Value"]
md_writer = MarkdownTableWriter()
latex_writer = LatexTableWriter()
md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
values = []
for k, dic in results.items():
for k, dic in results["results"].items():
version = results["versions"][k]
for m, v in dic.items():
values.append([k, m, '%.4f' % v])
if m.endswith("_stderr"): continue
if m + "_stderr" in dic:
se = dic[m + "_stderr"]
values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
else:
values.append([k, version, m, '%.4f' % v, '', ''])
k = ""
writer.value_matrix = values
version = ""
md_writer.value_matrix = values
latex_writer.value_matrix = values
# todo: make latex table look good
# print(latex_writer.dumps())
print(writer.dumps())
print(md_writer.dumps())
if __name__ == "__main__":
main()
......@@ -22,6 +22,8 @@ def test_basic_interface(taskname, Task):
for v in task.higher_is_better().values(): assert v in [True, False]
assert isinstance(task.VERSION, int)
# test deterministic docs
# (don't test train because it's slow)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment