test_transforms_video.py 6.4 KB
Newer Older
Zhicheng Yan's avatar
Zhicheng Yan committed
1
import torch
2
from torchvision.transforms import Compose
3
import pytest
Zhicheng Yan's avatar
Zhicheng Yan committed
4
5
import random
import numpy as np
6
import warnings
7
from common_utils import assert_equal
Zhicheng Yan's avatar
Zhicheng Yan committed
8
9
10
11
12
13
14

try:
    from scipy import stats
except ImportError:
    stats = None


15
16
17
18
19
with warnings.catch_warnings(record=True):
    warnings.simplefilter("always")
    import torchvision.transforms._transforms_video as transforms


20
class TestVideoTransforms():
Zhicheng Yan's avatar
Zhicheng Yan committed
21
22
23
24
25
26
27
28

    def test_random_crop_video(self):
        numFrames = random.randint(4, 128)
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        oheight = random.randint(5, (height - 2) / 2) * 2
        owidth = random.randint(5, (width - 2) / 2) * 2
        clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
29
        result = Compose([
Zhicheng Yan's avatar
Zhicheng Yan committed
30
31
32
            transforms.ToTensorVideo(),
            transforms.RandomCropVideo((oheight, owidth)),
        ])(clip)
33
34
        assert result.size(2) == oheight
        assert result.size(3) == owidth
Zhicheng Yan's avatar
Zhicheng Yan committed
35
36
37
38
39
40
41
42
43
44

        transforms.RandomCropVideo((oheight, owidth)).__repr__()

    def test_random_resized_crop_video(self):
        numFrames = random.randint(4, 128)
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        oheight = random.randint(5, (height - 2) / 2) * 2
        owidth = random.randint(5, (width - 2) / 2) * 2
        clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
45
        result = Compose([
Zhicheng Yan's avatar
Zhicheng Yan committed
46
47
48
            transforms.ToTensorVideo(),
            transforms.RandomResizedCropVideo((oheight, owidth)),
        ])(clip)
49
50
        assert result.size(2) == oheight
        assert result.size(3) == owidth
Zhicheng Yan's avatar
Zhicheng Yan committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

        transforms.RandomResizedCropVideo((oheight, owidth)).__repr__()

    def test_center_crop_video(self):
        numFrames = random.randint(4, 128)
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        oheight = random.randint(5, (height - 2) / 2) * 2
        owidth = random.randint(5, (width - 2) / 2) * 2

        clip = torch.ones((numFrames, height, width, 3), dtype=torch.uint8) * 255
        oh1 = (height - oheight) // 2
        ow1 = (width - owidth) // 2
        clipNarrow = clip[:, oh1:oh1 + oheight, ow1:ow1 + owidth, :]
        clipNarrow.fill_(0)
66
        result = Compose([
Zhicheng Yan's avatar
Zhicheng Yan committed
67
68
69
70
71
72
            transforms.ToTensorVideo(),
            transforms.CenterCropVideo((oheight, owidth)),
        ])(clip)

        msg = "height: " + str(height) + " width: " \
            + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
73
        assert result.sum().item() == 0, msg
Zhicheng Yan's avatar
Zhicheng Yan committed
74
75
76

        oheight += 1
        owidth += 1
77
        result = Compose([
Zhicheng Yan's avatar
Zhicheng Yan committed
78
79
80
81
82
83
84
            transforms.ToTensorVideo(),
            transforms.CenterCropVideo((oheight, owidth)),
        ])(clip)
        sum1 = result.sum()

        msg = "height: " + str(height) + " width: " \
            + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
85
        assert sum1.item() > 1, msg
Zhicheng Yan's avatar
Zhicheng Yan committed
86
87
88

        oheight += 1
        owidth += 1
89
        result = Compose([
Zhicheng Yan's avatar
Zhicheng Yan committed
90
91
92
93
94
95
96
            transforms.ToTensorVideo(),
            transforms.CenterCropVideo((oheight, owidth)),
        ])(clip)
        sum2 = result.sum()

        msg = "height: " + str(height) + " width: " \
            + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
97
98
        assert sum2.item() > 1, msg
        assert sum2.item() > sum1.item(), msg
Zhicheng Yan's avatar
Zhicheng Yan committed
99

100
101
102
    @pytest.mark.skipif(stats is None, reason='scipy.stats is not available')
    @pytest.mark.parametrize('channels', [1, 3])
    def test_normalize_video(self, channels):
Zhicheng Yan's avatar
Zhicheng Yan committed
103
104
105
106
107
108
        def samples_from_standard_normal(tensor):
            p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue
            return p_value > 0.0001

        random_state = random.getstate()
        random.seed(42)
109
110
111
112
113
114
115
116
117
118
119

        numFrames = random.randint(4, 128)
        height = random.randint(32, 256)
        width = random.randint(32, 256)
        mean = random.random()
        std = random.random()
        clip = torch.normal(mean, std, size=(channels, numFrames, height, width))
        mean = [clip[c].mean().item() for c in range(channels)]
        std = [clip[c].std().item() for c in range(channels)]
        normalized = transforms.NormalizeVideo(mean, std)(clip)
        assert samples_from_standard_normal(normalized)
Zhicheng Yan's avatar
Zhicheng Yan committed
120
121
122
123
124
        random.setstate(random_state)

        # Checking the optional in-place behaviour
        tensor = torch.rand((3, 128, 16, 16))
        tensor_inplace = transforms.NormalizeVideo((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)(tensor)
125
        assert_equal(tensor, tensor_inplace)
Zhicheng Yan's avatar
Zhicheng Yan committed
126
127
128
129
130
131
132

        transforms.NormalizeVideo((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True).__repr__()

    def test_to_tensor_video(self):
        numFrames, height, width = 64, 4, 4
        trans = transforms.ToTensorVideo()

133
        with pytest.raises(TypeError):
Zhicheng Yan's avatar
Zhicheng Yan committed
134
135
136
            trans(np.random.rand(numFrames, height, width, 1).tolist())
            trans(torch.rand((numFrames, height, width, 1), dtype=torch.float))

137
        with pytest.raises(ValueError):
Zhicheng Yan's avatar
Zhicheng Yan committed
138
139
140
141
142
143
144
            trans(torch.ones((3, numFrames, height, width, 3), dtype=torch.uint8))
            trans(torch.ones((height, width, 3), dtype=torch.uint8))
            trans(torch.ones((width, 3), dtype=torch.uint8))
            trans(torch.ones((3), dtype=torch.uint8))

        trans.__repr__()

145
    @pytest.mark.skipif(stats is None, reason='scipy.stats not available')
Zhicheng Yan's avatar
Zhicheng Yan committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    def test_random_horizontal_flip_video(self):
        random_state = random.getstate()
        random.seed(42)
        clip = torch.rand((3, 4, 112, 112), dtype=torch.float)
        hclip = clip.flip((-1))

        num_samples = 250
        num_horizontal = 0
        for _ in range(num_samples):
            out = transforms.RandomHorizontalFlipVideo()(clip)
            if torch.all(torch.eq(out, hclip)):
                num_horizontal += 1

        p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
        random.setstate(random_state)
161
        assert p_value > 0.0001
Zhicheng Yan's avatar
Zhicheng Yan committed
162
163
164
165
166
167
168
169
170
171

        num_samples = 250
        num_horizontal = 0
        for _ in range(num_samples):
            out = transforms.RandomHorizontalFlipVideo(p=0.7)(clip)
            if torch.all(torch.eq(out, hclip)):
                num_horizontal += 1

        p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
        random.setstate(random_state)
172
        assert p_value > 0.0001
Zhicheng Yan's avatar
Zhicheng Yan committed
173
174
175
176
177

        transforms.RandomHorizontalFlipVideo().__repr__()


if __name__ == '__main__':
178
    pytest.main([__file__])