test_image.py 17.2 KB
Newer Older
1
import glob
2
3
import io
import os
4
5
import sys
from pathlib import Path
6

7
import numpy as np
8
import pytest
9
import torch
10
import torchvision.transforms.functional as F
11
from common_utils import assert_equal, needs_cuda
12
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps
13
from torchvision.io.image import (
14
15
    _read_png_16,
    decode_image,
16
    decode_jpeg,
17
    decode_png,
18
19
20
    encode_jpeg,
    encode_png,
    ImageReadMode,
21
    read_file,
22
    read_image,
23
24
25
    write_file,
    write_jpeg,
    write_png,
26
)
Francisco Massa's avatar
Francisco Massa committed
27

28
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
29
30
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
31
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, "damaged_jpeg")
32
DAMAGED_PNG = os.path.join(IMAGE_ROOT, "damaged_png")
33
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
34
INTERLACED_PNG = os.path.join(IMAGE_ROOT, "interlaced_png")
35
TOOSMALL_PNG = os.path.join(IMAGE_ROOT, "toosmall_png")
36
IS_WINDOWS = sys.platform in ("win32", "cygwin")
37
IS_MACOS = sys.platform == "darwin"
38
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
39
40
41
42
43
44
45
46


def _get_safe_image_name(name):
    # Used when we need to change the pytest "id" for an "image path" parameter.
    # If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific,
    # and this creates issues when the test is running in a different machine than where it was collected
    # (typically, in fb internal infra)
    return name.split(os.path.sep)[-1]
47
48
49
50


def get_images(directory, img_ext):
    assert os.path.isdir(directory)
51
    image_paths = glob.glob(directory + f"/**/*{img_ext}", recursive=True)
52
    for path in image_paths:
53
        if path.split(os.sep)[-2] not in ["damaged_jpeg", "jpeg_write"]:
54
            yield path
55
56


57
58
59
60
61
62
63
64
65
66
67
68
69
def pil_read_image(img_path):
    with Image.open(img_path) as img:
        return torch.from_numpy(np.array(img))


def normalize_dimensions(img_pil):
    if len(img_pil.shape) == 3:
        img_pil = img_pil.permute(2, 0, 1)
    else:
        img_pil = img_pil.unsqueeze(0)
    return img_pil


70
71
72
73
74
75
76
77
78
79
80
81
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize(
    "pil_mode, mode",
    [
        (None, ImageReadMode.UNCHANGED),
        ("L", ImageReadMode.GRAY),
        ("RGB", ImageReadMode.RGB),
    ],
)
82
83
84
85
86
87
88
def test_decode_jpeg(img_path, pil_mode, mode):

    with Image.open(img_path) as img:
        is_cmyk = img.mode == "CMYK"
        if pil_mode is not None:
            img = img.convert(pil_mode)
        img_pil = torch.from_numpy(np.array(img))
89
        if is_cmyk and mode == ImageReadMode.UNCHANGED:
90
91
92
93
94
95
96
97
98
99
100
101
102
            # flip the colors to match libjpeg
            img_pil = 255 - img_pil

    img_pil = normalize_dimensions(img_pil)
    data = read_file(img_path)
    img_ljpeg = decode_image(data, mode=mode)

    # Permit a small variation on pixel values to account for implementation
    # differences between Pillow and LibJPEG.
    abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
    assert abs_mean_diff < 2


103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
@pytest.mark.parametrize("orientation", [1, 2, 3, 4, 5, 6, 7, 8, 0])
def test_decode_jpeg_with_exif_orientation(tmpdir, orientation):
    fp = os.path.join(tmpdir, f"exif_oriented_{orientation}.jpg")
    t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
    im = F.to_pil_image(t)
    exif = im.getexif()
    exif[0x0112] = orientation  # set exif orientation
    im.save(fp, "JPEG", exif=exif.tobytes())

    data = read_file(fp)
    output = decode_image(data, apply_exif_orientation=True)

    pimg = Image.open(fp)
    pimg = ImageOps.exif_transpose(pimg)

    expected = F.pil_to_tensor(pimg)
    torch.testing.assert_close(expected, output)


@pytest.mark.parametrize("size", [65533, 1, 7, 10, 23, 33])
def test_invalid_exif(tmpdir, size):
    # Inspired from a PIL test:
    # https://github.com/python-pillow/Pillow/blob/8f63748e50378424628155994efd7e0739a4d1d1/Tests/test_file_jpeg.py#L299
    fp = os.path.join(tmpdir, "invalid_exif.jpg")
    t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
    im = F.to_pil_image(t)
    im.save(fp, "JPEG", exif=b"1" * size)

    data = read_file(fp)
    output = decode_image(data, apply_exif_orientation=True)

    pimg = Image.open(fp)
    pimg = ImageOps.exif_transpose(pimg)

    expected = F.pil_to_tensor(pimg)
    torch.testing.assert_close(expected, output)


141
142
143
144
145
146
147
148
149
150
151
152
153
def test_decode_jpeg_errors():
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
        decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
        decode_jpeg(torch.empty((100,), dtype=torch.float16))

    with pytest.raises(RuntimeError, match="Not a JPEG file"):
        decode_jpeg(torch.empty((100), dtype=torch.uint8))


def test_decode_bad_huffman_images():
    # sanity check: make sure we can decode the bad Huffman encoding
154
    bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg"))
155
156
157
    decode_jpeg(bad_huff)


158
159
160
161
162
163
164
@pytest.mark.parametrize(
    "img_path",
    [
        pytest.param(truncated_image, id=_get_safe_image_name(truncated_image))
        for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, "corrupt*.jpg"))
    ],
)
165
166
167
def test_damaged_corrupt_images(img_path):
    # Truncated images should raise an exception
    data = read_file(img_path)
168
    if "corrupt34" in img_path:
169
170
171
172
173
174
175
        match_message = "Image is incomplete or truncated"
    else:
        match_message = "Unsupported marker type"
    with pytest.raises(RuntimeError, match=match_message):
        decode_jpeg(data)


176
177
178
179
180
181
182
183
184
185
186
187
188
189
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(FAKEDATA_DIR, ".png")],
)
@pytest.mark.parametrize(
    "pil_mode, mode",
    [
        (None, ImageReadMode.UNCHANGED),
        ("L", ImageReadMode.GRAY),
        ("LA", ImageReadMode.GRAY_ALPHA),
        ("RGB", ImageReadMode.RGB),
        ("RGBA", ImageReadMode.RGB_ALPHA),
    ],
)
190
191
192
193
194
195
196
197
def test_decode_png(img_path, pil_mode, mode):

    with Image.open(img_path) as img:
        if pil_mode is not None:
            img = img.convert(pil_mode)
        img_pil = torch.from_numpy(np.array(img))

    img_pil = normalize_dimensions(img_pil)
198

199
    if img_path.endswith("16.png"):
200
201
202
203
204
205
206
207
208
        # 16 bits image decoding is supported, but only as a private API
        # FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
        with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
            data = read_file(img_path)
            img_lpng = decode_image(data, mode=mode)

        img_lpng = _read_png_16(img_path, mode=mode)
        assert img_lpng.dtype == torch.int32
        # PIL converts 16 bits pngs in uint8
209
        img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8)
210
211
212
    else:
        data = read_file(img_path)
        img_lpng = decode_image(data, mode=mode)
213
214

    tol = 0 if pil_mode is None else 1
215
216
217
218
219
220
221
222
223

    if PILLOW_VERSION >= (8, 3) and pil_mode == "LA":
        # Avoid checking the transparency channel until
        # https://github.com/python-pillow/Pillow/issues/5593#issuecomment-878244910
        # is fixed.
        # TODO: remove once fix is released in PIL. Should be > 8.3.1.
        img_lpng, img_pil = img_lpng[0], img_pil[0]

    torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
224
225
226
227
228
229
230


def test_decode_png_errors():
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
        decode_png(torch.empty((), dtype=torch.uint8))
    with pytest.raises(RuntimeError, match="Content is not png"):
        decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
231
232
    with pytest.raises(RuntimeError, match="Out of bound read in decode_png"):
        decode_png(read_file(os.path.join(DAMAGED_PNG, "sigsegv.png")))
233
234
    with pytest.raises(RuntimeError, match="Content is too small for png"):
        decode_png(read_file(os.path.join(TOOSMALL_PNG, "heapbof.png")))
235
236


237
238
239
240
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
)
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def test_encode_png(img_path):
    pil_image = Image.open(img_path)
    img_pil = torch.from_numpy(np.array(pil_image))
    img_pil = img_pil.permute(2, 0, 1)
    png_buf = encode_png(img_pil, compression_level=6)

    rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
    rec_img = torch.from_numpy(np.array(rec_img))
    rec_img = rec_img.permute(2, 0, 1)

    assert_equal(img_pil, rec_img)


def test_encode_png_errors():
    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
        encode_png(torch.empty((3, 100, 100), dtype=torch.float32))

    with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
259
        encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=-1)
260
261

    with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
262
        encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=10)
263
264
265
266
267

    with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
        encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))


268
269
270
271
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
)
272
273
274
275
def test_write_png(img_path, tmpdir):
    pil_image = Image.open(img_path)
    img_pil = torch.from_numpy(np.array(pil_image))
    img_pil = img_pil.permute(2, 0, 1)
276

277
    filename, _ = os.path.splitext(os.path.basename(img_path))
278
    torch_png = os.path.join(tmpdir, f"{filename}_torch.png")
279
280
281
    write_png(img_pil, torch_png, compression_level=6)
    saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
    saved_image = saved_image.permute(2, 0, 1)
282

283
    assert_equal(img_pil, saved_image)
284
285


286
def test_read_file(tmpdir):
287
    fname, content = "test1.bin", b"TorchVision\211\n"
288
    fpath = os.path.join(tmpdir, fname)
289
    with open(fpath, "wb") as f:
290
        f.write(content)
291

292
293
294
295
    data = read_file(fpath)
    expected = torch.tensor(list(content), dtype=torch.uint8)
    os.unlink(fpath)
    assert_equal(data, expected)
296
297

    with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"):
298
        read_file("tst")
299
300


301
def test_read_file_non_ascii(tmpdir):
302
    fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
303
    fpath = os.path.join(tmpdir, fname)
304
    with open(fpath, "wb") as f:
305
        f.write(content)
306

307
308
309
310
    data = read_file(fpath)
    expected = torch.tensor(list(content), dtype=torch.uint8)
    os.unlink(fpath)
    assert_equal(data, expected)
311
312


313
def test_write_file(tmpdir):
314
    fname, content = "test1.bin", b"TorchVision\211\n"
315
316
317
    fpath = os.path.join(tmpdir, fname)
    content_tensor = torch.tensor(list(content), dtype=torch.uint8)
    write_file(fpath, content_tensor)
318

319
    with open(fpath, "rb") as f:
320
321
322
        saved_content = f.read()
    os.unlink(fpath)
    assert content == saved_content
323
324


325
def test_write_file_non_ascii(tmpdir):
326
    fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
327
328
329
    fpath = os.path.join(tmpdir, fname)
    content_tensor = torch.tensor(list(content), dtype=torch.uint8)
    write_file(fpath, content_tensor)
330

331
    with open(fpath, "rb") as f:
332
333
334
        saved_content = f.read()
    os.unlink(fpath)
    assert content == saved_content
335

336

337
338
339
340
341
342
343
344
@pytest.mark.parametrize(
    "shape",
    [
        (27, 27),
        (60, 60),
        (105, 105),
    ],
)
345
def test_read_1_bit_png(shape, tmpdir):
346
    np_rng = np.random.RandomState(0)
347
    image_path = os.path.join(tmpdir, f"test_{shape}.png")
348
349
350
351
352
353
    pixels = np_rng.rand(*shape) > 0.5
    img = Image.fromarray(pixels)
    img.save(image_path)
    img1 = read_image(image_path)
    img2 = normalize_dimensions(torch.as_tensor(pixels * 255, dtype=torch.uint8))
    assert_equal(img1, img2)
Prabhat Roy's avatar
Prabhat Roy committed
354
355


356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
@pytest.mark.parametrize(
    "shape",
    [
        (27, 27),
        (60, 60),
        (105, 105),
    ],
)
@pytest.mark.parametrize(
    "mode",
    [
        ImageReadMode.UNCHANGED,
        ImageReadMode.GRAY,
    ],
)
371
def test_read_1_bit_png_consistency(shape, mode, tmpdir):
372
    np_rng = np.random.RandomState(0)
373
    image_path = os.path.join(tmpdir, f"test_{shape}.png")
374
375
376
377
378
379
    pixels = np_rng.rand(*shape) > 0.5
    img = Image.fromarray(pixels)
    img.save(image_path)
    img1 = read_image(image_path, mode)
    img2 = read_image(image_path, mode)
    assert_equal(img1, img2)
Prabhat Roy's avatar
Prabhat Roy committed
380
381


382
383
384
385
386
387
388
389
390
def test_read_interlaced_png():
    imgs = list(get_images(INTERLACED_PNG, ".png"))
    with Image.open(imgs[0]) as im1, Image.open(imgs[1]) as im2:
        assert not (im1.info.get("interlace") is im2.info.get("interlace"))
    img1 = read_image(imgs[0])
    img2 = read_image(imgs[1])
    assert_equal(img1, img2)


391
@needs_cuda
392
393
394
395
396
397
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
@pytest.mark.parametrize("scripted", (False, True))
398
def test_decode_jpeg_cuda(mode, img_path, scripted):
399
    if "cmyk" in img_path:
400
        pytest.xfail("Decoding a CMYK jpeg isn't supported")
401

402
403
404
    data = read_file(img_path)
    img = decode_image(data, mode=mode)
    f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
405
    img_nvjpeg = f(data, mode=mode, device="cuda")
406
407

    # Some difference expected between jpeg implementations
408
    assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2
409

Nicolas Hug's avatar
Nicolas Hug committed
410

411
412
413
414
415
416
@needs_cuda
def test_decode_image_cuda_raises():
    data = torch.randint(0, 127, size=(255,), device="cuda", dtype=torch.uint8)
    with pytest.raises(RuntimeError):
        decode_image(data)

417
418

@needs_cuda
419
@pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda")))
420
421
def test_decode_jpeg_cuda_device_param(cuda_device):
    """Make sure we can pass a string or a torch.device as device param"""
422
    path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
423
    data = read_file(path)
424
425
426
427
428
429
430
    decode_jpeg(data, device=cuda_device)


@needs_cuda
def test_decode_jpeg_cuda_errors():
    data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
431
        decode_jpeg(data.reshape(-1, 1), device="cuda")
432
    with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
433
        decode_jpeg(data.to("cuda"), device="cuda")
434
    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
435
        decode_jpeg(data.to(torch.float), device="cuda")
436
    with pytest.raises(RuntimeError, match="Expected a cuda device"):
437
        torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, "cpu")
438
439


440
441
442
443
444
def test_encode_jpeg_errors():

    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))

445
    with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
446
447
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)

448
    with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
449
450
451
452
453
454
455
456
457
458
459
460
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)

    with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
        encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
        encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
        encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))


461
@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
462
463
464
465
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
466
467
468
469
470
def test_encode_jpeg(img_path):
    img = read_image(img_path)

    pil_img = F.to_pil_image(img)
    buf = io.BytesIO()
471
    pil_img.save(buf, format="JPEG", quality=75)
472

473
    encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8)
474
475
476
477
478
479

    for src_img in [img, img.contiguous()]:
        encoded_jpeg_torch = encode_jpeg(src_img, quality=75)
        assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)


480
@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
481
482
483
484
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
485
486
487
488
def test_write_jpeg(img_path, tmpdir):
    tmpdir = Path(tmpdir)
    img = read_image(img_path)
    pil_img = F.to_pil_image(img)
489

490
491
    torch_jpeg = str(tmpdir / "torch.jpg")
    pil_jpeg = str(tmpdir / "pil.jpg")
492

493
494
    write_jpeg(img, torch_jpeg, quality=75)
    pil_img.save(pil_jpeg, quality=75)
495

496
    with open(torch_jpeg, "rb") as f:
497
        torch_bytes = f.read()
498

499
    with open(pil_jpeg, "rb") as f:
500
        pil_bytes = f.read()
501

502
    assert_equal(torch_bytes, pil_bytes)
503
504


505
506
if __name__ == "__main__":
    pytest.main([__file__])