test_io.py 11.5 KB
Newer Older
1
import os
2
import contextlib
3
4
import tempfile
import torch
5
import torchvision.datasets.utils as utils
6
import torchvision.io as io
7
from torchvision import get_video_backend
8
import unittest
9
import warnings
10
from urllib.error import URLError
11

12
13
from common_utils import get_tmp_dir

14
15
16

try:
    import av
17
18
    # Do a version test too
    io.video._check_av_available()
19
20
21
22
except ImportError:
    av = None


23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def _create_video_frames(num_frames, height, width):
    y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
    data = []
    for i in range(num_frames):
        xc = float(i) / num_frames
        yc = 1 - float(i) / (2 * num_frames)
        d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255
        data.append(d.unsqueeze(2).repeat(1, 1, 3).byte())

    return torch.stack(data, 0)


@contextlib.contextmanager
def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, options=None):
    if lossless:
38
39
40
41
        if video_codec is not None:
            raise ValueError("video_codec can't be specified together with lossless")
        if options is not None:
            raise ValueError("options can't be specified together with lossless")
42
43
44
45
        video_codec = 'libx264rgb'
        options = {'crf': '0'}

    if video_codec is None:
Francisco Massa's avatar
Francisco Massa committed
46
        if get_video_backend() == "pyav":
47
48
49
50
51
            video_codec = 'libx264'
        else:
            # when video_codec is not set, we assume it is libx264rgb which accepts
            # RGB pixel formats as input instead of YUV
            video_codec = 'libx264rgb'
52
53
54
55
56
    if options is None:
        options = {}

    data = _create_video_frames(num_frames, height, width)
    with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
57
        f.close()
58
59
        io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options)
        yield f.name, data
60
    os.unlink(f.name)
61

Francisco Massa's avatar
Francisco Massa committed
62
63
@unittest.skipIf(get_video_backend() != "pyav" and not io._HAS_VIDEO_OPT,
                 "video_reader backend not available")
64
@unittest.skipIf(av is None, "PyAV unavailable")
65
66
67
68
69
70
class Tester(unittest.TestCase):
    # compression adds artifacts, thus we add a tolerance of
    # 6 in 0-255 range
    TOLERANCE = 6

    def test_write_read_video(self):
71
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
Francisco Massa's avatar
Francisco Massa committed
72
            lv, _, info = io.read_video(f_name)
73
            self.assertTrue(data.equal(lv))
74
75
            self.assertEqual(info["video_fps"], 5)

76
77
78
79
    @unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
    def test_probe_video_from_file(self):
        with temp_video(10, 300, 300, 5) as (f_name, data):
            video_info = io._probe_video_from_file(f_name)
80
81
            self.assertAlmostEqual(video_info.video_duration, 2, delta=0.1)
            self.assertAlmostEqual(video_info.video_fps, 5, delta=0.1)
82
83
84
85
86
87
88

    @unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
    def test_probe_video_from_memory(self):
        with temp_video(10, 300, 300, 5) as (f_name, data):
            with open(f_name, "rb") as fp:
                filebuffer = fp.read()
            video_info = io._probe_video_from_memory(filebuffer)
89
90
            self.assertAlmostEqual(video_info.video_duration, 2, delta=0.1)
            self.assertAlmostEqual(video_info.video_fps, 5, delta=0.1)
91

92
    def test_read_timestamps(self):
93
        with temp_video(10, 300, 300, 5) as (f_name, data):
Francisco Massa's avatar
Francisco Massa committed
94
            pts, _ = io.read_video_timestamps(f_name)
95
96
97
            # note: not all formats/codecs provide accurate information for computing the
            # timestamps. For the format that we use here, this information is available,
            # so we use it as a baseline
98
            container = av.open(f_name)
99
100
101
102
103
104
            stream = container.streams[0]
            pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
            num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
            expected_pts = [i * pts_step for i in range(num_frames)]

            self.assertEqual(pts, expected_pts)
105
            container.close()
106
107

    def test_read_partial_video(self):
108
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
Francisco Massa's avatar
Francisco Massa committed
109
            pts, _ = io.read_video_timestamps(f_name)
110
111
            for start in range(5):
                for l in range(1, 4):
Francisco Massa's avatar
Francisco Massa committed
112
                    lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
113
114
115
                    s_data = data[start:(start + l)]
                    self.assertEqual(len(lv), l)
                    self.assertTrue(s_data.equal(lv))
116

Francisco Massa's avatar
Francisco Massa committed
117
            if get_video_backend() == "pyav":
118
119
                # for "video_reader" backend, we don't decode the closest early frame
                # when the given start pts is not matching any frame pts
Francisco Massa's avatar
Francisco Massa committed
120
                lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
121
122
                self.assertEqual(len(lv), 4)
                self.assertTrue(data[4:8].equal(lv))
123

124
125
126
127
    def test_read_partial_video_bframes(self):
        # do not use lossless encoding, to test the presence of B-frames
        options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
        with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
Francisco Massa's avatar
Francisco Massa committed
128
            pts, _ = io.read_video_timestamps(f_name)
129
            for start in range(0, 80, 20):
130
                for l in range(1, 4):
Francisco Massa's avatar
Francisco Massa committed
131
                    lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
132
133
134
135
                    s_data = data[start:(start + l)]
                    self.assertEqual(len(lv), l)
                    self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE)

136
            lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
Francisco Massa's avatar
Francisco Massa committed
137
138
139
140
141
142
143
            # TODO fix this
            if get_video_backend() == 'pyav':
                self.assertEqual(len(lv), 4)
                self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)
            else:
                self.assertEqual(len(lv), 3)
                self.assertTrue((data[5:8].float() - lv.float()).abs().max() < self.TOLERANCE)
144

145
146
147
148
149
150
151
    def test_read_packed_b_frames_divx_file(self):
        with get_tmp_dir() as temp_dir:
            name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi"
            f_name = os.path.join(temp_dir, name)
            url = "https://download.pytorch.org/vision_tests/io/" + name
            try:
                utils.download_url(url, temp_dir)
Francisco Massa's avatar
Francisco Massa committed
152
                pts, fps = io.read_video_timestamps(f_name)
153

154
155
156
157
158
159
160
                self.assertEqual(pts, sorted(pts))
                self.assertEqual(fps, 30)
            except URLError:
                msg = "could not download test file '{}'".format(url)
                warnings.warn(msg, RuntimeWarning)
                raise unittest.SkipTest(msg)

161
162
    def test_read_timestamps_from_packet(self):
        with temp_video(10, 300, 300, 5, video_codec='mpeg4') as (f_name, data):
Francisco Massa's avatar
Francisco Massa committed
163
            pts, _ = io.read_video_timestamps(f_name)
164
165
166
167
168
169
170
171
172
173
174
175
            # note: not all formats/codecs provide accurate information for computing the
            # timestamps. For the format that we use here, this information is available,
            # so we use it as a baseline
            container = av.open(f_name)
            stream = container.streams[0]
            # make sure we went through the optimized codepath
            self.assertIn(b'Lavc', stream.codec_context.extradata)
            pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
            num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
            expected_pts = [i * pts_step for i in range(num_frames)]

            self.assertEqual(pts, expected_pts)
176
            container.close()
177

178
179
180
181
182
183
    def test_read_video_pts_unit_sec(self):
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
            lv, _, info = io.read_video(f_name, pts_unit='sec')

            self.assertTrue(data.equal(lv))
            self.assertEqual(info["video_fps"], 5)
184
            self.assertEqual(info, {"video_fps": 5})
185
186
187
188
189
190
191
192
193
194
195
196

    def test_read_timestamps_pts_unit_sec(self):
        with temp_video(10, 300, 300, 5) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name, pts_unit='sec')

            container = av.open(f_name)
            stream = container.streams[0]
            pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
            num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
            expected_pts = [i * pts_step * stream.time_base for i in range(num_frames)]

            self.assertEqual(pts, expected_pts)
197
            container.close()
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214

    def test_read_partial_video_pts_unit_sec(self):
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name, pts_unit='sec')

            for start in range(5):
                for l in range(1, 4):
                    lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1], pts_unit='sec')
                    s_data = data[start:(start + l)]
                    self.assertEqual(len(lv), l)
                    self.assertTrue(s_data.equal(lv))

            container = av.open(f_name)
            stream = container.streams[0]
            lv, _, _ = io.read_video(f_name,
                                     int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7],
                                     pts_unit='sec')
Francisco Massa's avatar
Francisco Massa committed
215
216
217
218
219
            if get_video_backend() == "pyav":
                # for "video_reader" backend, we don't decode the closest early frame
                # when the given start pts is not matching any frame pts
                self.assertEqual(len(lv), 4)
                self.assertTrue(data[4:8].equal(lv))
220
            container.close()
221

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    def test_read_video_corrupted_file(self):
        with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
            f.write(b'This is not an mpg4 file')
            video, audio, info = io.read_video(f.name)
            self.assertIsInstance(video, torch.Tensor)
            self.assertIsInstance(audio, torch.Tensor)
            self.assertEqual(video.numel(), 0)
            self.assertEqual(audio.numel(), 0)
            self.assertEqual(info, {})

    def test_read_video_timestamps_corrupted_file(self):
        with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
            f.write(b'This is not an mpg4 file')
            video_pts, video_fps = io.read_video_timestamps(f.name)
            self.assertEqual(video_pts, [])
            self.assertIs(video_fps, None)

    def test_read_video_partially_corrupted_file(self):
        with temp_video(5, 4, 4, 5, lossless=True) as (f_name, data):
            with open(f_name, 'r+b') as f:
                size = os.path.getsize(f_name)
                bytes_to_overwrite = size // 10
                # seek to the middle of the file
                f.seek(5 * bytes_to_overwrite)
                # corrupt 10% of the file from the middle
                f.write(b'\xff' * bytes_to_overwrite)
            # this exercises the container.decode assertion check
            video, audio, info = io.read_video(f.name, pts_unit='sec')
            # check that size is not equal to 5, but 3
Francisco Massa's avatar
Francisco Massa committed
251
252
253
254
255
            # TODO fix this
            if get_video_backend() == 'pyav':
                self.assertEqual(len(video), 3)
            else:
                self.assertEqual(len(video), 4)
256
257
258
259
260
            # but the valid decoded content is still correct
            self.assertTrue(video[:3].equal(data[:3]))
            # and the last few frames are wrong
            self.assertFalse(video.equal(data))

261
262
263
264
265
    # TODO add tests for audio


if __name__ == '__main__':
    unittest.main()