test_trainer_callback.py 16.2 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
16

import os
Sylvain Gugger's avatar
Sylvain Gugger committed
17
18
19
import shutil
import tempfile
import unittest
20
from unittest.mock import patch
Sylvain Gugger's avatar
Sylvain Gugger committed
21
22
23

from transformers import (
    DefaultFlowCallback,
24
    EarlyStoppingCallback,
25
    IntervalStrategy,
Sylvain Gugger's avatar
Sylvain Gugger committed
26
27
28
29
    PrinterCallback,
    ProgressCallback,
    Trainer,
    TrainerCallback,
30
    TrainerState,
Sylvain Gugger's avatar
Sylvain Gugger committed
31
32
33
34
    TrainingArguments,
    is_torch_available,
)
from transformers.testing_utils import require_torch
35
from transformers.trainer_callback import ExportableState
Sylvain Gugger's avatar
Sylvain Gugger committed
36
37
38


if is_torch_available():
39
    from transformers.trainer import DEFAULT_CALLBACKS, TRAINER_STATE_NAME
Sylvain Gugger's avatar
Sylvain Gugger committed
40
41
42
43

    from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel


44
45
46
47
48
49
50
51
52
53
54
55
class MyTestExportableCallback(TrainerCallback, ExportableState):
    def __init__(self, my_test_state="test"):
        self.my_test_state = my_test_state

    def state(self):
        return {
            "args": {
                "my_test_state": self.my_test_state,
            },
        }


Stas Bekman's avatar
Stas Bekman committed
56
class MyTestTrainerCallback(TrainerCallback):
Sylvain Gugger's avatar
Sylvain Gugger committed
57
58
    "A callback that registers the events that goes through."

59
    def __init__(self, my_test_state="test"):
Sylvain Gugger's avatar
Sylvain Gugger committed
60
        self.events = []
61
        self.my_test_state = my_test_state
Sylvain Gugger's avatar
Sylvain Gugger committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

    def on_init_end(self, args, state, control, **kwargs):
        self.events.append("on_init_end")

    def on_train_begin(self, args, state, control, **kwargs):
        self.events.append("on_train_begin")

    def on_train_end(self, args, state, control, **kwargs):
        self.events.append("on_train_end")

    def on_epoch_begin(self, args, state, control, **kwargs):
        self.events.append("on_epoch_begin")

    def on_epoch_end(self, args, state, control, **kwargs):
        self.events.append("on_epoch_end")

    def on_step_begin(self, args, state, control, **kwargs):
        self.events.append("on_step_begin")

    def on_step_end(self, args, state, control, **kwargs):
        self.events.append("on_step_end")

    def on_evaluate(self, args, state, control, **kwargs):
        self.events.append("on_evaluate")

87
88
89
    def on_predict(self, args, state, control, **kwargs):
        self.events.append("on_predict")

Sylvain Gugger's avatar
Sylvain Gugger committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    def on_save(self, args, state, control, **kwargs):
        self.events.append("on_save")

    def on_log(self, args, state, control, **kwargs):
        self.events.append("on_log")

    def on_prediction_step(self, args, state, control, **kwargs):
        self.events.append("on_prediction_step")


@require_torch
class TrainerCallbackTest(unittest.TestCase):
    def setUp(self):
        self.output_dir = tempfile.mkdtemp()

    def tearDown(self):
        shutil.rmtree(self.output_dir)

    def get_trainer(self, a=0, b=0, train_len=64, eval_len=64, callbacks=None, disable_tqdm=False, **kwargs):
        # disable_tqdm in TrainingArguments has a flaky default since it depends on the level of logging. We make sure
        # its set to False since the tests later on depend on its value.
        train_dataset = RegressionDataset(length=train_len)
        eval_dataset = RegressionDataset(length=eval_len)
        config = RegressionModelConfig(a=a, b=b)
        model = RegressionPreTrainedModel(config)

116
        args = TrainingArguments(self.output_dir, disable_tqdm=disable_tqdm, report_to=[], **kwargs)
Sylvain Gugger's avatar
Sylvain Gugger committed
117
118
119
120
121
122
123
124
125
126
127
128
        return Trainer(
            model,
            args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            callbacks=callbacks,
        )

    def check_callbacks_equality(self, cbs1, cbs2):
        self.assertEqual(len(cbs1), len(cbs2))

        # Order doesn't matter
129
130
        cbs1 = sorted(cbs1, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__)
        cbs2 = sorted(cbs2, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__)
Sylvain Gugger's avatar
Sylvain Gugger committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

        for cb1, cb2 in zip(cbs1, cbs2):
            if isinstance(cb1, type) and isinstance(cb2, type):
                self.assertEqual(cb1, cb2)
            elif isinstance(cb1, type) and not isinstance(cb2, type):
                self.assertEqual(cb1, cb2.__class__)
            elif not isinstance(cb1, type) and isinstance(cb2, type):
                self.assertEqual(cb1.__class__, cb2)
            else:
                self.assertEqual(cb1, cb2)

    def get_expected_events(self, trainer):
        expected_events = ["on_init_end", "on_train_begin"]
        step = 0
        train_dl_len = len(trainer.get_eval_dataloader())
        evaluation_events = ["on_prediction_step"] * len(trainer.get_eval_dataloader()) + ["on_log", "on_evaluate"]
        for _ in range(trainer.state.num_train_epochs):
            expected_events.append("on_epoch_begin")
            for _ in range(train_dl_len):
                step += 1
                expected_events += ["on_step_begin", "on_step_end"]
                if step % trainer.args.logging_steps == 0:
                    expected_events.append("on_log")
154
                if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
Sylvain Gugger's avatar
Sylvain Gugger committed
155
                    expected_events += evaluation_events.copy()
156
                if step % trainer.args.save_steps == 0 or step == trainer.state.max_steps:
Sylvain Gugger's avatar
Sylvain Gugger committed
157
158
                    expected_events.append("on_save")
            expected_events.append("on_epoch_end")
159
            if trainer.args.eval_strategy == IntervalStrategy.EPOCH:
Sylvain Gugger's avatar
Sylvain Gugger committed
160
                expected_events += evaluation_events.copy()
161
        expected_events += ["on_log", "on_train_end"]
Sylvain Gugger's avatar
Sylvain Gugger committed
162
163
164
165
166
167
168
169
        return expected_events

    def test_init_callback(self):
        trainer = self.get_trainer()
        expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback]
        self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)

        # Callbacks passed at init are added to the default callbacks
Stas Bekman's avatar
Stas Bekman committed
170
171
        trainer = self.get_trainer(callbacks=[MyTestTrainerCallback])
        expected_callbacks.append(MyTestTrainerCallback)
Sylvain Gugger's avatar
Sylvain Gugger committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)

        # TrainingArguments.disable_tqdm controls if use ProgressCallback or PrinterCallback
        trainer = self.get_trainer(disable_tqdm=True)
        expected_callbacks = DEFAULT_CALLBACKS.copy() + [PrinterCallback]
        self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)

    def test_add_remove_callback(self):
        expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback]
        trainer = self.get_trainer()

        # We can add, pop, or remove by class name
        trainer.remove_callback(DefaultFlowCallback)
        expected_callbacks.remove(DefaultFlowCallback)
        self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)

        trainer = self.get_trainer()
        cb = trainer.pop_callback(DefaultFlowCallback)
        self.assertEqual(cb.__class__, DefaultFlowCallback)
        self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)

        trainer.add_callback(DefaultFlowCallback)
        expected_callbacks.insert(0, DefaultFlowCallback)
        self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)

        # We can also add, pop, or remove by instance
        trainer = self.get_trainer()
        cb = trainer.callback_handler.callbacks[0]
        trainer.remove_callback(cb)
        expected_callbacks.remove(DefaultFlowCallback)
        self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)

        trainer = self.get_trainer()
        cb1 = trainer.callback_handler.callbacks[0]
        cb2 = trainer.pop_callback(cb1)
        self.assertEqual(cb1, cb2)
        self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)

        trainer.add_callback(cb1)
        expected_callbacks.insert(0, DefaultFlowCallback)
        self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)

    def test_event_flow(self):
Stas Bekman's avatar
Stas Bekman committed
215
216
217
218
219
220
        import warnings

        # XXX: for now ignore scatter_gather warnings in this test since it's not relevant to what's being tested
        warnings.simplefilter(action="ignore", category=UserWarning)

        trainer = self.get_trainer(callbacks=[MyTestTrainerCallback])
Sylvain Gugger's avatar
Sylvain Gugger committed
221
222
223
224
225
        trainer.train()
        events = trainer.callback_handler.callbacks[-2].events
        self.assertEqual(events, self.get_expected_events(trainer))

        # Independent log/save/eval
Stas Bekman's avatar
Stas Bekman committed
226
        trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], logging_steps=5)
Sylvain Gugger's avatar
Sylvain Gugger committed
227
228
229
230
        trainer.train()
        events = trainer.callback_handler.callbacks[-2].events
        self.assertEqual(events, self.get_expected_events(trainer))

Stas Bekman's avatar
Stas Bekman committed
231
        trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], save_steps=5)
Sylvain Gugger's avatar
Sylvain Gugger committed
232
233
234
235
        trainer.train()
        events = trainer.callback_handler.callbacks[-2].events
        self.assertEqual(events, self.get_expected_events(trainer))

236
        trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_steps=5, eval_strategy="steps")
Sylvain Gugger's avatar
Sylvain Gugger committed
237
238
239
240
        trainer.train()
        events = trainer.callback_handler.callbacks[-2].events
        self.assertEqual(events, self.get_expected_events(trainer))

241
        trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_strategy="epoch")
Sylvain Gugger's avatar
Sylvain Gugger committed
242
243
244
245
246
247
        trainer.train()
        events = trainer.callback_handler.callbacks[-2].events
        self.assertEqual(events, self.get_expected_events(trainer))

        # A bit of everything
        trainer = self.get_trainer(
Stas Bekman's avatar
Stas Bekman committed
248
249
250
251
            callbacks=[MyTestTrainerCallback],
            logging_steps=3,
            save_steps=10,
            eval_steps=5,
252
            eval_strategy="steps",
Sylvain Gugger's avatar
Sylvain Gugger committed
253
254
255
256
        )
        trainer.train()
        events = trainer.callback_handler.callbacks[-2].events
        self.assertEqual(events, self.get_expected_events(trainer))
257
258

        # warning should be emitted for duplicated callbacks
259
        with patch("transformers.trainer_callback.logger.warning") as warn_mock:
260
261
262
263
            trainer = self.get_trainer(
                callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
            )
            assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420

    def test_stateful_callbacks(self):
        # Use something with non-defaults
        cb = EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.2)
        trainer = self.get_trainer(
            callbacks=[cb],
            load_best_model_at_end=True,
            save_strategy="steps",
            eval_strategy="steps",
            save_steps=2,
            eval_steps=2,
            max_steps=2,
        )
        trainer.train()

        # Create a new trainer with defaults
        trainer = self.get_trainer(
            callbacks=[EarlyStoppingCallback()],
            load_best_model_at_end=True,
            save_strategy="steps",
            eval_strategy="steps",
            save_steps=2,
            eval_steps=2,
            max_steps=2,
            restore_callback_states_from_checkpoint=True,
        )
        # Load it back in and verify values
        checkpoint = os.path.join(self.output_dir, "checkpoint-2")
        trainer.train(resume_from_checkpoint=checkpoint)
        cb = [
            callback for callback in trainer.callback_handler.callbacks if isinstance(callback, EarlyStoppingCallback)
        ][0]
        assert cb.early_stopping_patience == 5
        assert cb.early_stopping_threshold == 0.2

    def test_stateful_mixed_callbacks(self):
        # Use two callbacks, one stateful one not
        # Use something with non-defaults
        cbs = [
            MyTestTrainerCallback(my_test_state="another value"),
            EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.2),
        ]
        trainer = self.get_trainer(
            callbacks=cbs,
            load_best_model_at_end=True,
            save_strategy="steps",
            eval_strategy="steps",
            save_steps=2,
            eval_steps=2,
            max_steps=2,
        )
        trainer.train()

        # Create a new trainer with defaults
        trainer = self.get_trainer(
            callbacks=[EarlyStoppingCallback(), MyTestTrainerCallback()],
            load_best_model_at_end=True,
            save_strategy="steps",
            eval_strategy="steps",
            save_steps=2,
            eval_steps=2,
            max_steps=2,
            restore_callback_states_from_checkpoint=True,
        )
        # Load it back in and verify values
        checkpoint = os.path.join(self.output_dir, "checkpoint-2")
        trainer.train(resume_from_checkpoint=checkpoint)
        cbs = [
            callback
            for callback in trainer.callback_handler.callbacks
            if isinstance(callback, (EarlyStoppingCallback, MyTestTrainerCallback))
        ]
        assert len(cbs) == 2
        my_test, early_stopping = cbs
        assert early_stopping.early_stopping_patience == 5
        assert early_stopping.early_stopping_threshold == 0.2
        assert my_test.my_test_state == "test"

    def test_stateful_duplicate_callbacks(self):
        # Use something with non-defaults
        cbs = [MyTestExportableCallback("first"), MyTestExportableCallback("second")]
        trainer = self.get_trainer(
            callbacks=cbs,
            load_best_model_at_end=True,
            save_strategy="steps",
            eval_strategy="steps",
            save_steps=2,
            eval_steps=2,
            max_steps=2,
        )
        trainer.train()

        # Create a new trainer with defaults
        trainer = self.get_trainer(
            callbacks=[MyTestExportableCallback(), MyTestExportableCallback()],
            load_best_model_at_end=True,
            save_strategy="steps",
            eval_strategy="steps",
            save_steps=2,
            eval_steps=2,
            max_steps=2,
            restore_callback_states_from_checkpoint=True,
        )
        # Load it back in and verify values
        checkpoint = os.path.join(self.output_dir, "checkpoint-2")
        trainer.train(resume_from_checkpoint=checkpoint)
        cbs = [
            callback
            for callback in trainer.callback_handler.callbacks
            if isinstance(callback, MyTestExportableCallback)
        ]
        assert len(cbs) == 2
        assert cbs[0].my_test_state == "first"
        assert cbs[1].my_test_state == "second"

    def test_missing_stateful_callback(self):
        cb = EarlyStoppingCallback()
        trainer = self.get_trainer(
            callbacks=[cb],
            load_best_model_at_end=True,
            save_strategy="steps",
            eval_strategy="steps",
            save_steps=2,
            eval_steps=2,
            max_steps=2,
        )
        trainer.train()

        # Create a new trainer with defaults
        trainer = self.get_trainer(
            save_strategy="steps",
            eval_strategy="steps",
            save_steps=2,
            eval_steps=2,
            max_steps=2,
            restore_callback_states_from_checkpoint=True,
        )
        # Load it back in and verify values
        checkpoint = os.path.join(self.output_dir, "checkpoint-2")
        # warning should be emitted for not-present callbacks
        with patch("transformers.trainer.logger.warning") as warn_mock:
            trainer.train(resume_from_checkpoint=checkpoint)
            assert "EarlyStoppingCallback" in warn_mock.call_args[0][0]

    def test_stateful_control(self):
        trainer = self.get_trainer(
            max_steps=2,
            save_strategy="steps",
            save_steps=2,
        )
        trainer.train()
        # Load it back in and verify values
        trainer = self.get_trainer(max_steps=2, restore_callback_states_from_checkpoint=True)
        checkpoint = os.path.join(self.output_dir, "checkpoint-2")
        trainer.state = TrainerState.load_from_json(os.path.join(checkpoint, TRAINER_STATE_NAME))
        trainer._load_callback_state()
        assert trainer.control.should_training_stop