test_trainer_ext.py 8.46 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import math
Suraj Patil's avatar
Suraj Patil committed
16
17
import os
import sys
18
import unittest
Suraj Patil's avatar
Suraj Patil committed
19
20
from unittest.mock import patch

Sylvain Gugger's avatar
Sylvain Gugger committed
21
from transformers.file_utils import is_apex_available
22
from transformers.integrations import is_fairscale_available
23
24
25
26
from transformers.testing_utils import (
    TestCasePlus,
    execute_subprocess_async,
    get_gpu_count,
27
    require_torch_gpu,
28
29
30
31
    require_torch_multi_gpu,
    require_torch_non_multi_gpu,
    slow,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
32
33
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
Suraj Patil's avatar
Suraj Patil committed
34

35
36
37
38

bindir = os.path.abspath(os.path.dirname(__file__))
sys.path.append(f"{bindir}/../../seq2seq")
from run_seq2seq import main  # noqa
39

Suraj Patil's avatar
Suraj Patil committed
40

41
set_seed(42)
Suraj Patil's avatar
Suraj Patil committed
42
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
Sylvain Gugger's avatar
Sylvain Gugger committed
43
MBART_TINY = "sshleifer/tiny-mbart"
Suraj Patil's avatar
Suraj Patil committed
44
45


46
47
48
49
50
51
52
53
54
55
56
# a candidate for testing_utils
def require_fairscale(test_case):
    """
    Decorator marking a test that requires fairscale
    """
    if not is_fairscale_available():
        return unittest.skip("test requires fairscale")(test_case)
    else:
        return test_case


57
58
59
60
61
62
63
64
65
66
67
# a candidate for testing_utils
def require_apex(test_case):
    """
    Decorator marking a test that requires apex
    """
    if not is_apex_available():
        return unittest.skip("test requires apex")(test_case)
    else:
        return test_case


68
class TestTrainerExt(TestCasePlus):
69
    def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, predict_with_generate=True):
70
71
72
73
74
75
76
77
78
        output_dir = self.run_trainer(
            eval_steps=1,
            max_len=12,
            model_name=MBART_TINY,
            num_train_epochs=1,
            distributed=distributed,
            extra_args_str=extra_args_str,
            predict_with_generate=predict_with_generate,
        )
79
80
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
81

82
        first_step_stats = eval_metrics[0]
83
84
        if predict_with_generate:
            assert "eval_bleu" in first_step_stats
Suraj Patil's avatar
Suraj Patil committed
85

86
87
88
89
        last_step_stats = eval_metrics[-1]
        assert isinstance(last_step_stats["eval_bleu"], float)
        assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"

90
    @require_torch_non_multi_gpu
91
92
    def test_run_seq2seq_no_dist(self):
        self.run_seq2seq_quick()
93

94
    # verify that the trainer can handle non-distributed with n_gpu > 1
95
    @require_torch_multi_gpu
96
97
    def test_run_seq2seq_dp(self):
        self.run_seq2seq_quick(distributed=False)
98

99
    # verify that the trainer can handle distributed with n_gpu > 1
100
    @require_torch_multi_gpu
101
102
    def test_run_seq2seq_ddp(self):
        self.run_seq2seq_quick(distributed=True)
103

104
    # test --sharded_ddp w/o --fp16
105
106
    @require_torch_multi_gpu
    @require_fairscale
107
108
    def test_run_seq2seq_sharded_ddp(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
109

110
    # test --sharded_ddp w/ --fp16
111
112
    @require_torch_multi_gpu
    @require_fairscale
113
114
115
    def test_run_seq2seq_sharded_ddp_fp16(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")

116
    # test --sharded_ddp zero_dp_2 w/o --fp16
117
118
    @require_torch_multi_gpu
    @require_fairscale
119
    @unittest.skip("XXX: Fixme: hanging")
120
    def test_run_seq2seq_fully_sharded_ddp(self):
121
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
122

123
    # test --sharded_ddp zero_dp_2 w/ --fp16
124
125
    @require_torch_multi_gpu
    @require_fairscale
126
    @unittest.skip("XXX: Fixme: hanging")
127
128
    def test_run_seq2seq_fully_sharded_ddp_fp16(self):
        self.run_seq2seq_quick(
129
            distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
130
        )
131

132
    @require_apex
133
    @require_torch_gpu
134
    def test_run_seq2seq_apex(self):
135
136
137
138
139
140
141
142
143
144
145
146
        # XXX: apex breaks the trainer if it's run twice e.g. run_seq2seq.main() from the same
        # program and it breaks other tests that run from the same pytest worker, therefore until this is
        # sorted out it must be run only in an external program, that is distributed=True in this
        # test and only under one or more gpus - if we want cpu will need to make a special test
        #
        # specifically to the problem traced it to self.optimizer.step() - if it's run 2nd time via
        # 2nd main() call it botches the future eval.
        #
        self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
        # test 2nd time - was getting eval_loss': nan'
        # to reproduce the problem set distributed=False
        self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
147

148
    @slow
149
    def test_run_seq2seq_slow(self):
150
        output_dir = self.run_trainer(
151
152
153
154
155
156
            eval_steps=2,
            max_len=128,
            model_name=MARIAN_MODEL,
            learning_rate=3e-4,
            num_train_epochs=10,
            distributed=False,
157
        )
Suraj Patil's avatar
Suraj Patil committed
158

159
160
161
162
163
        # Check metrics
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
        first_step_stats = eval_metrics[0]
        last_step_stats = eval_metrics[-1]
164

165
        assert first_step_stats["eval_loss"] > last_step_stats["eval_loss"], "model learned nothing"
166
        assert isinstance(last_step_stats["eval_bleu"], float)
167

168
169
170
        # test if do_predict saves generations and metrics
        contents = os.listdir(output_dir)
        contents = {os.path.basename(p) for p in contents}
171
        assert "test_generations.txt" in contents
172
        assert "test_results.json" in contents
173

174
    def run_trainer(
175
176
        self,
        eval_steps: int,
177
        max_len: int,
178
179
        model_name: str,
        num_train_epochs: int,
180
        learning_rate: float = 3e-3,
181
182
        distributed: bool = False,
        extra_args_str: str = None,
183
        predict_with_generate: bool = True,
184
    ):
185
        data_dir = self.examples_dir / "test_data/wmt_en_ro"
186
        output_dir = self.get_auto_remove_tmp_dir()
187
        args = f"""
188
            --model_name_or_path {model_name}
189
190
191
            --train_file {data_dir}/train.json
            --validation_file {data_dir}/val.json
            --test_file {data_dir}/test.json
192
193
            --output_dir {output_dir}
            --overwrite_output_dir
194
195
            --max_train_samples 8
            --max_val_samples 8
196
            --max_source_length {max_len}
197
198
            --max_target_length {max_len}
            --val_max_target_length {max_len}
199
200
201
202
203
204
            --do_train
            --do_eval
            --do_predict
            --num_train_epochs {str(num_train_epochs)}
            --per_device_train_batch_size 4
            --per_device_eval_batch_size 4
205
            --learning_rate {learning_rate}
206
            --warmup_steps 8
Sylvain Gugger's avatar
Sylvain Gugger committed
207
            --evaluation_strategy steps
208
209
210
            --logging_steps 0
            --save_steps {str(eval_steps)}
            --eval_steps {str(eval_steps)}
211
            --group_by_length
212
            --label_smoothing_factor 0.1
213
214
            --adafactor
            --task translation
215
216
            --target_lang ro_RO
            --source_lang en_XX
217
218
219
220
221
        """
        if predict_with_generate:
            args += "--predict_with_generate"

        args = args.split()
222

223
224
225
        if extra_args_str is not None:
            args.extend(extra_args_str.split())

226
        if distributed:
227
            n_gpu = get_gpu_count()
228
229
230
            distributed_args = f"""
                -m torch.distributed.launch
                --nproc_per_node={n_gpu}
231
                {self.examples_dir_str}/seq2seq/run_seq2seq.py
232
233
234
            """.split()
            cmd = [sys.executable] + distributed_args + args
            execute_subprocess_async(cmd, env=self.get_env())
235
        else:
236
            testargs = ["run_seq2seq.py"] + args
237
238
            with patch.object(sys, "argv", testargs):
                main()
239

240
        return output_dir