Commit b1619834 authored by &'s avatar &
Browse files

move task list to init file

parent b5bb7f6c
from pprint import pprint
import sacrebleu
from . import superglue
from . import glue
from . import arc
......@@ -26,6 +28,36 @@ from . import qa4mre
from . import translation
########################################
# Translation tasks
########################################
# 6 total
gpt3_translation_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian
}
# 28 total
selected_translation_benchmarks = {
**gpt3_translation_benchmarks,
"wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt17": ['en-ar', 'ar-en'] # Arabic
}
# 319 total
all_translation_benchmarks = {
ts: sacrebleu.get_langpairs_for_testset(ts)
for ts in sacrebleu.get_available_testsets()
}
########################################
# All tasks
########################################
TASK_REGISTRY = {
# GLUE
"cola": glue.CoLA,
......@@ -87,12 +119,13 @@ TASK_REGISTRY = {
"arithmetic_5ds": arithmetic.Arithmetic5DMinus,
"arithmetic_2dm": arithmetic.Arithmetic2DMultiplication,
"arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# e.g. wmt14-fr-en
**translation.create_tasks_from_benchmarks(translation.selected_benchmarks)
**translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# chef's selection, mostly wmt20
**translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
}
......
......@@ -2,6 +2,7 @@ import abc
import json
import random
import os
from collections import Iterable
from pprint import pprint
import pycountry
......@@ -20,36 +21,9 @@ See sacrebleu.DATASETS for all available datasets. There are a lot!
sacrebleu_datasets = sacrebleu.DATASETS
########################################
# Benchmarks one might want to run
########################################
# 6 total
gpt3_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian
}
# 14 total
selected_benchmarks = {
**gpt3_benchmarks,
"wmt20": ['fr-de', 'de-fr', 'en-ru', 'ru-en', 'en-iu', 'iu-en'], # French, German, Russian, Inuit
"iwslt17": ['en-ar', 'ar-en'] # Arabic
}
# 319 total
all_benchmarks = {
ts: sacrebleu.get_langpairs_for_testset(ts)
for ts in sacrebleu.get_available_testsets()
}
available_tests = {
"gpt3_tests": gpt3_benchmarks,
"selected_tests": selected_benchmarks,
"all_tests": all_benchmarks
}
def create_tasks_from_benchmarks(benchmark_dict):
"""Creates a dictionary of tasks from a dict
:param benchmark_dict: { dataset: [lang_pair, ...] }
:param benchmark_dict: { dataset: [lang_pair, ...], }
:return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
"""
......@@ -115,9 +89,8 @@ class GeneralTranslationTask(Task):
return doc["src"]
def doc_to_target(self, doc):
# TODO Note that some exotic tests have multiple ref lines.
# How does sacrebleu handle opening these files?
return doc["ref"]
# This shows a single target, though there may be multiple targets in a lang test
return doc["ref"] if isinstance(doc["ref"], str) else doc["ref"][0]
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
......@@ -229,7 +202,6 @@ if __name__ == "__main__":
main()
########################################
# Don't mind me...!
########################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment