test_bart_examples.py 5.31 KB
Newer Older
1
import argparse
2
import logging
3
import os
4
5
6
7
8
9
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch

10
11
12
from torch.utils.data import DataLoader

from transformers import BartTokenizer
13

14
from .evaluate_cnn import run_generate
15
from .finetune import main
16
from .utils import SummarizationDataset
17
18
19
20
21
22


logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
DEFAULT_ARGS = {
    "output_dir": "",
    "fp16": False,
    "fp16_opt_level": "O1",
    "n_gpu": 1,
    "n_tpu_cores": 0,
    "max_grad_norm": 1.0,
    "do_train": True,
    "do_predict": False,
    "gradient_accumulation_steps": 1,
    "server_ip": "",
    "server_port": "",
    "seed": 42,
    "model_type": "bart",
    "model_name_or_path": "sshleifer/bart-tiny-random",
    "config_name": "",
    "tokenizer_name": "",
    "cache_dir": "",
    "do_lower_case": False,
    "learning_rate": 3e-05,
    "weight_decay": 0.0,
    "adam_epsilon": 1e-08,
    "warmup_steps": 0,
    "num_train_epochs": 1,
    "train_batch_size": 2,
    "eval_batch_size": 2,
    "max_source_length": 12,
    "max_target_length": 12,
}

53

54
55
56
57
58
def _dump_articles(path: Path, articles: list):
    with path.open("w") as f:
        f.write("\n".join(articles))


59
60
61
62
63
64
65
66
67
68
def make_test_data_dir():
    tmp_dir = Path(tempfile.gettempdir())
    articles = [" Sam ate lunch today", "Sams lunch ingredients"]
    summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
    for split in ["train", "val", "test"]:
        _dump_articles((tmp_dir / f"{split}.source"), articles)
        _dump_articles((tmp_dir / f"{split}.target"), summaries)
    return tmp_dir


69
class TestBartExamples(unittest.TestCase):
70
71
    @classmethod
    def setUpClass(cls):
72
73
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)
74
75
76
77
        logging.disable(logging.CRITICAL)  # remove noisy download output from tracebacks
        return cls

    def test_bart_cnn_cli(self):
78
        tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
79
        output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
80
81
        articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
        _dump_articles(tmp, articles)
82
        testargs = ["evaluate_cnn.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"]
83
        with patch.object(sys, "argv", testargs):
84
            run_generate()
85
86
87
88
89
90
91
92
93
94
            self.assertTrue(Path(output_file_name).exists())
            os.remove(Path(output_file_name))

    def test_bart_run_sum_cli(self):
        args_d: dict = DEFAULT_ARGS.copy()
        tmp_dir = make_test_data_dir()
        output_dir = tempfile.mkdtemp(prefix="output_")
        args_d.update(
            data_dir=tmp_dir, model_type="bart", train_batch_size=2, eval_batch_size=2, n_gpu=0, output_dir=output_dir,
        )
95
96
        main(argparse.Namespace(**args_d))
        args_d.update({"do_train": False, "do_predict": True})
97

98
        main(argparse.Namespace(**args_d))
99
100
101
102
103
104
105
        contents = os.listdir(output_dir)
        expected_contents = {
            "checkpointepoch=0.ckpt",
            "test_results.txt",
        }
        created_files = {os.path.basename(p) for p in contents}
        self.assertSetEqual(expected_contents, created_files)
106

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    def test_t5_run_sum_cli(self):
        args_d: dict = DEFAULT_ARGS.copy()
        tmp_dir = make_test_data_dir()
        output_dir = tempfile.mkdtemp(prefix="output_")
        args_d.update(
            data_dir=tmp_dir,
            model_type="t5",
            model_name_or_path="patrickvonplaten/t5-tiny-random",
            train_batch_size=2,
            eval_batch_size=2,
            n_gpu=0,
            output_dir=output_dir,
            do_predict=True,
        )
        main(argparse.Namespace(**args_d))
122

123
124
        # args_d.update({"do_train": False, "do_predict": True})
        # main(argparse.Namespace(**args_d))
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

    def test_bart_summarization_dataset(self):
        tmp_dir = Path(tempfile.gettempdir())
        articles = [" Sam ate lunch today", "Sams lunch ingredients"]
        summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
        _dump_articles((tmp_dir / "train.source"), articles)
        _dump_articles((tmp_dir / "train.target"), summaries)
        tokenizer = BartTokenizer.from_pretrained("bart-large")
        max_len_source = max(len(tokenizer.encode(a)) for a in articles)
        max_len_target = max(len(tokenizer.encode(a)) for a in summaries)
        trunc_target = 4
        train_dataset = SummarizationDataset(
            tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
        )
        dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
        for batch in dataloader:
            self.assertEqual(batch["source_mask"].shape, batch["source_ids"].shape)
            # show that articles were trimmed.
            self.assertEqual(batch["source_ids"].shape[1], max_len_source)
            self.assertGreater(20, batch["source_ids"].shape[1])  # trimmed significantly

            # show that targets were truncated
            self.assertEqual(batch["target_ids"].shape[1], trunc_target)  # Truncated
            self.assertGreater(max_len_target, trunc_target)  # Truncated