Unverified Commit 87c78641 authored by Bruno Korbar's avatar Bruno Korbar Committed by GitHub
Browse files

(WIP) Initial implementation of the new videoReader API (#2683)



* adding base files

* setup modification to actually build the thing

* video api constructor registration

* FAIL metadata

* FAIL update for QS

* revert

* debugging with Victor

* adding base files

* setup modification to actually build the thing

* video api constructor registration

* FAIL metadata

* FAIL update for QS

* revert

* debugging with Victor

* metadata registration works

* API build next

* test

* Merge change

* formatting parameters to avoid the segfault

* next now works on a video

* make size of the output tensor format dependent

* Make next work on audio stream only as well

* refactoring the _setCurrentStream param

* Fixing the last frame return and sensor

* todo docs

* Formatting

* cleanup and comments

* introducing new tests for the API

* cleanup

* Comment out unnecesary format (will add following FFMPEG fix)

* Reformat parsing function

* removing the seek bug `get_decoder_params`

* Removing unnecessary code/variables

* enforce RGB24 as a reading format (will crash before ffmpeg fix)

* permute the dimensions to return (RGB x H x W)

* Changing the return type to std::tuple<torch::Tensor, double> as opposed to tensor list

* Adjusting tests for the new return type

* remove unnecessary jitter

* clangangangang

* Metadata return changes (#1)

* remove implicit calls to set a current stream (#2)

* Adding new tests to check the accuracy of the seek

* cleanup debugging statements

* adding base files

* setup modification to actually build the thing

* video api constructor registration

* FAIL metadata

* FAIL update for QS

* revert

* debugging with Victor

* adding base files

* video api constructor registration

* FAIL metadata

* FAIL update for QS

* revert

* debugging with Victor

* metadata registration works

* API build next

* test

* Merge change

* formatting parameters to avoid the segfault

* next now works on a video

* make size of the output tensor format dependent

* Make next work on audio stream only as well

* refactoring the _setCurrentStream param

* Fixing the last frame return and sensor

* todo docs

* Formatting

* cleanup and comments

* introducing new tests for the API

* cleanup

* Comment out unnecesary format (will add following FFMPEG fix)

* Reformat parsing function

* removing the seek bug `get_decoder_params`

* Removing unnecessary code/variables

* enforce RGB24 as a reading format (will crash before ffmpeg fix)

* permute the dimensions to return (RGB x H x W)

* Changing the return type to std::tuple<torch::Tensor, double> as opposed to tensor list

* Adjusting tests for the new return type

* remove unnecessary jitter

* clangangangang

* Metadata return changes (#1)

* remove implicit calls to set a current stream (#2)

* Adding new tests to check the accuracy of the seek

* cleanup debugging statements

* Addressing PR comments

* addressing Francisco's comments

* CLANG build formatting

* Updated testing to test against pyav for the video tensor reads

* Formatting

* remove pyav from pip deps and add it to conda build

* add pyav and ffmeped to conda builds

* Formatting?

* Setting up linter once and for all hopefully

* Testing pyav

* Fix to 8.0.0

* Try 6.2.0

* See what happens with av from pip

* Remove FFMPEG blocker

* What is going on?

* More tests

* Forgot something

* unblocker

* Check if cache is messing up with things

* Now try with different ffmpeg

* Now try with different ffmpeg

* Testing pyav

* Fix to 8.0.0

* Try 6.2.0

* See what happens with av from pip

* What is going on?

* More tests

* Forgot something

* Check if cache is messing up with things

* Now try with different ffmpeg

* Now try with different ffmpeg

* Do not install av

* Test with ffmpeg 4.2

* clean up video tests

* cleaning up the tests a bit to better test partial reading

* arrgh linter

* Forgot the av test

* forgot av test

* checkout build files from master

* revert circleci

* addressing Franciscos comments

* addressing Franciscos comments

* Ignore ffmpeg in travis
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
Co-authored-by: default avatarEdgar Andrés Margffoy Tuay <andfoy@gmail.com>
parent 754c954f
...@@ -38,7 +38,6 @@ before_install: ...@@ -38,7 +38,6 @@ before_install:
fi fi
- conda install av -c conda-forge - conda install av -c conda-forge
install: install:
# Using pip instead of setup.py ensures we install a non-compressed version of the package # Using pip instead of setup.py ensures we install a non-compressed version of the package
# (as opposed to an egg), which is necessary to collect coverage. # (as opposed to an egg), which is necessary to collect coverage.
...@@ -55,7 +54,7 @@ install: ...@@ -55,7 +54,7 @@ install:
cd - cd -
script: script:
- pytest --cov-config .coveragerc --cov torchvision --cov $TV_INSTALL_PATH -k 'not TestVideoReader and not TestVideoTransforms and not TestIO' test --ignore=test/test_datasets_download.py - pytest --cov-config .coveragerc --cov torchvision --cov $TV_INSTALL_PATH -k 'not TestVideo and not TestVideoReader and not TestVideoTransforms and not TestIO' test --ignore=test/test_datasets_download.py
- pytest test/test_hub.py - pytest test/test_hub.py
after_success: after_success:
......
...@@ -347,10 +347,13 @@ def get_extensions(): ...@@ -347,10 +347,13 @@ def get_extensions():
base_decoder_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'cpu', 'decoder') base_decoder_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'cpu', 'decoder')
base_decoder_src = glob.glob( base_decoder_src = glob.glob(
os.path.join(base_decoder_src_dir, "*.cpp")) os.path.join(base_decoder_src_dir, "*.cpp"))
# Torchvision video API
videoapi_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'cpu', 'video')
videoapi_src = glob.glob(os.path.join(videoapi_src_dir, "*.cpp"))
# exclude tests # exclude tests
base_decoder_src = [x for x in base_decoder_src if '_test.cpp' not in x] base_decoder_src = [x for x in base_decoder_src if '_test.cpp' not in x]
combined_src = video_reader_src + base_decoder_src combined_src = video_reader_src + base_decoder_src + videoapi_src
ext_modules.append( ext_modules.append(
CppExtension( CppExtension(
...@@ -359,6 +362,7 @@ def get_extensions(): ...@@ -359,6 +362,7 @@ def get_extensions():
include_dirs=[ include_dirs=[
base_decoder_src_dir, base_decoder_src_dir,
video_reader_src_dir, video_reader_src_dir,
videoapi_src_dir,
ffmpeg_include_dir, ffmpeg_include_dir,
extensions_dir, extensions_dir,
], ],
......
import os
import collections
import contextlib
import tempfile
import unittest
import random
import numpy as np
import torch
import torchvision
from torchvision.io import _HAS_VIDEO_OPT
try:
import av
# Do a version test too
torchvision.io.video._check_av_available()
except ImportError:
av = None
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
CheckerConfig = [
"duration",
"video_fps",
"audio_sample_rate",
# We find for some videos (e.g. HMDB51 videos), the decoded audio frames and pts are
# slightly different between TorchVision decoder and PyAv decoder. So omit it during check
"check_aframes",
"check_aframe_pts",
]
GroundTruth = collections.namedtuple("GroundTruth", " ".join(CheckerConfig))
all_check_config = GroundTruth(
duration=0,
video_fps=0,
audio_sample_rate=0,
check_aframes=True,
check_aframe_pts=True,
)
test_videos = {
"RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(
duration=2.0,
video_fps=30.0,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"v_SoccerJuggling_g23_c01.avi": GroundTruth(
duration=8.0,
video_fps=29.97,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
"v_SoccerJuggling_g24_c01.avi": GroundTruth(
duration=8.0,
video_fps=29.97,
audio_sample_rate=None,
check_aframes=True,
check_aframe_pts=True,
),
# Last three test segfault on video reader (see issues)
"R6llTwEh07w.mp4": GroundTruth(
duration=10.0,
video_fps=30.0,
audio_sample_rate=44100,
# PyAv miss one audio frame at the beginning (pts=0)
check_aframes=False,
check_aframe_pts=False,
),
"SOX5yA1l24A.mp4": GroundTruth(
duration=11.0,
video_fps=29.97,
audio_sample_rate=48000,
# PyAv miss one audio frame at the beginning (pts=0)
check_aframes=False,
check_aframe_pts=False,
),
"WUzgd7C1pWA.mp4": GroundTruth(
duration=11.0,
video_fps=29.97,
audio_sample_rate=48000,
# PyAv miss one audio frame at the beginning (pts=0)
check_aframes=False,
check_aframe_pts=False,
),
}
DecoderResult = collections.namedtuple(
"DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase"
)
def _read_from_stream(
container, start_pts, end_pts, stream, stream_name, buffer_size=4
):
"""
Args:
container: pyav container
start_pts/end_pts: the starting/ending Presentation TimeStamp where
frames are read
stream: pyav stream
stream_name: a dictionary of streams. For example, {"video": 0} means
video stream at stream index 0
buffer_size: pts of frames decoded by PyAv is not guaranteed to be in
ascending order. We need to decode more frames even when we meet end
pts
"""
# seeking in the stream is imprecise. Thus, seek to an ealier PTS by a margin
margin = 1
seek_offset = max(start_pts - margin, 0)
container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
frames = {}
buffer_count = 0
for frame in container.decode(**stream_name):
if frame.pts < start_pts:
continue
if frame.pts <= end_pts:
frames[frame.pts] = frame
else:
buffer_count += 1
if buffer_count >= buffer_size:
break
result = [frames[pts] for pts in sorted(frames)]
return result
def _fraction_to_tensor(fraction):
ret = torch.zeros([2], dtype=torch.int32)
ret[0] = fraction.numerator
ret[1] = fraction.denominator
return ret
def _decode_frames_by_av_module(
full_path,
video_start_pts=0,
video_end_pts=None,
audio_start_pts=0,
audio_end_pts=None,
):
"""
Use PyAv to decode video frames. This provides a reference for our decoder
to compare the decoding results.
Input arguments:
full_path: video file path
video_start_pts/video_end_pts: the starting/ending Presentation TimeStamp where
frames are read
"""
if video_end_pts is None:
video_end_pts = float("inf")
if audio_end_pts is None:
audio_end_pts = float("inf")
container = av.open(full_path)
video_frames = []
vtimebase = torch.zeros([0], dtype=torch.int32)
if container.streams.video:
video_frames = _read_from_stream(
container,
video_start_pts,
video_end_pts,
container.streams.video[0],
{"video": 0},
)
# container.streams.video[0].average_rate is not a reliable estimator of
# frame rate. It can be wrong for certain codec, such as VP80
# So we do not return video fps here
vtimebase = _fraction_to_tensor(container.streams.video[0].time_base)
audio_frames = []
atimebase = torch.zeros([0], dtype=torch.int32)
if container.streams.audio:
audio_frames = _read_from_stream(
container,
audio_start_pts,
audio_end_pts,
container.streams.audio[0],
{"audio": 0},
)
atimebase = _fraction_to_tensor(container.streams.audio[0].time_base)
container.close()
vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
vframes = torch.as_tensor(np.stack(vframes))
vframe_pts = torch.tensor([frame.pts for frame in video_frames], dtype=torch.int64)
aframes = [frame.to_ndarray() for frame in audio_frames]
if aframes:
aframes = np.transpose(np.concatenate(aframes, axis=1))
aframes = torch.as_tensor(aframes)
else:
aframes = torch.empty((1, 0), dtype=torch.float32)
aframe_pts = torch.tensor(
[audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64
)
return DecoderResult(
vframes=vframes.permute(0, 3, 1, 2),
vframe_pts=vframe_pts,
vtimebase=vtimebase,
aframes=aframes,
aframe_pts=aframe_pts,
atimebase=atimebase,
)
def _template_read_video(video_object, s=0, e=None):
if e is None:
e = float("inf")
if e < s:
raise ValueError(
"end time should be larger than start time, got "
"start time={} and end time={}".format(s, e)
)
video_object.set_current_stream("video")
video_object.seek(s)
video_frames = torch.empty(0)
frames = []
video_pts = []
t, pts = video_object.next()
while t.numel() > 0 and (pts >= s and pts <= e):
frames.append(t)
video_pts.append(pts)
t, pts = video_object.next()
if len(frames) > 0:
video_frames = torch.stack(frames, 0)
video_object.set_current_stream("audio")
video_object.seek(s)
audio_frames = torch.empty(0)
frames = []
audio_pts = []
t, pts = video_object.next()
while t.numel() > 0 and (pts > s and pts <= e):
frames.append(t)
audio_pts.append(pts)
t, pts = video_object.next()
if len(frames) > 0:
audio_frames = torch.stack(frames, 0)
return DecoderResult(
vframes=video_frames,
vframe_pts=video_pts,
vtimebase=None,
aframes=audio_frames,
aframe_pts=audio_pts,
atimebase=None,
)
return video_frames, audio_frames, video_object.get_metadata()
@unittest.skipIf(_HAS_VIDEO_OPT is False, "Didn't compile with ffmpeg")
class TestVideo(unittest.TestCase):
@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_video_tensor(self):
"""
Check if reading the video using the `next` based API yields the
same sized tensors as the pyav alternative.
"""
torchvision.set_video_backend("pyav")
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using existing TV decoder
tv_result, _, _ = torchvision.io.read_video(full_path, pts_unit="sec")
tv_result = tv_result.permute(0, 3, 1, 2)
# pass 2: decode all frames using new api
reader = torch.classes.torchvision.Video(full_path, "video")
frames = []
t, _ = reader.next()
while t.numel() > 0:
frames.append(t)
t, _ = reader.next()
new_api = torch.stack(frames, 0)
self.assertEqual(tv_result.size(), new_api.size())
# def test_partial_video_reading_fn(self):
# torchvision.set_video_backend("video_reader")
# for test_video, config in test_videos.items():
# full_path = os.path.join(VIDEO_DIR, test_video)
# # select two random points between 0 and duration
# r = []
# r.append(random.uniform(0, config.duration))
# r.append(random.uniform(0, config.duration))
# s = min(r)
# e = max(r)
# reader = torch.classes.torchvision.Video(full_path, "video")
# results = _template_read_video(reader, s, e)
# tv_video, tv_audio, info = torchvision.io.read_video(
# full_path, start_pts=s, end_pts=e, pts_unit="sec"
# )
# self.assertAlmostEqual(tv_video.size(0), results.vframes.size(0), delta=2.0)
# def test_pts(self):
# """
# Check if every frame read from
# """
# torchvision.set_video_backend("video_reader")
# for test_video, config in test_videos.items():
# full_path = os.path.join(VIDEO_DIR, test_video)
# tv_timestamps, _ = torchvision.io.read_video_timestamps(
# full_path, pts_unit="sec"
# )
# # pass 2: decode all frames using new api
# reader = torch.classes.torchvision.Video(full_path, "video")
# pts = []
# t, p = reader.next()
# while t.numel() > 0:
# pts.append(p)
# t, p = reader.next()
# tv_timestamps = [float(p) for p in tv_timestamps]
# napi_pts = [float(p) for p in pts]
# for i in range(len(napi_pts)):
# self.assertAlmostEqual(napi_pts[i], tv_timestamps[i], delta=0.001)
# # check if pts of video frames are sorted in ascending order
# for i in range(len(napi_pts) - 1):
# self.assertEqual(napi_pts[i] < napi_pts[i + 1], True)
@unittest.skipIf(av is None, "PyAV unavailable")
def test_metadata(self):
"""
Test that the metadata returned via pyav corresponds to the one returned
by the new video decoder API
"""
torchvision.set_video_backend("pyav")
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
reader = torch.classes.torchvision.Video(full_path, "video")
reader_md = reader.get_metadata()
self.assertAlmostEqual(
config.video_fps, reader_md["video"]["fps"][0], delta=0.0001
)
self.assertAlmostEqual(
config.duration, reader_md["video"]["duration"][0], delta=0.5
)
@unittest.skipIf(av is None, "PyAV unavailable")
def test_video_reading_fn(self):
"""
Test that the outputs of the pyav and ffmpeg outputs are mostly the same
"""
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
ref_result = _decode_frames_by_av_module(full_path)
reader = torch.classes.torchvision.Video(full_path, "video")
newapi_result = _template_read_video(reader)
# First we check if the frames are approximately the same
# (note that every codec context has signature artefacts which
# make a direct comparison not feasible)
if newapi_result.vframes.numel() > 0 and ref_result.vframes.numel() > 0:
mean_delta = torch.mean(
torch.abs(
newapi_result.vframes.float() - ref_result.vframes.float()
)
)
self.assertAlmostEqual(mean_delta, 0, delta=8.0)
# Just a sanity check: are the two of the correct size?
self.assertEqual(newapi_result.vframes.size(), ref_result.vframes.size())
# Lastly, we compare the resulting audio streams
if (
config.check_aframes
and newapi_result.aframes.numel() > 0
and ref_result.aframes.numel() > 0
):
"""Audio stream is available and audio frame is required to return
from decoder"""
is_same = torch.all(
torch.eq(newapi_result.aframes, ref_result.aframes)
).item()
self.assertEqual(is_same, True)
if __name__ == "__main__":
unittest.main()
#include "Video.h"
#include <c10/util/Logging.h>
#include <torch/script.h>
#include "defs.h"
#include "memory_buffer.h"
#include "sync_decoder.h"
using namespace std;
using namespace ffmpeg;
// If we are in a Windows environment, we need to define
// initialization functions for the _custom_ops extension
// #ifdef _WIN32
// #if PY_MAJOR_VERSION < 3
// PyMODINIT_FUNC init_video_reader(void) {
// // No need to do anything.
// return NULL;
// }
// #else
// PyMODINIT_FUNC PyInit_video_reader(void) {
// // No need to do anything.
// return NULL;
// }
// #endif
// #endif
const size_t decoderTimeoutMs = 600000;
const AVPixelFormat defaultVideoPixelFormat = AV_PIX_FMT_RGB24;
const AVSampleFormat defaultAudioSampleFormat = AV_SAMPLE_FMT_FLT;
// returns number of written bytes
template <typename T>
size_t fillTensorList(DecoderOutputMessage& msgs, torch::Tensor& frame) {
const auto& msg = msgs;
T* frameData = frame.numel() > 0 ? frame.data_ptr<T>() : nullptr;
if (frameData) {
auto sizeInBytes = msg.payload->length();
memcpy(frameData, msg.payload->data(), sizeInBytes);
}
return sizeof(T);
}
size_t fillVideoTensor(DecoderOutputMessage& msgs, torch::Tensor& videoFrame) {
return fillTensorList<uint8_t>(msgs, videoFrame);
}
size_t fillAudioTensor(DecoderOutputMessage& msgs, torch::Tensor& audioFrame) {
return fillTensorList<float>(msgs, audioFrame);
}
std::pair<std::string, ffmpeg::MediaType> const* _parse_type(
const std::string& stream_string) {
static const std::array<std::pair<std::string, MediaType>, 4> types = {{
{"video", TYPE_VIDEO},
{"audio", TYPE_AUDIO},
{"subtitle", TYPE_SUBTITLE},
{"cc", TYPE_CC},
}};
auto device = std::find_if(
types.begin(),
types.end(),
[stream_string](const std::pair<std::string, MediaType>& p) {
return p.first == stream_string;
});
if (device != types.end()) {
return device;
}
AT_ERROR("Expected one of [audio, video, subtitle, cc] ", stream_string);
}
std::string parse_type_to_string(const std::string& stream_string) {
auto device = _parse_type(stream_string);
return device->first;
}
MediaType parse_type_to_mt(const std::string& stream_string) {
auto device = _parse_type(stream_string);
return device->second;
}
std::tuple<std::string, long> _parseStream(const std::string& streamString) {
TORCH_CHECK(!streamString.empty(), "Stream string must not be empty");
static const std::regex regex("([a-zA-Z_]+)(?::([1-9]\\d*|0))?");
std::smatch match;
TORCH_CHECK(
std::regex_match(streamString, match, regex),
"Invalid stream string: '",
streamString,
"'");
std::string type_ = "video";
type_ = parse_type_to_string(match[1].str());
long index_ = -1;
if (match[2].matched) {
try {
index_ = c10::stoi(match[2].str());
} catch (const std::exception&) {
AT_ERROR(
"Could not parse device index '",
match[2].str(),
"' in device string '",
streamString,
"'");
}
}
return std::make_tuple(type_, index_);
}
void Video::_getDecoderParams(
double videoStartS,
int64_t getPtsOnly,
std::string stream,
long stream_id = -1,
bool all_streams = false,
double seekFrameMarginUs = 10) {
int64_t videoStartUs = int64_t(videoStartS * 1e6);
params.timeoutMs = decoderTimeoutMs;
params.startOffset = videoStartUs;
params.seekAccuracy = seekFrameMarginUs;
params.headerOnly = false;
params.preventStaleness = false; // not sure what this is about
if (all_streams == true) {
MediaFormat format;
format.stream = -2;
format.type = TYPE_AUDIO;
params.formats.insert(format);
format.type = TYPE_VIDEO;
format.stream = -2;
format.format.video.width = 0;
format.format.video.height = 0;
format.format.video.cropImage = 0;
format.format.video.format = defaultVideoPixelFormat;
params.formats.insert(format);
format.type = TYPE_SUBTITLE;
format.stream = -2;
params.formats.insert(format);
format.type = TYPE_CC;
format.stream = -2;
params.formats.insert(format);
} else {
// parse stream type
MediaType stream_type = parse_type_to_mt(stream);
// TODO: reset params.formats
std::set<MediaFormat> formats;
params.formats = formats;
// Define new format
MediaFormat format;
format.type = stream_type;
format.stream = stream_id;
if (stream_type == TYPE_VIDEO) {
format.format.video.width = 0;
format.format.video.height = 0;
format.format.video.cropImage = 0;
format.format.video.format = defaultVideoPixelFormat;
}
params.formats.insert(format);
}
} // _get decoder params
Video::Video(std::string videoPath, std::string stream) {
// parse stream information
current_stream = _parseStream(stream);
// note that in the initial call we want to get all streams
Video::_getDecoderParams(
0, // video start
0, // headerOnly
get<0>(current_stream), // stream info - remove that
long(-1), // stream_id parsed from info above change to -2
true // read all streams
);
std::string logMessage, logType;
// TODO: add read from memory option
params.uri = videoPath;
logType = "file";
logMessage = videoPath;
// locals
std::vector<double> audioFPS, videoFPS, ccFPS, subsFPS;
std::vector<double> audioDuration, videoDuration, ccDuration, subsDuration;
std::vector<double> audioTB, videoTB, ccTB, subsTB;
c10::Dict<std::string, std::vector<double>> audioMetadata;
c10::Dict<std::string, std::vector<double>> videoMetadata;
// calback and metadata defined in struct
succeeded = decoder.init(params, std::move(callback), &metadata);
if (succeeded) {
for (const auto& header : metadata) {
double fps = double(header.fps);
double duration = double(header.duration) * 1e-6; // * timeBase;
if (header.format.type == TYPE_VIDEO) {
videoFPS.push_back(fps);
videoDuration.push_back(duration);
} else if (header.format.type == TYPE_AUDIO) {
audioFPS.push_back(fps);
audioDuration.push_back(duration);
} else if (header.format.type == TYPE_CC) {
ccFPS.push_back(fps);
ccDuration.push_back(duration);
} else if (header.format.type == TYPE_SUBTITLE) {
subsFPS.push_back(fps);
subsDuration.push_back(duration);
};
}
}
audioMetadata.insert("duration", audioDuration);
audioMetadata.insert("framerate", audioFPS);
videoMetadata.insert("duration", videoDuration);
videoMetadata.insert("fps", videoFPS);
streamsMetadata.insert("video", videoMetadata);
streamsMetadata.insert("audio", audioMetadata);
succeeded = Video::setCurrentStream(stream);
LOG(INFO) << "\nDecoder inited with: " << succeeded << "\n";
if (get<1>(current_stream) != -1) {
LOG(INFO)
<< "Stream index set to " << get<1>(current_stream)
<< ". If you encounter trouble, consider switching it to automatic stream discovery. \n";
}
} // video
bool Video::setCurrentStream(std::string stream = "video") {
if ((!stream.empty()) && (_parseStream(stream) != current_stream)) {
current_stream = _parseStream(stream);
}
double ts = 0;
if (seekTS > 0) {
ts = seekTS;
}
_getDecoderParams(
ts, // video start
0, // headerOnly
get<0>(current_stream), // stream
long(get<1>(
current_stream)), // stream_id parsed from info above change to -2
false // read all streams
);
// calback and metadata defined in Video.h
return (decoder.init(params, std::move(callback), &metadata));
}
std::tuple<std::string, int64_t> Video::getCurrentStream() const {
return current_stream;
}
c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>> Video::
getStreamMetadata() const {
return streamsMetadata;
}
void Video::Seek(double ts) {
// initialize the class variables used for seeking and retrurn
_getDecoderParams(
ts, // video start
0, // headerOnly
get<0>(current_stream), // stream
long(get<1>(
current_stream)), // stream_id parsed from info above change to -2
false // read all streams
);
// calback and metadata defined in Video.h
succeeded = decoder.init(params, std::move(callback), &metadata);
LOG(INFO) << "Decoder init at seek " << succeeded << "\n";
}
std::tuple<torch::Tensor, double> Video::Next() {
// if failing to decode simply return a null tensor (note, should we
// raise an exeption?)
double frame_pts_s;
torch::Tensor outFrame = torch::zeros({0}, torch::kByte);
// decode single frame
DecoderOutputMessage out;
int64_t res = decoder.decode(&out, decoderTimeoutMs);
// if successfull
if (res == 0) {
frame_pts_s = double(double(out.header.pts) * 1e-6);
auto header = out.header;
const auto& format = header.format;
// initialize the output variables based on type
if (format.type == TYPE_VIDEO) {
// note: this can potentially be optimized
// by having the global tensor that we fill at decode time
// (would avoid allocations)
int outHeight = format.format.video.height;
int outWidth = format.format.video.width;
int numChannels = 3;
outFrame = torch::zeros({outHeight, outWidth, numChannels}, torch::kByte);
auto numberWrittenBytes = fillVideoTensor(out, outFrame);
outFrame = outFrame.permute({2, 0, 1});
} else if (format.type == TYPE_AUDIO) {
int outAudioChannels = format.format.audio.channels;
int bytesPerSample = av_get_bytes_per_sample(
static_cast<AVSampleFormat>(format.format.audio.format));
int frameSizeTotal = out.payload->length();
CHECK_EQ(frameSizeTotal % (outAudioChannels * bytesPerSample), 0);
int numAudioSamples =
frameSizeTotal / (outAudioChannels * bytesPerSample);
outFrame =
torch::zeros({numAudioSamples, outAudioChannels}, torch::kFloat);
auto numberWrittenBytes = fillAudioTensor(out, outFrame);
}
// currently not supporting other formats (will do soon)
out.payload.reset();
} else if (res == 61) {
LOG(INFO) << "Decoder ran out of frames (error 61)\n";
} else {
LOG(ERROR) << "Decoder failed with ERROR_CODE " << res;
}
std::tuple<torch::Tensor, double> result = {outFrame, frame_pts_s};
return result;
}
#pragma once
#include <map>
#include <regex>
#include <string>
#include <vector>
#include <ATen/ATen.h>
#include <Python.h>
#include <c10/util/Logging.h>
#include <torch/script.h>
#include <exception>
#include "defs.h"
#include "memory_buffer.h"
#include "sync_decoder.h"
using namespace ffmpeg;
struct Video : torch::CustomClassHolder {
std::tuple<std::string, long> current_stream; // stream type, id
// global video metadata
c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>>
streamsMetadata;
public:
Video(std::string videoPath, std::string stream);
std::tuple<std::string, int64_t> getCurrentStream() const;
c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>>
getStreamMetadata() const;
void Seek(double ts);
bool setCurrentStream(std::string stream);
std::tuple<torch::Tensor, double> Next();
private:
bool video_any_frame = false; // add this to input parameters?
bool succeeded = false; // decoder init flag
// seekTS and doSeek act as a flag - if it's not set, next function simply
// retruns the next frame. If it's set, we look at the global seek
// time in comination with any_frame settings
double seekTS = -1;
bool doSeek = false;
void _getDecoderParams(
double videoStartS,
int64_t getPtsOnly,
std::string stream,
long stream_id,
bool all_streams,
double seekFrameMarginUs); // this needs to be improved
std::map<std::string, std::vector<double>> streamTimeBase; // not used
DecoderInCallback callback = nullptr;
std::vector<DecoderMetadata> metadata;
protected:
SyncDecoder decoder;
DecoderParameters params;
}; // struct Video
#ifndef REGISTER_H
#define REGISTER_H
#include "Video.h"
namespace {
static auto registerVideo =
torch::class_<Video>("torchvision", "Video")
.def(torch::init<std::string, std::string>())
.def("get_current_stream", &Video::getCurrentStream)
.def("set_current_stream", &Video::setCurrentStream)
.def("get_metadata", &Video::getStreamMetadata)
.def("seek", &Video::Seek)
.def("next", &Video::Next);
} // namespace
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment