test_trainer.py 6.67 KB
Newer Older
Julien Chaumond's avatar
Julien Chaumond committed
1
2
3
import unittest

from transformers import AutoTokenizer, TrainingArguments, is_torch_available
4
from transformers.testing_utils import require_torch
Julien Chaumond's avatar
Julien Chaumond committed
5
6
7
8
9
10
11
12


if is_torch_available():
    import torch
    from transformers import (
        Trainer,
        LineByLineTextDataset,
        AutoModelForSequenceClassification,
13
        default_data_collator,
Julien Chaumond's avatar
Julien Chaumond committed
14
15
16
17
18
19
20
21
22
23
24
25
        DataCollatorForLanguageModeling,
        GlueDataset,
        GlueDataTrainingArguments,
        TextDataset,
    )


PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"


@require_torch
class DataCollatorIntegrationTest(unittest.TestCase):
26
    def test_default_with_dict(self):
Sylvain Gugger's avatar
Sylvain Gugger committed
27
        features = [{"label": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
28
29
30
31
32
33
34
35
36
37
38
39
40
        batch = default_data_collator(features)
        self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
        self.assertEqual(batch["labels"].dtype, torch.long)
        self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))

        # With label_ids
        features = [{"label_ids": [0, 1, 2], "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
        batch = default_data_collator(features)
        self.assertTrue(batch["labels"].equal(torch.tensor([[0, 1, 2]] * 8)))
        self.assertEqual(batch["labels"].dtype, torch.long)
        self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))

        # Features can already be tensors
Sylvain Gugger's avatar
Sylvain Gugger committed
41
        features = [{"label": i, "inputs": torch.randint(10, [10])} for i in range(8)]
42
43
44
45
46
        batch = default_data_collator(features)
        self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
        self.assertEqual(batch["labels"].dtype, torch.long)
        self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))

47
48
49
50
51
52
53
54
        # Labels can already be tensors
        features = [{"label": torch.tensor(i), "inputs": torch.randint(10, [10])} for i in range(8)]
        batch = default_data_collator(features)
        self.assertEqual(batch["labels"].dtype, torch.long)
        self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
        self.assertEqual(batch["labels"].dtype, torch.long)
        self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))

Sylvain Gugger's avatar
Sylvain Gugger committed
55
56
57
58
59
60
61
62
63
64
65
66
    def test_default_with_no_labels(self):
        features = [{"label": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
        batch = default_data_collator(features)
        self.assertTrue("labels" not in batch)
        self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))

        # With label_ids
        features = [{"label_ids": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
        batch = default_data_collator(features)
        self.assertTrue("labels" not in batch)
        self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))

Julien Chaumond's avatar
Julien Chaumond committed
67
68
69
70
    def test_default_classification(self):
        MODEL_ID = "bert-base-cased-finetuned-mrpc"
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
        data_args = GlueDataTrainingArguments(
71
            task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
Julien Chaumond's avatar
Julien Chaumond committed
72
        )
73
        dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
74
75
        data_collator = default_data_collator
        batch = data_collator(dataset.features)
Julien Chaumond's avatar
Julien Chaumond committed
76
77
78
79
80
81
        self.assertEqual(batch["labels"].dtype, torch.long)

    def test_default_regression(self):
        MODEL_ID = "distilroberta-base"
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
        data_args = GlueDataTrainingArguments(
82
            task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True
Julien Chaumond's avatar
Julien Chaumond committed
83
        )
84
        dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
85
86
        data_collator = default_data_collator
        batch = data_collator(dataset.features)
Julien Chaumond's avatar
Julien Chaumond committed
87
88
89
90
91
92
93
94
95
96
97
        self.assertEqual(batch["labels"].dtype, torch.float)

    def test_lm_tokenizer_without_padding(self):
        tokenizer = AutoTokenizer.from_pretrained("gpt2")
        data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
        # ^ causal lm

        dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
        examples = [dataset[i] for i in range(len(dataset))]
        with self.assertRaises(ValueError):
            # Expect error due to padding token missing on gpt2:
98
            data_collator(examples)
Julien Chaumond's avatar
Julien Chaumond committed
99
100
101

        dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
        examples = [dataset[i] for i in range(len(dataset))]
102
        batch = data_collator(examples)
Julien Chaumond's avatar
Julien Chaumond committed
103
104
105
106
107
108
109
110
111
112
113
        self.assertIsInstance(batch, dict)
        self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
        self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))

    def test_lm_tokenizer_with_padding(self):
        tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
        data_collator = DataCollatorForLanguageModeling(tokenizer)
        # ^ masked lm

        dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
        examples = [dataset[i] for i in range(len(dataset))]
114
        batch = data_collator(examples)
Julien Chaumond's avatar
Julien Chaumond committed
115
116
        self.assertIsInstance(batch, dict)
        self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
Sylvain Gugger's avatar
Sylvain Gugger committed
117
        self.assertEqual(batch["labels"].shape, torch.Size((31, 107)))
Julien Chaumond's avatar
Julien Chaumond committed
118
119
120

        dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
        examples = [dataset[i] for i in range(len(dataset))]
121
        batch = data_collator(examples)
Julien Chaumond's avatar
Julien Chaumond committed
122
123
        self.assertIsInstance(batch, dict)
        self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
Sylvain Gugger's avatar
Sylvain Gugger committed
124
        self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
Julien Chaumond's avatar
Julien Chaumond committed
125
126
127
128
129
130
131
132
133


@require_torch
class TrainerIntegrationTest(unittest.TestCase):
    def test_trainer_eval_mrpc(self):
        MODEL_ID = "bert-base-cased-finetuned-mrpc"
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
        data_args = GlueDataTrainingArguments(
134
            task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
Julien Chaumond's avatar
Julien Chaumond committed
135
        )
136
        eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
Julien Chaumond's avatar
Julien Chaumond committed
137
138
139
140

        training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
        trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset)
        result = trainer.evaluate()
141
        self.assertLess(result["eval_loss"], 0.2)
Julien Chaumond's avatar
Julien Chaumond committed
142
143
144
145
146
147
148
149

    def test_trainer_eval_lm(self):
        MODEL_ID = "distilroberta-base"
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
        dataset = LineByLineTextDataset(
            tokenizer=tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=tokenizer.max_len_single_sentence,
        )
        self.assertEqual(len(dataset), 31)