Unverified Commit 7296fea1 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] rougeLSum expects \n between sentences (#7410)


Co-authored-by: default avatarSwetha Mandava <smandava@nvidia.com>
parent eab5f596
...@@ -11,6 +11,7 @@ git-python==1.0.3 ...@@ -11,6 +11,7 @@ git-python==1.0.3
faiss-cpu faiss-cpu
streamlit streamlit
elasticsearch elasticsearch
nltk
pandas pandas
datasets datasets
fire fire
......
import fire
from utils import calculate_rouge, save_json
def calculate_rouge_path(pred_path, tgt_path, save_path=None, **kwargs):
"""Kwargs will be passed to calculate_rouge"""
pred_lns = [x.strip() for x in open(pred_path).readlines()]
tgt_lns = [x.strip() for x in open(tgt_path).readlines()][: len(pred_lns)]
metrics = calculate_rouge(pred_lns, tgt_lns, **kwargs)
if save_path is not None:
save_json(metrics, save_path)
return metrics # these print nicely
if __name__ == "__main__":
fire.Fire(calculate_rouge_path)
...@@ -7,13 +7,14 @@ import sys ...@@ -7,13 +7,14 @@ import sys
from collections import OrderedDict from collections import OrderedDict
from run_eval import datetime_now, run_generate from run_eval import datetime_now, run_generate
from utils import ROUGE_KEYS
# A table of supported tasks and the list of scores in the order of importance to be sorted by. # A table of supported tasks and the list of scores in the order of importance to be sorted by.
# To add a new task, simply list the score names that `run_eval.run_generate()` returns # To add a new task, simply list the score names that `run_eval.run_generate()` returns
task_score_names = { task_score_names = {
"translation": ["bleu"], "translation": ["bleu"],
"summarization": ["rouge1", "rouge2", "rougeL"], "summarization": ROUGE_KEYS,
} }
......
import re
try:
import nltk
NLTK_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
NLTK_AVAILABLE = False
if NLTK_AVAILABLE:
try:
nltk.download("punkt", quiet=True)
except FileExistsError: # multiprocessing race condition
pass
def add_newline_to_end_of_each_sentence(x: str) -> str:
re.sub("<n>", "", x) # remove pegasus newline char
assert NLTK_AVAILABLE, "nltk must be installed to separate newlines betwee sentences. (pip install nltk)"
return "\n".join(nltk.sent_tokenize(x))
from collections import defaultdict
from pathlib import Path
import pandas as pd
from rouge_cli import calculate_rouge_path
from utils import calculate_rouge
PRED = [
'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the final seconds on board Flight 9525. The Germanwings co-pilot says he had a "previous episode of severe depression" German airline confirms it knew of Andreas Lubitz\'s depression years before he took control.',
"The Palestinian Authority officially becomes the 123rd member of the International Criminal Court. The formal accession was marked with a ceremony at The Hague, in the Netherlands. The Palestinians signed the ICC's founding Rome Statute in January. Israel and the United States opposed the Palestinians' efforts to join the body.",
"Amnesty International releases its annual report on the death penalty. The report catalogs the use of state-sanctioned killing as a punitive measure across the globe. At least 607 people were executed around the world in 2014, compared to 778 in 2013. The U.S. remains one of the worst offenders for imposing capital punishment.",
]
TGT = [
'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports . Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz had informed his Lufthansa training school of an episode of severe depression, airline says .',
"Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June . Israel and the United States opposed the move, which could open the door to war crimes investigations against Israelis .",
"Amnesty's annual death penalty report catalogs encouraging signs, but setbacks in numbers of those sentenced to death . Organization claims that governments around the world are using the threat of terrorism to advance executions . The number of executions worldwide has gone down by almost 22% compared with 2013, but death sentences up by 28% .",
]
def test_disaggregated_scores_are_determinstic():
no_aggregation = calculate_rouge(PRED, TGT, bootstrap_aggregation=False, rouge_keys=["rouge2", "rougeL"])
assert isinstance(no_aggregation, defaultdict)
no_aggregation_just_r2 = calculate_rouge(PRED, TGT, bootstrap_aggregation=False, rouge_keys=["rouge2"])
assert (
pd.DataFrame(no_aggregation["rouge2"]).fmeasure.mean()
== pd.DataFrame(no_aggregation_just_r2["rouge2"]).fmeasure.mean()
)
def test_newline_cnn_improvement():
k = "rougeLsum"
score = calculate_rouge(PRED, TGT, newline_sep=True, rouge_keys=[k])[k]
score_no_sep = calculate_rouge(PRED, TGT, newline_sep=False, rouge_keys=[k])[k]
assert score > score_no_sep
def test_newline_irrelevant_for_other_metrics():
k = ["rouge1", "rouge2", "rougeL"]
score_sep = calculate_rouge(PRED, TGT, newline_sep=True, rouge_keys=k)
score_no_sep = calculate_rouge(PRED, TGT, newline_sep=False, rouge_keys=k)
assert score_sep == score_no_sep
def test_single_sent_scores_dont_depend_on_newline_sep():
pred = [
"Her older sister, Margot Frank, died in 1945, a month earlier than previously thought.",
'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports .',
]
tgt = [
"Margot Frank, died in 1945, a month earlier than previously thought.",
'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the final seconds on board Flight 9525.',
]
assert calculate_rouge(pred, tgt, newline_sep=True) == calculate_rouge(pred, tgt, newline_sep=False)
def test_pegasus_newline():
pred = [
"""" "a person who has such a video needs to immediately give it to the investigators," prosecutor says .<n> "it is a very disturbing scene," editor-in-chief of bild online tells "erin burnett: outfront" """
]
tgt = [
""" Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports . Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz had informed his Lufthansa training school of an episode of severe depression, airline says ."""
]
prev_score = calculate_rouge(pred, tgt, rouge_keys=["rougeLsum"], newline_sep=False)["rougeLsum"]
new_score = calculate_rouge(pred, tgt, rouge_keys=["rougeLsum"])["rougeLsum"]
assert new_score > prev_score
def test_rouge_cli():
data_dir = Path("examples/seq2seq/test_data/wmt_en_ro")
metrics = calculate_rouge_path(data_dir.joinpath("test.source"), data_dir.joinpath("test.target"))
assert isinstance(metrics, dict)
metrics_default_dict = calculate_rouge_path(
data_dir.joinpath("test.source"), data_dir.joinpath("test.target"), bootstrap_aggregation=False
)
assert isinstance(metrics_default_dict, defaultdict)
...@@ -20,7 +20,7 @@ from run_eval_search import run_search ...@@ -20,7 +20,7 @@ from run_eval_search import run_search
from transformers import AutoConfig, AutoModelForSeq2SeqLM from transformers import AutoConfig, AutoModelForSeq2SeqLM
from transformers.hf_api import HfApi from transformers.hf_api import HfApi
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
from utils import label_smoothed_nll_loss, lmap, load_json from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
...@@ -365,7 +365,7 @@ def test_run_eval_search(model): ...@@ -365,7 +365,7 @@ def test_run_eval_search(model):
if "translation" in task: if "translation" in task:
expected_strings.append("bleu") expected_strings.append("bleu")
else: else:
expected_strings.extend(["rouge1", "rouge2", "rougeL"]) expected_strings.extend(ROUGE_KEYS)
for w in expected_strings: for w in expected_strings:
assert w in cs.out assert w in cs.out
for w in un_expected_strings: for w in un_expected_strings:
......
...@@ -18,6 +18,7 @@ from sacrebleu import corpus_bleu ...@@ -18,6 +18,7 @@ from sacrebleu import corpus_bleu
from torch import nn from torch import nn
from torch.utils.data import Dataset, Sampler from torch.utils.data import Dataset, Sampler
from sentence_splitter import add_newline_to_end_of_each_sentence
from transformers import BartTokenizer from transformers import BartTokenizer
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
...@@ -378,19 +379,63 @@ def get_git_info(): ...@@ -378,19 +379,63 @@ def get_git_info():
return repo_infos return repo_infos
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict: def extract_rouge_mid_statistics(dct):
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) new_dict = {}
aggregator = scoring.BootstrapAggregator() for k1, v1 in dct.items():
mid = v1.mid
new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]}
return new_dict
def calculate_rouge(
pred_lns: List[str],
tgt_lns: List[str],
use_stemmer=True,
rouge_keys=ROUGE_KEYS,
return_precision_and_recall=False,
bootstrap_aggregation=True,
newline_sep=True,
) -> Dict:
"""Calculate rouge using rouge_scorer package.
for reference_ln, output_ln in zip(reference_lns, output_lns): Args:
scores = scorer.score(reference_ln, output_ln) pred_lns: list of summaries generated by model
tgt_lns: list of groundtruth summaries (e.g. contents of val.target)
use_stemmer: Bool indicating whether Porter stemmer should be used to
strip word suffixes to improve matching.
rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum
return_precision_and_recall: (False) whether to also return precision and recall.
bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False
this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]``
newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL
on multi sentence summaries (CNN/DM dataset).
Returns:
Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys
"""
scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer)
aggregator = scoring.BootstrapAggregator()
for pred, tgt in zip(tgt_lns, pred_lns):
# rougeLsum expects "\n" separated sentences within a summary
if newline_sep:
pred = add_newline_to_end_of_each_sentence(pred)
tgt = add_newline_to_end_of_each_sentence(tgt)
scores = scorer.score(pred, tgt)
aggregator.add_scores(scores) aggregator.add_scores(scores)
result = aggregator.aggregate() if bootstrap_aggregation:
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()} result = aggregator.aggregate()
if return_precision_and_recall:
return extract_rouge_mid_statistics(result) # here we return dict
else:
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
else:
return aggregator._scores # here we return defaultdict(list)
# Utilities for freezing parameters and checking whether they are frozen # Utilities for freezing parameters and checking whether they are frozen
...@@ -423,9 +468,6 @@ def assert_not_all_frozen(model): ...@@ -423,9 +468,6 @@ def assert_not_all_frozen(model):
assert any(model_grads), f"none of {npars} weights require grad" assert any(model_grads), f"none of {npars} weights require grad"
# CLI Parsing utils
def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]: def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
""" """
Parse an argv list of unspecified command line args to a dict. Parse an argv list of unspecified command line args to a dict.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment