"sgl-kernel/git@developer.sourcefind.cn:change/sglang.git" did not exist on "8e7b31546c724326780058c510fce3968c0b5285"
test_image.py 21.5 KB
Newer Older
1
import glob
2
3
import io
import os
Nicolas Hug's avatar
Nicolas Hug committed
4
import re
5
6
import sys
from pathlib import Path
7

8
import numpy as np
9
import pytest
Nicolas Hug's avatar
Nicolas Hug committed
10
import requests
11
import torch
12
import torchvision.transforms.functional as F
13
from common_utils import assert_equal, needs_cuda
Nicolas Hug's avatar
Nicolas Hug committed
14
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
15
from torchvision.io.image import (
16
    _read_png_16,
Nicolas Hug's avatar
Nicolas Hug committed
17
    decode_gif,
18
    decode_image,
19
    decode_jpeg,
20
    decode_png,
21
22
23
    encode_jpeg,
    encode_png,
    ImageReadMode,
24
    read_file,
25
    read_image,
26
27
28
    write_file,
    write_jpeg,
    write_png,
29
)
Francisco Massa's avatar
Francisco Massa committed
30

31
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
32
33
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
34
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, "damaged_jpeg")
35
DAMAGED_PNG = os.path.join(IMAGE_ROOT, "damaged_png")
36
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
37
INTERLACED_PNG = os.path.join(IMAGE_ROOT, "interlaced_png")
38
TOOSMALL_PNG = os.path.join(IMAGE_ROOT, "toosmall_png")
39
IS_WINDOWS = sys.platform in ("win32", "cygwin")
40
IS_MACOS = sys.platform == "darwin"
41
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
42
43
44
45
46
47
48
49


def _get_safe_image_name(name):
    # Used when we need to change the pytest "id" for an "image path" parameter.
    # If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific,
    # and this creates issues when the test is running in a different machine than where it was collected
    # (typically, in fb internal infra)
    return name.split(os.path.sep)[-1]
50
51
52
53


def get_images(directory, img_ext):
    assert os.path.isdir(directory)
54
    image_paths = glob.glob(directory + f"/**/*{img_ext}", recursive=True)
55
    for path in image_paths:
56
        if path.split(os.sep)[-2] not in ["damaged_jpeg", "jpeg_write"]:
57
            yield path
58
59


60
61
62
63
64
65
66
67
68
69
70
71
72
def pil_read_image(img_path):
    with Image.open(img_path) as img:
        return torch.from_numpy(np.array(img))


def normalize_dimensions(img_pil):
    if len(img_pil.shape) == 3:
        img_pil = img_pil.permute(2, 0, 1)
    else:
        img_pil = img_pil.unsqueeze(0)
    return img_pil


73
74
75
76
77
78
79
80
81
82
83
84
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize(
    "pil_mode, mode",
    [
        (None, ImageReadMode.UNCHANGED),
        ("L", ImageReadMode.GRAY),
        ("RGB", ImageReadMode.RGB),
    ],
)
85
86
87
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize("decode_fun", (decode_jpeg, decode_image))
def test_decode_jpeg(img_path, pil_mode, mode, scripted, decode_fun):
88
89
90
91
92
93

    with Image.open(img_path) as img:
        is_cmyk = img.mode == "CMYK"
        if pil_mode is not None:
            img = img.convert(pil_mode)
        img_pil = torch.from_numpy(np.array(img))
94
        if is_cmyk and mode == ImageReadMode.UNCHANGED:
95
96
97
98
99
            # flip the colors to match libjpeg
            img_pil = 255 - img_pil

    img_pil = normalize_dimensions(img_pil)
    data = read_file(img_path)
100
101
102
    if scripted:
        decode_fun = torch.jit.script(decode_fun)
    img_ljpeg = decode_fun(data, mode=mode)
103
104
105
106
107
108
109

    # Permit a small variation on pixel values to account for implementation
    # differences between Pillow and LibJPEG.
    abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
    assert abs_mean_diff < 2


110
@pytest.mark.parametrize("codec", ["png", "jpeg"])
111
@pytest.mark.parametrize("orientation", [1, 2, 3, 4, 5, 6, 7, 8, 0])
112
113
def test_decode_with_exif_orientation(tmpdir, codec, orientation):
    fp = os.path.join(tmpdir, f"exif_oriented_{orientation}.{codec}")
114
115
116
117
    t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
    im = F.to_pil_image(t)
    exif = im.getexif()
    exif[0x0112] = orientation  # set exif orientation
118
    im.save(fp, codec.upper(), exif=exif.tobytes())
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

    data = read_file(fp)
    output = decode_image(data, apply_exif_orientation=True)

    pimg = Image.open(fp)
    pimg = ImageOps.exif_transpose(pimg)

    expected = F.pil_to_tensor(pimg)
    torch.testing.assert_close(expected, output)


@pytest.mark.parametrize("size", [65533, 1, 7, 10, 23, 33])
def test_invalid_exif(tmpdir, size):
    # Inspired from a PIL test:
    # https://github.com/python-pillow/Pillow/blob/8f63748e50378424628155994efd7e0739a4d1d1/Tests/test_file_jpeg.py#L299
    fp = os.path.join(tmpdir, "invalid_exif.jpg")
    t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
    im = F.to_pil_image(t)
    im.save(fp, "JPEG", exif=b"1" * size)

    data = read_file(fp)
    output = decode_image(data, apply_exif_orientation=True)

    pimg = Image.open(fp)
    pimg = ImageOps.exif_transpose(pimg)

    expected = F.pil_to_tensor(pimg)
    torch.testing.assert_close(expected, output)


149
150
151
152
153
154
155
156
157
158
159
160
161
def test_decode_jpeg_errors():
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
        decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
        decode_jpeg(torch.empty((100,), dtype=torch.float16))

    with pytest.raises(RuntimeError, match="Not a JPEG file"):
        decode_jpeg(torch.empty((100), dtype=torch.uint8))


def test_decode_bad_huffman_images():
    # sanity check: make sure we can decode the bad Huffman encoding
162
    bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg"))
163
164
165
    decode_jpeg(bad_huff)


166
167
168
169
170
171
172
@pytest.mark.parametrize(
    "img_path",
    [
        pytest.param(truncated_image, id=_get_safe_image_name(truncated_image))
        for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, "corrupt*.jpg"))
    ],
)
173
174
175
def test_damaged_corrupt_images(img_path):
    # Truncated images should raise an exception
    data = read_file(img_path)
176
    if "corrupt34" in img_path:
177
178
179
180
181
182
183
        match_message = "Image is incomplete or truncated"
    else:
        match_message = "Unsupported marker type"
    with pytest.raises(RuntimeError, match=match_message):
        decode_jpeg(data)


184
185
186
187
188
189
190
191
192
193
194
195
196
197
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(FAKEDATA_DIR, ".png")],
)
@pytest.mark.parametrize(
    "pil_mode, mode",
    [
        (None, ImageReadMode.UNCHANGED),
        ("L", ImageReadMode.GRAY),
        ("LA", ImageReadMode.GRAY_ALPHA),
        ("RGB", ImageReadMode.RGB),
        ("RGBA", ImageReadMode.RGB_ALPHA),
    ],
)
198
199
200
201
202
203
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize("decode_fun", (decode_png, decode_image))
def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun):

    if scripted:
        decode_fun = torch.jit.script(decode_fun)
204
205
206
207
208
209
210

    with Image.open(img_path) as img:
        if pil_mode is not None:
            img = img.convert(pil_mode)
        img_pil = torch.from_numpy(np.array(img))

    img_pil = normalize_dimensions(img_pil)
211

212
    if img_path.endswith("16.png"):
213
214
215
216
        # 16 bits image decoding is supported, but only as a private API
        # FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
        with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
            data = read_file(img_path)
217
            img_lpng = decode_fun(data, mode=mode)
218
219
220
221

        img_lpng = _read_png_16(img_path, mode=mode)
        assert img_lpng.dtype == torch.int32
        # PIL converts 16 bits pngs in uint8
222
        img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8)
223
224
    else:
        data = read_file(img_path)
225
        img_lpng = decode_fun(data, mode=mode)
226
227

    tol = 0 if pil_mode is None else 1
228
229
230
231
232
233
234
235
236

    if PILLOW_VERSION >= (8, 3) and pil_mode == "LA":
        # Avoid checking the transparency channel until
        # https://github.com/python-pillow/Pillow/issues/5593#issuecomment-878244910
        # is fixed.
        # TODO: remove once fix is released in PIL. Should be > 8.3.1.
        img_lpng, img_pil = img_lpng[0], img_pil[0]

    torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
237
238
239
240
241
242
243


def test_decode_png_errors():
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
        decode_png(torch.empty((), dtype=torch.uint8))
    with pytest.raises(RuntimeError, match="Content is not png"):
        decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
244
245
    with pytest.raises(RuntimeError, match="Out of bound read in decode_png"):
        decode_png(read_file(os.path.join(DAMAGED_PNG, "sigsegv.png")))
246
247
    with pytest.raises(RuntimeError, match="Content is too small for png"):
        decode_png(read_file(os.path.join(TOOSMALL_PNG, "heapbof.png")))
248
249


250
251
252
253
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
)
254
255
@pytest.mark.parametrize("scripted", (True, False))
def test_encode_png(img_path, scripted):
256
257
258
    pil_image = Image.open(img_path)
    img_pil = torch.from_numpy(np.array(pil_image))
    img_pil = img_pil.permute(2, 0, 1)
259
260
    encode = torch.jit.script(encode_png) if scripted else encode_png
    png_buf = encode(img_pil, compression_level=6)
261
262
263
264
265
266
267
268
269
270
271
272
273

    rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
    rec_img = torch.from_numpy(np.array(rec_img))
    rec_img = rec_img.permute(2, 0, 1)

    assert_equal(img_pil, rec_img)


def test_encode_png_errors():
    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
        encode_png(torch.empty((3, 100, 100), dtype=torch.float32))

    with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
274
        encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=-1)
275
276

    with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
277
        encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=10)
278
279
280
281
282

    with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
        encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))


283
284
285
286
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
)
287
288
@pytest.mark.parametrize("scripted", (True, False))
def test_write_png(img_path, tmpdir, scripted):
289
290
291
    pil_image = Image.open(img_path)
    img_pil = torch.from_numpy(np.array(pil_image))
    img_pil = img_pil.permute(2, 0, 1)
292

293
    filename, _ = os.path.splitext(os.path.basename(img_path))
294
    torch_png = os.path.join(tmpdir, f"{filename}_torch.png")
295
296
    write = torch.jit.script(write_png) if scripted else write_png
    write(img_pil, torch_png, compression_level=6)
297
298
    saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
    saved_image = saved_image.permute(2, 0, 1)
299

300
    assert_equal(img_pil, saved_image)
301
302


303
304
305
306
307
308
309
310
311
312
def test_read_image():
    # Just testing torchcsript, the functionality is somewhat tested already in other tests.
    path = next(get_images(IMAGE_ROOT, ".jpg"))
    out = read_image(path)
    out_scripted = torch.jit.script(read_image)(path)
    torch.testing.assert_close(out, out_scripted, atol=0, rtol=0)


@pytest.mark.parametrize("scripted", (True, False))
def test_read_file(tmpdir, scripted):
313
    fname, content = "test1.bin", b"TorchVision\211\n"
314
    fpath = os.path.join(tmpdir, fname)
315
    with open(fpath, "wb") as f:
316
        f.write(content)
317

318
319
    fun = torch.jit.script(read_file) if scripted else read_file
    data = fun(fpath)
320
321
322
    expected = torch.tensor(list(content), dtype=torch.uint8)
    os.unlink(fpath)
    assert_equal(data, expected)
323
324

    with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"):
325
        read_file("tst")
326
327


328
def test_read_file_non_ascii(tmpdir):
329
    fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
330
    fpath = os.path.join(tmpdir, fname)
331
    with open(fpath, "wb") as f:
332
        f.write(content)
333

334
335
336
337
    data = read_file(fpath)
    expected = torch.tensor(list(content), dtype=torch.uint8)
    os.unlink(fpath)
    assert_equal(data, expected)
338
339


340
341
@pytest.mark.parametrize("scripted", (True, False))
def test_write_file(tmpdir, scripted):
342
    fname, content = "test1.bin", b"TorchVision\211\n"
343
344
    fpath = os.path.join(tmpdir, fname)
    content_tensor = torch.tensor(list(content), dtype=torch.uint8)
345
346
    write = torch.jit.script(write_file) if scripted else write_file
    write(fpath, content_tensor)
347

348
    with open(fpath, "rb") as f:
349
350
351
        saved_content = f.read()
    os.unlink(fpath)
    assert content == saved_content
352
353


354
def test_write_file_non_ascii(tmpdir):
355
    fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
356
357
358
    fpath = os.path.join(tmpdir, fname)
    content_tensor = torch.tensor(list(content), dtype=torch.uint8)
    write_file(fpath, content_tensor)
359

360
    with open(fpath, "rb") as f:
361
362
363
        saved_content = f.read()
    os.unlink(fpath)
    assert content == saved_content
364

365

366
367
368
369
370
371
372
373
@pytest.mark.parametrize(
    "shape",
    [
        (27, 27),
        (60, 60),
        (105, 105),
    ],
)
374
def test_read_1_bit_png(shape, tmpdir):
375
    np_rng = np.random.RandomState(0)
376
    image_path = os.path.join(tmpdir, f"test_{shape}.png")
377
378
379
380
381
382
    pixels = np_rng.rand(*shape) > 0.5
    img = Image.fromarray(pixels)
    img.save(image_path)
    img1 = read_image(image_path)
    img2 = normalize_dimensions(torch.as_tensor(pixels * 255, dtype=torch.uint8))
    assert_equal(img1, img2)
Prabhat Roy's avatar
Prabhat Roy committed
383
384


385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
@pytest.mark.parametrize(
    "shape",
    [
        (27, 27),
        (60, 60),
        (105, 105),
    ],
)
@pytest.mark.parametrize(
    "mode",
    [
        ImageReadMode.UNCHANGED,
        ImageReadMode.GRAY,
    ],
)
400
def test_read_1_bit_png_consistency(shape, mode, tmpdir):
401
    np_rng = np.random.RandomState(0)
402
    image_path = os.path.join(tmpdir, f"test_{shape}.png")
403
404
405
406
407
408
    pixels = np_rng.rand(*shape) > 0.5
    img = Image.fromarray(pixels)
    img.save(image_path)
    img1 = read_image(image_path, mode)
    img2 = read_image(image_path, mode)
    assert_equal(img1, img2)
Prabhat Roy's avatar
Prabhat Roy committed
409
410


411
412
413
414
415
416
417
418
419
def test_read_interlaced_png():
    imgs = list(get_images(INTERLACED_PNG, ".png"))
    with Image.open(imgs[0]) as im1, Image.open(imgs[1]) as im2:
        assert not (im1.info.get("interlace") is im2.info.get("interlace"))
    img1 = read_image(imgs[0])
    img2 = read_image(imgs[1])
    assert_equal(img1, img2)


420
@needs_cuda
421
422
423
424
425
426
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
@pytest.mark.parametrize("scripted", (False, True))
427
def test_decode_jpeg_cuda(mode, img_path, scripted):
428
    if "cmyk" in img_path:
429
        pytest.xfail("Decoding a CMYK jpeg isn't supported")
430

431
432
433
    data = read_file(img_path)
    img = decode_image(data, mode=mode)
    f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
434
    img_nvjpeg = f(data, mode=mode, device="cuda")
435
436

    # Some difference expected between jpeg implementations
437
    assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2
438

Nicolas Hug's avatar
Nicolas Hug committed
439

440
441
442
443
444
445
@needs_cuda
def test_decode_image_cuda_raises():
    data = torch.randint(0, 127, size=(255,), device="cuda", dtype=torch.uint8)
    with pytest.raises(RuntimeError):
        decode_image(data)

446
447

@needs_cuda
448
@pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda")))
449
450
def test_decode_jpeg_cuda_device_param(cuda_device):
    """Make sure we can pass a string or a torch.device as device param"""
451
    path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
452
    data = read_file(path)
453
454
455
456
457
458
459
    decode_jpeg(data, device=cuda_device)


@needs_cuda
def test_decode_jpeg_cuda_errors():
    data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
460
        decode_jpeg(data.reshape(-1, 1), device="cuda")
461
    with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
462
        decode_jpeg(data.to("cuda"), device="cuda")
463
    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
464
        decode_jpeg(data.to(torch.float), device="cuda")
465
    with pytest.raises(RuntimeError, match="Expected a cuda device"):
466
        torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, "cpu")
467
468


469
470
471
472
473
def test_encode_jpeg_errors():

    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))

474
    with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
475
476
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)

477
    with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
478
479
480
481
482
483
484
485
486
487
488
489
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)

    with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
        encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
        encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
        encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))


490
@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
491
492
493
494
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
495
496
@pytest.mark.parametrize("scripted", (True, False))
def test_encode_jpeg(img_path, scripted):
497
498
499
500
    img = read_image(img_path)

    pil_img = F.to_pil_image(img)
    buf = io.BytesIO()
501
    pil_img.save(buf, format="JPEG", quality=75)
502

503
    encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8)
504

505
    encode = torch.jit.script(encode_jpeg) if scripted else encode_jpeg
506
    for src_img in [img, img.contiguous()]:
507
        encoded_jpeg_torch = encode(src_img, quality=75)
508
509
510
        assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)


511
@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
512
513
514
515
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
516
517
@pytest.mark.parametrize("scripted", (True, False))
def test_write_jpeg(img_path, tmpdir, scripted):
518
519
520
    tmpdir = Path(tmpdir)
    img = read_image(img_path)
    pil_img = F.to_pil_image(img)
521

522
523
    torch_jpeg = str(tmpdir / "torch.jpg")
    pil_jpeg = str(tmpdir / "pil.jpg")
524

525
526
    write = torch.jit.script(write_jpeg) if scripted else write_jpeg
    write(img, torch_jpeg, quality=75)
527
    pil_img.save(pil_jpeg, quality=75)
528

529
    with open(torch_jpeg, "rb") as f:
530
        torch_bytes = f.read()
531

532
    with open(pil_jpeg, "rb") as f:
533
        pil_bytes = f.read()
534

535
    assert_equal(torch_bytes, pil_bytes)
536
537


538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
def test_pathlib_support(tmpdir):
    # Just make sure pathlib.Path is supported where relevant

    jpeg_path = Path(next(get_images(ENCODE_JPEG, ".jpg")))

    read_file(jpeg_path)
    read_image(jpeg_path)

    write_path = Path(tmpdir) / "whatever"
    img = torch.randint(0, 10, size=(3, 4, 4), dtype=torch.uint8)

    write_file(write_path, data=img.flatten())
    write_jpeg(img, write_path)
    write_png(img, write_path)


554
555
556
@pytest.mark.parametrize(
    "name", ("gifgrid", "fire", "porsche", "treescap", "treescap-interlaced", "solid2", "x-trans", "earth")
)
557
558
@pytest.mark.parametrize("scripted", (True, False))
def test_decode_gif(tmpdir, name, scripted):
Nicolas Hug's avatar
Nicolas Hug committed
559
560
561
562
563
564
    # Using test images from GIFLIB
    # https://sourceforge.net/p/giflib/code/ci/master/tree/pic/, we assert PIL
    # and torchvision decoded outputs are equal.
    # We're not testing against "welcome2" because PIL and GIFLIB disagee on what
    # the background color should be (likely a difference in the way they handle
    # transparency?)
565
566
567
568
    # 'earth' image is from wikipedia, licensed under CC BY-SA 3.0
    # https://creativecommons.org/licenses/by-sa/3.0/
    # it allows to properly test for transparency, TOP-LEFT offsets, and
    # disposal modes.
Nicolas Hug's avatar
Nicolas Hug committed
569
570

    path = tmpdir / f"{name}.gif"
571
572
573
574
    if name == "earth":
        url = "https://upload.wikimedia.org/wikipedia/commons/2/2c/Rotating_earth_%28large%29.gif"
    else:
        url = f"https://sourceforge.net/p/giflib/code/ci/master/tree/pic/{name}.gif?format=raw"
Nicolas Hug's avatar
Nicolas Hug committed
575
576
577
    with open(path, "wb") as f:
        f.write(requests.get(url).content)

578
579
580
    encoded_bytes = read_file(path)
    f = torch.jit.script(decode_gif) if scripted else decode_gif
    tv_out = f(encoded_bytes)
Nicolas Hug's avatar
Nicolas Hug committed
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
    if tv_out.ndim == 3:
        tv_out = tv_out[None]

    assert tv_out.is_contiguous(memory_format=torch.channels_last)

    # For some reason, not using Image.open() as a CM causes "ResourceWarning: unclosed file"
    with Image.open(path) as pil_img:
        pil_seq = ImageSequence.Iterator(pil_img)

        for pil_frame, tv_frame in zip(pil_seq, tv_out):
            pil_frame = F.pil_to_tensor(pil_frame.convert("RGB"))
            torch.testing.assert_close(tv_frame, pil_frame, atol=0, rtol=0)


def test_decode_gif_errors():
    encoded_data = torch.randint(0, 256, (100,), dtype=torch.uint8)
    with pytest.raises(RuntimeError, match="Input tensor must be 1-dimensional"):
        decode_gif(encoded_data[None])
    with pytest.raises(RuntimeError, match="Input tensor must have uint8 data type"):
        decode_gif(encoded_data.float())
    with pytest.raises(RuntimeError, match="Input tensor must be contiguous"):
        decode_gif(encoded_data[::2])
    with pytest.raises(RuntimeError, match=re.escape("DGifOpenFileName() failed - 103")):
        decode_gif(encoded_data)


607
608
if __name__ == "__main__":
    pytest.main([__file__])