Commit 31fad34f authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Francisco Massa
Browse files

[video reader] inception commit (#1303)

* [video reader] inception commit

* add method save_metadata to class VideoClips in video_utils.py

* add load_metadata() method to VideoClips class

* add Exception to not catch unexpected events such as memory erros, interrupt

* fix bugs in video_plus.py

* [video reader]remove logging. update setup.py

* remove time measurement in test_video_reader.py

* Remove glog and try making ffmpeg finding more robust

* Add ffmpeg to conda build

* Add ffmpeg to conda build [again]

* Make library path finding more robust

* Missing import

* One more missing fix for import

* Py2 compatibility and change package to av to avoid version conflict with ffmpeg

* Fix for python2

* [video reader] support to decode one stream only (e.g. video/audio stream)

* remove argument _precomputed_metadata_filepath

* remove save_metadata method

* add get_metadata method

* expose _precomputed_metadata and frame_rate arguments in video dataset __init__ method

* remove ssize_t

* remove size_t to pass CI check on Windows

* add PyInit__video_reader function to pass CI check on Windows

* minor fix to define PyInit_video_reader symbol

* Make c++ video reader optional

* Temporarily revert changes to test_io

* Revert changes to python files

* Rename files to make it private

* Fix python lint

* Fix C++ lint

* add a functor object EnumClassHash to make Enum class instances usable as key type of std::unordered_map

* fix cpp format check
parent a6a926bc
...@@ -12,6 +12,7 @@ requirements: ...@@ -12,6 +12,7 @@ requirements:
host: host:
- python - python
- setuptools - setuptools
- av
{{ environ.get('CONDA_PYTORCH_BUILD_CONSTRAINT') }} {{ environ.get('CONDA_PYTORCH_BUILD_CONSTRAINT') }}
{{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }} {{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }}
{{ environ.get('CONDA_CPUONLY_FEATURE') }} {{ environ.get('CONDA_CPUONLY_FEATURE') }}
...@@ -21,6 +22,7 @@ requirements: ...@@ -21,6 +22,7 @@ requirements:
- pillow >=4.1.1 - pillow >=4.1.1
- numpy >=1.11 - numpy >=1.11
- six - six
- av
{{ environ.get('CONDA_PYTORCH_CONSTRAINT') }} {{ environ.get('CONDA_PYTORCH_CONSTRAINT') }}
{{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }} {{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }}
......
...@@ -7,11 +7,12 @@ from setuptools import setup, find_packages ...@@ -7,11 +7,12 @@ from setuptools import setup, find_packages
from pkg_resources import get_distribution, DistributionNotFound from pkg_resources import get_distribution, DistributionNotFound
import subprocess import subprocess
import distutils.command.clean import distutils.command.clean
import distutils.spawn
import glob import glob
import shutil import shutil
import torch import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
def read(*names, **kwargs): def read(*names, **kwargs):
...@@ -124,6 +125,17 @@ def get_extensions(): ...@@ -124,6 +125,17 @@ def get_extensions():
include_dirs = [extensions_dir] include_dirs = [extensions_dir]
tests_include_dirs = [test_dir, models_dir] tests_include_dirs = [test_dir, models_dir]
ffmpeg_exe = distutils.spawn.find_executable('ffmpeg')
has_ffmpeg = ffmpeg_exe is not None
if has_ffmpeg:
ffmpeg_bin = os.path.dirname(ffmpeg_exe)
ffmpeg_root = os.path.dirname(ffmpeg_bin)
ffmpeg_include_dir = os.path.join(ffmpeg_root, 'include')
# TorchVision video reader
video_reader_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'cpu', 'video_reader')
video_reader_src = glob.glob(os.path.join(video_reader_src_dir, "*.cpp"))
ext_modules = [ ext_modules = [
extension( extension(
'torchvision._C', 'torchvision._C',
...@@ -140,6 +152,27 @@ def get_extensions(): ...@@ -140,6 +152,27 @@ def get_extensions():
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
), ),
] ]
if has_ffmpeg:
ext_modules.append(
CppExtension(
'torchvision.video_reader',
video_reader_src,
include_dirs=[
video_reader_src_dir,
ffmpeg_include_dir,
extensions_dir,
],
libraries=[
'avcodec',
'avformat',
'avutil',
'swresample',
'swscale',
],
extra_compile_args=["-std=c++14"],
extra_link_args=["-std=c++14"],
)
)
return ext_modules return ext_modules
...@@ -179,6 +212,8 @@ setup( ...@@ -179,6 +212,8 @@ setup(
"scipy": ["scipy"], "scipy": ["scipy"],
}, },
ext_modules=get_extensions(), ext_modules=get_extensions(),
cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension, cmdclass={
'clean': clean} 'build_ext': BuildExtension.with_options(no_python_abi_suffix=True),
'clean': clean,
}
) )
Video meta-information Notation
Video File Name
video: codec, fps
audio: codec, bits per sample, sample rate
Test videos are listed below.
--------------------------------
- RATRACE_wave_f_nm_np1_fr_goo_37.avi
- source: hmdb51
- video: DivX MPEG-4
- fps: 30
- audio: N/A
- SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi
- source: hmdb51
- video: DivX MPEG-4
- fps: 30
- audio: N/A
- TrumanShow_wave_f_nm_np1_fr_med_26.avi
- source: hmdb51
- video: DivX MPEG-4
- fps: 30
- audio: N/A
- v_SoccerJuggling_g23_c01.avi
- source: ucf101
- video: Xvid MPEG-4
- fps: 29.97
- audio: N/A
- v_SoccerJuggling_g24_c01.avi
- source: ucf101
- video: Xvid MPEG-4
- fps: 29.97
- audio: N/A
- R6llTwEh07w.mp4
- source: kinetics-400
- video: H-264 - MPEG-4 AVC (part 10) (avc1)
- fps: 30
- audio: MPEG AAC audio (mp4a)
- sample rate: 44.1K Hz
- SOX5yA1l24A.mp4
- source: kinetics-400
- video: H-264 - MPEG-4 AVC (part 10) (avc1)
- fps: 29.97
- audio: MPEG AAC audio (mp4a)
- sample rate: 48K Hz
- WUzgd7C1pWA.mp4
- source: kinetics-400
- video: H-264 - MPEG-4 AVC (part 10) (avc1)
- fps: 29.97
- audio: MPEG AAC audio (mp4a)
- sample rate: 48K Hz
import collections
from common_utils import get_tmp_dir
from fractions import Fraction
import math
import numpy as np
import os
import sys
import time
import torch
import torchvision.io as io
import unittest
from numpy.random import randint
try:
import av
# Do a version test too
io.video._check_av_available()
except ImportError:
av = None
if sys.version_info < (3,):
from urllib2 import URLError
else:
from urllib.error import URLError
from torchvision.io._video_opt import _HAS_VIDEO_OPT
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
CheckerConfig = [
"video_fps",
"audio_sample_rate",
# We find for some videos (e.g. HMDB51 videos), the decoded audio frames and pts are
# slightly different between TorchVision decoder and PyAv decoder. So omit it during check
"check_aframes",
"check_aframe_pts",
]
GroundTruth = collections.namedtuple(
"GroundTruth",
" ".join(CheckerConfig)
)
all_check_config = GroundTruth(
video_fps=0,
audio_sample_rate=0,
check_aframes=True,
check_aframe_pts=True,
)
test_videos = {
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"v_SoccerJuggling_g23_c01.avi": GroundTruth(
video_fps=29.97,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"v_SoccerJuggling_g24_c01.avi": GroundTruth(
video_fps=29.97,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"R6llTwEh07w.mp4": GroundTruth(
video_fps=30.0,
audio_sample_rate=44100,
# PyAv miss one audio frame at the beginning (pts=0)
check_aframes=False,
check_aframe_pts=False,
),
"SOX5yA1l24A.mp4": GroundTruth(
video_fps=29.97,
audio_sample_rate=48000,
# PyAv miss one audio frame at the beginning (pts=0)
check_aframes=False,
check_aframe_pts=False,
),
"WUzgd7C1pWA.mp4": GroundTruth(
video_fps=29.97,
audio_sample_rate=48000,
# PyAv miss one audio frame at the beginning (pts=0)
check_aframes=False,
check_aframe_pts=False,
),
}
DecoderResult = collections.namedtuple(
"DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase"
)
"""av_seek_frame is imprecise so seek to a timestamp earlier by a margin
The unit of margin is second"""
seek_frame_margin = 0.25
def _read_from_stream(
container, start_pts, end_pts, stream, stream_name, buffer_size=4
):
"""
Args:
container: pyav container
start_pts/end_pts: the starting/ending Presentation TimeStamp where
frames are read
stream: pyav stream
stream_name: a dictionary of streams. For example, {"video": 0} means
video stream at stream index 0
buffer_size: pts of frames decoded by PyAv is not guaranteed to be in
ascending order. We need to decode more frames even when we meet end
pts
"""
# seeking in the stream is imprecise. Thus, seek to an ealier PTS by a margin
margin = 1
seek_offset = max(start_pts - margin, 0)
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
frames = {}
buffer_count = 0
for frame in container.decode(**stream_name):
if frame.pts < start_pts:
continue
if frame.pts <= end_pts:
frames[frame.pts] = frame
else:
buffer_count += 1
if buffer_count >= buffer_size:
break
result = [frames[pts] for pts in sorted(frames)]
return result
def _get_timebase_by_av_module(full_path):
container = av.open(full_path)
video_time_base = container.streams.video[0].time_base
if container.streams.audio:
audio_time_base = container.streams.audio[0].time_base
else:
audio_time_base = None
return video_time_base, audio_time_base
def _fraction_to_tensor(fraction):
ret = torch.zeros([2], dtype=torch.int32)
ret[0] = fraction.numerator
ret[1] = fraction.denominator
return ret
def _decode_frames_by_av_module(
full_path,
video_start_pts=0,
video_end_pts=None,
audio_start_pts=0,
audio_end_pts=None,
):
"""
Use PyAv to decode video frames. This provides a reference for our decoder
to compare the decoding results.
Input arguments:
full_path: video file path
video_start_pts/video_end_pts: the starting/ending Presentation TimeStamp where
frames are read
"""
if video_end_pts is None:
video_end_pts = float('inf')
if audio_end_pts is None:
audio_end_pts = float('inf')
container = av.open(full_path)
video_frames = []
vtimebase = torch.zeros([0], dtype=torch.int32)
if container.streams.video:
video_frames = _read_from_stream(
container,
video_start_pts,
video_end_pts,
container.streams.video[0],
{"video": 0},
)
# container.streams.video[0].average_rate is not a reliable estimator of
# frame rate. It can be wrong for certain codec, such as VP80
# So we do not return video fps here
vtimebase = _fraction_to_tensor(container.streams.video[0].time_base)
audio_frames = []
atimebase = torch.zeros([0], dtype=torch.int32)
if container.streams.audio:
audio_frames = _read_from_stream(
container,
audio_start_pts,
audio_end_pts,
container.streams.audio[0],
{"audio": 0},
)
atimebase = _fraction_to_tensor(container.streams.audio[0].time_base)
container.close()
vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
vframes = torch.as_tensor(np.stack(vframes))
vframe_pts = torch.tensor([frame.pts for frame in video_frames], dtype=torch.int64)
aframes = [frame.to_ndarray() for frame in audio_frames]
if aframes:
aframes = np.transpose(np.concatenate(aframes, axis=1))
aframes = torch.as_tensor(aframes)
else:
aframes = torch.empty((1, 0), dtype=torch.float32)
aframe_pts = torch.tensor(
[audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64
)
return DecoderResult(
vframes=vframes,
vframe_pts=vframe_pts,
vtimebase=vtimebase,
aframes=aframes,
aframe_pts=aframe_pts,
atimebase=atimebase,
)
def _pts_convert(pts, timebase_from, timebase_to, round_func=math.floor):
"""convert pts between different time bases
Args:
pts: presentation timestamp, float
timebase_from: original timebase. Fraction
timebase_to: new timebase. Fraction
round_func: rounding function.
"""
new_pts = Fraction(pts, 1) * timebase_from / timebase_to
return int(round_func(new_pts))
def _get_video_tensor(video_dir, video_file):
"""open a video file, and represent the video data by a PT tensor"""
full_path = os.path.join(video_dir, video_file)
assert os.path.exists(full_path), "File not found: %s" % full_path
with open(full_path, "rb") as fp:
video_tensor = torch.from_numpy(np.frombuffer(fp.read(), dtype=np.uint8))
return full_path, video_tensor
@unittest.skipIf(av is None, "PyAV unavailable")
@unittest.skipIf(_HAS_VIDEO_OPT is False, "Didn't compile with ffmpeg")
class TestVideoReader(unittest.TestCase):
def check_separate_decoding_result(self, tv_result, config):
"""check the decoding results from TorchVision decoder
"""
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = (
tv_result
)
self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
if asample_rate.numel() > 0:
self.assertEqual(asample_rate.item(), config.audio_sample_rate)
# check if pts of video frames are sorted in ascending order
for i in range(len(vframe_pts) - 1):
self.assertEqual(vframe_pts[i] < vframe_pts[i + 1], True)
if len(aframe_pts) > 1:
# check if pts of audio frames are sorted in ascending order
for i in range(len(aframe_pts) - 1):
self.assertEqual(aframe_pts[i] < aframe_pts[i + 1], True)
def compare_decoding_result(self, tv_result, ref_result, config=all_check_config):
"""
Compare decoding results from two sources.
Args:
tv_result: decoding results from TorchVision decoder
ref_result: reference decoding results which can be from either PyAv
decoder or TorchVision decoder with getPtsOnly = 1
config: config of decoding results checker
"""
vframes, vframe_pts, vtimebase, _vfps, aframes, aframe_pts, atimebase, _asample_rate = (
tv_result
)
if isinstance(ref_result, list):
# the ref_result is from new video_reader decoder
ref_result = DecoderResult(
vframes=ref_result[0],
vframe_pts=ref_result[1],
vtimebase=ref_result[2],
aframes=ref_result[4],
aframe_pts=ref_result[5],
atimebase=ref_result[6],
)
if vframes.numel() > 0 and ref_result.vframes.numel() > 0:
mean_delta = torch.mean(torch.abs(vframes.float() - ref_result.vframes.float()))
self.assertAlmostEqual(mean_delta, 0, delta=8.0)
mean_delta = torch.mean(torch.abs(vframe_pts.float() - ref_result.vframe_pts.float()))
self.assertAlmostEqual(mean_delta, 0, delta=1.0)
is_same = torch.all(torch.eq(vtimebase, ref_result.vtimebase)).item()
self.assertEqual(is_same, True)
if config.check_aframes and aframes.numel() > 0 and ref_result.aframes.numel() > 0:
"""Audio stream is available and audio frame is required to return
from decoder"""
is_same = torch.all(torch.eq(aframes, ref_result.aframes)).item()
self.assertEqual(is_same, True)
if config.check_aframe_pts and aframe_pts.numel() > 0 and ref_result.aframe_pts.numel() > 0:
"""Audio stream is available"""
is_same = torch.all(torch.eq(aframe_pts, ref_result.aframe_pts)).item()
self.assertEqual(is_same, True)
is_same = torch.all(torch.eq(atimebase, ref_result.atimebase)).item()
self.assertEqual(is_same, True)
@unittest.skip(
"This stress test will iteratively decode the same set of videos."
"It helps to detect memory leak but it takes lots of time to run."
"By default, it is disabled"
)
def test_stress_test_read_video_from_file(self):
num_iter = 10000
# video related
width, height, min_dimension = 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for i in range(num_iter):
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using new decoder
_ = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
def test_read_video_from_file(self):
"""
Test the case when decoder starts with a video file to decode frames.
"""
# video related
width, height, min_dimension = 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
# pass 2: decode all frames using av
pyav_result = _decode_frames_by_av_module(full_path)
# check results from TorchVision decoder
self.check_separate_decoding_result(tv_result, config)
# compare decoding results
self.compare_decoding_result(tv_result, pyav_result, config)
def test_read_video_from_file_read_single_stream_only(self):
"""
Test the case when decoder starts with a video file to decode frames, and
only reads video stream and ignores audio stream
"""
# video related
width, height, min_dimension = 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
for readVideoStream, readAudioStream in [(1, 0), (0, 1)]:
# decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
readVideoStream,
width,
height,
min_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
readAudioStream,
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = (
tv_result
)
self.assertEqual(vframes.numel() > 0, readVideoStream)
self.assertEqual(vframe_pts.numel() > 0, readVideoStream)
self.assertEqual(vtimebase.numel() > 0, readVideoStream)
self.assertEqual(vfps.numel() > 0, readVideoStream)
expect_audio_data = readAudioStream == 1 and config.audio_sample_rate is not None
self.assertEqual(aframes.numel() > 0, expect_audio_data)
self.assertEqual(aframe_pts.numel() > 0, expect_audio_data)
self.assertEqual(atimebase.numel() > 0, expect_audio_data)
self.assertEqual(asample_rate.numel() > 0, expect_audio_data)
def test_read_video_from_file_rescale_min_dimension(self):
"""
Test the case when decoder starts with a video file to decode frames, and
video min dimension between height and width is set.
"""
# video related
width, height, min_dimension = 0, 0, 128
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(min_dimension, min(tv_result[0].size(1), tv_result[0].size(2)))
def test_read_video_from_file_rescale_width(self):
"""
Test the case when decoder starts with a video file to decode frames, and
video width is set.
"""
# video related
width, height, min_dimension = 256, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(tv_result[0].size(2), width)
def test_read_video_from_file_rescale_height(self):
"""
Test the case when decoder starts with a video file to decode frames, and
video height is set.
"""
# video related
width, height, min_dimension = 0, 224, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(tv_result[0].size(1), height)
def test_read_video_from_file_rescale_width_and_height(self):
"""
Test the case when decoder starts with a video file to decode frames, and
both video height and width are set.
"""
# video related
width, height, min_dimension = 320, 240, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(tv_result[0].size(1), height)
self.assertEqual(tv_result[0].size(2), width)
def test_read_video_from_file_audio_resampling(self):
"""
Test the case when decoder starts with a video file to decode frames, and
audio waveform are resampled
"""
for samples in [
9600, # downsampling
96000, # upsampling
]:
# video related
width, height, min_dimension = 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
channels = 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, a_sample_rate = (
tv_result
)
if aframes.numel() > 0:
self.assertEqual(samples, a_sample_rate.item())
self.assertEqual(1, aframes.size(1))
# when audio stream is found
duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1])
self.assertAlmostEqual(
aframes.size(0),
int(duration * a_sample_rate.item()),
delta=0.1 * a_sample_rate.item(),
)
def test_compare_read_video_from_memory_and_file(self):
"""
Test the case when video is already in memory, and decoder reads data in memory
"""
# video related
width, height, min_dimension = 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
tv_result_memory = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.check_separate_decoding_result(tv_result_memory, config)
# pass 2: decode all frames from file
tv_result_file = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.check_separate_decoding_result(tv_result_file, config)
# finally, compare results decoded from memory and file
self.compare_decoding_result(tv_result_memory, tv_result_file)
def test_read_video_from_memory(self):
"""
Test the case when video is already in memory, and decoder reads data in memory
"""
# video related
width, height, min_dimension = 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
# pass 2: decode all frames using av
pyav_result = _decode_frames_by_av_module(full_path)
self.check_separate_decoding_result(tv_result, config)
self.compare_decoding_result(tv_result, pyav_result, config)
def test_read_video_from_memory_get_pts_only(self):
"""
Test the case when video is already in memory, and decoder reads data in memory.
Compare frame pts between decoding for pts only and full decoding
for both pts and frame data
"""
# video related
width, height, min_dimension = 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertAlmostEqual(config.video_fps, tv_result[3].item(), delta=0.01)
# pass 2: decode all frames to get PTS only using cpp decoder
tv_result_pts_only = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
1, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
self.assertEqual(tv_result_pts_only[0].numel(), 0)
self.assertEqual(tv_result_pts_only[4].numel(), 0)
self.compare_decoding_result(tv_result, tv_result_pts_only)
def test_read_video_in_range_from_memory(self):
"""
Test the case when video is already in memory, and decoder reads data in memory.
In addition, decoder takes meaningful start- and end PTS as input, and decode
frames within that interval
"""
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# video related
width, height, min_dimension = 0, 0, 0
video_start_pts, video_end_pts = 0, -1
video_timebase_num, video_timebase_den = 0, 1
# audio related
samples, channels = 0, 0
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
# pass 1: decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = (
tv_result
)
self.assertAlmostEqual(config.video_fps, vfps.item(), delta=0.01)
for num_frames in [4, 8, 16, 32, 64, 128]:
start_pts_ind_max = vframe_pts.size(0) - num_frames
if start_pts_ind_max <= 0:
continue
# randomly pick start pts
start_pts_ind = randint(0, start_pts_ind_max)
end_pts_ind = start_pts_ind + num_frames - 1
video_start_pts = vframe_pts[start_pts_ind]
video_end_pts = vframe_pts[end_pts_ind]
video_timebase_num, video_timebase_den = vtimebase[0], vtimebase[1]
if len(atimebase) > 0:
# when audio stream is available
audio_timebase_num, audio_timebase_den = atimebase[0], atimebase[1]
audio_start_pts = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
math.floor,
)
audio_end_pts = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_num.item(), audio_timebase_den.item()),
math.ceil,
)
# pass 2: decode frames in the randomly generated range
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
0, # getPtsOnly
1, # readVideoStream
width,
height,
min_dimension,
video_start_pts,
video_end_pts,
video_timebase_num,
video_timebase_den,
1, # readAudioStream
samples,
channels,
audio_start_pts,
audio_end_pts,
audio_timebase_num,
audio_timebase_den,
)
# pass 3: decode frames in range using PyAv
video_timebase_av, audio_timebase_av = _get_timebase_by_av_module(full_path)
video_start_pts_av = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(video_timebase_av.numerator, video_timebase_av.denominator),
math.floor,
)
video_end_pts_av = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(video_timebase_av.numerator, video_timebase_av.denominator),
math.ceil,
)
if audio_timebase_av:
audio_start_pts = _pts_convert(
video_start_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator),
math.floor,
)
audio_end_pts = _pts_convert(
video_end_pts.item(),
Fraction(video_timebase_num.item(), video_timebase_den.item()),
Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator),
math.ceil,
)
pyav_result = _decode_frames_by_av_module(
full_path,
video_start_pts_av,
video_end_pts_av,
audio_start_pts,
audio_end_pts,
)
self.assertEqual(tv_result[0].size(0), num_frames)
if pyav_result.vframes.size(0) == num_frames:
# if PyAv decodes a different number of video frames, skip
# comparing the decoding results between Torchvision video reader
# and PyAv
self.compare_decoding_result(tv_result, pyav_result, config)
if __name__ == '__main__':
unittest.main()
#include "FfmpegAudioSampler.h"
#include <memory>
#include "FfmpegUtil.h"
using namespace std;
FfmpegAudioSampler::FfmpegAudioSampler(
const AudioFormat& in,
const AudioFormat& out)
: inFormat_(in), outFormat_(out) {}
FfmpegAudioSampler::~FfmpegAudioSampler() {
if (swrContext_) {
swr_free(&swrContext_);
}
}
int FfmpegAudioSampler::init() {
swrContext_ = swr_alloc_set_opts(
nullptr, // we're allocating a new context
av_get_default_channel_layout(outFormat_.channels), // out_ch_layout
static_cast<AVSampleFormat>(outFormat_.format), // out_sample_fmt
outFormat_.samples, // out_sample_rate
av_get_default_channel_layout(inFormat_.channels), // in_ch_layout
static_cast<AVSampleFormat>(inFormat_.format), // in_sample_fmt
inFormat_.samples, // in_sample_rate
0, // log_offset
nullptr); // log_ctx
if (swrContext_ == nullptr) {
LOG(ERROR) << "swr_alloc_set_opts fails";
return -1;
}
int result = 0;
if ((result = swr_init(swrContext_)) < 0) {
LOG(ERROR) << "swr_init failed, err: " << ffmpeg_util::getErrorDesc(result)
<< ", in -> format: " << inFormat_.format
<< ", channels: " << inFormat_.channels
<< ", samples: " << inFormat_.samples
<< ", out -> format: " << outFormat_.format
<< ", channels: " << outFormat_.channels
<< ", samples: " << outFormat_.samples;
return -1;
}
return 0;
}
int64_t FfmpegAudioSampler::getSampleBytes(const AVFrame* frame) const {
auto outSamples = getOutNumSamples(frame->nb_samples);
return av_samples_get_buffer_size(
nullptr,
outFormat_.channels,
outSamples,
static_cast<AVSampleFormat>(outFormat_.format),
1);
}
// https://www.ffmpeg.org/doxygen/3.2/group__lswr.html
unique_ptr<DecodedFrame> FfmpegAudioSampler::sample(const AVFrame* frame) {
if (!frame) {
return nullptr; // no flush for videos
}
auto inNumSamples = frame->nb_samples;
auto outNumSamples = getOutNumSamples(frame->nb_samples);
auto outSampleSize = getSampleBytes(frame);
AvDataPtr frameData(static_cast<uint8_t*>(av_malloc(outSampleSize)));
uint8_t* outPlanes[AVRESAMPLE_MAX_CHANNELS];
int result = 0;
if ((result = av_samples_fill_arrays(
outPlanes,
nullptr, // linesize is not needed
frameData.get(),
outFormat_.channels,
outNumSamples,
static_cast<AVSampleFormat>(outFormat_.format),
1)) < 0) {
LOG(ERROR) << "av_samples_fill_arrays failed, err: "
<< ffmpeg_util::getErrorDesc(result)
<< ", outNumSamples: " << outNumSamples
<< ", format: " << outFormat_.format;
return nullptr;
}
if ((result = swr_convert(
swrContext_,
&outPlanes[0],
outNumSamples,
(const uint8_t**)&frame->data[0],
inNumSamples)) < 0) {
LOG(ERROR) << "swr_convert faield, err: "
<< ffmpeg_util::getErrorDesc(result);
return nullptr;
}
// result returned by swr_convert is the No. of actual output samples.
// So update the buffer size using av_samples_get_buffer_size
result = av_samples_get_buffer_size(
nullptr,
outFormat_.channels,
result,
static_cast<AVSampleFormat>(outFormat_.format),
1);
return make_unique<DecodedFrame>(std::move(frameData), result, 0);
}
/*
Because of decoding delay, the returned value is an upper bound of No. of
output samples
*/
int64_t FfmpegAudioSampler::getOutNumSamples(int inNumSamples) const {
return av_rescale_rnd(
swr_get_delay(swrContext_, inFormat_.samples) + inNumSamples,
outFormat_.samples,
inFormat_.samples,
AV_ROUND_UP);
}
#pragma once
#include "FfmpegSampler.h"
#define AVRESAMPLE_MAX_CHANNELS 32
/**
* Class transcode audio frames from one format into another
*/
class FfmpegAudioSampler : public FfmpegSampler {
public:
explicit FfmpegAudioSampler(const AudioFormat& in, const AudioFormat& out);
~FfmpegAudioSampler() override;
int init() override;
int64_t getSampleBytes(const AVFrame* frame) const;
// FfmpegSampler overrides
// returns number of bytes of the sampled data
std::unique_ptr<DecodedFrame> sample(const AVFrame* frame) override;
const AudioFormat& getInFormat() const {
return inFormat_;
}
private:
int64_t getOutNumSamples(int inNumSamples) const;
AudioFormat inFormat_;
AudioFormat outFormat_;
SwrContext* swrContext_{nullptr};
};
#include "FfmpegAudioStream.h"
#include "FfmpegUtil.h"
using namespace std;
namespace {
bool operator==(const AudioFormat& x, const AVCodecContext& y) {
return x.samples == y.sample_rate && x.channels == y.channels &&
x.format == y.sample_fmt;
}
AudioFormat& toAudioFormat(
AudioFormat& audioFormat,
const AVCodecContext& codecCtx) {
audioFormat.samples = codecCtx.sample_rate;
audioFormat.channels = codecCtx.channels;
audioFormat.format = codecCtx.sample_fmt;
return audioFormat;
}
} // namespace
FfmpegAudioStream::FfmpegAudioStream(
AVFormatContext* inputCtx,
int index,
enum AVMediaType avMediaType,
MediaFormat mediaFormat,
double seekFrameMargin)
: FfmpegStream(inputCtx, index, avMediaType, seekFrameMargin),
mediaFormat_(mediaFormat) {}
FfmpegAudioStream::~FfmpegAudioStream() {}
void FfmpegAudioStream::checkStreamDecodeParams() {
auto timeBase = getTimeBase();
if (timeBase.first > 0) {
CHECK_EQ(timeBase.first, inputCtx_->streams[index_]->time_base.num);
CHECK_EQ(timeBase.second, inputCtx_->streams[index_]->time_base.den);
}
}
void FfmpegAudioStream::updateStreamDecodeParams() {
auto timeBase = getTimeBase();
if (timeBase.first == 0) {
mediaFormat_.format.audio.timeBaseNum =
inputCtx_->streams[index_]->time_base.num;
mediaFormat_.format.audio.timeBaseDen =
inputCtx_->streams[index_]->time_base.den;
}
}
int FfmpegAudioStream::initFormat() {
AudioFormat& format = mediaFormat_.format.audio;
if (format.samples == 0) {
format.samples = codecCtx_->sample_rate;
}
if (format.channels == 0) {
format.channels = codecCtx_->channels;
}
if (format.format == AV_SAMPLE_FMT_NONE) {
format.format = codecCtx_->sample_fmt;
VLOG(2) << "set stream format sample_fmt: " << format.format;
}
checkStreamDecodeParams();
updateStreamDecodeParams();
if (format.samples > 0 && format.channels > 0 &&
format.format != AV_SAMPLE_FMT_NONE) {
return 0;
} else {
return -1;
}
}
unique_ptr<DecodedFrame> FfmpegAudioStream::sampleFrameData() {
AudioFormat& audioFormat = mediaFormat_.format.audio;
if (!sampler_ || !(sampler_->getInFormat() == *codecCtx_)) {
AudioFormat newInFormat;
newInFormat = toAudioFormat(newInFormat, *codecCtx_);
sampler_ = make_unique<FfmpegAudioSampler>(newInFormat, audioFormat);
VLOG(1) << "Set sampler input audio format"
<< ", samples: " << newInFormat.samples
<< ", channels: " << newInFormat.channels
<< ", format: " << newInFormat.format
<< " : output audio sampler format"
<< ", samples: " << audioFormat.samples
<< ", channels: " << audioFormat.channels
<< ", format: " << audioFormat.format;
int ret = sampler_->init();
if (ret < 0) {
VLOG(1) << "Fail to initialize audio sampler";
return nullptr;
}
}
return sampler_->sample(frame_);
}
#pragma once
#include <utility>
#include "FfmpegAudioSampler.h"
#include "FfmpegStream.h"
/**
* Class uses FFMPEG library to decode one video stream.
*/
class FfmpegAudioStream : public FfmpegStream {
public:
explicit FfmpegAudioStream(
AVFormatContext* inputCtx,
int index,
enum AVMediaType avMediaType,
MediaFormat mediaFormat,
double seekFrameMargin);
~FfmpegAudioStream() override;
// FfmpegStream overrides
MediaType getMediaType() const override {
return MediaType::TYPE_AUDIO;
}
FormatUnion getMediaFormat() const override {
return mediaFormat_.format;
}
int64_t getStartPts() const override {
return mediaFormat_.format.audio.startPts;
}
int64_t getEndPts() const override {
return mediaFormat_.format.audio.endPts;
}
// return numerator and denominator of time base
std::pair<int, int> getTimeBase() const {
return std::make_pair(
mediaFormat_.format.audio.timeBaseNum,
mediaFormat_.format.audio.timeBaseDen);
}
void checkStreamDecodeParams();
void updateStreamDecodeParams();
protected:
int initFormat() override;
std::unique_ptr<DecodedFrame> sampleFrameData() override;
private:
MediaFormat mediaFormat_;
std::unique_ptr<FfmpegAudioSampler> sampler_{nullptr};
};
#include "FfmpegDecoder.h"
#include "FfmpegAudioStream.h"
#include "FfmpegUtil.h"
#include "FfmpegVideoStream.h"
using namespace std;
static AVPacket avPkt;
namespace {
unique_ptr<FfmpegStream> createFfmpegStream(
MediaType type,
AVFormatContext* ctx,
int idx,
MediaFormat& mediaFormat,
double seekFrameMargin) {
enum AVMediaType avType;
CHECK(ffmpeg_util::mapMediaType(type, &avType));
switch (type) {
case MediaType::TYPE_VIDEO:
return make_unique<FfmpegVideoStream>(
ctx, idx, avType, mediaFormat, seekFrameMargin);
case MediaType::TYPE_AUDIO:
return make_unique<FfmpegAudioStream>(
ctx, idx, avType, mediaFormat, seekFrameMargin);
default:
return nullptr;
}
}
} // namespace
FfmpegAvioContext::FfmpegAvioContext()
: workBuffersize_(VIO_BUFFER_SZ),
workBuffer_((uint8_t*)av_malloc(workBuffersize_)),
inputFile_(nullptr),
inputBuffer_(nullptr),
inputBufferSize_(0) {}
int FfmpegAvioContext::initAVIOContext(const uint8_t* buffer, int64_t size) {
inputBuffer_ = buffer;
inputBufferSize_ = size;
avioCtx_ = avio_alloc_context(
workBuffer_,
workBuffersize_,
0,
reinterpret_cast<void*>(this),
&FfmpegAvioContext::readMemory,
nullptr, // no write function
&FfmpegAvioContext::seekMemory);
return 0;
}
FfmpegAvioContext::~FfmpegAvioContext() {
/* note: the internal buffer could have changed, and be != workBuffer_ */
if (avioCtx_) {
av_freep(&avioCtx_->buffer);
av_freep(&avioCtx_);
} else {
av_freep(&workBuffer_);
}
if (inputFile_) {
fclose(inputFile_);
}
}
int FfmpegAvioContext::read(uint8_t* buf, int buf_size) {
if (inputBuffer_) {
return readMemory(this, buf, buf_size);
} else {
return -1;
}
}
int FfmpegAvioContext::readMemory(void* opaque, uint8_t* buf, int buf_size) {
FfmpegAvioContext* h = static_cast<FfmpegAvioContext*>(opaque);
if (buf_size < 0) {
return -1;
}
int reminder = h->inputBufferSize_ - h->offset_;
int r = buf_size < reminder ? buf_size : reminder;
if (r < 0) {
return AVERROR_EOF;
}
memcpy(buf, h->inputBuffer_ + h->offset_, r);
h->offset_ += r;
return r;
}
int64_t FfmpegAvioContext::seek(int64_t offset, int whence) {
if (inputBuffer_) {
return seekMemory(this, offset, whence);
} else {
return -1;
}
}
int64_t FfmpegAvioContext::seekMemory(
void* opaque,
int64_t offset,
int whence) {
FfmpegAvioContext* h = static_cast<FfmpegAvioContext*>(opaque);
switch (whence) {
case SEEK_CUR: // from current position
h->offset_ += offset;
break;
case SEEK_END: // from eof
h->offset_ = h->inputBufferSize_ + offset;
break;
case SEEK_SET: // from beginning of file
h->offset_ = offset;
break;
case AVSEEK_SIZE:
return h->inputBufferSize_;
}
return h->offset_;
}
int FfmpegDecoder::init(
const std::string& filename,
bool isDecodeFile,
FfmpegAvioContext& ioctx,
DecoderOutput& decoderOutput) {
cleanUp();
int ret = 0;
if (!isDecodeFile) {
formatCtx_ = avformat_alloc_context();
if (!formatCtx_) {
LOG(ERROR) << "avformat_alloc_context failed";
return -1;
}
formatCtx_->pb = ioctx.get_avio();
formatCtx_->flags |= AVFMT_FLAG_CUSTOM_IO;
// Determining the input format:
int probeSz = AVPROBE_SIZE + AVPROBE_PADDING_SIZE;
uint8_t* probe((uint8_t*)av_malloc(probeSz));
memset(probe, 0, probeSz);
int len = ioctx.read(probe, probeSz - AVPROBE_PADDING_SIZE);
if (len < probeSz - AVPROBE_PADDING_SIZE) {
LOG(ERROR) << "Insufficient data to determine video format";
av_freep(&probe);
return -1;
}
// seek back to start of stream
ioctx.seek(0, SEEK_SET);
unique_ptr<AVProbeData> probeData(new AVProbeData());
probeData->buf = probe;
probeData->buf_size = len;
probeData->filename = "";
// Determine the input-format:
formatCtx_->iformat = av_probe_input_format(probeData.get(), 1);
// this is to avoid the double-free error
if (formatCtx_->iformat == nullptr) {
LOG(ERROR) << "av_probe_input_format fails";
return -1;
}
VLOG(1) << "av_probe_input_format succeeds";
av_freep(&probe);
ret = avformat_open_input(&formatCtx_, "", nullptr, nullptr);
} else {
ret = avformat_open_input(&formatCtx_, filename.c_str(), nullptr, nullptr);
}
if (ret < 0) {
LOG(ERROR) << "avformat_open_input failed, error: "
<< ffmpeg_util::getErrorDesc(ret);
cleanUp();
return ret;
}
ret = avformat_find_stream_info(formatCtx_, nullptr);
if (ret < 0) {
LOG(ERROR) << "avformat_find_stream_info failed, error: "
<< ffmpeg_util::getErrorDesc(ret);
cleanUp();
return ret;
}
if (!initStreams()) {
LOG(ERROR) << "Cannot activate streams";
cleanUp();
return -1;
}
for (auto& stream : streams_) {
MediaType mediaType = stream.second->getMediaType();
decoderOutput.initMediaType(mediaType, stream.second->getMediaFormat());
}
VLOG(1) << "FfmpegDecoder initialized";
return 0;
}
int FfmpegDecoder::decodeFile(
unique_ptr<DecoderParameters> params,
const string& fileName,
DecoderOutput& decoderOutput) {
VLOG(1) << "decode file: " << fileName;
FfmpegAvioContext ioctx;
int ret = decodeLoop(std::move(params), fileName, true, ioctx, decoderOutput);
return ret;
}
int FfmpegDecoder::decodeMemory(
unique_ptr<DecoderParameters> params,
const uint8_t* buffer,
int64_t size,
DecoderOutput& decoderOutput) {
VLOG(1) << "decode video data in memory";
FfmpegAvioContext ioctx;
int ret = ioctx.initAVIOContext(buffer, size);
if (ret == 0) {
ret =
decodeLoop(std::move(params), string(""), false, ioctx, decoderOutput);
}
return ret;
}
void FfmpegDecoder::cleanUp() {
if (formatCtx_) {
for (auto& stream : streams_) {
// Drain stream buffers.
DecoderOutput decoderOutput;
stream.second->flush(1, decoderOutput);
stream.second.reset();
}
streams_.clear();
avformat_close_input(&formatCtx_);
}
}
FfmpegStream* FfmpegDecoder::findStreamByIndex(int streamIndex) const {
auto it = streams_.find(streamIndex);
return it != streams_.end() ? it->second.get() : nullptr;
}
/*
Reference implementation:
https://ffmpeg.org/doxygen/3.4/demuxing_decoding_8c-example.html
*/
int FfmpegDecoder::decodeLoop(
unique_ptr<DecoderParameters> params,
const std::string& filename,
bool isDecodeFile,
FfmpegAvioContext& ioctx,
DecoderOutput& decoderOutput) {
params_ = std::move(params);
int ret = init(filename, isDecodeFile, ioctx, decoderOutput);
if (ret < 0) {
return ret;
}
// init package
av_init_packet(&avPkt);
avPkt.data = nullptr;
avPkt.size = 0;
int result = 0;
bool ptsInRange = true;
while (ptsInRange) {
result = av_read_frame(formatCtx_, &avPkt);
if (result == AVERROR(EAGAIN)) {
VLOG(1) << "Decoder is busy";
ret = 0;
break;
} else if (result == AVERROR_EOF) {
VLOG(1) << "Stream decoding is completed";
ret = 0;
break;
} else if (result < 0) {
VLOG(1) << "av_read_frame fails. Break decoder loop. Error: "
<< ffmpeg_util::getErrorDesc(result);
ret = result;
break;
}
ret = 0;
auto stream = findStreamByIndex(avPkt.stream_index);
if (stream == nullptr) {
// the packet is from a stream the caller is not interested. Ignore it
VLOG(2) << "avPkt ignored. stream index: " << avPkt.stream_index;
// Need to free the memory of AVPacket. Otherwise, memory leak happens
av_packet_unref(&avPkt);
continue;
}
do {
result = stream->sendPacket(&avPkt);
if (result == AVERROR(EAGAIN)) {
VLOG(2) << "avcodec_send_packet returns AVERROR(EAGAIN)";
// start to recevie available frames from internal buffer
stream->receiveAvailFrames(params_->getPtsOnly, decoderOutput);
if (isPtsExceedRange()) {
// exit the most-outer while loop
VLOG(1) << "In all streams, exceed the end pts. Exit decoding loop";
ret = 0;
ptsInRange = false;
break;
}
} else if (result < 0) {
LOG(WARNING) << "avcodec_send_packet failed. Error: "
<< ffmpeg_util::getErrorDesc(result);
ret = result;
break;
} else {
VLOG(2) << "avcodec_send_packet succeeds";
// succeed. Read the next AVPacket and send out it
break;
}
} while (ptsInRange);
// Need to free the memory of AVPacket. Otherwise, memory leak happens
av_packet_unref(&avPkt);
}
/* flush cached frames */
flushStreams(decoderOutput);
return ret;
}
bool FfmpegDecoder::initStreams() {
for (auto it = params_->formats.begin(); it != params_->formats.end(); ++it) {
AVMediaType mediaType;
if (!ffmpeg_util::mapMediaType(it->first, &mediaType)) {
LOG(ERROR) << "Unknown media type: " << it->first;
return false;
}
int streamIdx =
av_find_best_stream(formatCtx_, mediaType, -1, -1, nullptr, 0);
if (streamIdx >= 0) {
VLOG(2) << "find stream index: " << streamIdx;
auto stream = createFfmpegStream(
it->first,
formatCtx_,
streamIdx,
it->second,
params_->seekFrameMargin);
CHECK(stream);
if (stream->openCodecContext() < 0) {
LOG(ERROR) << "Cannot open codec. Stream index: " << streamIdx;
return false;
}
streams_.emplace(streamIdx, move(stream));
} else {
VLOG(1) << "Cannot open find stream of type " << it->first;
}
}
// Seek frames in each stream
int ret = 0;
for (auto& stream : streams_) {
auto startPts = stream.second->getStartPts();
VLOG(1) << "stream: " << stream.first << " startPts: " << startPts;
if (startPts > 0 && (ret = stream.second->seekFrame(startPts)) < 0) {
LOG(WARNING) << "seekFrame in stream fails";
return false;
}
}
VLOG(1) << "initStreams succeeds";
return true;
}
bool FfmpegDecoder::isPtsExceedRange() {
bool exceed = true;
for (auto& stream : streams_) {
exceed = exceed && stream.second->isFramePtsExceedRange();
}
return exceed;
}
void FfmpegDecoder::flushStreams(DecoderOutput& decoderOutput) {
for (auto& stream : streams_) {
stream.second->flush(params_->getPtsOnly, decoderOutput);
}
}
#pragma once
#include <string>
#include <vector>
#include "FfmpegHeaders.h"
#include "FfmpegStream.h"
#include "Interface.h"
#define VIO_BUFFER_SZ 81920
#define AVPROBE_SIZE 8192
class DecoderParameters {
public:
std::unordered_map<MediaType, MediaFormat, EnumClassHash> formats;
// av_seek_frame is imprecise so seek to a timestamp earlier by a margin
// The unit of margin is second
double seekFrameMargin{1.0};
// When getPtsOnly is set to 1, we only get pts of each frame and don not
// output frame data. It will be much faster
int64_t getPtsOnly{0};
};
class FfmpegAvioContext {
public:
FfmpegAvioContext();
int initAVIOContext(const uint8_t* buffer, int64_t size);
~FfmpegAvioContext();
int read(uint8_t* buf, int buf_size);
static int readMemory(void* opaque, uint8_t* buf, int buf_size);
int64_t seek(int64_t offset, int whence);
static int64_t seekMemory(void* opaque, int64_t offset, int whence);
AVIOContext* get_avio() {
return avioCtx_;
}
private:
int workBuffersize_;
uint8_t* workBuffer_;
// for file mode
FILE* inputFile_;
// for memory mode
const uint8_t* inputBuffer_;
int inputBufferSize_;
int offset_ = 0;
AVIOContext* avioCtx_{nullptr};
};
class FfmpegDecoder {
public:
FfmpegDecoder() {
av_register_all();
}
~FfmpegDecoder() {
cleanUp();
}
// return 0 on success
// return negative number on failure
int decodeFile(
std::unique_ptr<DecoderParameters> params,
const std::string& filename,
DecoderOutput& decoderOutput);
// return 0 on success
// return negative number on failure
int decodeMemory(
std::unique_ptr<DecoderParameters> params,
const uint8_t* buffer,
int64_t size,
DecoderOutput& decoderOutput);
void cleanUp();
private:
FfmpegStream* findStreamByIndex(int streamIndex) const;
int init(
const std::string& filename,
bool isDecodeFile,
FfmpegAvioContext& ioctx,
DecoderOutput& decoderOutput);
// return 0 on success
// return negative number on failure
int decodeLoop(
std::unique_ptr<DecoderParameters> params,
const std::string& filename,
bool isDecodeFile,
FfmpegAvioContext& ioctx,
DecoderOutput& decoderOutput);
bool initStreams();
void flushStreams(DecoderOutput& decoderOutput);
// whether in all streams, the pts of most recent frame exceeds range
bool isPtsExceedRange();
std::unordered_map<int, std::unique_ptr<FfmpegStream>> streams_;
AVFormatContext* formatCtx_{nullptr};
std::unique_ptr<DecoderParameters> params_{nullptr};
};
#pragma once
extern "C" {
#include <libavcodec/avcodec.h>
#include <libavformat/avformat.h>
#include <libavformat/avio.h>
#include <libavutil/avutil.h>
#include <libavutil/imgutils.h>
#include <libavutil/log.h>
#include <libavutil/samplefmt.h>
#include <libswresample/swresample.h>
#include <libswscale/swscale.h>
}
#pragma once
#include "FfmpegHeaders.h"
#include "Interface.h"
/**
* Class sample data from AVFrame
*/
class FfmpegSampler {
public:
virtual ~FfmpegSampler() = default;
// return 0 on success and negative number on failure
virtual int init() = 0;
// sample from the given frame
virtual std::unique_ptr<DecodedFrame> sample(const AVFrame* frame) = 0;
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment