Unverified Commit 5b396457 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Summarization Examples: add Bart CNN Evaluation (#3082)

* Rename and improve example

* Add test

* slightly faster test

* style

* This breaks remy prolly

* shorter test string

* no slow

* newdir structure

* New tree

* Style

* shorter

* docs

* clean

* Attempt future import

* more import hax
parent 5c5af879
### Get the CNN/Daily Mail Data
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
```bash
tar -xvf cnn_stories.tgz && tar -xvf dailymail_stories.tgz
```
this should make a directory called cnn_dm/ with files like `test.source`.
To use your own data, copy that files format. Each article to be summarized is on its own line.
### Usage
To create summaries for each article in dataset, run:
```bash
python evaluate_cnn.py <path_to_test.source> cnn_test_summaries.txt
```
the default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
### Where is the code?
The core model is in `src/transformers/modeling_bart.py`. This directory only contains examples.
### (WIP) Rouge Scores
### Stanford CoreNLP Setup
```
ptb_tokenize () {
cat $1 | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > $2
}
sudo apt install openjdk-8-jre-headless
sudo apt-get install ant
wget http://nlp.stanford.edu/software/stanford-corenlp-full-2018-10-05.zip
unzip stanford-corenlp-full-2018-10-05.zip
cd stanford-corenlp-full-2018-10-05
export CLASSPATH=stanford-corenlp-3.9.2.jar:stanford-corenlp-3.9.2-models.jar
```
### Rouge Setup
Install `files2rouge` following the instructions at [here](https://github.com/pltrdy/files2rouge).
I also needed to run `sudo apt-get install libxml-parser-perl`
```python
from files2rouge import files2rouge
from files2rouge import settings
files2rouge.run(<path_to_tokenized_hypo>,
<path_to_tokenized_target>,
saveto='rouge_output.txt')
```
import argparse
from pathlib import Path
import torch
from tqdm import tqdm
from transformers import BartForMaskedLM, BartTokenizer
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
fout = Path(out_file).open("w")
model = BartForMaskedLM.from_pretrained("bart-large-cnn", output_past=True,)
tokenizer = BartTokenizer.from_pretrained("bart-large")
for batch in tqdm(list(chunks(lns, batch_size))):
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
summaries = model.generate(
input_ids=dct["input_ids"].to(device),
attention_mask=dct["attention_mask"].to(device),
num_beams=4,
length_penalty=2.0,
max_length=140,
min_len=55,
no_repeat_ngram_size=3,
)
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
for hypothesis in dec:
fout.write(hypothesis + "\n")
fout.flush()
def _run_generate():
parser = argparse.ArgumentParser()
parser.add_argument(
"source_path", type=str, help="like cnn_dm/test.source",
)
parser.add_argument(
"output_path", type=str, help="where to save summaries",
)
parser.add_argument(
"--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.",
)
parser.add_argument(
"--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time",
)
args = parser.parse_args()
lns = [" " + x.rstrip() for x in open(args.source_path).readlines()]
generate_summaries(lns, args.output_path, batch_size=args.bs, device=args.device)
if __name__ == "__main__":
_run_generate()
import logging
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
from .evaluate_cnn import _run_generate
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
class TestBartExamples(unittest.TestCase):
def test_bart_cnn_cli(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp = Path(tempfile.gettempdir()) / "utest_generations.hypo"
with tmp.open("w") as f:
f.write("\n".join(articles))
testargs = ["evaluate_cnn.py", str(tmp), "output.txt"]
with patch.object(sys, "argv", testargs):
_run_generate()
self.assertTrue(Path("output.txt").exists())
......@@ -15,7 +15,7 @@ pip install nltk py-rouge
cd examples/summarization
```
## Reproduce the authors' results on ROUGE
## Reproduce the authors' ROUGE score
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
......
......@@ -11,12 +11,13 @@ from tqdm import tqdm
from modeling_bertabs import BertAbs, build_predictor
from transformers import BertTokenizer
from utils_summarization import (
SummarizationDataset,
from .utils_summarization import (
CNNDMDataset,
build_mask,
compute_token_type_ids,
encode_for_summarization,
fit_to_block_size,
truncate_or_pad,
)
......@@ -194,7 +195,7 @@ def build_data_iterator(args, tokenizer):
def load_and_cache_examples(args, tokenizer):
dataset = SummarizationDataset(args.documents_dir)
dataset = CNNDMDataset(args.documents_dir)
return dataset
......@@ -211,7 +212,7 @@ def collate(data, tokenizer, block_size, device):
encoded_text = [encode_for_summarization(story, summary, tokenizer) for _, story, summary in data]
encoded_stories = torch.tensor(
[fit_to_block_size(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]
[truncate_or_pad(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]
)
encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)
......
......@@ -17,7 +17,7 @@ import unittest
import numpy as np
import torch
from utils_summarization import build_mask, compute_token_type_ids, fit_to_block_size, process_story
from .utils_summarization import build_mask, compute_token_type_ids, process_story, truncate_or_pad
class SummarizationDataProcessingTest(unittest.TestCase):
......@@ -28,19 +28,19 @@ class SummarizationDataProcessingTest(unittest.TestCase):
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
sequence = [1, 2, 3, 4]
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
self.assertEqual(fit_to_block_size(sequence, self.block_size, 0), expected_output)
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
def test_fit_to_block_sequence_fit_exactly(self):
""" Do nothing if the sequence is the right size. """
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
self.assertEqual(fit_to_block_size(sequence, self.block_size, 0), expected_output)
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
def test_fit_to_block_sequence_too_big(self):
""" Truncate the sequence if it is too long. """
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
self.assertEqual(fit_to_block_size(sequence, self.block_size, 0), expected_output)
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
def test_process_story_no_highlights(self):
""" Processing a story with no highlights returns an empty list for the summary.
......
......@@ -10,7 +10,7 @@ from torch.utils.data import Dataset
# ------------
class SummarizationDataset(Dataset):
class CNNDMDataset(Dataset):
""" Abstracts the dataset used to train seq2seq models.
The class will process the documents that are located in the specified
......@@ -62,11 +62,11 @@ class SummarizationDataset(Dataset):
def process_story(raw_story):
""" Extract the story and summary from a story file.
Attributes:
Arguments:
raw_story (str): content of the story file as an utf-8 encoded string.
Raises:
IndexError: If the stoy is empty or contains no highlights.
IndexError: If the story is empty or contains no highlights.
"""
nonempty_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
......@@ -107,7 +107,7 @@ def _add_missing_period(line):
# --------------------------
def fit_to_block_size(sequence, block_size, pad_token_id):
def truncate_or_pad(sequence, block_size, pad_token_id):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter we append padding token to the right of the sequence.
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment