test_modeling_vit.py 12.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Arthur's avatar
Arthur committed
15
"""Testing suite for the PyTorch ViT model."""
16
17
18

import unittest

19
from transformers import ViTConfig
20
21
22
from transformers.testing_utils import (
    require_accelerate,
    require_torch,
23
24
    require_torch_accelerator,
    require_torch_fp16,
25
26
27
28
    require_vision,
    slow,
    torch_device,
)
29
from transformers.utils import cached_property, is_torch_available, is_vision_available
30

Yih-Dar's avatar
Yih-Dar committed
31
32
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
33
from ...test_pipeline_mixin import PipelineTesterMixin
34
35
36
37


if is_torch_available():
    import torch
38
    from torch import nn
39

NielsRogge's avatar
NielsRogge committed
40
    from transformers import ViTForImageClassification, ViTForMaskedImageModeling, ViTModel
41
42
43
44
45


if is_vision_available():
    from PIL import Image

46
    from transformers import ViTImageProcessor
47
48
49
50
51
52
53
54
55
56
57
58
59


class ViTModelTester:
    def __init__(
        self,
        parent,
        batch_size=13,
        image_size=30,
        patch_size=2,
        num_channels=3,
        is_training=True,
        use_labels=True,
        hidden_size=32,
60
        num_hidden_layers=2,
61
62
63
64
65
66
67
68
        num_attention_heads=4,
        intermediate_size=37,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        type_sequence_label_size=10,
        initializer_range=0.02,
        scope=None,
NielsRogge's avatar
NielsRogge committed
69
        encoder_stride=2,
70
71
        mask_ratio=0.5,
        attn_implementation="eager",
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    ):
        self.parent = parent
        self.batch_size = batch_size
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.is_training = is_training
        self.use_labels = use_labels
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.type_sequence_label_size = type_sequence_label_size
        self.initializer_range = initializer_range
        self.scope = scope
NielsRogge's avatar
NielsRogge committed
90
        self.encoder_stride = encoder_stride
91
        self.attn_implementation = attn_implementation
92

NielsRogge's avatar
NielsRogge committed
93
        # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
94
        num_patches = (image_size // patch_size) ** 2
NielsRogge's avatar
NielsRogge committed
95
        self.seq_length = num_patches + 1
96
97
98
        self.mask_ratio = mask_ratio
        self.num_masks = int(mask_ratio * self.seq_length)
        self.mask_length = num_patches
99

100
101
102
103
104
105
106
    def prepare_config_and_inputs(self):
        pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])

        labels = None
        if self.use_labels:
            labels = ids_tensor([self.batch_size], self.type_sequence_label_size)

107
108
109
110
111
112
        config = self.get_config()

        return config, pixel_values, labels

    def get_config(self):
        return ViTConfig(
113
114
115
116
117
118
119
120
121
122
123
124
            image_size=self.image_size,
            patch_size=self.patch_size,
            num_channels=self.num_channels,
            hidden_size=self.hidden_size,
            num_hidden_layers=self.num_hidden_layers,
            num_attention_heads=self.num_attention_heads,
            intermediate_size=self.intermediate_size,
            hidden_act=self.hidden_act,
            hidden_dropout_prob=self.hidden_dropout_prob,
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
            is_decoder=False,
            initializer_range=self.initializer_range,
NielsRogge's avatar
NielsRogge committed
125
            encoder_stride=self.encoder_stride,
126
            attn_implementation=self.attn_implementation,
127
128
129
130
131
132
133
        )

    def create_and_check_model(self, config, pixel_values, labels):
        model = ViTModel(config=config)
        model.to(torch_device)
        model.eval()
        result = model(pixel_values)
NielsRogge's avatar
NielsRogge committed
134
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
135

NielsRogge's avatar
NielsRogge committed
136
137
138
139
140
141
    def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
        model = ViTForMaskedImageModeling(config=config)
        model.to(torch_device)
        model.eval()
        result = model(pixel_values)
        self.parent.assertEqual(
142
            result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
NielsRogge's avatar
NielsRogge committed
143
144
145
146
147
148
149
150
151
152
        )

        # test greyscale images
        config.num_channels = 1
        model = ViTForMaskedImageModeling(config)
        model.to(torch_device)
        model.eval()

        pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
        result = model(pixel_values)
153
        self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size))
NielsRogge's avatar
NielsRogge committed
154

155
156
157
158
159
160
161
162
    def create_and_check_for_image_classification(self, config, pixel_values, labels):
        config.num_labels = self.type_sequence_label_size
        model = ViTForImageClassification(config)
        model.to(torch_device)
        model.eval()
        result = model(pixel_values, labels=labels)
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))

NielsRogge's avatar
NielsRogge committed
163
164
165
166
167
168
169
170
171
172
        # test greyscale images
        config.num_channels = 1
        model = ViTForImageClassification(config)
        model.to(torch_device)
        model.eval()

        pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
        result = model(pixel_values)
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))

173
174
175
176
177
178
179
180
181
182
183
184
    def prepare_config_and_inputs_for_common(self):
        config_and_inputs = self.prepare_config_and_inputs()
        (
            config,
            pixel_values,
            labels,
        ) = config_and_inputs
        inputs_dict = {"pixel_values": pixel_values}
        return config, inputs_dict


@require_torch
185
class ViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
186
187
188
189
190
191
192
193
194
    """
    Here we also overwrite some of the tests of test_modeling_common.py, as ViT does not use input_ids, inputs_embeds,
    attention_mask and seq_length.
    """

    all_model_classes = (
        (
            ViTModel,
            ViTForImageClassification,
NielsRogge's avatar
NielsRogge committed
195
            ViTForMaskedImageModeling,
196
197
198
199
        )
        if is_torch_available()
        else ()
    )
200
    pipeline_model_mapping = (
201
        {"image-feature-extraction": ViTModel, "image-classification": ViTForImageClassification}
202
203
204
        if is_torch_available()
        else {}
    )
205
    fx_compatible = True
206
207
208
209
210
211
212

    test_pruning = False
    test_resize_embeddings = False
    test_head_masking = False

    def setUp(self):
        self.model_tester = ViTModelTester(self)
NielsRogge's avatar
NielsRogge committed
213
        self.config_tester = ConfigTester(self, config_class=ViTConfig, has_text_modality=False, hidden_size=37)
214

215
216
217
218
219
220
221
    @unittest.skip(
        "Since `torch==2.3+cu121`, although this test passes, many subsequent tests have `CUDA error: misaligned address`."
        "If `nvidia-xxx-cu118` are also installed, no failure (even with `torch==2.3+cu121`)."
    )
    def test_multi_gpu_data_parallel_forward(self):
        super().test_multi_gpu_data_parallel_forward()

222
    def test_config(self):
NielsRogge's avatar
NielsRogge committed
223
        self.config_tester.run_common_tests()
224

NielsRogge's avatar
NielsRogge committed
225
    @unittest.skip(reason="ViT does not use inputs_embeds")
226
227
228
    def test_inputs_embeds(self):
        pass

229
    def test_model_get_set_embeddings(self):
230
231
232
233
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)
234
            self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
235
            x = model.get_output_embeddings()
236
            self.assertTrue(x is None or isinstance(x, nn.Linear))
237
238
239
240
241

    def test_model(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_model(*config_and_inputs)

NielsRogge's avatar
NielsRogge committed
242
243
244
245
    def test_for_masked_image_modeling(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)

246
247
248
249
250
251
    def test_for_image_classification(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        self.model_tester.create_and_check_for_image_classification(*config_and_inputs)

    @slow
    def test_model_from_pretrained(self):
252
253
254
        model_name = "google/vit-base-patch16-224"
        model = ViTModel.from_pretrained(model_name)
        self.assertIsNotNone(model)
255
256
257
258


# We will verify our results on an image of cute cats
def prepare_img():
NielsRogge's avatar
NielsRogge committed
259
    image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
260
261
262
    return image


263
@require_torch
264
265
266
@require_vision
class ViTModelIntegrationTest(unittest.TestCase):
    @cached_property
267
268
    def default_image_processor(self):
        return ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") if is_vision_available() else None
269
270
271
272
273

    @slow
    def test_inference_image_classification_head(self):
        model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").to(torch_device)

274
        image_processor = self.default_image_processor
275
        image = prepare_img()
276
        inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
277
278

        # forward pass
279
280
        with torch.no_grad():
            outputs = model(**inputs)
281
282
283
284
285
286
287
288

        # verify the logits
        expected_shape = torch.Size((1, 1000))
        self.assertEqual(outputs.logits.shape, expected_shape)

        expected_slice = torch.tensor([-0.2744, 0.8215, -0.0836]).to(torch_device)

        self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
NielsRogge's avatar
NielsRogge committed
289
290
291
292
293
294
295
296
297

    @slow
    def test_inference_interpolate_pos_encoding(self):
        # ViT models have an `interpolate_pos_encoding` argument in their forward method,
        # allowing to interpolate the pre-trained position embeddings in order to use
        # the model on higher resolutions. The DINO model by Facebook AI leverages this
        # to visualize self-attention on higher resolution images.
        model = ViTModel.from_pretrained("facebook/dino-vits8").to(torch_device)

298
        image_processor = ViTImageProcessor.from_pretrained("facebook/dino-vits8", size=480)
NielsRogge's avatar
NielsRogge committed
299
        image = prepare_img()
300
        inputs = image_processor(images=image, return_tensors="pt")
NielsRogge's avatar
NielsRogge committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
        pixel_values = inputs.pixel_values.to(torch_device)

        # forward pass
        with torch.no_grad():
            outputs = model(pixel_values, interpolate_pos_encoding=True)

        # verify the logits
        expected_shape = torch.Size((1, 3601, 384))
        self.assertEqual(outputs.last_hidden_state.shape, expected_shape)

        expected_slice = torch.tensor(
            [[4.2340, 4.3906, -6.6692], [4.5463, 1.8928, -6.7257], [4.4429, 0.8496, -5.8585]]
        ).to(torch_device)

        self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
316
317
318

    @slow
    @require_accelerate
319
320
    @require_torch_accelerator
    @require_torch_fp16
321
322
323
324
325
    def test_inference_fp16(self):
        r"""
        A small test to make sure that inference work in half precision without any problem.
        """
        model = ViTModel.from_pretrained("facebook/dino-vits8", torch_dtype=torch.float16, device_map="auto")
326
        image_processor = self.default_image_processor
327
328

        image = prepare_img()
329
        inputs = image_processor(images=image, return_tensors="pt")
330
331
332
333
334
        pixel_values = inputs.pixel_values.to(torch_device)

        # forward pass to make sure inference works in fp16
        with torch.no_grad():
            _ = model(pixel_values)