Commit 7e1afc40 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Flush and reset internal state after seek (#2264)

Summary:
This commit adds the following behavior to `seek` so that `seek`
works after a frame is decoded.

1. Flush the decoder buffer.
2. Recreate filter graphs (so that internal state is re-initialized)
3. Discard the buffered tensor. (decoded chunks)

Also it disallows negative values for seek timestamp.

Pull Request resolved: https://github.com/pytorch/audio/pull/2264

Reviewed By: carolineechen

Differential Revision: D34497826

Pulled By: mthrok

fbshipit-source-id: 8b9a5bf160dfeb15f5cced3eed2288c33e2eb35d
parent 04875eef
...@@ -300,6 +300,22 @@ class StreamerInterfaceTest(TempDirMixin, TorchaudioTestCase): ...@@ -300,6 +300,22 @@ class StreamerInterfaceTest(TempDirMixin, TorchaudioTestCase):
if i >= 40: if i >= 40:
break break
def test_seek(self):
"""Calling `seek` multiple times should not segfault"""
s = Streamer(get_video_asset())
for i in range(10):
s.seek(i)
for _ in range(0):
s.seek(0)
for i in range(10, 0, -1):
s.seek(i)
def test_seek_negative(self):
"""Calling `seek` with negative value should raise an exception"""
s = Streamer(get_video_asset())
with self.assertRaises(ValueError):
s.seek(-1.0)
@skipIfNoFFmpeg @skipIfNoFFmpeg
class StreamerAudioTest(TempDirMixin, TorchaudioTestCase): class StreamerAudioTest(TempDirMixin, TorchaudioTestCase):
...@@ -363,6 +379,21 @@ class StreamerAudioTest(TempDirMixin, TorchaudioTestCase): ...@@ -363,6 +379,21 @@ class StreamerAudioTest(TempDirMixin, TorchaudioTestCase):
(output,) = s.pop_chunks() (output,) = s.pop_chunks()
self.assertEqual(expected, output) self.assertEqual(expected, output)
def test_audio_seek_multiple(self):
"""Calling `seek` after streaming is started should change the position properly"""
path, original = self._get_reference_wav(1, dtype="int16", num_channels=2, num_frames=30)
s = Streamer(path)
s.add_audio_stream(frames_per_chunk=-1)
ts = list(range(20)) + list(range(20, 0, -1)) + list(range(20))
for t in ts:
s.seek(float(t))
s.process_all_packets()
(output,) = s.pop_chunks()
expected = original[t:, :]
self.assertEqual(expected, output)
@nested_params( @nested_params(
[ [
(18, 6, 3), # num_frames is divisible by frames_per_chunk (18, 6, 3), # num_frames is divisible by frames_per_chunk
......
...@@ -285,5 +285,9 @@ torch::Tensor Buffer::pop_all() { ...@@ -285,5 +285,9 @@ torch::Tensor Buffer::pop_all() {
return torch::cat(ret, 0); return torch::cat(ret, 0);
} }
void Buffer::flush() {
chunks.clear();
}
} // namespace ffmpeg } // namespace ffmpeg
} // namespace torchaudio } // namespace torchaudio
...@@ -40,6 +40,8 @@ class Buffer { ...@@ -40,6 +40,8 @@ class Buffer {
c10::optional<torch::Tensor> pop_chunk(); c10::optional<torch::Tensor> pop_chunk();
void flush();
private: private:
virtual torch::Tensor pop_one_chunk() = 0; virtual torch::Tensor pop_one_chunk() = 0;
torch::Tensor pop_all(); torch::Tensor pop_all();
......
...@@ -16,5 +16,9 @@ int Decoder::get_frame(AVFrame* pFrame) { ...@@ -16,5 +16,9 @@ int Decoder::get_frame(AVFrame* pFrame) {
return avcodec_receive_frame(pCodecContext, pFrame); return avcodec_receive_frame(pCodecContext, pFrame);
} }
void Decoder::flush_buffer() {
avcodec_flush_buffers(pCodecContext);
}
} // namespace ffmpeg } // namespace ffmpeg
} // namespace torchaudio } // namespace torchaudio
...@@ -24,6 +24,8 @@ class Decoder { ...@@ -24,6 +24,8 @@ class Decoder {
int process_packet(AVPacket* pPacket); int process_packet(AVPacket* pPacket);
// Fetch a decoded frame // Fetch a decoded frame
int get_frame(AVFrame* pFrame); int get_frame(AVFrame* pFrame);
// Flush buffer (for seek)
void flush_buffer();
}; };
} // namespace ffmpeg } // namespace ffmpeg
......
...@@ -196,5 +196,9 @@ AVFilterGraph* get_filter_graph() { ...@@ -196,5 +196,9 @@ AVFilterGraph* get_filter_graph() {
} // namespace } // namespace
AVFilterGraphPtr::AVFilterGraphPtr() AVFilterGraphPtr::AVFilterGraphPtr()
: Wrapper<AVFilterGraph, AVFilterGraphDeleter>(get_filter_graph()) {} : Wrapper<AVFilterGraph, AVFilterGraphDeleter>(get_filter_graph()) {}
void AVFilterGraphPtr::reset() {
ptr.reset(get_filter_graph());
}
} // namespace ffmpeg } // namespace ffmpeg
} // namespace torchaudio } // namespace torchaudio
...@@ -129,6 +129,7 @@ struct AVFilterGraphDeleter { ...@@ -129,6 +129,7 @@ struct AVFilterGraphDeleter {
}; };
struct AVFilterGraphPtr : public Wrapper<AVFilterGraph, AVFilterGraphDeleter> { struct AVFilterGraphPtr : public Wrapper<AVFilterGraph, AVFilterGraphDeleter> {
AVFilterGraphPtr(); AVFilterGraphPtr();
void reset();
}; };
} // namespace ffmpeg } // namespace ffmpeg
} // namespace torchaudio } // namespace torchaudio
...@@ -8,11 +8,11 @@ FilterGraph::FilterGraph( ...@@ -8,11 +8,11 @@ FilterGraph::FilterGraph(
AVRational time_base, AVRational time_base,
AVCodecParameters* codecpar, AVCodecParameters* codecpar,
std::string filter_description) std::string filter_description)
: filter_description(filter_description) { : input_time_base(time_base),
add_src(time_base, codecpar); codecpar(codecpar),
add_sink(); filter_description(std::move(filter_description)),
add_process(); media_type(codecpar->codec_type) {
create_filter(); init();
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
...@@ -62,18 +62,29 @@ std::string get_video_src_args( ...@@ -62,18 +62,29 @@ std::string get_video_src_args(
} // namespace } // namespace
void FilterGraph::add_src(AVRational time_base, AVCodecParameters* codecpar) { void FilterGraph::init() {
if (media_type != AVMEDIA_TYPE_UNKNOWN) { add_src();
throw std::runtime_error("Source buffer is already allocated."); add_sink();
} add_process();
media_type = codecpar->codec_type; create_filter();
}
void FilterGraph::reset() {
pFilterGraph.reset();
buffersrc_ctx = nullptr;
buffersink_ctx = nullptr;
init();
}
void FilterGraph::add_src() {
std::string args; std::string args;
switch (media_type) { switch (media_type) {
case AVMEDIA_TYPE_AUDIO: case AVMEDIA_TYPE_AUDIO:
args = get_audio_src_args(time_base, codecpar); args = get_audio_src_args(input_time_base, codecpar);
break; break;
case AVMEDIA_TYPE_VIDEO: case AVMEDIA_TYPE_VIDEO:
args = get_video_src_args(time_base, codecpar); args = get_video_src_args(input_time_base, codecpar);
break; break;
default: default:
throw std::runtime_error("Only audio/video are supported."); throw std::runtime_error("Only audio/video are supported.");
......
...@@ -5,13 +5,20 @@ namespace torchaudio { ...@@ -5,13 +5,20 @@ namespace torchaudio {
namespace ffmpeg { namespace ffmpeg {
class FilterGraph { class FilterGraph {
AVMediaType media_type = AVMEDIA_TYPE_UNKNOWN; // Parameters required for `reset`
// Recreats the underlying filter_graph struct
AVRational input_time_base;
AVCodecParameters* codecpar;
std::string filter_description;
// Constant just for convenient access.
AVMediaType media_type;
AVFilterGraphPtr pFilterGraph; AVFilterGraphPtr pFilterGraph;
// AVFilterContext is freed as a part of AVFilterGraph // AVFilterContext is freed as a part of AVFilterGraph
// so we do not manage the resource. // so we do not manage the resource.
AVFilterContext* buffersrc_ctx = nullptr; AVFilterContext* buffersrc_ctx = nullptr;
AVFilterContext* buffersink_ctx = nullptr; AVFilterContext* buffersink_ctx = nullptr;
const std::string filter_description;
public: public:
FilterGraph( FilterGraph(
...@@ -35,8 +42,12 @@ class FilterGraph { ...@@ -35,8 +42,12 @@ class FilterGraph {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Configuration methods // Configuration methods
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
void init();
void reset();
private: private:
void add_src(AVRational time_base, AVCodecParameters* codecpar); void add_src();
void add_sink(); void add_sink();
......
...@@ -53,5 +53,11 @@ int Sink::process_frame(AVFrame* pFrame) { ...@@ -53,5 +53,11 @@ int Sink::process_frame(AVFrame* pFrame) {
bool Sink::is_buffer_ready() const { bool Sink::is_buffer_ready() const {
return buffer->is_ready(); return buffer->is_ready();
} }
void Sink::flush() {
filter.reset();
buffer->flush();
}
} // namespace ffmpeg } // namespace ffmpeg
} // namespace torchaudio } // namespace torchaudio
...@@ -22,6 +22,8 @@ class Sink { ...@@ -22,6 +22,8 @@ class Sink {
int process_frame(AVFrame* frame); int process_frame(AVFrame* frame);
bool is_buffer_ready() const; bool is_buffer_ready() const;
void flush();
}; };
} // namespace ffmpeg } // namespace ffmpeg
......
...@@ -82,6 +82,13 @@ int StreamProcessor::process_packet(AVPacket* packet) { ...@@ -82,6 +82,13 @@ int StreamProcessor::process_packet(AVPacket* packet) {
return ret; return ret;
} }
void StreamProcessor::flush() {
decoder.flush_buffer();
for (auto& ite : sinks) {
ite.second.flush();
}
}
// 0: some kind of success // 0: some kind of success
// <0: Some error happened // <0: Some error happened
int StreamProcessor::send_frame(AVFrame* pFrame) { int StreamProcessor::send_frame(AVFrame* pFrame) {
......
...@@ -68,6 +68,10 @@ class StreamProcessor { ...@@ -68,6 +68,10 @@ class StreamProcessor {
// - Sending NULL will drain (flush) the internal // - Sending NULL will drain (flush) the internal
int process_packet(AVPacket* packet); int process_packet(AVPacket* packet);
// flush the internal buffer of decoder.
// To be use when seeking
void flush();
private: private:
int send_frame(AVFrame* pFrame); int send_frame(AVFrame* pFrame);
......
...@@ -136,11 +136,20 @@ bool Streamer::is_buffer_ready() const { ...@@ -136,11 +136,20 @@ bool Streamer::is_buffer_ready() const {
// Configure methods // Configure methods
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
void Streamer::seek(double timestamp) { void Streamer::seek(double timestamp) {
if (timestamp < 0) {
throw std::invalid_argument("timestamp must be non-negative.");
}
int64_t ts = static_cast<int64_t>(timestamp * AV_TIME_BASE); int64_t ts = static_cast<int64_t>(timestamp * AV_TIME_BASE);
int ret = avformat_seek_file(pFormatContext, -1, INT64_MIN, ts, INT64_MAX, 0); int ret = avformat_seek_file(pFormatContext, -1, INT64_MIN, ts, INT64_MAX, 0);
if (ret < 0) { if (ret < 0) {
throw std::runtime_error("Failed to seek. (" + av_err2string(ret) + ".)"); throw std::runtime_error("Failed to seek. (" + av_err2string(ret) + ".)");
} }
for (const auto& it : processors) {
if (it) {
it->flush();
}
}
} }
void Streamer::add_audio_stream( void Streamer::add_audio_stream(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment