Commit f6e3d070 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Refactor the internal of StreamReader (#3188)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3188

Refactor the process after decoding in StreamRader.

The post-decode process consists of three parts,
1. preprocessing using FilterGraph
2. conversion to Tensor
3. store in Buffer

The FilterGraph class is a thin wrapper around AVFilterGraph
structure from FFmpeg and it is agnostic to media type. However
Tensor conversion and buffering consists of bunch of different
logics.

Currently, conversion process is abstracted away with
template, i.e. `template<typename Conversion> Buffer`, and the whole
process is implemeted in Sink class which consists of `FilterGraph`
and `Buffer` which internally contains Conversion logic, even
though conversion logic and buffer have nothing in common and beter
logically separated.

The new implementation replaces `Sink` class with `IPostDecodeProcess`
interface, which contains the three components.
The different post process is implemented as a template argument of the
actual implementation, i.e.

```c++
template<typename Converter, typename Buffer>
ProcessImpl : IPostDecodeProcess
```

and stored as `unique_ptr<IPostDecodeProcess>` on `StreamProcessor`.
([functionoid pattern](https://isocpp.org/wiki/faq/pointers-to-members#functionoids), which allows to eliminate all the branching based on the media format.)

Note:
This implementation was not possible at the initial version of
StreamReader, as there was no way of knowing the media attributes coming out
of `AVFilterGraph`. https://github.com/pytorch/audio/pull/3155 and https://github.com/pytorch/audio/pull/3183
added features to parse it properly, so we can finally make the post processing strongly-typed.

Reviewed By: hwangjeff

Differential Revision: D44242647

fbshipit-source-id: 96b8c6c72a2b8af4fa86a9b02292c65078ee265b
parent c17226a0
...@@ -13,7 +13,7 @@ set( ...@@ -13,7 +13,7 @@ set(
stream_reader/buffer/chunked_buffer.cpp stream_reader/buffer/chunked_buffer.cpp
stream_reader/buffer/unchunked_buffer.cpp stream_reader/buffer/unchunked_buffer.cpp
stream_reader/conversion.cpp stream_reader/conversion.cpp
stream_reader/sink.cpp stream_reader/post_process.cpp
stream_reader/stream_processor.cpp stream_reader/stream_processor.cpp
stream_reader/stream_reader.cpp stream_reader/stream_reader.cpp
stream_writer/encode_process.cpp stream_writer/encode_process.cpp
......
#pragma once
#include <torch/torch.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
namespace torchaudio {
namespace io {
//////////////////////////////////////////////////////////////////////////////
// Buffer Interface
//////////////////////////////////////////////////////////////////////////////
class Buffer {
public:
virtual ~Buffer() = default;
//////////////////////////////////////////////////////////////////////////////
// Query
//////////////////////////////////////////////////////////////////////////////
// Check if buffeer has enoough number of frames for a chunk
virtual bool is_ready() const = 0;
//////////////////////////////////////////////////////////////////////////////
// Modifiers
//////////////////////////////////////////////////////////////////////////////
virtual void push_frame(AVFrame* frame) = 0;
virtual c10::optional<Chunk> pop_chunk() = 0;
virtual void flush() = 0;
};
} // namespace io
} // namespace torchaudio
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/chunked_buffer.h> #include <torchaudio/csrc/ffmpeg/stream_reader/buffer/chunked_buffer.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/conversion.h>
namespace torchaudio::io::detail { namespace torchaudio::io::detail {
template <typename Converter> ChunkedBuffer::ChunkedBuffer(
ChunkedBuffer<Converter>::ChunkedBuffer(
AVRational time_base, AVRational time_base,
int frames_per_chunk_, int frames_per_chunk_,
int num_chunks_, int num_chunks_)
Converter&& converter_)
: time_base(time_base), : time_base(time_base),
frames_per_chunk(frames_per_chunk_), frames_per_chunk(frames_per_chunk_),
num_chunks(num_chunks_), num_chunks(num_chunks_){};
converter(std::move(converter_)){};
template <typename Converter> bool ChunkedBuffer::is_ready() const {
bool ChunkedBuffer<Converter>::is_ready() const {
return num_buffered_frames >= frames_per_chunk; return num_buffered_frames >= frames_per_chunk;
} }
template <typename Converter> void ChunkedBuffer::push_frame(torch::Tensor frame, int64_t pts_) {
void ChunkedBuffer<Converter>::push_frame(AVFrame* frame_) {
int64_t pts_ = frame_->pts;
torch::Tensor frame = converter.convert(frame_);
using namespace torch::indexing; using namespace torch::indexing;
// Note: // Note:
// Audio tensors contain multiple frames while video tensors contain only // Audio tensors contain multiple frames while video tensors contain only
...@@ -114,8 +105,7 @@ void ChunkedBuffer<Converter>::push_frame(AVFrame* frame_) { ...@@ -114,8 +105,7 @@ void ChunkedBuffer<Converter>::push_frame(AVFrame* frame_) {
} }
} }
template <typename Converter> c10::optional<Chunk> ChunkedBuffer::pop_chunk() {
c10::optional<Chunk> ChunkedBuffer<Converter>::pop_chunk() {
using namespace torch::indexing; using namespace torch::indexing;
if (!num_buffered_frames) { if (!num_buffered_frames) {
return {}; return {};
...@@ -131,171 +121,9 @@ c10::optional<Chunk> ChunkedBuffer<Converter>::pop_chunk() { ...@@ -131,171 +121,9 @@ c10::optional<Chunk> ChunkedBuffer<Converter>::pop_chunk() {
return {Chunk{chunk, pts_val}}; return {Chunk{chunk, pts_val}};
} }
template <typename Converter> void ChunkedBuffer::flush() {
void ChunkedBuffer<Converter>::flush() {
num_buffered_frames = 0; num_buffered_frames = 0;
chunks.clear(); chunks.clear();
} }
std::unique_ptr<Buffer> get_chunked_buffer(
AVRational tb,
int fpc,
int num_chunks,
AVSampleFormat fmt,
int channels) {
switch (fmt) {
case AV_SAMPLE_FMT_U8: {
using Converter = AudioConverter<torch::kUInt8, false>;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{channels});
}
case AV_SAMPLE_FMT_S16: {
using Converter = AudioConverter<torch::kInt16, false>;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{channels});
}
case AV_SAMPLE_FMT_S32: {
using Converter = AudioConverter<torch::kInt32, false>;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{channels});
}
case AV_SAMPLE_FMT_S64: {
using Converter = AudioConverter<torch::kInt64, false>;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{channels});
}
case AV_SAMPLE_FMT_FLT: {
using Converter = AudioConverter<torch::kFloat32, false>;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{channels});
}
case AV_SAMPLE_FMT_DBL: {
using Converter = AudioConverter<torch::kFloat64, false>;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{channels});
}
case AV_SAMPLE_FMT_U8P: {
using Converter = AudioConverter<torch::kUInt8, true>;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{channels});
}
case AV_SAMPLE_FMT_S16P: {
using Converter = AudioConverter<torch::kInt16, true>;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{channels});
}
case AV_SAMPLE_FMT_S32P: {
using Converter = AudioConverter<torch::kInt32, true>;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{channels});
}
case AV_SAMPLE_FMT_S64P: {
using Converter = AudioConverter<torch::kInt64, true>;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{channels});
}
case AV_SAMPLE_FMT_FLTP: {
using Converter = AudioConverter<torch::kFloat32, true>;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{channels});
}
case AV_SAMPLE_FMT_DBLP: {
using Converter = AudioConverter<torch::kFloat64, true>;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{channels});
}
default:
TORCH_INTERNAL_ASSERT(
false, "Unexpected audio type:", av_get_sample_fmt_name(fmt));
}
}
std::unique_ptr<Buffer> get_chunked_buffer(
AVRational tb,
int fpc,
int num_chunks,
AVPixelFormat fmt,
int h,
int w,
const torch::Device& device) {
if (device.type() == at::DeviceType::CUDA) {
#ifndef USE_CUDA
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
false,
"USE_CUDA is not defined, and it should be guarded before here.");
#else
switch (fmt) {
case AV_PIX_FMT_NV12: {
using Conv = NV12CudaConverter;
return std::make_unique<ChunkedBuffer<Conv>>(
tb, fpc, num_chunks, Conv{h, w, device});
}
case AV_PIX_FMT_P010: {
using Conv = P010CudaConverter;
return std::make_unique<ChunkedBuffer<Conv>>(
tb, fpc, num_chunks, Conv{h, w, device});
}
case AV_PIX_FMT_P016: {
TORCH_CHECK(
false,
"Unsupported video format found in CUDA HW: ",
av_get_pix_fmt_name(fmt));
}
default: {
TORCH_CHECK(
false,
"Unexpected video format found in CUDA HW: ",
av_get_pix_fmt_name(fmt));
}
}
#endif
}
switch (fmt) {
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24: {
using Converter = InterlacedImageConverter;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{h, w, 3});
}
case AV_PIX_FMT_ARGB:
case AV_PIX_FMT_RGBA:
case AV_PIX_FMT_ABGR:
case AV_PIX_FMT_BGRA: {
using Converter = InterlacedImageConverter;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{h, w, 4});
}
case AV_PIX_FMT_GRAY8: {
using Converter = InterlacedImageConverter;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{h, w, 1});
}
case AV_PIX_FMT_RGB48LE: {
using Converter = Interlaced16BitImageConverter;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{h, w, 3});
}
case AV_PIX_FMT_YUV444P: {
using Converter = PlanarImageConverter;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{h, w, 3});
}
case AV_PIX_FMT_YUV420P: {
using Converter = YUV420PConverter;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{h, w});
}
case AV_PIX_FMT_NV12: {
using Converter = NV12Converter;
return std::make_unique<ChunkedBuffer<Converter>>(
tb, fpc, num_chunks, Converter{h, w});
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Unexpected video format found: ", av_get_pix_fmt_name(fmt));
}
}
}
} // namespace torchaudio::io::detail } // namespace torchaudio::io::detail
#pragma once #pragma once
#include <torchaudio/csrc/ffmpeg/ffmpeg.h> #include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer.h> #include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
namespace torchaudio::io::detail { namespace torchaudio::io::detail {
////////////////////////////////////////////////////////////////////////////// class ChunkedBuffer {
// Chunked Buffer Implementation
//////////////////////////////////////////////////////////////////////////////
// Common to both audio and video
template <typename Converter>
class ChunkedBuffer : public Buffer {
// Each AVFrame is converted to a Tensor and stored here. // Each AVFrame is converted to a Tensor and stored here.
std::deque<torch::Tensor> chunks; std::deque<torch::Tensor> chunks;
// Time stamps corresponding the first frame of each chunk // Time stamps corresponding the first frame of each chunk
...@@ -26,35 +21,13 @@ class ChunkedBuffer : public Buffer { ...@@ -26,35 +21,13 @@ class ChunkedBuffer : public Buffer {
// one Tensor contains multiple samples, so we track here. // one Tensor contains multiple samples, so we track here.
int64_t num_buffered_frames = 0; int64_t num_buffered_frames = 0;
Converter converter;
public: public:
ChunkedBuffer( ChunkedBuffer(AVRational time_base, int frames_per_chunk, int num_chunks);
AVRational time_base,
int frames_per_chunk,
int num_chunks,
Converter&& converter);
bool is_ready() const override; bool is_ready() const;
void flush() override; void flush();
c10::optional<Chunk> pop_chunk() override; c10::optional<Chunk> pop_chunk();
void push_frame(AVFrame* frame_) override; void push_frame(torch::Tensor frame, int64_t pts_);
}; };
std::unique_ptr<Buffer> get_chunked_buffer(
AVRational time_base,
int frames_per_chunk,
int num_chunks,
AVSampleFormat fmt,
int num_channels);
std::unique_ptr<Buffer> get_chunked_buffer(
AVRational time_base,
int frames_per_chunk,
int num_chunks,
AVPixelFormat fmt,
int height,
int width,
const torch::Device& device);
} // namespace torchaudio::io::detail } // namespace torchaudio::io::detail
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/unchunked_buffer.h> #include <torchaudio/csrc/ffmpeg/stream_reader/buffer/unchunked_buffer.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/conversion.h>
namespace torchaudio { namespace torchaudio::io::detail {
namespace io {
namespace detail {
template <typename Converter> UnchunkedBuffer::UnchunkedBuffer(AVRational time_base) : time_base(time_base){};
UnchunkedBuffer<Converter>::UnchunkedBuffer(
AVRational time_base,
Converter&& converter)
: time_base(time_base), converter(std::move(converter)) {}
template <typename Converter> bool UnchunkedBuffer::is_ready() const {
bool UnchunkedBuffer<Converter>::is_ready() const {
return chunks.size() > 0; return chunks.size() > 0;
} }
template <typename Converter> void UnchunkedBuffer::push_frame(torch::Tensor frame, int64_t pts_) {
void UnchunkedBuffer<Converter>::push_frame(AVFrame* frame) {
if (chunks.size() == 0) { if (chunks.size() == 0) {
pts = double(frame->pts) * time_base.num / time_base.den; pts = double(pts_) * time_base.num / time_base.den;
} }
chunks.push_back(converter.convert(frame)); chunks.push_back(frame);
} }
template <typename Converter> c10::optional<Chunk> UnchunkedBuffer::pop_chunk() {
c10::optional<Chunk> UnchunkedBuffer<Converter>::pop_chunk() {
if (chunks.size() == 0) { if (chunks.size() == 0) {
return {}; return {};
} }
...@@ -36,164 +26,8 @@ c10::optional<Chunk> UnchunkedBuffer<Converter>::pop_chunk() { ...@@ -36,164 +26,8 @@ c10::optional<Chunk> UnchunkedBuffer<Converter>::pop_chunk() {
return {Chunk{frames, pts}}; return {Chunk{frames, pts}};
} }
template <typename Converter> void UnchunkedBuffer::flush() {
void UnchunkedBuffer<Converter>::flush() {
chunks.clear(); chunks.clear();
} }
std::unique_ptr<Buffer> get_unchunked_buffer( } // namespace torchaudio::io::detail
AVRational tb,
AVSampleFormat fmt,
int channels) {
switch (fmt) {
case AV_SAMPLE_FMT_U8: {
using Converter = AudioConverter<torch::kUInt8, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
}
case AV_SAMPLE_FMT_S16: {
using Converter = AudioConverter<torch::kInt16, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
}
case AV_SAMPLE_FMT_S32: {
using Converter = AudioConverter<torch::kInt32, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
}
case AV_SAMPLE_FMT_S64: {
using Converter = AudioConverter<torch::kInt64, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
}
case AV_SAMPLE_FMT_FLT: {
using Converter = AudioConverter<torch::kFloat32, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
}
case AV_SAMPLE_FMT_DBL: {
using Converter = AudioConverter<torch::kFloat64, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
}
case AV_SAMPLE_FMT_U8P: {
using Converter = AudioConverter<torch::kUInt8, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
}
case AV_SAMPLE_FMT_S16P: {
using Converter = AudioConverter<torch::kInt16, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
}
case AV_SAMPLE_FMT_S32P: {
using Converter = AudioConverter<torch::kInt32, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
}
case AV_SAMPLE_FMT_S64P: {
using Converter = AudioConverter<torch::kInt64, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
}
case AV_SAMPLE_FMT_FLTP: {
using Converter = AudioConverter<torch::kFloat32, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
}
case AV_SAMPLE_FMT_DBLP: {
using Converter = AudioConverter<torch::kFloat64, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
}
default:
TORCH_INTERNAL_ASSERT(
false, "Unexpected audio type:", av_get_sample_fmt_name(fmt));
}
}
std::unique_ptr<Buffer> get_unchunked_buffer(
AVRational tb,
AVPixelFormat fmt,
int h,
int w,
const torch::Device& device) {
if (device.type() == at::DeviceType::CUDA) {
#ifndef USE_CUDA
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
false,
"USE_CUDA is not defined, and it should be guarded before here.");
#else
switch (fmt) {
case AV_PIX_FMT_NV12: {
using Conv = NV12CudaConverter;
return std::make_unique<UnchunkedBuffer<Conv>>(tb, Conv{h, w, device});
}
case AV_PIX_FMT_P010: {
using Conv = P010CudaConverter;
return std::make_unique<UnchunkedBuffer<Conv>>(tb, Conv{h, w, device});
}
case AV_PIX_FMT_P016: {
TORCH_CHECK(
false,
"Unsupported video format found in CUDA HW: ",
av_get_pix_fmt_name(fmt));
}
default: {
TORCH_CHECK(
false,
"Unexpected video format found in CUDA HW: ",
av_get_pix_fmt_name(fmt));
}
}
#endif
}
switch (fmt) {
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24: {
using Converter = InterlacedImageConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{h, w, 3});
}
case AV_PIX_FMT_ARGB:
case AV_PIX_FMT_RGBA:
case AV_PIX_FMT_ABGR:
case AV_PIX_FMT_BGRA: {
using Converter = InterlacedImageConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{h, w, 4});
}
case AV_PIX_FMT_GRAY8: {
using Converter = InterlacedImageConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{h, w, 1});
}
case AV_PIX_FMT_RGB48LE: {
using Converter = Interlaced16BitImageConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{h, w, 3});
}
case AV_PIX_FMT_YUV444P: {
using Converter = PlanarImageConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{h, w, 3});
}
case AV_PIX_FMT_YUV420P: {
using Converter = YUV420PConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(tb, Converter{h, w});
}
case AV_PIX_FMT_NV12: {
using Converter = NV12Converter;
return std::make_unique<UnchunkedBuffer<Converter>>(tb, Converter{h, w});
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Unexpected video format found: ", av_get_pix_fmt_name(fmt));
}
}
}
} // namespace detail
} // namespace io
} // namespace torchaudio
#pragma once #pragma once
#include <torch/torch.h> #include <torch/torch.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h> #include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer.h> #include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
#include <deque> #include <deque>
namespace torchaudio::io::detail { namespace torchaudio::io::detail {
////////////////////////////////////////////////////////////////////////////// class UnchunkedBuffer {
// Unchunked Buffer Interface
//////////////////////////////////////////////////////////////////////////////
// Partial implementation for unchunked buffer common to both audio and video
// Used for buffering audio/video streams without chunking
template <typename Converter>
class UnchunkedBuffer : public Buffer {
// Each AVFrame is converted to a Tensor and stored here. // Each AVFrame is converted to a Tensor and stored here.
std::deque<torch::Tensor> chunks; std::deque<torch::Tensor> chunks;
double pts = -1.; double pts = -1.;
AVRational time_base; AVRational time_base;
Converter converter;
public: public:
UnchunkedBuffer(AVRational time_base, Converter&& converter); UnchunkedBuffer(AVRational time_base);
bool is_ready() const override; bool is_ready() const;
void push_frame(AVFrame* frame) override; void push_frame(torch::Tensor frame, int64_t pts_);
c10::optional<Chunk> pop_chunk() override; c10::optional<Chunk> pop_chunk();
void flush() override; void flush();
}; };
std::unique_ptr<Buffer> get_unchunked_buffer(
AVRational time_base,
AVSampleFormat fmt,
int num_channels);
std::unique_ptr<Buffer> get_unchunked_buffer(
AVRational time_base,
AVPixelFormat fmt,
int height,
int width,
const torch::Device& device);
} // namespace torchaudio::io::detail } // namespace torchaudio::io::detail
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/chunked_buffer.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/unchunked_buffer.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/conversion.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/post_process.h>
namespace torchaudio::io {
namespace detail {
namespace {
///////////////////////////////////////////////////////////////////////////////
// FilterGraphWrapper (FilterGraph + reset feature)
///////////////////////////////////////////////////////////////////////////////
using FilterGraphFactory = std::function<FilterGraph(const std::string&)>;
FilterGraphFactory get_audio_factory(
AVRational time_base,
AVCodecContext* codec_ctx) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(codec_ctx->codec_type == AVMEDIA_TYPE_AUDIO);
return [fmt = codec_ctx->sample_fmt,
time_base,
rate = codec_ctx->sample_rate,
channel_layout = codec_ctx->channel_layout](
const std::string& filter_desc) -> FilterGraph {
FilterGraph f{AVMEDIA_TYPE_AUDIO};
f.add_audio_src(fmt, time_base, rate, channel_layout);
f.add_sink();
f.add_process(filter_desc);
f.create_filter();
return f;
};
}
FilterGraphFactory get_video_factory(
AVRational time_base,
AVRational frame_rate,
AVCodecContext* codec_ctx) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(codec_ctx->codec_type == AVMEDIA_TYPE_VIDEO);
return [fmt = codec_ctx->pix_fmt,
time_base,
frame_rate,
w = codec_ctx->width,
h = codec_ctx->height,
ratio = codec_ctx->sample_aspect_ratio,
hw_frames_ctx = codec_ctx->hw_frames_ctx](
const std::string& filter_desc) -> FilterGraph {
FilterGraph f{AVMEDIA_TYPE_VIDEO};
f.add_video_src(fmt, time_base, frame_rate, w, h, ratio);
f.add_sink();
f.add_process(filter_desc);
if (hw_frames_ctx) {
f.create_filter(av_buffer_ref(hw_frames_ctx));
} else {
f.create_filter();
}
return f;
};
}
struct FilterGraphWrapper {
const std::string desc;
private:
FilterGraphFactory factory;
public:
FilterGraph filter;
// Constructor for audio input
FilterGraphWrapper(
AVRational input_time_base,
AVCodecContext* codec_ctx,
const std::string& desc)
: desc(desc),
factory(get_audio_factory(input_time_base, codec_ctx)),
filter(factory(desc)) {}
// Constructor for video input
FilterGraphWrapper(
AVRational input_time_base,
AVRational frame_rate,
AVCodecContext* codec_ctx,
const std::string& desc)
: desc(desc),
factory(get_video_factory(input_time_base, frame_rate, codec_ctx)),
filter(factory(desc)) {}
void reset() {
filter = factory(desc);
}
};
///////////////////////////////////////////////////////////////////////////////
// ProcessImpl
///////////////////////////////////////////////////////////////////////////////
template <typename Converter, typename Buffer>
struct ProcessImpl : public IPostDecodeProcess {
private:
AVFramePtr frame{};
FilterGraphWrapper filter_wrapper;
public:
Converter converter;
Buffer buffer;
ProcessImpl(
FilterGraphWrapper&& filter_wrapper,
Converter&& converter,
Buffer&& buffer)
: filter_wrapper(std::move(filter_wrapper)),
converter(std::move(converter)),
buffer(std::move(buffer)) {}
bool is_buffer_ready() const override {
return buffer.is_ready();
}
const std::string& get_filter_desc() const override {
return filter_wrapper.desc;
};
FilterGraphOutputInfo get_filter_output_info() const override {
return filter_wrapper.filter.get_output_info();
};
void flush() override {
filter_wrapper.reset();
buffer.flush();
}
int process_frame(AVFrame* in_frame) override {
int ret = filter_wrapper.filter.add_frame(in_frame);
while (ret >= 0) {
ret = filter_wrapper.filter.get_frame(frame);
// AVERROR(EAGAIN) means that new input data is required to return new
// output.
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
return 0;
}
if (ret >= 0) {
buffer.push_frame(converter.convert(frame), frame->pts);
}
av_frame_unref(frame);
}
return ret;
}
c10::optional<Chunk> pop_chunk() override {
return buffer.pop_chunk();
}
};
///////////////////////////////////////////////////////////////////////////////
// Audio
///////////////////////////////////////////////////////////////////////////////
std::unique_ptr<IPostDecodeProcess> get_unchunked_audio_process(
FilterGraphWrapper&& filter) {
auto i = filter.filter.get_output_info();
TORCH_INTERNAL_ASSERT(
i.type == AVMEDIA_TYPE_AUDIO,
"Unsupported media type found: ",
av_get_media_type_string(i.type));
using B = UnchunkedBuffer;
switch (auto fmt = (AVSampleFormat)i.format; fmt) {
case AV_SAMPLE_FMT_U8: {
using C = AudioConverter<torch::kUInt8, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_S16: {
using C = AudioConverter<torch::kInt16, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_S32: {
using C = AudioConverter<torch::kInt32, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_S64: {
using C = AudioConverter<torch::kInt64, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_FLT: {
using C = AudioConverter<torch::kFloat32, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_DBL: {
using C = AudioConverter<torch::kFloat64, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_U8P: {
using C = AudioConverter<torch::kUInt8, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_S16P: {
using C = AudioConverter<torch::kInt16, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_S32P: {
using C = AudioConverter<torch::kInt32, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_S64P: {
using C = AudioConverter<torch::kInt64, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_FLTP: {
using C = AudioConverter<torch::kFloat32, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
case AV_SAMPLE_FMT_DBLP: {
using C = AudioConverter<torch::kFloat64, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, B{i.time_base});
}
default:
TORCH_INTERNAL_ASSERT(
false, "Unexpected audio type:", av_get_sample_fmt_name(fmt));
}
}
std::unique_ptr<IPostDecodeProcess> get_chunked_audio_process(
FilterGraphWrapper&& filter,
int frames_per_chunk,
int num_chunks) {
auto i = filter.filter.get_output_info();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
i.type == AVMEDIA_TYPE_AUDIO,
"Unsupported media type found: ",
av_get_media_type_string(i.type));
using B = ChunkedBuffer;
B buffer{i.time_base, frames_per_chunk, num_chunks};
switch (auto fmt = (AVSampleFormat)i.format; fmt) {
case AV_SAMPLE_FMT_U8: {
using C = AudioConverter<torch::kUInt8, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_S16: {
using C = AudioConverter<torch::kInt16, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_S32: {
using C = AudioConverter<torch::kInt32, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_S64: {
using C = AudioConverter<torch::kInt64, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_FLT: {
using C = AudioConverter<torch::kFloat32, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_DBL: {
using C = AudioConverter<torch::kFloat64, false>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_U8P: {
using C = AudioConverter<torch::kUInt8, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_S16P: {
using C = AudioConverter<torch::kInt16, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_S32P: {
using C = AudioConverter<torch::kInt32, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_S64P: {
using C = AudioConverter<torch::kInt64, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_FLTP: {
using C = AudioConverter<torch::kFloat32, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
case AV_SAMPLE_FMT_DBLP: {
using C = AudioConverter<torch::kFloat64, true>;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.num_channels}, std::move(buffer));
}
default:
TORCH_INTERNAL_ASSERT(
false, "Unexpected audio type:", av_get_sample_fmt_name(fmt));
}
}
///////////////////////////////////////////////////////////////////////////////
// Video
///////////////////////////////////////////////////////////////////////////////
std::unique_ptr<IPostDecodeProcess> get_unchunked_video_process(
FilterGraphWrapper&& filter) {
auto i = filter.filter.get_output_info();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
i.type == AVMEDIA_TYPE_VIDEO,
"Unsupported media type found: ",
av_get_media_type_string(i.type));
auto h = i.height;
auto w = i.width;
auto tb = i.time_base;
using B = UnchunkedBuffer;
switch (auto fmt = (AVPixelFormat)i.format; fmt) {
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24: {
using C = InterlacedImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 3}, B{tb});
}
case AV_PIX_FMT_ARGB:
case AV_PIX_FMT_RGBA:
case AV_PIX_FMT_ABGR:
case AV_PIX_FMT_BGRA: {
using C = InterlacedImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 4}, B{tb});
}
case AV_PIX_FMT_GRAY8: {
using C = InterlacedImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 1}, B{tb});
}
case AV_PIX_FMT_RGB48LE: {
using C = Interlaced16BitImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 3}, B{tb});
}
case AV_PIX_FMT_YUV444P: {
using C = PlanarImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 3}, B{tb});
}
case AV_PIX_FMT_YUV420P: {
using C = YUV420PConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w}, B{tb});
}
case AV_PIX_FMT_NV12: {
using C = NV12Converter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w}, B{tb});
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Unexpected video format found: ", av_get_pix_fmt_name(fmt));
}
}
}
std::unique_ptr<IPostDecodeProcess> get_unchunked_cuda_video_process(
FilterGraphWrapper&& filter,
const torch::Device& device) {
#ifndef USE_CUDA
TORCH_INTERNAL_ASSERT(
false,
"USE_CUDA is not defined, but CUDA decoding process was requested.");
#else
auto i = filter.filter.get_output_info();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
i.type == AVMEDIA_TYPE_VIDEO,
"Unsupported media type found: ",
av_get_media_type_string(i.type));
using B = UnchunkedBuffer;
switch (auto fmt = (AVPixelFormat)i.format; fmt) {
case AV_PIX_FMT_NV12: {
using C = NV12CudaConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.height, i.width, device}, B{i.time_base});
}
case AV_PIX_FMT_P010: {
using C = P010CudaConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{i.height, i.width, device}, B{i.time_base});
}
case AV_PIX_FMT_P016: {
TORCH_CHECK(
false,
"Unsupported video format found in CUDA HW: ",
av_get_pix_fmt_name(fmt));
}
default: {
TORCH_CHECK(
false,
"Unexpected video format found in CUDA HW: ",
av_get_pix_fmt_name(fmt));
}
}
#endif
}
std::unique_ptr<IPostDecodeProcess> get_chunked_video_process(
FilterGraphWrapper&& filter,
int frames_per_chunk,
int num_chunks) {
auto i = filter.filter.get_output_info();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
i.type == AVMEDIA_TYPE_VIDEO,
"Unsupported media type found: ",
av_get_media_type_string(i.type));
auto h = i.height;
auto w = i.width;
auto tb = i.time_base;
using B = ChunkedBuffer;
switch (auto fmt = (AVPixelFormat)i.format; fmt) {
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24: {
using C = InterlacedImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 3}, B{tb, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_ARGB:
case AV_PIX_FMT_RGBA:
case AV_PIX_FMT_ABGR:
case AV_PIX_FMT_BGRA: {
using C = InterlacedImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 4}, B{tb, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_GRAY8: {
using C = InterlacedImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 1}, B{tb, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_RGB48LE: {
using C = Interlaced16BitImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 3}, B{tb, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_YUV444P: {
using C = PlanarImageConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w, 3}, B{tb, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_YUV420P: {
using C = YUV420PConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w}, B{tb, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_NV12: {
using C = NV12Converter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter), C{h, w}, B{tb, frames_per_chunk, num_chunks});
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Unexpected video format found: ", av_get_pix_fmt_name(fmt));
}
}
}
std::unique_ptr<IPostDecodeProcess> get_chunked_cuda_video_process(
FilterGraphWrapper&& filter,
int frames_per_chunk,
int num_chunks,
const torch::Device& device) {
#ifndef USE_CUDA
TORCH_INTERNAL_ASSERT(
false,
"USE_CUDA is not defined, but CUDA decoding process was requested.");
#else
auto i = filter.filter.get_output_info();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
i.type == AVMEDIA_TYPE_VIDEO,
"Unsupported media type found: ",
av_get_media_type_string(i.type));
using B = ChunkedBuffer;
switch (auto fmt = (AVPixelFormat)i.format; fmt) {
case AV_PIX_FMT_NV12: {
using C = NV12CudaConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter),
C{i.height, i.width, device},
B{i.time_base, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_P010: {
using C = P010CudaConverter;
return std::make_unique<ProcessImpl<C, B>>(
std::move(filter),
C{i.height, i.width, device},
B{i.time_base, frames_per_chunk, num_chunks});
}
case AV_PIX_FMT_P016: {
TORCH_CHECK(
false,
"Unsupported video format found in CUDA HW: ",
av_get_pix_fmt_name(fmt));
}
default: {
TORCH_CHECK(
false,
"Unexpected video format found in CUDA HW: ",
av_get_pix_fmt_name(fmt));
}
}
#endif
}
} // namespace
} // namespace detail
std::unique_ptr<IPostDecodeProcess> get_audio_process(
AVRational input_time_base,
AVCodecContext* codec_ctx,
const std::string& desc,
int frames_per_chunk,
int num_chunks) {
TORCH_CHECK(
frames_per_chunk > 0 || frames_per_chunk == -1,
"`frames_per_chunk` must be positive or -1. Found: ",
frames_per_chunk);
TORCH_CHECK(
num_chunks > 0 || num_chunks == -1,
"`num_chunks` must be positive or -1. Found: ",
num_chunks);
detail::FilterGraphWrapper filter{input_time_base, codec_ctx, desc};
if (frames_per_chunk == -1) {
return detail::get_unchunked_audio_process(std::move(filter));
}
return detail::get_chunked_audio_process(
std::move(filter), frames_per_chunk, num_chunks);
}
std::unique_ptr<IPostDecodeProcess> get_video_process(
AVRational input_time_base,
AVRational frame_rate,
AVCodecContext* codec_ctx,
const std::string& desc,
int frames_per_chunk,
int num_chunks,
const torch::Device& device) {
TORCH_CHECK(
frames_per_chunk > 0 || frames_per_chunk == -1,
"`frames_per_chunk` must be positive or -1. Found: ",
frames_per_chunk);
TORCH_CHECK(
num_chunks > 0 || num_chunks == -1,
"`num_chunks` must be positive or -1. Found: ",
num_chunks);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
device.is_cuda() || device.is_cpu(), "Unexpected device type: ", device);
detail::FilterGraphWrapper filter{
input_time_base, frame_rate, codec_ctx, desc};
if (frames_per_chunk == -1) {
if (device.is_cuda()) {
return detail::get_unchunked_cuda_video_process(
std::move(filter), device);
}
return detail::get_unchunked_video_process(std::move(filter));
}
if (device.is_cuda()) {
return detail::get_chunked_cuda_video_process(
std::move(filter), frames_per_chunk, num_chunks, device);
}
return detail::get_chunked_video_process(
std::move(filter), frames_per_chunk, num_chunks);
}
} // namespace torchaudio::io
#pragma once
#include <torchaudio/csrc/ffmpeg/filter_graph.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
namespace torchaudio::io {
struct IPostDecodeProcess {
virtual ~IPostDecodeProcess() = default;
virtual int process_frame(AVFrame* frame) = 0;
virtual c10::optional<Chunk> pop_chunk() = 0;
virtual bool is_buffer_ready() const = 0;
virtual const std::string& get_filter_desc() const = 0;
virtual FilterGraphOutputInfo get_filter_output_info() const = 0;
virtual void flush() = 0;
};
std::unique_ptr<IPostDecodeProcess> get_audio_process(
AVRational input_time_base,
AVCodecContext* codec_ctx,
const std::string& desc,
int frames_per_chunk,
int num_chunks);
std::unique_ptr<IPostDecodeProcess> get_video_process(
AVRational input_time_base,
AVRational frame_rate,
AVCodecContext* codec_ctx,
const std::string& desc,
int frames_per_chunk,
int num_chunks,
const torch::Device& device);
} // namespace torchaudio::io
#include <torchaudio/csrc/ffmpeg/hw_context.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/chunked_buffer.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer/unchunked_buffer.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/sink.h>
#include <stdexcept>
namespace torchaudio {
namespace io {
namespace {
std::unique_ptr<Buffer> get_buffer(
FilterGraph& filter,
int frames_per_chunk,
int num_chunks,
const torch::Device& device) {
TORCH_CHECK(
frames_per_chunk > 0 || frames_per_chunk == -1,
"`frames_per_chunk` must be positive or -1. Found: ",
frames_per_chunk);
TORCH_CHECK(
num_chunks > 0 || num_chunks == -1,
"`num_chunks` must be positive or -1. Found: ",
num_chunks);
auto info = filter.get_output_info();
TORCH_CHECK(
info.type == AVMEDIA_TYPE_AUDIO || info.type == AVMEDIA_TYPE_VIDEO,
"Unsupported media type: ",
av_get_media_type_string(info.type),
". Only video or audio is supported ");
if (info.type == AVMEDIA_TYPE_AUDIO) {
AVSampleFormat fmt = (AVSampleFormat)(info.format);
if (frames_per_chunk == -1) {
return detail::get_unchunked_buffer(
info.time_base, fmt, info.num_channels);
} else {
return detail::get_chunked_buffer(
info.time_base, frames_per_chunk, num_chunks, fmt, info.num_channels);
}
} else {
AVPixelFormat fmt = (AVPixelFormat)(info.format);
TORCH_INTERNAL_ASSERT(fmt != AV_PIX_FMT_CUDA);
if (frames_per_chunk == -1) {
return detail::get_unchunked_buffer(
info.time_base, fmt, info.height, info.width, device);
} else {
return detail::get_chunked_buffer(
info.time_base,
frames_per_chunk,
num_chunks,
fmt,
info.height,
info.width,
device);
}
}
}
FilterGraph get_filter_graph(
AVRational input_time_base,
AVCodecContext* codec_ctx,
AVRational frame_rate,
const std::string& filter_description) {
auto p = FilterGraph{codec_ctx->codec_type};
switch (codec_ctx->codec_type) {
case AVMEDIA_TYPE_AUDIO:
p.add_audio_src(
codec_ctx->sample_fmt,
input_time_base,
codec_ctx->sample_rate,
codec_ctx->channel_layout);
break;
case AVMEDIA_TYPE_VIDEO:
p.add_video_src(
codec_ctx->pix_fmt,
input_time_base,
frame_rate,
codec_ctx->width,
codec_ctx->height,
codec_ctx->sample_aspect_ratio);
break;
default:
TORCH_CHECK(false, "Only audio/video are supported.");
}
p.add_sink();
p.add_process(filter_description);
if (codec_ctx->hw_frames_ctx) {
p.create_filter(av_buffer_ref(codec_ctx->hw_frames_ctx));
} else {
p.create_filter(nullptr);
}
return p;
}
} // namespace
Sink::Sink(
AVRational input_time_base_,
AVCodecContext* codec_ctx_,
int frames_per_chunk,
int num_chunks,
AVRational frame_rate_,
const std::string& filter_desc,
const torch::Device& device)
: input_time_base(input_time_base_),
codec_ctx(codec_ctx_),
frame_rate(frame_rate_),
filter_description(filter_desc),
filter(get_filter_graph(
input_time_base_,
codec_ctx,
frame_rate,
filter_description)),
buffer(get_buffer(filter, frames_per_chunk, num_chunks, device)) {}
// 0: some kind of success
// <0: Some error happened
int Sink::process_frame(AVFrame* pFrame) {
int ret = filter.add_frame(pFrame);
while (ret >= 0) {
ret = filter.get_frame(frame);
// AVERROR(EAGAIN) means that new input data is required to return new
// output.
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
return 0;
}
if (ret >= 0) {
buffer->push_frame(frame);
}
av_frame_unref(frame);
}
return ret;
}
void Sink::flush() {
filter = get_filter_graph(
input_time_base, codec_ctx, frame_rate, filter_description);
buffer->flush();
}
} // namespace io
} // namespace torchaudio
#pragma once
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/filter_graph.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/buffer.h>
namespace torchaudio {
namespace io {
class Sink {
AVFramePtr frame{};
// Parameters for recreating FilterGraph
AVRational input_time_base;
AVCodecContext* codec_ctx;
AVRational frame_rate;
public:
const std::string filter_description;
FilterGraph filter;
std::unique_ptr<Buffer> buffer;
Sink(
AVRational input_time_base,
AVCodecContext* codec_ctx,
int frames_per_chunk,
int num_chunks,
AVRational frame_rate,
const std::string& filter_description,
const torch::Device& device);
int process_frame(AVFrame* frame);
bool is_buffer_ready() const;
void flush();
};
} // namespace io
} // namespace torchaudio
...@@ -218,28 +218,36 @@ KeyType StreamProcessor::add_stream( ...@@ -218,28 +218,36 @@ KeyType StreamProcessor::add_stream(
switch (codec_ctx->codec_type) { switch (codec_ctx->codec_type) {
case AVMEDIA_TYPE_AUDIO: case AVMEDIA_TYPE_AUDIO:
case AVMEDIA_TYPE_VIDEO: post_processes.emplace(
break;
default:
TORCH_CHECK(false, "Only Audio and Video are supported");
}
KeyType key = current_key++;
sinks.emplace(
std::piecewise_construct, std::piecewise_construct,
std::forward_as_tuple(key), std::forward_as_tuple(current_key),
std::forward_as_tuple( std::forward_as_tuple(get_audio_process(
stream_time_base, stream_time_base,
codec_ctx, codec_ctx,
filter_description,
frames_per_chunk, frames_per_chunk,
num_chunks, num_chunks)));
return current_key++;
case AVMEDIA_TYPE_VIDEO:
post_processes.emplace(
std::piecewise_construct,
std::forward_as_tuple(current_key),
std::forward_as_tuple(get_video_process(
stream_time_base,
frame_rate, frame_rate,
codec_ctx,
filter_description, filter_description,
device)); frames_per_chunk,
return key; num_chunks,
device)));
return current_key++;
default:
TORCH_CHECK(false, "Only Audio and Video are supported");
}
} }
void StreamProcessor::remove_stream(KeyType key) { void StreamProcessor::remove_stream(KeyType key) {
sinks.erase(key); post_processes.erase(key);
} }
void StreamProcessor::set_discard_timestamp(int64_t timestamp) { void StreamProcessor::set_discard_timestamp(int64_t timestamp) {
...@@ -252,17 +260,17 @@ void StreamProcessor::set_discard_timestamp(int64_t timestamp) { ...@@ -252,17 +260,17 @@ void StreamProcessor::set_discard_timestamp(int64_t timestamp) {
// Query methods // Query methods
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
std::string StreamProcessor::get_filter_description(KeyType key) const { std::string StreamProcessor::get_filter_description(KeyType key) const {
return sinks.at(key).filter_description; return post_processes.at(key)->get_filter_desc();
} }
FilterGraphOutputInfo StreamProcessor::get_filter_output_info( FilterGraphOutputInfo StreamProcessor::get_filter_output_info(
KeyType key) const { KeyType key) const {
return sinks.at(key).filter.get_output_info(); return post_processes.at(key)->get_filter_output_info();
} }
bool StreamProcessor::is_buffer_ready() const { bool StreamProcessor::is_buffer_ready() const {
for (const auto& it : sinks) { for (const auto& it : post_processes) {
if (!it.second.buffer->is_ready()) { if (!it.second->is_buffer_ready()) {
return false; return false;
} }
} }
...@@ -331,8 +339,8 @@ int StreamProcessor::process_packet(AVPacket* packet) { ...@@ -331,8 +339,8 @@ int StreamProcessor::process_packet(AVPacket* packet) {
void StreamProcessor::flush() { void StreamProcessor::flush() {
avcodec_flush_buffers(codec_ctx); avcodec_flush_buffers(codec_ctx);
for (auto& ite : sinks) { for (auto& ite : post_processes) {
ite.second.flush(); ite.second->flush();
} }
} }
...@@ -340,8 +348,8 @@ void StreamProcessor::flush() { ...@@ -340,8 +348,8 @@ void StreamProcessor::flush() {
// <0: Some error happened // <0: Some error happened
int StreamProcessor::send_frame(AVFrame* frame_) { int StreamProcessor::send_frame(AVFrame* frame_) {
int ret = 0; int ret = 0;
for (auto& ite : sinks) { for (auto& ite : post_processes) {
int ret2 = ite.second.process_frame(frame_); int ret2 = ite.second->process_frame(frame_);
if (ret2 < 0) if (ret2 < 0)
ret = ret2; ret = ret2;
} }
...@@ -352,7 +360,7 @@ int StreamProcessor::send_frame(AVFrame* frame_) { ...@@ -352,7 +360,7 @@ int StreamProcessor::send_frame(AVFrame* frame_) {
// Retrieval // Retrieval
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
c10::optional<Chunk> StreamProcessor::pop_chunk(KeyType key) { c10::optional<Chunk> StreamProcessor::pop_chunk(KeyType key) {
return sinks.at(key).buffer->pop_chunk(); return post_processes.at(key)->pop_chunk();
} }
} // namespace io } // namespace io
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include <torch/torch.h> #include <torch/torch.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h> #include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/sink.h> #include <torchaudio/csrc/ffmpeg/stream_reader/post_process.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h> #include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
#include <map> #include <map>
...@@ -22,7 +22,7 @@ class StreamProcessor { ...@@ -22,7 +22,7 @@ class StreamProcessor {
AVFramePtr frame; AVFramePtr frame;
KeyType current_key = 0; KeyType current_key = 0;
std::map<KeyType, Sink> sinks; std::map<KeyType, std::unique_ptr<IPostDecodeProcess>> post_processes;
// Used for precise seek. // Used for precise seek.
// 0: no discard // 0: no discard
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment