test_finetune_trainer.py 9.99 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Suraj Patil's avatar
Suraj Patil committed
15
16
import os
import sys
17
import unittest
Suraj Patil's avatar
Suraj Patil committed
18
19
from unittest.mock import patch

20
from transformers import BertTokenizer, EncoderDecoderModel
21
from transformers.file_utils import is_datasets_available
22
from transformers.integrations import is_fairscale_available
23
24
25
26
27
28
29
30
from transformers.testing_utils import (
    TestCasePlus,
    execute_subprocess_async,
    get_gpu_count,
    require_torch_multi_gpu,
    require_torch_non_multi_gpu,
    slow,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
31
32
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
Suraj Patil's avatar
Suraj Patil committed
33

34
35
from .finetune_trainer import Seq2SeqTrainingArguments, main
from .seq2seq_trainer import Seq2SeqTrainer
36

Suraj Patil's avatar
Suraj Patil committed
37

38
set_seed(42)
Suraj Patil's avatar
Suraj Patil committed
39
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
Sylvain Gugger's avatar
Sylvain Gugger committed
40
MBART_TINY = "sshleifer/tiny-mbart"
Suraj Patil's avatar
Suraj Patil committed
41
42


43
44
45
46
47
48
49
50
51
52
53
# a candidate for testing_utils
def require_fairscale(test_case):
    """
    Decorator marking a test that requires fairscale
    """
    if not is_fairscale_available():
        return unittest.skip("test requires fairscale")(test_case)
    else:
        return test_case


54
class TestFinetuneTrainer(TestCasePlus):
55
56
    def finetune_trainer_quick(self, distributed=None, extra_args_str=None):
        output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
57
58
59
60
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
        first_step_stats = eval_metrics[0]
        assert "eval_bleu" in first_step_stats
Suraj Patil's avatar
Suraj Patil committed
61

62
63
64
65
66
67
68
69
70
71
72
73
74
    @require_torch_non_multi_gpu
    def test_finetune_trainer_no_dist(self):
        self.finetune_trainer_quick()

    # the following 2 tests verify that the trainer can handle distributed and non-distributed with n_gpu > 1
    @require_torch_multi_gpu
    def test_finetune_trainer_dp(self):
        self.finetune_trainer_quick(distributed=False)

    @require_torch_multi_gpu
    def test_finetune_trainer_ddp(self):
        self.finetune_trainer_quick(distributed=True)

75
76
77
78
79
80
81
82
83
84
    @require_torch_multi_gpu
    @require_fairscale
    def test_finetune_trainer_ddp_sharded_ddp(self):
        self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp")

    @require_torch_multi_gpu
    @require_fairscale
    def test_finetune_trainer_ddp_sharded_ddp_fp16(self):
        self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")

85
86
87
    @slow
    def test_finetune_trainer_slow(self):
        # There is a missing call to __init__process_group somewhere
88
89
90
        output_dir = self.run_trainer(
            eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=10, distributed=False
        )
Suraj Patil's avatar
Suraj Patil committed
91

92
93
94
95
96
        # Check metrics
        logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
        eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
        first_step_stats = eval_metrics[0]
        last_step_stats = eval_metrics[-1]
97

98
99
        assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"]  # model learned nothing
        assert isinstance(last_step_stats["eval_bleu"], float)
100

101
102
103
104
105
        # test if do_predict saves generations and metrics
        contents = os.listdir(output_dir)
        contents = {os.path.basename(p) for p in contents}
        assert "test_generations.txt" in contents
        assert "test_results.json" in contents
106

107
108
109
110
111
112
113
114
115
116
117
    @slow
    def test_finetune_bert2bert(self):
        if not is_datasets_available():
            return

        import datasets

        bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

        bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
118
        bert2bert.config.eos_token_id = tokenizer.sep_token_id
119
        bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
120
        bert2bert.config.max_length = 128
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

        train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
        val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")

        train_dataset = train_dataset.select(range(32))
        val_dataset = val_dataset.select(range(16))

        rouge = datasets.load_metric("rouge")

        batch_size = 4

        def _map_to_encoder_decoder_inputs(batch):
            # Tokenizer will automatically set [BOS] <text> [EOS]
            inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512)
            outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128)
            batch["input_ids"] = inputs.input_ids
            batch["attention_mask"] = inputs.attention_mask

            batch["decoder_input_ids"] = outputs.input_ids
            batch["labels"] = outputs.input_ids.copy()
            batch["labels"] = [
                [-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]
            ]
            batch["decoder_attention_mask"] = outputs.attention_mask

            assert all([len(x) == 512 for x in inputs.input_ids])
            assert all([len(x) == 128 for x in outputs.input_ids])

            return batch

        def _compute_metrics(pred):
            labels_ids = pred.label_ids
            pred_ids = pred.predictions

            # all unnecessary tokens are removed
            pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
            label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

            rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[
                "rouge2"
            ].mid

            return {
                "rouge2_precision": round(rouge_output.precision, 4),
                "rouge2_recall": round(rouge_output.recall, 4),
                "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
            }

        # map train dataset
        train_dataset = train_dataset.map(
            _map_to_encoder_decoder_inputs,
            batched=True,
            batch_size=batch_size,
            remove_columns=["article", "highlights"],
        )
        train_dataset.set_format(
            type="torch",
            columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
        )

        # same for validation dataset
        val_dataset = val_dataset.map(
            _map_to_encoder_decoder_inputs,
            batched=True,
            batch_size=batch_size,
            remove_columns=["article", "highlights"],
        )
        val_dataset.set_format(
            type="torch",
            columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
        )

        output_dir = self.get_auto_remove_tmp_dir()

        training_args = Seq2SeqTrainingArguments(
            output_dir=output_dir,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            predict_with_generate=True,
Sylvain Gugger's avatar
Sylvain Gugger committed
200
            evaluation_strategy="steps",
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
            do_train=True,
            do_eval=True,
            warmup_steps=0,
            eval_steps=2,
            logging_steps=2,
        )

        # instantiate trainer
        trainer = Seq2SeqTrainer(
            model=bert2bert,
            args=training_args,
            compute_metrics=_compute_metrics,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
        )

        # start training
        trainer.train()

220
    def run_trainer(
221
222
223
224
225
226
227
        self,
        eval_steps: int,
        max_len: str,
        model_name: str,
        num_train_epochs: int,
        distributed: bool = False,
        extra_args_str: str = None,
228
    ):
229
        data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
230
        output_dir = self.get_auto_remove_tmp_dir()
231
        args = f"""
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
            --model_name_or_path {model_name}
            --data_dir {data_dir}
            --output_dir {output_dir}
            --overwrite_output_dir
            --n_train 8
            --n_val 8
            --max_source_length {max_len}
            --max_target_length {max_len}
            --val_max_target_length {max_len}
            --do_train
            --do_eval
            --do_predict
            --num_train_epochs {str(num_train_epochs)}
            --per_device_train_batch_size 4
            --per_device_eval_batch_size 4
247
            --learning_rate 3e-3
248
            --warmup_steps 8
Sylvain Gugger's avatar
Sylvain Gugger committed
249
            --evaluation_strategy steps
250
251
252
253
254
255
256
257
258
259
260
261
            --predict_with_generate
            --logging_steps 0
            --save_steps {str(eval_steps)}
            --eval_steps {str(eval_steps)}
            --sortish_sampler
            --label_smoothing 0.1
            --adafactor
            --task translation
            --tgt_lang ro_RO
            --src_lang en_XX
        """.split()
        # --eval_beams  2
262

263
264
265
        if extra_args_str is not None:
            args.extend(extra_args_str.split())

266
267
        if distributed:
            n_gpu = get_gpu_count()
268
269
270
271
272
273
274
            distributed_args = f"""
                -m torch.distributed.launch
                --nproc_per_node={n_gpu}
                {self.test_file_dir}/finetune_trainer.py
            """.split()
            cmd = [sys.executable] + distributed_args + args
            execute_subprocess_async(cmd, env=self.get_env())
275
        else:
276
            testargs = ["finetune_trainer.py"] + args
277
278
            with patch.object(sys, "argv", testargs):
                main()
279

280
        return output_dir