Commit b1619834 authored by &'s avatar &
Browse files

move task list to init file

parent b5bb7f6c
from pprint import pprint from pprint import pprint
import sacrebleu
from . import superglue from . import superglue
from . import glue from . import glue
from . import arc from . import arc
...@@ -26,6 +28,36 @@ from . import qa4mre ...@@ -26,6 +28,36 @@ from . import qa4mre
from . import translation from . import translation
########################################
# Translation tasks
########################################
# 6 total
gpt3_translation_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian
}
# 28 total
selected_translation_benchmarks = {
**gpt3_translation_benchmarks,
"wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt17": ['en-ar', 'ar-en'] # Arabic
}
# 319 total
all_translation_benchmarks = {
ts: sacrebleu.get_langpairs_for_testset(ts)
for ts in sacrebleu.get_available_testsets()
}
########################################
# All tasks
########################################
TASK_REGISTRY = { TASK_REGISTRY = {
# GLUE # GLUE
"cola": glue.CoLA, "cola": glue.CoLA,
...@@ -87,12 +119,13 @@ TASK_REGISTRY = { ...@@ -87,12 +119,13 @@ TASK_REGISTRY = {
"arithmetic_5ds": arithmetic.Arithmetic5DMinus, "arithmetic_5ds": arithmetic.Arithmetic5DMinus,
"arithmetic_2dm": arithmetic.Arithmetic2DMultiplication, "arithmetic_2dm": arithmetic.Arithmetic2DMultiplication,
"arithmetic_1dc": arithmetic.Arithmetic1DComposite, "arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks # TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations # e.g. anli, arithmetic, openai_translations, harness_translations
# e.g. wmt14-fr-en # e.g. wmt14-fr-en
**translation.create_tasks_from_benchmarks(translation.selected_benchmarks) **translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# chef's selection, mostly wmt20
**translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
} }
......
...@@ -2,6 +2,7 @@ import abc ...@@ -2,6 +2,7 @@ import abc
import json import json
import random import random
import os import os
from collections import Iterable
from pprint import pprint from pprint import pprint
import pycountry import pycountry
...@@ -20,36 +21,9 @@ See sacrebleu.DATASETS for all available datasets. There are a lot! ...@@ -20,36 +21,9 @@ See sacrebleu.DATASETS for all available datasets. There are a lot!
sacrebleu_datasets = sacrebleu.DATASETS sacrebleu_datasets = sacrebleu.DATASETS
########################################
# Benchmarks one might want to run
########################################
# 6 total
gpt3_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian
}
# 14 total
selected_benchmarks = {
**gpt3_benchmarks,
"wmt20": ['fr-de', 'de-fr', 'en-ru', 'ru-en', 'en-iu', 'iu-en'], # French, German, Russian, Inuit
"iwslt17": ['en-ar', 'ar-en'] # Arabic
}
# 319 total
all_benchmarks = {
ts: sacrebleu.get_langpairs_for_testset(ts)
for ts in sacrebleu.get_available_testsets()
}
available_tests = {
"gpt3_tests": gpt3_benchmarks,
"selected_tests": selected_benchmarks,
"all_tests": all_benchmarks
}
def create_tasks_from_benchmarks(benchmark_dict): def create_tasks_from_benchmarks(benchmark_dict):
"""Creates a dictionary of tasks from a dict """Creates a dictionary of tasks from a dict
:param benchmark_dict: { dataset: [lang_pair, ...] } :param benchmark_dict: { dataset: [lang_pair, ...], }
:return: {task_name: task} :return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task} e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
""" """
...@@ -115,9 +89,8 @@ class GeneralTranslationTask(Task): ...@@ -115,9 +89,8 @@ class GeneralTranslationTask(Task):
return doc["src"] return doc["src"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
# TODO Note that some exotic tests have multiple ref lines. # This shows a single target, though there may be multiple targets in a lang test
# How does sacrebleu handle opening these files? return doc["ref"] if isinstance(doc["ref"], str) else doc["ref"][0]
return doc["ref"]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -229,7 +202,6 @@ if __name__ == "__main__": ...@@ -229,7 +202,6 @@ if __name__ == "__main__":
main() main()
######################################## ########################################
# Don't mind me...! # Don't mind me...!
######################################## ########################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment