test_image.py 30.1 KB
Newer Older
1
import concurrent.futures
2
import glob
3
4
import io
import os
Nicolas Hug's avatar
Nicolas Hug committed
5
import re
6
7
import sys
from pathlib import Path
8

9
import numpy as np
10
import pytest
Nicolas Hug's avatar
Nicolas Hug committed
11
import requests
12
import torch
13
import torchvision.transforms.functional as F
14
from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda
Nicolas Hug's avatar
Nicolas Hug committed
15
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
16
from torchvision.io.image import (
17
    _read_png_16,
Nicolas Hug's avatar
Nicolas Hug committed
18
    decode_gif,
19
    decode_image,
20
    decode_jpeg,
21
    decode_png,
22
23
24
    encode_jpeg,
    encode_png,
    ImageReadMode,
25
    read_file,
26
    read_image,
27
28
29
    write_file,
    write_jpeg,
    write_png,
30
)
Francisco Massa's avatar
Francisco Massa committed
31

32
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
33
34
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
35
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, "damaged_jpeg")
36
DAMAGED_PNG = os.path.join(IMAGE_ROOT, "damaged_png")
37
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
38
INTERLACED_PNG = os.path.join(IMAGE_ROOT, "interlaced_png")
39
TOOSMALL_PNG = os.path.join(IMAGE_ROOT, "toosmall_png")
40
IS_WINDOWS = sys.platform in ("win32", "cygwin")
41
IS_MACOS = sys.platform == "darwin"
42
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
43
44
45
46
47
48
49
50


def _get_safe_image_name(name):
    # Used when we need to change the pytest "id" for an "image path" parameter.
    # If we don't, the test id (i.e. its name) will contain the whole path to the image, which is machine-specific,
    # and this creates issues when the test is running in a different machine than where it was collected
    # (typically, in fb internal infra)
    return name.split(os.path.sep)[-1]
51
52
53
54


def get_images(directory, img_ext):
    assert os.path.isdir(directory)
55
    image_paths = glob.glob(directory + f"/**/*{img_ext}", recursive=True)
56
    for path in image_paths:
57
        if path.split(os.sep)[-2] not in ["damaged_jpeg", "jpeg_write"]:
58
            yield path
59
60


61
62
63
64
65
66
67
68
69
70
71
72
73
def pil_read_image(img_path):
    with Image.open(img_path) as img:
        return torch.from_numpy(np.array(img))


def normalize_dimensions(img_pil):
    if len(img_pil.shape) == 3:
        img_pil = img_pil.permute(2, 0, 1)
    else:
        img_pil = img_pil.unsqueeze(0)
    return img_pil


74
75
76
77
78
79
80
81
82
83
84
85
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize(
    "pil_mode, mode",
    [
        (None, ImageReadMode.UNCHANGED),
        ("L", ImageReadMode.GRAY),
        ("RGB", ImageReadMode.RGB),
    ],
)
86
87
88
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize("decode_fun", (decode_jpeg, decode_image))
def test_decode_jpeg(img_path, pil_mode, mode, scripted, decode_fun):
89
90
91
92
93
94

    with Image.open(img_path) as img:
        is_cmyk = img.mode == "CMYK"
        if pil_mode is not None:
            img = img.convert(pil_mode)
        img_pil = torch.from_numpy(np.array(img))
95
        if is_cmyk and mode == ImageReadMode.UNCHANGED:
96
97
98
99
100
            # flip the colors to match libjpeg
            img_pil = 255 - img_pil

    img_pil = normalize_dimensions(img_pil)
    data = read_file(img_path)
101
102
103
    if scripted:
        decode_fun = torch.jit.script(decode_fun)
    img_ljpeg = decode_fun(data, mode=mode)
104
105
106
107
108
109
110

    # Permit a small variation on pixel values to account for implementation
    # differences between Pillow and LibJPEG.
    abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
    assert abs_mean_diff < 2


111
@pytest.mark.parametrize("codec", ["png", "jpeg"])
112
@pytest.mark.parametrize("orientation", [1, 2, 3, 4, 5, 6, 7, 8, 0])
113
114
def test_decode_with_exif_orientation(tmpdir, codec, orientation):
    fp = os.path.join(tmpdir, f"exif_oriented_{orientation}.{codec}")
115
116
117
118
    t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
    im = F.to_pil_image(t)
    exif = im.getexif()
    exif[0x0112] = orientation  # set exif orientation
119
    im.save(fp, codec.upper(), exif=exif.tobytes())
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

    data = read_file(fp)
    output = decode_image(data, apply_exif_orientation=True)

    pimg = Image.open(fp)
    pimg = ImageOps.exif_transpose(pimg)

    expected = F.pil_to_tensor(pimg)
    torch.testing.assert_close(expected, output)


@pytest.mark.parametrize("size", [65533, 1, 7, 10, 23, 33])
def test_invalid_exif(tmpdir, size):
    # Inspired from a PIL test:
    # https://github.com/python-pillow/Pillow/blob/8f63748e50378424628155994efd7e0739a4d1d1/Tests/test_file_jpeg.py#L299
    fp = os.path.join(tmpdir, "invalid_exif.jpg")
    t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
    im = F.to_pil_image(t)
    im.save(fp, "JPEG", exif=b"1" * size)

    data = read_file(fp)
    output = decode_image(data, apply_exif_orientation=True)

    pimg = Image.open(fp)
    pimg = ImageOps.exif_transpose(pimg)

    expected = F.pil_to_tensor(pimg)
    torch.testing.assert_close(expected, output)


150
151
152
153
154
155
156
157
158
159
160
161
162
def test_decode_jpeg_errors():
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
        decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
        decode_jpeg(torch.empty((100,), dtype=torch.float16))

    with pytest.raises(RuntimeError, match="Not a JPEG file"):
        decode_jpeg(torch.empty((100), dtype=torch.uint8))


def test_decode_bad_huffman_images():
    # sanity check: make sure we can decode the bad Huffman encoding
163
    bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg"))
164
165
166
    decode_jpeg(bad_huff)


167
168
169
170
171
172
173
@pytest.mark.parametrize(
    "img_path",
    [
        pytest.param(truncated_image, id=_get_safe_image_name(truncated_image))
        for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, "corrupt*.jpg"))
    ],
)
174
175
176
def test_damaged_corrupt_images(img_path):
    # Truncated images should raise an exception
    data = read_file(img_path)
177
    if "corrupt34" in img_path:
178
179
180
181
182
183
184
        match_message = "Image is incomplete or truncated"
    else:
        match_message = "Unsupported marker type"
    with pytest.raises(RuntimeError, match=match_message):
        decode_jpeg(data)


185
186
187
188
189
190
191
192
193
194
195
196
197
198
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(FAKEDATA_DIR, ".png")],
)
@pytest.mark.parametrize(
    "pil_mode, mode",
    [
        (None, ImageReadMode.UNCHANGED),
        ("L", ImageReadMode.GRAY),
        ("LA", ImageReadMode.GRAY_ALPHA),
        ("RGB", ImageReadMode.RGB),
        ("RGBA", ImageReadMode.RGB_ALPHA),
    ],
)
199
200
201
202
203
204
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize("decode_fun", (decode_png, decode_image))
def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun):

    if scripted:
        decode_fun = torch.jit.script(decode_fun)
205
206
207
208
209
210
211

    with Image.open(img_path) as img:
        if pil_mode is not None:
            img = img.convert(pil_mode)
        img_pil = torch.from_numpy(np.array(img))

    img_pil = normalize_dimensions(img_pil)
212

213
    if img_path.endswith("16.png"):
214
215
216
217
        # 16 bits image decoding is supported, but only as a private API
        # FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
        with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
            data = read_file(img_path)
218
            img_lpng = decode_fun(data, mode=mode)
219
220
221
222

        img_lpng = _read_png_16(img_path, mode=mode)
        assert img_lpng.dtype == torch.int32
        # PIL converts 16 bits pngs in uint8
223
        img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8)
224
225
    else:
        data = read_file(img_path)
226
        img_lpng = decode_fun(data, mode=mode)
227
228

    tol = 0 if pil_mode is None else 1
229
230
231
232
233
234
235
236
237

    if PILLOW_VERSION >= (8, 3) and pil_mode == "LA":
        # Avoid checking the transparency channel until
        # https://github.com/python-pillow/Pillow/issues/5593#issuecomment-878244910
        # is fixed.
        # TODO: remove once fix is released in PIL. Should be > 8.3.1.
        img_lpng, img_pil = img_lpng[0], img_pil[0]

    torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
238
239
240
241
242
243
244


def test_decode_png_errors():
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
        decode_png(torch.empty((), dtype=torch.uint8))
    with pytest.raises(RuntimeError, match="Content is not png"):
        decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
245
246
    with pytest.raises(RuntimeError, match="Out of bound read in decode_png"):
        decode_png(read_file(os.path.join(DAMAGED_PNG, "sigsegv.png")))
247
248
    with pytest.raises(RuntimeError, match="Content is too small for png"):
        decode_png(read_file(os.path.join(TOOSMALL_PNG, "heapbof.png")))
249
250


251
252
253
254
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
)
255
256
@pytest.mark.parametrize("scripted", (True, False))
def test_encode_png(img_path, scripted):
257
258
259
    pil_image = Image.open(img_path)
    img_pil = torch.from_numpy(np.array(pil_image))
    img_pil = img_pil.permute(2, 0, 1)
260
261
    encode = torch.jit.script(encode_png) if scripted else encode_png
    png_buf = encode(img_pil, compression_level=6)
262
263
264
265
266
267
268
269
270
271
272
273
274

    rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
    rec_img = torch.from_numpy(np.array(rec_img))
    rec_img = rec_img.permute(2, 0, 1)

    assert_equal(img_pil, rec_img)


def test_encode_png_errors():
    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
        encode_png(torch.empty((3, 100, 100), dtype=torch.float32))

    with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
275
        encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=-1)
276
277

    with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"):
278
        encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=10)
279
280
281
282
283

    with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
        encode_png(torch.empty((5, 100, 100), dtype=torch.uint8))


284
285
286
287
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
)
288
289
@pytest.mark.parametrize("scripted", (True, False))
def test_write_png(img_path, tmpdir, scripted):
290
291
292
    pil_image = Image.open(img_path)
    img_pil = torch.from_numpy(np.array(pil_image))
    img_pil = img_pil.permute(2, 0, 1)
293

294
    filename, _ = os.path.splitext(os.path.basename(img_path))
295
    torch_png = os.path.join(tmpdir, f"{filename}_torch.png")
296
297
    write = torch.jit.script(write_png) if scripted else write_png
    write(img_pil, torch_png, compression_level=6)
298
299
    saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
    saved_image = saved_image.permute(2, 0, 1)
300

301
    assert_equal(img_pil, saved_image)
302
303


304
305
306
307
308
309
310
311
312
313
def test_read_image():
    # Just testing torchcsript, the functionality is somewhat tested already in other tests.
    path = next(get_images(IMAGE_ROOT, ".jpg"))
    out = read_image(path)
    out_scripted = torch.jit.script(read_image)(path)
    torch.testing.assert_close(out, out_scripted, atol=0, rtol=0)


@pytest.mark.parametrize("scripted", (True, False))
def test_read_file(tmpdir, scripted):
314
    fname, content = "test1.bin", b"TorchVision\211\n"
315
    fpath = os.path.join(tmpdir, fname)
316
    with open(fpath, "wb") as f:
317
        f.write(content)
318

319
320
    fun = torch.jit.script(read_file) if scripted else read_file
    data = fun(fpath)
321
322
323
    expected = torch.tensor(list(content), dtype=torch.uint8)
    os.unlink(fpath)
    assert_equal(data, expected)
324
325

    with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"):
326
        read_file("tst")
327
328


329
def test_read_file_non_ascii(tmpdir):
330
    fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
331
    fpath = os.path.join(tmpdir, fname)
332
    with open(fpath, "wb") as f:
333
        f.write(content)
334

335
336
337
338
    data = read_file(fpath)
    expected = torch.tensor(list(content), dtype=torch.uint8)
    os.unlink(fpath)
    assert_equal(data, expected)
339
340


341
342
@pytest.mark.parametrize("scripted", (True, False))
def test_write_file(tmpdir, scripted):
343
    fname, content = "test1.bin", b"TorchVision\211\n"
344
345
    fpath = os.path.join(tmpdir, fname)
    content_tensor = torch.tensor(list(content), dtype=torch.uint8)
346
347
    write = torch.jit.script(write_file) if scripted else write_file
    write(fpath, content_tensor)
348

349
    with open(fpath, "rb") as f:
350
351
352
        saved_content = f.read()
    os.unlink(fpath)
    assert content == saved_content
353
354


355
def test_write_file_non_ascii(tmpdir):
356
    fname, content = "日本語(Japanese).bin", b"TorchVision\211\n"
357
358
359
    fpath = os.path.join(tmpdir, fname)
    content_tensor = torch.tensor(list(content), dtype=torch.uint8)
    write_file(fpath, content_tensor)
360

361
    with open(fpath, "rb") as f:
362
363
364
        saved_content = f.read()
    os.unlink(fpath)
    assert content == saved_content
365

366

367
368
369
370
371
372
373
374
@pytest.mark.parametrize(
    "shape",
    [
        (27, 27),
        (60, 60),
        (105, 105),
    ],
)
375
def test_read_1_bit_png(shape, tmpdir):
376
    np_rng = np.random.RandomState(0)
377
    image_path = os.path.join(tmpdir, f"test_{shape}.png")
378
379
380
381
382
383
    pixels = np_rng.rand(*shape) > 0.5
    img = Image.fromarray(pixels)
    img.save(image_path)
    img1 = read_image(image_path)
    img2 = normalize_dimensions(torch.as_tensor(pixels * 255, dtype=torch.uint8))
    assert_equal(img1, img2)
Prabhat Roy's avatar
Prabhat Roy committed
384
385


386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
@pytest.mark.parametrize(
    "shape",
    [
        (27, 27),
        (60, 60),
        (105, 105),
    ],
)
@pytest.mark.parametrize(
    "mode",
    [
        ImageReadMode.UNCHANGED,
        ImageReadMode.GRAY,
    ],
)
401
def test_read_1_bit_png_consistency(shape, mode, tmpdir):
402
    np_rng = np.random.RandomState(0)
403
    image_path = os.path.join(tmpdir, f"test_{shape}.png")
404
405
406
407
408
409
    pixels = np_rng.rand(*shape) > 0.5
    img = Image.fromarray(pixels)
    img.save(image_path)
    img1 = read_image(image_path, mode)
    img2 = read_image(image_path, mode)
    assert_equal(img1, img2)
Prabhat Roy's avatar
Prabhat Roy committed
410
411


412
413
414
415
416
417
418
419
420
def test_read_interlaced_png():
    imgs = list(get_images(INTERLACED_PNG, ".png"))
    with Image.open(imgs[0]) as im1, Image.open(imgs[1]) as im2:
        assert not (im1.info.get("interlace") is im2.info.get("interlace"))
    img1 = read_image(imgs[0])
    img2 = read_image(imgs[1])
    assert_equal(img1, img2)


421
@needs_cuda
422
423
424
425
426
427
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
@pytest.mark.parametrize("scripted", (False, True))
428
def test_decode_jpeg_cuda(mode, img_path, scripted):
429
    if "cmyk" in img_path:
430
        pytest.xfail("Decoding a CMYK jpeg isn't supported")
431

432
433
434
    data = read_file(img_path)
    img = decode_image(data, mode=mode)
    f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
435
    img_nvjpeg = f(data, mode=mode, device="cuda")
436
437

    # Some difference expected between jpeg implementations
438
    assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2
439

Nicolas Hug's avatar
Nicolas Hug committed
440

441
442
443
444
445
446
@needs_cuda
def test_decode_image_cuda_raises():
    data = torch.randint(0, 127, size=(255,), device="cuda", dtype=torch.uint8)
    with pytest.raises(RuntimeError):
        decode_image(data)

447
448

@needs_cuda
449
@pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda")))
450
451
def test_decode_jpeg_cuda_device_param(cuda_device):
    """Make sure we can pass a string or a torch.device as device param"""
452
    path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
453
    data = read_file(path)
454
455
456
457
458
459
460
    decode_jpeg(data, device=cuda_device)


@needs_cuda
def test_decode_jpeg_cuda_errors():
    data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
    with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
461
        decode_jpeg(data.reshape(-1, 1), device="cuda")
462
    with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
463
        decode_jpeg(data.to("cuda"), device="cuda")
464
    with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
465
        decode_jpeg(data.to(torch.float), device="cuda")
466
    with pytest.raises(RuntimeError, match="Expected a cuda device"):
467
        torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, "cpu")
468
469


470
471
472
473
474
def test_encode_jpeg_errors():

    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))

475
    with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
476
477
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)

478
    with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
479
480
481
482
483
484
485
486
487
488
489
490
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)

    with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
        encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
        encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))

    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
        encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))


491
@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
492
493
494
495
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
496
497
@pytest.mark.parametrize("scripted", (True, False))
def test_encode_jpeg(img_path, scripted):
498
499
500
501
    img = read_image(img_path)

    pil_img = F.to_pil_image(img)
    buf = io.BytesIO()
502
    pil_img.save(buf, format="JPEG", quality=75)
503

504
    encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8)
505

506
    encode = torch.jit.script(encode_jpeg) if scripted else encode_jpeg
507
    for src_img in [img, img.contiguous()]:
508
        encoded_jpeg_torch = encode(src_img, quality=75)
509
510
511
        assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)


512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
@needs_cuda
def test_encode_jpeg_cuda_device_param():
    path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)

    data = read_image(path)

    current_device = torch.cuda.current_device()
    current_stream = torch.cuda.current_stream()
    num_devices = torch.cuda.device_count()
    devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)]
    results = []
    for device in devices:
        print(f"python: device: {device}")
        results.append(encode_jpeg(data.to(device=device)))
    assert len(results) == len(devices)
    for result in results:
        assert torch.all(result.cpu() == results[0].cpu())

    assert current_device == torch.cuda.current_device()
    assert current_stream == torch.cuda.current_stream()


@needs_cuda
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
)
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize("contiguous", (False, True))
def test_encode_jpeg_cuda(img_path, scripted, contiguous):
    decoded_image_tv = read_image(img_path)
    encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg

    if "cmyk" in img_path:
        pytest.xfail("Encoding a CMYK jpeg isn't supported")
    if decoded_image_tv.shape[0] == 1:
        pytest.xfail("Decoding a grayscale jpeg isn't supported")
        # For more detail as to why check out: https://github.com/NVIDIA/cuda-samples/issues/23#issuecomment-559283013
    if contiguous:
        decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.contiguous_format)[0]
    else:
        decoded_image_tv = decoded_image_tv[None].contiguous(memory_format=torch.channels_last)[0]
    encoded_jpeg_cuda_tv = encode_fn(decoded_image_tv.cuda(), quality=75)
    decoded_jpeg_cuda_tv = decode_jpeg(encoded_jpeg_cuda_tv.cpu())

    # the actual encoded bytestreams from libnvjpeg and libjpeg-turbo differ for the same quality
    # instead, we re-decode the encoded image and compare to the original
    abs_mean_diff = (decoded_jpeg_cuda_tv.float() - decoded_image_tv.float()).abs().mean().item()
    assert abs_mean_diff < 3


@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scripted", (True, False))
@pytest.mark.parametrize("contiguous", (True, False))
def test_encode_jpegs_batch(scripted, contiguous, device):
    if device == "cpu" and IS_MACOS:
        pytest.skip("https://github.com/pytorch/vision/issues/8031")
    decoded_images_tv = []
    for jpeg_path in get_images(IMAGE_ROOT, ".jpg"):
        if "cmyk" in jpeg_path:
            continue
        decoded_image = read_image(jpeg_path)
        if decoded_image.shape[0] == 1:
            continue
        if contiguous:
            decoded_image = decoded_image[None].contiguous(memory_format=torch.contiguous_format)[0]
        else:
            decoded_image = decoded_image[None].contiguous(memory_format=torch.channels_last)[0]
        decoded_images_tv.append(decoded_image)

    encode_fn = torch.jit.script(encode_jpeg) if scripted else encode_jpeg

    decoded_images_tv_device = [img.to(device=device) for img in decoded_images_tv]
    encoded_jpegs_tv_device = encode_fn(decoded_images_tv_device, quality=75)
    encoded_jpegs_tv_device = [decode_jpeg(img.cpu()) for img in encoded_jpegs_tv_device]

    for original, encoded_decoded in zip(decoded_images_tv, encoded_jpegs_tv_device):
        c, h, w = original.shape
        abs_mean_diff = (original.float() - encoded_decoded.float()).abs().mean().item()
        assert abs_mean_diff < 3

    # test multithreaded decoding
    # in the current version we prevent this by using a lock but we still want to test it
    num_workers = 10
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = [executor.submit(encode_fn, decoded_images_tv_device) for _ in range(num_workers)]
    encoded_images_threaded = [future.result() for future in futures]
    assert len(encoded_images_threaded) == num_workers
    for encoded_images in encoded_images_threaded:
        assert len(decoded_images_tv_device) == len(encoded_images)
        for i, (encoded_image_cuda, decoded_image_tv) in enumerate(zip(encoded_images, decoded_images_tv_device)):
            # make sure all the threads produce identical outputs
            assert torch.all(encoded_image_cuda == encoded_images_threaded[0][i])

            # make sure the outputs are identical or close enough to baseline
            decoded_cuda_encoded_image = decode_jpeg(encoded_image_cuda.cpu())
            assert decoded_cuda_encoded_image.shape == decoded_image_tv.shape
            assert decoded_cuda_encoded_image.dtype == decoded_image_tv.dtype
            assert (decoded_cuda_encoded_image.cpu().float() - decoded_image_tv.cpu().float()).abs().mean() < 3


@needs_cuda
def test_single_encode_jpeg_cuda_errors():
    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
        encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"))

    with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"):
        encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"))

    with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"):
        encode_jpeg(torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"))

    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
        encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"))

    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
        encode_jpeg(torch.empty((100, 100), dtype=torch.uint8, device="cuda"))


@needs_cuda
def test_batch_encode_jpegs_cuda_errors():
    with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
        encode_jpeg(
            [
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
                torch.empty((3, 100, 100), dtype=torch.float32, device="cuda"),
            ]
        )

    with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 5"):
        encode_jpeg(
            [
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
                torch.empty((5, 100, 100), dtype=torch.uint8, device="cuda"),
            ]
        )

    with pytest.raises(RuntimeError, match="The number of channels should be 3, got: 1"):
        encode_jpeg(
            [
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
                torch.empty((1, 100, 100), dtype=torch.uint8, device="cuda"),
            ]
        )

    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
        encode_jpeg(
            [
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
                torch.empty((1, 3, 100, 100), dtype=torch.uint8, device="cuda"),
            ]
        )

    with pytest.raises(RuntimeError, match="Input data should be a 3-dimensional tensor"):
        encode_jpeg(
            [
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
                torch.empty((100, 100), dtype=torch.uint8, device="cuda"),
            ]
        )

    with pytest.raises(RuntimeError, match="Input tensor should be on CPU"):
        encode_jpeg(
            [
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"),
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
            ]
        )

    with pytest.raises(
        RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"
    ):
        encode_jpeg(
            [
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda"),
                torch.empty((3, 100, 100), dtype=torch.uint8, device="cpu"),
            ]
        )

    if torch.cuda.device_count() >= 2:
        with pytest.raises(
            RuntimeError, match="All input tensors must be on the same CUDA device when encoding with nvjpeg"
        ):
            encode_jpeg(
                [
                    torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:0"),
                    torch.empty((3, 100, 100), dtype=torch.uint8, device="cuda:1"),
                ]
            )

    with pytest.raises(ValueError, match="encode_jpeg requires at least one input tensor when a list is passed"):
        encode_jpeg([])


706
@pytest.mark.skipif(IS_MACOS, reason="https://github.com/pytorch/vision/issues/8031")
707
708
709
710
@pytest.mark.parametrize(
    "img_path",
    [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
)
711
712
@pytest.mark.parametrize("scripted", (True, False))
def test_write_jpeg(img_path, tmpdir, scripted):
713
714
715
    tmpdir = Path(tmpdir)
    img = read_image(img_path)
    pil_img = F.to_pil_image(img)
716

717
718
    torch_jpeg = str(tmpdir / "torch.jpg")
    pil_jpeg = str(tmpdir / "pil.jpg")
719

720
721
    write = torch.jit.script(write_jpeg) if scripted else write_jpeg
    write(img, torch_jpeg, quality=75)
722
    pil_img.save(pil_jpeg, quality=75)
723

724
    with open(torch_jpeg, "rb") as f:
725
        torch_bytes = f.read()
726

727
    with open(pil_jpeg, "rb") as f:
728
        pil_bytes = f.read()
729

730
    assert_equal(torch_bytes, pil_bytes)
731
732


733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
def test_pathlib_support(tmpdir):
    # Just make sure pathlib.Path is supported where relevant

    jpeg_path = Path(next(get_images(ENCODE_JPEG, ".jpg")))

    read_file(jpeg_path)
    read_image(jpeg_path)

    write_path = Path(tmpdir) / "whatever"
    img = torch.randint(0, 10, size=(3, 4, 4), dtype=torch.uint8)

    write_file(write_path, data=img.flatten())
    write_jpeg(img, write_path)
    write_png(img, write_path)


749
750
751
@pytest.mark.parametrize(
    "name", ("gifgrid", "fire", "porsche", "treescap", "treescap-interlaced", "solid2", "x-trans", "earth")
)
752
753
@pytest.mark.parametrize("scripted", (True, False))
def test_decode_gif(tmpdir, name, scripted):
Nicolas Hug's avatar
Nicolas Hug committed
754
755
756
757
758
759
    # Using test images from GIFLIB
    # https://sourceforge.net/p/giflib/code/ci/master/tree/pic/, we assert PIL
    # and torchvision decoded outputs are equal.
    # We're not testing against "welcome2" because PIL and GIFLIB disagee on what
    # the background color should be (likely a difference in the way they handle
    # transparency?)
760
761
762
763
    # 'earth' image is from wikipedia, licensed under CC BY-SA 3.0
    # https://creativecommons.org/licenses/by-sa/3.0/
    # it allows to properly test for transparency, TOP-LEFT offsets, and
    # disposal modes.
Nicolas Hug's avatar
Nicolas Hug committed
764
765

    path = tmpdir / f"{name}.gif"
766
    if name == "earth":
767
768
769
        if IN_OSS_CI:
            # TODO: Fix this... one day.
            pytest.skip("Skipping 'earth' test as it's flaky on OSS CI")
770
771
772
        url = "https://upload.wikimedia.org/wikipedia/commons/2/2c/Rotating_earth_%28large%29.gif"
    else:
        url = f"https://sourceforge.net/p/giflib/code/ci/master/tree/pic/{name}.gif?format=raw"
Nicolas Hug's avatar
Nicolas Hug committed
773
774
775
    with open(path, "wb") as f:
        f.write(requests.get(url).content)

776
777
778
    encoded_bytes = read_file(path)
    f = torch.jit.script(decode_gif) if scripted else decode_gif
    tv_out = f(encoded_bytes)
Nicolas Hug's avatar
Nicolas Hug committed
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
    if tv_out.ndim == 3:
        tv_out = tv_out[None]

    assert tv_out.is_contiguous(memory_format=torch.channels_last)

    # For some reason, not using Image.open() as a CM causes "ResourceWarning: unclosed file"
    with Image.open(path) as pil_img:
        pil_seq = ImageSequence.Iterator(pil_img)

        for pil_frame, tv_frame in zip(pil_seq, tv_out):
            pil_frame = F.pil_to_tensor(pil_frame.convert("RGB"))
            torch.testing.assert_close(tv_frame, pil_frame, atol=0, rtol=0)


def test_decode_gif_errors():
    encoded_data = torch.randint(0, 256, (100,), dtype=torch.uint8)
    with pytest.raises(RuntimeError, match="Input tensor must be 1-dimensional"):
        decode_gif(encoded_data[None])
    with pytest.raises(RuntimeError, match="Input tensor must have uint8 data type"):
        decode_gif(encoded_data.float())
    with pytest.raises(RuntimeError, match="Input tensor must be contiguous"):
        decode_gif(encoded_data[::2])
    with pytest.raises(RuntimeError, match=re.escape("DGifOpenFileName() failed - 103")):
        decode_gif(encoded_data)


805
806
if __name__ == "__main__":
    pytest.main([__file__])