Unverified Commit f0c96faf authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[examples] summarization/bart/finetune.py supports t5 (#3824)

renames `run_bart_sum.py` to `finetune.py`
parent 0cec4fab
......@@ -19,7 +19,7 @@ except ImportError:
logger = logging.getLogger(__name__)
class BartSystem(BaseTransformer):
class SummarizationTrainer(BaseTransformer):
mode = "language-modeling"
......@@ -64,18 +64,18 @@ class BartSystem(BaseTransformer):
return {"avg_val_loss": avg_loss, "log": tensorboard_logs}
def test_step(self, batch, batch_idx):
# NOTE: this generation will not use the cache.
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
# NOTE: these kwargs get more speed and lower quality summaries than those in evaluate_cnn.py.
# NOTE: the following kwargs get more speed and lower quality summaries than those in evaluate_cnn.py
generated_ids = self.model.generate(
source_ids,
source_mask,
input_ids=source_ids,
attention_mask=source_mask,
num_beams=1,
max_length=80,
repetition_penalty=2.5,
length_penalty=1.0,
early_stopping=True,
use_cache=True,
)
preds = [
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
......@@ -161,20 +161,20 @@ def main(args):
if not args.output_dir:
args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",)
os.makedirs(args.output_dir)
model = BartSystem(args)
model = SummarizationTrainer(args)
trainer = generic_train(model, args)
# Optionally, predict on dev set and write to output_dir
if args.do_predict:
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
BartSystem.load_from_checkpoint(checkpoints[-1])
SummarizationTrainer.load_from_checkpoint(checkpoints[-1])
trainer.test(model)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
add_generic_args(parser, os.getcwd())
parser = BartSystem.add_model_specific_args(parser, os.getcwd())
parser = SummarizationTrainer.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
main(args)
......@@ -8,7 +8,7 @@ mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access transformer_base.py
export PYTHONPATH="../../":"${PYTHONPATH}"
python run_bart_sum.py \
python finetune.py \
--data_dir=./cnn-dailymail/cnn_dm \
--model_type=bart \
--model_name_or_path=bart-large \
......
......@@ -14,7 +14,7 @@ mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access transformer_base.py and utils.py
export PYTHONPATH="../../":"${PYTHONPATH}"
python run_bart_sum.py \
python finetune.py \
--data_dir=cnn_tiny/ \
--model_type=bart \
--model_name_or_path=sshleifer/bart-tiny-random \
......
......@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
from transformers import BartTokenizer
from .evaluate_cnn import run_generate
from .run_bart_sum import main
from .finetune import main
from .utils import SummarizationDataset
......@@ -92,9 +92,27 @@ class TestBartExamples(unittest.TestCase):
args_d.update(
data_dir=tmp_dir, model_type="bart", train_batch_size=2, eval_batch_size=2, n_gpu=0, output_dir=output_dir,
)
main(argparse.Namespace(**args_d))
args_d.update({"do_train": False, "do_predict": True})
main(argparse.Namespace(**args_d))
args = argparse.Namespace(**args_d)
main(args)
def test_t5_run_sum_cli(self):
args_d: dict = DEFAULT_ARGS.copy()
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(
data_dir=tmp_dir,
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
train_batch_size=2,
eval_batch_size=2,
n_gpu=0,
output_dir=output_dir,
do_predict=True,
)
main(argparse.Namespace(**args_d))
# args_d.update({"do_train": False, "do_predict": True})
# main(argparse.Namespace(**args_d))
def test_bart_summarization_dataset(self):
tmp_dir = Path(tempfile.gettempdir())
......
......@@ -15,7 +15,7 @@ wc -l cnn_articles_input_data.txt # should print 11490
wc -l cnn_articles_reference_summaries.txt # should print 11490
```
### Usage
### Generating Summaries
To create summaries for each article in dataset, run:
```bash
......@@ -23,3 +23,7 @@ python evaluate_cnn.py cnn_articles_input_data.txt cnn_generated_articles_summar
```
The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
The rouge scores "rouge1, rouge2, rougeL" are automatically created and saved in ``rouge_score.txt``.
### Finetuning
Pass model_type=t5 and model `examples/summarization/bart/finetune.py`
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment