Commit 60f29ca0 authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Add precise seek (#2737)

Summary:
cc mthrok

Implements precise seek and seek to any frame in torchaudio

Pull Request resolved: https://github.com/pytorch/audio/pull/2737

Reviewed By: mthrok

Differential Revision: D40546716

Pulled By: jdsgomes

fbshipit-source-id: d37da7f55977337eb16a3c4df44ce8c3c102698e
parent 82d92da5
......@@ -418,15 +418,16 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
if i >= 40:
break
def test_seek(self):
@parameterized.expand(["key", "any", "precise"])
def test_seek(self, mode):
"""Calling `seek` multiple times should not segfault"""
s = StreamReader(self.get_src())
for i in range(10):
s.seek(i)
s.seek(i, mode)
for _ in range(0):
s.seek(0)
s.seek(0, mode)
for i in range(10, 0, -1):
s.seek(i)
s.seek(i, mode)
def test_seek_negative(self):
"""Calling `seek` with negative value should raise an exception"""
......@@ -434,6 +435,79 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
with self.assertRaises(RuntimeError):
s.seek(-1.0)
def test_seek_invalid_mode(self):
"""Calling `seek` with an invalid model should raise an exception"""
s = StreamReader(self.get_src())
with self.assertRaises(ValueError):
s.seek(10, "magic_seek")
@parameterized.expand(
[
# Test keyframe seek
# The source mp4 video has two key frames the first frame and 203rd frame at 8.08 second.
# If the seek time stamp is smaller than 8.08, it will seek into the first frame at 0.0 second.
("nasa_13013.mp4", "key", 0.2, (0, 0)),
("nasa_13013.mp4", "key", 8.04, (0, 0)),
("nasa_13013.mp4", "key", 8.08, (0, 202)),
("nasa_13013.mp4", "key", 8.12, (0, 202)),
# The source avi video has one keyframe every twelve frames 0, 12, 24,.. or every 0.4004 seconds.
# if we seek to a time stamp smaller than 0.4004 it will seek into the first frame at 0.0 second.
("nasa_13013.avi", "key", 0.2, (0, 0)),
("nasa_13013.avi", "key", 1.01, (0, 24)),
("nasa_13013.avi", "key", 7.37, (0, 216)),
("nasa_13013.avi", "key", 7.7, (0, 216)),
# Test precise seek
("nasa_13013.mp4", "precise", 0.0, (0, 0)),
("nasa_13013.mp4", "precise", 0.2, (0, 5)),
("nasa_13013.mp4", "precise", 8.04, (0, 201)),
("nasa_13013.mp4", "precise", 8.08, (0, 202)),
("nasa_13013.mp4", "precise", 8.12, (0, 203)),
("nasa_13013.avi", "precise", 0.0, (0, 0)),
("nasa_13013.avi", "precise", 0.2, (0, 1)),
("nasa_13013.avi", "precise", 8.1, (0, 238)),
("nasa_13013.avi", "precise", 8.14, (0, 239)),
("nasa_13013.avi", "precise", 8.17, (0, 240)),
# Test any seek
# The source avi video has one keyframe every twelve frames 0, 12, 24,.. or every 0.4004 seconds.
("nasa_13013.avi", "any", 0.0, (0, 0)),
("nasa_13013.avi", "any", 0.56, (0, 12)),
("nasa_13013.avi", "any", 7.77, (0, 228)),
("nasa_13013.avi", "any", 0.2002, (11, 12)),
("nasa_13013.avi", "any", 0.233567, (10, 12)),
("nasa_13013.avi", "any", 0.266933, (9, 12)),
]
)
def test_seek_modes(self, src, mode, seek_time, ref_indices):
"""We expect the following behaviour from the diferent kinds of seek:
- `key`: the reader will seek to the first keyframe from the timestamp given
- `precise`: the reader will seek to the first keyframe from the timestamp given
and start decoding from that position until the given timestmap (discarding all frames in between)
- `any`: the reader will seek to the colsest frame to the timestamp
given but if this is not a keyframe, the content will be the delta from other frames
To thest this behaviour we can parameterize the test with the tupple ref_indices. ref_indices[0]
is the expected index on the frames list decoded after seek and ref_indices[1] is exepected index for
the list of all frames decoded from the begining (reference frames). This test checks if
the reference frame at index ref_indices[1] is the same as ref_indices[0]. Plese note that with `any`
and `key` seek we only compare keyframes, but with `precise` seek we can compare any frame content.
"""
# Using the first video stream (which is not default video stream)
stream_index = 0
# Decode all frames for reference
src_bin = self.get_src(src)
s = StreamReader(src_bin)
s.add_basic_video_stream(-1, stream_index=stream_index)
s.process_all_packets()
(ref_frames,) = s.pop_chunks()
s.seek(seek_time, mode=mode)
s.process_all_packets()
(frame,) = s.pop_chunks()
hyp_index, ref_index = ref_indices
self.assertEqual(frame[hyp_index:], ref_frames[ref_index:])
def _to_fltp(original):
"""Convert Tensor to float32 with value range [-1, 1]"""
......
......@@ -84,10 +84,12 @@ int Sink::process_frame(AVFrame* pFrame) {
ret = filter->get_frame(frame);
// AVERROR(EAGAIN) means that new input data is required to return new
// output.
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF)
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
return 0;
if (ret >= 0)
}
if (ret >= 0) {
buffer->push_frame(frame);
}
av_frame_unref(frame);
}
return ret;
......
......@@ -68,7 +68,9 @@ bool StreamProcessor::is_buffer_ready() const {
////////////////////////////////////////////////////////////////////////////////
// 0: some kind of success
// <0: Some error happened
int StreamProcessor::process_packet(AVPacket* packet) {
int StreamProcessor::process_packet(
AVPacket* packet,
int64_t discard_before_pts) {
int ret = decoder.process_packet(packet);
while (ret >= 0) {
ret = decoder.get_frame(pFrame1);
......@@ -80,7 +82,12 @@ int StreamProcessor::process_packet(AVPacket* packet) {
return send_frame(NULL);
if (ret < 0)
return ret;
if (pFrame1->pts >= discard_before_pts) {
send_frame(pFrame1);
}
// else we can just unref the frame and continue
av_frame_unref(pFrame1);
}
return ret;
......
......@@ -70,7 +70,7 @@ class StreamProcessor {
// 2. pass the decoded data to filters
// 3. each filter store the result to the corresponding buffer
// - Sending NULL will drain (flush) the internal
int process_packet(AVPacket* packet);
int process_packet(AVPacket* packet, int64_t discard_before_pts = -1);
// flush the internal buffer of decoder.
// To be use when seeking
......
......@@ -163,12 +163,38 @@ bool StreamReader::is_buffer_ready() const {
////////////////////////////////////////////////////////////////////////////////
// Configure methods
////////////////////////////////////////////////////////////////////////////////
void StreamReader::seek(double timestamp) {
TORCH_CHECK(timestamp >= 0, "timestamp must be non-negative.");
void StreamReader::seek(double timestamp_s, int64_t mode) {
TORCH_CHECK(timestamp_s >= 0, "timestamp must be non-negative.");
TORCH_CHECK(
pFormatContext->nb_streams > 0,
"At least one stream must exist in this context");
int64_t timestamp_av_tb = static_cast<int64_t>(timestamp_s * AV_TIME_BASE);
int flag = AVSEEK_FLAG_BACKWARD;
switch (mode) {
case 0:
seek_timestamp =
-1; // reset seek_timestap as it is only used for precise seek
break;
case 1:
flag |= AVSEEK_FLAG_ANY;
seek_timestamp =
-1; // reset seek_timestap as it is only used for precise seek
break;
case 2:
seek_timestamp = timestamp_av_tb;
break;
default:
TORCH_CHECK(false, "Invalid mode value: ", mode);
}
int ret = av_seek_frame(pFormatContext, -1, timestamp_av_tb, flag);
int64_t ts = static_cast<int64_t>(timestamp * AV_TIME_BASE);
int ret = avformat_seek_file(pFormatContext, -1, INT64_MIN, ts, INT64_MAX, 0);
TORCH_CHECK(ret >= 0, "Failed to seek. (" + av_err2string(ret) + ".)");
if (ret < 0) {
seek_timestamp = -1;
TORCH_CHECK(false, "Failed to seek. (" + av_err2string(ret) + ".)");
}
for (const auto& it : processors) {
if (it) {
it->flush();
......@@ -301,7 +327,14 @@ int StreamReader::process_packet() {
if (!processor) {
return 0;
}
ret = processor->process_packet(packet);
AVRational stream_tb =
pFormatContext->streams[pPacket->stream_index]->time_base;
int64_t seek_timestamp_in_stream_tb =
av_rescale_q(seek_timestamp, av_get_time_base_q(), stream_tb);
ret = processor->process_packet(packet, seek_timestamp_in_stream_tb);
return (ret < 0) ? ret : 0;
}
......
......@@ -17,6 +17,9 @@ class StreamReader {
// the second is the map key inside of processor.
std::vector<std::pair<int, int>> stream_indices;
// timestamp to seek to expressed in AV_TIME_BASE
int64_t seek_timestamp = -1;
public:
explicit StreamReader(AVFormatInputContextPtr&& p);
~StreamReader() = default;
......@@ -57,7 +60,7 @@ class StreamReader {
//////////////////////////////////////////////////////////////////////////////
// Configure methods
//////////////////////////////////////////////////////////////////////////////
void seek(double timestamp);
void seek(double timestamp_s, int64_t mode);
void add_audio_stream(
int64_t i,
......
......@@ -42,7 +42,7 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
.def(
"find_best_video_stream",
[](S s) { return s->find_best_video_stream(); })
.def("seek", [](S s, double t) { return s->seek(t); })
.def("seek", [](S s, double t, int64_t mode) { return s->seek(t, mode); })
.def(
"add_audio_stream",
[](S s,
......
......@@ -123,7 +123,7 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
.def(
"find_best_video_stream",
[](S s) { return s->find_best_video_stream(); })
.def("seek", [](S s, double t) { return s->seek(t); })
.def("seek", [](S s, double t, int64_t mode) { return s->seek(t, mode); })
.def(
"add_audio_stream",
[](S s,
......
......@@ -425,13 +425,34 @@ class StreamReader:
"""
return _parse_oi(self._be.get_out_stream_info(i))
def seek(self, timestamp: float):
def seek(self, timestamp: float, mode: str = "precise"):
"""Seek the stream to the given timestamp [second]
Args:
timestamp (float): Target time in second.
mode (str): Controls how seek is done.
Valid choices are;
* "key": Seek into the nearest key frame before the given timestamp.
* "any": Seek into any frame (including non-key frames) before the given timestamp.
* "precise": First seek into the nearest key frame before the given timestamp, then
decode frames until it reaches the closes frame to the given timestamp.
Note:
All the modes invalidate and reset the internal state of decoder.
When using "any" mode and if it ends up seeking into non-key frame,
the image decoded may be invalid due to lack of key frame.
Using "precise" will workaround this issue by decoding frames from previous
key frame, but will be slower.
"""
self._be.seek(timestamp)
modes = {
"key": 0,
"any": 1,
"precise": 2,
}
if mode not in modes:
raise ValueError(f"The value of mode must be one of {list(modes.keys())}. Found: {mode}")
self._be.seek(timestamp, modes[mode])
@_format_audio_args
def add_basic_audio_stream(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment