"tests/models/glpn/test_image_processing_glpn.py" did not exist on "d3eacbb8299161d21e007e7e3d42505dae741282"
test_modeling_distilbert.py 16.7 KB
Newer Older
LysandreJik's avatar
LysandreJik committed
1
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
2
# Copyright 2020 The HuggingFace Team. All rights reserved.
LysandreJik's avatar
LysandreJik committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
16
import os
import tempfile
17
18
import unittest

19
20
from pytest import mark

21
from transformers import DistilBertConfig, is_torch_available
22
from transformers.testing_utils import require_flash_attn, require_torch, require_torch_accelerator, slow, torch_device
thomwolf's avatar
thomwolf committed
23

Yih-Dar's avatar
Yih-Dar committed
24
25
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
26
from ...test_pipeline_mixin import PipelineTesterMixin
Aymeric Augustin's avatar
Aymeric Augustin committed
27
28


29
if is_torch_available():
30
31
    import torch

32
    from transformers import (
33
        DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
34
        DistilBertForMaskedLM,
35
        DistilBertForMultipleChoice,
36
37
        DistilBertForQuestionAnswering,
        DistilBertForSequenceClassification,
38
39
        DistilBertForTokenClassification,
        DistilBertModel,
40
41
    )

42
43
44
45
46
47
48
49
50
51
52
53
54

class DistilBertModelTester(object):
    def __init__(
        self,
        parent,
        batch_size=13,
        seq_length=7,
        is_training=True,
        use_input_mask=True,
        use_token_type_ids=False,
        use_labels=True,
        vocab_size=99,
        hidden_size=32,
55
        num_hidden_layers=2,
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        num_attention_heads=4,
        intermediate_size=37,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=16,
        type_sequence_label_size=2,
        initializer_range=0.02,
        num_labels=3,
        num_choices=4,
        scope=None,
    ):
        self.parent = parent
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.is_training = is_training
        self.use_input_mask = use_input_mask
        self.use_token_type_ids = use_token_type_ids
        self.use_labels = use_labels
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.type_sequence_label_size = type_sequence_label_size
        self.initializer_range = initializer_range
        self.num_labels = num_labels
        self.num_choices = num_choices
        self.scope = scope

    def prepare_config_and_inputs(self):
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

        input_mask = None
        if self.use_input_mask:
            input_mask = random_attention_mask([self.batch_size, self.seq_length])

        sequence_labels = None
        token_labels = None
        choice_labels = None
        if self.use_labels:
            sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
            choice_labels = ids_tensor([self.batch_size], self.num_choices)

        config = self.get_config()

        return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels

    def get_config(self):
        return DistilBertConfig(
            vocab_size=self.vocab_size,
            dim=self.hidden_size,
            n_layers=self.num_hidden_layers,
            n_heads=self.num_attention_heads,
            hidden_dim=self.intermediate_size,
            hidden_act=self.hidden_act,
            dropout=self.hidden_dropout_prob,
            attention_dropout=self.attention_probs_dropout_prob,
            max_position_embeddings=self.max_position_embeddings,
            initializer_range=self.initializer_range,
        )

    def create_and_check_distilbert_model(
        self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = DistilBertModel(config=config)
        model.to(torch_device)
        model.eval()
        result = model(input_ids, input_mask)
        result = model(input_ids)
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))

    def create_and_check_distilbert_for_masked_lm(
        self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = DistilBertForMaskedLM(config=config)
        model.to(torch_device)
        model.eval()
        result = model(input_ids, attention_mask=input_mask, labels=token_labels)
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))

    def create_and_check_distilbert_for_question_answering(
        self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = DistilBertForQuestionAnswering(config=config)
        model.to(torch_device)
        model.eval()
        result = model(
            input_ids, attention_mask=input_mask, start_positions=sequence_labels, end_positions=sequence_labels
        )
        self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
        self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))

    def create_and_check_distilbert_for_sequence_classification(
        self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        config.num_labels = self.num_labels
        model = DistilBertForSequenceClassification(config)
        model.to(torch_device)
        model.eval()
        result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

    def create_and_check_distilbert_for_token_classification(
        self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        config.num_labels = self.num_labels
        model = DistilBertForTokenClassification(config=config)
        model.to(torch_device)
        model.eval()

        result = model(input_ids, attention_mask=input_mask, labels=token_labels)
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))

    def create_and_check_distilbert_for_multiple_choice(
        self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        config.num_choices = self.num_choices
        model = DistilBertForMultipleChoice(config=config)
        model.to(torch_device)
        model.eval()
        multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
        multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
        result = model(
            multiple_choice_inputs_ids,
            attention_mask=multiple_choice_input_mask,
            labels=choice_labels,
        )
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))

    def prepare_config_and_inputs_for_common(self):
        config_and_inputs = self.prepare_config_and_inputs()
        (config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs
        inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
        return config, inputs_dict
LysandreJik's avatar
LysandreJik committed
198

199
200

@require_torch
201
class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
202
203
204
205
206
207
208
209
210
211
212
213
    all_model_classes = (
        (
            DistilBertModel,
            DistilBertForMaskedLM,
            DistilBertForMultipleChoice,
            DistilBertForQuestionAnswering,
            DistilBertForSequenceClassification,
            DistilBertForTokenClassification,
        )
        if is_torch_available()
        else None
    )
214
215
216
217
218
219
220
221
222
223
224
225
    pipeline_model_mapping = (
        {
            "feature-extraction": DistilBertModel,
            "fill-mask": DistilBertForMaskedLM,
            "question-answering": DistilBertForQuestionAnswering,
            "text-classification": DistilBertForSequenceClassification,
            "token-classification": DistilBertForTokenClassification,
            "zero-shot": DistilBertForSequenceClassification,
        }
        if is_torch_available()
        else {}
    )
226
    fx_compatible = True
227
228
    test_pruning = True
    test_resize_embeddings = True
229
    test_resize_position_embeddings = True
230

LysandreJik's avatar
LysandreJik committed
231
    def setUp(self):
232
        self.model_tester = DistilBertModelTester(self)
thomwolf's avatar
thomwolf committed
233
        self.config_tester = ConfigTester(self, config_class=DistilBertConfig, dim=37)
LysandreJik's avatar
LysandreJik committed
234
235
236
237

    def test_config(self):
        self.config_tester.run_common_tests()

thomwolf's avatar
thomwolf committed
238
    def test_distilbert_model(self):
LysandreJik's avatar
LysandreJik committed
239
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
thomwolf's avatar
thomwolf committed
240
        self.model_tester.create_and_check_distilbert_model(*config_and_inputs)
LysandreJik's avatar
LysandreJik committed
241
242
243

    def test_for_masked_lm(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
thomwolf's avatar
thomwolf committed
244
        self.model_tester.create_and_check_distilbert_for_masked_lm(*config_and_inputs)
LysandreJik's avatar
LysandreJik committed
245
246
247

    def test_for_question_answering(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
thomwolf's avatar
thomwolf committed
248
        self.model_tester.create_and_check_distilbert_for_question_answering(*config_and_inputs)
LysandreJik's avatar
LysandreJik committed
249
250
251

    def test_for_sequence_classification(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
thomwolf's avatar
thomwolf committed
252
        self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs)
LysandreJik's avatar
LysandreJik committed
253

254
255
256
257
    def test_for_token_classification(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_distilbert_for_token_classification(*config_and_inputs)

258
259
260
261
    def test_for_multiple_choice(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_distilbert_for_multiple_choice(*config_and_inputs)

262
263
264
265
266
    @slow
    def test_model_from_pretrained(self):
        for model_name in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            model = DistilBertModel.from_pretrained(model_name)
            self.assertIsNotNone(model)
267

268
    @slow
269
    @require_torch_accelerator
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
    def test_torchscript_device_change(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
        for model_class in self.all_model_classes:
            # BertForMultipleChoice behaves incorrectly in JIT environments.
            if model_class == DistilBertForMultipleChoice:
                return

            config.torchscript = True
            model = model_class(config=config)

            inputs_dict = self._prepare_for_class(inputs_dict, model_class)
            traced_model = torch.jit.trace(
                model, (inputs_dict["input_ids"].to("cpu"), inputs_dict["attention_mask"].to("cpu"))
            )

            with tempfile.TemporaryDirectory() as tmp:
                torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt"))
287
                loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
288
289
                loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))

290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    # Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
    @require_flash_attn
    @require_torch_accelerator
    @mark.flash_attn_test
    @slow
    def test_flash_attn_2_inference(self):
        import torch

        for model_class in self.all_model_classes:
            dummy_input = torch.LongTensor(
                [
                    [1, 2, 3, 4],
                    [1, 2, 8, 9],
                    [1, 2, 11, 12],
                    [1, 2, 13, 14],
                ]
            ).to(torch_device)
            dummy_attention_mask = torch.LongTensor(
                [
                    [0, 1, 1, 1],
                    [0, 1, 1, 1],
                    [0, 1, 1, 1],
                    [0, 1, 1, 1],
                ]
            ).to(torch_device)

            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
            model = model_class(config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                model_fa = model_class.from_pretrained(
                    tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
                )
                model_fa.to(torch_device)

                model = model_class.from_pretrained(
                    tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
                )
                model.to(torch_device)

                logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
                logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]

                self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))

                output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
                logits_fa = output_fa.hidden_states[-1]

                output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
                logits = output.hidden_states[-1]

                self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))

    # Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
    @require_flash_attn
    @require_torch_accelerator
    @mark.flash_attn_test
    @slow
    def test_flash_attn_2_inference_padding_right(self):
        import torch

        for model_class in self.all_model_classes:
            dummy_input = torch.LongTensor(
                [
                    [1, 2, 3, 4],
                    [1, 2, 8, 9],
                    [1, 2, 11, 12],
                    [1, 2, 13, 14],
                ]
            ).to(torch_device)
            dummy_attention_mask = torch.LongTensor(
                [
                    [0, 1, 1, 1],
                    [0, 1, 1, 1],
                    [0, 1, 1, 1],
                    [0, 1, 1, 1],
                ]
            ).to(torch_device)

            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
            model = model_class(config)

            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                model_fa = model_class.from_pretrained(
                    tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
                )
                model_fa.to(torch_device)

                model = model_class.from_pretrained(
                    tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
                )
                model.to(torch_device)

                logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
                logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]

                self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))

                output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
                logits_fa = output_fa.hidden_states[-1]

                output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
                logits = output.hidden_states[-1]

                self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2))

398
399
400
401
402
403
404

@require_torch
class DistilBertModelIntergrationTest(unittest.TestCase):
    @slow
    def test_inference_no_head_absolute_embedding(self):
        model = DistilBertModel.from_pretrained("distilbert-base-uncased")
        input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
405
        attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
406
407
        with torch.no_grad():
            output = model(input_ids, attention_mask=attention_mask)[0]
408
409
410
        expected_shape = torch.Size((1, 11, 768))
        self.assertEqual(output.shape, expected_shape)
        expected_slice = torch.tensor(
411
            [[[-0.1639, 0.3299, 0.1648], [-0.1746, 0.3289, 0.1710], [-0.1884, 0.3357, 0.1810]]]
412
413
        )

414
        self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))