test_io.py 10.1 KB
Newer Older
1
import os
2
import contextlib
3
4
import tempfile
import torch
5
import torchvision.datasets.utils as utils
6
import torchvision.io as io
7
from torchvision import get_video_backend
8
import unittest
9
10
import sys
import warnings
11

12
13
14
15
16
17
from common_utils import get_tmp_dir

if sys.version_info < (3,):
    from urllib2 import URLError
else:
    from urllib.error import URLError
18
19
20

try:
    import av
21
22
    # Do a version test too
    io.video._check_av_available()
23
24
25
except ImportError:
    av = None

26
27
28
29
30
31
32
33
34
35
36
37
38
39
_video_backend = get_video_backend()


def _read_video(filename, start_pts=0, end_pts=None):
    if _video_backend == "pyav":
        return io.read_video(filename, start_pts, end_pts)
    else:
        if end_pts is None:
            end_pts = -1
        return io._read_video_from_file(
            filename,
            video_pts_range=(start_pts, end_pts),
        )

40

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def _create_video_frames(num_frames, height, width):
    y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
    data = []
    for i in range(num_frames):
        xc = float(i) / num_frames
        yc = 1 - float(i) / (2 * num_frames)
        d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255
        data.append(d.unsqueeze(2).repeat(1, 1, 3).byte())

    return torch.stack(data, 0)


@contextlib.contextmanager
def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, options=None):
    if lossless:
        assert video_codec is None, "video_codec can't be specified together with lossless"
        assert options is None, "options can't be specified together with lossless"
        video_codec = 'libx264rgb'
        options = {'crf': '0'}

    if video_codec is None:
62
63
64
65
66
67
        if _video_backend == "pyav":
            video_codec = 'libx264'
        else:
            # when video_codec is not set, we assume it is libx264rgb which accepts
            # RGB pixel formats as input instead of YUV
            video_codec = 'libx264rgb'
68
69
70
71
72
73
74
75
76
    if options is None:
        options = {}

    data = _create_video_frames(num_frames, height, width)
    with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
        io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options)
        yield f.name, data


77
@unittest.skipIf(av is None, "PyAV unavailable")
Francisco Massa's avatar
Francisco Massa committed
78
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
79
80
81
82
83
84
class Tester(unittest.TestCase):
    # compression adds artifacts, thus we add a tolerance of
    # 6 in 0-255 range
    TOLERANCE = 6

    def test_write_read_video(self):
85
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
86
            lv, _, info = _read_video(f_name)
87
            self.assertTrue(data.equal(lv))
88
89
            self.assertEqual(info["video_fps"], 5)

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    @unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
    def test_probe_video_from_file(self):
        with temp_video(10, 300, 300, 5) as (f_name, data):
            video_info = io._probe_video_from_file(f_name)
            self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1)
            self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1)

    @unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
    def test_probe_video_from_memory(self):
        with temp_video(10, 300, 300, 5) as (f_name, data):
            with open(f_name, "rb") as fp:
                filebuffer = fp.read()
            video_info = io._probe_video_from_memory(filebuffer)
            self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1)
            self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1)

106
    def test_read_timestamps(self):
107
        with temp_video(10, 300, 300, 5) as (f_name, data):
108
109
110
111
            if _video_backend == "pyav":
                pts, _ = io.read_video_timestamps(f_name)
            else:
                pts, _, _ = io._read_video_timestamps_from_file(f_name)
112
113
114
            # note: not all formats/codecs provide accurate information for computing the
            # timestamps. For the format that we use here, this information is available,
            # so we use it as a baseline
115
            container = av.open(f_name)
116
117
118
119
120
121
122
123
            stream = container.streams[0]
            pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
            num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
            expected_pts = [i * pts_step for i in range(num_frames)]

            self.assertEqual(pts, expected_pts)

    def test_read_partial_video(self):
124
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
125
126
127
128
            if _video_backend == "pyav":
                pts, _ = io.read_video_timestamps(f_name)
            else:
                pts, _, _ = io._read_video_timestamps_from_file(f_name)
129
130
            for start in range(5):
                for l in range(1, 4):
131
                    lv, _, _ = _read_video(f_name, pts[start], pts[start + l - 1])
132
133
134
                    s_data = data[start:(start + l)]
                    self.assertEqual(len(lv), l)
                    self.assertTrue(s_data.equal(lv))
135

136
137
138
139
140
141
            if _video_backend == "pyav":
                # for "video_reader" backend, we don't decode the closest early frame
                # when the given start pts is not matching any frame pts
                lv, _, _ = _read_video(f_name, pts[4] + 1, pts[7])
                self.assertEqual(len(lv), 4)
                self.assertTrue(data[4:8].equal(lv))
142

143
144
145
146
    def test_read_partial_video_bframes(self):
        # do not use lossless encoding, to test the presence of B-frames
        options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
        with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
147
148
149
150
            if _video_backend == "pyav":
                pts, _ = io.read_video_timestamps(f_name)
            else:
                pts, _, _ = io._read_video_timestamps_from_file(f_name)
151
            for start in range(0, 80, 20):
152
                for l in range(1, 4):
153
                    lv, _, _ = _read_video(f_name, pts[start], pts[start + l - 1])
154
155
156
157
                    s_data = data[start:(start + l)]
                    self.assertEqual(len(lv), l)
                    self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE)

158
            lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
159
160
161
            self.assertEqual(len(lv), 4)
            self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)

162
163
164
165
166
167
168
    def test_read_packed_b_frames_divx_file(self):
        with get_tmp_dir() as temp_dir:
            name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi"
            f_name = os.path.join(temp_dir, name)
            url = "https://download.pytorch.org/vision_tests/io/" + name
            try:
                utils.download_url(url, temp_dir)
169
170
171
172
173
174
                if _video_backend == "pyav":
                    pts, fps = io.read_video_timestamps(f_name)
                else:
                    pts, _, info = io._read_video_timestamps_from_file(f_name)
                    fps = info["video_fps"]

175
176
177
178
179
180
181
                self.assertEqual(pts, sorted(pts))
                self.assertEqual(fps, 30)
            except URLError:
                msg = "could not download test file '{}'".format(url)
                warnings.warn(msg, RuntimeWarning)
                raise unittest.SkipTest(msg)

182
183
    def test_read_timestamps_from_packet(self):
        with temp_video(10, 300, 300, 5, video_codec='mpeg4') as (f_name, data):
184
185
186
187
            if _video_backend == "pyav":
                pts, _ = io.read_video_timestamps(f_name)
            else:
                pts, _, _ = io._read_video_timestamps_from_file(f_name)
188
189
190
191
192
193
194
195
196
197
198
199
200
            # note: not all formats/codecs provide accurate information for computing the
            # timestamps. For the format that we use here, this information is available,
            # so we use it as a baseline
            container = av.open(f_name)
            stream = container.streams[0]
            # make sure we went through the optimized codepath
            self.assertIn(b'Lavc', stream.codec_context.extradata)
            pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
            num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
            expected_pts = [i * pts_step for i in range(num_frames)]

            self.assertEqual(pts, expected_pts)

201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    def test_read_video_pts_unit_sec(self):
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
            lv, _, info = io.read_video(f_name, pts_unit='sec')

            self.assertTrue(data.equal(lv))
            self.assertEqual(info["video_fps"], 5)

    def test_read_timestamps_pts_unit_sec(self):
        with temp_video(10, 300, 300, 5) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name, pts_unit='sec')

            container = av.open(f_name)
            stream = container.streams[0]
            pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
            num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
            expected_pts = [i * pts_step * stream.time_base for i in range(num_frames)]

            self.assertEqual(pts, expected_pts)

    def test_read_partial_video_pts_unit_sec(self):
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name, pts_unit='sec')

            for start in range(5):
                for l in range(1, 4):
                    lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1], pts_unit='sec')
                    s_data = data[start:(start + l)]
                    self.assertEqual(len(lv), l)
                    self.assertTrue(s_data.equal(lv))

            container = av.open(f_name)
            stream = container.streams[0]
            lv, _, _ = io.read_video(f_name,
                                     int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7],
                                     pts_unit='sec')
            self.assertEqual(len(lv), 4)
            self.assertTrue(data[4:8].equal(lv))

239
240
241
242
243
    # TODO add tests for audio


if __name__ == '__main__':
    unittest.main()