test_io.py 5.88 KB
Newer Older
1
import os
2
import contextlib
3
4
import tempfile
import torch
5
import torchvision.datasets.utils as utils
6
7
import torchvision.io as io
import unittest
8
9
import sys
import warnings
10

11
12
13
14
15
16
from common_utils import get_tmp_dir

if sys.version_info < (3,):
    from urllib2 import URLError
else:
    from urllib.error import URLError
17
18
19

try:
    import av
20
21
    # Do a version test too
    io.video._check_av_available()
22
23
24
25
except ImportError:
    av = None


26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def _create_video_frames(num_frames, height, width):
    y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
    data = []
    for i in range(num_frames):
        xc = float(i) / num_frames
        yc = 1 - float(i) / (2 * num_frames)
        d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255
        data.append(d.unsqueeze(2).repeat(1, 1, 3).byte())

    return torch.stack(data, 0)


@contextlib.contextmanager
def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, options=None):
    if lossless:
        assert video_codec is None, "video_codec can't be specified together with lossless"
        assert options is None, "options can't be specified together with lossless"
        video_codec = 'libx264rgb'
        options = {'crf': '0'}

    if video_codec is None:
        video_codec = 'libx264'
    if options is None:
        options = {}

    data = _create_video_frames(num_frames, height, width)
    with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
        io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options)
        yield f.name, data


57
@unittest.skipIf(av is None, "PyAV unavailable")
58
59
60
61
62
63
class Tester(unittest.TestCase):
    # compression adds artifacts, thus we add a tolerance of
    # 6 in 0-255 range
    TOLERANCE = 6

    def test_write_read_video(self):
64
65
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
            lv, _, info = io.read_video(f_name)
66

67
            self.assertTrue(data.equal(lv))
68
69
70
            self.assertEqual(info["video_fps"], 5)

    def test_read_timestamps(self):
71
72
        with temp_video(10, 300, 300, 5) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name)
73
74
75
76

            # note: not all formats/codecs provide accurate information for computing the
            # timestamps. For the format that we use here, this information is available,
            # so we use it as a baseline
77
            container = av.open(f_name)
78
79
80
81
82
83
84
85
            stream = container.streams[0]
            pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
            num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
            expected_pts = [i * pts_step for i in range(num_frames)]

            self.assertEqual(pts, expected_pts)

    def test_read_partial_video(self):
86
87
88
89
90
91
92
93
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name)
            for start in range(5):
                for l in range(1, 4):
                    lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
                    s_data = data[start:(start + l)]
                    self.assertEqual(len(lv), l)
                    self.assertTrue(s_data.equal(lv))
94

95
96
97
            lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
            self.assertEqual(len(lv), 4)
            self.assertTrue(data[4:8].equal(lv))
98

99
100
101
102
103
104
    def test_read_partial_video_bframes(self):
        # do not use lossless encoding, to test the presence of B-frames
        options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
        with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name)
            for start in range(0, 80, 20):
105
                for l in range(1, 4):
106
                    lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
107
108
109
110
                    s_data = data[start:(start + l)]
                    self.assertEqual(len(lv), l)
                    self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE)

111
            lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
112
113
114
            self.assertEqual(len(lv), 4)
            self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    def test_read_packed_b_frames_divx_file(self):
        with get_tmp_dir() as temp_dir:
            name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi"
            f_name = os.path.join(temp_dir, name)
            url = "https://download.pytorch.org/vision_tests/io/" + name
            try:
                utils.download_url(url, temp_dir)
                pts, fps = io.read_video_timestamps(f_name)
                self.assertEqual(pts, sorted(pts))
                self.assertEqual(fps, 30)
            except URLError:
                msg = "could not download test file '{}'".format(url)
                warnings.warn(msg, RuntimeWarning)
                raise unittest.SkipTest(msg)

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    def test_read_timestamps_from_packet(self):
        with temp_video(10, 300, 300, 5, video_codec='mpeg4') as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name)

            # note: not all formats/codecs provide accurate information for computing the
            # timestamps. For the format that we use here, this information is available,
            # so we use it as a baseline
            container = av.open(f_name)
            stream = container.streams[0]
            # make sure we went through the optimized codepath
            self.assertIn(b'Lavc', stream.codec_context.extradata)
            pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
            num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
            expected_pts = [i * pts_step for i in range(num_frames)]

            self.assertEqual(pts, expected_pts)

147
148
149
150
151
    # TODO add tests for audio


if __name__ == '__main__':
    unittest.main()