"model/models/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "df94175a0fb0356c9b9e9a62b73d908633c08810"
Unverified Commit 02e5f796 authored by Amil Khare's avatar Amil Khare Committed by GitHub
Browse files

[examples] consolidate summarization examples (#4837)

parent 9f5d5a53
### Get Preprocessed CNN Data ### Get CNN Data
Both types of models do require CNN data and follow different procedures of obtaining so.
#### For BART models
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running: To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
```bash ```bash
...@@ -6,25 +9,43 @@ wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz ...@@ -6,25 +9,43 @@ wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
tar -xzvf cnn_dm.tgz tar -xzvf cnn_dm.tgz
``` ```
this should make a directory called cnn_dm/ with files like `test.source`. this should make a directory called cnn_dm/ with files like `test.source`.
To use your own data, copy that files format. Each article to be summarized is on its own line. To use your own data, copy that files format. Each article to be summarized is on its own line.
#### For T5 models
First, you need to download the CNN data. It's about ~400 MB and can be downloaded by
running
```bash
python download_cnn_daily_mail.py cnn_articles_input_data.txt cnn_articles_reference_summaries.txt
```
You should confirm that each file has 11490 lines:
```bash
wc -l cnn_articles_input_data.txt # should print 11490
wc -l cnn_articles_reference_summaries.txt # should print 11490
```
### Evaluation ### Evaluation
To create summaries for each article in dataset, run: To create summaries for each article in dataset, run:
```bash ```bash
python evaluate_cnn.py <path_to_test.source> cnn_test_summaries.txt python evaluate_cnn.py <path_to_test.source> test_generations.txt <model-name>
``` ```
the default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system. The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
### Training ### Training
Run/modify `run_train.sh` Run/modify `finetune_bart.sh` or `finetune_t5.sh`
### Where is the code?
The core model is in `src/transformers/modeling_bart.py`. This directory only contains examples.
## (WIP) Rouge Scores ## (WIP) Rouge Scores
To create summaries for each article in dataset and also calculate rouge scores run:
```bash
python evaluate_cnn.py <path_to_test.source> test_generations.txt <model-name> --reference_path <path_to_correct_summaries> --score_path <path_to_save_rouge_scores>
```
The rouge scores "rouge1, rouge2, rougeL" are automatically created and saved in ``<path_to_save_rouge_scores>``.
### Stanford CoreNLP Setup ### Stanford CoreNLP Setup
``` ```
ptb_tokenize () { ptb_tokenize () {
......
import argparse
from pathlib import Path
import torch
from tqdm import tqdm
from transformers import BartForConditionalGeneration, BartTokenizer
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
def generate_summaries(
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE
):
fout = Path(out_file).open("w")
model = BartForConditionalGeneration.from_pretrained(model_name).to(device)
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
max_length = 140
min_length = 55
for batch in tqdm(list(chunks(examples, batch_size))):
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
summaries = model.generate(
input_ids=dct["input_ids"].to(device),
attention_mask=dct["attention_mask"].to(device),
num_beams=4,
length_penalty=2.0,
max_length=max_length + 2, # +2 from original because we start at step=1 and stop before max_length
min_length=min_length + 1, # +1 from original because we start at step=1
no_repeat_ngram_size=3,
early_stopping=True,
decoder_start_token_id=model.config.eos_token_id,
)
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
for hypothesis in dec:
fout.write(hypothesis + "\n")
fout.flush()
def run_generate():
parser = argparse.ArgumentParser()
parser.add_argument(
"source_path", type=str, help="like cnn_dm/test.source",
)
parser.add_argument(
"output_path", type=str, help="where to save summaries",
)
parser.add_argument(
"model_name", type=str, default="facebook/bart-large-cnn", help="like bart-large-cnn",
)
parser.add_argument(
"--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.",
)
parser.add_argument(
"--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time",
)
args = parser.parse_args()
examples = [" " + x.rstrip() for x in open(args.source_path).readlines()]
generate_summaries(examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device)
if __name__ == "__main__":
run_generate()
# -*- coding: utf-8 -*-
import argparse import argparse
from pathlib import Path from pathlib import Path
...@@ -8,8 +9,8 @@ def main(input_path, reference_path, data_dir): ...@@ -8,8 +9,8 @@ def main(input_path, reference_path, data_dir):
cnn_ds = tfds.load("cnn_dailymail", split="test", shuffle_files=False, data_dir=data_dir) cnn_ds = tfds.load("cnn_dailymail", split="test", shuffle_files=False, data_dir=data_dir)
cnn_ds_iter = tfds.as_numpy(cnn_ds) cnn_ds_iter = tfds.as_numpy(cnn_ds)
test_articles_file = Path(input_path).open("w") test_articles_file = Path(input_path).open("w", encoding="utf-8")
test_summaries_file = Path(reference_path).open("w") test_summaries_file = Path(reference_path).open("w", encoding="utf-8")
for example in cnn_ds_iter: for example in cnn_ds_iter:
test_articles_file.write(example["article"].decode("utf-8") + "\n") test_articles_file.write(example["article"].decode("utf-8") + "\n")
......
...@@ -5,7 +5,10 @@ import torch ...@@ -5,7 +5,10 @@ import torch
from rouge_score import rouge_scorer, scoring from rouge_score import rouge_scorer, scoring
from tqdm import tqdm from tqdm import tqdm
from transformers import T5ForConditionalGeneration, T5Tokenizer from transformers import AutoModelWithLMHead, AutoTokenizer
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def chunks(lst, n): def chunks(lst, n):
...@@ -14,32 +17,31 @@ def chunks(lst, n): ...@@ -14,32 +17,31 @@ def chunks(lst, n):
yield lst[i : i + n] yield lst[i : i + n]
def generate_summaries(lns, output_file_path, model_size, batch_size, device): def generate_summaries(
output_file = Path(output_file_path).open("w") examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE
):
fout = Path(out_file).open("w", encoding="utf-8")
model = AutoModelWithLMHead.from_pretrained(model_name).to(device)
model = T5ForConditionalGeneration.from_pretrained(model_size) tokenizer = AutoTokenizer.from_pretrained(model_name)
model.to(device)
tokenizer = T5Tokenizer.from_pretrained(model_size)
# update config with summarization specific params # update config with summarization specific params
task_specific_params = model.config.task_specific_params task_specific_params = model.config.task_specific_params
if task_specific_params is not None: if task_specific_params is not None:
model.config.update(task_specific_params.get("summarization", {})) model.config.update(task_specific_params.get("summarization", {}))
for batch in tqdm(list(chunks(lns, batch_size))): for batch in tqdm(list(chunks(examples, batch_size))):
batch = [model.config.prefix + text for text in batch] if "t5" in model_name:
batch = [model.config.prefix + text for text in batch]
dct = tokenizer.batch_encode_plus(batch, max_length=512, return_tensors="pt", pad_to_max_length=True) dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True).to(
input_ids = dct["input_ids"].to(device) device
attention_mask = dct["attention_mask"].to(device) )
summaries = model.generate(**dct)
summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask)
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for hypothesis in dec: for hypothesis in dec:
output_file.write(hypothesis + "\n") fout.write(hypothesis + "\n")
output_file.flush() fout.flush()
def calculate_rouge(output_lns, reference_lns, score_path): def calculate_rouge(output_lns, reference_lns, score_path):
...@@ -62,39 +64,36 @@ def calculate_rouge(output_lns, reference_lns, score_path): ...@@ -62,39 +64,36 @@ def calculate_rouge(output_lns, reference_lns, score_path):
def run_generate(): def run_generate():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"model_size", "input_path", type=str, help="like cnn_dm/test.source or cnn_dm/test_articles_input.txt",
type=str,
help="T5 model size, either 't5-small', 't5-base', 't5-large', 't5-3b', 't5-11b'. Defaults to 't5-base'.",
default="t5-base",
) )
parser.add_argument( parser.add_argument(
"input_path", type=str, help="like cnn_dm/test_articles_input.txt", "output_path", type=str, help="where to save summaries",
) )
parser.add_argument( parser.add_argument(
"output_path", type=str, help="where to save summaries", "model_name",
type=str,
default="facebook/bart-large-cnn",
help="like bart-large-cnn,'t5-small', 't5-base', 't5-large', 't5-3b', 't5-11b",
) )
parser.add_argument("reference_path", type=str, help="like cnn_dm/test_reference_summaries.txt") parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
parser.add_argument( parser.add_argument(
"score_path", type=str, help="where to save the rouge score", "--score_path", type=str, required=False, help="where to save the rouge score",
) )
parser.add_argument( parser.add_argument(
"--batch_size", type=int, default=8, required=False, help="batch size: how many to summarize at a time", "--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.",
) )
parser.add_argument( parser.add_argument(
"--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.", "--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time",
) )
args = parser.parse_args() args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
source_lns = [x.rstrip() for x in open(args.input_path).readlines()]
generate_summaries(source_lns, args.output_path, args.model_size, args.batch_size, args.device)
output_lns = [x.rstrip() for x in open(args.output_path).readlines()] generate_summaries(examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device)
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()] if args.score_path is not None:
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
calculate_rouge(output_lns, reference_lns, args.score_path) calculate_rouge(output_lns, reference_lns, args.score_path)
if __name__ == "__main__": if __name__ == "__main__":
......
export OUTPUT_DIR_NAME=t5
export CURRENT_DIR=${PWD}
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
# Make output directory if it doesn't exist
mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../../":"${PYTHONPATH}"
python finetune.py \
--data_dir=./cnn-dailymail/cnn_dm \
--model_name_or_path=t5-large \
--model_type=t5
--learning_rate=3e-5 \
--train_batch_size=4 \
--eval_batch_size=4 \
--output_dir=$OUTPUT_DIR \
--do_train $@
***This script evaluates the the multitask pre-trained checkpoint for ``t5-base`` (see paper [here](https://arxiv.org/pdf/1910.10683.pdf)) on the CNN/Daily Mail test dataset. Please note that the results in the paper were attained using a model fine-tuned on summarization, so that results will be worse here by approx. 0.5 ROUGE points***
### Get the CNN Data
First, you need to download the CNN data. It's about ~400 MB and can be downloaded by
running
```bash
python download_cnn_daily_mail.py cnn_articles_input_data.txt cnn_articles_reference_summaries.txt
```
You should confirm that each file has 11490 lines:
```bash
wc -l cnn_articles_input_data.txt # should print 11490
wc -l cnn_articles_reference_summaries.txt # should print 11490
```
### Generating Summaries
To create summaries for each article in dataset, run:
```bash
python evaluate_cnn.py cnn_articles_input_data.txt cnn_generated_articles_summaries.txt cnn_articles_reference_summaries.txt rouge_score.txt
```
The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
The rouge scores "rouge1, rouge2, rougeL" are automatically created and saved in ``rouge_score.txt``.
### Finetuning
Pass model_type=t5 and model `examples/summarization/bart/finetune.py`
import logging
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
from .evaluate_cnn import run_generate
output_file_name = "output_t5_sum.txt"
score_file_name = "score_t5_sum.txt"
articles = ["New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
class TestT5Examples(unittest.TestCase):
def test_t5_cli(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp = Path(tempfile.gettempdir()) / "utest_generations_t5_sum.hypo"
with tmp.open("w") as f:
f.write("\n".join(articles))
output_file_name = Path(tempfile.gettempdir()) / "utest_output_t5_sum.hypo"
score_file_name = Path(tempfile.gettempdir()) / "utest_score_t5_sum.hypo"
testargs = [
"evaluate_cnn.py",
"patrickvonplaten/t5-tiny-random",
str(tmp),
str(output_file_name),
str(tmp),
str(score_file_name),
]
with patch.object(sys, "argv", testargs):
run_generate()
self.assertTrue(Path(output_file_name).exists())
self.assertTrue(Path(score_file_name).exists())
...@@ -146,3 +146,34 @@ class TestBartExamples(unittest.TestCase): ...@@ -146,3 +146,34 @@ class TestBartExamples(unittest.TestCase):
# show that targets were truncated # show that targets were truncated
self.assertEqual(batch["target_ids"].shape[1], trunc_target) # Truncated self.assertEqual(batch["target_ids"].shape[1], trunc_target) # Truncated
self.assertGreater(max_len_target, trunc_target) # Truncated self.assertGreater(max_len_target, trunc_target) # Truncated
class TestT5Examples(unittest.TestCase):
def test_t5_cli(self):
output_file_name = "output_t5_sum.txt"
score_file_name = "score_t5_sum.txt"
articles = ["New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp = Path(tempfile.gettempdir()) / "utest_generations_t5_sum.hypo"
with tmp.open("w", encoding="utf-8") as f:
f.write("\n".join(articles))
output_file_name = Path(tempfile.gettempdir()) / "utest_output_t5_sum.hypo"
score_file_name = Path(tempfile.gettempdir()) / "utest_score_t5_sum.hypo"
testargs = [
"evaluate_cnn.py",
str(tmp),
str(output_file_name),
"patrickvonplaten/t5-tiny-random",
"--reference_path",
str(tmp),
"--score_path",
str(score_file_name),
]
with patch.object(sys, "argv", testargs):
run_generate()
self.assertTrue(Path(output_file_name).exists())
self.assertTrue(Path(score_file_name).exists())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment