Unverified Commit e130c6cc authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

torchscriptable functions for video io (#1653) (#1794)

* torchscriptable functions for video io (#1653)

Summary:
Pull Request resolved: https://github.com/pytorch/vision/pull/1653



created new torchscriptable video io functions as part of the api: read_video_meta_data_from_memory and read_video_from_memory.

Updated the implementation of some of the internal functions to be torchscriptable.

Reviewed By: stephenyan1231

Differential Revision: D18720474

fbshipit-source-id: 4ee646b66afecd2dc338a71fd8f249f25a3263bc

* BugFix
Co-authored-by: default avatarJon Guerin <54725679+jguerin-fb@users.noreply.github.com>
parent 28b7f8ae
import contextlib
import sys
import os
import torch
import unittest
......@@ -92,7 +91,6 @@ class Tester(unittest.TestCase):
video, audio, info, video_idx = video_clips.get_clip(i)
self.assertEqual(video.shape[0], num_frames)
self.assertEqual(info["video_fps"], fps)
self.assertEqual(info, {"video_fps": fps})
# TODO add tests checking that the content is right
def test_compute_clips_for_video(self):
......
......@@ -81,8 +81,8 @@ class Tester(unittest.TestCase):
def test_probe_video_from_file(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
video_info = io._probe_video_from_file(f_name)
self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1)
self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1)
self.assertAlmostEqual(video_info.video_duration, 2, delta=0.1)
self.assertAlmostEqual(video_info.video_fps, 5, delta=0.1)
@unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
def test_probe_video_from_memory(self):
......@@ -90,8 +90,8 @@ class Tester(unittest.TestCase):
with open(f_name, "rb") as fp:
filebuffer = fp.read()
video_info = io._probe_video_from_memory(filebuffer)
self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1)
self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1)
self.assertAlmostEqual(video_info.video_duration, 2, delta=0.1)
self.assertAlmostEqual(video_info.video_fps, 5, delta=0.1)
def test_read_timestamps(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
......
import collections
from common_utils import get_tmp_dir
from fractions import Fraction
import math
import numpy as np
import os
import sys
import time
import unittest
from fractions import Fraction
import numpy as np
import torch
import torchvision.io as io
import unittest
from numpy.random import randint
from torchvision.io import _HAS_VIDEO_OPT
try:
import av
# Do a version test too
io.video._check_av_available()
except ImportError:
......@@ -25,9 +28,6 @@ else:
from urllib.error import URLError
from torchvision.io import _HAS_VIDEO_OPT
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
CheckerConfig = [
......@@ -39,10 +39,7 @@ CheckerConfig = [
"check_aframes",
"check_aframe_pts",
]
GroundTruth = collections.namedtuple(
"GroundTruth",
" ".join(CheckerConfig)
)
GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig))
all_check_config = GroundTruth(
duration=0,
......@@ -193,9 +190,9 @@ def _decode_frames_by_av_module(
frames are read
"""
if video_end_pts is None:
video_end_pts = float('inf')
video_end_pts = float("inf")
if audio_end_pts is None:
audio_end_pts = float('inf')
audio_end_pts = float("inf")
container = av.open(full_path)
video_frames = []
......@@ -282,8 +279,10 @@ class TestVideoReader(unittest.TestCase):
def check_separate_decoding_result(self, tv_result, config):
"""check the decoding results from TorchVision decoder
"""
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = tv_result
vframes, vframe_pts, vtimebase, vfps, vduration, \
aframes, aframe_pts, atimebase, asample_rate, aduration = (
tv_result
)
video_duration = vduration.item() * Fraction(
vtimebase[0].item(), vtimebase[1].item()
......@@ -321,6 +320,13 @@ class TestVideoReader(unittest.TestCase):
)
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5)
def check_meta_result(self, result, config):
self.assertAlmostEqual(result.video_duration, config.duration, delta=0.5)
self.assertAlmostEqual(result.video_fps, config.video_fps, delta=0.5)
if result.has_audio > 0:
self.assertEqual(result.audio_sample_rate, config.audio_sample_rate)
self.assertAlmostEqual(result.audio_duration, config.duration, delta=0.5)
def compare_decoding_result(self, tv_result, ref_result, config=all_check_config):
"""
Compare decoding results from two sources.
......@@ -330,8 +336,10 @@ class TestVideoReader(unittest.TestCase):
decoder or TorchVision decoder with getPtsOnly = 1
config: config of decoding results checker
"""
vframes, vframe_pts, vtimebase, _vfps, _vduration, aframes, aframe_pts, \
atimebase, _asample_rate, _aduration = tv_result
vframes, vframe_pts, vtimebase, _vfps, _vduration, \
aframes, aframe_pts, atimebase, _asample_rate, _aduration = (
tv_result
)
if isinstance(ref_result, list):
# the ref_result is from new video_reader decoder
ref_result = DecoderResult(
......@@ -344,22 +352,34 @@ class TestVideoReader(unittest.TestCase):
)
if vframes.numel() > 0 and ref_result.vframes.numel() > 0:
mean_delta = torch.mean(torch.abs(vframes.float() - ref_result.vframes.float()))
mean_delta = torch.mean(
torch.abs(vframes.float() - ref_result.vframes.float())
)
self.assertAlmostEqual(mean_delta, 0, delta=8.0)
mean_delta = torch.mean(torch.abs(vframe_pts.float() - ref_result.vframe_pts.float()))
mean_delta = torch.mean(
torch.abs(vframe_pts.float() - ref_result.vframe_pts.float())
)
self.assertAlmostEqual(mean_delta, 0, delta=1.0)
is_same = torch.all(torch.eq(vtimebase, ref_result.vtimebase)).item()
self.assertEqual(is_same, True)
if config.check_aframes and aframes.numel() > 0 and ref_result.aframes.numel() > 0:
if (
config.check_aframes
and aframes.numel() > 0
and ref_result.aframes.numel() > 0
):
"""Audio stream is available and audio frame is required to return
from decoder"""
is_same = torch.all(torch.eq(aframes, ref_result.aframes)).item()
self.assertEqual(is_same, True)
if config.check_aframe_pts and aframe_pts.numel() > 0 and ref_result.aframe_pts.numel() > 0:
if (
config.check_aframe_pts
and aframe_pts.numel() > 0
and ref_result.aframe_pts.numel() > 0
):
"""Audio stream is available"""
is_same = torch.all(torch.eq(aframe_pts, ref_result.aframe_pts)).item()
self.assertEqual(is_same, True)
......@@ -492,15 +512,19 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = tv_result
vframes, vframe_pts, vtimebase, vfps, vduration, \
aframes, aframe_pts, atimebase, asample_rate, aduration = (
tv_result
)
self.assertEqual(vframes.numel() > 0, readVideoStream)
self.assertEqual(vframe_pts.numel() > 0, readVideoStream)
self.assertEqual(vtimebase.numel() > 0, readVideoStream)
self.assertEqual(vfps.numel() > 0, readVideoStream)
expect_audio_data = readAudioStream == 1 and config.audio_sample_rate is not None
expect_audio_data = (
readAudioStream == 1 and config.audio_sample_rate is not None
)
self.assertEqual(aframes.numel() > 0, expect_audio_data)
self.assertEqual(aframe_pts.numel() > 0, expect_audio_data)
self.assertEqual(atimebase.numel() > 0, expect_audio_data)
......@@ -543,7 +567,9 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(min_dimension, min(tv_result[0].size(1), tv_result[0].size(2)))
self.assertEqual(
min_dimension, min(tv_result[0].size(1), tv_result[0].size(2))
)
def test_read_video_from_file_rescale_width(self):
"""
......@@ -669,10 +695,7 @@ class TestVideoReader(unittest.TestCase):
audio waveform are resampled
"""
for samples in [
9600, # downsampling
96000, # upsampling
]:
for samples in [9600, 96000]: # downsampling # upsampling
# video related
width, height, min_dimension = 0, 0, 0
video_start_pts, video_end_pts = 0, -1
......@@ -705,13 +728,19 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num,
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = tv_result
vframes, vframe_pts, vtimebase, vfps, vduration, \
aframes, aframe_pts, atimebase, asample_rate, aduration = (
tv_result
)
if aframes.numel() > 0:
self.assertEqual(samples, asample_rate.item())
self.assertEqual(1, aframes.size(1))
# when audio stream is found
duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1])
duration = (
float(aframe_pts[-1])
* float(atimebase[0])
/ float(atimebase[1])
)
self.assertAlmostEqual(
aframes.size(0),
int(duration * asample_rate.item()),
......@@ -929,8 +958,10 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num,
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = tv_result
vframes, vframe_pts, vtimebase, vfps, vduration, \
aframes, aframe_pts, atimebase, asample_rate, aduration = (
tv_result
)
self.assertAlmostEqual(config.video_fps, vfps.item(), delta=0.01)
for num_frames in [4, 8, 16, 32, 64, 128]:
......@@ -983,31 +1014,41 @@ class TestVideoReader(unittest.TestCase):
)
# pass 3: decode frames in range using PyAv
video_timebase_av, audio_timebase_av = _get_timebase_by_av_module(full_path)
video_timebase_av, audio_timebase_av = _get_timebase_by_av_module(
full_path
)
video_start_pts_av = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(video_timebase_av.numerator, video_timebase_av.denominator),
Fraction(
video_timebase_av.numerator, video_timebase_av.denominator
),
math.floor,
)
video_end_pts_av = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(video_timebase_av.numerator, video_timebase_av.denominator),
Fraction(
video_timebase_av.numerator, video_timebase_av.denominator
),
math.ceil,
)
if audio_timebase_av:
audio_start_pts = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator),
Fraction(
audio_timebase_av.numerator, audio_timebase_av.denominator
),
math.floor,
)
audio_end_pts = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator),
Fraction(
audio_timebase_av.numerator, audio_timebase_av.denominator
),
math.ceil,
)
......@@ -1044,6 +1085,54 @@ class TestVideoReader(unittest.TestCase):
probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor)
self.check_probe_result(probe_result, config)
def test_read_video_meta_data_from_memory_script(self):
scripted_fun = torch.jit.script(io.read_video_meta_data_from_memory)
self.assertIsNotNone(scripted_fun)
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
probe_result = scripted_fun(video_tensor)
self.check_meta_result(probe_result, config)
def test_read_video_from_memory_scripted(self):
"""
Test the case when video is already in memory, and decoder reads data in memory
"""
# video related
width, height, min_dimension = 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
scripted_fun = torch.jit.script(io.read_video_from_memory)
self.assertIsNotNone(scripted_fun)
for test_video, _config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# decode all frames using cpp decoder
scripted_fun(
video_tensor,
seek_frame_margin,
1, # readVideoStream
width,
height,
min_dimension,
[video_start_pts, video_end_pts],
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
[audio_start_pts, audio_end_pts],
audio_timebase_num,
audio_timebase_den,
)
# FUTURE: check value of video / audio frames
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
import bisect
from fractions import Fraction
import math
from fractions import Fraction
import torch
from torchvision.io import (
_read_video_timestamps_from_file,
_probe_video_from_file,
_read_video_from_file,
_probe_video_from_file
_read_video_timestamps_from_file,
read_video,
read_video_timestamps,
)
from torchvision.io import read_video_timestamps, read_video
from .utils import tqdm
......@@ -48,6 +50,7 @@ class _DummyDataset(object):
Dummy dataset used for DataLoader in VideoClips.
Defined at top level so it can be pickled when forking.
"""
def __init__(self, x):
self.x = x
......@@ -83,10 +86,21 @@ class VideoClips(object):
num_workers (int): how many subprocesses to use for data loading.
0 means that the data will be loaded in the main process. (default: 0)
"""
def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1,
frame_rate=None, _precomputed_metadata=None, num_workers=0,
_video_width=0, _video_height=0, _video_min_dimension=0,
_audio_samples=0, _audio_channels=0):
def __init__(
self,
video_paths,
clip_length_in_frames=16,
frames_between_clips=1,
frame_rate=None,
_precomputed_metadata=None,
num_workers=0,
_video_width=0,
_video_height=0,
_video_min_dimension=0,
_audio_samples=0,
_audio_channels=0,
):
self.video_paths = video_paths
self.num_workers = num_workers
......@@ -114,11 +128,13 @@ class VideoClips(object):
# strategy: use a DataLoader to parallelize read_video_timestamps
# so need to create a dummy dataset first
import torch.utils.data
dl = torch.utils.data.DataLoader(
_DummyDataset(self.video_paths),
batch_size=16,
num_workers=self.num_workers,
collate_fn=self._collate_fn)
collate_fn=self._collate_fn,
)
with tqdm(total=len(dl)) as pbar:
for batch in dl:
......@@ -140,7 +156,7 @@ class VideoClips(object):
_metadata = {
"video_paths": self.video_paths,
"video_pts": self.video_pts,
"video_fps": self.video_fps
"video_fps": self.video_fps,
}
return _metadata
......@@ -151,15 +167,21 @@ class VideoClips(object):
metadata = {
"video_paths": video_paths,
"video_pts": video_pts,
"video_fps": video_fps
"video_fps": video_fps,
}
return type(self)(video_paths, self.num_frames, self.step, self.frame_rate,
_precomputed_metadata=metadata, num_workers=self.num_workers,
_video_width=self._video_width,
_video_height=self._video_height,
_video_min_dimension=self._video_min_dimension,
_audio_samples=self._audio_samples,
_audio_channels=self._audio_channels)
return type(self)(
video_paths,
self.num_frames,
self.step,
self.frame_rate,
_precomputed_metadata=metadata,
num_workers=self.num_workers,
_video_width=self._video_width,
_video_height=self._video_height,
_video_min_dimension=self._video_min_dimension,
_audio_samples=self._audio_samples,
_audio_channels=self._audio_channels,
)
@staticmethod
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate):
......@@ -170,7 +192,9 @@ class VideoClips(object):
if frame_rate is None:
frame_rate = fps
total_frames = len(video_pts) * (float(frame_rate) / fps)
idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
idxs = VideoClips._resample_video_idx(
int(math.floor(total_frames)), fps, frame_rate
)
video_pts = video_pts[idxs]
clips = unfold(video_pts, num_frames, step)
if isinstance(idxs, slice):
......@@ -195,7 +219,9 @@ class VideoClips(object):
self.clips = []
self.resampling_idxs = []
for video_pts, fps in zip(self.video_pts, self.video_fps):
clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate)
clips, idxs = self.compute_clips_for_video(
video_pts, num_frames, step, fps, frame_rate
)
self.clips.append(clips)
self.resampling_idxs.append(idxs)
clip_lengths = torch.as_tensor([len(v) for v in self.clips])
......@@ -251,13 +277,16 @@ class VideoClips(object):
video_idx (int): index of the video in `video_paths`
"""
if idx >= self.num_clips():
raise IndexError("Index {} out of range "
"({} number of clips)".format(idx, self.num_clips()))
raise IndexError(
"Index {} out of range "
"({} number of clips)".format(idx, self.num_clips())
)
video_idx, clip_idx = self.get_clip_location(idx)
video_path = self.video_paths[video_idx]
clip_pts = self.clips[video_idx][clip_idx]
from torchvision import get_video_backend
backend = get_video_backend()
if backend == "pyav":
......@@ -267,7 +296,9 @@ class VideoClips(object):
if self._video_height != 0:
raise ValueError("pyav backend doesn't support _video_height != 0")
if self._video_min_dimension != 0:
raise ValueError("pyav backend doesn't support _video_min_dimension != 0")
raise ValueError(
"pyav backend doesn't support _video_min_dimension != 0"
)
if self._audio_samples != 0:
raise ValueError("pyav backend doesn't support _audio_samples != 0")
......@@ -277,7 +308,7 @@ class VideoClips(object):
video, audio, info = read_video(video_path, start_pts, end_pts)
else:
info = _probe_video_from_file(video_path)
video_fps = info["video_fps"]
video_fps = info.video_fps
audio_fps = None
video_start_pts = clip_pts[0].item()
......@@ -285,28 +316,27 @@ class VideoClips(object):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase = Fraction(0, 1)
if "audio_timebase" in info:
audio_timebase = info["audio_timebase"]
video_timebase = Fraction(
info.video_timebase.numerator, info.video_timebase.denominator
)
if info.has_audio:
audio_timebase = Fraction(
info.audio_timebase.numerator, info.audio_timebase.denominator
)
audio_start_pts = pts_convert(
video_start_pts,
info["video_timebase"],
info["audio_timebase"],
math.floor,
video_start_pts, video_timebase, audio_timebase, math.floor
)
audio_end_pts = pts_convert(
video_end_pts,
info["video_timebase"],
info["audio_timebase"],
math.ceil,
video_end_pts, video_timebase, audio_timebase, math.ceil
)
audio_fps = info["audio_sample_rate"]
audio_fps = info.audio_sample_rate
video, audio, info = _read_video_from_file(
video_path,
video_width=self._video_width,
video_height=self._video_height,
video_min_dimension=self._video_min_dimension,
video_pts_range=(video_start_pts, video_end_pts),
video_timebase=info["video_timebase"],
video_timebase=video_timebase,
audio_samples=self._audio_samples,
audio_channels=self._audio_channels,
audio_pts_range=(audio_start_pts, audio_end_pts),
......@@ -323,5 +353,7 @@ class VideoClips(object):
resampling_idx = resampling_idx - resampling_idx[0]
video = video[resampling_idx]
info["video_fps"] = self.frame_rate
assert len(video) == self.num_frames, "{} x {}".format(video.shape, self.num_frames)
assert len(video) == self.num_frames, "{} x {}".format(
video.shape, self.num_frames
)
return video, audio, info, video_idx
from .video import write_video, read_video, read_video_timestamps, _HAS_VIDEO_OPT
from ._video_opt import (
_read_video_from_file,
_read_video_timestamps_from_file,
Timebase,
VideoMetaData,
_HAS_VIDEO_OPT,
_probe_video_from_file,
_probe_video_from_memory,
_read_video_from_file,
_read_video_from_memory,
_read_video_timestamps_from_file,
_read_video_timestamps_from_memory,
_probe_video_from_memory,
)
from .video import (
read_video,
read_video_from_memory,
read_video_meta_data_from_memory,
read_video_timestamps,
write_video,
)
__all__ = [
'write_video', 'read_video', 'read_video_timestamps',
'_read_video_from_file', '_read_video_timestamps_from_file', '_probe_video_from_file',
'_read_video_from_memory', '_read_video_timestamps_from_memory', '_probe_video_from_memory',
'_HAS_VIDEO_OPT',
"write_video",
"read_video",
"read_video_timestamps",
"read_video_meta_data_from_memory",
"read_video_from_memory",
"_read_video_from_file",
"_read_video_timestamps_from_file",
"_probe_video_from_file",
"_read_video_from_memory",
"_read_video_timestamps_from_memory",
"_probe_video_from_memory",
"_HAS_VIDEO_OPT",
"_read_video_clip_from_memory",
"_read_video_meta_data",
"VideoMetaData",
"Timebase",
]
from fractions import Fraction
import imp
import math
import os
import warnings
from fractions import Fraction
from typing import List, Tuple
import numpy as np
import torch
import warnings
_HAS_VIDEO_OPT = False
try:
lib_dir = os.path.join(os.path.dirname(__file__), "..")
_, path, description = imp.find_module("video_reader", [lib_dir])
torch.ops.load_library(path)
_HAS_VIDEO_OPT = True
except (ImportError, OSError):
pass
default_timebase = Fraction(0, 1)
# simple class for torch scripting
# the complex Fraction class from fractions module is not scriptable
@torch.jit.script
class Timebase(object):
__annotations__ = {"numerator": int, "denominator": int}
__slots__ = ["numerator", "denominator"]
def __init__(
self,
numerator, # type: int
denominator, # type: int
):
# type: (...) -> None
self.numerator = numerator
self.denominator = denominator
@torch.jit.script
class VideoMetaData(object):
__annotations__ = {
"has_video": bool,
"video_timebase": Timebase,
"video_duration": float,
"video_fps": float,
"has_audio": bool,
"audio_timebase": Timebase,
"audio_duration": float,
"audio_sample_rate": float,
}
__slots__ = [
"has_video",
"video_timebase",
"video_duration",
"video_fps",
"has_audio",
"audio_timebase",
"audio_duration",
"audio_sample_rate",
]
def __init__(self):
self.has_video = False
self.video_timebase = Timebase(0, 1)
self.video_duration = 0.0
self.video_fps = 0.0
self.has_audio = False
self.audio_timebase = Timebase(0, 1)
self.audio_duration = 0.0
self.audio_sample_rate = 0.0
def _validate_pts(pts_range):
# type: (List[int])
if pts_range[1] > 0:
assert pts_range[0] <= pts_range[1], \
"""Start pts should not be smaller than end pts, got
start pts: %d and end pts: %d""" % (pts_range[0], pts_range[1])
assert (
pts_range[0] <= pts_range[1]
), """Start pts should not be smaller than end pts, got
start pts: %d and end pts: %d""" % (
pts_range[0],
pts_range[1],
)
def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration):
info = {}
# type: (torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor) -> VideoMetaData
"""
Build update VideoMetaData struct with info about the video
"""
meta = VideoMetaData()
if vtimebase.numel() > 0:
info["video_timebase"] = Fraction(vtimebase[0].item(), vtimebase[1].item())
meta.video_timebase = Timebase(
int(vtimebase[0].item()), int(vtimebase[1].item())
)
timebase = vtimebase[0].item() / float(vtimebase[1].item())
if vduration.numel() > 0:
video_duration = vduration.item() * info["video_timebase"]
info["video_duration"] = video_duration
meta.has_video = True
meta.video_duration = float(vduration.item()) * timebase
if vfps.numel() > 0:
info["video_fps"] = vfps.item()
meta.video_fps = float(vfps.item())
if atimebase.numel() > 0:
info["audio_timebase"] = Fraction(atimebase[0].item(), atimebase[1].item())
meta.audio_timebase = Timebase(
int(atimebase[0].item()), int(atimebase[1].item())
)
timebase = atimebase[0].item() / float(atimebase[1].item())
if aduration.numel() > 0:
audio_duration = aduration.item() * info["audio_timebase"]
info["audio_duration"] = audio_duration
meta.has_audio = True
meta.audio_duration = float(aduration.item()) * timebase
if asample_rate.numel() > 0:
info["audio_sample_rate"] = asample_rate.item()
meta.audio_sample_rate = float(asample_rate.item())
return info
return meta
def _align_audio_frames(aframes, aframe_pts, audio_pts_range):
# type: (torch.Tensor, torch.Tensor, List[int]) -> torch.Tensor
start, end = aframe_pts[0], aframe_pts[-1]
num_samples = aframes.size(0)
step_per_aframe = float(end - start + 1) / float(num_samples)
......@@ -136,8 +219,10 @@ def _read_video_from_file(
audio_timebase.numerator,
audio_timebase.denominator,
)
vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, \
asample_rate, aduration = result
vframes, _vframe_pts, vtimebase, vfps, vduration, \
aframes, aframe_pts, atimebase, asample_rate, aduration = (
result
)
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
if aframes.numel() > 0:
# when audio stream is found
......@@ -171,8 +256,8 @@ def _read_video_timestamps_from_file(filename):
0, # audio_timebase_num
1, # audio_timebase_den
)
_vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, \
asample_rate, aduration = result
_vframes, vframe_pts, vtimebase, vfps, vduration, \
_aframes, aframe_pts, atimebase, asample_rate, aduration = (result)
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
vframe_pts = vframe_pts.numpy().tolist()
......@@ -182,10 +267,7 @@ def _read_video_timestamps_from_file(filename):
def _probe_video_from_file(filename):
"""
Probe a video file.
Return:
info [dict]: contain video meta information, including video_timebase,
video_duration, video_fps, audio_timebase, audio_duration, audio_sample_rate
Probe a video file and return VideoMetaData with info about the video
"""
result = torch.ops.video_reader.probe_video_from_file(filename)
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
......@@ -194,23 +276,27 @@ def _probe_video_from_file(filename):
def _read_video_from_memory(
video_data,
seek_frame_margin=0.25,
read_video_stream=1,
video_width=0,
video_height=0,
video_min_dimension=0,
video_pts_range=(0, -1),
video_timebase=default_timebase,
read_audio_stream=1,
audio_samples=0,
audio_channels=0,
audio_pts_range=(0, -1),
audio_timebase=default_timebase,
video_data, # type: torch.Tensor
seek_frame_margin=0.25, # type: float
read_video_stream=1, # type: int
video_width=0, # type: int
video_height=0, # type: int
video_min_dimension=0, # type: int
video_pts_range=(0, -1), # type: List[int]
video_timebase_numerator=0, # type: int
video_timebase_denominator=1, # type: int
read_audio_stream=1, # type: int
audio_samples=0, # type: int
audio_channels=0, # type: int
audio_pts_range=(0, -1), # type: List[int]
audio_timebase_numerator=0, # type: int
audio_timebase_denominator=1, # type: int
):
# type: (...) -> Tuple[torch.Tensor, torch.Tensor]
"""
Reads a video from memory, returning both the video frames as well as
the audio frames
This function is torchscriptable.
Args
----------
......@@ -234,8 +320,8 @@ def _read_video_from_memory(
are set to $video_width and $video_height, respectively
video_pts_range : list(int), optional
the start and end presentation timestamp of video stream
video_timebase: Fraction, optional
a Fraction rational number which denotes timebase in video stream
video_timebase_numerator / video_timebase_denominator: optional
a rational number which denotes timebase in video stream
read_audio_stream: int, optional
whether read audio stream. If yes, set to 1. Otherwise, 0
audio_samples: int, optional
......@@ -244,8 +330,8 @@ def _read_video_from_memory(
audio audio_channels
audio_pts_range : list(int), optional
the start and end presentation timestamp of audio stream
audio_timebase: Fraction, optional
a Fraction rational number which denotes time base in audio stream
audio_timebase_numerator / audio_timebase_denominator: optional
a rational number which denotes time base in audio stream
Returns
-------
......@@ -254,17 +340,11 @@ def _read_video_from_memory(
aframes : Tensor[L, K]
the audio frames, where `L` is the number of points and
`K` is the number of channels
info : Dict
metadata for the video and audio. Can contain the fields video fps (float)
and audio sample rate (int)
"""
_validate_pts(video_pts_range)
_validate_pts(audio_pts_range)
if not isinstance(video_data, torch.Tensor):
video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8))
result = torch.ops.video_reader.read_video_from_memory(
video_data,
seek_frame_margin,
......@@ -275,24 +355,27 @@ def _read_video_from_memory(
video_min_dimension,
video_pts_range[0],
video_pts_range[1],
video_timebase.numerator,
video_timebase.denominator,
video_timebase_numerator,
video_timebase_denominator,
read_audio_stream,
audio_samples,
audio_channels,
audio_pts_range[0],
audio_pts_range[1],
audio_timebase.numerator,
audio_timebase.denominator,
audio_timebase_numerator,
audio_timebase_denominator,
)
vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \
atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
vframes, _vframe_pts, vtimebase, vfps, vduration, \
aframes, aframe_pts, atimebase, asample_rate, aduration = (
result
)
if aframes.numel() > 0:
# when audio stream is found
aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
return vframes, aframes, info
return vframes, aframes
def _read_video_timestamps_from_memory(video_data):
......@@ -323,8 +406,10 @@ def _read_video_timestamps_from_memory(video_data):
0, # audio_timebase_num
1, # audio_timebase_den
)
_vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, \
atimebase, asample_rate, aduration = result
_vframes, vframe_pts, vtimebase, vfps, vduration, \
_aframes, aframe_pts, atimebase, asample_rate, aduration = (
result
)
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
vframe_pts = vframe_pts.numpy().tolist()
......@@ -333,11 +418,10 @@ def _read_video_timestamps_from_memory(video_data):
def _probe_video_from_memory(video_data):
# type: (torch.Tensor) -> VideoMetaData
"""
Probe a video in memory.
Return:
info [dict]: contain video meta information, including video_timebase,
video_duration, video_fps, audio_timebase, audio_duration, audio_sample_rate
Probe a video in memory and return VideoMetaData with info about the video
This function is torchscriptable
"""
if not isinstance(video_data, torch.Tensor):
video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8))
......@@ -347,23 +431,25 @@ def _probe_video_from_memory(video_data):
return info
def _read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
if end_pts is None:
end_pts = float("inf")
if pts_unit == 'pts':
warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " +
"follow-up version. Please use pts_unit 'sec'.")
if pts_unit == "pts":
warnings.warn(
"The pts_unit 'pts' gives wrong results and will be removed in a "
+ "follow-up version. Please use pts_unit 'sec'."
)
info = _probe_video_from_file(filename)
has_video = 'video_timebase' in info
has_audio = 'audio_timebase' in info
has_video = info.has_video
has_audio = info.has_audio
def get_pts(time_base):
start_offset = start_pts
end_offset = end_pts
if pts_unit == 'sec':
if pts_unit == "sec":
start_offset = int(math.floor(start_pts * (1 / time_base)))
if end_offset != float("inf"):
end_offset = int(math.ceil(end_pts * (1 / time_base)))
......@@ -374,13 +460,17 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
video_pts_range = (0, -1)
video_timebase = default_timebase
if has_video:
video_timebase = info['video_timebase']
video_timebase = Fraction(
info.video_timebase.numerator, info.video_timebase.denominator
)
video_pts_range = get_pts(video_timebase)
audio_pts_range = (0, -1)
audio_timebase = default_timebase
if has_audio:
audio_timebase = info['audio_timebase']
audio_timebase = Fraction(
info.audio_timebase.numerator, info.audio_timebase.denominator
)
audio_pts_range = get_pts(audio_timebase)
vframes, aframes, info = _read_video_from_file(
......@@ -394,24 +484,28 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
)
_info = {}
if has_video:
_info['video_fps'] = info['video_fps']
_info["video_fps"] = info.video_fps
if has_audio:
_info['audio_fps'] = info['audio_sample_rate']
_info["audio_fps"] = info.audio_sample_rate
return vframes, aframes, _info
def _read_video_timestamps(filename, pts_unit='pts'):
if pts_unit == 'pts':
warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " +
"follow-up version. Please use pts_unit 'sec'.")
def _read_video_timestamps(filename, pts_unit="pts"):
if pts_unit == "pts":
warnings.warn(
"The pts_unit 'pts' gives wrong results and will be removed in a "
+ "follow-up version. Please use pts_unit 'sec'."
)
pts, _, info = _read_video_timestamps_from_file(filename)
if pts_unit == 'sec':
video_time_base = info['video_timebase']
if pts_unit == "sec":
video_time_base = Fraction(
info.video_timebase.numerator, info.video_timebase.denominator
)
pts = [x * video_time_base for x in pts]
video_fps = info.get('video_fps', None)
video_fps = info.video_fps if info.has_video else None
return pts, video_fps
import re
import imp
import gc
import os
import torch
import numpy as np
import math
import re
import warnings
from typing import Tuple, List
from . import _video_opt
_HAS_VIDEO_OPT = False
import numpy as np
import torch
try:
lib_dir = os.path.join(os.path.dirname(__file__), '..')
_, path, description = imp.find_module("video_reader", [lib_dir])
torch.ops.load_library(path)
_HAS_VIDEO_OPT = True
except (ImportError, OSError):
pass
from . import _video_opt
from ._video_opt import VideoMetaData
try:
import av
av.logging.set_level(av.logging.ERROR)
if not hasattr(av.video.frame.VideoFrame, 'pict_type'):
av = ImportError("""\
if not hasattr(av.video.frame.VideoFrame, "pict_type"):
av = ImportError(
"""\
Your version of PyAV is too old for the necessary video operations in torchvision.
If you are on Python 3.5, you will have to build from source (the conda-forge
packages are not up-to-date). See
https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
""")
"""
)
except ImportError:
av = ImportError("""\
av = ImportError(
"""\
PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
""")
"""
)
def _check_av_available():
......@@ -54,7 +49,7 @@ _CALLED_TIMES = 0
_GC_COLLECTION_INTERVAL = 10
def write_video(filename, video_array, fps, video_codec='libx264', options=None):
def write_video(filename, video_array, fps, video_codec="libx264", options=None):
"""
Writes a 4d tensor in [T, H, W, C] format in a video file
......@@ -70,17 +65,17 @@ def write_video(filename, video_array, fps, video_codec='libx264', options=None)
_check_av_available()
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()
container = av.open(filename, mode='w')
container = av.open(filename, mode="w")
stream = container.add_stream(video_codec, rate=fps)
stream.width = video_array.shape[2]
stream.height = video_array.shape[1]
stream.pix_fmt = 'yuv420p' if video_codec != 'libx264rgb' else 'rgb24'
stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
stream.options = options or {}
for img in video_array:
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame.pict_type = 'NONE'
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
frame.pict_type = "NONE"
for packet in stream.encode(frame):
container.mux(packet)
......@@ -92,19 +87,23 @@ def write_video(filename, video_array, fps, video_codec='libx264', options=None)
container.close()
def _read_from_stream(container, start_offset, end_offset, pts_unit, stream, stream_name):
def _read_from_stream(
container, start_offset, end_offset, pts_unit, stream, stream_name
):
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
_CALLED_TIMES += 1
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
gc.collect()
if pts_unit == 'sec':
if pts_unit == "sec":
start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
if end_offset != float("inf"):
end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
else:
warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " +
"follow-up version. Please use pts_unit 'sec'.")
warnings.warn(
"The pts_unit 'pts' gives wrong results and will be removed in a "
+ "follow-up version. Please use pts_unit 'sec'."
)
frames = {}
should_buffer = False
......@@ -141,7 +140,7 @@ def _read_from_stream(container, start_offset, end_offset, pts_unit, stream, str
return []
buffer_count = 0
try:
for idx, frame in enumerate(container.decode(**stream_name)):
for _idx, frame in enumerate(container.decode(**stream_name)):
frames[frame.pts] = frame
if frame.pts >= end_offset:
if should_buffer and buffer_count < max_buffer_size:
......@@ -152,7 +151,9 @@ def _read_from_stream(container, start_offset, end_offset, pts_unit, stream, str
# TODO add a warning
pass
# ensure that the results are sorted wrt the pts
result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
result = [
frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset
]
if len(frames) > 0 and start_offset > 0 and start_offset not in frames:
# if there is no frame that exactly matches the pts of start_offset
# add the last frame smaller than start_offset, to guarantee that
......@@ -177,7 +178,7 @@ def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
return aframes[:, s_idx:e_idx]
def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
"""
Reads a video from a file, returning both the video frames as well as
the audio frames
......@@ -208,6 +209,7 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
"""
from torchvision import get_video_backend
if get_video_backend() != "pyav":
return _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
......@@ -217,30 +219,44 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
end_pts = float("inf")
if end_pts < start_pts:
raise ValueError("end_pts should be larger than start_pts, got "
"start_pts={} and end_pts={}".format(start_pts, end_pts))
raise ValueError(
"end_pts should be larger than start_pts, got "
"start_pts={} and end_pts={}".format(start_pts, end_pts)
)
info = {}
video_frames = []
audio_frames = []
try:
container = av.open(filename, metadata_errors='ignore')
container = av.open(filename, metadata_errors="ignore")
except av.AVError:
# TODO raise a warning?
pass
else:
if container.streams.video:
video_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
container.streams.video[0], {'video': 0})
video_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)
if container.streams.audio:
audio_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
container.streams.audio[0], {'audio': 0})
audio_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate
container.close()
......@@ -272,7 +288,7 @@ def _can_read_timestamps_from_packets(container):
return False
def read_video_timestamps(filename, pts_unit='pts'):
def read_video_timestamps(filename, pts_unit="pts"):
"""
List the video frames timestamps.
......@@ -295,6 +311,7 @@ def read_video_timestamps(filename, pts_unit='pts'):
"""
from torchvision import get_video_backend
if get_video_backend() != "pyav":
return _video_opt._read_video_timestamps(filename, pts_unit)
......@@ -304,7 +321,7 @@ def read_video_timestamps(filename, pts_unit='pts'):
video_fps = None
try:
container = av.open(filename, metadata_errors='ignore')
container = av.open(filename, metadata_errors="ignore")
except av.AVError:
# TODO add a warning
pass
......@@ -314,16 +331,61 @@ def read_video_timestamps(filename, pts_unit='pts'):
video_time_base = video_stream.time_base
if _can_read_timestamps_from_packets(container):
# fast path
video_frames = [x for x in container.demux(video=0) if x.pts is not None]
video_frames = [
x for x in container.demux(video=0) if x.pts is not None
]
else:
video_frames = _read_from_stream(container, 0, float("inf"), pts_unit,
video_stream, {'video': 0})
video_frames = _read_from_stream(
container, 0, float("inf"), pts_unit, video_stream, {"video": 0}
)
video_fps = float(video_stream.average_rate)
container.close()
pts = [x.pts for x in video_frames]
if pts_unit == 'sec':
if pts_unit == "sec":
pts = [x * video_time_base for x in pts]
return pts, video_fps
def read_video_meta_data_from_memory(video_data):
# type: (torch.Tensor) -> VideoMetaData
return _video_opt._probe_video_from_memory(video_data)
def read_video_from_memory(
video_data, # type: torch.Tensor
seek_frame_margin=0.25, # type: float
read_video_stream=1, # type: int
video_width=0, # type: int
video_height=0, # type: int
video_min_dimension=0, # type: int
video_pts_range=(0, -1), # type: List[int]
video_timebase_numerator=0, # type: int
video_timebase_denominator=1, # type: int
read_audio_stream=1, # type: int
audio_samples=0, # type: int
audio_channels=0, # type: int
audio_pts_range=(0, -1), # type: List[int]
audio_timebase_numerator=0, # type: int
audio_timebase_denominator=1, # type: int
):
# type: (...) -> Tuple[torch.Tensor, torch.Tensor]
return _video_opt._read_video_from_memory(
video_data,
seek_frame_margin,
read_audio_stream,
video_width,
video_height,
video_min_dimension,
video_pts_range,
video_timebase_numerator,
video_timebase_denominator,
read_audio_stream,
audio_samples,
audio_channels,
audio_pts_range,
audio_timebase_numerator,
audio_timebase_denominator,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment