Unverified Commit 4ccef06c authored by Bruno Korbar's avatar Bruno Korbar Committed by GitHub
Browse files

Fast seek implementation (#3179)



* modify processPacket to support fast seek

* add fastSeek to ProcessPacket decoder definition

* add fastseek flag to DecoderParametersStruct

* add fastseek flag to the process packet call

* no default params in C++ implementation

* enable flag in C++ implementation

* make order of parameters more normal

* register new seek with python api

* [somewhat broken] test suite for keyframes using pyav

* revert " changes

* add type annotations to init

* Adding tests

* linter

* Flake doesn't show up :|

* Change from unitest to pytest syntax

* add return type
Co-authored-by: default avatarPrabhat Roy <prabhatroy@fb.com>
parent b43353e8
...@@ -167,6 +167,43 @@ class TestVideoApi: ...@@ -167,6 +167,43 @@ class TestVideoApi:
assert metadata["subtitles"]["duration"] is not None assert metadata["subtitles"]["duration"] is not None
os.remove(video_path) os.remove(video_path)
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
def test_keyframe_reading(self):
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
av_reader = av.open(full_path)
# reduce streams to only keyframes
av_stream = av_reader.streams.video[0]
av_stream.codec_context.skip_frame = "NONKEY"
av_keyframes = []
vr_keyframes = []
if av_reader.streams.video:
# get all keyframes using pyav. Then, seek randomly into video reader
# and assert that all the returned values are in AV_KEYFRAMES
for av_frame in av_reader.decode(av_stream):
av_keyframes.append(float(av_frame.pts * av_frame.time_base))
if len(av_keyframes) > 1:
video_reader = VideoReader(full_path, "video")
for i in range(1, len(av_keyframes)):
seek_val = (av_keyframes[i] + av_keyframes[i - 1]) / 2
data = next(video_reader.seek(seek_val, True))
vr_keyframes.append(data["pts"])
data = next(video_reader.seek(config.duration, True))
vr_keyframes.append(data["pts"])
assert len(av_keyframes) == len(vr_keyframes)
# NOTE: this video gets different keyframe with different
# loaders (0.333 pyav, 0.666 for us)
if test_video != "TrumanShow_wave_f_nm_np1_fr_med_26.avi":
for i in range(len(av_keyframes)):
assert av_keyframes[i] == approx(vr_keyframes[i], rel=0.001)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
...@@ -552,9 +552,9 @@ int Decoder::getFrame(size_t workingTimeInMs) { ...@@ -552,9 +552,9 @@ int Decoder::getFrame(size_t workingTimeInMs) {
bool gotFrame = false; bool gotFrame = false;
bool hasMsg = false; bool hasMsg = false;
// packet either got consumed completely or not at all // packet either got consumed completely or not at all
if ((result = processPacket(stream, &avPacket, &gotFrame, &hasMsg)) < 0) { if ((result = processPacket(
LOG(ERROR) << "uuid=" << params_.loggingUuid stream, &avPacket, &gotFrame, &hasMsg, params_.fastSeek)) < 0) {
<< " processPacket failed with code=" << result; LOG(ERROR) << "processPacket failed with code: " << result;
break; break;
} }
...@@ -635,7 +635,8 @@ int Decoder::processPacket( ...@@ -635,7 +635,8 @@ int Decoder::processPacket(
Stream* stream, Stream* stream,
AVPacket* packet, AVPacket* packet,
bool* gotFrame, bool* gotFrame,
bool* hasMsg) { bool* hasMsg,
bool fastSeek) {
// decode package // decode package
int result; int result;
DecoderOutputMessage msg; DecoderOutputMessage msg;
...@@ -648,7 +649,15 @@ int Decoder::processPacket( ...@@ -648,7 +649,15 @@ int Decoder::processPacket(
bool endInRange = bool endInRange =
params_.endOffset <= 0 || msg.header.pts <= params_.endOffset; params_.endOffset <= 0 || msg.header.pts <= params_.endOffset;
inRange_.set(stream->getIndex(), endInRange); inRange_.set(stream->getIndex(), endInRange);
if (endInRange && msg.header.pts >= params_.startOffset) { // if fastseek is enabled, we're returning the first
// frame that we decode after (potential) seek.
// By default, we perform accurate seek to the closest
// following frame
bool startCondition = true;
if (!fastSeek) {
startCondition = msg.header.pts >= params_.startOffset;
}
if (endInRange && startCondition) {
*hasMsg = true; *hasMsg = true;
push(std::move(msg)); push(std::move(msg));
} }
......
...@@ -72,7 +72,8 @@ class Decoder : public MediaDecoder { ...@@ -72,7 +72,8 @@ class Decoder : public MediaDecoder {
Stream* stream, Stream* stream,
AVPacket* packet, AVPacket* packet,
bool* gotFrame, bool* gotFrame,
bool* hasMsg); bool* hasMsg,
bool fastSeek = false);
void flushStreams(); void flushStreams();
void cleanUp(); void cleanUp();
......
...@@ -190,6 +190,8 @@ struct DecoderParameters { ...@@ -190,6 +190,8 @@ struct DecoderParameters {
bool listen{false}; bool listen{false};
// don't copy frame body, only header // don't copy frame body, only header
bool headerOnly{false}; bool headerOnly{false};
// enable fast seek (seek only to keyframes)
bool fastSeek{false};
// interrupt init method on timeout // interrupt init method on timeout
bool preventStaleness{true}; bool preventStaleness{true};
// seek tolerated accuracy (us) // seek tolerated accuracy (us)
......
...@@ -98,6 +98,7 @@ void Video::_getDecoderParams( ...@@ -98,6 +98,7 @@ void Video::_getDecoderParams(
int64_t getPtsOnly, int64_t getPtsOnly,
std::string stream, std::string stream,
long stream_id = -1, long stream_id = -1,
bool fastSeek = true,
bool all_streams = false, bool all_streams = false,
int64_t num_threads = 1, int64_t num_threads = 1,
double seekFrameMarginUs = 10) { double seekFrameMarginUs = 10) {
...@@ -106,6 +107,7 @@ void Video::_getDecoderParams( ...@@ -106,6 +107,7 @@ void Video::_getDecoderParams(
params.timeoutMs = decoderTimeoutMs; params.timeoutMs = decoderTimeoutMs;
params.startOffset = videoStartUs; params.startOffset = videoStartUs;
params.seekAccuracy = seekFrameMarginUs; params.seekAccuracy = seekFrameMarginUs;
params.fastSeek = fastSeek;
params.headerOnly = false; params.headerOnly = false;
params.numThreads = num_threads; params.numThreads = num_threads;
...@@ -165,6 +167,7 @@ Video::Video(std::string videoPath, std::string stream, int64_t numThreads) { ...@@ -165,6 +167,7 @@ Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
0, // headerOnly 0, // headerOnly
std::get<0>(current_stream), // stream info - remove that std::get<0>(current_stream), // stream info - remove that
long(-1), // stream_id parsed from info above change to -2 long(-1), // stream_id parsed from info above change to -2
false, // fastseek: we're using the default param here
true, // read all streams true, // read all streams
numThreads_ // global number of Threads for decoding numThreads_ // global number of Threads for decoding
); );
...@@ -246,6 +249,7 @@ bool Video::setCurrentStream(std::string stream = "video") { ...@@ -246,6 +249,7 @@ bool Video::setCurrentStream(std::string stream = "video") {
std::get<0>(current_stream), // stream std::get<0>(current_stream), // stream
long(std::get<1>( long(std::get<1>(
current_stream)), // stream_id parsed from info above change to -2 current_stream)), // stream_id parsed from info above change to -2
false, // fastseek param set to 0 false by default (changed in seek)
false, // read all streams false, // read all streams
numThreads_ // global number of threads numThreads_ // global number of threads
); );
...@@ -263,7 +267,7 @@ c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>> Video:: ...@@ -263,7 +267,7 @@ c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>> Video::
return streamsMetadata; return streamsMetadata;
} }
void Video::Seek(double ts) { void Video::Seek(double ts, bool fastSeek = false) {
// initialize the class variables used for seeking and retrurn // initialize the class variables used for seeking and retrurn
_getDecoderParams( _getDecoderParams(
ts, // video start ts, // video start
...@@ -271,6 +275,7 @@ void Video::Seek(double ts) { ...@@ -271,6 +275,7 @@ void Video::Seek(double ts) {
std::get<0>(current_stream), // stream std::get<0>(current_stream), // stream
long(std::get<1>( long(std::get<1>(
current_stream)), // stream_id parsed from info above change to -2 current_stream)), // stream_id parsed from info above change to -2
fastSeek, // fastseek
false, // read all streams false, // read all streams
numThreads_ // global number of threads numThreads_ // global number of threads
); );
......
...@@ -23,7 +23,7 @@ struct Video : torch::CustomClassHolder { ...@@ -23,7 +23,7 @@ struct Video : torch::CustomClassHolder {
std::tuple<std::string, int64_t> getCurrentStream() const; std::tuple<std::string, int64_t> getCurrentStream() const;
c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>> c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>>
getStreamMetadata() const; getStreamMetadata() const;
void Seek(double ts); void Seek(double ts, bool fastSeek);
bool setCurrentStream(std::string stream); bool setCurrentStream(std::string stream);
std::tuple<torch::Tensor, double> Next(); std::tuple<torch::Tensor, double> Next();
...@@ -39,6 +39,7 @@ struct Video : torch::CustomClassHolder { ...@@ -39,6 +39,7 @@ struct Video : torch::CustomClassHolder {
int64_t getPtsOnly, int64_t getPtsOnly,
std::string stream, std::string stream,
long stream_id, long stream_id,
bool fastSeek,
bool all_streams, bool all_streams,
int64_t num_threads, int64_t num_threads,
double seekFrameMarginUs); // this needs to be improved double seekFrameMarginUs); // this needs to be improved
......
...@@ -135,11 +135,12 @@ class VideoReader: ...@@ -135,11 +135,12 @@ class VideoReader:
def __iter__(self) -> Iterator["VideoReader"]: def __iter__(self) -> Iterator["VideoReader"]:
return self return self
def seek(self, time_s: float) -> "VideoReader": def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader":
"""Seek within current stream. """Seek within current stream.
Args: Args:
time_s (float): seek time in seconds time_s (float): seek time in seconds
keyframes_only (bool): allow to seek only to keyframes
.. note:: .. note::
Current implementation is the so-called precise seek. This Current implementation is the so-called precise seek. This
...@@ -147,7 +148,7 @@ class VideoReader: ...@@ -147,7 +148,7 @@ class VideoReader:
frame with the exact timestamp if it exists or frame with the exact timestamp if it exists or
the first frame with timestamp larger than ``time_s``. the first frame with timestamp larger than ``time_s``.
""" """
self._c.seek(time_s) self._c.seek(time_s, keyframes_only)
return self return self
def get_metadata(self) -> Dict[str, Any]: def get_metadata(self) -> Dict[str, Any]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment