test_datasets_samplers.py 4.61 KB
Newer Older
1
2
3
4
import contextlib
import sys
import os
import torch
5
import pytest
6
7

from torchvision import io
8
9
10
11
12
from torchvision.datasets.samplers import (
    DistributedSampler,
    RandomClipSampler,
    UniformClipSampler,
)
13
14
15
16
from torchvision.datasets.video_utils import VideoClips, unfold
from torchvision import get_video_backend

from common_utils import get_tmp_dir
17
from _assert_utils import assert_equal
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32


@contextlib.contextmanager
def get_list_of_videos(num_videos=5, sizes=None, fps=None):
    with get_tmp_dir() as tmp_dir:
        names = []
        for i in range(num_videos):
            if sizes is None:
                size = 5 * (i + 1)
            else:
                size = sizes[i]
            if fps is None:
                f = 5
            else:
                f = fps[i]
33
            data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
34
35
36
37
38
39
40
            name = os.path.join(tmp_dir, "{}.mp4".format(i))
            names.append(name)
            io.write_video(name, data, fps=f)

        yield names


41
42
@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
class TestDatasetsSamplers:
43
44
45
46
    def test_random_clip_sampler(self):
        with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
            video_clips = VideoClips(video_list, 5, 5)
            sampler = RandomClipSampler(video_clips, 3)
47
            assert len(sampler) == 3 * 3
48
            indices = torch.tensor(list(iter(sampler)))
49
            videos = torch.div(indices, 5, rounding_mode='floor')
50
            v_idxs, count = torch.unique(videos, return_counts=True)
51
52
            assert_equal(v_idxs, torch.tensor([0, 1, 2]))
            assert_equal(count, torch.tensor([3, 3, 3]))
53
54
55
56
57

    def test_random_clip_sampler_unequal(self):
        with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
            video_clips = VideoClips(video_list, 5, 5)
            sampler = RandomClipSampler(video_clips, 3)
58
            assert len(sampler) == 2 + 3 + 3
59
            indices = list(iter(sampler))
60
61
            assert 0 in indices
            assert 1 in indices
62
63
64
65
            # remove elements of the first video, to simplify testing
            indices.remove(0)
            indices.remove(1)
            indices = torch.tensor(indices) - 2
66
            videos = torch.div(indices, 5, rounding_mode='floor')
67
            v_idxs, count = torch.unique(videos, return_counts=True)
68
69
            assert_equal(v_idxs, torch.tensor([0, 1]))
            assert_equal(count, torch.tensor([3, 3]))
70
71
72
73
74

    def test_uniform_clip_sampler(self):
        with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
            video_clips = VideoClips(video_list, 5, 5)
            sampler = UniformClipSampler(video_clips, 3)
75
            assert len(sampler) == 3 * 3
76
            indices = torch.tensor(list(iter(sampler)))
77
            videos = torch.div(indices, 5, rounding_mode='floor')
78
            v_idxs, count = torch.unique(videos, return_counts=True)
79
80
81
            assert_equal(v_idxs, torch.tensor([0, 1, 2]))
            assert_equal(count, torch.tensor([3, 3, 3]))
            assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]))
82
83
84
85
86

    def test_uniform_clip_sampler_insufficient_clips(self):
        with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
            video_clips = VideoClips(video_list, 5, 5)
            sampler = UniformClipSampler(video_clips, 3)
87
            assert len(sampler) == 3 * 3
88
            indices = torch.tensor(list(iter(sampler)))
89
            assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))
90

91
92
93
94
95
96
97
98
99
100
101
102
    def test_distributed_sampler_and_uniform_clip_sampler(self):
        with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
            video_clips = VideoClips(video_list, 5, 5)
            clip_sampler = UniformClipSampler(video_clips, 3)

            distributed_sampler_rank0 = DistributedSampler(
                clip_sampler,
                num_replicas=2,
                rank=0,
                group_size=3,
            )
            indices = torch.tensor(list(iter(distributed_sampler_rank0)))
103
            assert len(distributed_sampler_rank0) == 6
104
            assert_equal(indices, torch.tensor([0, 2, 4, 10, 12, 14]))
105
106
107
108
109
110
111
112

            distributed_sampler_rank1 = DistributedSampler(
                clip_sampler,
                num_replicas=2,
                rank=1,
                group_size=3,
            )
            indices = torch.tensor(list(iter(distributed_sampler_rank1)))
113
            assert len(distributed_sampler_rank1) == 6
114
            assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4]))
115

116
117

if __name__ == '__main__':
118
    pytest.main([__file__])