test_trainer.py 47.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2018 the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import dataclasses
17
import gc
18
19
import os
import tempfile
Julien Chaumond's avatar
Julien Chaumond committed
20
21
import unittest

Sylvain Gugger's avatar
Sylvain Gugger committed
22
23
import numpy as np

24
from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available
25
from transformers.file_utils import WEIGHTS_NAME
26
from transformers.testing_utils import (
27
    TestCasePlus,
28
    get_tests_dir,
29
    require_datasets,
30
    require_optuna,
31
    require_ray,
32
33
34
    require_sentencepiece,
    require_tokenizers,
    require_torch,
35
    require_torch_gpu,
36
    require_torch_multi_gpu,
37
38
39
    slow,
)
from transformers.utils.hp_naming import TrialShortNamer
Julien Chaumond's avatar
Julien Chaumond committed
40
41
42
43


if is_torch_available():
    import torch
44
45
    from torch.utils.data import IterableDataset

Julien Chaumond's avatar
Julien Chaumond committed
46
47
    from transformers import (
        AutoModelForSequenceClassification,
48
        EarlyStoppingCallback,
Julien Chaumond's avatar
Julien Chaumond committed
49
50
        GlueDataset,
        GlueDataTrainingArguments,
51
52
        GPT2Config,
        GPT2LMHeadModel,
53
        LineByLineTextDataset,
54
        PreTrainedModel,
55
        Trainer,
56
        TrainerState,
Julien Chaumond's avatar
Julien Chaumond committed
57
    )
58
    from transformers.modeling_utils import unwrap_model
Julien Chaumond's avatar
Julien Chaumond committed
59
60


61
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
Julien Chaumond's avatar
Julien Chaumond committed
62
63


Sylvain Gugger's avatar
Sylvain Gugger committed
64
class RegressionDataset:
Sylvain Gugger's avatar
Sylvain Gugger committed
65
    def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
Sylvain Gugger's avatar
Sylvain Gugger committed
66
        np.random.seed(seed)
Sylvain Gugger's avatar
Sylvain Gugger committed
67
        self.label_names = ["labels"] if label_names is None else label_names
Sylvain Gugger's avatar
Sylvain Gugger committed
68
69
        self.length = length
        self.x = np.random.normal(size=(length,)).astype(np.float32)
Sylvain Gugger's avatar
Sylvain Gugger committed
70
71
        self.ys = [a * self.x + b + np.random.normal(scale=0.1, size=(length,)) for _ in self.label_names]
        self.ys = [y.astype(np.float32) for y in self.ys]
Julien Chaumond's avatar
Julien Chaumond committed
72

Sylvain Gugger's avatar
Sylvain Gugger committed
73
74
75
76
    def __len__(self):
        return self.length

    def __getitem__(self, i):
Sylvain Gugger's avatar
Sylvain Gugger committed
77
78
79
        result = {name: y[i] for name, y in zip(self.label_names, self.ys)}
        result["input_x"] = self.x[i]
        return result
Sylvain Gugger's avatar
Sylvain Gugger committed
80
81


82
83
84
85
86
87
@dataclasses.dataclass
class RegressionTrainingArguments(TrainingArguments):
    a: float = 0.0
    b: float = 0.0


88
89
90
91
92
93
94
95
96
97
98
99
class RepeatDataset:
    def __init__(self, x, length=64):
        self.x = x
        self.length = length

    def __len__(self):
        return self.length

    def __getitem__(self, i):
        return {"input_ids": self.x, "labels": self.x}


100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class DynamicShapesDataset:
    def __init__(self, length=64, seed=42, batch_size=8):
        self.length = length
        np.random.seed(seed)
        sizes = np.random.randint(1, 20, (length // batch_size,))
        # For easy batching, we make every batch_size consecutive samples the same size.
        self.xs = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)]
        self.ys = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)]

    def __len__(self):
        return self.length

    def __getitem__(self, i):
        return {"input_x": self.xs[i], "labels": self.ys[i]}


Sylvain Gugger's avatar
Sylvain Gugger committed
116
117
118
119
120
121
122
123
class AlmostAccuracy:
    def __init__(self, thresh=0.25):
        self.thresh = thresh

    def __call__(self, eval_pred):
        predictions, labels = eval_pred
        true = np.abs(predictions - labels) <= self.thresh
        return {"accuracy": true.astype(np.float32).mean().item()}
124

Julien Chaumond's avatar
Julien Chaumond committed
125

126
127
128
129
130
131
class RegressionModelConfig(PretrainedConfig):
    def __init__(self, a=0, b=0, double_output=False, **kwargs):
        super().__init__(**kwargs)
        self.a = a
        self.b = b
        self.double_output = double_output
132
        self.hidden_size = 1
133
134


135
136
137
if is_torch_available():

    class SampleIterableDataset(IterableDataset):
138
139
        def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
            self.dataset = RegressionDataset(a=a, b=b, length=length, seed=seed, label_names=label_names)
140
141

        def __iter__(self):
142
143
            for i in range(len(self.dataset)):
                yield self.dataset[i]
144

Sylvain Gugger's avatar
Sylvain Gugger committed
145
    class RegressionModel(torch.nn.Module):
146
        def __init__(self, a=0, b=0, double_output=False):
Sylvain Gugger's avatar
Sylvain Gugger committed
147
148
149
            super().__init__()
            self.a = torch.nn.Parameter(torch.tensor(a).float())
            self.b = torch.nn.Parameter(torch.tensor(b).float())
150
151
            self.double_output = double_output
            self.config = None
Sylvain Gugger's avatar
Sylvain Gugger committed
152

Stas Bekman's avatar
Stas Bekman committed
153
        def forward(self, input_x, labels=None, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
154
155
            y = input_x * self.a + self.b
            if labels is None:
156
                return (y, y) if self.double_output else (y,)
Sylvain Gugger's avatar
Sylvain Gugger committed
157
            loss = torch.nn.functional.mse_loss(y, labels)
158
            return (loss, y, y) if self.double_output else (loss, y)
Sylvain Gugger's avatar
Sylvain Gugger committed
159

160
161
162
163
164
165
166
    class RegressionDictModel(torch.nn.Module):
        def __init__(self, a=0, b=0):
            super().__init__()
            self.a = torch.nn.Parameter(torch.tensor(a).float())
            self.b = torch.nn.Parameter(torch.tensor(b).float())
            self.config = None

Stas Bekman's avatar
Stas Bekman committed
167
        def forward(self, input_x, labels=None, **kwargs):
168
169
170
171
172
173
            y = input_x * self.a + self.b
            result = {"output": y}
            if labels is not None:
                result["loss"] = torch.nn.functional.mse_loss(y, labels)
            return result

174
175
176
177
178
179
180
181
182
183
    class RegressionPreTrainedModel(PreTrainedModel):
        config_class = RegressionModelConfig
        base_model_prefix = "regression"

        def __init__(self, config):
            super().__init__(config)
            self.a = torch.nn.Parameter(torch.tensor(config.a).float())
            self.b = torch.nn.Parameter(torch.tensor(config.b).float())
            self.double_output = config.double_output

Stas Bekman's avatar
Stas Bekman committed
184
        def forward(self, input_x, labels=None, **kwargs):
185
186
187
188
189
190
            y = input_x * self.a + self.b
            if labels is None:
                return (y, y) if self.double_output else (y,)
            loss = torch.nn.functional.mse_loss(y, labels)
            return (loss, y, y) if self.double_output else (loss, y)

191
192
193
194
195
196
197
198
199
200
201
202
203
204
    class TstLayer(torch.nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.linear1 = torch.nn.Linear(hidden_size, hidden_size)
            self.ln1 = torch.nn.LayerNorm(hidden_size)
            self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
            self.ln2 = torch.nn.LayerNorm(hidden_size)
            self.bias = torch.nn.Parameter(torch.zeros(hidden_size))

        def forward(self, x):
            h = self.ln1(torch.nn.functional.relu(self.linear1(x)))
            h = torch.nn.functional.relu(self.linear2(x))
            return self.ln2(x + h + self.bias)

205
    def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, pretrained=True, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
206
207
208
        label_names = kwargs.get("label_names", None)
        train_dataset = RegressionDataset(length=train_len, label_names=label_names)
        eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
209
210
211
212
213
        if pretrained:
            config = RegressionModelConfig(a=a, b=b, double_output=double_output)
            model = RegressionPreTrainedModel(config)
        else:
            model = RegressionModel(a=a, b=b, double_output=double_output)
Sylvain Gugger's avatar
Sylvain Gugger committed
214
215
216
        compute_metrics = kwargs.pop("compute_metrics", None)
        data_collator = kwargs.pop("data_collator", None)
        optimizers = kwargs.pop("optimizers", (None, None))
217
        output_dir = kwargs.pop("output_dir", "./regression")
218
        model_init = kwargs.pop("model_init", None)
219
220

        args = RegressionTrainingArguments(output_dir, a=a, b=b, **kwargs)
Sylvain Gugger's avatar
Sylvain Gugger committed
221
222
223
224
225
226
227
228
        return Trainer(
            model,
            args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            compute_metrics=compute_metrics,
            optimizers=optimizers,
229
            model_init=model_init,
Sylvain Gugger's avatar
Sylvain Gugger committed
230
231
        )

232

233
class TrainerIntegrationCommon:
234
    def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True):
235
        file_list = [WEIGHTS_NAME, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
236
237
238
239
240
241
242
243
244
245
246
247
        if is_pretrained:
            file_list.append("config.json")
        for step in range(freq, total, freq):
            checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
            self.assertTrue(os.path.isdir(checkpoint))
            for filename in file_list:
                self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))

    def check_best_model_has_been_loaded(
        self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True
    ):
        checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
248
        log_history = TrainerState.load_from_json(os.path.join(checkpoint, "trainer_state.json")).log_history
249
250
251
252
253
254
255
256
257
258
259
260

        values = [d[metric] for d in log_history]
        best_value = max(values) if greater_is_better else min(values)
        best_checkpoint = (values.index(best_value) + 1) * freq
        checkpoint = os.path.join(output_dir, f"checkpoint-{best_checkpoint}")
        if is_pretrained:
            best_model = RegressionPreTrainedModel.from_pretrained(checkpoint)
            best_model.to(trainer.args.device)
        else:
            best_model = RegressionModel()
            state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
            best_model.load_state_dict(state_dict)
261
            best_model.to(trainer.args.device)
262
263
264
265
266
267
        self.assertTrue(torch.allclose(best_model.a, trainer.model.a))
        self.assertTrue(torch.allclose(best_model.b, trainer.model.b))

        metrics = trainer.evaluate()
        self.assertEqual(metrics[metric], best_value)

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    def check_trainer_state_are_the_same(self, trainer_state, trainer_state1):
        # We'll pop things so operate on copies.
        state = trainer_state.copy()
        state1 = trainer_state1.copy()
        # Log history main contain different logs for the time metrics (after resuming a training).
        log_history = state.pop("log_history", None)
        log_history1 = state1.pop("log_history", None)
        self.assertEqual(state, state1)
        for log, log1 in zip(log_history, log_history1):
            _ = log.pop("train_runtime", None)
            _ = log1.pop("train_runtime", None)
            _ = log.pop("train_samples_per_second", None)
            _ = log1.pop("train_samples_per_second", None)
            self.assertEqual(log, log1)

283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306

@require_torch
@require_sentencepiece
@require_tokenizers
class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
    def setUp(self):
        super().setUp()
        args = TrainingArguments(".")
        self.n_epochs = args.num_train_epochs
        self.batch_size = args.train_batch_size
        trainer = get_regression_trainer(learning_rate=0.1)
        trainer.train()
        self.default_trained_model = (trainer.model.a, trainer.model.b)

        trainer = get_regression_trainer(learning_rate=0.1, seed=314)
        trainer.train()
        self.alternate_trained_model = (trainer.model.a, trainer.model.b)

    def check_trained_model(self, model, alternate_seed=False):
        # Checks a training seeded with learning_rate = 0.1
        (a, b) = self.alternate_trained_model if alternate_seed else self.default_trained_model
        self.assertTrue(torch.allclose(model.a, a))
        self.assertTrue(torch.allclose(model.b, b))

307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
    def test_trainer_works_with_dict(self):
        # Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break
        # anything.
        train_dataset = RegressionDataset()
        eval_dataset = RegressionDataset()
        model = RegressionDictModel()
        args = TrainingArguments("./regression")
        trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
        trainer.train()
        _ = trainer.evaluate()
        _ = trainer.predict(eval_dataset)

    def test_evaluation_with_keys_to_drop(self):
        config = GPT2Config(vocab_size=100, n_positions=128, n_ctx=128, n_embd=32, n_layer=3, n_head=4)
        tiny_gpt2 = GPT2LMHeadModel(config)
        x = torch.randint(0, 100, (128,))
        eval_dataset = RepeatDataset(x)
        args = TrainingArguments("./test")
        trainer = Trainer(tiny_gpt2, args, eval_dataset=eval_dataset)
        # By default the past_key_values are removed
        result = trainer.predict(eval_dataset)
        self.assertTrue(isinstance(result.predictions, np.ndarray))
        # We can still get them by setting ignore_keys to []
        result = trainer.predict(eval_dataset, ignore_keys=[])
        self.assertTrue(isinstance(result.predictions, tuple))
        self.assertEqual(len(result.predictions), 2)

334
335
336
337
    def test_training_arguments_are_left_untouched(self):
        trainer = get_regression_trainer()
        trainer.train()
        args = TrainingArguments("./regression")
338
339
        dict1, dict2 = args.to_dict(), trainer.args.to_dict()
        for key in dict1.keys():
340
            # Logging dir can be slightly different as they default to something with the time.
Sylvain Gugger's avatar
Sylvain Gugger committed
341
            if key != "logging_dir":
342
                self.assertEqual(dict1[key], dict2[key])
343

Sylvain Gugger's avatar
Sylvain Gugger committed
344
345
346
347
    def test_reproducible_training(self):
        # Checks that training worked, model trained and seed made a reproducible training.
        trainer = get_regression_trainer(learning_rate=0.1)
        trainer.train()
Sylvain Gugger's avatar
Sylvain Gugger committed
348
        self.check_trained_model(trainer.model)
Sylvain Gugger's avatar
Sylvain Gugger committed
349
350
351
352

        # Checks that a different seed gets different (reproducible) results.
        trainer = get_regression_trainer(learning_rate=0.1, seed=314)
        trainer.train()
Sylvain Gugger's avatar
Sylvain Gugger committed
353
        self.check_trained_model(trainer.model, alternate_seed=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371

    def test_number_of_steps_in_training(self):
        # Regular training has n_epochs * len(train_dl) steps
        trainer = get_regression_trainer(learning_rate=0.1)
        train_output = trainer.train()
        self.assertEqual(train_output.global_step, self.n_epochs * 64 / self.batch_size)

        # Check passing num_train_epochs works (and a float version too):
        trainer = get_regression_trainer(learning_rate=0.1, num_train_epochs=1.5)
        train_output = trainer.train()
        self.assertEqual(train_output.global_step, int(1.5 * 64 / self.batch_size))

        # If we pass a max_steps, num_train_epochs is ignored
        trainer = get_regression_trainer(learning_rate=0.1, max_steps=10)
        train_output = trainer.train()
        self.assertEqual(train_output.global_step, 10)

    def test_train_and_eval_dataloaders(self):
372
        n_gpu = max(1, torch.cuda.device_count())
Sylvain Gugger's avatar
Sylvain Gugger committed
373
        trainer = get_regression_trainer(learning_rate=0.1, per_device_train_batch_size=16)
374
        self.assertEqual(trainer.get_train_dataloader().batch_size, 16 * n_gpu)
Sylvain Gugger's avatar
Sylvain Gugger committed
375
        trainer = get_regression_trainer(learning_rate=0.1, per_device_eval_batch_size=16)
376
        self.assertEqual(trainer.get_eval_dataloader().batch_size, 16 * n_gpu)
Sylvain Gugger's avatar
Sylvain Gugger committed
377
378
379
380
381

        # Check drop_last works
        trainer = get_regression_trainer(
            train_len=66, eval_len=74, learning_rate=0.1, per_device_train_batch_size=16, per_device_eval_batch_size=32
        )
382
383
        self.assertEqual(len(trainer.get_train_dataloader()), 66 // (16 * n_gpu) + 1)
        self.assertEqual(len(trainer.get_eval_dataloader()), 74 // (32 * n_gpu) + 1)
Sylvain Gugger's avatar
Sylvain Gugger committed
384
385
386
387
388
389
390
391
392

        trainer = get_regression_trainer(
            train_len=66,
            eval_len=74,
            learning_rate=0.1,
            per_device_train_batch_size=16,
            per_device_eval_batch_size=32,
            dataloader_drop_last=True,
        )
393
394
        self.assertEqual(len(trainer.get_train_dataloader()), 66 // (16 * n_gpu))
        self.assertEqual(len(trainer.get_eval_dataloader()), 74 // (32 * n_gpu))
Sylvain Gugger's avatar
Sylvain Gugger committed
395

396
        # Check passing a new dataset for evaluation works
Sylvain Gugger's avatar
Sylvain Gugger committed
397
        new_eval_dataset = RegressionDataset(length=128)
398
        self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu))
Sylvain Gugger's avatar
Sylvain Gugger committed
399

400
401
402
403
404
405
    @require_torch_multi_gpu
    def test_data_is_not_parallelized_when_model_is_parallel(self):
        model = RegressionModel()
        # Make the Trainer believe it's a parallelized model
        model.is_parallelizable = True
        model.model_parallel = True
406
407
        args = TrainingArguments("./regression", per_device_train_batch_size=16, per_device_eval_batch_size=16)
        trainer = Trainer(model, args, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
408
409
        # Check the Trainer was fooled
        self.assertTrue(trainer.is_model_parallel)
410
        self.assertEqual(trainer.args.n_gpu, 1)
411
412
413
414
415
416
417

        # The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu
        self.assertEqual(trainer.get_train_dataloader().batch_size, 16)
        self.assertEqual(len(trainer.get_train_dataloader()), 64 // 16)
        self.assertEqual(trainer.get_eval_dataloader().batch_size, 16)
        self.assertEqual(len(trainer.get_eval_dataloader()), 64 // 16)

Sylvain Gugger's avatar
Sylvain Gugger committed
418
419
420
421
    def test_evaluate(self):
        trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy())
        results = trainer.evaluate()

Sylvain Gugger's avatar
Sylvain Gugger committed
422
        x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
Sylvain Gugger's avatar
Sylvain Gugger committed
423
424
425
426
427
428
429
430
431
432
        pred = 1.5 * x + 2.5
        expected_loss = ((pred - y) ** 2).mean()
        self.assertAlmostEqual(results["eval_loss"], expected_loss)
        expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
        self.assertAlmostEqual(results["eval_accuracy"], expected_acc)

        # With a number of elements not a round multiple of the batch size
        trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracy())
        results = trainer.evaluate()

Sylvain Gugger's avatar
Sylvain Gugger committed
433
        x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
Sylvain Gugger's avatar
Sylvain Gugger committed
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
        pred = 1.5 * x + 2.5
        expected_loss = ((pred - y) ** 2).mean()
        self.assertAlmostEqual(results["eval_loss"], expected_loss)
        expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
        self.assertAlmostEqual(results["eval_accuracy"], expected_acc)

    def test_predict(self):
        trainer = get_regression_trainer(a=1.5, b=2.5)
        preds = trainer.predict(trainer.eval_dataset).predictions
        x = trainer.eval_dataset.x
        self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))

        # With a number of elements not a round multiple of the batch size
        trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66)
        preds = trainer.predict(trainer.eval_dataset).predictions
        x = trainer.eval_dataset.x
        self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))

452
453
454
455
456
457
458
459
        # With more than one output of the model
        trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True)
        preds = trainer.predict(trainer.eval_dataset).predictions
        x = trainer.eval_dataset.x
        self.assertTrue(len(preds), 2)
        self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
        self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))

Sylvain Gugger's avatar
Sylvain Gugger committed
460
461
462
463
464
465
466
467
468
469
470
471
        # With more than one output/label of the model
        trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True, label_names=["labels", "labels_2"])
        outputs = trainer.predict(trainer.eval_dataset)
        preds = outputs.predictions
        labels = outputs.label_ids
        x = trainer.eval_dataset.x
        self.assertTrue(len(preds), 2)
        self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
        self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
        self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
        self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))

472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
    def test_dynamic_shapes(self):
        eval_dataset = DynamicShapesDataset(batch_size=self.batch_size)
        model = RegressionModel(a=2, b=1)
        args = TrainingArguments("./regression")
        trainer = Trainer(model, args, eval_dataset=eval_dataset)

        # Check evaluation can run to completion
        _ = trainer.evaluate()

        # Check predictions
        preds = trainer.predict(eval_dataset)
        for expected, seen in zip(eval_dataset.ys, preds.label_ids):
            self.assertTrue(np.array_equal(expected, seen[: expected.shape[0]]))
            self.assertTrue(np.all(seen[expected.shape[0] :] == -100))

        for expected, seen in zip(eval_dataset.xs, preds.predictions):
            self.assertTrue(np.array_equal(2 * expected + 1, seen[: expected.shape[0]]))
            self.assertTrue(np.all(seen[expected.shape[0] :] == -100))

        # Same tests with eval accumulation
        args = TrainingArguments("./regression", eval_accumulation_steps=2)
        trainer = Trainer(model, args, eval_dataset=eval_dataset)

        # Check evaluation can run to completion
        _ = trainer.evaluate()

        # Check predictions
        preds = trainer.predict(eval_dataset)
        for expected, seen in zip(eval_dataset.ys, preds.label_ids):
            self.assertTrue(np.array_equal(expected, seen[: expected.shape[0]]))
            self.assertTrue(np.all(seen[expected.shape[0] :] == -100))

        for expected, seen in zip(eval_dataset.xs, preds.predictions):
            self.assertTrue(np.array_equal(2 * expected + 1, seen[: expected.shape[0]]))
            self.assertTrue(np.all(seen[expected.shape[0] :] == -100))

508
    @require_datasets
509
    def test_trainer_with_datasets(self):
510
511
        import datasets

Sylvain Gugger's avatar
Sylvain Gugger committed
512
513
514
        np.random.seed(42)
        x = np.random.normal(size=(64,)).astype(np.float32)
        y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,))
515
        train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y})
Sylvain Gugger's avatar
Sylvain Gugger committed
516
517
518
519
520
521
522
523
524

        # Base training. Should have the same results as test_reproducible_training
        model = RegressionModel()
        args = TrainingArguments("./regression", learning_rate=0.1)
        trainer = Trainer(model, args, train_dataset=train_dataset)
        trainer.train()
        self.check_trained_model(trainer.model)

        # Can return tensors.
525
        train_dataset.set_format(type="torch", dtype=torch.float32)
Sylvain Gugger's avatar
Sylvain Gugger committed
526
527
528
529
530
531
532
        model = RegressionModel()
        trainer = Trainer(model, args, train_dataset=train_dataset)
        trainer.train()
        self.check_trained_model(trainer.model)

        # Adding one column not used by the model should have no impact
        z = np.random.normal(size=(64,)).astype(np.float32)
533
        train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y, "extra": z})
Sylvain Gugger's avatar
Sylvain Gugger committed
534
535
536
537
538
539
540
541
542
543
544
545
546
547
        model = RegressionModel()
        trainer = Trainer(model, args, train_dataset=train_dataset)
        trainer.train()
        self.check_trained_model(trainer.model)

    def test_custom_optimizer(self):
        train_dataset = RegressionDataset()
        args = TrainingArguments("./regression")
        model = RegressionModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1.0)
        trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler))
        trainer.train()

548
549
550
        (a, b) = self.default_trained_model
        self.assertFalse(torch.allclose(trainer.model.a, a))
        self.assertFalse(torch.allclose(trainer.model.b, b))
Sylvain Gugger's avatar
Sylvain Gugger committed
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)

    def test_model_init(self):
        train_dataset = RegressionDataset()
        args = TrainingArguments("./regression", learning_rate=0.1)
        trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel())
        trainer.train()
        self.check_trained_model(trainer.model)

        # Re-training should restart from scratch, thus lead the same results.
        trainer.train()
        self.check_trained_model(trainer.model)

        # Re-training should restart from scratch, thus lead the same results and new seed should be used.
        trainer.args.seed = 314
        trainer.train()
        self.check_trained_model(trainer.model, alternate_seed=True)

569
570
571
572
573
574
575
576
    def test_save_checkpoints(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
            trainer.train()
            self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size))

        # With a regular model that is not a PreTrainedModel
        with tempfile.TemporaryDirectory() as tmpdir:
577
            trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5, pretrained=False)
578
579
580
            trainer.train()
            self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)

581
582
583
584
585
586
587
588
    def test_gradient_accumulation(self):
        # Training with half the batch size but accumulation steps as 2 should give the same results.
        trainer = get_regression_trainer(
            gradient_accumulation_steps=2, per_device_train_batch_size=4, learning_rate=0.1
        )
        trainer.train()
        self.check_trained_model(trainer.model)

589
590
591
592
593
594
595
596
597
598
599
600
601
    @require_torch_multi_gpu
    def test_run_seq2seq_double_train_wrap_once(self):
        # test that we don't wrap the model more than once
        # since wrapping primarily happens on multi-gpu setup we want multiple gpus to test for
        # example DataParallel(DataParallel(model))

        trainer = get_regression_trainer()
        trainer.train()
        model_wrapped_before = trainer.model_wrapped
        trainer.train()
        model_wrapped_after = trainer.model_wrapped
        self.assertIs(model_wrapped_before, model_wrapped_after, "should be not wrapped twice")

602
603
604
605
606
607
    def test_can_resume_training(self):
        if torch.cuda.device_count() > 2:
            # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
            # save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
            # won't be the same since the training dataloader is shuffled).
            return
608

609
        with tempfile.TemporaryDirectory() as tmpdir:
610
611
            kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
            trainer = get_regression_trainer(**kwargs)
612
613
614
615
616
617
            trainer.train()
            (a, b) = trainer.model.a.item(), trainer.model.b.item()
            state = dataclasses.asdict(trainer.state)

            checkpoint = os.path.join(tmpdir, "checkpoint-5")

618
            # Reinitialize trainer
619
            trainer = get_regression_trainer(**kwargs)
620

621
            trainer.train(resume_from_checkpoint=checkpoint)
622
623
624
625
            (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
            state1 = dataclasses.asdict(trainer.state)
            self.assertEqual(a, a1)
            self.assertEqual(b, b1)
626
            self.check_trainer_state_are_the_same(state, state1)
627

628
629
630
631
            # Now check with a later checkpoint that it also works when we span over one epoch
            checkpoint = os.path.join(tmpdir, "checkpoint-15")

            # Reinitialize trainer and load model
632
            trainer = get_regression_trainer(**kwargs)
633

634
            trainer.train(resume_from_checkpoint=checkpoint)
635
636
637
638
            (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
            state1 = dataclasses.asdict(trainer.state)
            self.assertEqual(a, a1)
            self.assertEqual(b, b1)
639
            self.check_trainer_state_are_the_same(state, state1)
640

641
642
        # With a regular model that is not a PreTrainedModel
        with tempfile.TemporaryDirectory() as tmpdir:
643
644
645
            kwargs = dict(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, pretrained=False)

            trainer = get_regression_trainer(**kwargs)
646
647
648
649
650
651
652
            trainer.train()
            (a, b) = trainer.model.a.item(), trainer.model.b.item()
            state = dataclasses.asdict(trainer.state)

            checkpoint = os.path.join(tmpdir, "checkpoint-5")

            # Reinitialize trainer and load model
653
            trainer = get_regression_trainer(**kwargs)
654

655
            trainer.train(resume_from_checkpoint=checkpoint)
656
657
658
659
            (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
            state1 = dataclasses.asdict(trainer.state)
            self.assertEqual(a, a1)
            self.assertEqual(b, b1)
660
            self.check_trainer_state_are_the_same(state, state1)
661

662
663
664
665
            # Now check with a later checkpoint that it also works when we span over one epoch
            checkpoint = os.path.join(tmpdir, "checkpoint-15")

            # Reinitialize trainer and load model
666
            trainer = get_regression_trainer(**kwargs)
667

668
            trainer.train(resume_from_checkpoint=checkpoint)
669
670
671
672
            (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
            state1 = dataclasses.asdict(trainer.state)
            self.assertEqual(a, a1)
            self.assertEqual(b, b1)
673
            self.check_trainer_state_are_the_same(state, state1)
674

675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
        # Now check failures

        # 1. fail to find a bogus checkpoint
        trainer = get_regression_trainer()
        with self.assertRaises(Exception) as context:
            trainer.train(resume_from_checkpoint=f"{checkpoint}-bogus")
        self.assertTrue("Can't find a valid checkpoint at" in str(context.exception))

        # 2. fail to find any checkpoint - due a fresh output_dir
        output_dir2 = self.get_auto_remove_tmp_dir()
        trainer = get_regression_trainer(output_dir=output_dir2)
        with self.assertRaises(Exception) as context:
            trainer.train(resume_from_checkpoint=True)
        self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))

690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
    def test_resume_training_with_gradient_accumulation(self):
        if torch.cuda.device_count() > 2:
            # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
            # save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
            # won't be the same since the training dataloader is shuffled).
            return
        with tempfile.TemporaryDirectory() as tmpdir:
            trainer = get_regression_trainer(
                output_dir=tmpdir,
                train_len=128,
                gradient_accumulation_steps=2,
                per_device_train_batch_size=4,
                save_steps=5,
                learning_rate=0.1,
            )
            trainer.train()
            (a, b) = trainer.model.a.item(), trainer.model.b.item()
            state = dataclasses.asdict(trainer.state)

            checkpoint = os.path.join(tmpdir, "checkpoint-5")

711
712
713
714
715
716
717
718
719
            # Reinitialize trainer
            trainer = get_regression_trainer(
                output_dir=tmpdir,
                train_len=128,
                gradient_accumulation_steps=2,
                per_device_train_batch_size=4,
                save_steps=5,
                learning_rate=0.1,
            )
720

721
            trainer.train(resume_from_checkpoint=checkpoint)
722
723
724
725
            (a1, b1) = trainer.model.a.item(), trainer.model.b.item()
            state1 = dataclasses.asdict(trainer.state)
            self.assertEqual(a, a1)
            self.assertEqual(b, b1)
726
            self.check_trainer_state_are_the_same(state, state1)
727

728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
    def test_load_best_model_at_end(self):
        total = int(self.n_epochs * 64 / self.batch_size)
        with tempfile.TemporaryDirectory() as tmpdir:
            trainer = get_regression_trainer(
                a=1.5,
                b=2.5,
                output_dir=tmpdir,
                learning_rate=0.1,
                eval_steps=5,
                evaluation_strategy="steps",
                load_best_model_at_end=True,
            )
            self.assertFalse(trainer.args.greater_is_better)
            trainer.train()
            self.check_saved_checkpoints(tmpdir, 5, total)
            self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss")

        with tempfile.TemporaryDirectory() as tmpdir:
            trainer = get_regression_trainer(
                a=1.5,
                b=2.5,
                output_dir=tmpdir,
                learning_rate=0.1,
                eval_steps=5,
                evaluation_strategy="steps",
                load_best_model_at_end=True,
                metric_for_best_model="accuracy",
                compute_metrics=AlmostAccuracy(),
            )
            self.assertTrue(trainer.args.greater_is_better)
            trainer.train()
            self.check_saved_checkpoints(tmpdir, 5, total)
            self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_accuracy", greater_is_better=True)

        # Save is done every eval regardless of the strategy
        with tempfile.TemporaryDirectory() as tmpdir:
            trainer = get_regression_trainer(
                a=1.5,
                b=2.5,
                output_dir=tmpdir,
                learning_rate=0.1,
                evaluation_strategy="epoch",
                load_best_model_at_end=True,
                metric_for_best_model="accuracy",
                compute_metrics=AlmostAccuracy(),
            )
            self.assertTrue(trainer.args.greater_is_better)
            trainer.train()
            self.check_saved_checkpoints(tmpdir, 64 // self.batch_size, total)
            self.check_best_model_has_been_loaded(
                tmpdir, 64 // self.batch_size, total, trainer, "eval_accuracy", greater_is_better=True
            )

        # Test this works with a non PreTrainedModel
        with tempfile.TemporaryDirectory() as tmpdir:
            trainer = get_regression_trainer(
                output_dir=tmpdir,
                learning_rate=0.1,
                eval_steps=5,
                evaluation_strategy="steps",
                load_best_model_at_end=True,
789
                pretrained=False,
790
791
792
793
794
795
            )
            self.assertFalse(trainer.args.greater_is_better)
            trainer.train()
            self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=False)
            self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss", is_pretrained=False)

796
    @slow
Julien Chaumond's avatar
Julien Chaumond committed
797
798
799
800
801
    def test_trainer_eval_mrpc(self):
        MODEL_ID = "bert-base-cased-finetuned-mrpc"
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
        data_args = GlueDataTrainingArguments(
802
            task_name="mrpc", data_dir=f"{get_tests_dir()}/fixtures/tests_samples/MRPC", overwrite_cache=True
Julien Chaumond's avatar
Julien Chaumond committed
803
        )
804
        eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
Julien Chaumond's avatar
Julien Chaumond committed
805
806
807
808

        training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
        trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset)
        result = trainer.evaluate()
809
        self.assertLess(result["eval_loss"], 0.2)
Julien Chaumond's avatar
Julien Chaumond committed
810

811
    @slow
Julien Chaumond's avatar
Julien Chaumond committed
812
813
814
815
    def test_trainer_eval_lm(self):
        MODEL_ID = "distilroberta-base"
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
        dataset = LineByLineTextDataset(
Lysandre's avatar
Lysandre committed
816
817
818
            tokenizer=tokenizer,
            file_path=PATH_SAMPLE_TEXT,
            block_size=tokenizer.max_len_single_sentence,
Julien Chaumond's avatar
Julien Chaumond committed
819
820
        )
        self.assertEqual(len(dataset), 31)
821

822
    def test_training_iterable_dataset(self):
823
824
825
        config = RegressionModelConfig()
        model = RegressionPreTrainedModel(config)
        train_dataset = SampleIterableDataset()
826

827
        args = RegressionTrainingArguments(output_dir="./examples", max_steps=4)
828
        trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
829
        trainer.train()
830
        self.assertEqual(trainer.state.global_step, 4)
831

832
833
        loader = trainer.get_train_dataloader()
        self.assertIsInstance(loader, torch.utils.data.DataLoader)
834
835
        self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)

836
837
838
839
840
841
842
843
    def test_evaluation_iterable_dataset(self):
        config = RegressionModelConfig(a=1.5, b=2.5)
        model = RegressionPreTrainedModel(config)
        eval_dataset = SampleIterableDataset()

        args = RegressionTrainingArguments(output_dir="./examples")
        trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
        results = trainer.evaluate()
844

845
846
847
848
849
850
        x, y = trainer.eval_dataset.dataset.x, trainer.eval_dataset.dataset.ys[0]
        pred = 1.5 * x + 2.5
        expected_loss = ((pred - y) ** 2).mean()
        self.assertAlmostEqual(results["eval_loss"], expected_loss)
        expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
        self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
851

852
853
854
        # With a number of elements not a round multiple of the batch size
        eval_dataset = SampleIterableDataset(length=66)
        results = trainer.evaluate(eval_dataset)
855

856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
        x, y = eval_dataset.dataset.x, eval_dataset.dataset.ys[0]
        pred = 1.5 * x + 2.5
        expected_loss = ((pred - y) ** 2).mean()
        self.assertAlmostEqual(results["eval_loss"], expected_loss)
        expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
        self.assertAlmostEqual(results["eval_accuracy"], expected_acc)

    def test_predict_iterable_dataset(self):
        config = RegressionModelConfig(a=1.5, b=2.5)
        model = RegressionPreTrainedModel(config)
        eval_dataset = SampleIterableDataset()

        args = RegressionTrainingArguments(output_dir="./examples")
        trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())

        preds = trainer.predict(trainer.eval_dataset).predictions
        x = eval_dataset.dataset.x
        self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))

        # With a number of elements not a round multiple of the batch size
        test_dataset = SampleIterableDataset(length=66)
        preds = trainer.predict(test_dataset).predictions
        x = test_dataset.dataset.x
        self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894

    def test_num_train_epochs_in_training(self):
        # len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.
        # It should give 1 update step for each epoch.
        trainer = get_regression_trainer(
            max_steps=3, train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5
        )
        train_output = trainer.train()
        self.assertEqual(train_output.global_step, 3)

        # Even ``max_steps`` is not specified, we still expect 1 update step for each epoch if
        # len(train_dl) < gradient_accumulation_steps.
        trainer = get_regression_trainer(train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5)
        train_output = trainer.train()
        self.assertEqual(train_output.global_step, int(self.n_epochs))
Marcin Zab艂ocki's avatar
Marcin Zab艂ocki committed
895

896
897
    def test_early_stopping_callback(self):
        # early stopping stops training before num_training_epochs
898
899
900
901
902
903
904
        with tempfile.TemporaryDirectory() as tmp_dir:
            trainer = get_regression_trainer(
                output_dir=tmp_dir,
                num_train_epochs=20,
                gradient_accumulation_steps=1,
                per_device_train_batch_size=16,
                load_best_model_at_end=True,
905
                evaluation_strategy=IntervalStrategy.EPOCH,
906
907
908
909
910
911
                compute_metrics=AlmostAccuracy(),
                metric_for_best_model="accuracy",
            )
            trainer.add_callback(EarlyStoppingCallback(1, 0.0001))
            train_output = trainer.train()
            self.assertLess(train_output.global_step, 20 * 64 / 16)
912
913

        # Invalid inputs to trainer with early stopping callback result in assertion error
914
915
916
917
918
919
        with tempfile.TemporaryDirectory() as tmp_dir:
            trainer = get_regression_trainer(
                output_dir=tmp_dir,
                num_train_epochs=20,
                gradient_accumulation_steps=1,
                per_device_train_batch_size=16,
920
                evaluation_strategy=IntervalStrategy.EPOCH,
921
922
923
924
                compute_metrics=AlmostAccuracy(),
                metric_for_best_model="accuracy",
            )
            trainer.add_callback(EarlyStoppingCallback(1))
925
            self.assertEqual(trainer.state.global_step, 0)
926
927
928
929
            try:
                trainer.train()
            except AssertionError:
                self.assertEqual(trainer.state.global_step, 0)
930

Marcin Zab艂ocki's avatar
Marcin Zab艂ocki committed
931
932
933
934
    def test_flos_extraction(self):
        trainer = get_regression_trainer(learning_rate=0.1)

        def assert_flos_extraction(trainer, wrapped_model_to_check):
935
936
            self.assertEqual(trainer.model, unwrap_model(wrapped_model_to_check))
            self.assertGreaterEqual(getattr(unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0)
Marcin Zab艂ocki's avatar
Marcin Zab艂ocki committed
937
938
939
940
941
942

        # with plain model
        assert_flos_extraction(trainer, trainer.model)

        # with enforced DataParallel
        assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
943

944
945
946
        trainer.train()
        self.assertTrue(isinstance(trainer.state.total_flos, float))

947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
    def check_mem_metrics(self, trainer, check_func):
        metrics = trainer.train().metrics
        check_func("init_mem_cpu_alloc_delta", metrics)
        check_func("train_mem_cpu_alloc_delta", metrics)
        if torch.cuda.device_count() > 0:
            check_func("init_mem_gpu_alloc_delta", metrics)
            check_func("train_mem_gpu_alloc_delta", metrics)

        metrics = trainer.evaluate()
        check_func("eval_mem_cpu_alloc_delta", metrics)
        if torch.cuda.device_count() > 0:
            check_func("eval_mem_gpu_alloc_delta", metrics)

        metrics = trainer.predict(RegressionDataset()).metrics
        check_func("test_mem_cpu_alloc_delta", metrics)
        if torch.cuda.device_count() > 0:
            check_func("test_mem_gpu_alloc_delta", metrics)

    def test_mem_metrics(self):

        # with mem metrics enabled
        trainer = get_regression_trainer()
        self.check_mem_metrics(trainer, self.assertIn)

        # with mem metrics disabled
        trainer = get_regression_trainer(skip_memory_metrics=True)
        self.check_mem_metrics(trainer, self.assertNotIn)

975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
    @require_torch_gpu
    def test_fp16_full_eval(self):

        # this is a sensitive test so let's keep debugging printouts in place for quick diagnosis.
        # it's using pretty large safety margins, but small enough to detect broken functionality.
        debug = 0

        bs = 8
        # make the params somewhat big so that there will be enough RAM consumed to be able to
        # measure things. We should get about 64KB for a+b in fp32
        a = torch.ones(1000, bs) + 0.001
        b = torch.ones(1000, bs) - 0.001

        # 1. with mem metrics enabled
        trainer = get_regression_trainer(a=a, b=b, eval_len=16)
        metrics = trainer.evaluate()
        del trainer
        gc.collect()

        fp32_init = metrics["init_mem_gpu_alloc_delta"]
        fp32_eval = metrics["eval_mem_gpu_alloc_delta"]

        if debug:
            print(f"fp32_init {fp32_init}")
            print(f"fp32_eval {fp32_eval}")

        # here we expect the model to be preloaded in trainer.__init__ and consume around 64K gpu ram.
        # perfect world: fp32_init == 64<<10
        self.assertGreater(fp32_init, 59_000)
        # after eval should be no extra memory allocated - with a small margin (other than the peak
        # memory consumption for the forward calculation that gets recovered)
        # perfect world: fp32_eval == close to zero
        self.assertLess(fp32_eval, 5_000)

        # 2. with mem metrics disabled
        trainer = get_regression_trainer(a=a, b=b, eval_len=16, fp16_full_eval=True)
        metrics = trainer.evaluate()
        fp16_init = metrics["init_mem_gpu_alloc_delta"]
        fp16_eval = metrics["eval_mem_gpu_alloc_delta"]

        if debug:
            print(f"fp16_init {fp16_init}")
            print(f"fp16_eval {fp16_eval}")

        # here we expect the model to not be preloaded in trainer.__init__, so with a small margin it should be close to 0
        # perfect world: fp16_init == close to zero
        self.assertLess(fp16_init, 5_000)
        # here we put the model on device in eval and only `half()` of it, i.e. about 32K,(again we ignore the peak margin which gets returned back)
        # perfect world: fp32_init == 32<<10
        self.assertGreater(fp16_eval, 27_000)

        # 3. relative comparison fp32 vs full fp16
        # should be about half of fp16_init
        # perfect world: fp32_init/2 == fp16_eval
        self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)

1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
    def test_no_wd_param_group(self):
        model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)]))
        trainer = Trainer(model=model)
        trainer.create_optimizer_and_scheduler(10)
        # fmt: off
        wd_names = ['0.linear1.weight', '0.linear2.weight', '1.0.linear1.weight', '1.0.linear2.weight', '1.1.linear1.weight', '1.1.linear2.weight']
        # fmt: on
        wd_params = [p for n, p in model.named_parameters() if n in wd_names]
        no_wd_params = [p for n, p in model.named_parameters() if n not in wd_names]
        self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params)
        self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)

1043
1044
1045

@require_torch
@require_optuna
1046
class TrainerHyperParameterOptunaIntegrationTest(unittest.TestCase):
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
    def setUp(self):
        args = TrainingArguments(".")
        self.n_epochs = args.num_train_epochs
        self.batch_size = args.train_batch_size

    def test_hyperparameter_search(self):
        class MyTrialShortNamer(TrialShortNamer):
            DEFAULTS = {"a": 0, "b": 0}

        def hp_space(trial):
            return {}

        def model_init(trial):
            if trial is not None:
                a = trial.suggest_int("a", -4, 4)
                b = trial.suggest_int("b", -4, 4)
            else:
                a = 0
                b = 0
            config = RegressionModelConfig(a=a, b=b, double_output=False)

            return RegressionPreTrainedModel(config)

        def hp_name(trial):
            return MyTrialShortNamer.shortname(trial.params)

1073
1074
1075
1076
1077
        with tempfile.TemporaryDirectory() as tmp_dir:
            trainer = get_regression_trainer(
                output_dir=tmp_dir,
                learning_rate=0.1,
                logging_steps=1,
1078
                evaluation_strategy=IntervalStrategy.EPOCH,
1079
1080
1081
1082
1083
1084
1085
1086
                num_train_epochs=4,
                disable_tqdm=True,
                load_best_model_at_end=True,
                logging_dir="runs",
                run_name="test",
                model_init=model_init,
            )
            trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, hp_name=hp_name, n_trials=4)
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121


@require_torch
@require_ray
class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
    def setUp(self):
        args = TrainingArguments(".")
        self.n_epochs = args.num_train_epochs
        self.batch_size = args.train_batch_size

    def test_hyperparameter_search(self):
        class MyTrialShortNamer(TrialShortNamer):
            DEFAULTS = {"a": 0, "b": 0}

        def hp_space(trial):
            from ray import tune

            return {
                "a": tune.randint(-4, 4),
                "b": tune.randint(-4, 4),
            }

        def model_init(config):
            model_config = RegressionModelConfig(a=config["a"], b=config["b"], double_output=False)

            return RegressionPreTrainedModel(model_config)

        def hp_name(params):
            return MyTrialShortNamer.shortname(params)

        with tempfile.TemporaryDirectory() as tmp_dir:
            trainer = get_regression_trainer(
                output_dir=tmp_dir,
                learning_rate=0.1,
                logging_steps=1,
1122
                evaluation_strategy=IntervalStrategy.EPOCH,
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
                num_train_epochs=4,
                disable_tqdm=True,
                load_best_model_at_end=True,
                logging_dir="runs",
                run_name="test",
                model_init=model_init,
            )
            trainer.hyperparameter_search(
                direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="ray", n_trials=4
            )