Unverified Commit 1e5194a2 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #153 from EleutherAI/translation-v2

Translation v2
parents 758b9e3c d0a301cc
import math
from collections import Iterable
from pprint import pprint
import numpy as np
......@@ -107,6 +108,10 @@ def ter(items):
return sacrebleu.corpus_ter(preds, refs).score
def is_non_str_iterable(obj):
return isinstance(obj, Iterable) and not isinstance(obj, str)
def _sacreformat(refs, preds):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str])
......@@ -118,17 +123,17 @@ def _sacreformat(refs, preds):
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds
if not isinstance(refs, list):
if not is_non_str_iterable(refs):
refs = list(refs)
if not isinstance(refs[0], list):
if not is_non_str_iterable(refs):
refs = [[ref] for ref in refs]
refs = list(zip(*refs))
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
if not isinstance(preds, list):
if not is_non_str_iterable(preds):
preds = list(preds)
if isinstance(preds[0], list):
if is_non_str_iterable(preds[0]):
assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
preds = [pred[0] for pred in preds]
......
from pprint import pprint
import sacrebleu
from . import superglue
from . import glue
from . import arc
......@@ -27,6 +29,36 @@ from . import translation
from . import headqa
from . import mathqa
########################################
# Translation tasks
########################################
# 6 total
gpt3_translation_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian
}
# 28 total
selected_translation_benchmarks = {
**gpt3_translation_benchmarks,
"wmt20": sacrebleu.get_langpairs_for_testset("wmt20"),
"iwslt17": ['en-ar', 'ar-en'] # Arabic
}
# 319 total
all_translation_benchmarks = {
ts: sacrebleu.get_langpairs_for_testset(ts)
for ts in sacrebleu.get_available_testsets()
}
########################################
# All tasks
########################################
TASK_REGISTRY = {
# GLUE
"cola": glue.CoLA,
......@@ -90,12 +122,13 @@ TASK_REGISTRY = {
"arithmetic_5ds": arithmetic.Arithmetic5DMinus,
"arithmetic_2dm": arithmetic.Arithmetic2DMultiplication,
"arithmetic_1dc": arithmetic.Arithmetic1DComposite,
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# e.g. wmt14-fr-en
**translation.create_tasks_from_benchmarks(translation.selected_benchmarks)
**translation.create_tasks_from_benchmarks(gpt3_translation_benchmarks),
# chef's selection, mostly wmt20
**translation.create_tasks_from_benchmarks(selected_translation_benchmarks),
}
......
......@@ -2,6 +2,7 @@ import abc
import json
import random
import os
from collections import Iterable
from pprint import pprint
import pycountry
......@@ -20,36 +21,9 @@ See sacrebleu.DATASETS for all available datasets. There are a lot!
sacrebleu_datasets = sacrebleu.DATASETS
########################################
# Benchmarks one might want to run
########################################
# 6 total
gpt3_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French
"wmt16": ['en-ro', 'ro-en', 'de-en', 'en-de'], # German, Romanian
}
# 14 total
selected_benchmarks = {
**gpt3_benchmarks,
"wmt20": ['fr-de', 'de-fr', 'en-ru', 'ru-en', 'en-iu', 'iu-en'], # French, German, Russian, Inuit
"iwslt17": ['en-ar', 'ar-en'] # Arabic
}
# 319 total
all_benchmarks = {
ts: sacrebleu.get_langpairs_for_testset(ts)
for ts in sacrebleu.get_available_testsets()
}
available_tests = {
"gpt3_tests": gpt3_benchmarks,
"selected_tests": selected_benchmarks,
"all_tests": all_benchmarks
}
def create_tasks_from_benchmarks(benchmark_dict):
"""Creates a dictionary of tasks from a dict
:param benchmark_dict: { dataset: [lang_pair, ...] }
:param benchmark_dict: { dataset: [lang_pair, ...], }
:return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
"""
......@@ -115,9 +89,8 @@ class GeneralTranslationTask(Task):
return doc["src"]
def doc_to_target(self, doc):
# TODO Note that some exotic tests have multiple ref lines.
# How does sacrebleu handle opening these files?
return doc["ref"]
# This shows a single target, though there may be multiple targets in a lang test
return doc["ref"] if isinstance(doc["ref"], str) else doc["ref"][0]
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
......@@ -199,6 +172,14 @@ def print_available_tests():
pprint({ts: sacrebleu.get_langpairs_for_testset(ts) for ts in sacrebleu.get_available_testsets()})
def print_available_pairs():
list_of_pairs = [sacrebleu.get_langpairs_for_testset(ts) for ts in sacrebleu.get_available_testsets()]
pairs = set([item for sublist in list_of_pairs for item in sublist])
pairs = sorted(["-".join(map(code_to_language, pair.split("-"))) for pair in pairs])
pprint(pairs)
print(len(pairs))
def main():
# print(sacrebleu.download_test_set("wmt14", "en-fr"))
# print_available_tests()
......@@ -213,6 +194,7 @@ def main():
# Test task dictionary
# for task, task_class in create_tasks_from_benchmarks(selected_benchmarks).items():
# print(task, task_class())
print_available_pairs()
pass
......@@ -220,7 +202,6 @@ if __name__ == "__main__":
main()
########################################
# Don't mind me...!
########################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment