test_trainer_ext.py 11 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import math
Suraj Patil's avatar
Suraj Patil committed
16
import os
17
import re
Suraj Patil's avatar
Suraj Patil committed
18
import sys
19
import unittest
Suraj Patil's avatar
Suraj Patil committed
20
21
from unittest.mock import patch

Sylvain Gugger's avatar
Sylvain Gugger committed
22
from transformers.file_utils import is_apex_available
23
from transformers.integrations import is_fairscale_available
24
from transformers.testing_utils import (
25
    CaptureStderr,
26
    ExtendSysPath,
27
28
29
    TestCasePlus,
    execute_subprocess_async,
    get_gpu_count,
30
    get_torch_dist_unique_port,
31
    require_torch,
32
    require_torch_gpu,
33
34
35
36
    require_torch_multi_gpu,
    require_torch_non_multi_gpu,
    slow,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
37
38
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
Suraj Patil's avatar
Suraj Patil committed
39

40
41

bindir = os.path.abspath(os.path.dirname(__file__))
Sylvain Gugger's avatar
Sylvain Gugger committed
42
with ExtendSysPath(f"{bindir}/../../examples/pytorch/translation"):
43
    from run_translation import main  # noqa
44

Suraj Patil's avatar
Suraj Patil committed
45

46
set_seed(42)
Suraj Patil's avatar
Suraj Patil committed
47
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
Sylvain Gugger's avatar
Sylvain Gugger committed
48
MBART_TINY = "sshleifer/tiny-mbart"
Suraj Patil's avatar
Suraj Patil committed
49
50


51
52
53
54
55
56
57
58
59
60
61
# a candidate for testing_utils
def require_fairscale(test_case):
    """
    Decorator marking a test that requires fairscale
    """
    if not is_fairscale_available():
        return unittest.skip("test requires fairscale")(test_case)
    else:
        return test_case


62
63
64
65
66
67
68
69
70
71
72
# a candidate for testing_utils
def require_apex(test_case):
    """
    Decorator marking a test that requires apex
    """
    if not is_apex_available():
        return unittest.skip("test requires apex")(test_case)
    else:
        return test_case


73
@require_torch
74
class TestTrainerExt(TestCasePlus):
75
76
77
78
79
80
81
82
83
    def run_seq2seq_quick(
        self,
        distributed=False,
        extra_args_str=None,
        predict_with_generate=True,
        do_train=True,
        do_eval=True,
        do_predict=True,
    ):
84
85
86
87
88
89
90
91
        output_dir = self.run_trainer(
            eval_steps=1,
            max_len=12,
            model_name=MBART_TINY,
            num_train_epochs=1,
            distributed=distributed,
            extra_args_str=extra_args_str,
            predict_with_generate=predict_with_generate,
92
93
94
            do_train=do_train,
            do_eval=do_eval,
            do_predict=do_predict,
95
        )
96
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
97
98
99
100

        if not do_eval:
            return

101
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
102

103
        first_step_stats = eval_metrics[0]
104
105
        if predict_with_generate:
            assert "eval_bleu" in first_step_stats
Suraj Patil's avatar
Suraj Patil committed
106

107
108
109
            last_step_stats = eval_metrics[-1]
            assert isinstance(last_step_stats["eval_bleu"], float)
            assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"
110

111
    @require_torch_non_multi_gpu
112
113
    def test_run_seq2seq_no_dist(self):
        self.run_seq2seq_quick()
114

115
    # verify that the trainer can handle non-distributed with n_gpu > 1
116
    @require_torch_multi_gpu
117
118
    def test_run_seq2seq_dp(self):
        self.run_seq2seq_quick(distributed=False)
119

120
    # verify that the trainer can handle distributed with n_gpu > 1
121
    @require_torch_multi_gpu
122
123
    def test_run_seq2seq_ddp(self):
        self.run_seq2seq_quick(distributed=True)
124

125
    # test --sharded_ddp w/o --fp16
126
127
    @require_torch_multi_gpu
    @require_fairscale
128
129
    def test_run_seq2seq_sharded_ddp(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
130

131
    # test --sharded_ddp w/ --fp16
132
133
    @require_torch_multi_gpu
    @require_fairscale
134
135
136
    def test_run_seq2seq_sharded_ddp_fp16(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")

137
    # test --sharded_ddp zero_dp_2 w/o --fp16
138
139
140
    @require_torch_multi_gpu
    @require_fairscale
    def test_run_seq2seq_fully_sharded_ddp(self):
141
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
142

143
    # test --sharded_ddp zero_dp_2 w/ --fp16
144
145
146
147
    @require_torch_multi_gpu
    @require_fairscale
    def test_run_seq2seq_fully_sharded_ddp_fp16(self):
        self.run_seq2seq_quick(
148
            distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
149
        )
150

151
    @require_apex
152
    @require_torch_gpu
153
    def test_run_seq2seq_apex(self):
154
155
156
157
158
159
160
161
162
163
164
165
        # XXX: apex breaks the trainer if it's run twice e.g. run_seq2seq.main() from the same
        # program and it breaks other tests that run from the same pytest worker, therefore until this is
        # sorted out it must be run only in an external program, that is distributed=True in this
        # test and only under one or more gpus - if we want cpu will need to make a special test
        #
        # specifically to the problem traced it to self.optimizer.step() - if it's run 2nd time via
        # 2nd main() call it botches the future eval.
        #
        self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
        # test 2nd time - was getting eval_loss': nan'
        # to reproduce the problem set distributed=False
        self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
166

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    @require_torch_multi_gpu
    def test_trainer_log_level_replica(self):
        log_info_string = "Running training"
        kwargs = dict(distributed=True, predict_with_generate=False, do_eval=False, do_predict=False)

        # test with the default log_level - should be info and thus log info once
        with CaptureStderr() as cl:
            self.run_seq2seq_quick(
                **kwargs,
                extra_args_str="",
            )
        n_matches = len(re.findall(log_info_string, cl.err))
        self.assertEqual(n_matches, 1)

        # test with low log_level and log_level_replica - should be noisy on all processes
        # now the info string should appear twice on 2 processes
        with CaptureStderr() as cl:
            self.run_seq2seq_quick(
                **kwargs,
                extra_args_str="--log_level debug --log_level_replica debug",
            )
        n_matches = len(re.findall(log_info_string, cl.err))
        self.assertEqual(n_matches, 2)

        # test with high log_level and low log_level_replica
        # now the info string should appear once only on the replica
        with CaptureStderr() as cl:
            self.run_seq2seq_quick(
                **kwargs,
                extra_args_str="--log_level error --log_level_replica debug",
            )
        n_matches = len(re.findall(log_info_string, cl.err))
        self.assertEqual(n_matches, 1)

        # test with high log_level and log_level_replica - should be quiet on all processes
        with CaptureStderr() as cl:
            self.run_seq2seq_quick(
                **kwargs,
                extra_args_str="--log_level error --log_level_replica error",
            )
        n_matches = len(re.findall(log_info_string, cl.err))
        self.assertEqual(n_matches, 0)

210
    @slow
211
    def test_run_seq2seq_slow(self):
212
        output_dir = self.run_trainer(
213
214
215
216
217
218
            eval_steps=2,
            max_len=128,
            model_name=MARIAN_MODEL,
            learning_rate=3e-4,
            num_train_epochs=10,
            distributed=False,
219
        )
Suraj Patil's avatar
Suraj Patil committed
220

221
222
223
224
225
        # Check metrics
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
        first_step_stats = eval_metrics[0]
        last_step_stats = eval_metrics[-1]
226

227
        assert first_step_stats["eval_loss"] > last_step_stats["eval_loss"], "model learned nothing"
228
        assert isinstance(last_step_stats["eval_bleu"], float)
229

230
231
232
        # test if do_predict saves generations and metrics
        contents = os.listdir(output_dir)
        contents = {os.path.basename(p) for p in contents}
233
234
        assert "generated_predictions.txt" in contents
        assert "predict_results.json" in contents
235

236
    def run_trainer(
237
238
        self,
        eval_steps: int,
239
        max_len: int,
240
241
        model_name: str,
        num_train_epochs: int,
242
        learning_rate: float = 3e-3,
243
244
        distributed: bool = False,
        extra_args_str: str = None,
245
        predict_with_generate: bool = True,
246
247
248
        do_train: bool = True,
        do_eval: bool = True,
        do_predict: bool = True,
249
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
250
        data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
251
        output_dir = self.get_auto_remove_tmp_dir()
252
        args_train = f"""
253
            --model_name_or_path {model_name}
254
255
256
            --train_file {data_dir}/train.json
            --validation_file {data_dir}/val.json
            --test_file {data_dir}/test.json
257
258
            --output_dir {output_dir}
            --overwrite_output_dir
259
            --max_train_samples 8
260
            --max_source_length {max_len}
261
            --max_target_length {max_len}
262
263
264
            --do_train
            --num_train_epochs {str(num_train_epochs)}
            --per_device_train_batch_size 4
265
            --learning_rate {learning_rate}
266
267
            --warmup_steps 8
            --logging_steps 0
268
            --save_steps {str(eval_steps)}
269
            --group_by_length
270
            --label_smoothing_factor 0.1
271
            --adafactor
272
273
            --target_lang ro_RO
            --source_lang en_XX
274
        """
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298

        args_eval = f"""
            --do_eval
            --per_device_eval_batch_size 4
            --max_eval_samples 8
            --val_max_target_length {max_len}
            --evaluation_strategy steps
            --eval_steps {str(eval_steps)}
        """

        args_predict = """
            --do_predict
        """

        args = ""
        if do_train:
            args += args_train

        if do_eval:
            args += args_eval

        if do_predict:
            args += args_predict

299
300
301
302
        if predict_with_generate:
            args += "--predict_with_generate"

        args = args.split()
303

304
305
306
        if extra_args_str is not None:
            args.extend(extra_args_str.split())

307
        if distributed:
308
            n_gpu = get_gpu_count()
309
            master_port = get_torch_dist_unique_port()
310
311
312
            distributed_args = f"""
                -m torch.distributed.launch
                --nproc_per_node={n_gpu}
313
                --master_port={master_port}
Sylvain Gugger's avatar
Sylvain Gugger committed
314
                {self.examples_dir_str}/pytorch/translation/run_translation.py
315
316
317
            """.split()
            cmd = [sys.executable] + distributed_args + args
            execute_subprocess_async(cmd, env=self.get_env())
318
        else:
319
            testargs = ["run_translation.py"] + args
320
321
            with patch.object(sys, "argv", testargs):
                main()
322

323
        return output_dir