test_trainer_ext.py 11 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import math
Suraj Patil's avatar
Suraj Patil committed
16
import os
17
import re
Suraj Patil's avatar
Suraj Patil committed
18
import sys
19
import unittest
Suraj Patil's avatar
Suraj Patil committed
20
21
from unittest.mock import patch

Sylvain Gugger's avatar
Sylvain Gugger committed
22
from transformers.file_utils import is_apex_available
23
from transformers.integrations import is_fairscale_available
24
from transformers.testing_utils import (
25
    CaptureStderr,
26
    ExtendSysPath,
27
28
29
    TestCasePlus,
    execute_subprocess_async,
    get_gpu_count,
30
    get_torch_dist_unique_port,
31
    require_torch_gpu,
32
33
34
35
    require_torch_multi_gpu,
    require_torch_non_multi_gpu,
    slow,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
36
37
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
Suraj Patil's avatar
Suraj Patil committed
38

39
40

bindir = os.path.abspath(os.path.dirname(__file__))
Sylvain Gugger's avatar
Sylvain Gugger committed
41
with ExtendSysPath(f"{bindir}/../../examples/pytorch/translation"):
42
    from run_translation import main  # noqa
43

Suraj Patil's avatar
Suraj Patil committed
44

45
set_seed(42)
Suraj Patil's avatar
Suraj Patil committed
46
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
Sylvain Gugger's avatar
Sylvain Gugger committed
47
MBART_TINY = "sshleifer/tiny-mbart"
Suraj Patil's avatar
Suraj Patil committed
48
49


50
51
52
53
54
55
56
57
58
59
60
# a candidate for testing_utils
def require_fairscale(test_case):
    """
    Decorator marking a test that requires fairscale
    """
    if not is_fairscale_available():
        return unittest.skip("test requires fairscale")(test_case)
    else:
        return test_case


61
62
63
64
65
66
67
68
69
70
71
# a candidate for testing_utils
def require_apex(test_case):
    """
    Decorator marking a test that requires apex
    """
    if not is_apex_available():
        return unittest.skip("test requires apex")(test_case)
    else:
        return test_case


72
class TestTrainerExt(TestCasePlus):
73
74
75
76
77
78
79
80
81
    def run_seq2seq_quick(
        self,
        distributed=False,
        extra_args_str=None,
        predict_with_generate=True,
        do_train=True,
        do_eval=True,
        do_predict=True,
    ):
82
83
84
85
86
87
88
89
        output_dir = self.run_trainer(
            eval_steps=1,
            max_len=12,
            model_name=MBART_TINY,
            num_train_epochs=1,
            distributed=distributed,
            extra_args_str=extra_args_str,
            predict_with_generate=predict_with_generate,
90
91
92
            do_train=do_train,
            do_eval=do_eval,
            do_predict=do_predict,
93
        )
94
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
95
96
97
98

        if not do_eval:
            return

99
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
100

101
        first_step_stats = eval_metrics[0]
102
103
        if predict_with_generate:
            assert "eval_bleu" in first_step_stats
Suraj Patil's avatar
Suraj Patil committed
104

105
106
107
            last_step_stats = eval_metrics[-1]
            assert isinstance(last_step_stats["eval_bleu"], float)
            assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"
108

109
    @require_torch_non_multi_gpu
110
111
    def test_run_seq2seq_no_dist(self):
        self.run_seq2seq_quick()
112

113
    # verify that the trainer can handle non-distributed with n_gpu > 1
114
    @require_torch_multi_gpu
115
116
    def test_run_seq2seq_dp(self):
        self.run_seq2seq_quick(distributed=False)
117

118
    # verify that the trainer can handle distributed with n_gpu > 1
119
    @require_torch_multi_gpu
120
121
    def test_run_seq2seq_ddp(self):
        self.run_seq2seq_quick(distributed=True)
122

123
    # test --sharded_ddp w/o --fp16
124
125
    @require_torch_multi_gpu
    @require_fairscale
126
127
    def test_run_seq2seq_sharded_ddp(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
128

129
    # test --sharded_ddp w/ --fp16
130
131
    @require_torch_multi_gpu
    @require_fairscale
132
133
134
    def test_run_seq2seq_sharded_ddp_fp16(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")

135
    # test --sharded_ddp zero_dp_2 w/o --fp16
136
137
138
    @require_torch_multi_gpu
    @require_fairscale
    def test_run_seq2seq_fully_sharded_ddp(self):
139
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
140

141
    # test --sharded_ddp zero_dp_2 w/ --fp16
142
143
144
145
    @require_torch_multi_gpu
    @require_fairscale
    def test_run_seq2seq_fully_sharded_ddp_fp16(self):
        self.run_seq2seq_quick(
146
            distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
147
        )
148

149
    @require_apex
150
    @require_torch_gpu
151
    def test_run_seq2seq_apex(self):
152
153
154
155
156
157
158
159
160
161
162
163
        # XXX: apex breaks the trainer if it's run twice e.g. run_seq2seq.main() from the same
        # program and it breaks other tests that run from the same pytest worker, therefore until this is
        # sorted out it must be run only in an external program, that is distributed=True in this
        # test and only under one or more gpus - if we want cpu will need to make a special test
        #
        # specifically to the problem traced it to self.optimizer.step() - if it's run 2nd time via
        # 2nd main() call it botches the future eval.
        #
        self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
        # test 2nd time - was getting eval_loss': nan'
        # to reproduce the problem set distributed=False
        self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
164

165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    @require_torch_multi_gpu
    def test_trainer_log_level_replica(self):
        log_info_string = "Running training"
        kwargs = dict(distributed=True, predict_with_generate=False, do_eval=False, do_predict=False)

        # test with the default log_level - should be info and thus log info once
        with CaptureStderr() as cl:
            self.run_seq2seq_quick(
                **kwargs,
                extra_args_str="",
            )
        n_matches = len(re.findall(log_info_string, cl.err))
        self.assertEqual(n_matches, 1)

        # test with low log_level and log_level_replica - should be noisy on all processes
        # now the info string should appear twice on 2 processes
        with CaptureStderr() as cl:
            self.run_seq2seq_quick(
                **kwargs,
                extra_args_str="--log_level debug --log_level_replica debug",
            )
        n_matches = len(re.findall(log_info_string, cl.err))
        self.assertEqual(n_matches, 2)

        # test with high log_level and low log_level_replica
        # now the info string should appear once only on the replica
        with CaptureStderr() as cl:
            self.run_seq2seq_quick(
                **kwargs,
                extra_args_str="--log_level error --log_level_replica debug",
            )
        n_matches = len(re.findall(log_info_string, cl.err))
        self.assertEqual(n_matches, 1)

        # test with high log_level and log_level_replica - should be quiet on all processes
        with CaptureStderr() as cl:
            self.run_seq2seq_quick(
                **kwargs,
                extra_args_str="--log_level error --log_level_replica error",
            )
        n_matches = len(re.findall(log_info_string, cl.err))
        self.assertEqual(n_matches, 0)

208
    @slow
209
    def test_run_seq2seq_slow(self):
210
        output_dir = self.run_trainer(
211
212
213
214
215
216
            eval_steps=2,
            max_len=128,
            model_name=MARIAN_MODEL,
            learning_rate=3e-4,
            num_train_epochs=10,
            distributed=False,
217
        )
Suraj Patil's avatar
Suraj Patil committed
218

219
220
221
222
223
        # Check metrics
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
        first_step_stats = eval_metrics[0]
        last_step_stats = eval_metrics[-1]
224

225
        assert first_step_stats["eval_loss"] > last_step_stats["eval_loss"], "model learned nothing"
226
        assert isinstance(last_step_stats["eval_bleu"], float)
227

228
229
230
        # test if do_predict saves generations and metrics
        contents = os.listdir(output_dir)
        contents = {os.path.basename(p) for p in contents}
231
232
        assert "generated_predictions.txt" in contents
        assert "predict_results.json" in contents
233

234
    def run_trainer(
235
236
        self,
        eval_steps: int,
237
        max_len: int,
238
239
        model_name: str,
        num_train_epochs: int,
240
        learning_rate: float = 3e-3,
241
242
        distributed: bool = False,
        extra_args_str: str = None,
243
        predict_with_generate: bool = True,
244
245
246
        do_train: bool = True,
        do_eval: bool = True,
        do_predict: bool = True,
247
    ):
Sylvain Gugger's avatar
Sylvain Gugger committed
248
        data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
249
        output_dir = self.get_auto_remove_tmp_dir()
250
        args_train = f"""
251
            --model_name_or_path {model_name}
252
253
254
            --train_file {data_dir}/train.json
            --validation_file {data_dir}/val.json
            --test_file {data_dir}/test.json
255
256
            --output_dir {output_dir}
            --overwrite_output_dir
257
            --max_train_samples 8
258
            --max_source_length {max_len}
259
            --max_target_length {max_len}
260
261
262
            --do_train
            --num_train_epochs {str(num_train_epochs)}
            --per_device_train_batch_size 4
263
            --learning_rate {learning_rate}
264
265
            --warmup_steps 8
            --logging_steps 0
266
            --save_steps {str(eval_steps)}
267
            --group_by_length
268
            --label_smoothing_factor 0.1
269
            --adafactor
270
271
            --target_lang ro_RO
            --source_lang en_XX
272
        """
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296

        args_eval = f"""
            --do_eval
            --per_device_eval_batch_size 4
            --max_eval_samples 8
            --val_max_target_length {max_len}
            --evaluation_strategy steps
            --eval_steps {str(eval_steps)}
        """

        args_predict = """
            --do_predict
        """

        args = ""
        if do_train:
            args += args_train

        if do_eval:
            args += args_eval

        if do_predict:
            args += args_predict

297
298
299
300
        if predict_with_generate:
            args += "--predict_with_generate"

        args = args.split()
301

302
303
304
        if extra_args_str is not None:
            args.extend(extra_args_str.split())

305
        if distributed:
306
            n_gpu = get_gpu_count()
307
            master_port = get_torch_dist_unique_port()
308
309
310
            distributed_args = f"""
                -m torch.distributed.launch
                --nproc_per_node={n_gpu}
311
                --master_port={master_port}
Sylvain Gugger's avatar
Sylvain Gugger committed
312
                {self.examples_dir_str}/pytorch/translation/run_translation.py
313
314
315
            """.split()
            cmd = [sys.executable] + distributed_args + args
            execute_subprocess_async(cmd, env=self.get_env())
316
        else:
317
            testargs = ["run_translation.py"] + args
318
319
            with patch.object(sys, "argv", testargs):
                main()
320

321
        return output_dir