test_image.py 18.4 KB
Newer Older
1
import glob
2
3
import io
import os
4
5
import sys
from pathlib import Path
6

7
import numpy as np
8
import pytest
9
import torch
10
import torchvision.transforms.functional as F
11
12
from common_utils import assert_equal, needs_cuda
from PIL import __version__ as PILLOW_VERSION, Image
13
from torchvision.io.image import (
14
15
    _read_png_16,
    decode_image,
16
    decode_jpeg,
17
    decode_png,
18
19
20
    encode_jpeg,
    encode_png,
    ImageReadMode,
21
    read_file,
22
    read_image,
23
24
25
    write_file,
    write_jpeg,
    write_png,
26
)
Francisco Massa's avatar
Francisco Massa committed
27

28
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
29
30
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
31
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, "damaged_jpeg")
32
DAMAGED_PNG = os.path.join(IMAGE_ROOT, "damaged_png")
33
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
34
INTERLACED_PNG = os.path.join(IMAGE_ROOT, "interlaced_png")
35
36
IS_WINDOWS = sys.platform in ("win32", "cygwin")
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
37
38
39
40
41
42
43
44


def _get_safe_image_name(name):
    # Used when we need to change the pytest "id" for an "image path" parameter.
    # If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific,
    # and this creates issues when the test is running in a different machine than where it was collected
    # (typically, in fb internal infra)
    return name.split(os.path.sep)[-1]
45
46
47
48


def get_images(directory, img_ext):
    assert os.path.isdir(directory)
49
    image_paths = glob.glob(directory + f"/**/*{img_ext}", recursive=True)
50
    for path in image_paths:
51
        if path.split(os.sep)[-2] not in ["damaged_jpeg", "jpeg_write"]:
52
            yield path
53
54


55
56
57
58
59
60
61
62
63
64
65
66
67
def pil_read_image(img_path):
    with Image.open(img_path) as img:
        return torch.from_numpy(np.array(img))


def normalize_dimensions(img_pil):
    if len(img_pil.shape) == 3:
        img_pil = img_pil.permute(2, 0, 1)
    else:
        img_pil = img_pil.unsqueeze(0)
    return img_pil


68
69
70
71
72
73
74
75
76
77
78
79
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize(
    "pil_mode, mode",
    [
        (None, ImageReadMode.UNCHANGED),
        ("L", ImageReadMode.GRAY),
        ("RGB", ImageReadMode.RGB),
    ],
)
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def test_decode_jpeg(img_path, pil_mode, mode):

    with Image.open(img_path) as img:
        is_cmyk = img.mode == "CMYK"
        if pil_mode is not None:
            if is_cmyk:
                # libjpeg does not support the conversion
                pytest.xfail("Decoding a CMYK jpeg isn't supported")
            img = img.convert(pil_mode)
        img_pil = torch.from_numpy(np.array(img))
        if is_cmyk:
            # flip the colors to match libjpeg
            img_pil = 255 - img_pil

    img_pil = normalize_dimensions(img_pil)
    data = read_file(img_path)
    img_ljpeg = decode_image(data, mode=mode)

    # Permit a small variation on pixel values to account for implementation
    # differences between Pillow and LibJPEG.
    abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
    assert abs_mean_diff < 2


def test_decode_jpeg_errors():
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
        decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
        decode_jpeg(torch.empty((100,), dtype=torch.float16))

    with pytest.raises(RuntimeError, match="Not a JPEG file"):
        decode_jpeg(torch.empty((100), dtype=torch.uint8))


def test_decode_bad_huffman_images():
    # sanity check: make sure we can decode the bad Huffman encoding
117
    bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg"))
118
119
120
    decode_jpeg(bad_huff)


121
122
123
124
125
126
127
@pytest.mark.parametrize(
    "img_path",
    [
        pytest.param(truncated_image, id=_get_safe_image_name(truncated_image))
        for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, "corrupt*.jpg"))
    ],
)
128
129
130
def test_damaged_corrupt_images(img_path):
    # Truncated images should raise an exception
    data = read_file(img_path)
131
    if "corrupt34" in img_path:
132
133
134
135
136
137
138
        match_message = "Image is incomplete or truncated"
    else:
        match_message = "Unsupported marker type"
    with pytest.raises(RuntimeError, match=match_message):
        decode_jpeg(data)


139
140
141
142
143
144
145
146
147
148
149
150
151
152
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(FAKEDATA_DIR, ".png")],
)
@pytest.mark.parametrize(
    "pil_mode, mode",
    [
        (None, ImageReadMode.UNCHANGED),
        ("L", ImageReadMode.GRAY),
        ("LA", ImageReadMode.GRAY_ALPHA),
        ("RGB", ImageReadMode.RGB),
        ("RGBA", ImageReadMode.RGB_ALPHA),
    ],
)
153
154
155
156
157
158
159
160
def test_decode_png(img_path, pil_mode, mode):

    with Image.open(img_path) as img:
        if pil_mode is not None:
            img = img.convert(pil_mode)
        img_pil = torch.from_numpy(np.array(img))

    img_pil = normalize_dimensions(img_pil)
161

162
    if img_path.endswith("16.png"):
163
164
165
166
167
168
169
170
171
        # 16 bits image decoding is supported, but only as a private API
        # FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
        with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
            data = read_file(img_path)
            img_lpng = decode_image(data, mode=mode)

        img_lpng = _read_png_16(img_path, mode=mode)
        assert img_lpng.dtype == torch.int32
        # PIL converts 16 bits pngs in uint8
172
        img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8)
173
174
175
    else:
        data = read_file(img_path)
        img_lpng = decode_image(data, mode=mode)
176
177

    tol = 0 if pil_mode is None else 1
178
179
180
181
182
183
184
185
186

    if PILLOW_VERSION >= (8, 3) and pil_mode == "LA":
        # Avoid checking the transparency channel until
        # https://github.com/python-pillow/Pillow/issues/5593#issuecomment-878244910
        # is fixed.
        # TODO: remove once fix is released in PIL. Should be > 8.3.1.
        img_lpng, img_pil = img_lpng[0], img_pil[0]

    torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
187
188
189
190
191
192
193


def test_decode_png_errors():
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
        decode_png(torch.empty((), dtype=torch.uint8))
    with pytest.raises(RuntimeError, match="Content is not png"):
        decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
194
195
    with pytest.raises(RuntimeError, match="Out of bound read in decode_png"):
        decode_png(read_file(os.path.join(DAMAGED_PNG, "sigsegv.png")))
196
197


198
199
200
201
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
)
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def test_encode_png(img_path):
    pil_image = Image.open(img_path)
    img_pil = torch.from_numpy(np.array(pil_image))
    img_pil = img_pil.permute(2, 0, 1)
    png_buf = encode_png(img_pil, compression_level=6)

    rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
    rec_img = torch.from_numpy(np.array(rec_img))
    rec_img = rec_img.permute(2, 0, 1)

    assert_equal(img_pil, rec_img)


def test_encode_png_errors():
    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
        encode_png(torch.empty((3, 100, 100), dtype=torch.float32))

    with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
220
        encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=-1)
221
222

    with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
223
        encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=10)
224
225
226
227
228

    with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
        encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))


229
230
231
232
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
)
233
234
235
236
def test_write_png(img_path, tmpdir):
    pil_image = Image.open(img_path)
    img_pil = torch.from_numpy(np.array(pil_image))
    img_pil = img_pil.permute(2, 0, 1)
237

238
    filename, _ = os.path.splitext(os.path.basename(img_path))
239
    torch_png = os.path.join(tmpdir, f"{filename}_torch.png")
240
241
242
    write_png(img_pil, torch_png, compression_level=6)
    saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
    saved_image = saved_image.permute(2, 0, 1)
243

244
    assert_equal(img_pil, saved_image)
245
246


247
def test_read_file(tmpdir):
248
    fname, content = "test1.bin", b"TorchVision\211\n"
249
    fpath = os.path.join(tmpdir, fname)
250
    with open(fpath, "wb") as f:
251
        f.write(content)
252

253
254
255
256
    data = read_file(fpath)
    expected = torch.tensor(list(content), dtype=torch.uint8)
    os.unlink(fpath)
    assert_equal(data, expected)
257
258

    with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"):
259
        read_file("tst")
260
261


262
def test_read_file_non_ascii(tmpdir):
263
    fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
264
    fpath = os.path.join(tmpdir, fname)
265
    with open(fpath, "wb") as f:
266
        f.write(content)
267

268
269
270
271
    data = read_file(fpath)
    expected = torch.tensor(list(content), dtype=torch.uint8)
    os.unlink(fpath)
    assert_equal(data, expected)
272
273


274
def test_write_file(tmpdir):
275
    fname, content = "test1.bin", b"TorchVision\211\n"
276
277
278
    fpath = os.path.join(tmpdir, fname)
    content_tensor = torch.tensor(list(content), dtype=torch.uint8)
    write_file(fpath, content_tensor)
279

280
    with open(fpath, "rb") as f:
281
282
283
        saved_content = f.read()
    os.unlink(fpath)
    assert content == saved_content
284
285


286
def test_write_file_non_ascii(tmpdir):
287
    fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
288
289
290
    fpath = os.path.join(tmpdir, fname)
    content_tensor = torch.tensor(list(content), dtype=torch.uint8)
    write_file(fpath, content_tensor)
291

292
    with open(fpath, "rb") as f:
293
294
295
        saved_content = f.read()
    os.unlink(fpath)
    assert content == saved_content
296

297

298
299
300
301
302
303
304
305
@pytest.mark.parametrize(
    "shape",
    [
        (27, 27),
        (60, 60),
        (105, 105),
    ],
)
306
def test_read_1_bit_png(shape, tmpdir):
307
    np_rng = np.random.RandomState(0)
308
    image_path = os.path.join(tmpdir, f"test_{shape}.png")
309
310
311
312
313
314
    pixels = np_rng.rand(*shape) > 0.5
    img = Image.fromarray(pixels)
    img.save(image_path)
    img1 = read_image(image_path)
    img2 = normalize_dimensions(torch.as_tensor(pixels * 255, dtype=torch.uint8))
    assert_equal(img1, img2)
Prabhat Roy's avatar
Prabhat Roy committed
315
316


317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
@pytest.mark.parametrize(
    "shape",
    [
        (27, 27),
        (60, 60),
        (105, 105),
    ],
)
@pytest.mark.parametrize(
    "mode",
    [
        ImageReadMode.UNCHANGED,
        ImageReadMode.GRAY,
    ],
)
332
def test_read_1_bit_png_consistency(shape, mode, tmpdir):
333
    np_rng = np.random.RandomState(0)
334
    image_path = os.path.join(tmpdir, f"test_{shape}.png")
335
336
337
338
339
340
    pixels = np_rng.rand(*shape) > 0.5
    img = Image.fromarray(pixels)
    img.save(image_path)
    img1 = read_image(image_path, mode)
    img2 = read_image(image_path, mode)
    assert_equal(img1, img2)
Prabhat Roy's avatar
Prabhat Roy committed
341
342


343
344
345
346
347
348
349
350
351
def test_read_interlaced_png():
    imgs = list(get_images(INTERLACED_PNG, ".png"))
    with Image.open(imgs[0]) as im1, Image.open(imgs[1]) as im2:
        assert not (im1.info.get("interlace") is im2.info.get("interlace"))
    img1 = read_image(imgs[0])
    img2 = read_image(imgs[1])
    assert_equal(img1, img2)


352
@needs_cuda
353
354
355
356
357
358
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
@pytest.mark.parametrize("scripted", (False, True))
359
def test_decode_jpeg_cuda(mode, img_path, scripted):
360
    if "cmyk" in img_path:
361
        pytest.xfail("Decoding a CMYK jpeg isn't supported")
362

363
364
365
    data = read_file(img_path)
    img = decode_image(data, mode=mode)
    f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
366
    img_nvjpeg = f(data, mode=mode, device="cuda")
367
368

    # Some difference expected between jpeg implementations
369
    assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2
370

Nicolas Hug's avatar
Nicolas Hug committed
371

372
373
374
375
376
377
@needs_cuda
def test_decode_image_cuda_raises():
    data = torch.randint(0, 127, size=(255,), device="cuda", dtype=torch.uint8)
    with pytest.raises(RuntimeError):
        decode_image(data)

378
379

@needs_cuda
380
@pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda")))
381
382
def test_decode_jpeg_cuda_device_param(cuda_device):
    """Make sure we can pass a string or a torch.device as device param"""
383
    path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
384
    data = read_file(path)
385
386
387
388
389
390
391
    decode_jpeg(data, device=cuda_device)


@needs_cuda
def test_decode_jpeg_cuda_errors():
    data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
392
        decode_jpeg(data.reshape(-1, 1), device="cuda")
393
    with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
394
        decode_jpeg(data.to("cuda"), device="cuda")
395
    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
396
        decode_jpeg(data.to(torch.float), device="cuda")
397
    with pytest.raises(RuntimeError, match="Expected a cuda device"):
398
        torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, "cpu")
399
400


401
402
403
404
405
def test_encode_jpeg_errors():

    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))

406
    with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
407
408
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)

409
    with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
410
411
412
413
414
415
416
417
418
419
420
421
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)

    with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
        encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
        encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
        encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))


422
423
424
425
426
427
428
429
def _collect_if(cond):
    # TODO: remove this once test_encode_jpeg_reference and test_write_jpeg_reference
    # are removed
    def _inner(test_func):
        if cond:
            return test_func
        else:
            return pytest.mark.dont_collect(test_func)
430

431
432
433
    return _inner


434
@_collect_if(cond=False)
435
436
437
438
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
439
def test_encode_jpeg_reference(img_path):
440
    # This test is *wrong*.
441
    # It compares a torchvision-encoded jpeg with a PIL-encoded jpeg (the reference), but it
442
443
444
445
446
447
448
449
    # starts encoding the torchvision version from an image that comes from
    # decode_jpeg, which can yield different results from pil.decode (see
    # test_decode... which uses a high tolerance).
    # Instead, we should start encoding from the exact same decoded image, for a
    # valid comparison. This is done in test_encode_jpeg, but unfortunately
    # these more correct tests fail on windows (probably because of a difference
    # in libjpeg) between torchvision and PIL.
    # FIXME: make the correct tests pass on windows and remove this.
450
451
    dirname = os.path.dirname(img_path)
    filename, _ = os.path.splitext(os.path.basename(img_path))
452
    write_folder = os.path.join(dirname, "jpeg_write")
453
    expected_file = os.path.join(write_folder, f"{filename}_pil.jpg")
454
455
    img = decode_jpeg(read_file(img_path))

456
    with open(expected_file, "rb") as f:
457
458
459
460
461
462
        pil_bytes = f.read()
        pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
    for src_img in [img, img.contiguous()]:
        # PIL sets jpeg quality to 75 by default
        jpeg_bytes = encode_jpeg(src_img, quality=75)
        assert_equal(jpeg_bytes, pil_bytes)
463
464


465
@_collect_if(cond=False)
466
467
468
469
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
470
def test_write_jpeg_reference(img_path, tmpdir):
471
    # FIXME: Remove this eventually, see test_encode_jpeg_reference
472
473
    data = read_file(img_path)
    img = decode_jpeg(data)
474

475
476
    basedir = os.path.dirname(img_path)
    filename, _ = os.path.splitext(os.path.basename(img_path))
477
478
    torch_jpeg = os.path.join(tmpdir, f"{filename}_torch.jpg")
    pil_jpeg = os.path.join(basedir, "jpeg_write", f"{filename}_pil.jpg")
479

480
    write_jpeg(img, torch_jpeg, quality=75)
481

482
    with open(torch_jpeg, "rb") as f:
483
        torch_bytes = f.read()
484

485
    with open(pil_jpeg, "rb") as f:
486
        pil_bytes = f.read()
487

488
    assert_equal(torch_bytes, pil_bytes)
489
490


491
492
# TODO: Remove the skip. See https://github.com/pytorch/vision/issues/5162.
@pytest.mark.skip("this test fails because PIL uses libjpeg-turbo")
493
494
495
496
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
497
498
499
500
501
def test_encode_jpeg(img_path):
    img = read_image(img_path)

    pil_img = F.to_pil_image(img)
    buf = io.BytesIO()
502
    pil_img.save(buf, format="JPEG", quality=75)
503

504
    encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8)
505
506
507
508
509
510

    for src_img in [img, img.contiguous()]:
        encoded_jpeg_torch = encode_jpeg(src_img, quality=75)
        assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)


511
512
# TODO: Remove the skip. See https://github.com/pytorch/vision/issues/5162.
@pytest.mark.skip("this test fails because PIL uses libjpeg-turbo")
513
514
515
516
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
517
518
519
520
def test_write_jpeg(img_path, tmpdir):
    tmpdir = Path(tmpdir)
    img = read_image(img_path)
    pil_img = F.to_pil_image(img)
521

522
523
    torch_jpeg = str(tmpdir / "torch.jpg")
    pil_jpeg = str(tmpdir / "pil.jpg")
524

525
526
    write_jpeg(img, torch_jpeg, quality=75)
    pil_img.save(pil_jpeg, quality=75)
527

528
    with open(torch_jpeg, "rb") as f:
529
        torch_bytes = f.read()
530

531
    with open(pil_jpeg, "rb") as f:
532
        pil_bytes = f.read()
533

534
    assert_equal(torch_bytes, pil_bytes)
535
536


537
538
if __name__ == "__main__":
    pytest.main([__file__])