Unverified Commit 448bcbfb authored by Charles Lovering's avatar Charles Lovering Committed by GitHub
Browse files

Merge pull request #3 from cjlovering/add-rouge

Add `ROUGE` metric to `PromptSourceTask`
parents d40a7ce0 ff89667f
...@@ -652,8 +652,6 @@ class PromptSourceTask(Task): ...@@ -652,8 +652,6 @@ class PromptSourceTask(Task):
added default behavior for. If you want to add default behavior for a new metric, added default behavior for. If you want to add default behavior for a new metric,
update the functions below. If you want to use one of the following metrics, update the functions below. If you want to use one of the following metrics,
*and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`. *and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.
WARNING: ROUGE is WIP.
""" """
CONFIGURED_PS_METRICS = set(["Accuracy", "BLEU", "ROUGE"]) CONFIGURED_PS_METRICS = set(["Accuracy", "BLEU", "ROUGE"])
...@@ -765,7 +763,6 @@ class PromptSourceTask(Task): ...@@ -765,7 +763,6 @@ class PromptSourceTask(Task):
# NOTE: In the future, target will be a list of strings. # NOTE: In the future, target will be a list of strings.
pred = results[0].strip() pred = results[0].strip()
out = {} out = {}
for metric in self.prompt.metadata.metrics: for metric in self.prompt.metadata.metrics:
assert ( assert (
metric in self.CONFIGURED_PS_METRICS metric in self.CONFIGURED_PS_METRICS
...@@ -773,8 +770,15 @@ class PromptSourceTask(Task): ...@@ -773,8 +770,15 @@ class PromptSourceTask(Task):
if metric == "BLEU": if metric == "BLEU":
out["bleu"] = (target, pred) out["bleu"] = (target, pred)
if metric == "ROUGE": if metric == "ROUGE":
print("WARNING: Skipping Rouge.") # TODO: This computes all rouge sub-metrics. Find a generic
# way to handle user specified rouge sub-metrics to avoid extra
# compute.
rouge_scores = metrics.rouge(target, pred)
# Flatten rouge score dict.
rouge_scores = utils.flatten(rouge_scores)
# Merge all the rouge-type scores into the `out` dict.
out = {**out, **rouge_scores}
print(out)
return out return out
def higher_is_better(self): def higher_is_better(self):
...@@ -788,7 +792,22 @@ class PromptSourceTask(Task): ...@@ -788,7 +792,22 @@ class PromptSourceTask(Task):
if metric == "BLEU": if metric == "BLEU":
out["bleu"] = True out["bleu"] = True
if metric == "ROUGE": if metric == "ROUGE":
print("WARNING: Skipping Rouge.") # TODO: Find a generic way to handle user specified rouge metrics.
out["rouge1_precision"] = True
out["rouge1_recall"] = True
out["rouge1_fmeasure"] = True
out["rouge2_precision"] = True
out["rouge2_recall"] = True
out["rouge2_fmeasure"] = True
out["rougeL_precision"] = True
out["rougeL_recall"] = True
out["rougeL_fmeasure"] = True
out["rougeLsum_precision"] = True
out["rougeLsum_recall"] = True
out["rougeLsum_fmeasure"] = True
return out return out
def aggregation(self): def aggregation(self):
...@@ -802,7 +821,22 @@ class PromptSourceTask(Task): ...@@ -802,7 +821,22 @@ class PromptSourceTask(Task):
if metric == "BLEU": if metric == "BLEU":
out["bleu"] = metrics.bleu out["bleu"] = metrics.bleu
if metric == "ROUGE": if metric == "ROUGE":
print("WARNING: Skipping Rouge.") # TODO: Find a generic way to handle user specified rouge metrics.
out["rouge1_precision"] = mean
out["rouge1_recall"] = mean
out["rouge1_fmeasure"] = mean
out["rouge2_precision"] = mean
out["rouge2_recall"] = mean
out["rouge2_fmeasure"] = mean
out["rougeL_precision"] = mean
out["rougeL_recall"] = mean
out["rougeL_fmeasure"] = mean
out["rougeLsum_precision"] = mean
out["rougeLsum_recall"] = mean
out["rougeLsum_fmeasure"] = mean
return out return out
......
...@@ -202,6 +202,15 @@ def rouge( ...@@ -202,6 +202,15 @@ def rouge(
:param pred: :param pred:
A single prediction `str`s. A single prediction `str`s.
""" """
# Add newlines between sentences to correctly compute `rougeLsum`.
if "rougeLsum" in rouge_types:
# TODO: Adapt this to handle languages that do not support sentence endings by `.`.
# See GEM-metrics implementation with lang specific `nltk` tokenizers to
# split sentences.
pred = pred.replace(".", ".\n")
refs = [ref.replace(".", ".\n") for ref in refs]
scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=True) scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=True)
# ROUGE multi-ref jackknifing # ROUGE multi-ref jackknifing
if len(refs) > 1: if len(refs) > 1:
......
...@@ -146,6 +146,19 @@ class Reorderer: ...@@ -146,6 +146,19 @@ class Reorderer:
return res return res
def flatten(d, parent_key='', sep='_'):
# From: https://stackoverflow.com/a/6027615
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.MutableMapping):
items.extend(flatten(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def positional_deprecated(fn): def positional_deprecated(fn):
""" """
A decorator to nudge users into passing only keyword args (`kwargs`) to the A decorator to nudge users into passing only keyword args (`kwargs`) to the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment