test_bash_script.py 8.34 KB
Newer Older
1
2
#!/usr/bin/env python

3
4
5
6
7
8
9
10
11
import argparse
import os
import sys
from unittest.mock import patch

import pytorch_lightning as pl
import timeout_decorator
import torch

12
13
from distillation import BartSummarizationDistiller, distill_main
from finetune import SummarizationModule, main
14
15
from transformers import MarianMTModel
from transformers.file_utils import cached_path
16
from transformers.testing_utils import TestCasePlus, require_torch_gpu, require_torch_non_multigpu_but_fix_me, slow
17
from utils import load_json
18
19


20
MARIAN_MODEL = "sshleifer/mar_enro_6_3_student"
21
22


23
24
25
26
27
28
29
30
31
32
class TestMbartCc25Enro(TestCasePlus):
    def setUp(self):
        super().setUp()

        data_cached = cached_path(
            "https://cdn-datasets.huggingface.co/translation/wmt_en_ro-tr40k-va0.5k-te0.5k.tar.gz",
            extract_compressed_file=True,
        )
        self.data_dir = f"{data_cached}/wmt_en_ro-tr40k-va0.5k-te0.5k"

33
    @slow
34
    @require_torch_gpu
35
    @require_torch_non_multigpu_but_fix_me
36
37
38
39
    def test_model_download(self):
        """This warms up the cache so that we can time the next test without including download time, which varies between machines."""
        MarianMTModel.from_pretrained(MARIAN_MODEL)

40
    # @timeout_decorator.timeout(1200)
41
    @slow
42
    @require_torch_gpu
43
    @require_torch_non_multigpu_but_fix_me
44
45
    def test_train_mbart_cc25_enro_script(self):
        env_vars_to_replace = {
46
47
            "$MAX_LEN": 64,
            "$BS": 64,
48
            "$GAS": 1,
49
50
51
52
53
            "$ENRO_DIR": self.data_dir,
            "facebook/mbart-large-cc25": MARIAN_MODEL,
            # "val_check_interval=0.25": "val_check_interval=1.0",
            "--learning_rate=3e-5": "--learning_rate 3e-4",
            "--num_train_epochs 6": "--num_train_epochs 1",
54
55
56
        }

        # Clean up bash script
57
        bash_script = (self.test_file_dir / "train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
58
59
60
61
62
        bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
        for k, v in env_vars_to_replace.items():
            bash_script = bash_script.replace(k, str(v))
        output_dir = self.get_auto_remove_tmp_dir()

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        # bash_script = bash_script.replace("--fp16 ", "")
        args = f"""
            --output_dir {output_dir}
            --tokenizer_name Helsinki-NLP/opus-mt-en-ro
            --sortish_sampler
            --do_predict
            --gpus 1
            --freeze_encoder
            --n_train 40000
            --n_val 500
            --n_test 500
            --fp16_opt_level O1
            --num_sanity_val_steps 0
            --eval_beams 2
        """.split()
        # XXX: args.gpus > 1 : handle multigpu in the future

        testargs = ["finetune.py"] + bash_script.split() + args
81
82
83
84
85
86
87
88
89
90
91
        with patch.object(sys, "argv", testargs):
            parser = argparse.ArgumentParser()
            parser = pl.Trainer.add_argparse_args(parser)
            parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
            args = parser.parse_args()
            model = main(args)

        # Check metrics
        metrics = load_json(model.metrics_save_path)
        first_step_stats = metrics["val"][0]
        last_step_stats = metrics["val"][-1]
92
93
        self.assertEqual(len(metrics["val"]), (args.max_epochs / args.val_check_interval))
        assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
94

95
96
97
        self.assertGreater(last_step_stats["val_avg_gen_time"], 0.01)
        # model hanging on generate. Maybe bad config was saved. (XXX: old comment/assert?)
        self.assertLessEqual(last_step_stats["val_avg_gen_time"], 1.0)
98

99
100
101
102
103
104
105
106
107
108
        # test learning requirements:

        # 1. BLEU improves over the course of training by more than 2 pts
        self.assertGreater(last_step_stats["val_avg_bleu"] - first_step_stats["val_avg_bleu"], 2)

        # 2. BLEU finishes above 17
        self.assertGreater(last_step_stats["val_avg_bleu"], 17)

        # 3. test BLEU and val BLEU within ~1.1 pt.
        self.assertLess(abs(metrics["val"][-1]["val_avg_bleu"] - metrics["test"][-1]["test_avg_bleu"]), 1.1)
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

        # check lightning ckpt can be loaded and has a reasonable statedict
        contents = os.listdir(output_dir)
        ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
        full_path = os.path.join(args.output_dir, ckpt_path)
        ckpt = torch.load(full_path, map_location="cpu")
        expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
        assert expected_key in ckpt["state_dict"]
        assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32

        # TODO: turn on args.do_predict when PL bug fixed.
        if args.do_predict:
            contents = {os.path.basename(p) for p in contents}
            assert "test_generations.txt" in contents
            assert "test_results.txt" in contents
            # assert len(metrics["val"]) ==  desired_n_evals
            assert len(metrics["test"]) == 1

127
128

class TestDistilMarianNoTeacher(TestCasePlus):
129
130
    @timeout_decorator.timeout(600)
    @slow
131
    @require_torch_gpu
132
    @require_torch_non_multigpu_but_fix_me
133
    def test_opus_mt_distill_script(self):
134
        data_dir = f"{self.test_file_dir_str}/test_data/wmt_en_ro"
135
136
137
138
139
140
141
142
143
144
145
146
        env_vars_to_replace = {
            "--fp16_opt_level=O1": "",
            "$MAX_LEN": 128,
            "$BS": 16,
            "$GAS": 1,
            "$ENRO_DIR": data_dir,
            "$m": "sshleifer/student_marian_en_ro_6_1",
            "val_check_interval=0.25": "val_check_interval=1.0",
        }

        # Clean up bash script
        bash_script = (
147
            (self.test_file_dir / "distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        )
        bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
        bash_script = bash_script.replace("--fp16 ", " ")

        for k, v in env_vars_to_replace.items():
            bash_script = bash_script.replace(k, str(v))
        output_dir = self.get_auto_remove_tmp_dir()
        bash_script = bash_script.replace("--fp16", "")
        epochs = 6
        testargs = (
            ["distillation.py"]
            + bash_script.split()
            + [
                f"--output_dir={output_dir}",
                "--gpus=1",
                "--learning_rate=1e-3",
                f"--num_train_epochs={epochs}",
                "--warmup_steps=10",
                "--val_check_interval=1.0",
167
                "--do_predict",
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
            ]
        )
        with patch.object(sys, "argv", testargs):
            parser = argparse.ArgumentParser()
            parser = pl.Trainer.add_argparse_args(parser)
            parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
            args = parser.parse_args()
            # assert args.gpus == gpus THIS BREAKS for multigpu

            model = distill_main(args)

        # Check metrics
        metrics = load_json(model.metrics_save_path)
        first_step_stats = metrics["val"][0]
        last_step_stats = metrics["val"][-1]
        assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval)  # +1 accounts for val_sanity_check

        assert last_step_stats["val_avg_gen_time"] >= 0.01

        assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"]  # model learned nothing
        assert 1.0 >= last_step_stats["val_avg_gen_time"]  # model hanging on generate. Maybe bad config was saved.
        assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)

        # check lightning ckpt can be loaded and has a reasonable statedict
        contents = os.listdir(output_dir)
        ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
        full_path = os.path.join(args.output_dir, ckpt_path)
        ckpt = torch.load(full_path, map_location="cpu")
        expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
        assert expected_key in ckpt["state_dict"]
        assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32

        # TODO: turn on args.do_predict when PL bug fixed.
        if args.do_predict:
            contents = {os.path.basename(p) for p in contents}
            assert "test_generations.txt" in contents
            assert "test_results.txt" in contents
            # assert len(metrics["val"]) ==  desired_n_evals
            assert len(metrics["test"]) == 1