Commit d0a301cc authored by &'s avatar &
Browse files

fix metric iterable checks

parent b1619834
import math import math
from collections import Iterable
from pprint import pprint from pprint import pprint
import numpy as np import numpy as np
...@@ -107,6 +108,10 @@ def ter(items): ...@@ -107,6 +108,10 @@ def ter(items):
return sacrebleu.corpus_ter(preds, refs).score return sacrebleu.corpus_ter(preds, refs).score
def is_non_str_iterable(obj):
return isinstance(obj, Iterable) and not isinstance(obj, str)
def _sacreformat(refs, preds): def _sacreformat(refs, preds):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular""" """Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str]) # Sacrebleu expects (List[str], List[List[str])
...@@ -118,17 +123,17 @@ def _sacreformat(refs, preds): ...@@ -118,17 +123,17 @@ def _sacreformat(refs, preds):
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds # We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds # Must become List[List[str]] with the inner list corresponding to preds
if not isinstance(refs, list): if not is_non_str_iterable(refs):
refs = list(refs) refs = list(refs)
if not isinstance(refs[0], list): if not is_non_str_iterable(refs):
refs = [[ref] for ref in refs] refs = [[ref] for ref in refs]
refs = list(zip(*refs)) refs = list(zip(*refs))
# Note the number of refs in each ref list much match the number of preds # Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str] # We expect preds to be List[str] or List[List[str]]. Must become List[str]
if not isinstance(preds, list): if not is_non_str_iterable(preds):
preds = list(preds) preds = list(preds)
if isinstance(preds[0], list): if is_non_str_iterable(preds[0]):
assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}" assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
preds = [pred[0] for pred in preds] preds = [pred[0] for pred in preds]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment