test_trainer_ext.py 10.9 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import math
Suraj Patil's avatar
Suraj Patil committed
16
import os
17
import re
Suraj Patil's avatar
Suraj Patil committed
18
import sys
19
import unittest
Suraj Patil's avatar
Suraj Patil committed
20
21
from unittest.mock import patch

22
from parameterized import parameterized
Sylvain Gugger's avatar
Sylvain Gugger committed
23
from transformers.file_utils import is_apex_available
24
from transformers.integrations import is_fairscale_available
25
from transformers.testing_utils import (
26
    CaptureStderr,
27
    ExtendSysPath,
28
29
30
    TestCasePlus,
    execute_subprocess_async,
    get_gpu_count,
31
    get_torch_dist_unique_port,
32
    require_torch,
33
    require_torch_gpu,
34
35
36
37
    require_torch_multi_gpu,
    require_torch_non_multi_gpu,
    slow,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
38
39
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
Suraj Patil's avatar
Suraj Patil committed
40

41
42

bindir = os.path.abspath(os.path.dirname(__file__))
Sylvain Gugger's avatar
Sylvain Gugger committed
43
with ExtendSysPath(f"{bindir}/../../examples/pytorch/translation"):
44
    from run_translation import main  # noqa
45

Suraj Patil's avatar
Suraj Patil committed
46

47
set_seed(42)
Suraj Patil's avatar
Suraj Patil committed
48
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
Sylvain Gugger's avatar
Sylvain Gugger committed
49
MBART_TINY = "sshleifer/tiny-mbart"
Suraj Patil's avatar
Suraj Patil committed
50
51


52
53
54
55
56
57
58
59
60
61
62
# a candidate for testing_utils
def require_fairscale(test_case):
    """
    Decorator marking a test that requires fairscale
    """
    if not is_fairscale_available():
        return unittest.skip("test requires fairscale")(test_case)
    else:
        return test_case


63
64
65
66
67
68
69
70
71
72
73
# a candidate for testing_utils
def require_apex(test_case):
    """
    Decorator marking a test that requires apex
    """
    if not is_apex_available():
        return unittest.skip("test requires apex")(test_case)
    else:
        return test_case


74
@require_torch
75
class TestTrainerExt(TestCasePlus):
76
77
78
79
80
81
82
83
84
    def run_seq2seq_quick(
        self,
        distributed=False,
        extra_args_str=None,
        predict_with_generate=True,
        do_train=True,
        do_eval=True,
        do_predict=True,
    ):
85
86
87
88
89
90
91
92
        output_dir = self.run_trainer(
            eval_steps=1,
            max_len=12,
            model_name=MBART_TINY,
            num_train_epochs=1,
            distributed=distributed,
            extra_args_str=extra_args_str,
            predict_with_generate=predict_with_generate,
93
94
95
            do_train=do_train,
            do_eval=do_eval,
            do_predict=do_predict,
96
        )
97
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
98
99
100
101

        if not do_eval:
            return

102
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
103

104
        first_step_stats = eval_metrics[0]
105
106
        if predict_with_generate:
            assert "eval_bleu" in first_step_stats
Suraj Patil's avatar
Suraj Patil committed
107

108
109
110
            last_step_stats = eval_metrics[-1]
            assert isinstance(last_step_stats["eval_bleu"], float)
            assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"
111

112
    @require_torch_non_multi_gpu
113
114
    def test_run_seq2seq_no_dist(self):
        self.run_seq2seq_quick()
115

116
    # verify that the trainer can handle non-distributed with n_gpu > 1
117
    @require_torch_multi_gpu
118
119
    def test_run_seq2seq_dp(self):
        self.run_seq2seq_quick(distributed=False)
120

121
    # verify that the trainer can handle distributed with n_gpu > 1
122
    @require_torch_multi_gpu
123
124
    def test_run_seq2seq_ddp(self):
        self.run_seq2seq_quick(distributed=True)
125

126
    # test --sharded_ddp w/o --fp16
127
128
    @require_torch_multi_gpu
    @require_fairscale
129
130
    def test_run_seq2seq_sharded_ddp(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
131

132
    # test --sharded_ddp w/ --fp16
133
    @unittest.skip("Requires an update of the env running those tests")
134
135
    @require_torch_multi_gpu
    @require_fairscale
136
137
138
    def test_run_seq2seq_sharded_ddp_fp16(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")

139
    # test --sharded_ddp zero_dp_2 w/o --fp16
140
141
142
    @require_torch_multi_gpu
    @require_fairscale
    def test_run_seq2seq_fully_sharded_ddp(self):
143
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
144

145
    # test --sharded_ddp zero_dp_2 w/ --fp16
146
    @unittest.skip("Requires an update of the env running those tests")
147
148
149
150
    @require_torch_multi_gpu
    @require_fairscale
    def test_run_seq2seq_fully_sharded_ddp_fp16(self):
        self.run_seq2seq_quick(
151
            distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
152
        )
153

154
    @require_apex
155
    @require_torch_gpu
156
    def test_run_seq2seq_apex(self):
157
158
159
160
161
162
163
164
165
166
167
168
        # XXX: apex breaks the trainer if it's run twice e.g. run_seq2seq.main() from the same
        # program and it breaks other tests that run from the same pytest worker, therefore until this is
        # sorted out it must be run only in an external program, that is distributed=True in this
        # test and only under one or more gpus - if we want cpu will need to make a special test
        #
        # specifically to the problem traced it to self.optimizer.step() - if it's run 2nd time via
        # 2nd main() call it botches the future eval.
        #
        self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
        # test 2nd time - was getting eval_loss': nan'
        # to reproduce the problem set distributed=False
        self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
169

170
    @parameterized.expand(["base", "low", "high", "mixed"])
171
    @require_torch_multi_gpu
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    def test_trainer_log_level_replica(self, experiment_id):
        # as each sub-test is slow-ish split into multiple sub-tests to avoid CI timeout
        experiments = dict(
            # test with the default log_level - should be info and thus log info once
            base=dict(extra_args_str="", n_matches=1),
            # test with low log_level and log_level_replica - should be noisy on all processes
            # now the info string should appear twice on 2 processes
            low=dict(extra_args_str="--log_level debug --log_level_replica debug", n_matches=2),
            # test with high log_level and low log_level_replica
            # now the info string should appear once only on the replica
            high=dict(extra_args_str="--log_level error --log_level_replica debug", n_matches=1),
            # test with high log_level and log_level_replica - should be quiet on all processes
            mixed=dict(extra_args_str="--log_level error --log_level_replica error", n_matches=0),
        )
186

187
188
189
        data = experiments[experiment_id]
        kwargs = dict(distributed=True, predict_with_generate=False, do_eval=False, do_predict=False)
        log_info_string = "Running training"
190
        with CaptureStderr() as cl:
191
            self.run_seq2seq_quick(**kwargs, extra_args_str=data["extra_args_str"])
192
        n_matches = len(re.findall(log_info_string, cl.err))
193
        self.assertEqual(n_matches, data["n_matches"])
194

195
    @slow
196
    def test_run_seq2seq_slow(self):
197
        output_dir = self.run_trainer(
198
199
200
201
202
203
            eval_steps=2,
            max_len=128,
            model_name=MARIAN_MODEL,
            learning_rate=3e-4,
            num_train_epochs=10,
            distributed=False,
204
        )
Suraj Patil's avatar
Suraj Patil committed
205

206
207
208
209
210
        # Check metrics
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
        first_step_stats = eval_metrics[0]
        last_step_stats = eval_metrics[-1]
211

212
        assert first_step_stats["eval_loss"] > last_step_stats["eval_loss"], "model learned nothing"
213
        assert isinstance(last_step_stats["eval_bleu"], float)
214

215
216
217
        # test if do_predict saves generations and metrics
        contents = os.listdir(output_dir)
        contents = {os.path.basename(p) for p in contents}
218
219
        assert "generated_predictions.txt" in contents
        assert "predict_results.json" in contents
220

221
    def run_trainer(
222
223
        self,
        eval_steps: int,
224
        max_len: int,
225
226
        model_name: str,
        num_train_epochs: int,
227
        learning_rate: float = 3e-3,
228
229
        distributed: bool = False,
        extra_args_str: str = None,
230
        predict_with_generate: bool = True,
231
232
233
        do_train: bool = True,
        do_eval: bool = True,
        do_predict: bool = True,
234
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
235
        data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
236
        output_dir = self.get_auto_remove_tmp_dir()
237
        args_train = f"""
238
            --model_name_or_path {model_name}
239
240
241
            --train_file {data_dir}/train.json
            --validation_file {data_dir}/val.json
            --test_file {data_dir}/test.json
242
243
            --output_dir {output_dir}
            --overwrite_output_dir
244
            --max_train_samples 8
245
            --max_source_length {max_len}
246
            --max_target_length {max_len}
247
248
249
            --do_train
            --num_train_epochs {str(num_train_epochs)}
            --per_device_train_batch_size 4
250
            --learning_rate {learning_rate}
251
252
            --warmup_steps 8
            --logging_steps 0
253
            --logging_strategy no
254
            --save_steps {str(eval_steps)}
255
            --group_by_length
256
            --label_smoothing_factor 0.1
257
            --adafactor
258
259
            --target_lang ro_RO
            --source_lang en_XX
260
        """
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

        args_eval = f"""
            --do_eval
            --per_device_eval_batch_size 4
            --max_eval_samples 8
            --val_max_target_length {max_len}
            --evaluation_strategy steps
            --eval_steps {str(eval_steps)}
        """

        args_predict = """
            --do_predict
        """

        args = ""
        if do_train:
            args += args_train

        if do_eval:
            args += args_eval

        if do_predict:
            args += args_predict

285
286
287
288
        if predict_with_generate:
            args += "--predict_with_generate"

        args = args.split()
289

290
291
292
        if extra_args_str is not None:
            args.extend(extra_args_str.split())

293
        if distributed:
294
            n_gpu = get_gpu_count()
295
            master_port = get_torch_dist_unique_port()
296
297
298
            distributed_args = f"""
                -m torch.distributed.launch
                --nproc_per_node={n_gpu}
299
                --master_port={master_port}
Sylvain Gugger's avatar
Sylvain Gugger committed
300
                {self.examples_dir_str}/pytorch/translation/run_translation.py
301
302
303
            """.split()
            cmd = [sys.executable] + distributed_args + args
            execute_subprocess_async(cmd, env=self.get_env())
304
        else:
305
            testargs = ["run_translation.py"] + args
306
307
            with patch.object(sys, "argv", testargs):
                main()
308

309
        return output_dir