test_image_processing_imagegpt.py 10.4 KB
Newer Older
NielsRogge's avatar
NielsRogge committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# coding=utf-8
# Copyright 2021 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json
import os
import tempfile
import unittest

import numpy as np
from datasets import load_dataset

from transformers.testing_utils import require_torch, require_vision, slow
26
from transformers.utils import is_torch_available, is_vision_available
NielsRogge's avatar
NielsRogge committed
27

28
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
NielsRogge's avatar
NielsRogge committed
29
30
31
32
33
34
35
36


if is_torch_available():
    import torch

if is_vision_available():
    from PIL import Image

37
    from transformers import ImageGPTImageProcessor
NielsRogge's avatar
NielsRogge committed
38
39


40
class ImageGPTImageProcessingTester(unittest.TestCase):
NielsRogge's avatar
NielsRogge committed
41
42
43
44
45
46
47
48
49
    def __init__(
        self,
        parent,
        batch_size=7,
        num_channels=3,
        image_size=18,
        min_resolution=30,
        max_resolution=400,
        do_resize=True,
amyeroberts's avatar
amyeroberts committed
50
        size=None,
NielsRogge's avatar
NielsRogge committed
51
52
        do_normalize=True,
    ):
amyeroberts's avatar
amyeroberts committed
53
        size = size if size is not None else {"height": 18, "width": 18}
NielsRogge's avatar
NielsRogge committed
54
55
56
57
58
59
60
61
62
63
        self.parent = parent
        self.batch_size = batch_size
        self.num_channels = num_channels
        self.image_size = image_size
        self.min_resolution = min_resolution
        self.max_resolution = max_resolution
        self.do_resize = do_resize
        self.size = size
        self.do_normalize = do_normalize

64
    def prepare_image_processor_dict(self):
NielsRogge's avatar
NielsRogge committed
65
66
67
68
69
70
71
72
73
74
75
76
77
        return {
            # here we create 2 clusters for the sake of simplicity
            "clusters": np.asarray(
                [
                    [0.8866443634033203, 0.6618829369544983, 0.3891746401786804],
                    [-0.6042559146881104, -0.02295008860528469, 0.5423797369003296],
                ]
            ),
            "do_resize": self.do_resize,
            "size": self.size,
            "do_normalize": self.do_normalize,
        }

78
79
80
81
82
83
84
85
86
87
88
89
90
91
    def expected_output_image_shape(self, images):
        return (self.size["height"] * self.size["width"],)

    def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
        return prepare_image_inputs(
            batch_size=self.batch_size,
            num_channels=self.num_channels,
            min_resolution=self.min_resolution,
            max_resolution=self.max_resolution,
            equal_resolution=equal_resolution,
            numpify=numpify,
            torchify=torchify,
        )

NielsRogge's avatar
NielsRogge committed
92
93
94

@require_torch
@require_vision
95
class ImageGPTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
96
    image_processing_class = ImageGPTImageProcessor if is_vision_available() else None
NielsRogge's avatar
NielsRogge committed
97
98

    def setUp(self):
99
        self.image_processor_tester = ImageGPTImageProcessingTester(self)
NielsRogge's avatar
NielsRogge committed
100
101

    @property
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    def image_processor_dict(self):
        return self.image_processor_tester.prepare_image_processor_dict()

    def test_image_processor_properties(self):
        image_processing = self.image_processing_class(**self.image_processor_dict)
        self.assertTrue(hasattr(image_processing, "clusters"))
        self.assertTrue(hasattr(image_processing, "do_resize"))
        self.assertTrue(hasattr(image_processing, "size"))
        self.assertTrue(hasattr(image_processing, "do_normalize"))

    def test_image_processor_from_dict_with_kwargs(self):
        image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
        self.assertEqual(image_processor.size, {"height": 18, "width": 18})

        image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
        self.assertEqual(image_processor.size, {"height": 42, "width": 42})

    def test_image_processor_to_json_string(self):
        image_processor = self.image_processing_class(**self.image_processor_dict)
        obj = json.loads(image_processor.to_json_string())
        for key, value in self.image_processor_dict.items():
NielsRogge's avatar
NielsRogge committed
123
124
125
126
127
            if key == "clusters":
                self.assertTrue(np.array_equal(value, obj[key]))
            else:
                self.assertEqual(obj[key], value)

128
129
    def test_image_processor_to_json_file(self):
        image_processor_first = self.image_processing_class(**self.image_processor_dict)
NielsRogge's avatar
NielsRogge committed
130
131

        with tempfile.TemporaryDirectory() as tmpdirname:
132
133
134
            json_file_path = os.path.join(tmpdirname, "image_processor.json")
            image_processor_first.to_json_file(json_file_path)
            image_processor_second = self.image_processing_class.from_json_file(json_file_path).to_dict()
NielsRogge's avatar
NielsRogge committed
135

136
137
        image_processor_first = image_processor_first.to_dict()
        for key, value in image_processor_first.items():
NielsRogge's avatar
NielsRogge committed
138
            if key == "clusters":
139
                self.assertTrue(np.array_equal(value, image_processor_second[key]))
NielsRogge's avatar
NielsRogge committed
140
            else:
141
                self.assertEqual(image_processor_first[key], value)
NielsRogge's avatar
NielsRogge committed
142

143
144
    def test_image_processor_from_and_save_pretrained(self):
        image_processor_first = self.image_processing_class(**self.image_processor_dict)
NielsRogge's avatar
NielsRogge committed
145
146

        with tempfile.TemporaryDirectory() as tmpdirname:
147
148
            image_processor_first.save_pretrained(tmpdirname)
            image_processor_second = self.image_processing_class.from_pretrained(tmpdirname).to_dict()
NielsRogge's avatar
NielsRogge committed
149

150
151
        image_processor_first = image_processor_first.to_dict()
        for key, value in image_processor_first.items():
NielsRogge's avatar
NielsRogge committed
152
            if key == "clusters":
153
                self.assertTrue(np.array_equal(value, image_processor_second[key]))
NielsRogge's avatar
NielsRogge committed
154
            else:
155
                self.assertEqual(image_processor_first[key], value)
NielsRogge's avatar
NielsRogge committed
156
157
158
159
160

    @unittest.skip("ImageGPT requires clusters at initialization")
    def test_init_without_params(self):
        pass

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    # Override the test from ImageProcessingTestMixin as ImageGPT model takes input_ids as input
    def test_call_pil(self):
        # Initialize image_processing
        image_processing = self.image_processing_class(**self.image_processor_dict)
        # create random PIL images
        image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
        for image in image_inputs:
            self.assertIsInstance(image, Image.Image)

        # Test not batched input
        encoded_images = image_processing(image_inputs[0], return_tensors="pt").input_ids
        expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(encoded_images)
        self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))

        # Test batched
        encoded_images = image_processing(image_inputs, return_tensors="pt").input_ids
        self.assertEqual(
            tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
        )

    # Override the test from ImageProcessingTestMixin as ImageGPT model takes input_ids as input
    def test_call_numpy(self):
        # Initialize image_processing
        image_processing = self.image_processing_class(**self.image_processor_dict)
        # create random numpy tensors
        image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
        for image in image_inputs:
            self.assertIsInstance(image, np.ndarray)

        # Test not batched input
        encoded_images = image_processing(image_inputs[0], return_tensors="pt").input_ids
        expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(encoded_images)
        self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))

        # Test batched
        encoded_images = image_processing(image_inputs, return_tensors="pt").input_ids
        self.assertEqual(
            tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
        )

amyeroberts's avatar
amyeroberts committed
201
202
203
204
    @unittest.skip("ImageGPT assumes clusters for 3 channels")
    def test_call_numpy_4_channels(self):
        pass

205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    # Override the test from ImageProcessingTestMixin as ImageGPT model takes input_ids as input
    def test_call_pytorch(self):
        # Initialize image_processing
        image_processing = self.image_processing_class(**self.image_processor_dict)
        # create random PyTorch tensors
        image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
        expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)

        for image in image_inputs:
            self.assertIsInstance(image, torch.Tensor)

        # Test not batched input
        encoded_images = image_processing(image_inputs[0], return_tensors="pt").input_ids
        self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))

        # Test batched
        encoded_images = image_processing(image_inputs, return_tensors="pt").input_ids
        self.assertEqual(
            tuple(encoded_images.shape),
            (self.image_processor_tester.batch_size, *expected_output_image_shape),
        )

NielsRogge's avatar
NielsRogge committed
227
228
229
230
231
232
233
234
235
236
237
238
239
240

def prepare_images():
    dataset = load_dataset("hf-internal-testing/fixtures_image_utils", split="test")

    image1 = Image.open(dataset[4]["file"])
    image2 = Image.open(dataset[5]["file"])

    images = [image1, image2]

    return images


@require_vision
@require_torch
241
class ImageGPTImageProcessorIntegrationTest(unittest.TestCase):
NielsRogge's avatar
NielsRogge committed
242
243
    @slow
    def test_image(self):
244
        image_processing = ImageGPTImageProcessor.from_pretrained("openai/imagegpt-small")
NielsRogge's avatar
NielsRogge committed
245
246
247
248

        images = prepare_images()

        # test non-batched
249
        encoding = image_processing(images[0], return_tensors="pt")
NielsRogge's avatar
NielsRogge committed
250

251
252
        self.assertIsInstance(encoding.input_ids, torch.LongTensor)
        self.assertEqual(encoding.input_ids.shape, (1, 1024))
NielsRogge's avatar
NielsRogge committed
253
254

        expected_slice = [306, 191, 191]
255
        self.assertEqual(encoding.input_ids[0, :3].tolist(), expected_slice)
NielsRogge's avatar
NielsRogge committed
256
257

        # test batched
258
        encoding = image_processing(images, return_tensors="pt")
NielsRogge's avatar
NielsRogge committed
259

260
261
        self.assertIsInstance(encoding.input_ids, torch.LongTensor)
        self.assertEqual(encoding.input_ids.shape, (2, 1024))
NielsRogge's avatar
NielsRogge committed
262
263

        expected_slice = [303, 13, 13]
264
        self.assertEqual(encoding.input_ids[1, -3:].tolist(), expected_slice)