test_image.py 18.7 KB
Newer Older
1
import glob
2
3
import io
import os
4
5
import sys
from pathlib import Path
6

7
import numpy as np
8
import pytest
9
import torch
10
import torchvision.transforms.functional as F
11
from common_utils import assert_equal, needs_cuda
12
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps
13
from torchvision.io.image import (
14
15
    _read_png_16,
    decode_image,
16
    decode_jpeg,
17
    decode_png,
18
19
20
    encode_jpeg,
    encode_png,
    ImageReadMode,
21
    read_file,
22
    read_image,
23
24
25
    write_file,
    write_jpeg,
    write_png,
26
)
Francisco Massa's avatar
Francisco Massa committed
27

28
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
29
30
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
31
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, "damaged_jpeg")
32
DAMAGED_PNG = os.path.join(IMAGE_ROOT, "damaged_png")
33
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
34
INTERLACED_PNG = os.path.join(IMAGE_ROOT, "interlaced_png")
35
TOOSMALL_PNG = os.path.join(IMAGE_ROOT, "toosmall_png")
36
IS_WINDOWS = sys.platform in ("win32", "cygwin")
37
IS_MACOS = sys.platform == "darwin"
38
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
39
40
41
42
43
44
45
46


def _get_safe_image_name(name):
    # Used when we need to change the pytest "id" for an "image path" parameter.
    # If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific,
    # and this creates issues when the test is running in a different machine than where it was collected
    # (typically, in fb internal infra)
    return name.split(os.path.sep)[-1]
47
48
49
50


def get_images(directory, img_ext):
    assert os.path.isdir(directory)
51
    image_paths = glob.glob(directory + f"/**/*{img_ext}", recursive=True)
52
    for path in image_paths:
53
        if path.split(os.sep)[-2] not in ["damaged_jpeg", "jpeg_write"]:
54
            yield path
55
56


57
58
59
60
61
62
63
64
65
66
67
68
69
def pil_read_image(img_path):
    with Image.open(img_path) as img:
        return torch.from_numpy(np.array(img))


def normalize_dimensions(img_pil):
    if len(img_pil.shape) == 3:
        img_pil = img_pil.permute(2, 0, 1)
    else:
        img_pil = img_pil.unsqueeze(0)
    return img_pil


70
71
72
73
74
75
76
77
78
79
80
81
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize(
    "pil_mode, mode",
    [
        (None, ImageReadMode.UNCHANGED),
        ("L", ImageReadMode.GRAY),
        ("RGB", ImageReadMode.RGB),
    ],
)
82
83
84
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize("decode_fun", (decode_jpeg, decode_image))
def test_decode_jpeg(img_path, pil_mode, mode, scripted, decode_fun):
85
86
87
88
89
90

    with Image.open(img_path) as img:
        is_cmyk = img.mode == "CMYK"
        if pil_mode is not None:
            img = img.convert(pil_mode)
        img_pil = torch.from_numpy(np.array(img))
91
        if is_cmyk and mode == ImageReadMode.UNCHANGED:
92
93
94
95
96
            # flip the colors to match libjpeg
            img_pil = 255 - img_pil

    img_pil = normalize_dimensions(img_pil)
    data = read_file(img_path)
97
98
99
    if scripted:
        decode_fun = torch.jit.script(decode_fun)
    img_ljpeg = decode_fun(data, mode=mode)
100
101
102
103
104
105
106

    # Permit a small variation on pixel values to account for implementation
    # differences between Pillow and LibJPEG.
    abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
    assert abs_mean_diff < 2


107
@pytest.mark.parametrize("codec", ["png", "jpeg"])
108
@pytest.mark.parametrize("orientation", [1, 2, 3, 4, 5, 6, 7, 8, 0])
109
110
def test_decode_with_exif_orientation(tmpdir, codec, orientation):
    fp = os.path.join(tmpdir, f"exif_oriented_{orientation}.{codec}")
111
112
113
114
    t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
    im = F.to_pil_image(t)
    exif = im.getexif()
    exif[0x0112] = orientation  # set exif orientation
115
    im.save(fp, codec.upper(), exif=exif.tobytes())
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

    data = read_file(fp)
    output = decode_image(data, apply_exif_orientation=True)

    pimg = Image.open(fp)
    pimg = ImageOps.exif_transpose(pimg)

    expected = F.pil_to_tensor(pimg)
    torch.testing.assert_close(expected, output)


@pytest.mark.parametrize("size", [65533, 1, 7, 10, 23, 33])
def test_invalid_exif(tmpdir, size):
    # Inspired from a PIL test:
    # https://github.com/python-pillow/Pillow/blob/8f63748e50378424628155994efd7e0739a4d1d1/Tests/test_file_jpeg.py#L299
    fp = os.path.join(tmpdir, "invalid_exif.jpg")
    t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
    im = F.to_pil_image(t)
    im.save(fp, "JPEG", exif=b"1" * size)

    data = read_file(fp)
    output = decode_image(data, apply_exif_orientation=True)

    pimg = Image.open(fp)
    pimg = ImageOps.exif_transpose(pimg)

    expected = F.pil_to_tensor(pimg)
    torch.testing.assert_close(expected, output)


146
147
148
149
150
151
152
153
154
155
156
157
158
def test_decode_jpeg_errors():
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
        decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
        decode_jpeg(torch.empty((100,), dtype=torch.float16))

    with pytest.raises(RuntimeError, match="Not a JPEG file"):
        decode_jpeg(torch.empty((100), dtype=torch.uint8))


def test_decode_bad_huffman_images():
    # sanity check: make sure we can decode the bad Huffman encoding
159
    bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg"))
160
161
162
    decode_jpeg(bad_huff)


163
164
165
166
167
168
169
@pytest.mark.parametrize(
    "img_path",
    [
        pytest.param(truncated_image, id=_get_safe_image_name(truncated_image))
        for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, "corrupt*.jpg"))
    ],
)
170
171
172
def test_damaged_corrupt_images(img_path):
    # Truncated images should raise an exception
    data = read_file(img_path)
173
    if "corrupt34" in img_path:
174
175
176
177
178
179
180
        match_message = "Image is incomplete or truncated"
    else:
        match_message = "Unsupported marker type"
    with pytest.raises(RuntimeError, match=match_message):
        decode_jpeg(data)


181
182
183
184
185
186
187
188
189
190
191
192
193
194
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(FAKEDATA_DIR, ".png")],
)
@pytest.mark.parametrize(
    "pil_mode, mode",
    [
        (None, ImageReadMode.UNCHANGED),
        ("L", ImageReadMode.GRAY),
        ("LA", ImageReadMode.GRAY_ALPHA),
        ("RGB", ImageReadMode.RGB),
        ("RGBA", ImageReadMode.RGB_ALPHA),
    ],
)
195
196
197
198
199
200
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize("decode_fun", (decode_png, decode_image))
def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun):

    if scripted:
        decode_fun = torch.jit.script(decode_fun)
201
202
203
204
205
206
207

    with Image.open(img_path) as img:
        if pil_mode is not None:
            img = img.convert(pil_mode)
        img_pil = torch.from_numpy(np.array(img))

    img_pil = normalize_dimensions(img_pil)
208

209
    if img_path.endswith("16.png"):
210
211
212
213
        # 16 bits image decoding is supported, but only as a private API
        # FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
        with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
            data = read_file(img_path)
214
            img_lpng = decode_fun(data, mode=mode)
215
216
217
218

        img_lpng = _read_png_16(img_path, mode=mode)
        assert img_lpng.dtype == torch.int32
        # PIL converts 16 bits pngs in uint8
219
        img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8)
220
221
    else:
        data = read_file(img_path)
222
        img_lpng = decode_fun(data, mode=mode)
223
224

    tol = 0 if pil_mode is None else 1
225
226
227
228
229
230
231
232
233

    if PILLOW_VERSION >= (8, 3) and pil_mode == "LA":
        # Avoid checking the transparency channel until
        # https://github.com/python-pillow/Pillow/issues/5593#issuecomment-878244910
        # is fixed.
        # TODO: remove once fix is released in PIL. Should be > 8.3.1.
        img_lpng, img_pil = img_lpng[0], img_pil[0]

    torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
234
235
236
237
238
239
240


def test_decode_png_errors():
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
        decode_png(torch.empty((), dtype=torch.uint8))
    with pytest.raises(RuntimeError, match="Content is not png"):
        decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
241
242
    with pytest.raises(RuntimeError, match="Out of bound read in decode_png"):
        decode_png(read_file(os.path.join(DAMAGED_PNG, "sigsegv.png")))
243
244
    with pytest.raises(RuntimeError, match="Content is too small for png"):
        decode_png(read_file(os.path.join(TOOSMALL_PNG, "heapbof.png")))
245
246


247
248
249
250
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
)
251
252
@pytest.mark.parametrize("scripted", (True, False))
def test_encode_png(img_path, scripted):
253
254
255
    pil_image = Image.open(img_path)
    img_pil = torch.from_numpy(np.array(pil_image))
    img_pil = img_pil.permute(2, 0, 1)
256
257
    encode = torch.jit.script(encode_png) if scripted else encode_png
    png_buf = encode(img_pil, compression_level=6)
258
259
260
261
262
263
264
265
266
267
268
269
270

    rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
    rec_img = torch.from_numpy(np.array(rec_img))
    rec_img = rec_img.permute(2, 0, 1)

    assert_equal(img_pil, rec_img)


def test_encode_png_errors():
    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
        encode_png(torch.empty((3, 100, 100), dtype=torch.float32))

    with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
271
        encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=-1)
272
273

    with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
274
        encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=10)
275
276
277
278
279

    with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
        encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))


280
281
282
283
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
)
284
285
@pytest.mark.parametrize("scripted", (True, False))
def test_write_png(img_path, tmpdir, scripted):
286
287
288
    pil_image = Image.open(img_path)
    img_pil = torch.from_numpy(np.array(pil_image))
    img_pil = img_pil.permute(2, 0, 1)
289

290
    filename, _ = os.path.splitext(os.path.basename(img_path))
291
    torch_png = os.path.join(tmpdir, f"{filename}_torch.png")
292
293
    write = torch.jit.script(write_png) if scripted else write_png
    write(img_pil, torch_png, compression_level=6)
294
295
    saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
    saved_image = saved_image.permute(2, 0, 1)
296

297
    assert_equal(img_pil, saved_image)
298
299


300
301
302
303
304
305
306
307
308
309
def test_read_image():
    # Just testing torchcsript, the functionality is somewhat tested already in other tests.
    path = next(get_images(IMAGE_ROOT, ".jpg"))
    out = read_image(path)
    out_scripted = torch.jit.script(read_image)(path)
    torch.testing.assert_close(out, out_scripted, atol=0, rtol=0)


@pytest.mark.parametrize("scripted", (True, False))
def test_read_file(tmpdir, scripted):
310
    fname, content = "test1.bin", b"TorchVision\211\n"
311
    fpath = os.path.join(tmpdir, fname)
312
    with open(fpath, "wb") as f:
313
        f.write(content)
314

315
316
    fun = torch.jit.script(read_file) if scripted else read_file
    data = fun(fpath)
317
318
319
    expected = torch.tensor(list(content), dtype=torch.uint8)
    os.unlink(fpath)
    assert_equal(data, expected)
320
321

    with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"):
322
        read_file("tst")
323
324


325
def test_read_file_non_ascii(tmpdir):
326
    fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
327
    fpath = os.path.join(tmpdir, fname)
328
    with open(fpath, "wb") as f:
329
        f.write(content)
330

331
332
333
334
    data = read_file(fpath)
    expected = torch.tensor(list(content), dtype=torch.uint8)
    os.unlink(fpath)
    assert_equal(data, expected)
335
336


337
338
@pytest.mark.parametrize("scripted", (True, False))
def test_write_file(tmpdir, scripted):
339
    fname, content = "test1.bin", b"TorchVision\211\n"
340
341
    fpath = os.path.join(tmpdir, fname)
    content_tensor = torch.tensor(list(content), dtype=torch.uint8)
342
343
    write = torch.jit.script(write_file) if scripted else write_file
    write(fpath, content_tensor)
344

345
    with open(fpath, "rb") as f:
346
347
348
        saved_content = f.read()
    os.unlink(fpath)
    assert content == saved_content
349
350


351
def test_write_file_non_ascii(tmpdir):
352
    fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
353
354
355
    fpath = os.path.join(tmpdir, fname)
    content_tensor = torch.tensor(list(content), dtype=torch.uint8)
    write_file(fpath, content_tensor)
356

357
    with open(fpath, "rb") as f:
358
359
360
        saved_content = f.read()
    os.unlink(fpath)
    assert content == saved_content
361

362

363
364
365
366
367
368
369
370
@pytest.mark.parametrize(
    "shape",
    [
        (27, 27),
        (60, 60),
        (105, 105),
    ],
)
371
def test_read_1_bit_png(shape, tmpdir):
372
    np_rng = np.random.RandomState(0)
373
    image_path = os.path.join(tmpdir, f"test_{shape}.png")
374
375
376
377
378
379
    pixels = np_rng.rand(*shape) > 0.5
    img = Image.fromarray(pixels)
    img.save(image_path)
    img1 = read_image(image_path)
    img2 = normalize_dimensions(torch.as_tensor(pixels * 255, dtype=torch.uint8))
    assert_equal(img1, img2)
Prabhat Roy's avatar
Prabhat Roy committed
380
381


382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
@pytest.mark.parametrize(
    "shape",
    [
        (27, 27),
        (60, 60),
        (105, 105),
    ],
)
@pytest.mark.parametrize(
    "mode",
    [
        ImageReadMode.UNCHANGED,
        ImageReadMode.GRAY,
    ],
)
397
def test_read_1_bit_png_consistency(shape, mode, tmpdir):
398
    np_rng = np.random.RandomState(0)
399
    image_path = os.path.join(tmpdir, f"test_{shape}.png")
400
401
402
403
404
405
    pixels = np_rng.rand(*shape) > 0.5
    img = Image.fromarray(pixels)
    img.save(image_path)
    img1 = read_image(image_path, mode)
    img2 = read_image(image_path, mode)
    assert_equal(img1, img2)
Prabhat Roy's avatar
Prabhat Roy committed
406
407


408
409
410
411
412
413
414
415
416
def test_read_interlaced_png():
    imgs = list(get_images(INTERLACED_PNG, ".png"))
    with Image.open(imgs[0]) as im1, Image.open(imgs[1]) as im2:
        assert not (im1.info.get("interlace") is im2.info.get("interlace"))
    img1 = read_image(imgs[0])
    img2 = read_image(imgs[1])
    assert_equal(img1, img2)


417
@needs_cuda
418
419
420
421
422
423
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
@pytest.mark.parametrize("scripted", (False, True))
424
def test_decode_jpeg_cuda(mode, img_path, scripted):
425
    if "cmyk" in img_path:
426
        pytest.xfail("Decoding a CMYK jpeg isn't supported")
427

428
429
430
    data = read_file(img_path)
    img = decode_image(data, mode=mode)
    f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
431
    img_nvjpeg = f(data, mode=mode, device="cuda")
432
433

    # Some difference expected between jpeg implementations
434
    assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2
435

Nicolas Hug's avatar
Nicolas Hug committed
436

437
438
439
440
441
442
@needs_cuda
def test_decode_image_cuda_raises():
    data = torch.randint(0, 127, size=(255,), device="cuda", dtype=torch.uint8)
    with pytest.raises(RuntimeError):
        decode_image(data)

443
444

@needs_cuda
445
@pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda")))
446
447
def test_decode_jpeg_cuda_device_param(cuda_device):
    """Make sure we can pass a string or a torch.device as device param"""
448
    path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
449
    data = read_file(path)
450
451
452
453
454
455
456
    decode_jpeg(data, device=cuda_device)


@needs_cuda
def test_decode_jpeg_cuda_errors():
    data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
457
        decode_jpeg(data.reshape(-1, 1), device="cuda")
458
    with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
459
        decode_jpeg(data.to("cuda"), device="cuda")
460
    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
461
        decode_jpeg(data.to(torch.float), device="cuda")
462
    with pytest.raises(RuntimeError, match="Expected a cuda device"):
463
        torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, "cpu")
464
465


466
467
468
469
470
def test_encode_jpeg_errors():

    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))

471
    with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
472
473
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)

474
    with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
475
476
477
478
479
480
481
482
483
484
485
486
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)

    with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
        encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
        encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
        encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))


487
@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
488
489
490
491
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
492
493
@pytest.mark.parametrize("scripted", (True, False))
def test_encode_jpeg(img_path, scripted):
494
495
496
497
    img = read_image(img_path)

    pil_img = F.to_pil_image(img)
    buf = io.BytesIO()
498
    pil_img.save(buf, format="JPEG", quality=75)
499

500
    encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8)
501

502
    encode = torch.jit.script(encode_jpeg) if scripted else encode_jpeg
503
    for src_img in [img, img.contiguous()]:
504
        encoded_jpeg_torch = encode(src_img, quality=75)
505
506
507
        assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)


508
@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
509
510
511
512
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
513
514
@pytest.mark.parametrize("scripted", (True, False))
def test_write_jpeg(img_path, tmpdir, scripted):
515
516
517
    tmpdir = Path(tmpdir)
    img = read_image(img_path)
    pil_img = F.to_pil_image(img)
518

519
520
    torch_jpeg = str(tmpdir / "torch.jpg")
    pil_jpeg = str(tmpdir / "pil.jpg")
521

522
523
    write = torch.jit.script(write_jpeg) if scripted else write_jpeg
    write(img, torch_jpeg, quality=75)
524
    pil_img.save(pil_jpeg, quality=75)
525

526
    with open(torch_jpeg, "rb") as f:
527
        torch_bytes = f.read()
528

529
    with open(pil_jpeg, "rb") as f:
530
        pil_bytes = f.read()
531

532
    assert_equal(torch_bytes, pil_bytes)
533
534


535
536
if __name__ == "__main__":
    pytest.main([__file__])