"references/vscode:/vscode.git/clone" did not exist on "9cece405db3a8301463b219d9e0885142837ae4f"
test_transforms_v2_functional.py 1.35 KB
Newer Older
1
import numpy as np
2
import PIL.Image
3
import pytest
4
import torch
5

6
from torchvision.transforms.v2 import functional as F
7

8

9
@pytest.mark.parametrize(
10
    ("alias", "target"),
11
    [
12
13
14
15
16
        pytest.param(alias, target, id=alias.__name__)
        for alias, target in [
            (F.hflip, F.horizontal_flip),
            (F.vflip, F.vertical_flip),
            (F.get_image_num_channels, F.get_num_channels),
17
            (F.to_pil_image, F.to_pil_image),
18
            (F.elastic_transform, F.elastic),
19
            (F.to_grayscale, F.rgb_to_grayscale),
20
        ]
21
22
    ],
)
23
24
def test_alias(alias, target):
    assert alias is target
25
26


27
28
29
30
31
32
33
@pytest.mark.parametrize(
    "inpt",
    [
        127 * np.ones((32, 32, 3), dtype="uint8"),
        PIL.Image.new("RGB", (32, 32), 122),
    ],
)
34
35
def test_to_image(inpt):
    output = F.to_image(inpt)
36
    assert isinstance(output, torch.Tensor)
37
    assert output.shape == (3, 32, 32)
38
39
40
41
42
43
44
45
46
47
48
49

    assert np.asarray(inpt).sum() == output.sum().item()


@pytest.mark.parametrize(
    "inpt",
    [
        torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
        127 * np.ones((32, 32, 3), dtype="uint8"),
    ],
)
@pytest.mark.parametrize("mode", [None, "RGB"])
50
51
def test_to_pil_image(inpt, mode):
    output = F.to_pil_image(inpt, mode=mode)
52
53
54
    assert isinstance(output, PIL.Image.Image)

    assert np.asarray(inpt).sum() == np.asarray(output).sum()