test_io.py 13.1 KB
Newer Older
1
import os
2
import contextlib
3
import sys
4
5
import tempfile
import torch
6
import torchvision.datasets.utils as utils
7
import torchvision.io as io
8
from torchvision import get_video_backend
9
import unittest
10
import warnings
11
from urllib.error import URLError
12

13
14
from common_utils import get_tmp_dir

15
16
17

try:
    import av
18
19
    # Do a version test too
    io.video._check_av_available()
20
21
22
23
except ImportError:
    av = None


24
25
26
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")


27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def _create_video_frames(num_frames, height, width):
    y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
    data = []
    for i in range(num_frames):
        xc = float(i) / num_frames
        yc = 1 - float(i) / (2 * num_frames)
        d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255
        data.append(d.unsqueeze(2).repeat(1, 1, 3).byte())

    return torch.stack(data, 0)


@contextlib.contextmanager
def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, options=None):
    if lossless:
42
43
44
45
        if video_codec is not None:
            raise ValueError("video_codec can't be specified together with lossless")
        if options is not None:
            raise ValueError("options can't be specified together with lossless")
46
47
48
49
        video_codec = 'libx264rgb'
        options = {'crf': '0'}

    if video_codec is None:
Francisco Massa's avatar
Francisco Massa committed
50
        if get_video_backend() == "pyav":
51
52
53
54
55
            video_codec = 'libx264'
        else:
            # when video_codec is not set, we assume it is libx264rgb which accepts
            # RGB pixel formats as input instead of YUV
            video_codec = 'libx264rgb'
56
57
58
59
60
    if options is None:
        options = {}

    data = _create_video_frames(num_frames, height, width)
    with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
61
        f.close()
62
63
        io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options)
        yield f.name, data
64
    os.unlink(f.name)
65

Francisco Massa's avatar
Francisco Massa committed
66

Francisco Massa's avatar
Francisco Massa committed
67
68
@unittest.skipIf(get_video_backend() != "pyav" and not io._HAS_VIDEO_OPT,
                 "video_reader backend not available")
69
@unittest.skipIf(av is None, "PyAV unavailable")
70
class TestIO(unittest.TestCase):
71
72
73
74
75
    # compression adds artifacts, thus we add a tolerance of
    # 6 in 0-255 range
    TOLERANCE = 6

    def test_write_read_video(self):
76
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
Francisco Massa's avatar
Francisco Massa committed
77
            lv, _, info = io.read_video(f_name)
78
            self.assertTrue(data.equal(lv))
79
80
            self.assertEqual(info["video_fps"], 5)

81
82
83
84
    @unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
    def test_probe_video_from_file(self):
        with temp_video(10, 300, 300, 5) as (f_name, data):
            video_info = io._probe_video_from_file(f_name)
85
86
            self.assertAlmostEqual(video_info.video_duration, 2, delta=0.1)
            self.assertAlmostEqual(video_info.video_fps, 5, delta=0.1)
87
88
89
90
91
92
93

    @unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
    def test_probe_video_from_memory(self):
        with temp_video(10, 300, 300, 5) as (f_name, data):
            with open(f_name, "rb") as fp:
                filebuffer = fp.read()
            video_info = io._probe_video_from_memory(filebuffer)
94
95
            self.assertAlmostEqual(video_info.video_duration, 2, delta=0.1)
            self.assertAlmostEqual(video_info.video_fps, 5, delta=0.1)
96

97
    def test_read_timestamps(self):
98
        with temp_video(10, 300, 300, 5) as (f_name, data):
Francisco Massa's avatar
Francisco Massa committed
99
            pts, _ = io.read_video_timestamps(f_name)
100
101
102
            # note: not all formats/codecs provide accurate information for computing the
            # timestamps. For the format that we use here, this information is available,
            # so we use it as a baseline
103
            container = av.open(f_name)
104
105
106
107
108
109
            stream = container.streams[0]
            pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
            num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
            expected_pts = [i * pts_step for i in range(num_frames)]

            self.assertEqual(pts, expected_pts)
110
            container.close()
111
112

    def test_read_partial_video(self):
113
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
Francisco Massa's avatar
Francisco Massa committed
114
            pts, _ = io.read_video_timestamps(f_name)
115
            for start in range(5):
Francisco Massa's avatar
Francisco Massa committed
116
117
118
119
                for offset in range(1, 4):
                    lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1])
                    s_data = data[start:(start + offset)]
                    self.assertEqual(len(lv), offset)
120
                    self.assertTrue(s_data.equal(lv))
121

Francisco Massa's avatar
Francisco Massa committed
122
            if get_video_backend() == "pyav":
123
124
                # for "video_reader" backend, we don't decode the closest early frame
                # when the given start pts is not matching any frame pts
Francisco Massa's avatar
Francisco Massa committed
125
                lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
126
127
                self.assertEqual(len(lv), 4)
                self.assertTrue(data[4:8].equal(lv))
128

129
130
131
132
    def test_read_partial_video_bframes(self):
        # do not use lossless encoding, to test the presence of B-frames
        options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
        with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
Francisco Massa's avatar
Francisco Massa committed
133
            pts, _ = io.read_video_timestamps(f_name)
134
            for start in range(0, 80, 20):
Francisco Massa's avatar
Francisco Massa committed
135
136
137
138
                for offset in range(1, 4):
                    lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1])
                    s_data = data[start:(start + offset)]
                    self.assertEqual(len(lv), offset)
139
140
                    self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE)

141
            lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
Francisco Massa's avatar
Francisco Massa committed
142
143
144
145
146
147
148
            # TODO fix this
            if get_video_backend() == 'pyav':
                self.assertEqual(len(lv), 4)
                self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)
            else:
                self.assertEqual(len(lv), 3)
                self.assertTrue((data[5:8].float() - lv.float()).abs().max() < self.TOLERANCE)
149

150
151
152
153
154
155
156
    def test_read_packed_b_frames_divx_file(self):
        with get_tmp_dir() as temp_dir:
            name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi"
            f_name = os.path.join(temp_dir, name)
            url = "https://download.pytorch.org/vision_tests/io/" + name
            try:
                utils.download_url(url, temp_dir)
Francisco Massa's avatar
Francisco Massa committed
157
                pts, fps = io.read_video_timestamps(f_name)
158

159
160
161
162
163
164
165
                self.assertEqual(pts, sorted(pts))
                self.assertEqual(fps, 30)
            except URLError:
                msg = "could not download test file '{}'".format(url)
                warnings.warn(msg, RuntimeWarning)
                raise unittest.SkipTest(msg)

166
167
    def test_read_timestamps_from_packet(self):
        with temp_video(10, 300, 300, 5, video_codec='mpeg4') as (f_name, data):
Francisco Massa's avatar
Francisco Massa committed
168
            pts, _ = io.read_video_timestamps(f_name)
169
170
171
172
173
174
175
176
177
178
179
180
            # note: not all formats/codecs provide accurate information for computing the
            # timestamps. For the format that we use here, this information is available,
            # so we use it as a baseline
            container = av.open(f_name)
            stream = container.streams[0]
            # make sure we went through the optimized codepath
            self.assertIn(b'Lavc', stream.codec_context.extradata)
            pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
            num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
            expected_pts = [i * pts_step for i in range(num_frames)]

            self.assertEqual(pts, expected_pts)
181
            container.close()
182

183
184
185
186
187
188
    def test_read_video_pts_unit_sec(self):
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
            lv, _, info = io.read_video(f_name, pts_unit='sec')

            self.assertTrue(data.equal(lv))
            self.assertEqual(info["video_fps"], 5)
189
            self.assertEqual(info, {"video_fps": 5})
190
191
192
193
194
195
196
197
198
199
200
201

    def test_read_timestamps_pts_unit_sec(self):
        with temp_video(10, 300, 300, 5) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name, pts_unit='sec')

            container = av.open(f_name)
            stream = container.streams[0]
            pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
            num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
            expected_pts = [i * pts_step * stream.time_base for i in range(num_frames)]

            self.assertEqual(pts, expected_pts)
202
            container.close()
203
204
205
206
207
208

    def test_read_partial_video_pts_unit_sec(self):
        with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
            pts, _ = io.read_video_timestamps(f_name, pts_unit='sec')

            for start in range(5):
Francisco Massa's avatar
Francisco Massa committed
209
210
211
212
                for offset in range(1, 4):
                    lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1], pts_unit='sec')
                    s_data = data[start:(start + offset)]
                    self.assertEqual(len(lv), offset)
213
214
215
216
217
218
219
                    self.assertTrue(s_data.equal(lv))

            container = av.open(f_name)
            stream = container.streams[0]
            lv, _, _ = io.read_video(f_name,
                                     int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7],
                                     pts_unit='sec')
Francisco Massa's avatar
Francisco Massa committed
220
221
222
223
224
            if get_video_backend() == "pyav":
                # for "video_reader" backend, we don't decode the closest early frame
                # when the given start pts is not matching any frame pts
                self.assertEqual(len(lv), 4)
                self.assertTrue(data[4:8].equal(lv))
225
            container.close()
226

227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    def test_read_video_corrupted_file(self):
        with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
            f.write(b'This is not an mpg4 file')
            video, audio, info = io.read_video(f.name)
            self.assertIsInstance(video, torch.Tensor)
            self.assertIsInstance(audio, torch.Tensor)
            self.assertEqual(video.numel(), 0)
            self.assertEqual(audio.numel(), 0)
            self.assertEqual(info, {})

    def test_read_video_timestamps_corrupted_file(self):
        with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
            f.write(b'This is not an mpg4 file')
            video_pts, video_fps = io.read_video_timestamps(f.name)
            self.assertEqual(video_pts, [])
            self.assertIs(video_fps, None)

244
    @unittest.skip("Temporarily disabled due to new pyav")
245
246
247
248
249
250
251
252
253
254
255
256
    def test_read_video_partially_corrupted_file(self):
        with temp_video(5, 4, 4, 5, lossless=True) as (f_name, data):
            with open(f_name, 'r+b') as f:
                size = os.path.getsize(f_name)
                bytes_to_overwrite = size // 10
                # seek to the middle of the file
                f.seek(5 * bytes_to_overwrite)
                # corrupt 10% of the file from the middle
                f.write(b'\xff' * bytes_to_overwrite)
            # this exercises the container.decode assertion check
            video, audio, info = io.read_video(f.name, pts_unit='sec')
            # check that size is not equal to 5, but 3
Francisco Massa's avatar
Francisco Massa committed
257
258
259
260
261
            # TODO fix this
            if get_video_backend() == 'pyav':
                self.assertEqual(len(video), 3)
            else:
                self.assertEqual(len(video), 4)
262
263
264
265
266
            # but the valid decoded content is still correct
            self.assertTrue(video[:3].equal(data[:3]))
            # and the last few frames are wrong
            self.assertFalse(video.equal(data))

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    @unittest.skipIf(sys.platform == 'win32', 'temporarily disabled on Windows')
    def test_write_video_with_audio(self):
        f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4")
        video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec")

        with get_tmp_dir() as tmpdir:
            out_f_name = os.path.join(tmpdir, "testing.mp4")
            io.video.write_video(
                out_f_name,
                video_tensor,
                round(info["video_fps"]),
                video_codec="libx264rgb",
                options={'crf': '0'},
                audio_array=audio_tensor,
                audio_fps=info["audio_fps"],
                audio_codec="aac",
            )

            out_video_tensor, out_audio_tensor, out_info = io.read_video(
                out_f_name, pts_unit="sec"
            )

            self.assertEqual(info["video_fps"], out_info["video_fps"])
            self.assertTrue(video_tensor.equal(out_video_tensor))

            audio_stream = av.open(f_name).streams.audio[0]
            out_audio_stream = av.open(out_f_name).streams.audio[0]

            self.assertEqual(info["audio_fps"], out_info["audio_fps"])
            self.assertEqual(audio_stream.rate, out_audio_stream.rate)
            self.assertAlmostEqual(audio_stream.frames, out_audio_stream.frames, delta=1)
            self.assertEqual(audio_stream.frame_size, out_audio_stream.frame_size)

300
301
302
303
304
    # TODO add tests for audio


if __name__ == '__main__':
    unittest.main()