Unverified Commit c59b1e68 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[examples] unit test for run_bart_sum (#3544)

- adds pytorch-lightning dependency
parent 301bf8d1
...@@ -5,4 +5,5 @@ seqeval ...@@ -5,4 +5,5 @@ seqeval
psutil psutil
sacrebleu sacrebleu
rouge-score rouge-score
tensorflow_datasets tensorflow_datasets
\ No newline at end of file pytorch-lightning==0.7.3 # April 10, 2020 release
...@@ -8,7 +8,12 @@ import torch ...@@ -8,7 +8,12 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformer_base import BaseTransformer, add_generic_args, generic_train, get_linear_schedule_with_warmup from transformer_base import BaseTransformer, add_generic_args, generic_train, get_linear_schedule_with_warmup
from utils import SummarizationDataset
try:
from .utils import SummarizationDataset
except ImportError:
from utils import SummarizationDataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -20,6 +25,11 @@ class BartSystem(BaseTransformer): ...@@ -20,6 +25,11 @@ class BartSystem(BaseTransformer):
def __init__(self, hparams): def __init__(self, hparams):
super().__init__(hparams, num_labels=None, mode=self.mode) super().__init__(hparams, num_labels=None, mode=self.mode)
self.dataset_kwargs: dict = dict(
data_dir=self.hparams.data_dir,
max_source_length=self.hparams.max_source_length,
max_target_length=self.hparams.max_target_length,
)
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, lm_labels=None): def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, lm_labels=None):
return self.model( return self.model(
...@@ -92,14 +102,6 @@ class BartSystem(BaseTransformer): ...@@ -92,14 +102,6 @@ class BartSystem(BaseTransformer):
return self.test_end(outputs) return self.test_end(outputs)
@property
def dataset_kwargs(self):
return dict(
data_dir=self.hparams.data_dir,
max_source_length=self.hparams.max_source_length,
max_target_length=self.hparams.max_target_length,
)
def get_dataloader(self, type_path: str, batch_size: int) -> DataLoader: def get_dataloader(self, type_path: str, batch_size: int) -> DataLoader:
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs) dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn) dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn)
...@@ -153,17 +155,12 @@ class BartSystem(BaseTransformer): ...@@ -153,17 +155,12 @@ class BartSystem(BaseTransformer):
return parser return parser
if __name__ == "__main__": def main(args):
parser = argparse.ArgumentParser()
add_generic_args(parser, os.getcwd())
parser = BartSystem.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
# If output_dir not provided, a folder will be generated in pwd # If output_dir not provided, a folder will be generated in pwd
if not args.output_dir: if not args.output_dir:
args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",) args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",)
os.makedirs(args.output_dir) os.makedirs(args.output_dir)
model = BartSystem(args) model = BartSystem(args)
trainer = generic_train(model, args) trainer = generic_train(model, args)
...@@ -172,3 +169,12 @@ if __name__ == "__main__": ...@@ -172,3 +169,12 @@ if __name__ == "__main__":
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
BartSystem.load_from_checkpoint(checkpoints[-1]) BartSystem.load_from_checkpoint(checkpoints[-1])
trainer.test(model) trainer.test(model)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
add_generic_args(parser, os.getcwd())
parser = BartSystem.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
main(args)
# Install newest ptl.
pip install -U git+http://github.com/PyTorchLightning/pytorch-lightning/
export OUTPUT_DIR_NAME=bart_sum export OUTPUT_DIR_NAME=bart_sum
export CURRENT_DIR=${PWD} export CURRENT_DIR=${PWD}
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME} export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
...@@ -20,4 +16,4 @@ python run_bart_sum.py \ ...@@ -20,4 +16,4 @@ python run_bart_sum.py \
--train_batch_size=4 \ --train_batch_size=4 \
--eval_batch_size=4 \ --eval_batch_size=4 \
--output_dir=$OUTPUT_DIR \ --output_dir=$OUTPUT_DIR \
--do_train --do_train $@
\ No newline at end of file
# Script for verifying that run_bart_sum can be invoked from its directory
# Get tiny dataset with cnn_dm format (4 examples for train, val, test)
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_tiny.tgz
tar -xzvf cnn_tiny.tgz
rm cnn_tiny.tgz
export OUTPUT_DIR_NAME=bart_utest_output
export CURRENT_DIR=${PWD}
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
# Make output directory if it doesn't exist
mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access transformer_base.py and utils.py
export PYTHONPATH="../../":"${PYTHONPATH}"
python run_bart_sum.py \
--data_dir=cnn_tiny/ \
--model_type=bart \
--model_name_or_path=sshleifer/bart-tiny-random \
--learning_rate=3e-5 \
--train_batch_size=2 \
--eval_batch_size=2 \
--output_dir=$OUTPUT_DIR \
--num_train_epochs=1 \
--n_gpu=0 \
--do_train $@
rm -rf cnn_tiny
rm -rf $OUTPUT_DIR
import argparse
import logging import logging
import os
import sys import sys
import tempfile import tempfile
import unittest import unittest
...@@ -10,6 +12,7 @@ from torch.utils.data import DataLoader ...@@ -10,6 +12,7 @@ from torch.utils.data import DataLoader
from transformers import BartTokenizer from transformers import BartTokenizer
from .evaluate_cnn import run_generate from .evaluate_cnn import run_generate
from .run_bart_sum import main
from .utils import SummarizationDataset from .utils import SummarizationDataset
...@@ -17,16 +20,61 @@ logging.basicConfig(level=logging.DEBUG) ...@@ -17,16 +20,61 @@ logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger() logger = logging.getLogger()
DEFAULT_ARGS = {
"output_dir": "",
"fp16": False,
"fp16_opt_level": "O1",
"n_gpu": 1,
"n_tpu_cores": 0,
"max_grad_norm": 1.0,
"do_train": True,
"do_predict": False,
"gradient_accumulation_steps": 1,
"server_ip": "",
"server_port": "",
"seed": 42,
"model_type": "bart",
"model_name_or_path": "sshleifer/bart-tiny-random",
"config_name": "",
"tokenizer_name": "",
"cache_dir": "",
"do_lower_case": False,
"learning_rate": 3e-05,
"weight_decay": 0.0,
"adam_epsilon": 1e-08,
"warmup_steps": 0,
"num_train_epochs": 1,
"train_batch_size": 2,
"eval_batch_size": 2,
"max_source_length": 12,
"max_target_length": 12,
}
def _dump_articles(path: Path, articles: list): def _dump_articles(path: Path, articles: list):
with path.open("w") as f: with path.open("w") as f:
f.write("\n".join(articles)) f.write("\n".join(articles))
def make_test_data_dir():
tmp_dir = Path(tempfile.gettempdir())
articles = [" Sam ate lunch today", "Sams lunch ingredients"]
summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
for split in ["train", "val", "test"]:
_dump_articles((tmp_dir / f"{split}.source"), articles)
_dump_articles((tmp_dir / f"{split}.target"), summaries)
return tmp_dir
class TestBartExamples(unittest.TestCase): class TestBartExamples(unittest.TestCase):
def test_bart_cnn_cli(self): @classmethod
def setUpClass(cls):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
return cls
def test_bart_cnn_cli(self):
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo" tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo" output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
...@@ -34,7 +82,19 @@ class TestBartExamples(unittest.TestCase): ...@@ -34,7 +82,19 @@ class TestBartExamples(unittest.TestCase):
testargs = ["evaluate_cnn.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"] testargs = ["evaluate_cnn.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"]
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_generate() run_generate()
self.assertTrue(output_file_name.exists()) self.assertTrue(Path(output_file_name).exists())
os.remove(Path(output_file_name))
def test_bart_run_sum_cli(self):
args_d: dict = DEFAULT_ARGS.copy()
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(
data_dir=tmp_dir, model_type="bart", train_batch_size=2, eval_batch_size=2, n_gpu=0, output_dir=output_dir,
)
args = argparse.Namespace(**args_d)
main(args)
def test_bart_summarization_dataset(self): def test_bart_summarization_dataset(self):
tmp_dir = Path(tempfile.gettempdir()) tmp_dir = Path(tempfile.gettempdir())
......
...@@ -104,8 +104,8 @@ class BaseTransformer(pl.LightningModule): ...@@ -104,8 +104,8 @@ class BaseTransformer(pl.LightningModule):
self.lr_scheduler.step() self.lr_scheduler.step()
def get_tqdm_dict(self): def get_tqdm_dict(self):
tqdm_dict = {"loss": "{:.3f}".format(self.trainer.avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]} avg_loss = getattr(self.trainer, "avg_loss", 0.0)
tqdm_dict = {"loss": "{:.3f}".format(avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}
return tqdm_dict return tqdm_dict
def test_step(self, batch, batch_nb): def test_step(self, batch, batch_nb):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment