test_image.py 17.2 KB
Newer Older
1
import glob
2
3
import io
import os
4
5
import sys
from pathlib import Path
6

7
import numpy as np
8
import pytest
9
import torch
10
import torchvision.transforms.functional as F
11
from common_utils import assert_equal, needs_cuda
12
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps
13
from torchvision.io.image import (
14
15
    _read_png_16,
    decode_image,
16
    decode_jpeg,
17
    decode_png,
18
19
20
    encode_jpeg,
    encode_png,
    ImageReadMode,
21
    read_file,
22
    read_image,
23
24
25
    write_file,
    write_jpeg,
    write_png,
26
)
Francisco Massa's avatar
Francisco Massa committed
27

28
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
29
30
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
31
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, "damaged_jpeg")
32
DAMAGED_PNG = os.path.join(IMAGE_ROOT, "damaged_png")
33
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
34
INTERLACED_PNG = os.path.join(IMAGE_ROOT, "interlaced_png")
35
TOOSMALL_PNG = os.path.join(IMAGE_ROOT, "toosmall_png")
36
IS_WINDOWS = sys.platform in ("win32", "cygwin")
37
IS_MACOS = sys.platform == "darwin"
38
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
39
40
41
42
43
44
45
46


def _get_safe_image_name(name):
    # Used when we need to change the pytest "id" for an "image path" parameter.
    # If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific,
    # and this creates issues when the test is running in a different machine than where it was collected
    # (typically, in fb internal infra)
    return name.split(os.path.sep)[-1]
47
48
49
50


def get_images(directory, img_ext):
    assert os.path.isdir(directory)
51
    image_paths = glob.glob(directory + f"/**/*{img_ext}", recursive=True)
52
    for path in image_paths:
53
        if path.split(os.sep)[-2] not in ["damaged_jpeg", "jpeg_write"]:
54
            yield path
55
56


57
58
59
60
61
62
63
64
65
66
67
68
69
def pil_read_image(img_path):
    with Image.open(img_path) as img:
        return torch.from_numpy(np.array(img))


def normalize_dimensions(img_pil):
    if len(img_pil.shape) == 3:
        img_pil = img_pil.permute(2, 0, 1)
    else:
        img_pil = img_pil.unsqueeze(0)
    return img_pil


70
71
72
73
74
75
76
77
78
79
80
81
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize(
    "pil_mode, mode",
    [
        (None, ImageReadMode.UNCHANGED),
        ("L", ImageReadMode.GRAY),
        ("RGB", ImageReadMode.RGB),
    ],
)
82
83
84
85
86
87
88
def test_decode_jpeg(img_path, pil_mode, mode):

    with Image.open(img_path) as img:
        is_cmyk = img.mode == "CMYK"
        if pil_mode is not None:
            img = img.convert(pil_mode)
        img_pil = torch.from_numpy(np.array(img))
89
        if is_cmyk and mode == ImageReadMode.UNCHANGED:
90
91
92
93
94
95
96
97
98
99
100
101
102
            # flip the colors to match libjpeg
            img_pil = 255 - img_pil

    img_pil = normalize_dimensions(img_pil)
    data = read_file(img_path)
    img_ljpeg = decode_image(data, mode=mode)

    # Permit a small variation on pixel values to account for implementation
    # differences between Pillow and LibJPEG.
    abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
    assert abs_mean_diff < 2


103
@pytest.mark.parametrize("codec", ["png", "jpeg"])
104
@pytest.mark.parametrize("orientation", [1, 2, 3, 4, 5, 6, 7, 8, 0])
105
106
def test_decode_with_exif_orientation(tmpdir, codec, orientation):
    fp = os.path.join(tmpdir, f"exif_oriented_{orientation}.{codec}")
107
108
109
110
    t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
    im = F.to_pil_image(t)
    exif = im.getexif()
    exif[0x0112] = orientation  # set exif orientation
111
    im.save(fp, codec.upper(), exif=exif.tobytes())
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

    data = read_file(fp)
    output = decode_image(data, apply_exif_orientation=True)

    pimg = Image.open(fp)
    pimg = ImageOps.exif_transpose(pimg)

    expected = F.pil_to_tensor(pimg)
    torch.testing.assert_close(expected, output)


@pytest.mark.parametrize("size", [65533, 1, 7, 10, 23, 33])
def test_invalid_exif(tmpdir, size):
    # Inspired from a PIL test:
    # https://github.com/python-pillow/Pillow/blob/8f63748e50378424628155994efd7e0739a4d1d1/Tests/test_file_jpeg.py#L299
    fp = os.path.join(tmpdir, "invalid_exif.jpg")
    t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
    im = F.to_pil_image(t)
    im.save(fp, "JPEG", exif=b"1" * size)

    data = read_file(fp)
    output = decode_image(data, apply_exif_orientation=True)

    pimg = Image.open(fp)
    pimg = ImageOps.exif_transpose(pimg)

    expected = F.pil_to_tensor(pimg)
    torch.testing.assert_close(expected, output)


142
143
144
145
146
147
148
149
150
151
152
153
154
def test_decode_jpeg_errors():
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
        decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
        decode_jpeg(torch.empty((100,), dtype=torch.float16))

    with pytest.raises(RuntimeError, match="Not a JPEG file"):
        decode_jpeg(torch.empty((100), dtype=torch.uint8))


def test_decode_bad_huffman_images():
    # sanity check: make sure we can decode the bad Huffman encoding
155
    bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg"))
156
157
158
    decode_jpeg(bad_huff)


159
160
161
162
163
164
165
@pytest.mark.parametrize(
    "img_path",
    [
        pytest.param(truncated_image, id=_get_safe_image_name(truncated_image))
        for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, "corrupt*.jpg"))
    ],
)
166
167
168
def test_damaged_corrupt_images(img_path):
    # Truncated images should raise an exception
    data = read_file(img_path)
169
    if "corrupt34" in img_path:
170
171
172
173
174
175
176
        match_message = "Image is incomplete or truncated"
    else:
        match_message = "Unsupported marker type"
    with pytest.raises(RuntimeError, match=match_message):
        decode_jpeg(data)


177
178
179
180
181
182
183
184
185
186
187
188
189
190
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(FAKEDATA_DIR, ".png")],
)
@pytest.mark.parametrize(
    "pil_mode, mode",
    [
        (None, ImageReadMode.UNCHANGED),
        ("L", ImageReadMode.GRAY),
        ("LA", ImageReadMode.GRAY_ALPHA),
        ("RGB", ImageReadMode.RGB),
        ("RGBA", ImageReadMode.RGB_ALPHA),
    ],
)
191
192
193
194
195
196
197
198
def test_decode_png(img_path, pil_mode, mode):

    with Image.open(img_path) as img:
        if pil_mode is not None:
            img = img.convert(pil_mode)
        img_pil = torch.from_numpy(np.array(img))

    img_pil = normalize_dimensions(img_pil)
199

200
    if img_path.endswith("16.png"):
201
202
203
204
205
206
207
208
209
        # 16 bits image decoding is supported, but only as a private API
        # FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
        with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
            data = read_file(img_path)
            img_lpng = decode_image(data, mode=mode)

        img_lpng = _read_png_16(img_path, mode=mode)
        assert img_lpng.dtype == torch.int32
        # PIL converts 16 bits pngs in uint8
210
        img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8)
211
212
213
    else:
        data = read_file(img_path)
        img_lpng = decode_image(data, mode=mode)
214
215

    tol = 0 if pil_mode is None else 1
216
217
218
219
220
221
222
223
224

    if PILLOW_VERSION >= (8, 3) and pil_mode == "LA":
        # Avoid checking the transparency channel until
        # https://github.com/python-pillow/Pillow/issues/5593#issuecomment-878244910
        # is fixed.
        # TODO: remove once fix is released in PIL. Should be > 8.3.1.
        img_lpng, img_pil = img_lpng[0], img_pil[0]

    torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
225
226
227
228
229
230
231


def test_decode_png_errors():
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
        decode_png(torch.empty((), dtype=torch.uint8))
    with pytest.raises(RuntimeError, match="Content is not png"):
        decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
232
233
    with pytest.raises(RuntimeError, match="Out of bound read in decode_png"):
        decode_png(read_file(os.path.join(DAMAGED_PNG, "sigsegv.png")))
234
235
    with pytest.raises(RuntimeError, match="Content is too small for png"):
        decode_png(read_file(os.path.join(TOOSMALL_PNG, "heapbof.png")))
236
237


238
239
240
241
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
)
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def test_encode_png(img_path):
    pil_image = Image.open(img_path)
    img_pil = torch.from_numpy(np.array(pil_image))
    img_pil = img_pil.permute(2, 0, 1)
    png_buf = encode_png(img_pil, compression_level=6)

    rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
    rec_img = torch.from_numpy(np.array(rec_img))
    rec_img = rec_img.permute(2, 0, 1)

    assert_equal(img_pil, rec_img)


def test_encode_png_errors():
    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
        encode_png(torch.empty((3, 100, 100), dtype=torch.float32))

    with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
260
        encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=-1)
261
262

    with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
263
        encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=10)
264
265
266
267
268

    with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
        encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))


269
270
271
272
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
)
273
274
275
276
def test_write_png(img_path, tmpdir):
    pil_image = Image.open(img_path)
    img_pil = torch.from_numpy(np.array(pil_image))
    img_pil = img_pil.permute(2, 0, 1)
277

278
    filename, _ = os.path.splitext(os.path.basename(img_path))
279
    torch_png = os.path.join(tmpdir, f"{filename}_torch.png")
280
281
282
    write_png(img_pil, torch_png, compression_level=6)
    saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
    saved_image = saved_image.permute(2, 0, 1)
283

284
    assert_equal(img_pil, saved_image)
285
286


287
def test_read_file(tmpdir):
288
    fname, content = "test1.bin", b"TorchVision\211\n"
289
    fpath = os.path.join(tmpdir, fname)
290
    with open(fpath, "wb") as f:
291
        f.write(content)
292

293
294
295
296
    data = read_file(fpath)
    expected = torch.tensor(list(content), dtype=torch.uint8)
    os.unlink(fpath)
    assert_equal(data, expected)
297
298

    with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"):
299
        read_file("tst")
300
301


302
def test_read_file_non_ascii(tmpdir):
303
    fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
304
    fpath = os.path.join(tmpdir, fname)
305
    with open(fpath, "wb") as f:
306
        f.write(content)
307

308
309
310
311
    data = read_file(fpath)
    expected = torch.tensor(list(content), dtype=torch.uint8)
    os.unlink(fpath)
    assert_equal(data, expected)
312
313


314
def test_write_file(tmpdir):
315
    fname, content = "test1.bin", b"TorchVision\211\n"
316
317
318
    fpath = os.path.join(tmpdir, fname)
    content_tensor = torch.tensor(list(content), dtype=torch.uint8)
    write_file(fpath, content_tensor)
319

320
    with open(fpath, "rb") as f:
321
322
323
        saved_content = f.read()
    os.unlink(fpath)
    assert content == saved_content
324
325


326
def test_write_file_non_ascii(tmpdir):
327
    fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
328
329
330
    fpath = os.path.join(tmpdir, fname)
    content_tensor = torch.tensor(list(content), dtype=torch.uint8)
    write_file(fpath, content_tensor)
331

332
    with open(fpath, "rb") as f:
333
334
335
        saved_content = f.read()
    os.unlink(fpath)
    assert content == saved_content
336

337

338
339
340
341
342
343
344
345
@pytest.mark.parametrize(
    "shape",
    [
        (27, 27),
        (60, 60),
        (105, 105),
    ],
)
346
def test_read_1_bit_png(shape, tmpdir):
347
    np_rng = np.random.RandomState(0)
348
    image_path = os.path.join(tmpdir, f"test_{shape}.png")
349
350
351
352
353
354
    pixels = np_rng.rand(*shape) > 0.5
    img = Image.fromarray(pixels)
    img.save(image_path)
    img1 = read_image(image_path)
    img2 = normalize_dimensions(torch.as_tensor(pixels * 255, dtype=torch.uint8))
    assert_equal(img1, img2)
Prabhat Roy's avatar
Prabhat Roy committed
355
356


357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
@pytest.mark.parametrize(
    "shape",
    [
        (27, 27),
        (60, 60),
        (105, 105),
    ],
)
@pytest.mark.parametrize(
    "mode",
    [
        ImageReadMode.UNCHANGED,
        ImageReadMode.GRAY,
    ],
)
372
def test_read_1_bit_png_consistency(shape, mode, tmpdir):
373
    np_rng = np.random.RandomState(0)
374
    image_path = os.path.join(tmpdir, f"test_{shape}.png")
375
376
377
378
379
380
    pixels = np_rng.rand(*shape) > 0.5
    img = Image.fromarray(pixels)
    img.save(image_path)
    img1 = read_image(image_path, mode)
    img2 = read_image(image_path, mode)
    assert_equal(img1, img2)
Prabhat Roy's avatar
Prabhat Roy committed
381
382


383
384
385
386
387
388
389
390
391
def test_read_interlaced_png():
    imgs = list(get_images(INTERLACED_PNG, ".png"))
    with Image.open(imgs[0]) as im1, Image.open(imgs[1]) as im2:
        assert not (im1.info.get("interlace") is im2.info.get("interlace"))
    img1 = read_image(imgs[0])
    img2 = read_image(imgs[1])
    assert_equal(img1, img2)


392
@needs_cuda
393
394
395
396
397
398
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
@pytest.mark.parametrize("scripted", (False, True))
399
def test_decode_jpeg_cuda(mode, img_path, scripted):
400
    if "cmyk" in img_path:
401
        pytest.xfail("Decoding a CMYK jpeg isn't supported")
402

403
404
405
    data = read_file(img_path)
    img = decode_image(data, mode=mode)
    f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
406
    img_nvjpeg = f(data, mode=mode, device="cuda")
407
408

    # Some difference expected between jpeg implementations
409
    assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2
410

Nicolas Hug's avatar
Nicolas Hug committed
411

412
413
414
415
416
417
@needs_cuda
def test_decode_image_cuda_raises():
    data = torch.randint(0, 127, size=(255,), device="cuda", dtype=torch.uint8)
    with pytest.raises(RuntimeError):
        decode_image(data)

418
419

@needs_cuda
420
@pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda")))
421
422
def test_decode_jpeg_cuda_device_param(cuda_device):
    """Make sure we can pass a string or a torch.device as device param"""
423
    path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
424
    data = read_file(path)
425
426
427
428
429
430
431
    decode_jpeg(data, device=cuda_device)


@needs_cuda
def test_decode_jpeg_cuda_errors():
    data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
432
        decode_jpeg(data.reshape(-1, 1), device="cuda")
433
    with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
434
        decode_jpeg(data.to("cuda"), device="cuda")
435
    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
436
        decode_jpeg(data.to(torch.float), device="cuda")
437
    with pytest.raises(RuntimeError, match="Expected a cuda device"):
438
        torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, "cpu")
439
440


441
442
443
444
445
def test_encode_jpeg_errors():

    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))

446
    with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
447
448
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)

449
    with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
450
451
452
453
454
455
456
457
458
459
460
461
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)

    with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
        encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
        encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
        encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))


462
@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
463
464
465
466
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
467
468
469
470
471
def test_encode_jpeg(img_path):
    img = read_image(img_path)

    pil_img = F.to_pil_image(img)
    buf = io.BytesIO()
472
    pil_img.save(buf, format="JPEG", quality=75)
473

474
    encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8)
475
476
477
478
479
480

    for src_img in [img, img.contiguous()]:
        encoded_jpeg_torch = encode_jpeg(src_img, quality=75)
        assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)


481
@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
482
483
484
485
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
486
487
488
489
def test_write_jpeg(img_path, tmpdir):
    tmpdir = Path(tmpdir)
    img = read_image(img_path)
    pil_img = F.to_pil_image(img)
490

491
492
    torch_jpeg = str(tmpdir / "torch.jpg")
    pil_jpeg = str(tmpdir / "pil.jpg")
493

494
495
    write_jpeg(img, torch_jpeg, quality=75)
    pil_img.save(pil_jpeg, quality=75)
496

497
    with open(torch_jpeg, "rb") as f:
498
        torch_bytes = f.read()
499

500
    with open(pil_jpeg, "rb") as f:
501
        pil_bytes = f.read()
502

503
    assert_equal(torch_bytes, pil_bytes)
504
505


506
507
if __name__ == "__main__":
    pytest.main([__file__])