Commit 1297e342 authored by &'s avatar &
Browse files

add tasks to registry

parent 12ba8426
......@@ -21,6 +21,7 @@ from . import pubmedqa
from . import sciq
from . import webqs
from . import qa4mre
from . import translation
TASK_REGISTRY = {
......@@ -85,6 +86,9 @@ TASK_REGISTRY = {
"arithmetic_2dm": arithmetic.Arithmetic2DMultiplication,
"arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
**translation.create_tasks_from_benchmarks(translation.selected_benchmarks)
}
......
......@@ -26,26 +26,26 @@ sacrebleu_datasets = sacrebleu.DATASETS
# 6 total
gpt3_tests = {
gpt3_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian
}
# 14 total
selected_tests = {
**gpt3_tests,
selected_benchmarks = {
**gpt3_benchmarks,
"wmt20": ['fr-de', 'de-fr', 'en-ru', 'ru-en', 'en-iu', 'iu-en'], # French, German, Russian, Inuit
"iwslt17": ['en-ar', 'ar-en'] # Arabic
}
# 319 total
all_tests = {
all_benchmarks = {
ts: sacrebleu.get_langpairs_for_testset(ts)
for ts in sacrebleu.get_available_testsets()
}
available_tests = {
"gpt3_tests": gpt3_tests,
"selected_tests": selected_tests,
"all_tests": all_tests
"gpt3_tests": gpt3_benchmarks,
"selected_tests": selected_benchmarks,
"all_tests": all_benchmarks
}
......@@ -53,6 +53,14 @@ available_tests = {
# Tasks
########################################
def create_tasks_from_benchmarks(benchmark_dict):
"""Creates a dictionary of tasks from a dict {dataset: [lang_pair, ...]}"""
return {
f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair)
for dataset, language_pairs in benchmark_dict.items()
for language_pair in language_pairs
}
def create_translation_task(dataset, language_pair):
class TranslationTask(GeneralTranslationTask):
def __init__(self):
......@@ -125,10 +133,11 @@ class GeneralTranslationTask(Task):
def process_results(self, doc, results):
# These metrics are corpus-level not sentence level, so we'll hide the
# results in this dict and compute the corpus score in the aggregate method
ref_pred = (doc["ref"], results)
return {
"bleu": (doc["ref"], results),
"chrf": (doc["ref"], results),
"ter": (doc["ref"], results),
"bleu": ref_pred,
"chrf": ref_pred,
"ter": ref_pred,
}
def aggregation(self):
......@@ -157,7 +166,9 @@ class GeneralTranslationTask(Task):
def fewshot_description(self):
language_codes = self.sacrebleu_language_pair.split("-")
return f"Translate {code_to_language(language_codes[0])} to {language_codes[1]}."
src_lang = code_to_language(language_codes[0])
tar_lang = code_to_language(language_codes[1])
return f"Translate these {src_lang} phrases to {tar_lang}."
# TODO This should be something like
# French: {src_line}
......@@ -165,6 +176,12 @@ class GeneralTranslationTask(Task):
def fewshot_context(self, doc, num_fewshot, provide_description):
return ""
def __str__(self):
language_codes = self.sacrebleu_language_pair.split("-")
src_lang = code_to_language(language_codes[0])
tar_lang = code_to_language(language_codes[1])
return f"{self.sacrebleu_dataset.upper()} {src_lang} to {tar_lang} Task"
########################################
# Util
......@@ -173,7 +190,7 @@ class GeneralTranslationTask(Task):
def code_to_language(code):
# key is alpha_2 or alpha_3 depending on the code length
language_tuple = pycountry.languages.get({f"alpha_{len(code)}": code})
language_tuple = pycountry.languages.get(**{f"alpha_{len(code)}": code})
return language_tuple.name
def print_available_tests():
......@@ -181,14 +198,20 @@ def print_available_tests():
def main():
# print(sacrebleu.download_test_set("wmt14", "en-fr"))
# print_available_tests()
# print(len(sacrebleu.print_test_set("wmt14", "fr-en", "src")))
# print(GeneralTranslationTask("wmt14", "fr-en"))
print(sum(
[len(sacrebleu.get_langpairs_for_testset(ts)) for ts in sacrebleu.get_available_testsets()])
)
pass
# sacrebleu.print_test_set("wmt14", "fr-en", "src")
# # Print number of benchmarks
# print(sum([
# len(sacrebleu.get_langpairs_for_testset(ts))
# for ts in sacrebleu.get_available_testsets()
# ]))
# Test task dictionary
# for task, task_class in create_tasks_from_benchmarks(selected_benchmarks).items():
# print(task, task_class())
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment