test_io.py 4.18 KB
Newer Older
1
import os
2
import contextlib
3
4
5
6
7
8
9
10
11
12
13
14
import tempfile
import torch
import torchvision.io as io
import unittest


try:
    import av
except ImportError:
    av = None


15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def _create_video_frames(num_frames, height, width):
    y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
    data = []
    for i in range(num_frames):
        xc = float(i) / num_frames
        yc = 1 - float(i) / (2 * num_frames)
        d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255
        data.append(d.unsqueeze(2).repeat(1, 1, 3).byte())

    return torch.stack(data, 0)


@contextlib.contextmanager
def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, options=None):
    if lossless:
        assert video_codec is None, "video_codec can't be specified together with lossless"
        assert options is None, "options can't be specified together with lossless"
        video_codec = 'libx264rgb'
        options = {'crf': '0'}

    if video_codec is None:
        video_codec = 'libx264'
    if options is None:
        options = {}

    data = _create_video_frames(num_frames, height, width)
    with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
        io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options)
        yield f.name, data


46
47
48
49
50
51
52
class Tester(unittest.TestCase):
    # compression adds artifacts, thus we add a tolerance of
    # 6 in 0-255 range
    TOLERANCE = 6

    @unittest.skipIf(av is None, "PyAV unavailable")
    def test_write_read_video(self):
53
54
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
            lv, _, info = io.read_video(f_name)
55

56
            self.assertTrue(data.equal(lv))
57
58
59
60
            self.assertEqual(info["video_fps"], 5)

    @unittest.skipIf(av is None, "PyAV unavailable")
    def test_read_timestamps(self):
61
62
        with temp_video(10, 300, 300, 5) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name)
63
64
65
66

            # note: not all formats/codecs provide accurate information for computing the
            # timestamps. For the format that we use here, this information is available,
            # so we use it as a baseline
67
            container = av.open(f_name)
68
69
70
71
72
73
74
75
76
            stream = container.streams[0]
            pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
            num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
            expected_pts = [i * pts_step for i in range(num_frames)]

            self.assertEqual(pts, expected_pts)

    @unittest.skipIf(av is None, "PyAV unavailable")
    def test_read_partial_video(self):
77
78
79
80
81
82
83
84
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name)
            for start in range(5):
                for l in range(1, 4):
                    lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
                    s_data = data[start:(start + l)]
                    self.assertEqual(len(lv), l)
                    self.assertTrue(s_data.equal(lv))
85

86
87
88
            lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
            self.assertEqual(len(lv), 4)
            self.assertTrue(data[4:8].equal(lv))
89

90
91
92
93
94
95
96
    @unittest.skipIf(av is None, "PyAV unavailable")
    def test_read_partial_video_bframes(self):
        # do not use lossless encoding, to test the presence of B-frames
        options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
        with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name)
            for start in range(0, 80, 20):
97
                for l in range(1, 4):
98
                    lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
99
100
101
102
                    s_data = data[start:(start + l)]
                    self.assertEqual(len(lv), l)
                    self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE)

103
            lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
104
105
106
107
108
109
110
111
            self.assertEqual(len(lv), 4)
            self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)

    # TODO add tests for audio


if __name__ == '__main__':
    unittest.main()