Unverified Commit 1e5194a2 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #153 from EleutherAI/translation-v2

Translation v2
parents 758b9e3c d0a301cc
import math import math
from collections import Iterable
from pprint import pprint from pprint import pprint
import numpy as np import numpy as np
...@@ -107,6 +108,10 @@ def ter(items): ...@@ -107,6 +108,10 @@ def ter(items):
return sacrebleu.corpus_ter(preds, refs).score return sacrebleu.corpus_ter(preds, refs).score
def is_non_str_iterable(obj):
return isinstance(obj, Iterable) and not isinstance(obj, str)
def _sacreformat(refs, preds): def _sacreformat(refs, preds):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular""" """Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str]) # Sacrebleu expects (List[str], List[List[str])
...@@ -118,17 +123,17 @@ def _sacreformat(refs, preds): ...@@ -118,17 +123,17 @@ def _sacreformat(refs, preds):
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds # We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds # Must become List[List[str]] with the inner list corresponding to preds
if not isinstance(refs, list): if not is_non_str_iterable(refs):
refs = list(refs) refs = list(refs)
if not isinstance(refs[0], list): if not is_non_str_iterable(refs):
refs = [[ref] for ref in refs] refs = [[ref] for ref in refs]
refs = list(zip(*refs)) refs = list(zip(*refs))
# Note the number of refs in each ref list much match the number of preds # Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str] # We expect preds to be List[str] or List[List[str]]. Must become List[str]
if not isinstance(preds, list): if not is_non_str_iterable(preds):
preds = list(preds) preds = list(preds)
if isinstance(preds[0], list): if is_non_str_iterable(preds[0]):
assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}" assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
preds = [pred[0] for pred in preds] preds = [pred[0] for pred in preds]
......
from pprint import pprint from pprint import pprint
import sacrebleu
from . import superglue from . import superglue
from . import glue from . import glue
from . import arc from . import arc
...@@ -27,6 +29,36 @@ from . import translation ...@@ -27,6 +29,36 @@ from . import translation
from . import headqa from . import headqa
from . import mathqa from . import mathqa
########################################
# Translation tasks
########################################
# 6 total
gpt3_translation_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian
}
# 28 total
selected_translation_benchmarks = {
**gpt3_translation_benchmarks,
"wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt17": ['en-ar', 'ar-en'] # Arabic
}
# 319 total
all_translation_benchmarks = {
ts: sacrebleu.get_langpairs_for_testset(ts)
for ts in sacrebleu.get_available_testsets()
}
########################################
# All tasks
########################################
TASK_REGISTRY = { TASK_REGISTRY = {
# GLUE # GLUE
"cola": glue.CoLA, "cola": glue.CoLA,
...@@ -90,12 +122,13 @@ TASK_REGISTRY = { ...@@ -90,12 +122,13 @@ TASK_REGISTRY = {
"arithmetic_5ds": arithmetic.Arithmetic5DMinus, "arithmetic_5ds": arithmetic.Arithmetic5DMinus,
"arithmetic_2dm": arithmetic.Arithmetic2DMultiplication, "arithmetic_2dm": arithmetic.Arithmetic2DMultiplication,
"arithmetic_1dc": arithmetic.Arithmetic1DComposite, "arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks # TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations # e.g. anli, arithmetic, openai_translations, harness_translations
# e.g. wmt14-fr-en # e.g. wmt14-fr-en
**translation.create_tasks_from_benchmarks(translation.selected_benchmarks) **translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# chef's selection, mostly wmt20
**translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
} }
......
...@@ -2,6 +2,7 @@ import abc ...@@ -2,6 +2,7 @@ import abc
import json import json
import random import random
import os import os
from collections import Iterable
from pprint import pprint from pprint import pprint
import pycountry import pycountry
...@@ -20,36 +21,9 @@ See sacrebleu.DATASETS for all available datasets. There are a lot! ...@@ -20,36 +21,9 @@ See sacrebleu.DATASETS for all available datasets. There are a lot!
sacrebleu_datasets = sacrebleu.DATASETS sacrebleu_datasets = sacrebleu.DATASETS
########################################
# Benchmarks one might want to run
########################################
# 6 total
gpt3_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian
}
# 14 total
selected_benchmarks = {
**gpt3_benchmarks,
"wmt20": ['fr-de', 'de-fr', 'en-ru', 'ru-en', 'en-iu', 'iu-en'], # French, German, Russian, Inuit
"iwslt17": ['en-ar', 'ar-en'] # Arabic
}
# 319 total
all_benchmarks = {
ts: sacrebleu.get_langpairs_for_testset(ts)
for ts in sacrebleu.get_available_testsets()
}
available_tests = {
"gpt3_tests": gpt3_benchmarks,
"selected_tests": selected_benchmarks,
"all_tests": all_benchmarks
}
def create_tasks_from_benchmarks(benchmark_dict): def create_tasks_from_benchmarks(benchmark_dict):
"""Creates a dictionary of tasks from a dict """Creates a dictionary of tasks from a dict
:param benchmark_dict: { dataset: [lang_pair, ...] } :param benchmark_dict: { dataset: [lang_pair, ...], }
:return: {task_name: task} :return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task} e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
""" """
...@@ -115,9 +89,8 @@ class GeneralTranslationTask(Task): ...@@ -115,9 +89,8 @@ class GeneralTranslationTask(Task):
return doc["src"] return doc["src"]
def doc_to_target(self, doc): def doc_to_target(self, doc):
# TODO Note that some exotic tests have multiple ref lines. # This shows a single target, though there may be multiple targets in a lang test
# How does sacrebleu handle opening these files? return doc["ref"] if isinstance(doc["ref"], str) else doc["ref"][0]
return doc["ref"]
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -199,6 +172,14 @@ def print_available_tests(): ...@@ -199,6 +172,14 @@ def print_available_tests():
pprint({ts: sacrebleu.get_langpairs_for_testset(ts) for ts in sacrebleu.get_available_testsets()}) pprint({ts: sacrebleu.get_langpairs_for_testset(ts) for ts in sacrebleu.get_available_testsets()})
def print_available_pairs():
list_of_pairs = [sacrebleu.get_langpairs_for_testset(ts) for ts in sacrebleu.get_available_testsets()]
pairs = set([item for sublist in list_of_pairs for item in sublist])
pairs = sorted(["-".join(map(code_to_language, pair.split("-"))) for pair in pairs])
pprint(pairs)
print(len(pairs))
def main(): def main():
# print(sacrebleu.download_test_set("wmt14", "en-fr")) # print(sacrebleu.download_test_set("wmt14", "en-fr"))
# print_available_tests() # print_available_tests()
...@@ -213,6 +194,7 @@ def main(): ...@@ -213,6 +194,7 @@ def main():
# Test task dictionary # Test task dictionary
# for task, task_class in create_tasks_from_benchmarks(selected_benchmarks).items(): # for task, task_class in create_tasks_from_benchmarks(selected_benchmarks).items():
# print(task, task_class()) # print(task, task_class())
print_available_pairs()
pass pass
...@@ -220,7 +202,6 @@ if __name__ == "__main__": ...@@ -220,7 +202,6 @@ if __name__ == "__main__":
main() main()
######################################## ########################################
# Don't mind me...! # Don't mind me...!
######################################## ########################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment