Commit 60f29ca0 authored by Joao Gomes's avatar Joao Gomes Committed by Facebook GitHub Bot
Browse files

Add precise seek (#2737)

Summary:
cc mthrok

Implements precise seek and seek to any frame in torchaudio

Pull Request resolved: https://github.com/pytorch/audio/pull/2737

Reviewed By: mthrok

Differential Revision: D40546716

Pulled By: jdsgomes

fbshipit-source-id: d37da7f55977337eb16a3c4df44ce8c3c102698e
parent 82d92da5
...@@ -418,15 +418,16 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -418,15 +418,16 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
if i >= 40: if i >= 40:
break break
def test_seek(self): @parameterized.expand(["key", "any", "precise"])
def test_seek(self, mode):
"""Calling `seek` multiple times should not segfault""" """Calling `seek` multiple times should not segfault"""
s = StreamReader(self.get_src()) s = StreamReader(self.get_src())
for i in range(10): for i in range(10):
s.seek(i) s.seek(i, mode)
for _ in range(0): for _ in range(0):
s.seek(0) s.seek(0, mode)
for i in range(10, 0, -1): for i in range(10, 0, -1):
s.seek(i) s.seek(i, mode)
def test_seek_negative(self): def test_seek_negative(self):
"""Calling `seek` with negative value should raise an exception""" """Calling `seek` with negative value should raise an exception"""
...@@ -434,6 +435,79 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -434,6 +435,79 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
s.seek(-1.0) s.seek(-1.0)
def test_seek_invalid_mode(self):
"""Calling `seek` with an invalid model should raise an exception"""
s = StreamReader(self.get_src())
with self.assertRaises(ValueError):
s.seek(10, "magic_seek")
@parameterized.expand(
[
# Test keyframe seek
# The source mp4 video has two key frames the first frame and 203rd frame at 8.08 second.
# If the seek time stamp is smaller than 8.08, it will seek into the first frame at 0.0 second.
("nasa_13013.mp4", "key", 0.2, (0, 0)),
("nasa_13013.mp4", "key", 8.04, (0, 0)),
("nasa_13013.mp4", "key", 8.08, (0, 202)),
("nasa_13013.mp4", "key", 8.12, (0, 202)),
# The source avi video has one keyframe every twelve frames 0, 12, 24,.. or every 0.4004 seconds.
# if we seek to a time stamp smaller than 0.4004 it will seek into the first frame at 0.0 second.
("nasa_13013.avi", "key", 0.2, (0, 0)),
("nasa_13013.avi", "key", 1.01, (0, 24)),
("nasa_13013.avi", "key", 7.37, (0, 216)),
("nasa_13013.avi", "key", 7.7, (0, 216)),
# Test precise seek
("nasa_13013.mp4", "precise", 0.0, (0, 0)),
("nasa_13013.mp4", "precise", 0.2, (0, 5)),
("nasa_13013.mp4", "precise", 8.04, (0, 201)),
("nasa_13013.mp4", "precise", 8.08, (0, 202)),
("nasa_13013.mp4", "precise", 8.12, (0, 203)),
("nasa_13013.avi", "precise", 0.0, (0, 0)),
("nasa_13013.avi", "precise", 0.2, (0, 1)),
("nasa_13013.avi", "precise", 8.1, (0, 238)),
("nasa_13013.avi", "precise", 8.14, (0, 239)),
("nasa_13013.avi", "precise", 8.17, (0, 240)),
# Test any seek
# The source avi video has one keyframe every twelve frames 0, 12, 24,.. or every 0.4004 seconds.
("nasa_13013.avi", "any", 0.0, (0, 0)),
("nasa_13013.avi", "any", 0.56, (0, 12)),
("nasa_13013.avi", "any", 7.77, (0, 228)),
("nasa_13013.avi", "any", 0.2002, (11, 12)),
("nasa_13013.avi", "any", 0.233567, (10, 12)),
("nasa_13013.avi", "any", 0.266933, (9, 12)),
]
)
def test_seek_modes(self, src, mode, seek_time, ref_indices):
"""We expect the following behaviour from the diferent kinds of seek:
- `key`: the reader will seek to the first keyframe from the timestamp given
- `precise`: the reader will seek to the first keyframe from the timestamp given
and start decoding from that position until the given timestmap (discarding all frames in between)
- `any`: the reader will seek to the colsest frame to the timestamp
given but if this is not a keyframe, the content will be the delta from other frames
To thest this behaviour we can parameterize the test with the tupple ref_indices. ref_indices[0]
is the expected index on the frames list decoded after seek and ref_indices[1] is exepected index for
the list of all frames decoded from the begining (reference frames). This test checks if
the reference frame at index ref_indices[1] is the same as ref_indices[0]. Plese note that with `any`
and `key` seek we only compare keyframes, but with `precise` seek we can compare any frame content.
"""
# Using the first video stream (which is not default video stream)
stream_index = 0
# Decode all frames for reference
src_bin = self.get_src(src)
s = StreamReader(src_bin)
s.add_basic_video_stream(-1, stream_index=stream_index)
s.process_all_packets()
(ref_frames,) = s.pop_chunks()
s.seek(seek_time, mode=mode)
s.process_all_packets()
(frame,) = s.pop_chunks()
hyp_index, ref_index = ref_indices
self.assertEqual(frame[hyp_index:], ref_frames[ref_index:])
def _to_fltp(original): def _to_fltp(original):
"""Convert Tensor to float32 with value range [-1, 1]""" """Convert Tensor to float32 with value range [-1, 1]"""
......
...@@ -84,10 +84,12 @@ int Sink::process_frame(AVFrame* pFrame) { ...@@ -84,10 +84,12 @@ int Sink::process_frame(AVFrame* pFrame) {
ret = filter->get_frame(frame); ret = filter->get_frame(frame);
// AVERROR(EAGAIN) means that new input data is required to return new // AVERROR(EAGAIN) means that new input data is required to return new
// output. // output.
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
return 0; return 0;
if (ret >= 0) }
if (ret >= 0) {
buffer->push_frame(frame); buffer->push_frame(frame);
}
av_frame_unref(frame); av_frame_unref(frame);
} }
return ret; return ret;
......
...@@ -68,7 +68,9 @@ bool StreamProcessor::is_buffer_ready() const { ...@@ -68,7 +68,9 @@ bool StreamProcessor::is_buffer_ready() const {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// 0: some kind of success // 0: some kind of success
// <0: Some error happened // <0: Some error happened
int StreamProcessor::process_packet(AVPacket* packet) { int StreamProcessor::process_packet(
AVPacket* packet,
int64_t discard_before_pts) {
int ret = decoder.process_packet(packet); int ret = decoder.process_packet(packet);
while (ret >= 0) { while (ret >= 0) {
ret = decoder.get_frame(pFrame1); ret = decoder.get_frame(pFrame1);
...@@ -80,7 +82,12 @@ int StreamProcessor::process_packet(AVPacket* packet) { ...@@ -80,7 +82,12 @@ int StreamProcessor::process_packet(AVPacket* packet) {
return send_frame(NULL); return send_frame(NULL);
if (ret < 0) if (ret < 0)
return ret; return ret;
if (pFrame1->pts >= discard_before_pts) {
send_frame(pFrame1); send_frame(pFrame1);
}
// else we can just unref the frame and continue
av_frame_unref(pFrame1); av_frame_unref(pFrame1);
} }
return ret; return ret;
......
...@@ -70,7 +70,7 @@ class StreamProcessor { ...@@ -70,7 +70,7 @@ class StreamProcessor {
// 2. pass the decoded data to filters // 2. pass the decoded data to filters
// 3. each filter store the result to the corresponding buffer // 3. each filter store the result to the corresponding buffer
// - Sending NULL will drain (flush) the internal // - Sending NULL will drain (flush) the internal
int process_packet(AVPacket* packet); int process_packet(AVPacket* packet, int64_t discard_before_pts = -1);
// flush the internal buffer of decoder. // flush the internal buffer of decoder.
// To be use when seeking // To be use when seeking
......
...@@ -163,12 +163,38 @@ bool StreamReader::is_buffer_ready() const { ...@@ -163,12 +163,38 @@ bool StreamReader::is_buffer_ready() const {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Configure methods // Configure methods
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
void StreamReader::seek(double timestamp) { void StreamReader::seek(double timestamp_s, int64_t mode) {
TORCH_CHECK(timestamp >= 0, "timestamp must be non-negative."); TORCH_CHECK(timestamp_s >= 0, "timestamp must be non-negative.");
TORCH_CHECK(
pFormatContext->nb_streams > 0,
"At least one stream must exist in this context");
int64_t timestamp_av_tb = static_cast<int64_t>(timestamp_s * AV_TIME_BASE);
int flag = AVSEEK_FLAG_BACKWARD;
switch (mode) {
case 0:
seek_timestamp =
-1; // reset seek_timestap as it is only used for precise seek
break;
case 1:
flag |= AVSEEK_FLAG_ANY;
seek_timestamp =
-1; // reset seek_timestap as it is only used for precise seek
break;
case 2:
seek_timestamp = timestamp_av_tb;
break;
default:
TORCH_CHECK(false, "Invalid mode value: ", mode);
}
int ret = av_seek_frame(pFormatContext, -1, timestamp_av_tb, flag);
int64_t ts = static_cast<int64_t>(timestamp * AV_TIME_BASE); if (ret < 0) {
int ret = avformat_seek_file(pFormatContext, -1, INT64_MIN, ts, INT64_MAX, 0); seek_timestamp = -1;
TORCH_CHECK(ret >= 0, "Failed to seek. (" + av_err2string(ret) + ".)"); TORCH_CHECK(false, "Failed to seek. (" + av_err2string(ret) + ".)");
}
for (const auto& it : processors) { for (const auto& it : processors) {
if (it) { if (it) {
it->flush(); it->flush();
...@@ -301,7 +327,14 @@ int StreamReader::process_packet() { ...@@ -301,7 +327,14 @@ int StreamReader::process_packet() {
if (!processor) { if (!processor) {
return 0; return 0;
} }
ret = processor->process_packet(packet);
AVRational stream_tb =
pFormatContext->streams[pPacket->stream_index]->time_base;
int64_t seek_timestamp_in_stream_tb =
av_rescale_q(seek_timestamp, av_get_time_base_q(), stream_tb);
ret = processor->process_packet(packet, seek_timestamp_in_stream_tb);
return (ret < 0) ? ret : 0; return (ret < 0) ? ret : 0;
} }
......
...@@ -17,6 +17,9 @@ class StreamReader { ...@@ -17,6 +17,9 @@ class StreamReader {
// the second is the map key inside of processor. // the second is the map key inside of processor.
std::vector<std::pair<int, int>> stream_indices; std::vector<std::pair<int, int>> stream_indices;
// timestamp to seek to expressed in AV_TIME_BASE
int64_t seek_timestamp = -1;
public: public:
explicit StreamReader(AVFormatInputContextPtr&& p); explicit StreamReader(AVFormatInputContextPtr&& p);
~StreamReader() = default; ~StreamReader() = default;
...@@ -57,7 +60,7 @@ class StreamReader { ...@@ -57,7 +60,7 @@ class StreamReader {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Configure methods // Configure methods
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
void seek(double timestamp); void seek(double timestamp_s, int64_t mode);
void add_audio_stream( void add_audio_stream(
int64_t i, int64_t i,
......
...@@ -42,7 +42,7 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { ...@@ -42,7 +42,7 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
.def( .def(
"find_best_video_stream", "find_best_video_stream",
[](S s) { return s->find_best_video_stream(); }) [](S s) { return s->find_best_video_stream(); })
.def("seek", [](S s, double t) { return s->seek(t); }) .def("seek", [](S s, double t, int64_t mode) { return s->seek(t, mode); })
.def( .def(
"add_audio_stream", "add_audio_stream",
[](S s, [](S s,
......
...@@ -123,7 +123,7 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { ...@@ -123,7 +123,7 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
.def( .def(
"find_best_video_stream", "find_best_video_stream",
[](S s) { return s->find_best_video_stream(); }) [](S s) { return s->find_best_video_stream(); })
.def("seek", [](S s, double t) { return s->seek(t); }) .def("seek", [](S s, double t, int64_t mode) { return s->seek(t, mode); })
.def( .def(
"add_audio_stream", "add_audio_stream",
[](S s, [](S s,
......
...@@ -425,13 +425,34 @@ class StreamReader: ...@@ -425,13 +425,34 @@ class StreamReader:
""" """
return _parse_oi(self._be.get_out_stream_info(i)) return _parse_oi(self._be.get_out_stream_info(i))
def seek(self, timestamp: float): def seek(self, timestamp: float, mode: str = "precise"):
"""Seek the stream to the given timestamp [second] """Seek the stream to the given timestamp [second]
Args: Args:
timestamp (float): Target time in second. timestamp (float): Target time in second.
mode (str): Controls how seek is done.
Valid choices are;
* "key": Seek into the nearest key frame before the given timestamp.
* "any": Seek into any frame (including non-key frames) before the given timestamp.
* "precise": First seek into the nearest key frame before the given timestamp, then
decode frames until it reaches the closes frame to the given timestamp.
Note:
All the modes invalidate and reset the internal state of decoder.
When using "any" mode and if it ends up seeking into non-key frame,
the image decoded may be invalid due to lack of key frame.
Using "precise" will workaround this issue by decoding frames from previous
key frame, but will be slower.
""" """
self._be.seek(timestamp) modes = {
"key": 0,
"any": 1,
"precise": 2,
}
if mode not in modes:
raise ValueError(f"The value of mode must be one of {list(modes.keys())}. Found: {mode}")
self._be.seek(timestamp, modes[mode])
@_format_audio_args @_format_audio_args
def add_basic_audio_stream( def add_basic_audio_stream(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment