test_trainer_ext.py 8.37 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import math
Suraj Patil's avatar
Suraj Patil committed
16
17
import os
import sys
18
import unittest
Suraj Patil's avatar
Suraj Patil committed
19
20
from unittest.mock import patch

Sylvain Gugger's avatar
Sylvain Gugger committed
21
from transformers.file_utils import is_apex_available
22
from transformers.integrations import is_fairscale_available
23
24
25
26
from transformers.testing_utils import (
    TestCasePlus,
    execute_subprocess_async,
    get_gpu_count,
27
    require_torch_gpu,
28
29
30
31
    require_torch_multi_gpu,
    require_torch_non_multi_gpu,
    slow,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
32
33
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
Suraj Patil's avatar
Suraj Patil committed
34

35
36
37
38

bindir = os.path.abspath(os.path.dirname(__file__))
sys.path.append(f"{bindir}/../../seq2seq")
from run_seq2seq import main  # noqa
39

Suraj Patil's avatar
Suraj Patil committed
40

41
set_seed(42)
Suraj Patil's avatar
Suraj Patil committed
42
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
Sylvain Gugger's avatar
Sylvain Gugger committed
43
MBART_TINY = "sshleifer/tiny-mbart"
Suraj Patil's avatar
Suraj Patil committed
44
45


46
47
48
49
50
51
52
53
54
55
56
# a candidate for testing_utils
def require_fairscale(test_case):
    """
    Decorator marking a test that requires fairscale
    """
    if not is_fairscale_available():
        return unittest.skip("test requires fairscale")(test_case)
    else:
        return test_case


57
58
59
60
61
62
63
64
65
66
67
# a candidate for testing_utils
def require_apex(test_case):
    """
    Decorator marking a test that requires apex
    """
    if not is_apex_available():
        return unittest.skip("test requires apex")(test_case)
    else:
        return test_case


68
class TestTrainerExt(TestCasePlus):
69
    def run_seq2seq_quick(self, distributed=False, extra_args_str=None, predict_with_generate=True):
70
71
72
73
74
75
76
77
78
        output_dir = self.run_trainer(
            eval_steps=1,
            max_len=12,
            model_name=MBART_TINY,
            num_train_epochs=1,
            distributed=distributed,
            extra_args_str=extra_args_str,
            predict_with_generate=predict_with_generate,
        )
79
80
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
81

82
        first_step_stats = eval_metrics[0]
83
84
        if predict_with_generate:
            assert "eval_bleu" in first_step_stats
Suraj Patil's avatar
Suraj Patil committed
85

86
87
88
            last_step_stats = eval_metrics[-1]
            assert isinstance(last_step_stats["eval_bleu"], float)
            assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"
89

90
    @require_torch_non_multi_gpu
91
92
    def test_run_seq2seq_no_dist(self):
        self.run_seq2seq_quick()
93

94
    # verify that the trainer can handle non-distributed with n_gpu > 1
95
    @require_torch_multi_gpu
96
97
    def test_run_seq2seq_dp(self):
        self.run_seq2seq_quick(distributed=False)
98

99
    # verify that the trainer can handle distributed with n_gpu > 1
100
    @require_torch_multi_gpu
101
102
    def test_run_seq2seq_ddp(self):
        self.run_seq2seq_quick(distributed=True)
103

104
    # test --sharded_ddp w/o --fp16
105
106
    @require_torch_multi_gpu
    @require_fairscale
107
108
    def test_run_seq2seq_sharded_ddp(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
109

110
    # test --sharded_ddp w/ --fp16
111
112
    @require_torch_multi_gpu
    @require_fairscale
113
114
115
    def test_run_seq2seq_sharded_ddp_fp16(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")

116
    # test --sharded_ddp zero_dp_2 w/o --fp16
117
118
119
    @require_torch_multi_gpu
    @require_fairscale
    def test_run_seq2seq_fully_sharded_ddp(self):
120
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
121

122
    # test --sharded_ddp zero_dp_2 w/ --fp16
123
124
125
126
    @require_torch_multi_gpu
    @require_fairscale
    def test_run_seq2seq_fully_sharded_ddp_fp16(self):
        self.run_seq2seq_quick(
127
            distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
128
        )
129

130
    @require_apex
131
    @require_torch_gpu
132
    def test_run_seq2seq_apex(self):
133
134
135
136
137
138
139
140
141
142
143
144
        # XXX: apex breaks the trainer if it's run twice e.g. run_seq2seq.main() from the same
        # program and it breaks other tests that run from the same pytest worker, therefore until this is
        # sorted out it must be run only in an external program, that is distributed=True in this
        # test and only under one or more gpus - if we want cpu will need to make a special test
        #
        # specifically to the problem traced it to self.optimizer.step() - if it's run 2nd time via
        # 2nd main() call it botches the future eval.
        #
        self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
        # test 2nd time - was getting eval_loss': nan'
        # to reproduce the problem set distributed=False
        self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
145

146
    @slow
147
    def test_run_seq2seq_slow(self):
148
        output_dir = self.run_trainer(
149
150
151
152
153
154
            eval_steps=2,
            max_len=128,
            model_name=MARIAN_MODEL,
            learning_rate=3e-4,
            num_train_epochs=10,
            distributed=False,
155
        )
Suraj Patil's avatar
Suraj Patil committed
156

157
158
159
160
161
        # Check metrics
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
        first_step_stats = eval_metrics[0]
        last_step_stats = eval_metrics[-1]
162

163
        assert first_step_stats["eval_loss"] > last_step_stats["eval_loss"], "model learned nothing"
164
        assert isinstance(last_step_stats["eval_bleu"], float)
165

166
167
168
        # test if do_predict saves generations and metrics
        contents = os.listdir(output_dir)
        contents = {os.path.basename(p) for p in contents}
169
        assert "test_generations.txt" in contents
170
        assert "test_results.json" in contents
171

172
    def run_trainer(
173
174
        self,
        eval_steps: int,
175
        max_len: int,
176
177
        model_name: str,
        num_train_epochs: int,
178
        learning_rate: float = 3e-3,
179
180
        distributed: bool = False,
        extra_args_str: str = None,
181
        predict_with_generate: bool = True,
182
    ):
183
        data_dir = self.examples_dir / "test_data/wmt_en_ro"
184
        output_dir = self.get_auto_remove_tmp_dir()
185
        args = f"""
186
            --model_name_or_path {model_name}
187
188
189
            --train_file {data_dir}/train.json
            --validation_file {data_dir}/val.json
            --test_file {data_dir}/test.json
190
191
            --output_dir {output_dir}
            --overwrite_output_dir
192
193
            --max_train_samples 8
            --max_val_samples 8
194
            --max_source_length {max_len}
195
196
            --max_target_length {max_len}
            --val_max_target_length {max_len}
197
198
199
200
201
202
            --do_train
            --do_eval
            --do_predict
            --num_train_epochs {str(num_train_epochs)}
            --per_device_train_batch_size 4
            --per_device_eval_batch_size 4
203
            --learning_rate {learning_rate}
204
            --warmup_steps 8
Sylvain Gugger's avatar
Sylvain Gugger committed
205
            --evaluation_strategy steps
206
207
            --logging_steps 0
            --eval_steps {str(eval_steps)}
208
            --save_steps {str(eval_steps)}
209
            --group_by_length
210
            --label_smoothing_factor 0.1
211
212
            --adafactor
            --task translation
213
214
            --target_lang ro_RO
            --source_lang en_XX
215
216
217
218
219
        """
        if predict_with_generate:
            args += "--predict_with_generate"

        args = args.split()
220

221
222
223
        if extra_args_str is not None:
            args.extend(extra_args_str.split())

224
        if distributed:
225
            n_gpu = get_gpu_count()
226
227
228
            distributed_args = f"""
                -m torch.distributed.launch
                --nproc_per_node={n_gpu}
229
                {self.examples_dir_str}/seq2seq/run_seq2seq.py
230
231
232
            """.split()
            cmd = [sys.executable] + distributed_args + args
            execute_subprocess_async(cmd, env=self.get_env())
233
        else:
234
            testargs = ["run_seq2seq.py"] + args
235
236
            with patch.object(sys, "argv", testargs):
                main()
237

238
        return output_dir