Commit 1297e342 authored by &'s avatar &
Browse files

add tasks to registry

parent 12ba8426
...@@ -21,6 +21,7 @@ from . import pubmedqa ...@@ -21,6 +21,7 @@ from . import pubmedqa
from . import sciq from . import sciq
from . import webqs from . import webqs
from . import qa4mre from . import qa4mre
from . import translation
TASK_REGISTRY = { TASK_REGISTRY = {
...@@ -85,6 +86,9 @@ TASK_REGISTRY = { ...@@ -85,6 +86,9 @@ TASK_REGISTRY = {
"arithmetic_2dm": arithmetic.Arithmetic2DMultiplication, "arithmetic_2dm": arithmetic.Arithmetic2DMultiplication,
"arithmetic_1dc": arithmetic.Arithmetic1DComposite, "arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
**translation.create_tasks_from_benchmarks(translation.selected_benchmarks)
} }
......
...@@ -26,26 +26,26 @@ sacrebleu_datasets = sacrebleu.DATASETS ...@@ -26,26 +26,26 @@ sacrebleu_datasets = sacrebleu.DATASETS
# 6 total # 6 total
gpt3_tests = { gpt3_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French "wmt14": ['en-fr', 'fr-en'], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian "wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian
} }
# 14 total # 14 total
selected_tests = { selected_benchmarks = {
**gpt3_tests, **gpt3_benchmarks,
"wmt20": ['fr-de', 'de-fr', 'en-ru', 'ru-en', 'en-iu', 'iu-en'], # French, German, Russian, Inuit "wmt20": ['fr-de', 'de-fr', 'en-ru', 'ru-en', 'en-iu', 'iu-en'], # French, German, Russian, Inuit
"iwslt17": ['en-ar', 'ar-en'] # Arabic "iwslt17": ['en-ar', 'ar-en'] # Arabic
} }
# 319 total # 319 total
all_tests = { all_benchmarks = {
ts: sacrebleu.get_langpairs_for_testset(ts) ts: sacrebleu.get_langpairs_for_testset(ts)
for ts in sacrebleu.get_available_testsets() for ts in sacrebleu.get_available_testsets()
} }
available_tests = { available_tests = {
"gpt3_tests": gpt3_tests, "gpt3_tests": gpt3_benchmarks,
"selected_tests": selected_tests, "selected_tests": selected_benchmarks,
"all_tests": all_tests "all_tests": all_benchmarks
} }
...@@ -53,6 +53,14 @@ available_tests = { ...@@ -53,6 +53,14 @@ available_tests = {
# Tasks # Tasks
######################################## ########################################
def create_tasks_from_benchmarks(benchmark_dict):
"""Creates a dictionary of tasks from a dict {dataset: [lang_pair, ...]}"""
return {
f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair)
for dataset, language_pairs in benchmark_dict.items()
for language_pair in language_pairs
}
def create_translation_task(dataset, language_pair): def create_translation_task(dataset, language_pair):
class TranslationTask(GeneralTranslationTask): class TranslationTask(GeneralTranslationTask):
def __init__(self): def __init__(self):
...@@ -125,10 +133,11 @@ class GeneralTranslationTask(Task): ...@@ -125,10 +133,11 @@ class GeneralTranslationTask(Task):
def process_results(self, doc, results): def process_results(self, doc, results):
# These metrics are corpus-level not sentence level, so we'll hide the # These metrics are corpus-level not sentence level, so we'll hide the
# results in this dict and compute the corpus score in the aggregate method # results in this dict and compute the corpus score in the aggregate method
ref_pred = (doc["ref"], results)
return { return {
"bleu": (doc["ref"], results), "bleu": ref_pred,
"chrf": (doc["ref"], results), "chrf": ref_pred,
"ter": (doc["ref"], results), "ter": ref_pred,
} }
def aggregation(self): def aggregation(self):
...@@ -157,7 +166,9 @@ class GeneralTranslationTask(Task): ...@@ -157,7 +166,9 @@ class GeneralTranslationTask(Task):
def fewshot_description(self): def fewshot_description(self):
language_codes = self.sacrebleu_language_pair.split("-") language_codes = self.sacrebleu_language_pair.split("-")
return f"Translate {code_to_language(language_codes[0])} to {language_codes[1]}." src_lang = code_to_language(language_codes[0])
tar_lang = code_to_language(language_codes[1])
return f"Translate these {src_lang} phrases to {tar_lang}."
# TODO This should be something like # TODO This should be something like
# French: {src_line} # French: {src_line}
...@@ -165,6 +176,12 @@ class GeneralTranslationTask(Task): ...@@ -165,6 +176,12 @@ class GeneralTranslationTask(Task):
def fewshot_context(self, doc, num_fewshot, provide_description): def fewshot_context(self, doc, num_fewshot, provide_description):
return "" return ""
def __str__(self):
language_codes = self.sacrebleu_language_pair.split("-")
src_lang = code_to_language(language_codes[0])
tar_lang = code_to_language(language_codes[1])
return f"{self.sacrebleu_dataset.upper()} {src_lang} to {tar_lang} Task"
######################################## ########################################
# Util # Util
...@@ -173,7 +190,7 @@ class GeneralTranslationTask(Task): ...@@ -173,7 +190,7 @@ class GeneralTranslationTask(Task):
def code_to_language(code): def code_to_language(code):
# key is alpha_2 or alpha_3 depending on the code length # key is alpha_2 or alpha_3 depending on the code length
language_tuple = pycountry.languages.get({f"alpha_{len(code)}": code}) language_tuple = pycountry.languages.get(**{f"alpha_{len(code)}": code})
return language_tuple.name return language_tuple.name
def print_available_tests(): def print_available_tests():
...@@ -181,14 +198,20 @@ def print_available_tests(): ...@@ -181,14 +198,20 @@ def print_available_tests():
def main(): def main():
# print(sacrebleu.download_test_set("wmt14", "en-fr")) # print(sacrebleu.download_test_set("wmt14", "en-fr"))
# print_available_tests() # print_available_tests()
# print(len(sacrebleu.print_test_set("wmt14", "fr-en", "src"))) # sacrebleu.print_test_set("wmt14", "fr-en", "src")
# print(GeneralTranslationTask("wmt14", "fr-en"))
print(sum( # # Print number of benchmarks
[len(sacrebleu.get_langpairs_for_testset(ts)) for ts in sacrebleu.get_available_testsets()]) # print(sum([
) # len(sacrebleu.get_langpairs_for_testset(ts))
pass # for ts in sacrebleu.get_available_testsets()
# ]))
# Test task dictionary
# for task, task_class in create_tasks_from_benchmarks(selected_benchmarks).items():
# print(task, task_class())
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment