test_image.py 13.3 KB
Newer Older
1
import glob
2
3
import io
import os
4
5
import unittest

6
import pytest
7
import numpy as np
8
9
import torch
from PIL import Image
10
from common_utils import get_tmp_dir, needs_cuda
11
from _assert_utils import assert_equal
12

13
from torchvision.io.image import (
14
    decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
15
    encode_png, write_png, write_file, ImageReadMode)
Francisco Massa's avatar
Francisco Massa committed
16

17
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
18
19
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
20
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
21
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
22
23
24
25


def get_images(directory, img_ext):
    assert os.path.isdir(directory)
26
27
28
29
    image_paths = glob.glob(directory + f'/**/*{img_ext}', recursive=True)
    for path in image_paths:
        if path.split(os.sep)[-2] not in ['damaged_jpeg', 'jpeg_write']:
            yield path
30
31


32
33
34
35
36
37
38
39
40
41
42
43
44
def pil_read_image(img_path):
    with Image.open(img_path) as img:
        return torch.from_numpy(np.array(img))


def normalize_dimensions(img_pil):
    if len(img_pil.shape) == 3:
        img_pil = img_pil.permute(2, 0, 1)
    else:
        img_pil = img_pil.unsqueeze(0)
    return img_pil


45
class ImageTester(unittest.TestCase):
46
    def test_decode_jpeg(self):
47
        conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("RGB", ImageReadMode.RGB)]
48
        for img_path in get_images(IMAGE_ROOT, ".jpg"):
49
            for pil_mode, mode in conversion:
50
51
52
53
54
55
56
57
58
59
60
61
62
63
                with Image.open(img_path) as img:
                    is_cmyk = img.mode == "CMYK"
                    if pil_mode is not None:
                        if is_cmyk:
                            # libjpeg does not support the conversion
                            continue
                        img = img.convert(pil_mode)
                    img_pil = torch.from_numpy(np.array(img))
                    if is_cmyk:
                        # flip the colors to match libjpeg
                        img_pil = 255 - img_pil

                img_pil = normalize_dimensions(img_pil)
                data = read_file(img_path)
64
                img_ljpeg = decode_image(data, mode=mode)
65
66
67
68
69

                # Permit a small variation on pixel values to account for implementation
                # differences between Pillow and LibJPEG.
                abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
                self.assertTrue(abs_mean_diff < 2)
70

Francisco Massa's avatar
Francisco Massa committed
71
        with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
72
73
            decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

Francisco Massa's avatar
Francisco Massa committed
74
        with self.assertRaisesRegex(RuntimeError, "Expected a torch.uint8 tensor"):
75
            decode_jpeg(torch.empty((100,), dtype=torch.float16))
76
77
78
79

        with self.assertRaises(RuntimeError):
            decode_jpeg(torch.empty((100), dtype=torch.uint8))

80
81
    def test_damaged_images(self):
        # Test image with bad Huffman encoding (should not raise)
82
        bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg'))
83
        try:
84
            _ = decode_jpeg(bad_huff)
85
86
87
88
89
90
91
        except RuntimeError:
            self.assertTrue(False)

        # Truncated images should raise an exception
        truncated_images = glob.glob(
            os.path.join(DAMAGED_JPEG, 'corrupt*.jpg'))
        for image_path in truncated_images:
92
            data = read_file(image_path)
93
            with self.assertRaises(RuntimeError):
94
                decode_jpeg(data)
95

96
    def test_encode_jpeg(self):
97
        for img_path in get_images(ENCODE_JPEG, ".jpg"):
98
99
100
101
102
            dirname = os.path.dirname(img_path)
            filename, _ = os.path.splitext(os.path.basename(img_path))
            write_folder = os.path.join(dirname, 'jpeg_write')
            expected_file = os.path.join(
                write_folder, '{0}_pil.jpg'.format(filename))
103
            img = decode_jpeg(read_file(img_path))
104
105
106
107
108
109
110

            with open(expected_file, 'rb') as f:
                pil_bytes = f.read()
                pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
            for src_img in [img, img.contiguous()]:
                # PIL sets jpeg quality to 75 by default
                jpeg_bytes = encode_jpeg(src_img, quality=75)
111
                assert_equal(jpeg_bytes, pil_bytes)
112
113
114
115
116
117
118

        with self.assertRaisesRegex(
                RuntimeError, "Input tensor dtype should be uint8"):
            encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))

        with self.assertRaisesRegex(
                ValueError, "Image quality should be a positive number "
119
                            "between 1 and 100"):
120
121
122
123
            encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)

        with self.assertRaisesRegex(
                ValueError, "Image quality should be a positive number "
124
                            "between 1 and 100"):
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
            encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)

        with self.assertRaisesRegex(
                RuntimeError, "The number of channels should be 1 or 3, got: 5"):
            encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))

        with self.assertRaisesRegex(
                RuntimeError, "Input data should be a 3-dimensional tensor"):
            encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))

        with self.assertRaisesRegex(
                RuntimeError, "Input data should be a 3-dimensional tensor"):
            encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))

    def test_write_jpeg(self):
140
141
142
143
        with get_tmp_dir() as d:
            for img_path in get_images(ENCODE_JPEG, ".jpg"):
                data = read_file(img_path)
                img = decode_jpeg(data)
144

145
146
147
148
149
150
                basedir = os.path.dirname(img_path)
                filename, _ = os.path.splitext(os.path.basename(img_path))
                torch_jpeg = os.path.join(
                    d, '{0}_torch.jpg'.format(filename))
                pil_jpeg = os.path.join(
                    basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
151

152
                write_jpeg(img, torch_jpeg, quality=75)
153

154
155
                with open(torch_jpeg, 'rb') as f:
                    torch_bytes = f.read()
156

157
158
                with open(pil_jpeg, 'rb') as f:
                    pil_bytes = f.read()
159

160
                self.assertEqual(torch_bytes, pil_bytes)
161

162
    def test_decode_png(self):
163
164
        conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA),
                      ("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)]
165
        for img_path in get_images(FAKEDATA_DIR, ".png"):
166
            for pil_mode, mode in conversion:
167
168
169
170
                with Image.open(img_path) as img:
                    if pil_mode is not None:
                        img = img.convert(pil_mode)
                    img_pil = torch.from_numpy(np.array(img))
171

172
173
                img_pil = normalize_dimensions(img_pil)
                data = read_file(img_path)
174
                img_lpng = decode_image(data, mode=mode)
175
176
177
178
179
180
181
182

                tol = 0 if conversion is None else 1
                self.assertTrue(img_lpng.allclose(img_pil, atol=tol))

        with self.assertRaises(RuntimeError):
            decode_png(torch.empty((), dtype=torch.uint8))
        with self.assertRaises(RuntimeError):
            decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
183

184
185
186
187
188
189
190
191
192
193
194
    def test_encode_png(self):
        for img_path in get_images(IMAGE_DIR, '.png'):
            pil_image = Image.open(img_path)
            img_pil = torch.from_numpy(np.array(pil_image))
            img_pil = img_pil.permute(2, 0, 1)
            png_buf = encode_png(img_pil, compression_level=6)

            rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
            rec_img = torch.from_numpy(np.array(rec_img))
            rec_img = rec_img.permute(2, 0, 1)

195
            assert_equal(img_pil, rec_img)
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

        with self.assertRaisesRegex(
                RuntimeError, "Input tensor dtype should be uint8"):
            encode_png(torch.empty((3, 100, 100), dtype=torch.float32))

        with self.assertRaisesRegex(
                RuntimeError, "Compression level should be between 0 and 9"):
            encode_png(torch.empty((3, 100, 100), dtype=torch.uint8),
                       compression_level=-1)

        with self.assertRaisesRegex(
                RuntimeError, "Compression level should be between 0 and 9"):
            encode_png(torch.empty((3, 100, 100), dtype=torch.uint8),
                       compression_level=10)

        with self.assertRaisesRegex(
                RuntimeError, "The number of channels should be 1 or 3, got: 5"):
            encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))

    def test_write_png(self):
216
217
218
219
220
221
222
223
224
225
226
227
        with get_tmp_dir() as d:
            for img_path in get_images(IMAGE_DIR, '.png'):
                pil_image = Image.open(img_path)
                img_pil = torch.from_numpy(np.array(pil_image))
                img_pil = img_pil.permute(2, 0, 1)

                filename, _ = os.path.splitext(os.path.basename(img_path))
                torch_png = os.path.join(d, '{0}_torch.png'.format(filename))
                write_png(img_pil, torch_png, compression_level=6)
                saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
                saved_image = saved_image.permute(2, 0, 1)

228
                assert_equal(img_pil, saved_image)
229

Francisco Massa's avatar
Francisco Massa committed
230
231
232
233
234
235
236
237
238
    def test_read_file(self):
        with get_tmp_dir() as d:
            fname, content = 'test1.bin', b'TorchVision\211\n'
            fpath = os.path.join(d, fname)
            with open(fpath, 'wb') as f:
                f.write(content)

            data = read_file(fpath)
            expected = torch.tensor(list(content), dtype=torch.uint8)
239
            assert_equal(data, expected)
Francisco Massa's avatar
Francisco Massa committed
240
241
242
243
244
245
            os.unlink(fpath)

        with self.assertRaisesRegex(
                RuntimeError, "No such file or directory: 'tst'"):
            read_file('tst')

246
247
248
249
250
251
252
253
254
    def test_read_file_non_ascii(self):
        with get_tmp_dir() as d:
            fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
            fpath = os.path.join(d, fname)
            with open(fpath, 'wb') as f:
                f.write(content)

            data = read_file(fpath)
            expected = torch.tensor(list(content), dtype=torch.uint8)
255
            assert_equal(data, expected)
256
257
            os.unlink(fpath)

Francisco Massa's avatar
Francisco Massa committed
258
259
260
261
262
263
264
265
266
267
268
269
    def test_write_file(self):
        with get_tmp_dir() as d:
            fname, content = 'test1.bin', b'TorchVision\211\n'
            fpath = os.path.join(d, fname)
            content_tensor = torch.tensor(list(content), dtype=torch.uint8)
            write_file(fpath, content_tensor)

            with open(fpath, 'rb') as f:
                saved_content = f.read()
            self.assertEqual(content, saved_content)
            os.unlink(fpath)

270
271
272
273
274
275
276
277
278
279
280
281
    def test_write_file_non_ascii(self):
        with get_tmp_dir() as d:
            fname, content = '日本語(Japanese).bin', b'TorchVision\211\n'
            fpath = os.path.join(d, fname)
            content_tensor = torch.tensor(list(content), dtype=torch.uint8)
            write_file(fpath, content_tensor)

            with open(fpath, 'rb') as f:
                saved_content = f.read()
            self.assertEqual(content, saved_content)
            os.unlink(fpath)

282

283
@needs_cuda
284
285
286
287
288
289
290
291
@pytest.mark.parametrize('img_path', [
    # We need to change the "id" for that parameter.
    # If we don't, the test id (i.e. its name) will contain the whole path to the image which is machine-specific,
    # and this creates issues when the test is running in a different machine than where it was collected
    # (typically, in fb internal infra)
    pytest.param(jpeg_path, id=jpeg_path.split('/')[-1])
    for jpeg_path in get_images(IMAGE_ROOT, ".jpg")
])
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
@pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
@pytest.mark.parametrize('scripted', (False, True))
def test_decode_jpeg_cuda(mode, img_path, scripted):
    if 'cmyk' in img_path:
        pytest.xfail("Decoding a CMYK jpeg isn't supported")
    tester = ImageTester()
    data = read_file(img_path)
    img = decode_image(data, mode=mode)
    f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
    img_nvjpeg = f(data, mode=mode, device='cuda')

    # Some difference expected between jpeg implementations
    tester.assertTrue((img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2)


@needs_cuda
@pytest.mark.parametrize('cuda_device', ('cuda', 'cuda:0', torch.device('cuda')))
def test_decode_jpeg_cuda_device_param(cuda_device):
    """Make sure we can pass a string or a torch.device as device param"""
    data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
    decode_jpeg(data, device=cuda_device)


@needs_cuda
def test_decode_jpeg_cuda_errors():
    data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
        decode_jpeg(data.reshape(-1, 1), device='cuda')
    with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
        decode_jpeg(data.to('cuda'), device='cuda')
    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
        decode_jpeg(data.to(torch.float), device='cuda')
    with pytest.raises(RuntimeError, match="Expected a cuda device"):
        torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu')


328
329
if __name__ == '__main__':
    unittest.main()