test_datasets_video_utils.py 3.97 KB
Newer Older
limm's avatar
limm committed
1
import pytest
2
import torch
limm's avatar
limm committed
3
from common_utils import assert_equal, get_list_of_videos
4
from torchvision import io
limm's avatar
limm committed
5
from torchvision.datasets.video_utils import unfold, VideoClips
6
7


limm's avatar
limm committed
8
class TestVideo:
9
10
11
12
    def test_unfold(self):
        a = torch.arange(7)

        r = unfold(a, 3, 3, 1)
limm's avatar
limm committed
13
14
15
16
17
18
19
        expected = torch.tensor(
            [
                [0, 1, 2],
                [3, 4, 5],
            ]
        )
        assert_equal(r, expected)
20
21

        r = unfold(a, 3, 2, 1)
limm's avatar
limm committed
22
23
        expected = torch.tensor([[0, 1, 2], [2, 3, 4], [4, 5, 6]])
        assert_equal(r, expected)
24
25

        r = unfold(a, 3, 2, 2)
limm's avatar
limm committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        expected = torch.tensor(
            [
                [0, 2, 4],
                [2, 4, 6],
            ]
        )
        assert_equal(r, expected)

    @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
    def test_video_clips(self, tmpdir):
        video_list = get_list_of_videos(tmpdir, num_videos=3)
        video_clips = VideoClips(video_list, 5, 5, num_workers=2)
        assert video_clips.num_clips() == 1 + 2 + 3
        for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]):
            video_idx, clip_idx = video_clips.get_clip_location(i)
            assert video_idx == v_idx
            assert clip_idx == c_idx

        video_clips = VideoClips(video_list, 6, 6)
        assert video_clips.num_clips() == 0 + 1 + 2
        for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]):
            video_idx, clip_idx = video_clips.get_clip_location(i)
            assert video_idx == v_idx
            assert clip_idx == c_idx

        video_clips = VideoClips(video_list, 6, 1)
        assert video_clips.num_clips() == 0 + (10 - 6 + 1) + (15 - 6 + 1)
        for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0), (6, 2, 1)]:
            video_idx, clip_idx = video_clips.get_clip_location(i)
            assert video_idx == v_idx
            assert clip_idx == c_idx

    @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
    def test_video_clips_custom_fps(self, tmpdir):
        video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6])
        num_frames = 4
        for fps in [1, 3, 4, 10]:
            video_clips = VideoClips(video_list, num_frames, num_frames, fps)
            for i in range(video_clips.num_clips()):
                video, audio, info, video_idx = video_clips.get_clip(i)
                assert video.shape[0] == num_frames
                assert info["video_fps"] == fps
                # TODO add tests checking that the content is right
69
70
71
72
73
74
75
76

    def test_compute_clips_for_video(self):
        video_pts = torch.arange(30)
        # case 1: single clip
        num_frames = 13
        orig_fps = 30
        duration = float(len(video_pts)) / orig_fps
        new_fps = 13
limm's avatar
limm committed
77
        clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps)
78
        resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps)
79
80
81
        assert len(clips) == 1
        assert_equal(clips, idxs)
        assert_equal(idxs[0], resampled_idxs)
82
83
84
85
86
87

        # case 2: all frames appear only once
        num_frames = 4
        orig_fps = 30
        duration = float(len(video_pts)) / orig_fps
        new_fps = 12
limm's avatar
limm committed
88
        clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps)
89
        resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps)
90
91
92
        assert len(clips) == 3
        assert_equal(clips, idxs)
        assert_equal(idxs.flatten(), resampled_idxs)
93

94
95
96
97
        # case 3: frames aren't enough for a clip
        num_frames = 32
        orig_fps = 30
        new_fps = 13
limm's avatar
limm committed
98
99
        with pytest.warns(UserWarning):
            clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps)
100
101
        assert len(clips) == 0
        assert len(idxs) == 0
102

103

limm's avatar
limm committed
104
105
if __name__ == "__main__":
    pytest.main([__file__])