Commit 116c540a authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

make chrf and ter aggregations

parent 8806eff5
......@@ -74,6 +74,37 @@ def bleu(items):
return sacrebleu.corpus_bleu(preds, refs).score
@register_aggregation("chrf")
def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
Source: https://github.com/m-popovic/chrF
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
Higher is better # TODO I think
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_chrf(preds, refs).score
@register_aggregation("ter")
def ter(items):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
Source: http://www.cs.umd.edu/~snover/tercom/
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
Lower is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_ter(preds, refs).score
@register_metric(
metric="acc",
higher_is_better=True,
......@@ -188,6 +219,26 @@ def bleu_fn(items): # This is a passthrough function
return items
@register_metric(
metric="chrf",
higher_is_better=True,
output_type="greedy_until",
aggregation="chrf",
)
def chrf_fn(items): # This is a passthrough function
return items
@register_metric(
metric="ter",
higher_is_better=True,
output_type="greedy_until",
aggregation="ter",
)
def ter_fn(items): # This is a passthrough function
return items
@register_metric(
metric="acc_all",
higher_is_better=True,
......@@ -245,37 +296,6 @@ def weighted_mean(items):
return sum(a) / sum(b)
@register_metric(metric="chrf", higher_is_better=True, aggregation="mean")
def chrf(items):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
Source: https://github.com/m-popovic/chrF
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
Higher is better # TODO I think
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_chrf(preds, refs).score
@register_metric(metric="ter", higher_is_better=True, aggregation="mean")
def ter(items):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
Source: http://www.cs.umd.edu/~snover/tercom/
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
Lower is better
"""
refs = list(zip(*items))[0]
preds = list(zip(*items))[1]
refs, preds = _sacreformat(refs, preds)
return sacrebleu.corpus_ter(preds, refs).score
def is_non_str_iterable(obj):
return isinstance(obj, Iterable) and not isinstance(obj, str)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment