Unverified Commit 448bcbfb authored by Charles Lovering's avatar Charles Lovering Committed by GitHub
Browse files

Merge pull request #3 from cjlovering/add-rouge

Add `ROUGE` metric to `PromptSourceTask`
parents d40a7ce0 ff89667f
......@@ -652,8 +652,6 @@ class PromptSourceTask(Task):
added default behavior for. If you want to add default behavior for a new metric,
update the functions below. If you want to use one of the following metrics,
*and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.
WARNING: ROUGE is WIP.
"""
CONFIGURED_PS_METRICS = set(["Accuracy", "BLEU", "ROUGE"])
......@@ -765,7 +763,6 @@ class PromptSourceTask(Task):
# NOTE: In the future, target will be a list of strings.
pred = results[0].strip()
out = {}
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
......@@ -773,8 +770,15 @@ class PromptSourceTask(Task):
if metric == "BLEU":
out["bleu"] = (target, pred)
if metric == "ROUGE":
print("WARNING: Skipping Rouge.")
# TODO: This computes all rouge sub-metrics. Find a generic
# way to handle user specified rouge sub-metrics to avoid extra
# compute.
rouge_scores = metrics.rouge(target, pred)
# Flatten rouge score dict.
rouge_scores = utils.flatten(rouge_scores)
# Merge all the rouge-type scores into the `out` dict.
out = {**out, **rouge_scores}
print(out)
return out
def higher_is_better(self):
......@@ -788,7 +792,22 @@ class PromptSourceTask(Task):
if metric == "BLEU":
out["bleu"] = True
if metric == "ROUGE":
print("WARNING: Skipping Rouge.")
# TODO: Find a generic way to handle user specified rouge metrics.
out["rouge1_precision"] = True
out["rouge1_recall"] = True
out["rouge1_fmeasure"] = True
out["rouge2_precision"] = True
out["rouge2_recall"] = True
out["rouge2_fmeasure"] = True
out["rougeL_precision"] = True
out["rougeL_recall"] = True
out["rougeL_fmeasure"] = True
out["rougeLsum_precision"] = True
out["rougeLsum_recall"] = True
out["rougeLsum_fmeasure"] = True
return out
def aggregation(self):
......@@ -802,7 +821,22 @@ class PromptSourceTask(Task):
if metric == "BLEU":
out["bleu"] = metrics.bleu
if metric == "ROUGE":
print("WARNING: Skipping Rouge.")
# TODO: Find a generic way to handle user specified rouge metrics.
out["rouge1_precision"] = mean
out["rouge1_recall"] = mean
out["rouge1_fmeasure"] = mean
out["rouge2_precision"] = mean
out["rouge2_recall"] = mean
out["rouge2_fmeasure"] = mean
out["rougeL_precision"] = mean
out["rougeL_recall"] = mean
out["rougeL_fmeasure"] = mean
out["rougeLsum_precision"] = mean
out["rougeLsum_recall"] = mean
out["rougeLsum_fmeasure"] = mean
return out
......
......@@ -202,6 +202,15 @@ def rouge(
:param pred:
A single prediction `str`s.
"""
# Add newlines between sentences to correctly compute `rougeLsum`.
if "rougeLsum" in rouge_types:
# TODO: Adapt this to handle languages that do not support sentence endings by `.`.
# See GEM-metrics implementation with lang specific `nltk` tokenizers to
# split sentences.
pred = pred.replace(".", ".\n")
refs = [ref.replace(".", ".\n") for ref in refs]
scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=True)
# ROUGE multi-ref jackknifing
if len(refs) > 1:
......
......@@ -146,6 +146,19 @@ class Reorderer:
return res
def flatten(d, parent_key='', sep='_'):
# From: https://stackoverflow.com/a/6027615
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.MutableMapping):
items.extend(flatten(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def positional_deprecated(fn):
"""
A decorator to nudge users into passing only keyword args (`kwargs`) to the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment