Commit 52b6bc3b authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor chunked buffer implementation (#2984)

Summary:
So that the number of Tensor frames stored in buffers is always a multiple of frames_per_chunk.

This makes it easy to store PTS values in aligned manner.

Pull Request resolved: https://github.com/pytorch/audio/pull/2984

Reviewed By: nateanl

Differential Revision: D42526670

Pulled By: mthrok

fbshipit-source-id: d83ee914b7e50de3b51758069b0e0b6b3ebe2e54
parent 3ecf78d6
......@@ -553,6 +553,28 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
assert video.shape == torch.Size([390, 3, 270, 480])
assert audio.shape == torch.Size([208896, 2])
@parameterized.expand([(1,), (3,), (5,), (10,)])
def test_frames_per_chunk(self, fpc):
"""Changing frames_per_chunk does not change the returned content"""
src = self.get_src()
s = StreamReader(src)
s.add_video_stream(frames_per_chunk=-1, buffer_chunk_size=-1)
s.add_audio_stream(frames_per_chunk=-1, buffer_chunk_size=-1)
s.process_all_packets()
ref_video, ref_audio = s.pop_chunks()
if self.test_type == "fileobj":
src.seek(0)
s = StreamReader(src)
s.add_video_stream(frames_per_chunk=fpc, buffer_chunk_size=-1)
s.add_audio_stream(frames_per_chunk=fpc, buffer_chunk_size=-1)
chunks = list(s.stream())
video_chunks = torch.cat([c[0] for c in chunks if c[0] is not None])
audio_chunks = torch.cat([c[1] for c in chunks if c[1] is not None])
self.assertEqual(ref_video, video_chunks)
self.assertEqual(ref_audio, audio_chunks)
def test_buffer_chunk_size(self):
"""`buffer_chunk_size=-1` does not drop frames."""
src = self.get_src()
......
......@@ -20,108 +20,111 @@ bool ChunkedBuffer::is_ready() const {
return num_buffered_frames >= frames_per_chunk;
}
void ChunkedAudioBuffer::push_tensor(torch::Tensor frame) {
// Push
void ChunkedBuffer::push_tensor(torch::Tensor frame) {
using namespace torch::indexing;
// Note:
// For audio, the incoming tensor contains multiple of samples.
// For small `frames_per_chunk` value, it might be more than `max_frames`.
// If we push the tensor as-is, then, the whole frame might be popped at
// trimming stage, resulting buffer always empty. So we slice push the
// incoming Tensor.
// Check the last inserted Tensor and if the numbe of frames is not
// frame_per_chunk, reprocess it again with the incomping tensor
if (num_buffered_frames % frames_per_chunk) {
// Audio tensors contain multiple frames while video tensors contain only
// one frame. Video tensors can be regarded as special degenerated case of
// audio, so in the following, we only consider audio processing.
//
// The incoming Tensor might contain more frames than the value of
// `frames_per_chunk`.
// If we push the input tensor to dequeu as-is, then, at the trimming stage,
// the entire frames would be trimmed, this is not ideal. We want to keep
// at most `frames_per_chunk * num_chunks` frames.
// So we slice push the incoming Tensor.
//
// 1. Check if the last chunk is fully filled. If not, fill it.
//
// <----- frames per chunk ----->^
// x x x x x x x x x x x x x x x |
// x x x x x x x + + + + + + - - | num_chunks
// - - - - - - - - - - - - - - - |
// <-- filled --><--- remain --->v
// <- append->
//
if (int64_t filled = num_buffered_frames % frames_per_chunk) {
int64_t num_frames = frame.size(0);
int64_t remain = frames_per_chunk - filled;
int64_t append = remain < num_frames ? remain : num_frames;
torch::Tensor prev = chunks.back();
chunks.pop_back();
num_buffered_frames -= prev.size(0);
frame = torch::cat({prev, frame}, 0);
// prev[filled:filled+append] = frame[:append]
prev.index_put_(
{Slice(filled, filled + append)}, frame.index({Slice(None, append)}));
num_buffered_frames += append;
// frame = frame[append:]
frame = frame.index({Slice(append)});
}
while (true) {
int64_t num_input_frames = frame.size(0);
if (num_input_frames <= frames_per_chunk) {
chunks.push_back(frame);
num_buffered_frames += num_input_frames;
break;
}
// The input tensor contains more frames than frames_per_chunk
auto splits =
torch::tensor_split(frame, {frames_per_chunk, num_input_frames});
chunks.push_back(splits[0]);
num_buffered_frames += frames_per_chunk;
frame = splits[1];
// 2. Return if the number of input frames are smaller than the empty buffer.
// i.e. all the frames are pushed.
if (frame.numel() == 0) {
return;
}
// Trim
// If frames_per_chunk > 0, we only retain the following number of frames and
// Discard older frames.
if (num_chunks > 0) {
int64_t max_frames = num_chunks * frames_per_chunk;
while (num_buffered_frames > max_frames) {
TORCH_WARN_ONCE(
"The number of buffered frames exceeded the buffer size. "
"Dropping the old frames. "
"To avoid this, you can set a higher buffer_chunk_size value.");
torch::Tensor& t = chunks.front();
num_buffered_frames -= t.size(0);
chunks.pop_front();
}
// 3. Now the existing buffer chunks are fully filled, start adding new chunks
//
// <----- frames per chunk ----->^
// x x x x x x x x x x x x x x x |
// x x x x x x x x x x x x x x x | num_chunks
// + + + + + + + + + + + + + + + |
// <---------- append ---------->v
//
int64_t num_frames = frame.size(0);
int64_t num_splits =
num_frames / frames_per_chunk + (num_frames % frames_per_chunk ? 1 : 0);
for (int64_t i = 0; i < num_splits; ++i) {
int64_t start = i * frames_per_chunk;
// chunk = frame[i*frames_per_chunk:(i+1) * frames_per_chunk]
auto chunk = frame.index({Slice(start, start + frames_per_chunk)});
int64_t chunk_size = chunk.size(0);
TORCH_INTERNAL_ASSERT(
chunk_size <= frames_per_chunk,
"Chunk size is larger than frames per chunk. Please file an issue.");
if (chunk_size < frames_per_chunk) {
auto shape = chunk.sizes().vec();
shape[0] = frames_per_chunk;
auto temp = torch::empty(shape, frame.options());
temp.index_put_({Slice(None, chunk_size)}, chunk);
chunk = temp;
}
}
chunks.push_back(chunk);
num_buffered_frames += chunk_size;
void ChunkedAudioBuffer::push_frame(AVFrame* frame) {
push_tensor(detail::convert_audio(frame));
}
void ChunkedVideoBuffer::push_tensor(const torch::Tensor& frame) {
// the video frames is expected to contain only one frame
chunks.push_back(frame);
num_buffered_frames += frame.size(0);
// Trim
if (num_chunks > 0) {
int64_t max_frames = num_chunks * frames_per_chunk;
if (num_buffered_frames > max_frames) {
// Trim if num_chunks > 0
if (num_chunks > 0 && chunks.size() > num_chunks) {
TORCH_WARN_ONCE(
"The number of buffered frames exceeded the buffer size. "
"Dropping the old frames. "
"To avoid this, you can set a higher buffer_chunk_size value.");
torch::Tensor& t = chunks.front();
num_buffered_frames -= t.size(0);
chunks.pop_front();
num_buffered_frames -= frames_per_chunk;
}
}
}
void ChunkedVideoBuffer::push_frame(AVFrame* frame) {
push_tensor(detail::convert_image(frame, device));
}
c10::optional<torch::Tensor> ChunkedAudioBuffer::pop_chunk() {
c10::optional<torch::Tensor> ChunkedBuffer::pop_chunk() {
using namespace torch::indexing;
if (!num_buffered_frames) {
return {};
}
// Audio deque are aligned with `frames_per_chunk`
torch::Tensor ret = chunks.front();
chunks.pop_front();
if (num_buffered_frames < frames_per_chunk) {
ret = ret.index({Slice(None, num_buffered_frames)});
}
num_buffered_frames -= ret.size(0);
return c10::optional<torch::Tensor>{ret};
}
c10::optional<torch::Tensor> ChunkedVideoBuffer::pop_chunk() {
if (!num_buffered_frames) {
return {};
}
// Video deque contains one frame par one tensor
std::vector<torch::Tensor> ret;
while (num_buffered_frames > 0 && ret.size() < frames_per_chunk) {
torch::Tensor& t = chunks.front();
ret.push_back(t);
chunks.pop_front();
num_buffered_frames -= 1;
}
return c10::optional<torch::Tensor>{torch::cat(ret, 0)};
void ChunkedAudioBuffer::push_frame(AVFrame* frame) {
push_tensor(detail::convert_audio(frame));
}
void ChunkedVideoBuffer::push_frame(AVFrame* frame) {
push_tensor(detail::convert_image(frame, device));
}
void ChunkedBuffer::flush() {
......
......@@ -10,9 +10,6 @@ namespace ffmpeg {
//////////////////////////////////////////////////////////////////////////////
// Common to both audio and video
class ChunkedBuffer : public Buffer {
protected:
ChunkedBuffer(int frames_per_chunk, int num_chunks);
// Each AVFrame is converted to a Tensor and stored here.
std::deque<torch::Tensor> chunks;
......@@ -26,24 +23,26 @@ class ChunkedBuffer : public Buffer {
// one Tensor contains multiple samples, so we track here.
int64_t num_buffered_frames = 0;
protected:
ChunkedBuffer(int frames_per_chunk, int num_chunks);
void push_tensor(torch::Tensor frame);
public:
bool is_ready() const override;
void flush() override;
c10::optional<torch::Tensor> pop_chunk() override;
};
class ChunkedAudioBuffer : public ChunkedBuffer {
void push_tensor(torch::Tensor frame);
public:
ChunkedAudioBuffer(int frames_per_chunk, int num_chunks);
void push_frame(AVFrame* frame) override;
c10::optional<torch::Tensor> pop_chunk() override;
};
class ChunkedVideoBuffer : public ChunkedBuffer {
const torch::Device device;
void push_tensor(const torch::Tensor& frame);
public:
ChunkedVideoBuffer(
......@@ -52,7 +51,6 @@ class ChunkedVideoBuffer : public ChunkedBuffer {
const torch::Device& device);
void push_frame(AVFrame* frame) override;
c10::optional<torch::Tensor> pop_chunk() override;
};
} // namespace ffmpeg
......
......@@ -8,14 +8,11 @@ UnchunkedVideoBuffer::UnchunkedVideoBuffer(const torch::Device& device)
: device(device) {}
bool UnchunkedBuffer::is_ready() const {
return num_buffered_frames > 0;
return chunks.size() > 0;
}
void UnchunkedBuffer::push_tensor(const torch::Tensor& t) {
// If frames_per_chunk < 0, users want to fetch all frames.
// Just push back to chunks and that's it.
chunks.push_back(t);
num_buffered_frames += t.size(0);
}
void UnchunkedAudioBuffer::push_frame(AVFrame* frame) {
......@@ -27,19 +24,14 @@ void UnchunkedVideoBuffer::push_frame(AVFrame* frame) {
}
c10::optional<torch::Tensor> UnchunkedBuffer::pop_chunk() {
if (!num_buffered_frames) {
return c10::optional<torch::Tensor>{};
if (chunks.size() == 0) {
return {};
}
std::vector<torch::Tensor> ret;
while (chunks.size()) {
torch::Tensor& t = chunks.front();
int64_t n_frames = t.size(0);
ret.push_back(t);
chunks.pop_front();
num_buffered_frames -= n_frames;
}
return c10::optional<torch::Tensor>{torch::cat(ret, 0)};
auto ret =
torch::cat(std::vector<torch::Tensor>{chunks.begin(), chunks.end()}, 0);
chunks.clear();
return {ret};
}
void UnchunkedBuffer::flush() {
......
......@@ -16,11 +16,6 @@ class UnchunkedBuffer : public Buffer {
// Each AVFrame is converted to a Tensor and stored here.
std::deque<torch::Tensor> chunks;
// The number of currently stored chunks
// For video, one Tensor corresponds to one frame, but for audio,
// one Tensor contains multiple samples, so we track here.
int64_t num_buffered_frames = 0;
protected:
void push_tensor(const torch::Tensor& t);
......
......@@ -22,28 +22,27 @@ std::unique_ptr<Buffer> get_buffer(
"`num_chunks` must be positive or -1. Found: ",
num_chunks);
switch (type) {
case AVMEDIA_TYPE_AUDIO: {
if (frames_per_chunk < 0) {
return std::unique_ptr<Buffer>(new UnchunkedAudioBuffer());
} else {
TORCH_INTERNAL_ASSERT(
type == AVMEDIA_TYPE_AUDIO || type == AVMEDIA_TYPE_VIDEO,
"Unsupported media type: ",
av_get_media_type_string(type),
". Only video or audio is supported ");
// Chunked Mode
if (frames_per_chunk > 0) {
if (type == AVMEDIA_TYPE_AUDIO) {
return std::unique_ptr<Buffer>(
new ChunkedAudioBuffer(frames_per_chunk, num_chunks));
}
}
case AVMEDIA_TYPE_VIDEO: {
if (frames_per_chunk < 0) {
return std::unique_ptr<Buffer>(new UnchunkedVideoBuffer(device));
} else {
return std::unique_ptr<Buffer>(
new ChunkedVideoBuffer(frames_per_chunk, num_chunks, device));
}
} else { // unchunked mode
if (type == AVMEDIA_TYPE_AUDIO) {
return std::unique_ptr<Buffer>(new UnchunkedAudioBuffer());
} else {
return std::unique_ptr<Buffer>(new UnchunkedVideoBuffer(device));
}
default:
TORCH_CHECK(
false,
std::string("Unsupported media type: ") +
av_get_media_type_string(type));
}
}
......@@ -121,10 +120,6 @@ std::string Sink::get_filter_description() const {
return filter_description;
}
bool Sink::is_buffer_ready() const {
return buffer->is_ready();
}
void Sink::flush() {
filter = get_filter_graph(input_time_base, codecpar, filter_description);
buffer->flush();
......
......@@ -62,7 +62,7 @@ std::string StreamProcessor::get_filter_description(KeyType key) const {
bool StreamProcessor::is_buffer_ready() const {
for (const auto& it : sinks) {
if (!it.second.is_buffer_ready()) {
if (!it.second.buffer->is_ready()) {
return false;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment