Commit 61aa411b authored by &'s avatar &
Browse files

Fixed metrics. Translation works

parent 1297e342
env
*.pyc
data/
.idea
\ No newline at end of file
.idea
lm_cache
\ No newline at end of file
import math
from pprint import pprint
import numpy as np
import sacrebleu
......@@ -73,7 +74,7 @@ def bleu(items):
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = sacreformat(refs, preds)
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_bleu(preds, refs).score
......@@ -87,7 +88,7 @@ def chrf(items):
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = sacreformat(refs, preds)
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_chrf(preds, refs).score
......@@ -102,22 +103,33 @@ def ter(items):
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = sacreformat(refs, preds)
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_ter(preds, refs).score
def sacreformat(refs, preds):
def _sacreformat(refs, preds):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects List[List[str]
# Sacrebleu expects (List[str], List[List[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# We expect refs to be List[str] or List[List[str]]
# Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds
if not isinstance(refs, list):
refs = list(refs)
if not isinstance(refs[0], list):
refs = [[ref] for ref in refs]
refs = list(zip(*refs))
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
if not isinstance(preds, list):
preds = list(preds)
if isinstance(preds[0], list):
assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
preds = [pred[0] for pred in preds]
return refs, preds
......@@ -24,7 +24,6 @@ sacrebleu_datasets = sacrebleu.DATASETS
# Benchmarks one might want to run
########################################
# 6 total
gpt3_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French
......@@ -48,19 +47,22 @@ available_tests = {
"all_tests": all_benchmarks
}
########################################
# Tasks
########################################
def create_tasks_from_benchmarks(benchmark_dict):
"""Creates a dictionary of tasks from a dict {dataset: [lang_pair, ...]}"""
"""Creates a dictionary of tasks from a dict
:param benchmark_dict: { dataset: [lang_pair, ...] }
:return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
"""
return {
f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair)
for dataset, language_pairs in benchmark_dict.items()
for language_pair in language_pairs
}
########################################
# Tasks
########################################
def create_translation_task(dataset, language_pair):
class TranslationTask(GeneralTranslationTask):
def __init__(self):
......@@ -198,7 +200,6 @@ def print_available_tests():
def main():
# print(sacrebleu.download_test_set("wmt14", "en-fr"))
# print_available_tests()
# sacrebleu.print_test_set("wmt14", "fr-en", "src")
......@@ -212,12 +213,18 @@ def main():
# Test task dictionary
# for task, task_class in create_tasks_from_benchmarks(selected_benchmarks).items():
# print(task, task_class())
pass
if __name__ == "__main__":
main()
########################################
# Don't mind me...!
########################################
# Available tests as of 2020/02/11
"""
{'iwslt17': ['en-fr',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment