test_trainer_ext.py 7.08 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Suraj Patil's avatar
Suraj Patil committed
15
16
import os
import sys
17
import unittest
Suraj Patil's avatar
Suraj Patil committed
18
19
from unittest.mock import patch

Sylvain Gugger's avatar
Sylvain Gugger committed
20
from transformers.file_utils import is_apex_available
21
from transformers.integrations import is_fairscale_available
22
23
24
25
26
27
28
29
from transformers.testing_utils import (
    TestCasePlus,
    execute_subprocess_async,
    get_gpu_count,
    require_torch_multi_gpu,
    require_torch_non_multi_gpu,
    slow,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
30
31
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
Suraj Patil's avatar
Suraj Patil committed
32

33
34
35
36

bindir = os.path.abspath(os.path.dirname(__file__))
sys.path.append(f"{bindir}/../../seq2seq")
from run_seq2seq import main  # noqa
37

Suraj Patil's avatar
Suraj Patil committed
38

39
set_seed(42)
Suraj Patil's avatar
Suraj Patil committed
40
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
Sylvain Gugger's avatar
Sylvain Gugger committed
41
MBART_TINY = "sshleifer/tiny-mbart"
Suraj Patil's avatar
Suraj Patil committed
42
43


44
45
46
47
48
49
50
51
52
53
54
# a candidate for testing_utils
def require_fairscale(test_case):
    """
    Decorator marking a test that requires fairscale
    """
    if not is_fairscale_available():
        return unittest.skip("test requires fairscale")(test_case)
    else:
        return test_case


55
56
57
58
59
60
61
62
63
64
65
# a candidate for testing_utils
def require_apex(test_case):
    """
    Decorator marking a test that requires apex
    """
    if not is_apex_available():
        return unittest.skip("test requires apex")(test_case)
    else:
        return test_case


66
class TestTrainerExt(TestCasePlus):
67
68
    def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, predict_with_generate=True):
        output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str, predict_with_generate)
69
70
71
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
        first_step_stats = eval_metrics[0]
72
73
        if predict_with_generate:
            assert "eval_bleu" in first_step_stats
Suraj Patil's avatar
Suraj Patil committed
74

75
    @require_torch_non_multi_gpu
76
77
    def test_run_seq2seq_no_dist(self):
        self.run_seq2seq_quick()
78

79
    # verify that the trainer can handle non-distributed with n_gpu > 1
80
    @require_torch_multi_gpu
81
82
    def test_run_seq2seq_dp(self):
        self.run_seq2seq_quick(distributed=False)
83

84
    # verify that the trainer can handle distributed with n_gpu > 1
85
    @require_torch_multi_gpu
86
87
    def test_run_seq2seq_ddp(self):
        self.run_seq2seq_quick(distributed=True)
88

89
    # test --sharded_ddp w/o --fp16
90
91
    @require_torch_multi_gpu
    @require_fairscale
92
93
    def test_run_seq2seq_sharded_ddp(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
94

95
    # test --sharded_ddp w/ --fp16
96
97
    @require_torch_multi_gpu
    @require_fairscale
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    def test_run_seq2seq_sharded_ddp_fp16(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")

    # test --sharded_ddp zero2 w/o --fp16
    @require_torch_multi_gpu
    @require_fairscale
    def test_run_seq2seq_fully_sharded_ddp(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero2", predict_with_generate=False)

    # test --sharded_ddp zero2 w/ --fp16
    @require_torch_multi_gpu
    @require_fairscale
    def test_run_seq2seq_fully_sharded_ddp_fp16(self):
        self.run_seq2seq_quick(
            distributed=True, extra_args_str="--sharded_ddp zero2 --fp16", predict_with_generate=False
        )
114

115
    @require_apex
116
117
    def test_run_seq2seq_apex(self):
        self.run_seq2seq_quick(extra_args_str="--fp16 --fp16_backend=apex")
118

119
    @slow
120
    def test_run_seq2seq_slow(self):
121
        # There is a missing call to __init__process_group somewhere
122
123
124
        output_dir = self.run_trainer(
            eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=10, distributed=False
        )
Suraj Patil's avatar
Suraj Patil committed
125

126
127
128
129
130
        # Check metrics
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
        first_step_stats = eval_metrics[0]
        last_step_stats = eval_metrics[-1]
131

132
133
        assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"]  # model learned nothing
        assert isinstance(last_step_stats["eval_bleu"], float)
134

135
136
137
        # test if do_predict saves generations and metrics
        contents = os.listdir(output_dir)
        contents = {os.path.basename(p) for p in contents}
138
        assert "test_preds_seq2seq.txt" in contents
139
        assert "test_results.json" in contents
140

141
    def run_trainer(
142
143
144
145
146
147
148
        self,
        eval_steps: int,
        max_len: str,
        model_name: str,
        num_train_epochs: int,
        distributed: bool = False,
        extra_args_str: str = None,
149
        predict_with_generate: bool = True,
150
    ):
151
        data_dir = self.examples_dir / "test_data/wmt_en_ro"
152
        output_dir = self.get_auto_remove_tmp_dir()
153
        args = f"""
154
            --model_name_or_path {model_name}
155
156
157
            --train_file {data_dir}/train.json
            --validation_file {data_dir}/val.json
            --test_file {data_dir}/test.json
158
159
            --output_dir {output_dir}
            --overwrite_output_dir
160
161
            --max_train_samples 8
            --max_val_samples 8
162
            --max_source_length {max_len}
163
164
            --max_target_length {max_len}
            --val_max_target_length {max_len}
165
166
167
168
169
170
            --do_train
            --do_eval
            --do_predict
            --num_train_epochs {str(num_train_epochs)}
            --per_device_train_batch_size 4
            --per_device_eval_batch_size 4
171
            --learning_rate 3e-3
172
            --warmup_steps 8
Sylvain Gugger's avatar
Sylvain Gugger committed
173
            --evaluation_strategy steps
174
175
176
            --logging_steps 0
            --save_steps {str(eval_steps)}
            --eval_steps {str(eval_steps)}
177
            --group_by_length
178
            --label_smoothing_factor 0.1
179
180
            --adafactor
            --task translation
181
182
            --target_lang ro_RO
            --source_lang en_XX
183
184
185
186
187
        """
        if predict_with_generate:
            args += "--predict_with_generate"

        args = args.split()
188

189
190
191
        if extra_args_str is not None:
            args.extend(extra_args_str.split())

192
        if distributed:
193
            n_gpu = get_gpu_count()
194
195
196
            distributed_args = f"""
                -m torch.distributed.launch
                --nproc_per_node={n_gpu}
197
                {self.examples_dir_str}/seq2seq/run_seq2seq.py
198
199
200
            """.split()
            cmd = [sys.executable] + distributed_args + args
            execute_subprocess_async(cmd, env=self.get_env())
201
        else:
202
            testargs = ["run_seq2seq.py"] + args
203
204
            with patch.object(sys, "argv", testargs):
                main()
205

206
        return output_dir