test_image.py 11.3 KB
Newer Older
1
import glob
2
3
import io
import os
4
5
import unittest

6
import numpy as np
7
8
import torch
from PIL import Image
9
10
from common_utils import get_tmp_dir

11
from torchvision.io.image import (
12
    decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
13
    encode_png, write_png, write_file, ImageReadMode)
Francisco Massa's avatar
Francisco Massa committed
14

15
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
16
17
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
18
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
19
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
20
21
22
23


def get_images(directory, img_ext):
    assert os.path.isdir(directory)
24
25
26
27
    image_paths = glob.glob(directory + f'/**/*{img_ext}', recursive=True)
    for path in image_paths:
        if path.split(os.sep)[-2] not in ['damaged_jpeg', 'jpeg_write']:
            yield path
28
29


30
31
32
33
34
35
36
37
38
39
40
41
42
def pil_read_image(img_path):
    with Image.open(img_path) as img:
        return torch.from_numpy(np.array(img))


def normalize_dimensions(img_pil):
    if len(img_pil.shape) == 3:
        img_pil = img_pil.permute(2, 0, 1)
    else:
        img_pil = img_pil.unsqueeze(0)
    return img_pil


43
class ImageTester(unittest.TestCase):
44
    def test_decode_jpeg(self):
45
        conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("RGB", ImageReadMode.RGB)]
46
        for img_path in get_images(IMAGE_ROOT, ".jpg"):
47
            for pil_mode, mode in conversion:
48
49
50
51
52
53
54
55
56
57
58
59
60
61
                with Image.open(img_path) as img:
                    is_cmyk = img.mode == "CMYK"
                    if pil_mode is not None:
                        if is_cmyk:
                            # libjpeg does not support the conversion
                            continue
                        img = img.convert(pil_mode)
                    img_pil = torch.from_numpy(np.array(img))
                    if is_cmyk:
                        # flip the colors to match libjpeg
                        img_pil = 255 - img_pil

                img_pil = normalize_dimensions(img_pil)
                data = read_file(img_path)
62
                img_ljpeg = decode_image(data, mode=mode)
63
64
65
66
67

                # Permit a small variation on pixel values to account for implementation
                # differences between Pillow and LibJPEG.
                abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
                self.assertTrue(abs_mean_diff < 2)
68

Francisco Massa's avatar
Francisco Massa committed
69
        with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
70
71
            decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

Francisco Massa's avatar
Francisco Massa committed
72
        with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"):
73
            decode_jpeg(torch.empty((100,), dtype=torch.float16))
74
75
76
77

        with self.assertRaises(RuntimeError):
            decode_jpeg(torch.empty((100), dtype=torch.uint8))

78
79
    def test_damaged_images(self):
        # Test image with bad Huffman encoding (should not raise)
80
        bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg'))
81
        try:
82
            _ = decode_jpeg(bad_huff)
83
84
85
86
87
88
89
        except RuntimeError:
            self.assertTrue(False)

        # Truncated images should raise an exception
        truncated_images = glob.glob(
            os.path.join(DAMAGED_JPEG, 'corrupt*.jpg'))
        for image_path in truncated_images:
90
            data = read_file(image_path)
91
            with self.assertRaises(RuntimeError):
92
                decode_jpeg(data)
93

94
    def test_encode_jpeg(self):
95
        for img_path in get_images(ENCODE_JPEG, ".jpg"):
96
97
98
99
100
            dirname = os.path.dirname(img_path)
            filename, _ = os.path.splitext(os.path.basename(img_path))
            write_folder = os.path.join(dirname, 'jpeg_write')
            expected_file = os.path.join(
                write_folder, '{0}_pil.jpg'.format(filename))
101
            img = decode_jpeg(read_file(img_path))
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

            with open(expected_file, 'rb') as f:
                pil_bytes = f.read()
                pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
            for src_img in [img, img.contiguous()]:
                # PIL sets jpeg quality to 75 by default
                jpeg_bytes = encode_jpeg(src_img, quality=75)
                self.assertTrue(jpeg_bytes.equal(pil_bytes))

        with self.assertRaisesRegex(
                RuntimeError, "Input tensor dtype should be uint8"):
            encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))

        with self.assertRaisesRegex(
                ValueError, "Image quality should be a positive number "
117
                            "between 1 and 100"):
118
119
120
121
            encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)

        with self.assertRaisesRegex(
                ValueError, "Image quality should be a positive number "
122
                            "between 1 and 100"):
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
            encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)

        with self.assertRaisesRegex(
                RuntimeError, "The number of channels should be 1 or 3, got: 5"):
            encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))

        with self.assertRaisesRegex(
                RuntimeError, "Input data should be a 3-dimensional tensor"):
            encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))

        with self.assertRaisesRegex(
                RuntimeError, "Input data should be a 3-dimensional tensor"):
            encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))

    def test_write_jpeg(self):
138
139
140
141
        with get_tmp_dir() as d:
            for img_path in get_images(ENCODE_JPEG, ".jpg"):
                data = read_file(img_path)
                img = decode_jpeg(data)
142

143
144
145
146
147
148
                basedir = os.path.dirname(img_path)
                filename, _ = os.path.splitext(os.path.basename(img_path))
                torch_jpeg = os.path.join(
                    d, '{0}_torch.jpg'.format(filename))
                pil_jpeg = os.path.join(
                    basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
149

150
                write_jpeg(img, torch_jpeg, quality=75)
151

152
153
                with open(torch_jpeg, 'rb') as f:
                    torch_bytes = f.read()
154

155
156
                with open(pil_jpeg, 'rb') as f:
                    pil_bytes = f.read()
157

158
                self.assertEqual(torch_bytes, pil_bytes)
159

160
    def test_decode_png(self):
161
162
        conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA),
                      ("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)]
163
        for img_path in get_images(FAKEDATA_DIR, ".png"):
164
            for pil_mode, mode in conversion:
165
166
167
168
                with Image.open(img_path) as img:
                    if pil_mode is not None:
                        img = img.convert(pil_mode)
                    img_pil = torch.from_numpy(np.array(img))
169

170
171
                img_pil = normalize_dimensions(img_pil)
                data = read_file(img_path)
172
                img_lpng = decode_image(data, mode=mode)
173
174
175
176
177
178
179
180

                tol = 0 if conversion is None else 1
                self.assertTrue(img_lpng.allclose(img_pil, atol=tol))

        with self.assertRaises(RuntimeError):
            decode_png(torch.empty((), dtype=torch.uint8))
        with self.assertRaises(RuntimeError):
            decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    def test_encode_png(self):
        for img_path in get_images(IMAGE_DIR, '.png'):
            pil_image = Image.open(img_path)
            img_pil = torch.from_numpy(np.array(pil_image))
            img_pil = img_pil.permute(2, 0, 1)
            png_buf = encode_png(img_pil, compression_level=6)

            rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
            rec_img = torch.from_numpy(np.array(rec_img))
            rec_img = rec_img.permute(2, 0, 1)

            self.assertTrue(img_pil.equal(rec_img))

        with self.assertRaisesRegex(
                RuntimeError, "Input tensor dtype should be uint8"):
            encode_png(torch.empty((3, 100, 100), dtype=torch.float32))

        with self.assertRaisesRegex(
                RuntimeError, "Compression level should be between 0 and 9"):
            encode_png(torch.empty((3, 100, 100), dtype=torch.uint8),
                       compression_level=-1)

        with self.assertRaisesRegex(
                RuntimeError, "Compression level should be between 0 and 9"):
            encode_png(torch.empty((3, 100, 100), dtype=torch.uint8),
                       compression_level=10)

        with self.assertRaisesRegex(
                RuntimeError, "The number of channels should be 1 or 3, got: 5"):
            encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))

    def test_write_png(self):
214
215
216
217
218
219
220
221
222
223
224
225
226
        with get_tmp_dir() as d:
            for img_path in get_images(IMAGE_DIR, '.png'):
                pil_image = Image.open(img_path)
                img_pil = torch.from_numpy(np.array(pil_image))
                img_pil = img_pil.permute(2, 0, 1)

                filename, _ = os.path.splitext(os.path.basename(img_path))
                torch_png = os.path.join(d, '{0}_torch.png'.format(filename))
                write_png(img_pil, torch_png, compression_level=6)
                saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
                saved_image = saved_image.permute(2, 0, 1)

                self.assertTrue(img_pil.equal(saved_image))
227

Francisco Massa's avatar
Francisco Massa committed
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    def test_read_file(self):
        with get_tmp_dir() as d:
            fname, content = 'test1.bin', b'TorchVision\211\n'
            fpath = os.path.join(d, fname)
            with open(fpath, 'wb') as f:
                f.write(content)

            data = read_file(fpath)
            expected = torch.tensor(list(content), dtype=torch.uint8)
            self.assertTrue(data.equal(expected))
            os.unlink(fpath)

        with self.assertRaisesRegex(
                RuntimeError, "No such file or directory: 'tst'"):
            read_file('tst')

244
245
246
247
248
249
250
251
252
253
254
255
    def test_read_file_non_ascii(self):
        with get_tmp_dir() as d:
            fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
            fpath = os.path.join(d, fname)
            with open(fpath, 'wb') as f:
                f.write(content)

            data = read_file(fpath)
            expected = torch.tensor(list(content), dtype=torch.uint8)
            self.assertTrue(data.equal(expected))
            os.unlink(fpath)

Francisco Massa's avatar
Francisco Massa committed
256
257
258
259
260
261
262
263
264
265
266
267
    def test_write_file(self):
        with get_tmp_dir() as d:
            fname, content = 'test1.bin', b'TorchVision\211\n'
            fpath = os.path.join(d, fname)
            content_tensor = torch.tensor(list(content), dtype=torch.uint8)
            write_file(fpath, content_tensor)

            with open(fpath, 'rb') as f:
                saved_content = f.read()
            self.assertEqual(content, saved_content)
            os.unlink(fpath)

268
269
270
271
272
273
274
275
276
277
278
279
    def test_write_file_non_ascii(self):
        with get_tmp_dir() as d:
            fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
            fpath = os.path.join(d, fname)
            content_tensor = torch.tensor(list(content), dtype=torch.uint8)
            write_file(fpath, content_tensor)

            with open(fpath, 'rb') as f:
                saved_content = f.read()
            self.assertEqual(content, saved_content)
            os.unlink(fpath)

280
281
282

if __name__ == '__main__':
    unittest.main()