Commit 7e1afc40 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Flush and reset internal state after seek (#2264)

Summary:
This commit adds the following behavior to `seek` so that `seek`
works after a frame is decoded.

1. Flush the decoder buffer.
2. Recreate filter graphs (so that internal state is re-initialized)
3. Discard the buffered tensor. (decoded chunks)

Also it disallows negative values for seek timestamp.

Pull Request resolved: https://github.com/pytorch/audio/pull/2264

Reviewed By: carolineechen

Differential Revision: D34497826

Pulled By: mthrok

fbshipit-source-id: 8b9a5bf160dfeb15f5cced3eed2288c33e2eb35d
parent 04875eef
......@@ -300,6 +300,22 @@ class StreamerInterfaceTest(TempDirMixin, TorchaudioTestCase):
if i >= 40:
break
def test_seek(self):
"""Calling `seek` multiple times should not segfault"""
s = Streamer(get_video_asset())
for i in range(10):
s.seek(i)
for _ in range(0):
s.seek(0)
for i in range(10, 0, -1):
s.seek(i)
def test_seek_negative(self):
"""Calling `seek` with negative value should raise an exception"""
s = Streamer(get_video_asset())
with self.assertRaises(ValueError):
s.seek(-1.0)
@skipIfNoFFmpeg
class StreamerAudioTest(TempDirMixin, TorchaudioTestCase):
......@@ -363,6 +379,21 @@ class StreamerAudioTest(TempDirMixin, TorchaudioTestCase):
(output,) = s.pop_chunks()
self.assertEqual(expected, output)
def test_audio_seek_multiple(self):
"""Calling `seek` after streaming is started should change the position properly"""
path, original = self._get_reference_wav(1, dtype="int16", num_channels=2, num_frames=30)
s = Streamer(path)
s.add_audio_stream(frames_per_chunk=-1)
ts = list(range(20)) + list(range(20, 0, -1)) + list(range(20))
for t in ts:
s.seek(float(t))
s.process_all_packets()
(output,) = s.pop_chunks()
expected = original[t:, :]
self.assertEqual(expected, output)
@nested_params(
[
(18, 6, 3), # num_frames is divisible by frames_per_chunk
......
......@@ -285,5 +285,9 @@ torch::Tensor Buffer::pop_all() {
return torch::cat(ret, 0);
}
void Buffer::flush() {
chunks.clear();
}
} // namespace ffmpeg
} // namespace torchaudio
......@@ -40,6 +40,8 @@ class Buffer {
c10::optional<torch::Tensor> pop_chunk();
void flush();
private:
virtual torch::Tensor pop_one_chunk() = 0;
torch::Tensor pop_all();
......
......@@ -16,5 +16,9 @@ int Decoder::get_frame(AVFrame* pFrame) {
return avcodec_receive_frame(pCodecContext, pFrame);
}
void Decoder::flush_buffer() {
avcodec_flush_buffers(pCodecContext);
}
} // namespace ffmpeg
} // namespace torchaudio
......@@ -24,6 +24,8 @@ class Decoder {
int process_packet(AVPacket* pPacket);
// Fetch a decoded frame
int get_frame(AVFrame* pFrame);
// Flush buffer (for seek)
void flush_buffer();
};
} // namespace ffmpeg
......
......@@ -196,5 +196,9 @@ AVFilterGraph* get_filter_graph() {
} // namespace
AVFilterGraphPtr::AVFilterGraphPtr()
: Wrapper<AVFilterGraph, AVFilterGraphDeleter>(get_filter_graph()) {}
void AVFilterGraphPtr::reset() {
ptr.reset(get_filter_graph());
}
} // namespace ffmpeg
} // namespace torchaudio
......@@ -129,6 +129,7 @@ struct AVFilterGraphDeleter {
};
struct AVFilterGraphPtr : public Wrapper<AVFilterGraph, AVFilterGraphDeleter> {
AVFilterGraphPtr();
void reset();
};
} // namespace ffmpeg
} // namespace torchaudio
......@@ -8,11 +8,11 @@ FilterGraph::FilterGraph(
AVRational time_base,
AVCodecParameters* codecpar,
std::string filter_description)
: filter_description(filter_description) {
add_src(time_base, codecpar);
add_sink();
add_process();
create_filter();
: input_time_base(time_base),
codecpar(codecpar),
filter_description(std::move(filter_description)),
media_type(codecpar->codec_type) {
init();
}
////////////////////////////////////////////////////////////////////////////////
......@@ -62,18 +62,29 @@ std::string get_video_src_args(
} // namespace
void FilterGraph::add_src(AVRational time_base, AVCodecParameters* codecpar) {
if (media_type != AVMEDIA_TYPE_UNKNOWN) {
throw std::runtime_error("Source buffer is already allocated.");
}
media_type = codecpar->codec_type;
void FilterGraph::init() {
add_src();
add_sink();
add_process();
create_filter();
}
void FilterGraph::reset() {
pFilterGraph.reset();
buffersrc_ctx = nullptr;
buffersink_ctx = nullptr;
init();
}
void FilterGraph::add_src() {
std::string args;
switch (media_type) {
case AVMEDIA_TYPE_AUDIO:
args = get_audio_src_args(time_base, codecpar);
args = get_audio_src_args(input_time_base, codecpar);
break;
case AVMEDIA_TYPE_VIDEO:
args = get_video_src_args(time_base, codecpar);
args = get_video_src_args(input_time_base, codecpar);
break;
default:
throw std::runtime_error("Only audio/video are supported.");
......
......@@ -5,13 +5,20 @@ namespace torchaudio {
namespace ffmpeg {
class FilterGraph {
AVMediaType media_type = AVMEDIA_TYPE_UNKNOWN;
// Parameters required for `reset`
// Recreats the underlying filter_graph struct
AVRational input_time_base;
AVCodecParameters* codecpar;
std::string filter_description;
// Constant just for convenient access.
AVMediaType media_type;
AVFilterGraphPtr pFilterGraph;
// AVFilterContext is freed as a part of AVFilterGraph
// so we do not manage the resource.
AVFilterContext* buffersrc_ctx = nullptr;
AVFilterContext* buffersink_ctx = nullptr;
const std::string filter_description;
public:
FilterGraph(
......@@ -35,8 +42,12 @@ class FilterGraph {
//////////////////////////////////////////////////////////////////////////////
// Configuration methods
//////////////////////////////////////////////////////////////////////////////
void init();
void reset();
private:
void add_src(AVRational time_base, AVCodecParameters* codecpar);
void add_src();
void add_sink();
......
......@@ -53,5 +53,11 @@ int Sink::process_frame(AVFrame* pFrame) {
bool Sink::is_buffer_ready() const {
return buffer->is_ready();
}
void Sink::flush() {
filter.reset();
buffer->flush();
}
} // namespace ffmpeg
} // namespace torchaudio
......@@ -22,6 +22,8 @@ class Sink {
int process_frame(AVFrame* frame);
bool is_buffer_ready() const;
void flush();
};
} // namespace ffmpeg
......
......@@ -82,6 +82,13 @@ int StreamProcessor::process_packet(AVPacket* packet) {
return ret;
}
void StreamProcessor::flush() {
decoder.flush_buffer();
for (auto& ite : sinks) {
ite.second.flush();
}
}
// 0: some kind of success
// <0: Some error happened
int StreamProcessor::send_frame(AVFrame* pFrame) {
......
......@@ -68,6 +68,10 @@ class StreamProcessor {
// - Sending NULL will drain (flush) the internal
int process_packet(AVPacket* packet);
// flush the internal buffer of decoder.
// To be use when seeking
void flush();
private:
int send_frame(AVFrame* pFrame);
......
......@@ -136,11 +136,20 @@ bool Streamer::is_buffer_ready() const {
// Configure methods
////////////////////////////////////////////////////////////////////////////////
void Streamer::seek(double timestamp) {
if (timestamp < 0) {
throw std::invalid_argument("timestamp must be non-negative.");
}
int64_t ts = static_cast<int64_t>(timestamp * AV_TIME_BASE);
int ret = avformat_seek_file(pFormatContext, -1, INT64_MIN, ts, INT64_MAX, 0);
if (ret < 0) {
throw std::runtime_error("Failed to seek. (" + av_err2string(ret) + ".)");
}
for (const auto& it : processors) {
if (it) {
it->flush();
}
}
}
void Streamer::add_audio_stream(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment