Unverified Commit e130c6cc authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

torchscriptable functions for video io (#1653) (#1794)

* torchscriptable functions for video io (#1653)

Summary:
Pull Request resolved: https://github.com/pytorch/vision/pull/1653



created new torchscriptable video io functions as part of the api: read_video_meta_data_from_memory and read_video_from_memory.

Updated the implementation of some of the internal functions to be torchscriptable.

Reviewed By: stephenyan1231

Differential Revision: D18720474

fbshipit-source-id: 4ee646b66afecd2dc338a71fd8f249f25a3263bc

* BugFix
Co-authored-by: default avatarJon Guerin <54725679+jguerin-fb@users.noreply.github.com>
parent 28b7f8ae
import contextlib import contextlib
import sys
import os import os
import torch import torch
import unittest import unittest
...@@ -92,7 +91,6 @@ class Tester(unittest.TestCase): ...@@ -92,7 +91,6 @@ class Tester(unittest.TestCase):
video, audio, info, video_idx = video_clips.get_clip(i) video, audio, info, video_idx = video_clips.get_clip(i)
self.assertEqual(video.shape[0], num_frames) self.assertEqual(video.shape[0], num_frames)
self.assertEqual(info["video_fps"], fps) self.assertEqual(info["video_fps"], fps)
self.assertEqual(info, {"video_fps": fps})
# TODO add tests checking that the content is right # TODO add tests checking that the content is right
def test_compute_clips_for_video(self): def test_compute_clips_for_video(self):
......
...@@ -81,8 +81,8 @@ class Tester(unittest.TestCase): ...@@ -81,8 +81,8 @@ class Tester(unittest.TestCase):
def test_probe_video_from_file(self): def test_probe_video_from_file(self):
with temp_video(10, 300, 300, 5) as (f_name, data): with temp_video(10, 300, 300, 5) as (f_name, data):
video_info = io._probe_video_from_file(f_name) video_info = io._probe_video_from_file(f_name)
self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1) self.assertAlmostEqual(video_info.video_duration, 2, delta=0.1)
self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1) self.assertAlmostEqual(video_info.video_fps, 5, delta=0.1)
@unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen") @unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
def test_probe_video_from_memory(self): def test_probe_video_from_memory(self):
...@@ -90,8 +90,8 @@ class Tester(unittest.TestCase): ...@@ -90,8 +90,8 @@ class Tester(unittest.TestCase):
with open(f_name, "rb") as fp: with open(f_name, "rb") as fp:
filebuffer = fp.read() filebuffer = fp.read()
video_info = io._probe_video_from_memory(filebuffer) video_info = io._probe_video_from_memory(filebuffer)
self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1) self.assertAlmostEqual(video_info.video_duration, 2, delta=0.1)
self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1) self.assertAlmostEqual(video_info.video_fps, 5, delta=0.1)
def test_read_timestamps(self): def test_read_timestamps(self):
with temp_video(10, 300, 300, 5) as (f_name, data): with temp_video(10, 300, 300, 5) as (f_name, data):
......
import collections import collections
from common_utils import get_tmp_dir
from fractions import Fraction
import math import math
import numpy as np
import os import os
import sys import sys
import time import time
import unittest
from fractions import Fraction
import numpy as np
import torch import torch
import torchvision.io as io import torchvision.io as io
import unittest
from numpy.random import randint from numpy.random import randint
from torchvision.io import _HAS_VIDEO_OPT
try: try:
import av import av
# Do a version test too # Do a version test too
io.video._check_av_available() io.video._check_av_available()
except ImportError: except ImportError:
...@@ -25,9 +28,6 @@ else: ...@@ -25,9 +28,6 @@ else:
from urllib.error import URLError from urllib.error import URLError
from torchvision.io import _HAS_VIDEO_OPT
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos") VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
CheckerConfig = [ CheckerConfig = [
...@@ -39,10 +39,7 @@ CheckerConfig = [ ...@@ -39,10 +39,7 @@ CheckerConfig = [
"check_aframes", "check_aframes",
"check_aframe_pts", "check_aframe_pts",
] ]
GroundTruth = collections.namedtuple( GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig))
"GroundTruth",
" ".join(CheckerConfig)
)
all_check_config = GroundTruth( all_check_config = GroundTruth(
duration=0, duration=0,
...@@ -193,9 +190,9 @@ def _decode_frames_by_av_module( ...@@ -193,9 +190,9 @@ def _decode_frames_by_av_module(
frames are read frames are read
""" """
if video_end_pts is None: if video_end_pts is None:
video_end_pts = float('inf') video_end_pts = float("inf")
if audio_end_pts is None: if audio_end_pts is None:
audio_end_pts = float('inf') audio_end_pts = float("inf")
container = av.open(full_path) container = av.open(full_path)
video_frames = [] video_frames = []
...@@ -282,8 +279,10 @@ class TestVideoReader(unittest.TestCase): ...@@ -282,8 +279,10 @@ class TestVideoReader(unittest.TestCase):
def check_separate_decoding_result(self, tv_result, config): def check_separate_decoding_result(self, tv_result, config):
"""check the decoding results from TorchVision decoder """check the decoding results from TorchVision decoder
""" """
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \ vframes, vframe_pts, vtimebase, vfps, vduration, \
atimebase, asample_rate, aduration = tv_result aframes, aframe_pts, atimebase, asample_rate, aduration = (
tv_result
)
video_duration = vduration.item() * Fraction( video_duration = vduration.item() * Fraction(
vtimebase[0].item(), vtimebase[1].item() vtimebase[0].item(), vtimebase[1].item()
...@@ -321,6 +320,13 @@ class TestVideoReader(unittest.TestCase): ...@@ -321,6 +320,13 @@ class TestVideoReader(unittest.TestCase):
) )
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5) self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)
def check_meta_result(self, result, config):
self.assertAlmostEqual(result.video_duration, config.duration, delta=0.5)
self.assertAlmostEqual(result.video_fps, config.video_fps, delta=0.5)
if result.has_audio > 0:
self.assertEqual(result.audio_sample_rate, config.audio_sample_rate)
self.assertAlmostEqual(result.audio_duration, config.duration, delta=0.5)
def compare_decoding_result(self, tv_result, ref_result, config=all_check_config): def compare_decoding_result(self, tv_result, ref_result, config=all_check_config):
""" """
Compare decoding results from two sources. Compare decoding results from two sources.
...@@ -330,8 +336,10 @@ class TestVideoReader(unittest.TestCase): ...@@ -330,8 +336,10 @@ class TestVideoReader(unittest.TestCase):
decoder or TorchVision decoder with getPtsOnly = 1 decoder or TorchVision decoder with getPtsOnly = 1
config: config of decoding results checker config: config of decoding results checker
""" """
vframes, vframe_pts, vtimebase, _vfps, _vduration, aframes, aframe_pts, \ vframes, vframe_pts, vtimebase, _vfps, _vduration, \
atimebase, _asample_rate, _aduration = tv_result aframes, aframe_pts, atimebase, _asample_rate, _aduration = (
tv_result
)
if isinstance(ref_result, list): if isinstance(ref_result, list):
# the ref_result is from new video_reader decoder # the ref_result is from new video_reader decoder
ref_result = DecoderResult( ref_result = DecoderResult(
...@@ -344,22 +352,34 @@ class TestVideoReader(unittest.TestCase): ...@@ -344,22 +352,34 @@ class TestVideoReader(unittest.TestCase):
) )
if vframes.numel() > 0 and ref_result.vframes.numel() > 0: if vframes.numel() > 0 and ref_result.vframes.numel() > 0:
mean_delta = torch.mean(torch.abs(vframes.float() - ref_result.vframes.float())) mean_delta = torch.mean(
torch.abs(vframes.float() - ref_result.vframes.float())
)
self.assertAlmostEqual(mean_delta, 0, delta=8.0) self.assertAlmostEqual(mean_delta, 0, delta=8.0)
mean_delta = torch.mean(torch.abs(vframe_pts.float() - ref_result.vframe_pts.float())) mean_delta = torch.mean(
torch.abs(vframe_pts.float() - ref_result.vframe_pts.float())
)
self.assertAlmostEqual(mean_delta, 0, delta=1.0) self.assertAlmostEqual(mean_delta, 0, delta=1.0)
is_same = torch.all(torch.eq(vtimebase, ref_result.vtimebase)).item() is_same = torch.all(torch.eq(vtimebase, ref_result.vtimebase)).item()
self.assertEqual(is_same, True) self.assertEqual(is_same, True)
if config.check_aframes and aframes.numel() > 0 and ref_result.aframes.numel() > 0: if (
config.check_aframes
and aframes.numel() > 0
and ref_result.aframes.numel() > 0
):
"""Audio stream is available and audio frame is required to return """Audio stream is available and audio frame is required to return
from decoder""" from decoder"""
is_same = torch.all(torch.eq(aframes, ref_result.aframes)).item() is_same = torch.all(torch.eq(aframes, ref_result.aframes)).item()
self.assertEqual(is_same, True) self.assertEqual(is_same, True)
if config.check_aframe_pts and aframe_pts.numel() > 0 and ref_result.aframe_pts.numel() > 0: if (
config.check_aframe_pts
and aframe_pts.numel() > 0
and ref_result.aframe_pts.numel() > 0
):
"""Audio stream is available""" """Audio stream is available"""
is_same = torch.all(torch.eq(aframe_pts, ref_result.aframe_pts)).item() is_same = torch.all(torch.eq(aframe_pts, ref_result.aframe_pts)).item()
self.assertEqual(is_same, True) self.assertEqual(is_same, True)
...@@ -492,15 +512,19 @@ class TestVideoReader(unittest.TestCase): ...@@ -492,15 +512,19 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_den, audio_timebase_den,
) )
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \ vframes, vframe_pts, vtimebase, vfps, vduration, \
atimebase, asample_rate, aduration = tv_result aframes, aframe_pts, atimebase, asample_rate, aduration = (
tv_result
)
self.assertEqual(vframes.numel() > 0, readVideoStream) self.assertEqual(vframes.numel() > 0, readVideoStream)
self.assertEqual(vframe_pts.numel() > 0, readVideoStream) self.assertEqual(vframe_pts.numel() > 0, readVideoStream)
self.assertEqual(vtimebase.numel() > 0, readVideoStream) self.assertEqual(vtimebase.numel() > 0, readVideoStream)
self.assertEqual(vfps.numel() > 0, readVideoStream) self.assertEqual(vfps.numel() > 0, readVideoStream)
expect_audio_data = readAudioStream == 1 and config.audio_sample_rate is not None expect_audio_data = (
readAudioStream == 1 and config.audio_sample_rate is not None
)
self.assertEqual(aframes.numel() > 0, expect_audio_data) self.assertEqual(aframes.numel() > 0, expect_audio_data)
self.assertEqual(aframe_pts.numel() > 0, expect_audio_data) self.assertEqual(aframe_pts.numel() > 0, expect_audio_data)
self.assertEqual(atimebase.numel() > 0, expect_audio_data) self.assertEqual(atimebase.numel() > 0, expect_audio_data)
...@@ -543,7 +567,9 @@ class TestVideoReader(unittest.TestCase): ...@@ -543,7 +567,9 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_num,
audio_timebase_den, audio_timebase_den,
) )
self.assertEqual(min_dimension, min(tv_result[0].size(1), tv_result[0].size(2))) self.assertEqual(
min_dimension, min(tv_result[0].size(1), tv_result[0].size(2))
)
def test_read_video_from_file_rescale_width(self): def test_read_video_from_file_rescale_width(self):
""" """
...@@ -669,10 +695,7 @@ class TestVideoReader(unittest.TestCase): ...@@ -669,10 +695,7 @@ class TestVideoReader(unittest.TestCase):
audio waveform are resampled audio waveform are resampled
""" """
for samples in [ for samples in [9600, 96000]: # downsampling # upsampling
9600, # downsampling
96000, # upsampling
]:
# video related # video related
width, height, min_dimension = 0, 0, 0 width, height, min_dimension = 0, 0, 0
video_start_pts, video_end_pts = 0, -1 video_start_pts, video_end_pts = 0, -1
...@@ -705,13 +728,19 @@ class TestVideoReader(unittest.TestCase): ...@@ -705,13 +728,19 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_num,
audio_timebase_den, audio_timebase_den,
) )
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \ vframes, vframe_pts, vtimebase, vfps, vduration, \
atimebase, asample_rate, aduration = tv_result aframes, aframe_pts, atimebase, asample_rate, aduration = (
tv_result
)
if aframes.numel() > 0: if aframes.numel() > 0:
self.assertEqual(samples, asample_rate.item()) self.assertEqual(samples, asample_rate.item())
self.assertEqual(1, aframes.size(1)) self.assertEqual(1, aframes.size(1))
# when audio stream is found # when audio stream is found
duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1]) duration = (
float(aframe_pts[-1])
* float(atimebase[0])
/ float(atimebase[1])
)
self.assertAlmostEqual( self.assertAlmostEqual(
aframes.size(0), aframes.size(0),
int(duration * asample_rate.item()), int(duration * asample_rate.item()),
...@@ -929,8 +958,10 @@ class TestVideoReader(unittest.TestCase): ...@@ -929,8 +958,10 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_num,
audio_timebase_den, audio_timebase_den,
) )
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \ vframes, vframe_pts, vtimebase, vfps, vduration, \
atimebase, asample_rate, aduration = tv_result aframes, aframe_pts, atimebase, asample_rate, aduration = (
tv_result
)
self.assertAlmostEqual(config.video_fps, vfps.item(), delta=0.01) self.assertAlmostEqual(config.video_fps, vfps.item(), delta=0.01)
for num_frames in [4, 8, 16, 32, 64, 128]: for num_frames in [4, 8, 16, 32, 64, 128]:
...@@ -983,31 +1014,41 @@ class TestVideoReader(unittest.TestCase): ...@@ -983,31 +1014,41 @@ class TestVideoReader(unittest.TestCase):
) )
# pass 3: decode frames in range using PyAv # pass 3: decode frames in range using PyAv
video_timebase_av, audio_timebase_av = _get_timebase_by_av_module(full_path) video_timebase_av, audio_timebase_av = _get_timebase_by_av_module(
full_path
)
video_start_pts_av = _pts_convert( video_start_pts_av = _pts_convert(
video_start_pts.item(), video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()), Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(video_timebase_av.numerator, video_timebase_av.denominator), Fraction(
video_timebase_av.numerator, video_timebase_av.denominator
),
math.floor, math.floor,
) )
video_end_pts_av = _pts_convert( video_end_pts_av = _pts_convert(
video_end_pts.item(), video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()), Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(video_timebase_av.numerator, video_timebase_av.denominator), Fraction(
video_timebase_av.numerator, video_timebase_av.denominator
),
math.ceil, math.ceil,
) )
if audio_timebase_av: if audio_timebase_av:
audio_start_pts = _pts_convert( audio_start_pts = _pts_convert(
video_start_pts.item(), video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()), Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator), Fraction(
audio_timebase_av.numerator, audio_timebase_av.denominator
),
math.floor, math.floor,
) )
audio_end_pts = _pts_convert( audio_end_pts = _pts_convert(
video_end_pts.item(), video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()), Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator), Fraction(
audio_timebase_av.numerator, audio_timebase_av.denominator
),
math.ceil, math.ceil,
) )
...@@ -1044,6 +1085,54 @@ class TestVideoReader(unittest.TestCase): ...@@ -1044,6 +1085,54 @@ class TestVideoReader(unittest.TestCase):
probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor) probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor)
self.check_probe_result(probe_result, config) self.check_probe_result(probe_result, config)
def test_read_video_meta_data_from_memory_script(self):
scripted_fun = torch.jit.script(io.read_video_meta_data_from_memory)
self.assertIsNotNone(scripted_fun)
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
probe_result = scripted_fun(video_tensor)
self.check_meta_result(probe_result, config)
def test_read_video_from_memory_scripted(self):
"""
Test the case when video is already in memory, and decoder reads data in memory
"""
# video related
width, height, min_dimension = 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
scripted_fun = torch.jit.script(io.read_video_from_memory)
self.assertIsNotNone(scripted_fun)
for test_video, _config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# decode all frames using cpp decoder
scripted_fun(
video_tensor,
seek_frame_margin,
1, # readVideoStream
width,
height,
min_dimension,
[video_start_pts, video_end_pts],
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
[audio_start_pts, audio_end_pts],
audio_timebase_num,
audio_timebase_den,
)
# FUTURE: check value of video / audio frames
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
import bisect import bisect
from fractions import Fraction
import math import math
from fractions import Fraction
import torch import torch
from torchvision.io import ( from torchvision.io import (
_read_video_timestamps_from_file, _probe_video_from_file,
_read_video_from_file, _read_video_from_file,
_probe_video_from_file _read_video_timestamps_from_file,
read_video,
read_video_timestamps,
) )
from torchvision.io import read_video_timestamps, read_video
from .utils import tqdm from .utils import tqdm
...@@ -48,6 +50,7 @@ class _DummyDataset(object): ...@@ -48,6 +50,7 @@ class _DummyDataset(object):
Dummy dataset used for DataLoader in VideoClips. Dummy dataset used for DataLoader in VideoClips.
Defined at top level so it can be pickled when forking. Defined at top level so it can be pickled when forking.
""" """
def __init__(self, x): def __init__(self, x):
self.x = x self.x = x
...@@ -83,10 +86,21 @@ class VideoClips(object): ...@@ -83,10 +86,21 @@ class VideoClips(object):
num_workers (int): how many subprocesses to use for data loading. num_workers (int): how many subprocesses to use for data loading.
0 means that the data will be loaded in the main process. (default: 0) 0 means that the data will be loaded in the main process. (default: 0)
""" """
def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1,
frame_rate=None, _precomputed_metadata=None, num_workers=0, def __init__(
_video_width=0, _video_height=0, _video_min_dimension=0, self,
_audio_samples=0, _audio_channels=0): video_paths,
clip_length_in_frames=16,
frames_between_clips=1,
frame_rate=None,
_precomputed_metadata=None,
num_workers=0,
_video_width=0,
_video_height=0,
_video_min_dimension=0,
_audio_samples=0,
_audio_channels=0,
):
self.video_paths = video_paths self.video_paths = video_paths
self.num_workers = num_workers self.num_workers = num_workers
...@@ -114,11 +128,13 @@ class VideoClips(object): ...@@ -114,11 +128,13 @@ class VideoClips(object):
# strategy: use a DataLoader to parallelize read_video_timestamps # strategy: use a DataLoader to parallelize read_video_timestamps
# so need to create a dummy dataset first # so need to create a dummy dataset first
import torch.utils.data import torch.utils.data
dl = torch.utils.data.DataLoader( dl = torch.utils.data.DataLoader(
_DummyDataset(self.video_paths), _DummyDataset(self.video_paths),
batch_size=16, batch_size=16,
num_workers=self.num_workers, num_workers=self.num_workers,
collate_fn=self._collate_fn) collate_fn=self._collate_fn,
)
with tqdm(total=len(dl)) as pbar: with tqdm(total=len(dl)) as pbar:
for batch in dl: for batch in dl:
...@@ -140,7 +156,7 @@ class VideoClips(object): ...@@ -140,7 +156,7 @@ class VideoClips(object):
_metadata = { _metadata = {
"video_paths": self.video_paths, "video_paths": self.video_paths,
"video_pts": self.video_pts, "video_pts": self.video_pts,
"video_fps": self.video_fps "video_fps": self.video_fps,
} }
return _metadata return _metadata
...@@ -151,15 +167,21 @@ class VideoClips(object): ...@@ -151,15 +167,21 @@ class VideoClips(object):
metadata = { metadata = {
"video_paths": video_paths, "video_paths": video_paths,
"video_pts": video_pts, "video_pts": video_pts,
"video_fps": video_fps "video_fps": video_fps,
} }
return type(self)(video_paths, self.num_frames, self.step, self.frame_rate, return type(self)(
_precomputed_metadata=metadata, num_workers=self.num_workers, video_paths,
_video_width=self._video_width, self.num_frames,
_video_height=self._video_height, self.step,
_video_min_dimension=self._video_min_dimension, self.frame_rate,
_audio_samples=self._audio_samples, _precomputed_metadata=metadata,
_audio_channels=self._audio_channels) num_workers=self.num_workers,
_video_width=self._video_width,
_video_height=self._video_height,
_video_min_dimension=self._video_min_dimension,
_audio_samples=self._audio_samples,
_audio_channels=self._audio_channels,
)
@staticmethod @staticmethod
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate): def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate):
...@@ -170,7 +192,9 @@ class VideoClips(object): ...@@ -170,7 +192,9 @@ class VideoClips(object):
if frame_rate is None: if frame_rate is None:
frame_rate = fps frame_rate = fps
total_frames = len(video_pts) * (float(frame_rate) / fps) total_frames = len(video_pts) * (float(frame_rate) / fps)
idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate) idxs = VideoClips._resample_video_idx(
int(math.floor(total_frames)), fps, frame_rate
)
video_pts = video_pts[idxs] video_pts = video_pts[idxs]
clips = unfold(video_pts, num_frames, step) clips = unfold(video_pts, num_frames, step)
if isinstance(idxs, slice): if isinstance(idxs, slice):
...@@ -195,7 +219,9 @@ class VideoClips(object): ...@@ -195,7 +219,9 @@ class VideoClips(object):
self.clips = [] self.clips = []
self.resampling_idxs = [] self.resampling_idxs = []
for video_pts, fps in zip(self.video_pts, self.video_fps): for video_pts, fps in zip(self.video_pts, self.video_fps):
clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate) clips, idxs = self.compute_clips_for_video(
video_pts, num_frames, step, fps, frame_rate
)
self.clips.append(clips) self.clips.append(clips)
self.resampling_idxs.append(idxs) self.resampling_idxs.append(idxs)
clip_lengths = torch.as_tensor([len(v) for v in self.clips]) clip_lengths = torch.as_tensor([len(v) for v in self.clips])
...@@ -251,13 +277,16 @@ class VideoClips(object): ...@@ -251,13 +277,16 @@ class VideoClips(object):
video_idx (int): index of the video in `video_paths` video_idx (int): index of the video in `video_paths`
""" """
if idx >= self.num_clips(): if idx >= self.num_clips():
raise IndexError("Index {} out of range " raise IndexError(
"({} number of clips)".format(idx, self.num_clips())) "Index {} out of range "
"({} number of clips)".format(idx, self.num_clips())
)
video_idx, clip_idx = self.get_clip_location(idx) video_idx, clip_idx = self.get_clip_location(idx)
video_path = self.video_paths[video_idx] video_path = self.video_paths[video_idx]
clip_pts = self.clips[video_idx][clip_idx] clip_pts = self.clips[video_idx][clip_idx]
from torchvision import get_video_backend from torchvision import get_video_backend
backend = get_video_backend() backend = get_video_backend()
if backend == "pyav": if backend == "pyav":
...@@ -267,7 +296,9 @@ class VideoClips(object): ...@@ -267,7 +296,9 @@ class VideoClips(object):
if self._video_height != 0: if self._video_height != 0:
raise ValueError("pyav backend doesn't support _video_height != 0") raise ValueError("pyav backend doesn't support _video_height != 0")
if self._video_min_dimension != 0: if self._video_min_dimension != 0:
raise ValueError("pyav backend doesn't support _video_min_dimension != 0") raise ValueError(
"pyav backend doesn't support _video_min_dimension != 0"
)
if self._audio_samples != 0: if self._audio_samples != 0:
raise ValueError("pyav backend doesn't support _audio_samples != 0") raise ValueError("pyav backend doesn't support _audio_samples != 0")
...@@ -277,7 +308,7 @@ class VideoClips(object): ...@@ -277,7 +308,7 @@ class VideoClips(object):
video, audio, info = read_video(video_path, start_pts, end_pts) video, audio, info = read_video(video_path, start_pts, end_pts)
else: else:
info = _probe_video_from_file(video_path) info = _probe_video_from_file(video_path)
video_fps = info["video_fps"] video_fps = info.video_fps
audio_fps = None audio_fps = None
video_start_pts = clip_pts[0].item() video_start_pts = clip_pts[0].item()
...@@ -285,28 +316,27 @@ class VideoClips(object): ...@@ -285,28 +316,27 @@ class VideoClips(object):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase = Fraction(0, 1) audio_timebase = Fraction(0, 1)
if "audio_timebase" in info: video_timebase = Fraction(
audio_timebase = info["audio_timebase"] info.video_timebase.numerator, info.video_timebase.denominator
)
if info.has_audio:
audio_timebase = Fraction(
info.audio_timebase.numerator, info.audio_timebase.denominator
)
audio_start_pts = pts_convert( audio_start_pts = pts_convert(
video_start_pts, video_start_pts, video_timebase, audio_timebase, math.floor
info["video_timebase"],
info["audio_timebase"],
math.floor,
) )
audio_end_pts = pts_convert( audio_end_pts = pts_convert(
video_end_pts, video_end_pts, video_timebase, audio_timebase, math.ceil
info["video_timebase"],
info["audio_timebase"],
math.ceil,
) )
audio_fps = info["audio_sample_rate"] audio_fps = info.audio_sample_rate
video, audio, info = _read_video_from_file( video, audio, info = _read_video_from_file(
video_path, video_path,
video_width=self._video_width, video_width=self._video_width,
video_height=self._video_height, video_height=self._video_height,
video_min_dimension=self._video_min_dimension, video_min_dimension=self._video_min_dimension,
video_pts_range=(video_start_pts, video_end_pts), video_pts_range=(video_start_pts, video_end_pts),
video_timebase=info["video_timebase"], video_timebase=video_timebase,
audio_samples=self._audio_samples, audio_samples=self._audio_samples,
audio_channels=self._audio_channels, audio_channels=self._audio_channels,
audio_pts_range=(audio_start_pts, audio_end_pts), audio_pts_range=(audio_start_pts, audio_end_pts),
...@@ -323,5 +353,7 @@ class VideoClips(object): ...@@ -323,5 +353,7 @@ class VideoClips(object):
resampling_idx = resampling_idx - resampling_idx[0] resampling_idx = resampling_idx - resampling_idx[0]
video = video[resampling_idx] video = video[resampling_idx]
info["video_fps"] = self.frame_rate info["video_fps"] = self.frame_rate
assert len(video) == self.num_frames, "{} x {}".format(video.shape, self.num_frames) assert len(video) == self.num_frames, "{} x {}".format(
video.shape, self.num_frames
)
return video, audio, info, video_idx return video, audio, info, video_idx
from .video import write_video, read_video, read_video_timestamps, _HAS_VIDEO_OPT
from ._video_opt import ( from ._video_opt import (
_read_video_from_file, Timebase,
_read_video_timestamps_from_file, VideoMetaData,
_HAS_VIDEO_OPT,
_probe_video_from_file, _probe_video_from_file,
_probe_video_from_memory,
_read_video_from_file,
_read_video_from_memory, _read_video_from_memory,
_read_video_timestamps_from_file,
_read_video_timestamps_from_memory, _read_video_timestamps_from_memory,
_probe_video_from_memory, )
from .video import (
read_video,
read_video_from_memory,
read_video_meta_data_from_memory,
read_video_timestamps,
write_video,
) )
__all__ = [ __all__ = [
'write_video', 'read_video', 'read_video_timestamps', "write_video",
'_read_video_from_file', '_read_video_timestamps_from_file', '_probe_video_from_file', "read_video",
'_read_video_from_memory', '_read_video_timestamps_from_memory', '_probe_video_from_memory', "read_video_timestamps",
'_HAS_VIDEO_OPT', "read_video_meta_data_from_memory",
"read_video_from_memory",
"_read_video_from_file",
"_read_video_timestamps_from_file",
"_probe_video_from_file",
"_read_video_from_memory",
"_read_video_timestamps_from_memory",
"_probe_video_from_memory",
"_HAS_VIDEO_OPT",
"_read_video_clip_from_memory",
"_read_video_meta_data",
"VideoMetaData",
"Timebase",
] ]
from fractions import Fraction
import imp
import math import math
import os
import warnings
from fractions import Fraction
from typing import List, Tuple
import numpy as np import numpy as np
import torch import torch
import warnings
_HAS_VIDEO_OPT = False
try:
lib_dir = os.path.join(os.path.dirname(__file__), "..")
_, path, description = imp.find_module("video_reader", [lib_dir])
torch.ops.load_library(path)
_HAS_VIDEO_OPT = True
except (ImportError, OSError):
pass
default_timebase = Fraction(0, 1) default_timebase = Fraction(0, 1)
# simple class for torch scripting
# the complex Fraction class from fractions module is not scriptable
@torch.jit.script
class Timebase(object):
__annotations__ = {"numerator": int, "denominator": int}
__slots__ = ["numerator", "denominator"]
def __init__(
self,
numerator, # type: int
denominator, # type: int
):
# type: (...) -> None
self.numerator = numerator
self.denominator = denominator
@torch.jit.script
class VideoMetaData(object):
__annotations__ = {
"has_video": bool,
"video_timebase": Timebase,
"video_duration": float,
"video_fps": float,
"has_audio": bool,
"audio_timebase": Timebase,
"audio_duration": float,
"audio_sample_rate": float,
}
__slots__ = [
"has_video",
"video_timebase",
"video_duration",
"video_fps",
"has_audio",
"audio_timebase",
"audio_duration",
"audio_sample_rate",
]
def __init__(self):
self.has_video = False
self.video_timebase = Timebase(0, 1)
self.video_duration = 0.0
self.video_fps = 0.0
self.has_audio = False
self.audio_timebase = Timebase(0, 1)
self.audio_duration = 0.0
self.audio_sample_rate = 0.0
def _validate_pts(pts_range): def _validate_pts(pts_range):
# type: (List[int])
if pts_range[1] > 0: if pts_range[1] > 0:
assert pts_range[0] <= pts_range[1], \ assert (
"""Start pts should not be smaller than end pts, got pts_range[0] <= pts_range[1]
start pts: %d and end pts: %d""" % (pts_range[0], pts_range[1]) ), """Start pts should not be smaller than end pts, got
start pts: %d and end pts: %d""" % (
pts_range[0],
pts_range[1],
)
def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration): def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration):
info = {} # type: (torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor) -> VideoMetaData
"""
Build update VideoMetaData struct with info about the video
"""
meta = VideoMetaData()
if vtimebase.numel() > 0: if vtimebase.numel() > 0:
info["video_timebase"] = Fraction(vtimebase[0].item(), vtimebase[1].item()) meta.video_timebase = Timebase(
int(vtimebase[0].item()), int(vtimebase[1].item())
)
timebase = vtimebase[0].item() / float(vtimebase[1].item())
if vduration.numel() > 0: if vduration.numel() > 0:
video_duration = vduration.item() * info["video_timebase"] meta.has_video = True
info["video_duration"] = video_duration meta.video_duration = float(vduration.item()) * timebase
if vfps.numel() > 0: if vfps.numel() > 0:
info["video_fps"] = vfps.item() meta.video_fps = float(vfps.item())
if atimebase.numel() > 0: if atimebase.numel() > 0:
info["audio_timebase"] = Fraction(atimebase[0].item(), atimebase[1].item()) meta.audio_timebase = Timebase(
int(atimebase[0].item()), int(atimebase[1].item())
)
timebase = atimebase[0].item() / float(atimebase[1].item())
if aduration.numel() > 0: if aduration.numel() > 0:
audio_duration = aduration.item() * info["audio_timebase"] meta.has_audio = True
info["audio_duration"] = audio_duration meta.audio_duration = float(aduration.item()) * timebase
if asample_rate.numel() > 0: if asample_rate.numel() > 0:
info["audio_sample_rate"] = asample_rate.item() meta.audio_sample_rate = float(asample_rate.item())
return info return meta
def _align_audio_frames(aframes, aframe_pts, audio_pts_range): def _align_audio_frames(aframes, aframe_pts, audio_pts_range):
# type: (torch.Tensor, torch.Tensor, List[int]) -> torch.Tensor
start, end = aframe_pts[0], aframe_pts[-1] start, end = aframe_pts[0], aframe_pts[-1]
num_samples = aframes.size(0) num_samples = aframes.size(0)
step_per_aframe = float(end - start + 1) / float(num_samples) step_per_aframe = float(end - start + 1) / float(num_samples)
...@@ -136,8 +219,10 @@ def _read_video_from_file( ...@@ -136,8 +219,10 @@ def _read_video_from_file(
audio_timebase.numerator, audio_timebase.numerator,
audio_timebase.denominator, audio_timebase.denominator,
) )
vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, \ vframes, _vframe_pts, vtimebase, vfps, vduration, \
asample_rate, aduration = result aframes, aframe_pts, atimebase, asample_rate, aduration = (
result
)
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
if aframes.numel() > 0: if aframes.numel() > 0:
# when audio stream is found # when audio stream is found
...@@ -171,8 +256,8 @@ def _read_video_timestamps_from_file(filename): ...@@ -171,8 +256,8 @@ def _read_video_timestamps_from_file(filename):
0, # audio_timebase_num 0, # audio_timebase_num
1, # audio_timebase_den 1, # audio_timebase_den
) )
_vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, \ _vframes, vframe_pts, vtimebase, vfps, vduration, \
asample_rate, aduration = result _aframes, aframe_pts, atimebase, asample_rate, aduration = (result)
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
vframe_pts = vframe_pts.numpy().tolist() vframe_pts = vframe_pts.numpy().tolist()
...@@ -182,10 +267,7 @@ def _read_video_timestamps_from_file(filename): ...@@ -182,10 +267,7 @@ def _read_video_timestamps_from_file(filename):
def _probe_video_from_file(filename): def _probe_video_from_file(filename):
""" """
Probe a video file. Probe a video file and return VideoMetaData with info about the video
Return:
info [dict]: contain video meta information, including video_timebase,
video_duration, video_fps, audio_timebase, audio_duration, audio_sample_rate
""" """
result = torch.ops.video_reader.probe_video_from_file(filename) result = torch.ops.video_reader.probe_video_from_file(filename)
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
...@@ -194,23 +276,27 @@ def _probe_video_from_file(filename): ...@@ -194,23 +276,27 @@ def _probe_video_from_file(filename):
def _read_video_from_memory( def _read_video_from_memory(
video_data, video_data, # type: torch.Tensor
seek_frame_margin=0.25, seek_frame_margin=0.25, # type: float
read_video_stream=1, read_video_stream=1, # type: int
video_width=0, video_width=0, # type: int
video_height=0, video_height=0, # type: int
video_min_dimension=0, video_min_dimension=0, # type: int
video_pts_range=(0, -1), video_pts_range=(0, -1), # type: List[int]
video_timebase=default_timebase, video_timebase_numerator=0, # type: int
read_audio_stream=1, video_timebase_denominator=1, # type: int
audio_samples=0, read_audio_stream=1, # type: int
audio_channels=0, audio_samples=0, # type: int
audio_pts_range=(0, -1), audio_channels=0, # type: int
audio_timebase=default_timebase, audio_pts_range=(0, -1), # type: List[int]
audio_timebase_numerator=0, # type: int
audio_timebase_denominator=1, # type: int
): ):
# type: (...) -> Tuple[torch.Tensor, torch.Tensor]
""" """
Reads a video from memory, returning both the video frames as well as Reads a video from memory, returning both the video frames as well as
the audio frames the audio frames
This function is torchscriptable.
Args Args
---------- ----------
...@@ -234,8 +320,8 @@ def _read_video_from_memory( ...@@ -234,8 +320,8 @@ def _read_video_from_memory(
are set to $video_width and $video_height, respectively are set to $video_width and $video_height, respectively
video_pts_range : list(int), optional video_pts_range : list(int), optional
the start and end presentation timestamp of video stream the start and end presentation timestamp of video stream
video_timebase: Fraction, optional video_timebase_numerator / video_timebase_denominator: optional
a Fraction rational number which denotes timebase in video stream a rational number which denotes timebase in video stream
read_audio_stream: int, optional read_audio_stream: int, optional
whether read audio stream. If yes, set to 1. Otherwise, 0 whether read audio stream. If yes, set to 1. Otherwise, 0
audio_samples: int, optional audio_samples: int, optional
...@@ -244,8 +330,8 @@ def _read_video_from_memory( ...@@ -244,8 +330,8 @@ def _read_video_from_memory(
audio audio_channels audio audio_channels
audio_pts_range : list(int), optional audio_pts_range : list(int), optional
the start and end presentation timestamp of audio stream the start and end presentation timestamp of audio stream
audio_timebase: Fraction, optional audio_timebase_numerator / audio_timebase_denominator: optional
a Fraction rational number which denotes time base in audio stream a rational number which denotes time base in audio stream
Returns Returns
------- -------
...@@ -254,17 +340,11 @@ def _read_video_from_memory( ...@@ -254,17 +340,11 @@ def _read_video_from_memory(
aframes : Tensor[L, K] aframes : Tensor[L, K]
the audio frames, where `L` is the number of points and the audio frames, where `L` is the number of points and
`K` is the number of channels `K` is the number of channels
info : Dict
metadata for the video and audio. Can contain the fields video fps (float)
and audio sample rate (int)
""" """
_validate_pts(video_pts_range) _validate_pts(video_pts_range)
_validate_pts(audio_pts_range) _validate_pts(audio_pts_range)
if not isinstance(video_data, torch.Tensor):
video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8))
result = torch.ops.video_reader.read_video_from_memory( result = torch.ops.video_reader.read_video_from_memory(
video_data, video_data,
seek_frame_margin, seek_frame_margin,
...@@ -275,24 +355,27 @@ def _read_video_from_memory( ...@@ -275,24 +355,27 @@ def _read_video_from_memory(
video_min_dimension, video_min_dimension,
video_pts_range[0], video_pts_range[0],
video_pts_range[1], video_pts_range[1],
video_timebase.numerator, video_timebase_numerator,
video_timebase.denominator, video_timebase_denominator,
read_audio_stream, read_audio_stream,
audio_samples, audio_samples,
audio_channels, audio_channels,
audio_pts_range[0], audio_pts_range[0],
audio_pts_range[1], audio_pts_range[1],
audio_timebase.numerator, audio_timebase_numerator,
audio_timebase.denominator, audio_timebase_denominator,
) )
vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \ vframes, _vframe_pts, vtimebase, vfps, vduration, \
atimebase, asample_rate, aduration = result aframes, aframe_pts, atimebase, asample_rate, aduration = (
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) result
)
if aframes.numel() > 0: if aframes.numel() > 0:
# when audio stream is found # when audio stream is found
aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range) aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
return vframes, aframes, info
return vframes, aframes
def _read_video_timestamps_from_memory(video_data): def _read_video_timestamps_from_memory(video_data):
...@@ -323,8 +406,10 @@ def _read_video_timestamps_from_memory(video_data): ...@@ -323,8 +406,10 @@ def _read_video_timestamps_from_memory(video_data):
0, # audio_timebase_num 0, # audio_timebase_num
1, # audio_timebase_den 1, # audio_timebase_den
) )
_vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, \ _vframes, vframe_pts, vtimebase, vfps, vduration, \
atimebase, asample_rate, aduration = result _aframes, aframe_pts, atimebase, asample_rate, aduration = (
result
)
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
vframe_pts = vframe_pts.numpy().tolist() vframe_pts = vframe_pts.numpy().tolist()
...@@ -333,11 +418,10 @@ def _read_video_timestamps_from_memory(video_data): ...@@ -333,11 +418,10 @@ def _read_video_timestamps_from_memory(video_data):
def _probe_video_from_memory(video_data): def _probe_video_from_memory(video_data):
# type: (torch.Tensor) -> VideoMetaData
""" """
Probe a video in memory. Probe a video in memory and return VideoMetaData with info about the video
Return: This function is torchscriptable
info [dict]: contain video meta information, including video_timebase,
video_duration, video_fps, audio_timebase, audio_duration, audio_sample_rate
""" """
if not isinstance(video_data, torch.Tensor): if not isinstance(video_data, torch.Tensor):
video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8)) video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8))
...@@ -347,23 +431,25 @@ def _probe_video_from_memory(video_data): ...@@ -347,23 +431,25 @@ def _probe_video_from_memory(video_data):
return info return info
def _read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'): def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
if end_pts is None: if end_pts is None:
end_pts = float("inf") end_pts = float("inf")
if pts_unit == 'pts': if pts_unit == "pts":
warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " + warnings.warn(
"follow-up version. Please use pts_unit 'sec'.") "The pts_unit 'pts' gives wrong results and will be removed in a "
+ "follow-up version. Please use pts_unit 'sec'."
)
info = _probe_video_from_file(filename) info = _probe_video_from_file(filename)
has_video = 'video_timebase' in info has_video = info.has_video
has_audio = 'audio_timebase' in info has_audio = info.has_audio
def get_pts(time_base): def get_pts(time_base):
start_offset = start_pts start_offset = start_pts
end_offset = end_pts end_offset = end_pts
if pts_unit == 'sec': if pts_unit == "sec":
start_offset = int(math.floor(start_pts * (1 / time_base))) start_offset = int(math.floor(start_pts * (1 / time_base)))
if end_offset != float("inf"): if end_offset != float("inf"):
end_offset = int(math.ceil(end_pts * (1 / time_base))) end_offset = int(math.ceil(end_pts * (1 / time_base)))
...@@ -374,13 +460,17 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'): ...@@ -374,13 +460,17 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
video_pts_range = (0, -1) video_pts_range = (0, -1)
video_timebase = default_timebase video_timebase = default_timebase
if has_video: if has_video:
video_timebase = info['video_timebase'] video_timebase = Fraction(
info.video_timebase.numerator, info.video_timebase.denominator
)
video_pts_range = get_pts(video_timebase) video_pts_range = get_pts(video_timebase)
audio_pts_range = (0, -1) audio_pts_range = (0, -1)
audio_timebase = default_timebase audio_timebase = default_timebase
if has_audio: if has_audio:
audio_timebase = info['audio_timebase'] audio_timebase = Fraction(
info.audio_timebase.numerator, info.audio_timebase.denominator
)
audio_pts_range = get_pts(audio_timebase) audio_pts_range = get_pts(audio_timebase)
vframes, aframes, info = _read_video_from_file( vframes, aframes, info = _read_video_from_file(
...@@ -394,24 +484,28 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'): ...@@ -394,24 +484,28 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
) )
_info = {} _info = {}
if has_video: if has_video:
_info['video_fps'] = info['video_fps'] _info["video_fps"] = info.video_fps
if has_audio: if has_audio:
_info['audio_fps'] = info['audio_sample_rate'] _info["audio_fps"] = info.audio_sample_rate
return vframes, aframes, _info return vframes, aframes, _info
def _read_video_timestamps(filename, pts_unit='pts'): def _read_video_timestamps(filename, pts_unit="pts"):
if pts_unit == 'pts': if pts_unit == "pts":
warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " + warnings.warn(
"follow-up version. Please use pts_unit 'sec'.") "The pts_unit 'pts' gives wrong results and will be removed in a "
+ "follow-up version. Please use pts_unit 'sec'."
)
pts, _, info = _read_video_timestamps_from_file(filename) pts, _, info = _read_video_timestamps_from_file(filename)
if pts_unit == 'sec': if pts_unit == "sec":
video_time_base = info['video_timebase'] video_time_base = Fraction(
info.video_timebase.numerator, info.video_timebase.denominator
)
pts = [x * video_time_base for x in pts] pts = [x * video_time_base for x in pts]
video_fps = info.get('video_fps', None) video_fps = info.video_fps if info.has_video else None
return pts, video_fps return pts, video_fps
import re
import imp
import gc import gc
import os
import torch
import numpy as np
import math import math
import re
import warnings import warnings
from typing import Tuple, List
from . import _video_opt import numpy as np
import torch
_HAS_VIDEO_OPT = False
try: from . import _video_opt
lib_dir = os.path.join(os.path.dirname(__file__), '..') from ._video_opt import VideoMetaData
_, path, description = imp.find_module("video_reader", [lib_dir])
torch.ops.load_library(path)
_HAS_VIDEO_OPT = True
except (ImportError, OSError):
pass
try: try:
import av import av
av.logging.set_level(av.logging.ERROR) av.logging.set_level(av.logging.ERROR)
if not hasattr(av.video.frame.VideoFrame, 'pict_type'): if not hasattr(av.video.frame.VideoFrame, "pict_type"):
av = ImportError("""\ av = ImportError(
"""\
Your version of PyAV is too old for the necessary video operations in torchvision. Your version of PyAV is too old for the necessary video operations in torchvision.
If you are on Python 3.5, you will have to build from source (the conda-forge If you are on Python 3.5, you will have to build from source (the conda-forge
packages are not up-to-date). See packages are not up-to-date). See
https://github.com/mikeboers/PyAV#installation for instructions on how to https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system. install PyAV on your system.
""") """
)
except ImportError: except ImportError:
av = ImportError("""\ av = ImportError(
"""\
PyAV is not installed, and is necessary for the video operations in torchvision. PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system. install PyAV on your system.
""") """
)
def _check_av_available(): def _check_av_available():
...@@ -54,7 +49,7 @@ _CALLED_TIMES = 0 ...@@ -54,7 +49,7 @@ _CALLED_TIMES = 0
_GC_COLLECTION_INTERVAL = 10 _GC_COLLECTION_INTERVAL = 10
def write_video(filename, video_array, fps, video_codec='libx264', options=None): def write_video(filename, video_array, fps, video_codec="libx264", options=None):
""" """
Writes a 4d tensor in [T, H, W, C] format in a video file Writes a 4d tensor in [T, H, W, C] format in a video file
...@@ -70,17 +65,17 @@ def write_video(filename, video_array, fps, video_codec='libx264', options=None) ...@@ -70,17 +65,17 @@ def write_video(filename, video_array, fps, video_codec='libx264', options=None)
_check_av_available() _check_av_available()
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy() video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()
container = av.open(filename, mode='w') container = av.open(filename, mode="w")
stream = container.add_stream(video_codec, rate=fps) stream = container.add_stream(video_codec, rate=fps)
stream.width = video_array.shape[2] stream.width = video_array.shape[2]
stream.height = video_array.shape[1] stream.height = video_array.shape[1]
stream.pix_fmt = 'yuv420p' if video_codec != 'libx264rgb' else 'rgb24' stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
stream.options = options or {} stream.options = options or {}
for img in video_array: for img in video_array:
frame = av.VideoFrame.from_ndarray(img, format='rgb24') frame = av.VideoFrame.from_ndarray(img, format="rgb24")
frame.pict_type = 'NONE' frame.pict_type = "NONE"
for packet in stream.encode(frame): for packet in stream.encode(frame):
container.mux(packet) container.mux(packet)
...@@ -92,19 +87,23 @@ def write_video(filename, video_array, fps, video_codec='libx264', options=None) ...@@ -92,19 +87,23 @@ def write_video(filename, video_array, fps, video_codec='libx264', options=None)
container.close() container.close()
def _read_from_stream(container, start_offset, end_offset, pts_unit, stream, stream_name): def _read_from_stream(
container, start_offset, end_offset, pts_unit, stream, stream_name
):
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
_CALLED_TIMES += 1 _CALLED_TIMES += 1
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
gc.collect() gc.collect()
if pts_unit == 'sec': if pts_unit == "sec":
start_offset = int(math.floor(start_offset * (1 / stream.time_base))) start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
if end_offset != float("inf"): if end_offset != float("inf"):
end_offset = int(math.ceil(end_offset * (1 / stream.time_base))) end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
else: else:
warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " + warnings.warn(
"follow-up version. Please use pts_unit 'sec'.") "The pts_unit 'pts' gives wrong results and will be removed in a "
+ "follow-up version. Please use pts_unit 'sec'."
)
frames = {} frames = {}
should_buffer = False should_buffer = False
...@@ -141,7 +140,7 @@ def _read_from_stream(container, start_offset, end_offset, pts_unit, stream, str ...@@ -141,7 +140,7 @@ def _read_from_stream(container, start_offset, end_offset, pts_unit, stream, str
return [] return []
buffer_count = 0 buffer_count = 0
try: try:
for idx, frame in enumerate(container.decode(**stream_name)): for _idx, frame in enumerate(container.decode(**stream_name)):
frames[frame.pts] = frame frames[frame.pts] = frame
if frame.pts >= end_offset: if frame.pts >= end_offset:
if should_buffer and buffer_count < max_buffer_size: if should_buffer and buffer_count < max_buffer_size:
...@@ -152,7 +151,9 @@ def _read_from_stream(container, start_offset, end_offset, pts_unit, stream, str ...@@ -152,7 +151,9 @@ def _read_from_stream(container, start_offset, end_offset, pts_unit, stream, str
# TODO add a warning # TODO add a warning
pass pass
# ensure that the results are sorted wrt the pts # ensure that the results are sorted wrt the pts
result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset] result = [
frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset
]
if len(frames) > 0 and start_offset > 0 and start_offset not in frames: if len(frames) > 0 and start_offset > 0 and start_offset not in frames:
# if there is no frame that exactly matches the pts of start_offset # if there is no frame that exactly matches the pts of start_offset
# add the last frame smaller than start_offset, to guarantee that # add the last frame smaller than start_offset, to guarantee that
...@@ -177,7 +178,7 @@ def _align_audio_frames(aframes, audio_frames, ref_start, ref_end): ...@@ -177,7 +178,7 @@ def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
return aframes[:, s_idx:e_idx] return aframes[:, s_idx:e_idx]
def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'): def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
""" """
Reads a video from a file, returning both the video frames as well as Reads a video from a file, returning both the video frames as well as
the audio frames the audio frames
...@@ -208,6 +209,7 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'): ...@@ -208,6 +209,7 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
""" """
from torchvision import get_video_backend from torchvision import get_video_backend
if get_video_backend() != "pyav": if get_video_backend() != "pyav":
return _video_opt._read_video(filename, start_pts, end_pts, pts_unit) return _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
...@@ -217,30 +219,44 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'): ...@@ -217,30 +219,44 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
end_pts = float("inf") end_pts = float("inf")
if end_pts < start_pts: if end_pts < start_pts:
raise ValueError("end_pts should be larger than start_pts, got " raise ValueError(
"start_pts={} and end_pts={}".format(start_pts, end_pts)) "end_pts should be larger than start_pts, got "
"start_pts={} and end_pts={}".format(start_pts, end_pts)
)
info = {} info = {}
video_frames = [] video_frames = []
audio_frames = [] audio_frames = []
try: try:
container = av.open(filename, metadata_errors='ignore') container = av.open(filename, metadata_errors="ignore")
except av.AVError: except av.AVError:
# TODO raise a warning? # TODO raise a warning?
pass pass
else: else:
if container.streams.video: if container.streams.video:
video_frames = _read_from_stream(container, start_pts, end_pts, pts_unit, video_frames = _read_from_stream(
container.streams.video[0], {'video': 0}) container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files # guard against potentially corrupted files
if video_fps is not None: if video_fps is not None:
info["video_fps"] = float(video_fps) info["video_fps"] = float(video_fps)
if container.streams.audio: if container.streams.audio:
audio_frames = _read_from_stream(container, start_pts, end_pts, pts_unit, audio_frames = _read_from_stream(
container.streams.audio[0], {'audio': 0}) container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate info["audio_fps"] = container.streams.audio[0].rate
container.close() container.close()
...@@ -272,7 +288,7 @@ def _can_read_timestamps_from_packets(container): ...@@ -272,7 +288,7 @@ def _can_read_timestamps_from_packets(container):
return False return False
def read_video_timestamps(filename, pts_unit='pts'): def read_video_timestamps(filename, pts_unit="pts"):
""" """
List the video frames timestamps. List the video frames timestamps.
...@@ -295,6 +311,7 @@ def read_video_timestamps(filename, pts_unit='pts'): ...@@ -295,6 +311,7 @@ def read_video_timestamps(filename, pts_unit='pts'):
""" """
from torchvision import get_video_backend from torchvision import get_video_backend
if get_video_backend() != "pyav": if get_video_backend() != "pyav":
return _video_opt._read_video_timestamps(filename, pts_unit) return _video_opt._read_video_timestamps(filename, pts_unit)
...@@ -304,7 +321,7 @@ def read_video_timestamps(filename, pts_unit='pts'): ...@@ -304,7 +321,7 @@ def read_video_timestamps(filename, pts_unit='pts'):
video_fps = None video_fps = None
try: try:
container = av.open(filename, metadata_errors='ignore') container = av.open(filename, metadata_errors="ignore")
except av.AVError: except av.AVError:
# TODO add a warning # TODO add a warning
pass pass
...@@ -314,16 +331,61 @@ def read_video_timestamps(filename, pts_unit='pts'): ...@@ -314,16 +331,61 @@ def read_video_timestamps(filename, pts_unit='pts'):
video_time_base = video_stream.time_base video_time_base = video_stream.time_base
if _can_read_timestamps_from_packets(container): if _can_read_timestamps_from_packets(container):
# fast path # fast path
video_frames = [x for x in container.demux(video=0) if x.pts is not None] video_frames = [
x for x in container.demux(video=0) if x.pts is not None
]
else: else:
video_frames = _read_from_stream(container, 0, float("inf"), pts_unit, video_frames = _read_from_stream(
video_stream, {'video': 0}) container, 0, float("inf"), pts_unit, video_stream, {"video": 0}
)
video_fps = float(video_stream.average_rate) video_fps = float(video_stream.average_rate)
container.close() container.close()
pts = [x.pts for x in video_frames] pts = [x.pts for x in video_frames]
if pts_unit == 'sec': if pts_unit == "sec":
pts = [x * video_time_base for x in pts] pts = [x * video_time_base for x in pts]
return pts, video_fps return pts, video_fps
def read_video_meta_data_from_memory(video_data):
# type: (torch.Tensor) -> VideoMetaData
return _video_opt._probe_video_from_memory(video_data)
def read_video_from_memory(
video_data, # type: torch.Tensor
seek_frame_margin=0.25, # type: float
read_video_stream=1, # type: int
video_width=0, # type: int
video_height=0, # type: int
video_min_dimension=0, # type: int
video_pts_range=(0, -1), # type: List[int]
video_timebase_numerator=0, # type: int
video_timebase_denominator=1, # type: int
read_audio_stream=1, # type: int
audio_samples=0, # type: int
audio_channels=0, # type: int
audio_pts_range=(0, -1), # type: List[int]
audio_timebase_numerator=0, # type: int
audio_timebase_denominator=1, # type: int
):
# type: (...) -> Tuple[torch.Tensor, torch.Tensor]
return _video_opt._read_video_from_memory(
video_data,
seek_frame_margin,
read_audio_stream,
video_width,
video_height,
video_min_dimension,
video_pts_range,
video_timebase_numerator,
video_timebase_denominator,
read_audio_stream,
audio_samples,
audio_channels,
audio_pts_range,
audio_timebase_numerator,
audio_timebase_denominator,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment