test_trainer_ext.py 6.33 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Suraj Patil's avatar
Suraj Patil committed
15
16
import os
import sys
17
import unittest
Suraj Patil's avatar
Suraj Patil committed
18
19
from unittest.mock import patch

Sylvain Gugger's avatar
Sylvain Gugger committed
20
from transformers.file_utils import is_apex_available
21
from transformers.integrations import is_fairscale_available
22
23
24
25
26
27
28
29
from transformers.testing_utils import (
    TestCasePlus,
    execute_subprocess_async,
    get_gpu_count,
    require_torch_multi_gpu,
    require_torch_non_multi_gpu,
    slow,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
30
31
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
Suraj Patil's avatar
Suraj Patil committed
32

33
34
35
36

bindir = os.path.abspath(os.path.dirname(__file__))
sys.path.append(f"{bindir}/../../seq2seq")
from run_seq2seq import main  # noqa
37

Suraj Patil's avatar
Suraj Patil committed
38

39
set_seed(42)
Suraj Patil's avatar
Suraj Patil committed
40
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
Sylvain Gugger's avatar
Sylvain Gugger committed
41
MBART_TINY = "sshleifer/tiny-mbart"
Suraj Patil's avatar
Suraj Patil committed
42
43


44
45
46
47
48
49
50
51
52
53
54
# a candidate for testing_utils
def require_fairscale(test_case):
    """
    Decorator marking a test that requires fairscale
    """
    if not is_fairscale_available():
        return unittest.skip("test requires fairscale")(test_case)
    else:
        return test_case


55
56
57
58
59
60
61
62
63
64
65
# a candidate for testing_utils
def require_apex(test_case):
    """
    Decorator marking a test that requires apex
    """
    if not is_apex_available():
        return unittest.skip("test requires apex")(test_case)
    else:
        return test_case


66
67
class TestTrainerExt(TestCasePlus):
    def run_seq2seq_quick(self, distributed=False, extra_args_str=None):
68
        output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
69
70
71
72
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
        first_step_stats = eval_metrics[0]
        assert "eval_bleu" in first_step_stats
Suraj Patil's avatar
Suraj Patil committed
73

74
    @require_torch_non_multi_gpu
75
76
    def test_run_seq2seq_no_dist(self):
        self.run_seq2seq_quick()
77

78
    # verify that the trainer can handle non-distributed with n_gpu > 1
79
    @require_torch_multi_gpu
80
81
    def test_run_seq2seq_dp(self):
        self.run_seq2seq_quick(distributed=False)
82

83
    # verify that the trainer can handle distributed with n_gpu > 1
84
    @require_torch_multi_gpu
85
86
    def test_run_seq2seq_ddp(self):
        self.run_seq2seq_quick(distributed=True)
87

88
    # test --sharded_ddp w/o --fp16
89
90
    @require_torch_multi_gpu
    @require_fairscale
91
92
    def test_run_seq2seq_ddp_sharded_ddp(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp")
93

94
    # test --sharded_ddp w/ --fp16
95
96
    @require_torch_multi_gpu
    @require_fairscale
97
98
    def test_run_seq2seq_ddp_sharded_ddp_fp16(self):
        self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
99

100
    @require_apex
101
102
    def test_run_seq2seq_apex(self):
        self.run_seq2seq_quick(extra_args_str="--fp16 --fp16_backend=apex")
103

104
    @slow
105
    def test_run_seq2seq_slow(self):
106
        # There is a missing call to __init__process_group somewhere
107
108
109
        output_dir = self.run_trainer(
            eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=10, distributed=False
        )
Suraj Patil's avatar
Suraj Patil committed
110

111
112
113
114
115
        # Check metrics
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
        first_step_stats = eval_metrics[0]
        last_step_stats = eval_metrics[-1]
116

117
118
        assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"]  # model learned nothing
        assert isinstance(last_step_stats["eval_bleu"], float)
119

120
121
122
        # test if do_predict saves generations and metrics
        contents = os.listdir(output_dir)
        contents = {os.path.basename(p) for p in contents}
123
        assert "test_preds_seq2seq.txt" in contents
124
        assert "test_results.json" in contents
125

126
    def run_trainer(
127
128
129
130
131
132
133
        self,
        eval_steps: int,
        max_len: str,
        model_name: str,
        num_train_epochs: int,
        distributed: bool = False,
        extra_args_str: str = None,
134
    ):
135
        data_dir = self.examples_dir / "test_data/wmt_en_ro"
136
        output_dir = self.get_auto_remove_tmp_dir()
137
        args = f"""
138
            --model_name_or_path {model_name}
139
140
141
            --train_file {data_dir}/train.json
            --validation_file {data_dir}/val.json
            --test_file {data_dir}/test.json
142
143
            --output_dir {output_dir}
            --overwrite_output_dir
144
145
            --max_train_samples 8
            --max_val_samples 8
146
            --max_source_length {max_len}
147
148
            --max_target_length {max_len}
            --val_max_target_length {max_len}
149
150
151
152
153
154
            --do_train
            --do_eval
            --do_predict
            --num_train_epochs {str(num_train_epochs)}
            --per_device_train_batch_size 4
            --per_device_eval_batch_size 4
155
            --learning_rate 3e-3
156
            --warmup_steps 8
Sylvain Gugger's avatar
Sylvain Gugger committed
157
            --evaluation_strategy steps
158
159
160
161
            --predict_with_generate
            --logging_steps 0
            --save_steps {str(eval_steps)}
            --eval_steps {str(eval_steps)}
162
            --group_by_length
163
            --label_smoothing_factor 0.1
164
165
            --adafactor
            --task translation
166
167
            --target_lang ro_RO
            --source_lang en_XX
168
        """.split()
169

170
171
172
        if extra_args_str is not None:
            args.extend(extra_args_str.split())

173
        if distributed:
174
            n_gpu = get_gpu_count()
175
176
177
            distributed_args = f"""
                -m torch.distributed.launch
                --nproc_per_node={n_gpu}
178
                {self.examples_dir_str}/seq2seq/run_seq2seq.py
179
180
181
            """.split()
            cmd = [sys.executable] + distributed_args + args
            execute_subprocess_async(cmd, env=self.get_env())
182
        else:
183
            testargs = ["run_seq2seq.py"] + args
184
185
            with patch.object(sys, "argv", testargs):
                main()
186

187
        return output_dir