test_trainer_ext.py 8.39 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import math
Suraj Patil's avatar
Suraj Patil committed
16
17
import os
import sys
18
import unittest
Suraj Patil's avatar
Suraj Patil committed
19
20
from unittest.mock import patch

Sylvain Gugger's avatar
Sylvain Gugger committed
21
from transformers.file_utils import is_apex_available
22
from transformers.integrations import is_fairscale_available
23
from transformers.testing_utils import (
24
    ExtendSysPath,
25
26
27
    TestCasePlus,
    execute_subprocess_async,
    get_gpu_count,
28
    require_torch_gpu,
29
30
31
32
    require_torch_multi_gpu,
    require_torch_non_multi_gpu,
    slow,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
33
34
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
Suraj Patil's avatar
Suraj Patil committed
35

36
37

bindir = os.path.abspath(os.path.dirname(__file__))
38
39
with ExtendSysPath(f"{bindir}/../../examples/seq2seq"):
    from run_translation import main  # noqa
40

Suraj Patil's avatar
Suraj Patil committed
41

42
set_seed(42)
Suraj Patil's avatar
Suraj Patil committed
43
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
Sylvain Gugger's avatar
Sylvain Gugger committed
44
MBART_TINY = "sshleifer/tiny-mbart"
Suraj Patil's avatar
Suraj Patil committed
45
46


47
48
49
50
51
52
53
54
55
56
57
# a candidate for testing_utils
def require_fairscale(test_case):
    """
    Decorator marking a test that requires fairscale
    """
    if not is_fairscale_available():
        return unittest.skip("test requires fairscale")(test_case)
    else:
        return test_case


58
59
60
61
62
63
64
65
66
67
68
# a candidate for testing_utils
def require_apex(test_case):
    """
    Decorator marking a test that requires apex
    """
    if not is_apex_available():
        return unittest.skip("test requires apex")(test_case)
    else:
        return test_case


69
class TestTrainerExt(TestCasePlus):
70
    def run_seq2seq_quick(self, distributed=False, extra_args_str=None, predict_with_generate=True):
71
72
73
74
75
76
77
78
79
        output_dir = self.run_trainer(
            eval_steps=1,
            max_len=12,
            model_name=MBART_TINY,
            num_train_epochs=1,
            distributed=distributed,
            extra_args_str=extra_args_str,
            predict_with_generate=predict_with_generate,
        )
80
81
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
82

83
        first_step_stats = eval_metrics[0]
84
85
        if predict_with_generate:
            assert "eval_bleu" in first_step_stats
Suraj Patil's avatar
Suraj Patil committed
86

87
88
89
            last_step_stats = eval_metrics[-1]
            assert isinstance(last_step_stats["eval_bleu"], float)
            assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"
90

91
    @require_torch_non_multi_gpu
92
93
    def test_run_seq2seq_no_dist(self):
        self.run_seq2seq_quick()
94

95
    # verify that the trainer can handle non-distributed with n_gpu > 1
96
    @require_torch_multi_gpu
97
98
    def test_run_seq2seq_dp(self):
        self.run_seq2seq_quick(distributed=False)
99

100
    # verify that the trainer can handle distributed with n_gpu > 1
101
    @require_torch_multi_gpu
102
103
    def test_run_seq2seq_ddp(self):
        self.run_seq2seq_quick(distributed=True)
104

105
    # test --sharded_ddp w/o --fp16
106
107
    @require_torch_multi_gpu
    @require_fairscale
108
109
    def test_run_seq2seq_sharded_ddp(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
110

111
    # test --sharded_ddp w/ --fp16
112
113
    @require_torch_multi_gpu
    @require_fairscale
114
115
116
    def test_run_seq2seq_sharded_ddp_fp16(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")

117
    # test --sharded_ddp zero_dp_2 w/o --fp16
118
119
120
    @require_torch_multi_gpu
    @require_fairscale
    def test_run_seq2seq_fully_sharded_ddp(self):
121
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
122

123
    # test --sharded_ddp zero_dp_2 w/ --fp16
124
125
126
127
    @require_torch_multi_gpu
    @require_fairscale
    def test_run_seq2seq_fully_sharded_ddp_fp16(self):
        self.run_seq2seq_quick(
128
            distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
129
        )
130

131
    @require_apex
132
    @require_torch_gpu
133
    def test_run_seq2seq_apex(self):
134
135
136
137
138
139
140
141
142
143
144
145
        # XXX: apex breaks the trainer if it's run twice e.g. run_seq2seq.main() from the same
        # program and it breaks other tests that run from the same pytest worker, therefore until this is
        # sorted out it must be run only in an external program, that is distributed=True in this
        # test and only under one or more gpus - if we want cpu will need to make a special test
        #
        # specifically to the problem traced it to self.optimizer.step() - if it's run 2nd time via
        # 2nd main() call it botches the future eval.
        #
        self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
        # test 2nd time - was getting eval_loss': nan'
        # to reproduce the problem set distributed=False
        self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
146

147
    @slow
148
    def test_run_seq2seq_slow(self):
149
        output_dir = self.run_trainer(
150
151
152
153
154
155
            eval_steps=2,
            max_len=128,
            model_name=MARIAN_MODEL,
            learning_rate=3e-4,
            num_train_epochs=10,
            distributed=False,
156
        )
Suraj Patil's avatar
Suraj Patil committed
157

158
159
160
161
162
        # Check metrics
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
        first_step_stats = eval_metrics[0]
        last_step_stats = eval_metrics[-1]
163

164
        assert first_step_stats["eval_loss"] > last_step_stats["eval_loss"], "model learned nothing"
165
        assert isinstance(last_step_stats["eval_bleu"], float)
166

167
168
169
        # test if do_predict saves generations and metrics
        contents = os.listdir(output_dir)
        contents = {os.path.basename(p) for p in contents}
170
        assert "test_generations.txt" in contents
171
        assert "test_results.json" in contents
172

173
    def run_trainer(
174
175
        self,
        eval_steps: int,
176
        max_len: int,
177
178
        model_name: str,
        num_train_epochs: int,
179
        learning_rate: float = 3e-3,
180
181
        distributed: bool = False,
        extra_args_str: str = None,
182
        predict_with_generate: bool = True,
183
    ):
184
        data_dir = self.examples_dir / "test_data/wmt_en_ro"
185
        output_dir = self.get_auto_remove_tmp_dir()
186
        args = f"""
187
            --model_name_or_path {model_name}
188
189
190
            --train_file {data_dir}/train.json
            --validation_file {data_dir}/val.json
            --test_file {data_dir}/test.json
191
192
            --output_dir {output_dir}
            --overwrite_output_dir
193
194
            --max_train_samples 8
            --max_val_samples 8
195
            --max_source_length {max_len}
196
197
            --max_target_length {max_len}
            --val_max_target_length {max_len}
198
199
200
201
202
203
            --do_train
            --do_eval
            --do_predict
            --num_train_epochs {str(num_train_epochs)}
            --per_device_train_batch_size 4
            --per_device_eval_batch_size 4
204
            --learning_rate {learning_rate}
205
            --warmup_steps 8
Sylvain Gugger's avatar
Sylvain Gugger committed
206
            --evaluation_strategy steps
207
208
            --logging_steps 0
            --eval_steps {str(eval_steps)}
209
            --save_steps {str(eval_steps)}
210
            --group_by_length
211
            --label_smoothing_factor 0.1
212
            --adafactor
213
214
            --target_lang ro_RO
            --source_lang en_XX
215
216
217
218
219
        """
        if predict_with_generate:
            args += "--predict_with_generate"

        args = args.split()
220

221
222
223
        if extra_args_str is not None:
            args.extend(extra_args_str.split())

224
        if distributed:
225
            n_gpu = get_gpu_count()
226
227
228
            distributed_args = f"""
                -m torch.distributed.launch
                --nproc_per_node={n_gpu}
229
                {self.examples_dir_str}/seq2seq/run_translation.py
230
231
232
            """.split()
            cmd = [sys.executable] + distributed_args + args
            execute_subprocess_async(cmd, env=self.get_env())
233
        else:
234
            testargs = ["run_translation.py"] + args
235
236
            with patch.object(sys, "argv", testargs):
                main()
237

238
        return output_dir