Commit 52b6bc3b authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor chunked buffer implementation (#2984)

Summary:
So that the number of Tensor frames stored in buffers is always a multiple of frames_per_chunk.

This makes it easy to store PTS values in aligned manner.

Pull Request resolved: https://github.com/pytorch/audio/pull/2984

Reviewed By: nateanl

Differential Revision: D42526670

Pulled By: mthrok

fbshipit-source-id: d83ee914b7e50de3b51758069b0e0b6b3ebe2e54
parent 3ecf78d6
...@@ -553,6 +553,28 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -553,6 +553,28 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
assert video.shape == torch.Size([390, 3, 270, 480]) assert video.shape == torch.Size([390, 3, 270, 480])
assert audio.shape == torch.Size([208896, 2]) assert audio.shape == torch.Size([208896, 2])
@parameterized.expand([(1,), (3,), (5,), (10,)])
def test_frames_per_chunk(self, fpc):
"""Changing frames_per_chunk does not change the returned content"""
src = self.get_src()
s = StreamReader(src)
s.add_video_stream(frames_per_chunk=-1, buffer_chunk_size=-1)
s.add_audio_stream(frames_per_chunk=-1, buffer_chunk_size=-1)
s.process_all_packets()
ref_video, ref_audio = s.pop_chunks()
if self.test_type == "fileobj":
src.seek(0)
s = StreamReader(src)
s.add_video_stream(frames_per_chunk=fpc, buffer_chunk_size=-1)
s.add_audio_stream(frames_per_chunk=fpc, buffer_chunk_size=-1)
chunks = list(s.stream())
video_chunks = torch.cat([c[0] for c in chunks if c[0] is not None])
audio_chunks = torch.cat([c[1] for c in chunks if c[1] is not None])
self.assertEqual(ref_video, video_chunks)
self.assertEqual(ref_audio, audio_chunks)
def test_buffer_chunk_size(self): def test_buffer_chunk_size(self):
"""`buffer_chunk_size=-1` does not drop frames.""" """`buffer_chunk_size=-1` does not drop frames."""
src = self.get_src() src = self.get_src()
......
...@@ -20,108 +20,111 @@ bool ChunkedBuffer::is_ready() const { ...@@ -20,108 +20,111 @@ bool ChunkedBuffer::is_ready() const {
return num_buffered_frames >= frames_per_chunk; return num_buffered_frames >= frames_per_chunk;
} }
void ChunkedAudioBuffer::push_tensor(torch::Tensor frame) { void ChunkedBuffer::push_tensor(torch::Tensor frame) {
// Push using namespace torch::indexing;
// Note: // Note:
// For audio, the incoming tensor contains multiple of samples. // Audio tensors contain multiple frames while video tensors contain only
// For small `frames_per_chunk` value, it might be more than `max_frames`. // one frame. Video tensors can be regarded as special degenerated case of
// If we push the tensor as-is, then, the whole frame might be popped at // audio, so in the following, we only consider audio processing.
// trimming stage, resulting buffer always empty. So we slice push the //
// incoming Tensor. // The incoming Tensor might contain more frames than the value of
// `frames_per_chunk`.
// Check the last inserted Tensor and if the numbe of frames is not // If we push the input tensor to dequeu as-is, then, at the trimming stage,
// frame_per_chunk, reprocess it again with the incomping tensor // the entire frames would be trimmed, this is not ideal. We want to keep
if (num_buffered_frames % frames_per_chunk) { // at most `frames_per_chunk * num_chunks` frames.
// So we slice push the incoming Tensor.
//
// 1. Check if the last chunk is fully filled. If not, fill it.
//
// <----- frames per chunk ----->^
// x x x x x x x x x x x x x x x |
// x x x x x x x + + + + + + - - | num_chunks
// - - - - - - - - - - - - - - - |
// <-- filled --><--- remain --->v
// <- append->
//
if (int64_t filled = num_buffered_frames % frames_per_chunk) {
int64_t num_frames = frame.size(0);
int64_t remain = frames_per_chunk - filled;
int64_t append = remain < num_frames ? remain : num_frames;
torch::Tensor prev = chunks.back(); torch::Tensor prev = chunks.back();
chunks.pop_back(); // prev[filled:filled+append] = frame[:append]
num_buffered_frames -= prev.size(0); prev.index_put_(
frame = torch::cat({prev, frame}, 0); {Slice(filled, filled + append)}, frame.index({Slice(None, append)}));
num_buffered_frames += append;
// frame = frame[append:]
frame = frame.index({Slice(append)});
} }
while (true) { // 2. Return if the number of input frames are smaller than the empty buffer.
int64_t num_input_frames = frame.size(0); // i.e. all the frames are pushed.
if (num_input_frames <= frames_per_chunk) { if (frame.numel() == 0) {
chunks.push_back(frame); return;
num_buffered_frames += num_input_frames;
break;
}
// The input tensor contains more frames than frames_per_chunk
auto splits =
torch::tensor_split(frame, {frames_per_chunk, num_input_frames});
chunks.push_back(splits[0]);
num_buffered_frames += frames_per_chunk;
frame = splits[1];
} }
// Trim // 3. Now the existing buffer chunks are fully filled, start adding new chunks
// If frames_per_chunk > 0, we only retain the following number of frames and //
// Discard older frames. // <----- frames per chunk ----->^
if (num_chunks > 0) { // x x x x x x x x x x x x x x x |
int64_t max_frames = num_chunks * frames_per_chunk; // x x x x x x x x x x x x x x x | num_chunks
while (num_buffered_frames > max_frames) { // + + + + + + + + + + + + + + + |
TORCH_WARN_ONCE( // <---------- append ---------->v
"The number of buffered frames exceeded the buffer size. " //
"Dropping the old frames. " int64_t num_frames = frame.size(0);
"To avoid this, you can set a higher buffer_chunk_size value."); int64_t num_splits =
torch::Tensor& t = chunks.front(); num_frames / frames_per_chunk + (num_frames % frames_per_chunk ? 1 : 0);
num_buffered_frames -= t.size(0); for (int64_t i = 0; i < num_splits; ++i) {
chunks.pop_front(); int64_t start = i * frames_per_chunk;
// chunk = frame[i*frames_per_chunk:(i+1) * frames_per_chunk]
auto chunk = frame.index({Slice(start, start + frames_per_chunk)});
int64_t chunk_size = chunk.size(0);
TORCH_INTERNAL_ASSERT(
chunk_size <= frames_per_chunk,
"Chunk size is larger than frames per chunk. Please file an issue.");
if (chunk_size < frames_per_chunk) {
auto shape = chunk.sizes().vec();
shape[0] = frames_per_chunk;
auto temp = torch::empty(shape, frame.options());
temp.index_put_({Slice(None, chunk_size)}, chunk);
chunk = temp;
} }
} chunks.push_back(chunk);
} num_buffered_frames += chunk_size;
void ChunkedAudioBuffer::push_frame(AVFrame* frame) { // Trim if num_chunks > 0
push_tensor(detail::convert_audio(frame)); if (num_chunks > 0 && chunks.size() > num_chunks) {
}
void ChunkedVideoBuffer::push_tensor(const torch::Tensor& frame) {
// the video frames is expected to contain only one frame
chunks.push_back(frame);
num_buffered_frames += frame.size(0);
// Trim
if (num_chunks > 0) {
int64_t max_frames = num_chunks * frames_per_chunk;
if (num_buffered_frames > max_frames) {
TORCH_WARN_ONCE( TORCH_WARN_ONCE(
"The number of buffered frames exceeded the buffer size. " "The number of buffered frames exceeded the buffer size. "
"Dropping the old frames. " "Dropping the old frames. "
"To avoid this, you can set a higher buffer_chunk_size value."); "To avoid this, you can set a higher buffer_chunk_size value.");
torch::Tensor& t = chunks.front();
num_buffered_frames -= t.size(0);
chunks.pop_front(); chunks.pop_front();
num_buffered_frames -= frames_per_chunk;
} }
} }
} }
void ChunkedVideoBuffer::push_frame(AVFrame* frame) { c10::optional<torch::Tensor> ChunkedBuffer::pop_chunk() {
push_tensor(detail::convert_image(frame, device)); using namespace torch::indexing;
}
c10::optional<torch::Tensor> ChunkedAudioBuffer::pop_chunk() {
if (!num_buffered_frames) { if (!num_buffered_frames) {
return {}; return {};
} }
// Audio deque are aligned with `frames_per_chunk`
torch::Tensor ret = chunks.front(); torch::Tensor ret = chunks.front();
chunks.pop_front(); chunks.pop_front();
if (num_buffered_frames < frames_per_chunk) {
ret = ret.index({Slice(None, num_buffered_frames)});
}
num_buffered_frames -= ret.size(0); num_buffered_frames -= ret.size(0);
return c10::optional<torch::Tensor>{ret}; return c10::optional<torch::Tensor>{ret};
} }
c10::optional<torch::Tensor> ChunkedVideoBuffer::pop_chunk() { void ChunkedAudioBuffer::push_frame(AVFrame* frame) {
if (!num_buffered_frames) { push_tensor(detail::convert_audio(frame));
return {}; }
}
// Video deque contains one frame par one tensor void ChunkedVideoBuffer::push_frame(AVFrame* frame) {
std::vector<torch::Tensor> ret; push_tensor(detail::convert_image(frame, device));
while (num_buffered_frames > 0 && ret.size() < frames_per_chunk) {
torch::Tensor& t = chunks.front();
ret.push_back(t);
chunks.pop_front();
num_buffered_frames -= 1;
}
return c10::optional<torch::Tensor>{torch::cat(ret, 0)};
} }
void ChunkedBuffer::flush() { void ChunkedBuffer::flush() {
......
...@@ -10,9 +10,6 @@ namespace ffmpeg { ...@@ -10,9 +10,6 @@ namespace ffmpeg {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Common to both audio and video // Common to both audio and video
class ChunkedBuffer : public Buffer { class ChunkedBuffer : public Buffer {
protected:
ChunkedBuffer(int frames_per_chunk, int num_chunks);
// Each AVFrame is converted to a Tensor and stored here. // Each AVFrame is converted to a Tensor and stored here.
std::deque<torch::Tensor> chunks; std::deque<torch::Tensor> chunks;
...@@ -26,24 +23,26 @@ class ChunkedBuffer : public Buffer { ...@@ -26,24 +23,26 @@ class ChunkedBuffer : public Buffer {
// one Tensor contains multiple samples, so we track here. // one Tensor contains multiple samples, so we track here.
int64_t num_buffered_frames = 0; int64_t num_buffered_frames = 0;
protected:
ChunkedBuffer(int frames_per_chunk, int num_chunks);
void push_tensor(torch::Tensor frame);
public: public:
bool is_ready() const override; bool is_ready() const override;
void flush() override; void flush() override;
c10::optional<torch::Tensor> pop_chunk() override;
}; };
class ChunkedAudioBuffer : public ChunkedBuffer { class ChunkedAudioBuffer : public ChunkedBuffer {
void push_tensor(torch::Tensor frame);
public: public:
ChunkedAudioBuffer(int frames_per_chunk, int num_chunks); ChunkedAudioBuffer(int frames_per_chunk, int num_chunks);
void push_frame(AVFrame* frame) override; void push_frame(AVFrame* frame) override;
c10::optional<torch::Tensor> pop_chunk() override;
}; };
class ChunkedVideoBuffer : public ChunkedBuffer { class ChunkedVideoBuffer : public ChunkedBuffer {
const torch::Device device; const torch::Device device;
void push_tensor(const torch::Tensor& frame);
public: public:
ChunkedVideoBuffer( ChunkedVideoBuffer(
...@@ -52,7 +51,6 @@ class ChunkedVideoBuffer : public ChunkedBuffer { ...@@ -52,7 +51,6 @@ class ChunkedVideoBuffer : public ChunkedBuffer {
const torch::Device& device); const torch::Device& device);
void push_frame(AVFrame* frame) override; void push_frame(AVFrame* frame) override;
c10::optional<torch::Tensor> pop_chunk() override;
}; };
} // namespace ffmpeg } // namespace ffmpeg
......
...@@ -8,14 +8,11 @@ UnchunkedVideoBuffer::UnchunkedVideoBuffer(const torch::Device& device) ...@@ -8,14 +8,11 @@ UnchunkedVideoBuffer::UnchunkedVideoBuffer(const torch::Device& device)
: device(device) {} : device(device) {}
bool UnchunkedBuffer::is_ready() const { bool UnchunkedBuffer::is_ready() const {
return num_buffered_frames > 0; return chunks.size() > 0;
} }
void UnchunkedBuffer::push_tensor(const torch::Tensor& t) { void UnchunkedBuffer::push_tensor(const torch::Tensor& t) {
// If frames_per_chunk < 0, users want to fetch all frames.
// Just push back to chunks and that's it.
chunks.push_back(t); chunks.push_back(t);
num_buffered_frames += t.size(0);
} }
void UnchunkedAudioBuffer::push_frame(AVFrame* frame) { void UnchunkedAudioBuffer::push_frame(AVFrame* frame) {
...@@ -27,19 +24,14 @@ void UnchunkedVideoBuffer::push_frame(AVFrame* frame) { ...@@ -27,19 +24,14 @@ void UnchunkedVideoBuffer::push_frame(AVFrame* frame) {
} }
c10::optional<torch::Tensor> UnchunkedBuffer::pop_chunk() { c10::optional<torch::Tensor> UnchunkedBuffer::pop_chunk() {
if (!num_buffered_frames) { if (chunks.size() == 0) {
return c10::optional<torch::Tensor>{}; return {};
} }
std::vector<torch::Tensor> ret; auto ret =
while (chunks.size()) { torch::cat(std::vector<torch::Tensor>{chunks.begin(), chunks.end()}, 0);
torch::Tensor& t = chunks.front(); chunks.clear();
int64_t n_frames = t.size(0); return {ret};
ret.push_back(t);
chunks.pop_front();
num_buffered_frames -= n_frames;
}
return c10::optional<torch::Tensor>{torch::cat(ret, 0)};
} }
void UnchunkedBuffer::flush() { void UnchunkedBuffer::flush() {
......
...@@ -16,11 +16,6 @@ class UnchunkedBuffer : public Buffer { ...@@ -16,11 +16,6 @@ class UnchunkedBuffer : public Buffer {
// Each AVFrame is converted to a Tensor and stored here. // Each AVFrame is converted to a Tensor and stored here.
std::deque<torch::Tensor> chunks; std::deque<torch::Tensor> chunks;
// The number of currently stored chunks
// For video, one Tensor corresponds to one frame, but for audio,
// one Tensor contains multiple samples, so we track here.
int64_t num_buffered_frames = 0;
protected: protected:
void push_tensor(const torch::Tensor& t); void push_tensor(const torch::Tensor& t);
......
...@@ -22,28 +22,27 @@ std::unique_ptr<Buffer> get_buffer( ...@@ -22,28 +22,27 @@ std::unique_ptr<Buffer> get_buffer(
"`num_chunks` must be positive or -1. Found: ", "`num_chunks` must be positive or -1. Found: ",
num_chunks); num_chunks);
switch (type) { TORCH_INTERNAL_ASSERT(
case AVMEDIA_TYPE_AUDIO: { type == AVMEDIA_TYPE_AUDIO || type == AVMEDIA_TYPE_VIDEO,
if (frames_per_chunk < 0) { "Unsupported media type: ",
return std::unique_ptr<Buffer>(new UnchunkedAudioBuffer()); av_get_media_type_string(type),
} else { ". Only video or audio is supported ");
return std::unique_ptr<Buffer>(
new ChunkedAudioBuffer(frames_per_chunk, num_chunks)); // Chunked Mode
} if (frames_per_chunk > 0) {
if (type == AVMEDIA_TYPE_AUDIO) {
return std::unique_ptr<Buffer>(
new ChunkedAudioBuffer(frames_per_chunk, num_chunks));
} else {
return std::unique_ptr<Buffer>(
new ChunkedVideoBuffer(frames_per_chunk, num_chunks, device));
} }
case AVMEDIA_TYPE_VIDEO: { } else { // unchunked mode
if (frames_per_chunk < 0) { if (type == AVMEDIA_TYPE_AUDIO) {
return std::unique_ptr<Buffer>(new UnchunkedVideoBuffer(device)); return std::unique_ptr<Buffer>(new UnchunkedAudioBuffer());
} else { } else {
return std::unique_ptr<Buffer>( return std::unique_ptr<Buffer>(new UnchunkedVideoBuffer(device));
new ChunkedVideoBuffer(frames_per_chunk, num_chunks, device));
}
} }
default:
TORCH_CHECK(
false,
std::string("Unsupported media type: ") +
av_get_media_type_string(type));
} }
} }
...@@ -121,10 +120,6 @@ std::string Sink::get_filter_description() const { ...@@ -121,10 +120,6 @@ std::string Sink::get_filter_description() const {
return filter_description; return filter_description;
} }
bool Sink::is_buffer_ready() const {
return buffer->is_ready();
}
void Sink::flush() { void Sink::flush() {
filter = get_filter_graph(input_time_base, codecpar, filter_description); filter = get_filter_graph(input_time_base, codecpar, filter_description);
buffer->flush(); buffer->flush();
......
...@@ -62,7 +62,7 @@ std::string StreamProcessor::get_filter_description(KeyType key) const { ...@@ -62,7 +62,7 @@ std::string StreamProcessor::get_filter_description(KeyType key) const {
bool StreamProcessor::is_buffer_ready() const { bool StreamProcessor::is_buffer_ready() const {
for (const auto& it : sinks) { for (const auto& it : sinks) {
if (!it.second.is_buffer_ready()) { if (!it.second.buffer->is_ready()) {
return false; return false;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment