test_trainer_ext.py 10.8 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import math
Suraj Patil's avatar
Suraj Patil committed
16
import os
17
import re
Suraj Patil's avatar
Suraj Patil committed
18
import sys
19
import unittest
Suraj Patil's avatar
Suraj Patil committed
20
21
from unittest.mock import patch

22
from parameterized import parameterized
Sylvain Gugger's avatar
Sylvain Gugger committed
23
from transformers.file_utils import is_apex_available
24
from transformers.integrations import is_fairscale_available
25
from transformers.testing_utils import (
26
    CaptureStderr,
27
    ExtendSysPath,
28
29
30
    TestCasePlus,
    execute_subprocess_async,
    get_gpu_count,
31
    get_torch_dist_unique_port,
32
    require_torch,
33
    require_torch_gpu,
34
35
36
37
    require_torch_multi_gpu,
    require_torch_non_multi_gpu,
    slow,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
38
39
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
Suraj Patil's avatar
Suraj Patil committed
40

41
42

bindir = os.path.abspath(os.path.dirname(__file__))
Sylvain Gugger's avatar
Sylvain Gugger committed
43
with ExtendSysPath(f"{bindir}/../../examples/pytorch/translation"):
44
    from run_translation import main  # noqa
45

Suraj Patil's avatar
Suraj Patil committed
46

47
set_seed(42)
Suraj Patil's avatar
Suraj Patil committed
48
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
Sylvain Gugger's avatar
Sylvain Gugger committed
49
MBART_TINY = "sshleifer/tiny-mbart"
Suraj Patil's avatar
Suraj Patil committed
50
51


52
53
54
55
56
57
58
59
60
61
62
# a candidate for testing_utils
def require_fairscale(test_case):
    """
    Decorator marking a test that requires fairscale
    """
    if not is_fairscale_available():
        return unittest.skip("test requires fairscale")(test_case)
    else:
        return test_case


63
64
65
66
67
68
69
70
71
72
73
# a candidate for testing_utils
def require_apex(test_case):
    """
    Decorator marking a test that requires apex
    """
    if not is_apex_available():
        return unittest.skip("test requires apex")(test_case)
    else:
        return test_case


74
@require_torch
75
class TestTrainerExt(TestCasePlus):
76
77
78
79
80
81
82
83
84
    def run_seq2seq_quick(
        self,
        distributed=False,
        extra_args_str=None,
        predict_with_generate=True,
        do_train=True,
        do_eval=True,
        do_predict=True,
    ):
85
86
87
88
89
90
91
92
        output_dir = self.run_trainer(
            eval_steps=1,
            max_len=12,
            model_name=MBART_TINY,
            num_train_epochs=1,
            distributed=distributed,
            extra_args_str=extra_args_str,
            predict_with_generate=predict_with_generate,
93
94
95
            do_train=do_train,
            do_eval=do_eval,
            do_predict=do_predict,
96
        )
97
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
98
99
100
101

        if not do_eval:
            return

102
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
103

104
        first_step_stats = eval_metrics[0]
105
106
        if predict_with_generate:
            assert "eval_bleu" in first_step_stats
Suraj Patil's avatar
Suraj Patil committed
107

108
109
110
            last_step_stats = eval_metrics[-1]
            assert isinstance(last_step_stats["eval_bleu"], float)
            assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"
111

112
    @require_torch_non_multi_gpu
113
114
    def test_run_seq2seq_no_dist(self):
        self.run_seq2seq_quick()
115

116
    # verify that the trainer can handle non-distributed with n_gpu > 1
117
    @require_torch_multi_gpu
118
119
    def test_run_seq2seq_dp(self):
        self.run_seq2seq_quick(distributed=False)
120

121
    # verify that the trainer can handle distributed with n_gpu > 1
122
    @require_torch_multi_gpu
123
124
    def test_run_seq2seq_ddp(self):
        self.run_seq2seq_quick(distributed=True)
125

126
    # test --sharded_ddp w/o --fp16
127
128
    @require_torch_multi_gpu
    @require_fairscale
129
130
    def test_run_seq2seq_sharded_ddp(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
131

132
    # test --sharded_ddp w/ --fp16
133
134
    @require_torch_multi_gpu
    @require_fairscale
135
136
137
    def test_run_seq2seq_sharded_ddp_fp16(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")

138
    # test --sharded_ddp zero_dp_2 w/o --fp16
139
140
141
    @require_torch_multi_gpu
    @require_fairscale
    def test_run_seq2seq_fully_sharded_ddp(self):
142
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
143

144
    # test --sharded_ddp zero_dp_2 w/ --fp16
145
146
147
148
    @require_torch_multi_gpu
    @require_fairscale
    def test_run_seq2seq_fully_sharded_ddp_fp16(self):
        self.run_seq2seq_quick(
149
            distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
150
        )
151

152
    @require_apex
153
    @require_torch_gpu
154
    def test_run_seq2seq_apex(self):
155
156
157
158
159
160
161
162
163
164
165
166
        # XXX: apex breaks the trainer if it's run twice e.g. run_seq2seq.main() from the same
        # program and it breaks other tests that run from the same pytest worker, therefore until this is
        # sorted out it must be run only in an external program, that is distributed=True in this
        # test and only under one or more gpus - if we want cpu will need to make a special test
        #
        # specifically to the problem traced it to self.optimizer.step() - if it's run 2nd time via
        # 2nd main() call it botches the future eval.
        #
        self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
        # test 2nd time - was getting eval_loss': nan'
        # to reproduce the problem set distributed=False
        self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
167

168
    @parameterized.expand(["base", "low", "high", "mixed"])
169
    @require_torch_multi_gpu
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    def test_trainer_log_level_replica(self, experiment_id):
        # as each sub-test is slow-ish split into multiple sub-tests to avoid CI timeout
        experiments = dict(
            # test with the default log_level - should be info and thus log info once
            base=dict(extra_args_str="", n_matches=1),
            # test with low log_level and log_level_replica - should be noisy on all processes
            # now the info string should appear twice on 2 processes
            low=dict(extra_args_str="--log_level debug --log_level_replica debug", n_matches=2),
            # test with high log_level and low log_level_replica
            # now the info string should appear once only on the replica
            high=dict(extra_args_str="--log_level error --log_level_replica debug", n_matches=1),
            # test with high log_level and log_level_replica - should be quiet on all processes
            mixed=dict(extra_args_str="--log_level error --log_level_replica error", n_matches=0),
        )
184

185
186
187
        data = experiments[experiment_id]
        kwargs = dict(distributed=True, predict_with_generate=False, do_eval=False, do_predict=False)
        log_info_string = "Running training"
188
        with CaptureStderr() as cl:
189
            self.run_seq2seq_quick(**kwargs, extra_args_str=data["extra_args_str"])
190
        n_matches = len(re.findall(log_info_string, cl.err))
191
        self.assertEqual(n_matches, data["n_matches"])
192

193
    @slow
194
    def test_run_seq2seq_slow(self):
195
        output_dir = self.run_trainer(
196
197
198
199
200
201
            eval_steps=2,
            max_len=128,
            model_name=MARIAN_MODEL,
            learning_rate=3e-4,
            num_train_epochs=10,
            distributed=False,
202
        )
Suraj Patil's avatar
Suraj Patil committed
203

204
205
206
207
208
        # Check metrics
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
        first_step_stats = eval_metrics[0]
        last_step_stats = eval_metrics[-1]
209

210
        assert first_step_stats["eval_loss"] > last_step_stats["eval_loss"], "model learned nothing"
211
        assert isinstance(last_step_stats["eval_bleu"], float)
212

213
214
215
        # test if do_predict saves generations and metrics
        contents = os.listdir(output_dir)
        contents = {os.path.basename(p) for p in contents}
216
217
        assert "generated_predictions.txt" in contents
        assert "predict_results.json" in contents
218

219
    def run_trainer(
220
221
        self,
        eval_steps: int,
222
        max_len: int,
223
224
        model_name: str,
        num_train_epochs: int,
225
        learning_rate: float = 3e-3,
226
227
        distributed: bool = False,
        extra_args_str: str = None,
228
        predict_with_generate: bool = True,
229
230
231
        do_train: bool = True,
        do_eval: bool = True,
        do_predict: bool = True,
232
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
233
        data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
234
        output_dir = self.get_auto_remove_tmp_dir()
235
        args_train = f"""
236
            --model_name_or_path {model_name}
237
238
239
            --train_file {data_dir}/train.json
            --validation_file {data_dir}/val.json
            --test_file {data_dir}/test.json
240
241
            --output_dir {output_dir}
            --overwrite_output_dir
242
            --max_train_samples 8
243
            --max_source_length {max_len}
244
            --max_target_length {max_len}
245
246
247
            --do_train
            --num_train_epochs {str(num_train_epochs)}
            --per_device_train_batch_size 4
248
            --learning_rate {learning_rate}
249
250
            --warmup_steps 8
            --logging_steps 0
251
            --save_steps {str(eval_steps)}
252
            --group_by_length
253
            --label_smoothing_factor 0.1
254
            --adafactor
255
256
            --target_lang ro_RO
            --source_lang en_XX
257
        """
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

        args_eval = f"""
            --do_eval
            --per_device_eval_batch_size 4
            --max_eval_samples 8
            --val_max_target_length {max_len}
            --evaluation_strategy steps
            --eval_steps {str(eval_steps)}
        """

        args_predict = """
            --do_predict
        """

        args = ""
        if do_train:
            args += args_train

        if do_eval:
            args += args_eval

        if do_predict:
            args += args_predict

282
283
284
285
        if predict_with_generate:
            args += "--predict_with_generate"

        args = args.split()
286

287
288
289
        if extra_args_str is not None:
            args.extend(extra_args_str.split())

290
        if distributed:
291
            n_gpu = get_gpu_count()
292
            master_port = get_torch_dist_unique_port()
293
294
295
            distributed_args = f"""
                -m torch.distributed.launch
                --nproc_per_node={n_gpu}
296
                --master_port={master_port}
Sylvain Gugger's avatar
Sylvain Gugger committed
297
                {self.examples_dir_str}/pytorch/translation/run_translation.py
298
299
300
            """.split()
            cmd = [sys.executable] + distributed_args + args
            execute_subprocess_async(cmd, env=self.get_env())
301
        else:
302
            testargs = ["run_translation.py"] + args
303
304
            with patch.object(sys, "argv", testargs):
                main()
305

306
        return output_dir