Commit 61aa411b authored by &'s avatar &
Browse files

Fixed metrics. Translation works

parent 1297e342
env env
*.pyc *.pyc
data/ data/
.idea .idea
\ No newline at end of file lm_cache
\ No newline at end of file
import math import math
from pprint import pprint
import numpy as np import numpy as np
import sacrebleu import sacrebleu
...@@ -73,7 +74,7 @@ def bleu(items): ...@@ -73,7 +74,7 @@ def bleu(items):
""" """
refs = list(zip(*items))[0] refs = list(zip(*items))[0]
preds = list(zip(*items))[1] preds = list(zip(*items))[1]
refs, preds = sacreformat(refs, preds) refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_bleu(preds, refs).score return sacrebleu.corpus_bleu(preds, refs).score
...@@ -87,7 +88,7 @@ def chrf(items): ...@@ -87,7 +88,7 @@ def chrf(items):
""" """
refs = list(zip(*items))[0] refs = list(zip(*items))[0]
preds = list(zip(*items))[1] preds = list(zip(*items))[1]
refs, preds = sacreformat(refs, preds) refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_chrf(preds, refs).score return sacrebleu.corpus_chrf(preds, refs).score
...@@ -102,22 +103,33 @@ def ter(items): ...@@ -102,22 +103,33 @@ def ter(items):
""" """
refs = list(zip(*items))[0] refs = list(zip(*items))[0]
preds = list(zip(*items))[1] preds = list(zip(*items))[1]
refs, preds = sacreformat(refs, preds) refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_ter(preds, refs).score return sacrebleu.corpus_ter(preds, refs).score
def sacreformat(refs, preds): def _sacreformat(refs, preds):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular""" """Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects List[List[str] # Sacrebleu expects (List[str], List[List[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...]) # e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# We expect refs to be List[str] or List[List[str]] # Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds
if not isinstance(refs, list): if not isinstance(refs, list):
refs = list(refs) refs = list(refs)
if not isinstance(refs[0], list): if not isinstance(refs[0], list):
refs = [[ref] for ref in refs] refs = [[ref] for ref in refs]
refs = list(zip(*refs))
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
if not isinstance(preds, list): if not isinstance(preds, list):
preds = list(preds) preds = list(preds)
if isinstance(preds[0], list):
assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
preds = [pred[0] for pred in preds]
return refs, preds return refs, preds
...@@ -24,7 +24,6 @@ sacrebleu_datasets = sacrebleu.DATASETS ...@@ -24,7 +24,6 @@ sacrebleu_datasets = sacrebleu.DATASETS
# Benchmarks one might want to run # Benchmarks one might want to run
######################################## ########################################
# 6 total # 6 total
gpt3_benchmarks = { gpt3_benchmarks = {
"wmt14": ['en-fr', 'fr-en'], # French "wmt14": ['en-fr', 'fr-en'], # French
...@@ -48,19 +47,22 @@ available_tests = { ...@@ -48,19 +47,22 @@ available_tests = {
"all_tests": all_benchmarks "all_tests": all_benchmarks
} }
########################################
# Tasks
########################################
def create_tasks_from_benchmarks(benchmark_dict): def create_tasks_from_benchmarks(benchmark_dict):
"""Creates a dictionary of tasks from a dict {dataset: [lang_pair, ...]}""" """Creates a dictionary of tasks from a dict
:param benchmark_dict: { dataset: [lang_pair, ...] }
:return: {task_name: task}
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
"""
return { return {
f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair) f"{dataset}-{language_pair}": create_translation_task(dataset, language_pair)
for dataset, language_pairs in benchmark_dict.items() for dataset, language_pairs in benchmark_dict.items()
for language_pair in language_pairs for language_pair in language_pairs
} }
########################################
# Tasks
########################################
def create_translation_task(dataset, language_pair): def create_translation_task(dataset, language_pair):
class TranslationTask(GeneralTranslationTask): class TranslationTask(GeneralTranslationTask):
def __init__(self): def __init__(self):
...@@ -198,7 +200,6 @@ def print_available_tests(): ...@@ -198,7 +200,6 @@ def print_available_tests():
def main(): def main():
# print(sacrebleu.download_test_set("wmt14", "en-fr")) # print(sacrebleu.download_test_set("wmt14", "en-fr"))
# print_available_tests() # print_available_tests()
# sacrebleu.print_test_set("wmt14", "fr-en", "src") # sacrebleu.print_test_set("wmt14", "fr-en", "src")
...@@ -212,12 +213,18 @@ def main(): ...@@ -212,12 +213,18 @@ def main():
# Test task dictionary # Test task dictionary
# for task, task_class in create_tasks_from_benchmarks(selected_benchmarks).items(): # for task, task_class in create_tasks_from_benchmarks(selected_benchmarks).items():
# print(task, task_class()) # print(task, task_class())
pass
if __name__ == "__main__": if __name__ == "__main__":
main() main()
########################################
# Don't mind me...!
########################################
# Available tests as of 2020/02/11 # Available tests as of 2020/02/11
""" """
{'iwslt17': ['en-fr', {'iwslt17': ['en-fr',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment