test_trainer.py 21.5 KB
Newer Older
1
2
3
import json
import os
import tempfile
Julien Chaumond's avatar
Julien Chaumond committed
4
5
import unittest

6
import datasets
Sylvain Gugger's avatar
Sylvain Gugger committed
7
8
import numpy as np

9
10
from transformers import AutoTokenizer, PretrainedConfig, TrainingArguments, is_torch_available
from transformers.file_utils import WEIGHTS_NAME
11
from transformers.testing_utils import get_tests_dir, require_torch, slow
Julien Chaumond's avatar
Julien Chaumond committed
12
13
14
15


if is_torch_available():
    import torch
16
17
    from torch.utils.data import IterableDataset

Julien Chaumond's avatar
Julien Chaumond committed
18
19
20
21
    from transformers import (
        AutoModelForSequenceClassification,
        GlueDataset,
        GlueDataTrainingArguments,
22
        LineByLineTextDataset,
23
        PreTrainedModel,
24
        Trainer,
Julien Chaumond's avatar
Julien Chaumond committed
25
26
27
    )


28
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
Julien Chaumond's avatar
Julien Chaumond committed
29
30


Sylvain Gugger's avatar
Sylvain Gugger committed
31
class RegressionDataset:
Sylvain Gugger's avatar
Sylvain Gugger committed
32
    def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
Sylvain Gugger's avatar
Sylvain Gugger committed
33
        np.random.seed(seed)
Sylvain Gugger's avatar
Sylvain Gugger committed
34
        self.label_names = ["labels"] if label_names is None else label_names
Sylvain Gugger's avatar
Sylvain Gugger committed
35
36
        self.length = length
        self.x = np.random.normal(size=(length,)).astype(np.float32)
Sylvain Gugger's avatar
Sylvain Gugger committed
37
38
        self.ys = [a * self.x + b + np.random.normal(scale=0.1, size=(length,)) for _ in self.label_names]
        self.ys = [y.astype(np.float32) for y in self.ys]
Julien Chaumond's avatar
Julien Chaumond committed
39

Sylvain Gugger's avatar
Sylvain Gugger committed
40
41
42
43
    def __len__(self):
        return self.length

    def __getitem__(self, i):
Sylvain Gugger's avatar
Sylvain Gugger committed
44
45
46
        result = {name: y[i] for name, y in zip(self.label_names, self.ys)}
        result["input_x"] = self.x[i]
        return result
Sylvain Gugger's avatar
Sylvain Gugger committed
47
48
49
50
51
52
53
54
55
56


class AlmostAccuracy:
    def __init__(self, thresh=0.25):
        self.thresh = thresh

    def __call__(self, eval_pred):
        predictions, labels = eval_pred
        true = np.abs(predictions - labels) <= self.thresh
        return {"accuracy": true.astype(np.float32).mean().item()}
57

Julien Chaumond's avatar
Julien Chaumond committed
58

59
60
61
62
63
64
65
66
class RegressionModelConfig(PretrainedConfig):
    def __init__(self, a=0, b=0, double_output=False, **kwargs):
        super().__init__(**kwargs)
        self.a = a
        self.b = b
        self.double_output = double_output


67
68
69
70
71
72
73
74
75
76
77
78
79
if is_torch_available():

    class SampleIterableDataset(IterableDataset):
        def __init__(self, file_path):
            self.file_path = file_path

        def parse_file(self):
            f = open(self.file_path, "r")
            return f.readlines()

        def __iter__(self):
            return iter(self.parse_file())

Sylvain Gugger's avatar
Sylvain Gugger committed
80
    class RegressionModel(torch.nn.Module):
81
        def __init__(self, a=0, b=0, double_output=False):
Sylvain Gugger's avatar
Sylvain Gugger committed
82
83
84
            super().__init__()
            self.a = torch.nn.Parameter(torch.tensor(a).float())
            self.b = torch.nn.Parameter(torch.tensor(b).float())
85
86
            self.double_output = double_output
            self.config = None
Sylvain Gugger's avatar
Sylvain Gugger committed
87

Sylvain Gugger's avatar
Sylvain Gugger committed
88
        def forward(self, input_x=None, labels=None, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
89
90
            y = input_x * self.a + self.b
            if labels is None:
91
                return (y, y) if self.double_output else (y,)
Sylvain Gugger's avatar
Sylvain Gugger committed
92
            loss = torch.nn.functional.mse_loss(y, labels)
93
            return (loss, y, y) if self.double_output else (loss, y)
Sylvain Gugger's avatar
Sylvain Gugger committed
94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    class RegressionPreTrainedModel(PreTrainedModel):
        config_class = RegressionModelConfig
        base_model_prefix = "regression"

        def __init__(self, config):
            super().__init__(config)
            self.a = torch.nn.Parameter(torch.tensor(config.a).float())
            self.b = torch.nn.Parameter(torch.tensor(config.b).float())
            self.double_output = config.double_output

        def forward(self, input_x=None, labels=None, **kwargs):
            y = input_x * self.a + self.b
            if labels is None:
                return (y, y) if self.double_output else (y,)
            loss = torch.nn.functional.mse_loss(y, labels)
            return (loss, y, y) if self.double_output else (loss, y)

112
    def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
113
114
115
        label_names = kwargs.get("label_names", None)
        train_dataset = RegressionDataset(length=train_len, label_names=label_names)
        eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
116
117
        config = RegressionModelConfig(a=a, b=b, double_output=double_output)
        model = RegressionPreTrainedModel(config)
Sylvain Gugger's avatar
Sylvain Gugger committed
118
119
120
        compute_metrics = kwargs.pop("compute_metrics", None)
        data_collator = kwargs.pop("data_collator", None)
        optimizers = kwargs.pop("optimizers", (None, None))
121
122
        output_dir = kwargs.pop("output_dir", "./regression")
        args = TrainingArguments(output_dir, **kwargs)
Sylvain Gugger's avatar
Sylvain Gugger committed
123
124
125
126
127
128
129
130
131
132
        return Trainer(
            model,
            args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            compute_metrics=compute_metrics,
            optimizers=optimizers,
        )

133

Julien Chaumond's avatar
Julien Chaumond committed
134
135
@require_torch
class TrainerIntegrationTest(unittest.TestCase):
Sylvain Gugger's avatar
Sylvain Gugger committed
136
137
138
    def setUp(self):
        args = TrainingArguments(".")
        self.n_epochs = args.num_train_epochs
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        self.batch_size = args.train_batch_size
        trainer = get_regression_trainer(learning_rate=0.1)
        trainer.train()
        self.default_trained_model = (trainer.model.a, trainer.model.b)

        trainer = get_regression_trainer(learning_rate=0.1, seed=314)
        trainer.train()
        self.alternate_trained_model = (trainer.model.a, trainer.model.b)

    def check_trained_model(self, model, alternate_seed=False):
        # Checks a training seeded with learning_rate = 0.1
        (a, b) = self.alternate_trained_model if alternate_seed else self.default_trained_model
        self.assertTrue(torch.allclose(model.a, a))
        self.assertTrue(torch.allclose(model.b, b))
Sylvain Gugger's avatar
Sylvain Gugger committed
153

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True):
        file_list = [WEIGHTS_NAME, "training_args.bin", "log_history.json", "optimizer.pt", "scheduler.pt"]
        if is_pretrained:
            file_list.append("config.json")
        for step in range(freq, total, freq):
            checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
            self.assertTrue(os.path.isdir(checkpoint))
            for filename in file_list:
                self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))

    def check_best_model_has_been_loaded(
        self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True
    ):
        checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
        log_history = json.load(open(os.path.join(checkpoint, "log_history.json")))

        values = [d[metric] for d in log_history]
        best_value = max(values) if greater_is_better else min(values)
        best_checkpoint = (values.index(best_value) + 1) * freq
        checkpoint = os.path.join(output_dir, f"checkpoint-{best_checkpoint}")
        if is_pretrained:
            best_model = RegressionPreTrainedModel.from_pretrained(checkpoint)
            best_model.to(trainer.args.device)
        else:
            best_model = RegressionModel()
            state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
            best_model.load_state_dict(state_dict)
        self.assertTrue(torch.allclose(best_model.a, trainer.model.a))
        self.assertTrue(torch.allclose(best_model.b, trainer.model.b))

        metrics = trainer.evaluate()
        self.assertEqual(metrics[metric], best_value)

Sylvain Gugger's avatar
Sylvain Gugger committed
187
188
189
190
    def test_reproducible_training(self):
        # Checks that training worked, model trained and seed made a reproducible training.
        trainer = get_regression_trainer(learning_rate=0.1)
        trainer.train()
Sylvain Gugger's avatar
Sylvain Gugger committed
191
        self.check_trained_model(trainer.model)
Sylvain Gugger's avatar
Sylvain Gugger committed
192
193
194
195

        # Checks that a different seed gets different (reproducible) results.
        trainer = get_regression_trainer(learning_rate=0.1, seed=314)
        trainer.train()
Sylvain Gugger's avatar
Sylvain Gugger committed
196
        self.check_trained_model(trainer.model, alternate_seed=True)
Sylvain Gugger's avatar
Sylvain Gugger committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214

    def test_number_of_steps_in_training(self):
        # Regular training has n_epochs * len(train_dl) steps
        trainer = get_regression_trainer(learning_rate=0.1)
        train_output = trainer.train()
        self.assertEqual(train_output.global_step, self.n_epochs * 64 / self.batch_size)

        # Check passing num_train_epochs works (and a float version too):
        trainer = get_regression_trainer(learning_rate=0.1, num_train_epochs=1.5)
        train_output = trainer.train()
        self.assertEqual(train_output.global_step, int(1.5 * 64 / self.batch_size))

        # If we pass a max_steps, num_train_epochs is ignored
        trainer = get_regression_trainer(learning_rate=0.1, max_steps=10)
        train_output = trainer.train()
        self.assertEqual(train_output.global_step, 10)

    def test_train_and_eval_dataloaders(self):
215
        n_gpu = max(1, torch.cuda.device_count())
Sylvain Gugger's avatar
Sylvain Gugger committed
216
        trainer = get_regression_trainer(learning_rate=0.1, per_device_train_batch_size=16)
217
        self.assertEqual(trainer.get_train_dataloader().batch_size, 16 * n_gpu)
Sylvain Gugger's avatar
Sylvain Gugger committed
218
        trainer = get_regression_trainer(learning_rate=0.1, per_device_eval_batch_size=16)
219
        self.assertEqual(trainer.get_eval_dataloader().batch_size, 16 * n_gpu)
Sylvain Gugger's avatar
Sylvain Gugger committed
220
221
222
223
224

        # Check drop_last works
        trainer = get_regression_trainer(
            train_len=66, eval_len=74, learning_rate=0.1, per_device_train_batch_size=16, per_device_eval_batch_size=32
        )
225
226
        self.assertEqual(len(trainer.get_train_dataloader()), 66 // (16 * n_gpu) + 1)
        self.assertEqual(len(trainer.get_eval_dataloader()), 74 // (32 * n_gpu) + 1)
Sylvain Gugger's avatar
Sylvain Gugger committed
227
228
229
230
231
232
233
234
235

        trainer = get_regression_trainer(
            train_len=66,
            eval_len=74,
            learning_rate=0.1,
            per_device_train_batch_size=16,
            per_device_eval_batch_size=32,
            dataloader_drop_last=True,
        )
236
237
        self.assertEqual(len(trainer.get_train_dataloader()), 66 // (16 * n_gpu))
        self.assertEqual(len(trainer.get_eval_dataloader()), 74 // (32 * n_gpu))
Sylvain Gugger's avatar
Sylvain Gugger committed
238

239
        # Check passing a new dataset for evaluation wors
Sylvain Gugger's avatar
Sylvain Gugger committed
240
        new_eval_dataset = RegressionDataset(length=128)
241
        self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu))
Sylvain Gugger's avatar
Sylvain Gugger committed
242
243
244
245
246

    def test_evaluate(self):
        trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy())
        results = trainer.evaluate()

Sylvain Gugger's avatar
Sylvain Gugger committed
247
        x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
Sylvain Gugger's avatar
Sylvain Gugger committed
248
249
250
251
252
253
254
255
256
257
        pred = 1.5 * x + 2.5
        expected_loss = ((pred - y) ** 2).mean()
        self.assertAlmostEqual(results["eval_loss"], expected_loss)
        expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
        self.assertAlmostEqual(results["eval_accuracy"], expected_acc)

        # With a number of elements not a round multiple of the batch size
        trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracy())
        results = trainer.evaluate()

Sylvain Gugger's avatar
Sylvain Gugger committed
258
        x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
Sylvain Gugger's avatar
Sylvain Gugger committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        pred = 1.5 * x + 2.5
        expected_loss = ((pred - y) ** 2).mean()
        self.assertAlmostEqual(results["eval_loss"], expected_loss)
        expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
        self.assertAlmostEqual(results["eval_accuracy"], expected_acc)

    def test_predict(self):
        trainer = get_regression_trainer(a=1.5, b=2.5)
        preds = trainer.predict(trainer.eval_dataset).predictions
        x = trainer.eval_dataset.x
        self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))

        # With a number of elements not a round multiple of the batch size
        trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66)
        preds = trainer.predict(trainer.eval_dataset).predictions
        x = trainer.eval_dataset.x
        self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))

277
278
279
280
281
282
283
284
        # With more than one output of the model
        trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True)
        preds = trainer.predict(trainer.eval_dataset).predictions
        x = trainer.eval_dataset.x
        self.assertTrue(len(preds), 2)
        self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
        self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))

Sylvain Gugger's avatar
Sylvain Gugger committed
285
286
287
288
289
290
291
292
293
294
295
296
        # With more than one output/label of the model
        trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True, label_names=["labels", "labels_2"])
        outputs = trainer.predict(trainer.eval_dataset)
        preds = outputs.predictions
        labels = outputs.label_ids
        x = trainer.eval_dataset.x
        self.assertTrue(len(preds), 2)
        self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
        self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
        self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
        self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))

297
    def test_trainer_with_datasets(self):
Sylvain Gugger's avatar
Sylvain Gugger committed
298
299
300
        np.random.seed(42)
        x = np.random.normal(size=(64,)).astype(np.float32)
        y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,))
301
        train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y})
Sylvain Gugger's avatar
Sylvain Gugger committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318

        # Base training. Should have the same results as test_reproducible_training
        model = RegressionModel()
        args = TrainingArguments("./regression", learning_rate=0.1)
        trainer = Trainer(model, args, train_dataset=train_dataset)
        trainer.train()
        self.check_trained_model(trainer.model)

        # Can return tensors.
        train_dataset.set_format(type="torch")
        model = RegressionModel()
        trainer = Trainer(model, args, train_dataset=train_dataset)
        trainer.train()
        self.check_trained_model(trainer.model)

        # Adding one column not used by the model should have no impact
        z = np.random.normal(size=(64,)).astype(np.float32)
319
        train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y, "extra": z})
Sylvain Gugger's avatar
Sylvain Gugger committed
320
321
322
323
324
325
326
327
328
329
330
331
332
333
        model = RegressionModel()
        trainer = Trainer(model, args, train_dataset=train_dataset)
        trainer.train()
        self.check_trained_model(trainer.model)

    def test_custom_optimizer(self):
        train_dataset = RegressionDataset()
        args = TrainingArguments("./regression")
        model = RegressionModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1.0)
        trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler))
        trainer.train()

334
335
336
        (a, b) = self.default_trained_model
        self.assertFalse(torch.allclose(trainer.model.a, a))
        self.assertFalse(torch.allclose(trainer.model.b, b))
Sylvain Gugger's avatar
Sylvain Gugger committed
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
        self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)

    def test_model_init(self):
        train_dataset = RegressionDataset()
        args = TrainingArguments("./regression", learning_rate=0.1)
        trainer = Trainer(args=args, train_dataset=train_dataset, model_init=lambda: RegressionModel())
        trainer.train()
        self.check_trained_model(trainer.model)

        # Re-training should restart from scratch, thus lead the same results.
        trainer.train()
        self.check_trained_model(trainer.model)

        # Re-training should restart from scratch, thus lead the same results and new seed should be used.
        trainer.args.seed = 314
        trainer.train()
        self.check_trained_model(trainer.model, alternate_seed=True)

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    def test_save_checkpoints(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
            trainer.train()
            self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size))

        # With a regular model that is not a PreTrainedModel
        with tempfile.TemporaryDirectory() as tmpdir:
            trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
            trainer.model = RegressionModel()
            trainer.train()
            self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)

    def test_load_best_model_at_end(self):
        total = int(self.n_epochs * 64 / self.batch_size)
        with tempfile.TemporaryDirectory() as tmpdir:
            trainer = get_regression_trainer(
                a=1.5,
                b=2.5,
                output_dir=tmpdir,
                learning_rate=0.1,
                eval_steps=5,
                evaluation_strategy="steps",
                load_best_model_at_end=True,
            )
            self.assertFalse(trainer.args.greater_is_better)
            trainer.train()
            self.check_saved_checkpoints(tmpdir, 5, total)
            self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss")

        with tempfile.TemporaryDirectory() as tmpdir:
            trainer = get_regression_trainer(
                a=1.5,
                b=2.5,
                output_dir=tmpdir,
                learning_rate=0.1,
                eval_steps=5,
                evaluation_strategy="steps",
                load_best_model_at_end=True,
                metric_for_best_model="accuracy",
                compute_metrics=AlmostAccuracy(),
            )
            self.assertTrue(trainer.args.greater_is_better)
            trainer.train()
            self.check_saved_checkpoints(tmpdir, 5, total)
            self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_accuracy", greater_is_better=True)

        # Save is done every eval regardless of the strategy
        with tempfile.TemporaryDirectory() as tmpdir:
            trainer = get_regression_trainer(
                a=1.5,
                b=2.5,
                output_dir=tmpdir,
                learning_rate=0.1,
                evaluation_strategy="epoch",
                load_best_model_at_end=True,
                metric_for_best_model="accuracy",
                compute_metrics=AlmostAccuracy(),
            )
            self.assertTrue(trainer.args.greater_is_better)
            trainer.train()
            self.check_saved_checkpoints(tmpdir, 64 // self.batch_size, total)
            self.check_best_model_has_been_loaded(
                tmpdir, 64 // self.batch_size, total, trainer, "eval_accuracy", greater_is_better=True
            )

        # Test this works with a non PreTrainedModel
        with tempfile.TemporaryDirectory() as tmpdir:
            trainer = get_regression_trainer(
                output_dir=tmpdir,
                learning_rate=0.1,
                eval_steps=5,
                evaluation_strategy="steps",
                load_best_model_at_end=True,
            )
            trainer.model = RegressionModel(a=1.5, b=2.5)
            self.assertFalse(trainer.args.greater_is_better)
            trainer.train()
            self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=False)
            self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss", is_pretrained=False)

436
    @slow
Julien Chaumond's avatar
Julien Chaumond committed
437
438
439
440
441
    def test_trainer_eval_mrpc(self):
        MODEL_ID = "bert-base-cased-finetuned-mrpc"
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
        data_args = GlueDataTrainingArguments(
442
            task_name="mrpc", data_dir=f"{get_tests_dir()}/fixtures/tests_samples/MRPC", overwrite_cache=True
Julien Chaumond's avatar
Julien Chaumond committed
443
        )
444
        eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
Julien Chaumond's avatar
Julien Chaumond committed
445
446
447
448

        training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
        trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset)
        result = trainer.evaluate()
449
        self.assertLess(result["eval_loss"], 0.2)
Julien Chaumond's avatar
Julien Chaumond committed
450

451
    @slow
Julien Chaumond's avatar
Julien Chaumond committed
452
453
454
455
    def test_trainer_eval_lm(self):
        MODEL_ID = "distilroberta-base"
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
        dataset = LineByLineTextDataset(
Lysandre's avatar
Lysandre committed
456
457
458
            tokenizer=tokenizer,
            file_path=PATH_SAMPLE_TEXT,
            block_size=tokenizer.max_len_single_sentence,
Julien Chaumond's avatar
Julien Chaumond committed
459
460
        )
        self.assertEqual(len(dataset), 31)
461
462
463
464
465
466
467
468
469

    def test_trainer_iterable_dataset(self):
        MODEL_ID = "sshleifer/tiny-distilbert-base-cased"
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
        train_dataset = SampleIterableDataset(PATH_SAMPLE_TEXT)
        training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
        trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
        loader = trainer.get_train_dataloader()
        self.assertIsInstance(loader, torch.utils.data.DataLoader)
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484

    def test_num_train_epochs_in_training(self):
        # len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.
        # It should give 1 update step for each epoch.
        trainer = get_regression_trainer(
            max_steps=3, train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5
        )
        train_output = trainer.train()
        self.assertEqual(train_output.global_step, 3)

        # Even ``max_steps`` is not specified, we still expect 1 update step for each epoch if
        # len(train_dl) < gradient_accumulation_steps.
        trainer = get_regression_trainer(train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5)
        train_output = trainer.train()
        self.assertEqual(train_output.global_step, int(self.n_epochs))
Marcin Zab艂ocki's avatar
Marcin Zab艂ocki committed
485
486
487
488
489
490
491
492
493
494
495
496
497

    def test_flos_extraction(self):
        trainer = get_regression_trainer(learning_rate=0.1)

        def assert_flos_extraction(trainer, wrapped_model_to_check):
            self.assertEqual(trainer.model, trainer._actual_model(wrapped_model_to_check))
            self.assertGreaterEqual(getattr(trainer._actual_model(wrapped_model_to_check).config, "total_flos", 0), 0)

        # with plain model
        assert_flos_extraction(trainer, trainer.model)

        # with enforced DataParallel
        assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))