Unverified Commit f0c96faf authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[examples] summarization/bart/finetune.py supports t5 (#3824)

renames `run_bart_sum.py` to `finetune.py`
parent 0cec4fab
...@@ -19,7 +19,7 @@ except ImportError: ...@@ -19,7 +19,7 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BartSystem(BaseTransformer): class SummarizationTrainer(BaseTransformer):
mode = "language-modeling" mode = "language-modeling"
...@@ -64,18 +64,18 @@ class BartSystem(BaseTransformer): ...@@ -64,18 +64,18 @@ class BartSystem(BaseTransformer):
return {"avg_val_loss": avg_loss, "log": tensorboard_logs} return {"avg_val_loss": avg_loss, "log": tensorboard_logs}
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
# NOTE: this generation will not use the cache.
pad_token_id = self.tokenizer.pad_token_id pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id) source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
# NOTE: these kwargs get more speed and lower quality summaries than those in evaluate_cnn.py. # NOTE: the following kwargs get more speed and lower quality summaries than those in evaluate_cnn.py
generated_ids = self.model.generate( generated_ids = self.model.generate(
source_ids, input_ids=source_ids,
source_mask, attention_mask=source_mask,
num_beams=1, num_beams=1,
max_length=80, max_length=80,
repetition_penalty=2.5, repetition_penalty=2.5,
length_penalty=1.0, length_penalty=1.0,
early_stopping=True, early_stopping=True,
use_cache=True,
) )
preds = [ preds = [
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
...@@ -161,20 +161,20 @@ def main(args): ...@@ -161,20 +161,20 @@ def main(args):
if not args.output_dir: if not args.output_dir:
args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",) args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",)
os.makedirs(args.output_dir) os.makedirs(args.output_dir)
model = BartSystem(args) model = SummarizationTrainer(args)
trainer = generic_train(model, args) trainer = generic_train(model, args)
# Optionally, predict on dev set and write to output_dir # Optionally, predict on dev set and write to output_dir
if args.do_predict: if args.do_predict:
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
BartSystem.load_from_checkpoint(checkpoints[-1]) SummarizationTrainer.load_from_checkpoint(checkpoints[-1])
trainer.test(model) trainer.test(model)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
add_generic_args(parser, os.getcwd()) add_generic_args(parser, os.getcwd())
parser = BartSystem.add_model_specific_args(parser, os.getcwd()) parser = SummarizationTrainer.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -8,7 +8,7 @@ mkdir -p $OUTPUT_DIR ...@@ -8,7 +8,7 @@ mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access transformer_base.py # Add parent directory to python path to access transformer_base.py
export PYTHONPATH="../../":"${PYTHONPATH}" export PYTHONPATH="../../":"${PYTHONPATH}"
python run_bart_sum.py \ python finetune.py \
--data_dir=./cnn-dailymail/cnn_dm \ --data_dir=./cnn-dailymail/cnn_dm \
--model_type=bart \ --model_type=bart \
--model_name_or_path=bart-large \ --model_name_or_path=bart-large \
......
...@@ -14,7 +14,7 @@ mkdir -p $OUTPUT_DIR ...@@ -14,7 +14,7 @@ mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access transformer_base.py and utils.py # Add parent directory to python path to access transformer_base.py and utils.py
export PYTHONPATH="../../":"${PYTHONPATH}" export PYTHONPATH="../../":"${PYTHONPATH}"
python run_bart_sum.py \ python finetune.py \
--data_dir=cnn_tiny/ \ --data_dir=cnn_tiny/ \
--model_type=bart \ --model_type=bart \
--model_name_or_path=sshleifer/bart-tiny-random \ --model_name_or_path=sshleifer/bart-tiny-random \
......
...@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader ...@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
from transformers import BartTokenizer from transformers import BartTokenizer
from .evaluate_cnn import run_generate from .evaluate_cnn import run_generate
from .run_bart_sum import main from .finetune import main
from .utils import SummarizationDataset from .utils import SummarizationDataset
...@@ -92,9 +92,27 @@ class TestBartExamples(unittest.TestCase): ...@@ -92,9 +92,27 @@ class TestBartExamples(unittest.TestCase):
args_d.update( args_d.update(
data_dir=tmp_dir, model_type="bart", train_batch_size=2, eval_batch_size=2, n_gpu=0, output_dir=output_dir, data_dir=tmp_dir, model_type="bart", train_batch_size=2, eval_batch_size=2, n_gpu=0, output_dir=output_dir,
) )
main(argparse.Namespace(**args_d))
args_d.update({"do_train": False, "do_predict": True})
main(argparse.Namespace(**args_d))
args = argparse.Namespace(**args_d) def test_t5_run_sum_cli(self):
main(args) args_d: dict = DEFAULT_ARGS.copy()
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(
data_dir=tmp_dir,
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
train_batch_size=2,
eval_batch_size=2,
n_gpu=0,
output_dir=output_dir,
do_predict=True,
)
main(argparse.Namespace(**args_d))
# args_d.update({"do_train": False, "do_predict": True})
# main(argparse.Namespace(**args_d))
def test_bart_summarization_dataset(self): def test_bart_summarization_dataset(self):
tmp_dir = Path(tempfile.gettempdir()) tmp_dir = Path(tempfile.gettempdir())
......
...@@ -15,7 +15,7 @@ wc -l cnn_articles_input_data.txt # should print 11490 ...@@ -15,7 +15,7 @@ wc -l cnn_articles_input_data.txt # should print 11490
wc -l cnn_articles_reference_summaries.txt # should print 11490 wc -l cnn_articles_reference_summaries.txt # should print 11490
``` ```
### Usage ### Generating Summaries
To create summaries for each article in dataset, run: To create summaries for each article in dataset, run:
```bash ```bash
...@@ -23,3 +23,7 @@ python evaluate_cnn.py cnn_articles_input_data.txt cnn_generated_articles_summar ...@@ -23,3 +23,7 @@ python evaluate_cnn.py cnn_articles_input_data.txt cnn_generated_articles_summar
``` ```
The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system. The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
The rouge scores "rouge1, rouge2, rougeL" are automatically created and saved in ``rouge_score.txt``. The rouge scores "rouge1, rouge2, rougeL" are automatically created and saved in ``rouge_score.txt``.
### Finetuning
Pass model_type=t5 and model `examples/summarization/bart/finetune.py`
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment