Commit c17226a0 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Refactor StreamReader internals (#3184)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3184

Tweak internals of StreamReader
1. Pass time_base to Buffer class so that
    * no need to pass frame_duration separately
    * Conversion of PTS to double type can be delayed until when it's popped
2. Merge `get_output_timebase` method into `get_output_stream_info`.
3. If filter description is not provided, fill in null filter at top-level StreamReader
4. Expose filer and filter description from Sink class to get rid of wrapper get methods.

Reviewed By: nateanl

Differential Revision: D44207976

fbshipit-source-id: f25ac9be69c9897e9dcec0c6e978f29b83b166e8
parent 9533d300
...@@ -177,33 +177,35 @@ void FilterGraph::create_filter(AVBufferRef* hw_frames_ctx) { ...@@ -177,33 +177,35 @@ void FilterGraph::create_filter(AVBufferRef* hw_frames_ctx) {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Query methods // Query methods
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
AVRational FilterGraph::get_output_timebase() const {
TORCH_INTERNAL_ASSERT(buffersink_ctx, "FilterGraph is not initialized.");
return buffersink_ctx->inputs[0]->time_base;
}
FilterGraphOutputInfo FilterGraph::get_output_info() const { FilterGraphOutputInfo FilterGraph::get_output_info() const {
TORCH_INTERNAL_ASSERT(buffersink_ctx, "FilterGraph is not initialized."); TORCH_INTERNAL_ASSERT(buffersink_ctx, "FilterGraph is not initialized.");
AVFilterLink* l = buffersink_ctx->inputs[0]; AVFilterLink* l = buffersink_ctx->inputs[0];
FilterGraphOutputInfo ret{}; FilterGraphOutputInfo ret{};
ret.type = l->type; ret.type = l->type;
ret.format = l->format; ret.format = l->format;
if (l->type == AVMEDIA_TYPE_AUDIO) { ret.time_base = l->time_base;
ret.sample_rate = l->sample_rate; switch (l->type) {
case AVMEDIA_TYPE_AUDIO: {
ret.sample_rate = l->sample_rate;
#if LIBAVFILTER_VERSION_MAJOR >= 8 && LIBAVFILTER_VERSION_MINOR >= 44 #if LIBAVFILTER_VERSION_MAJOR >= 8 && LIBAVFILTER_VERSION_MINOR >= 44
ret.num_channels = l->ch_layout.nb_channels; ret.num_channels = l->ch_layout.nb_channels;
#else #else
// Before FFmpeg 5.1 // Before FFmpeg 5.1
ret.num_channels = av_get_channel_layout_nb_channels(l->channel_layout); ret.num_channels = av_get_channel_layout_nb_channels(l->channel_layout);
#endif #endif
} else { break;
if (l->format == AV_PIX_FMT_CUDA && l->hw_frames_ctx) { }
auto frames_ctx = (AVHWFramesContext*)(l->hw_frames_ctx->data); case AVMEDIA_TYPE_VIDEO: {
ret.format = frames_ctx->sw_format; if (l->format == AV_PIX_FMT_CUDA && l->hw_frames_ctx) {
auto frames_ctx = (AVHWFramesContext*)(l->hw_frames_ctx->data);
ret.format = frames_ctx->sw_format;
}
ret.frame_rate = l->frame_rate;
ret.height = l->h;
ret.width = l->w;
break;
} }
ret.frame_rate = l->frame_rate; default:;
ret.height = l->h;
ret.width = l->w;
} }
return ret; return ret;
} }
......
...@@ -9,6 +9,8 @@ struct FilterGraphOutputInfo { ...@@ -9,6 +9,8 @@ struct FilterGraphOutputInfo {
AVMediaType type = AVMEDIA_TYPE_UNKNOWN; AVMediaType type = AVMEDIA_TYPE_UNKNOWN;
int format = -1; int format = -1;
AVRational time_base = {1, 1};
// Audio // Audio
int sample_rate = -1; int sample_rate = -1;
int num_channels = -1; int num_channels = -1;
...@@ -68,7 +70,6 @@ class FilterGraph { ...@@ -68,7 +70,6 @@ class FilterGraph {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Query methods // Query methods
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
[[nodiscard]] AVRational get_output_timebase() const;
[[nodiscard]] FilterGraphOutputInfo get_output_info() const; [[nodiscard]] FilterGraphOutputInfo get_output_info() const;
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
......
...@@ -22,7 +22,7 @@ class Buffer { ...@@ -22,7 +22,7 @@ class Buffer {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Modifiers // Modifiers
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
virtual void push_frame(AVFrame* frame, double pts) = 0; virtual void push_frame(AVFrame* frame) = 0;
virtual c10::optional<Chunk> pop_chunk() = 0; virtual c10::optional<Chunk> pop_chunk() = 0;
......
...@@ -5,11 +5,11 @@ namespace torchaudio::io::detail { ...@@ -5,11 +5,11 @@ namespace torchaudio::io::detail {
template <typename Converter> template <typename Converter>
ChunkedBuffer<Converter>::ChunkedBuffer( ChunkedBuffer<Converter>::ChunkedBuffer(
AVRational time_base,
int frames_per_chunk_, int frames_per_chunk_,
int num_chunks_, int num_chunks_,
double frame_duration_,
Converter&& converter_) Converter&& converter_)
: frame_duration(frame_duration_), : time_base(time_base),
frames_per_chunk(frames_per_chunk_), frames_per_chunk(frames_per_chunk_),
num_chunks(num_chunks_), num_chunks(num_chunks_),
converter(std::move(converter_)){}; converter(std::move(converter_)){};
...@@ -20,7 +20,8 @@ bool ChunkedBuffer<Converter>::is_ready() const { ...@@ -20,7 +20,8 @@ bool ChunkedBuffer<Converter>::is_ready() const {
} }
template <typename Converter> template <typename Converter>
void ChunkedBuffer<Converter>::push_frame(AVFrame* frame_, double pts_) { void ChunkedBuffer<Converter>::push_frame(AVFrame* frame_) {
int64_t pts_ = frame_->pts;
torch::Tensor frame = converter.convert(frame_); torch::Tensor frame = converter.convert(frame_);
using namespace torch::indexing; using namespace torch::indexing;
...@@ -61,7 +62,7 @@ void ChunkedBuffer<Converter>::push_frame(AVFrame* frame_, double pts_) { ...@@ -61,7 +62,7 @@ void ChunkedBuffer<Converter>::push_frame(AVFrame* frame_, double pts_) {
num_buffered_frames += append; num_buffered_frames += append;
// frame = frame[append:] // frame = frame[append:]
frame = frame.index({Slice(append)}); frame = frame.index({Slice(append)});
pts_ += double(append) * frame_duration; pts_ += append;
} }
// 2. Return if the number of input frames are smaller than the empty buffer. // 2. Return if the number of input frames are smaller than the empty buffer.
...@@ -85,7 +86,7 @@ void ChunkedBuffer<Converter>::push_frame(AVFrame* frame_, double pts_) { ...@@ -85,7 +86,7 @@ void ChunkedBuffer<Converter>::push_frame(AVFrame* frame_, double pts_) {
int64_t start = i * frames_per_chunk; int64_t start = i * frames_per_chunk;
// chunk = frame[i*frames_per_chunk:(i+1) * frames_per_chunk] // chunk = frame[i*frames_per_chunk:(i+1) * frames_per_chunk]
auto chunk = frame.index({Slice(start, start + frames_per_chunk)}); auto chunk = frame.index({Slice(start, start + frames_per_chunk)});
double pts_val = pts_ + double(start) * frame_duration; int64_t pts_val = pts_ + start;
int64_t chunk_size = chunk.size(0); int64_t chunk_size = chunk.size(0);
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
chunk_size <= frames_per_chunk, chunk_size <= frames_per_chunk,
...@@ -120,7 +121,7 @@ c10::optional<Chunk> ChunkedBuffer<Converter>::pop_chunk() { ...@@ -120,7 +121,7 @@ c10::optional<Chunk> ChunkedBuffer<Converter>::pop_chunk() {
return {}; return {};
} }
torch::Tensor chunk = chunks.front(); torch::Tensor chunk = chunks.front();
double pts_val = pts.front(); double pts_val = double(pts.front()) * time_base.num / time_base.den;
chunks.pop_front(); chunks.pop_front();
pts.pop_front(); pts.pop_front();
if (num_buffered_frames < frames_per_chunk) { if (num_buffered_frames < frames_per_chunk) {
...@@ -137,71 +138,71 @@ void ChunkedBuffer<Converter>::flush() { ...@@ -137,71 +138,71 @@ void ChunkedBuffer<Converter>::flush() {
} }
std::unique_ptr<Buffer> get_chunked_buffer( std::unique_ptr<Buffer> get_chunked_buffer(
int frames_per_chunk, AVRational tb,
int fpc,
int num_chunks, int num_chunks,
double frame_duration,
AVSampleFormat fmt, AVSampleFormat fmt,
int channels) { int channels) {
switch (fmt) { switch (fmt) {
case AV_SAMPLE_FMT_U8: { case AV_SAMPLE_FMT_U8: {
using Converter = AudioConverter<torch::kUInt8, false>; using Converter = AudioConverter<torch::kUInt8, false>;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels}); tb, fpc, num_chunks, Converter{channels});
} }
case AV_SAMPLE_FMT_S16: { case AV_SAMPLE_FMT_S16: {
using Converter = AudioConverter<torch::kInt16, false>; using Converter = AudioConverter<torch::kInt16, false>;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels}); tb, fpc, num_chunks, Converter{channels});
} }
case AV_SAMPLE_FMT_S32: { case AV_SAMPLE_FMT_S32: {
using Converter = AudioConverter<torch::kInt32, false>; using Converter = AudioConverter<torch::kInt32, false>;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels}); tb, fpc, num_chunks, Converter{channels});
} }
case AV_SAMPLE_FMT_S64: { case AV_SAMPLE_FMT_S64: {
using Converter = AudioConverter<torch::kInt64, false>; using Converter = AudioConverter<torch::kInt64, false>;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels}); tb, fpc, num_chunks, Converter{channels});
} }
case AV_SAMPLE_FMT_FLT: { case AV_SAMPLE_FMT_FLT: {
using Converter = AudioConverter<torch::kFloat32, false>; using Converter = AudioConverter<torch::kFloat32, false>;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels}); tb, fpc, num_chunks, Converter{channels});
} }
case AV_SAMPLE_FMT_DBL: { case AV_SAMPLE_FMT_DBL: {
using Converter = AudioConverter<torch::kFloat64, false>; using Converter = AudioConverter<torch::kFloat64, false>;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels}); tb, fpc, num_chunks, Converter{channels});
} }
case AV_SAMPLE_FMT_U8P: { case AV_SAMPLE_FMT_U8P: {
using Converter = AudioConverter<torch::kUInt8, true>; using Converter = AudioConverter<torch::kUInt8, true>;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels}); tb, fpc, num_chunks, Converter{channels});
} }
case AV_SAMPLE_FMT_S16P: { case AV_SAMPLE_FMT_S16P: {
using Converter = AudioConverter<torch::kInt16, true>; using Converter = AudioConverter<torch::kInt16, true>;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels}); tb, fpc, num_chunks, Converter{channels});
} }
case AV_SAMPLE_FMT_S32P: { case AV_SAMPLE_FMT_S32P: {
using Converter = AudioConverter<torch::kInt32, true>; using Converter = AudioConverter<torch::kInt32, true>;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels}); tb, fpc, num_chunks, Converter{channels});
} }
case AV_SAMPLE_FMT_S64P: { case AV_SAMPLE_FMT_S64P: {
using Converter = AudioConverter<torch::kInt64, true>; using Converter = AudioConverter<torch::kInt64, true>;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels}); tb, fpc, num_chunks, Converter{channels});
} }
case AV_SAMPLE_FMT_FLTP: { case AV_SAMPLE_FMT_FLTP: {
using Converter = AudioConverter<torch::kFloat32, true>; using Converter = AudioConverter<torch::kFloat32, true>;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels}); tb, fpc, num_chunks, Converter{channels});
} }
case AV_SAMPLE_FMT_DBLP: { case AV_SAMPLE_FMT_DBLP: {
using Converter = AudioConverter<torch::kFloat64, true>; using Converter = AudioConverter<torch::kFloat64, true>;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{channels}); tb, fpc, num_chunks, Converter{channels});
} }
default: default:
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
...@@ -210,9 +211,9 @@ std::unique_ptr<Buffer> get_chunked_buffer( ...@@ -210,9 +211,9 @@ std::unique_ptr<Buffer> get_chunked_buffer(
} }
std::unique_ptr<Buffer> get_chunked_buffer( std::unique_ptr<Buffer> get_chunked_buffer(
int frames_per_chunk, AVRational tb,
int fpc,
int num_chunks, int num_chunks,
double frame_duration,
AVPixelFormat fmt, AVPixelFormat fmt,
int h, int h,
int w, int w,
...@@ -227,12 +228,12 @@ std::unique_ptr<Buffer> get_chunked_buffer( ...@@ -227,12 +228,12 @@ std::unique_ptr<Buffer> get_chunked_buffer(
case AV_PIX_FMT_NV12: { case AV_PIX_FMT_NV12: {
using Conv = NV12CudaConverter; using Conv = NV12CudaConverter;
return std::make_unique<ChunkedBuffer<Conv>>( return std::make_unique<ChunkedBuffer<Conv>>(
frames_per_chunk, num_chunks, frame_duration, Conv{h, w, device}); tb, fpc, num_chunks, Conv{h, w, device});
} }
case AV_PIX_FMT_P010: { case AV_PIX_FMT_P010: {
using Conv = P010CudaConverter; using Conv = P010CudaConverter;
return std::make_unique<ChunkedBuffer<Conv>>( return std::make_unique<ChunkedBuffer<Conv>>(
frames_per_chunk, num_chunks, frame_duration, Conv{h, w, device}); tb, fpc, num_chunks, Conv{h, w, device});
} }
case AV_PIX_FMT_P016: { case AV_PIX_FMT_P016: {
TORCH_CHECK( TORCH_CHECK(
...@@ -255,7 +256,7 @@ std::unique_ptr<Buffer> get_chunked_buffer( ...@@ -255,7 +256,7 @@ std::unique_ptr<Buffer> get_chunked_buffer(
case AV_PIX_FMT_BGR24: { case AV_PIX_FMT_BGR24: {
using Converter = InterlacedImageConverter; using Converter = InterlacedImageConverter;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{h, w, 3}); tb, fpc, num_chunks, Converter{h, w, 3});
} }
case AV_PIX_FMT_ARGB: case AV_PIX_FMT_ARGB:
case AV_PIX_FMT_RGBA: case AV_PIX_FMT_RGBA:
...@@ -263,32 +264,32 @@ std::unique_ptr<Buffer> get_chunked_buffer( ...@@ -263,32 +264,32 @@ std::unique_ptr<Buffer> get_chunked_buffer(
case AV_PIX_FMT_BGRA: { case AV_PIX_FMT_BGRA: {
using Converter = InterlacedImageConverter; using Converter = InterlacedImageConverter;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{h, w, 4}); tb, fpc, num_chunks, Converter{h, w, 4});
} }
case AV_PIX_FMT_GRAY8: { case AV_PIX_FMT_GRAY8: {
using Converter = InterlacedImageConverter; using Converter = InterlacedImageConverter;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{h, w, 1}); tb, fpc, num_chunks, Converter{h, w, 1});
} }
case AV_PIX_FMT_RGB48LE: { case AV_PIX_FMT_RGB48LE: {
using Converter = Interlaced16BitImageConverter; using Converter = Interlaced16BitImageConverter;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{h, w, 3}); tb, fpc, num_chunks, Converter{h, w, 3});
} }
case AV_PIX_FMT_YUV444P: { case AV_PIX_FMT_YUV444P: {
using Converter = PlanarImageConverter; using Converter = PlanarImageConverter;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{h, w, 3}); tb, fpc, num_chunks, Converter{h, w, 3});
} }
case AV_PIX_FMT_YUV420P: { case AV_PIX_FMT_YUV420P: {
using Converter = YUV420PConverter; using Converter = YUV420PConverter;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{h, w}); tb, fpc, num_chunks, Converter{h, w});
} }
case AV_PIX_FMT_NV12: { case AV_PIX_FMT_NV12: {
using Converter = NV12Converter; using Converter = NV12Converter;
return std::make_unique<ChunkedBuffer<Converter>>( return std::make_unique<ChunkedBuffer<Converter>>(
frames_per_chunk, num_chunks, frame_duration, Converter{h, w}); tb, fpc, num_chunks, Converter{h, w});
} }
default: { default: {
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
......
...@@ -13,9 +13,8 @@ class ChunkedBuffer : public Buffer { ...@@ -13,9 +13,8 @@ class ChunkedBuffer : public Buffer {
// Each AVFrame is converted to a Tensor and stored here. // Each AVFrame is converted to a Tensor and stored here.
std::deque<torch::Tensor> chunks; std::deque<torch::Tensor> chunks;
// Time stamps corresponding the first frame of each chunk // Time stamps corresponding the first frame of each chunk
std::deque<double> pts; std::deque<int64_t> pts;
// Duration of one frame, used to recalculate the PTS of audio samples AVRational time_base;
double frame_duration;
// The number of frames to return as a chunk // The number of frames to return as a chunk
// If <0, then user wants to receive all the frames // If <0, then user wants to receive all the frames
...@@ -31,28 +30,28 @@ class ChunkedBuffer : public Buffer { ...@@ -31,28 +30,28 @@ class ChunkedBuffer : public Buffer {
public: public:
ChunkedBuffer( ChunkedBuffer(
AVRational time_base,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
double frame_duration,
Converter&& converter); Converter&& converter);
bool is_ready() const override; bool is_ready() const override;
void flush() override; void flush() override;
c10::optional<Chunk> pop_chunk() override; c10::optional<Chunk> pop_chunk() override;
void push_frame(AVFrame* frame_, double pts_) override; void push_frame(AVFrame* frame_) override;
}; };
std::unique_ptr<Buffer> get_chunked_buffer( std::unique_ptr<Buffer> get_chunked_buffer(
AVRational time_base,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
double frame_duration,
AVSampleFormat fmt, AVSampleFormat fmt,
int num_channels); int num_channels);
std::unique_ptr<Buffer> get_chunked_buffer( std::unique_ptr<Buffer> get_chunked_buffer(
AVRational time_base,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
double frame_duration,
AVPixelFormat fmt, AVPixelFormat fmt,
int height, int height,
int width, int width,
......
...@@ -6,8 +6,10 @@ namespace io { ...@@ -6,8 +6,10 @@ namespace io {
namespace detail { namespace detail {
template <typename Converter> template <typename Converter>
UnchunkedBuffer<Converter>::UnchunkedBuffer(Converter&& converter_) UnchunkedBuffer<Converter>::UnchunkedBuffer(
: converter(std::move(converter_)) {} AVRational time_base,
Converter&& converter)
: time_base(time_base), converter(std::move(converter)) {}
template <typename Converter> template <typename Converter>
bool UnchunkedBuffer<Converter>::is_ready() const { bool UnchunkedBuffer<Converter>::is_ready() const {
...@@ -15,9 +17,9 @@ bool UnchunkedBuffer<Converter>::is_ready() const { ...@@ -15,9 +17,9 @@ bool UnchunkedBuffer<Converter>::is_ready() const {
} }
template <typename Converter> template <typename Converter>
void UnchunkedBuffer<Converter>::push_frame(AVFrame* frame, double pts_) { void UnchunkedBuffer<Converter>::push_frame(AVFrame* frame) {
if (chunks.size() == 0) { if (chunks.size() == 0) {
pts = pts_; pts = double(frame->pts) * time_base.num / time_base.den;
} }
chunks.push_back(converter.convert(frame)); chunks.push_back(converter.convert(frame));
} }
...@@ -39,55 +41,70 @@ void UnchunkedBuffer<Converter>::flush() { ...@@ -39,55 +41,70 @@ void UnchunkedBuffer<Converter>::flush() {
chunks.clear(); chunks.clear();
} }
std::unique_ptr<Buffer> get_unchunked_buffer(AVSampleFormat fmt, int channels) { std::unique_ptr<Buffer> get_unchunked_buffer(
AVRational tb,
AVSampleFormat fmt,
int channels) {
switch (fmt) { switch (fmt) {
case AV_SAMPLE_FMT_U8: { case AV_SAMPLE_FMT_U8: {
using Converter = AudioConverter<torch::kUInt8, false>; using Converter = AudioConverter<torch::kUInt8, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels}); return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
} }
case AV_SAMPLE_FMT_S16: { case AV_SAMPLE_FMT_S16: {
using Converter = AudioConverter<torch::kInt16, false>; using Converter = AudioConverter<torch::kInt16, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels}); return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
} }
case AV_SAMPLE_FMT_S32: { case AV_SAMPLE_FMT_S32: {
using Converter = AudioConverter<torch::kInt32, false>; using Converter = AudioConverter<torch::kInt32, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels}); return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
} }
case AV_SAMPLE_FMT_S64: { case AV_SAMPLE_FMT_S64: {
using Converter = AudioConverter<torch::kInt64, false>; using Converter = AudioConverter<torch::kInt64, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels}); return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
} }
case AV_SAMPLE_FMT_FLT: { case AV_SAMPLE_FMT_FLT: {
using Converter = AudioConverter<torch::kFloat32, false>; using Converter = AudioConverter<torch::kFloat32, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels}); return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
} }
case AV_SAMPLE_FMT_DBL: { case AV_SAMPLE_FMT_DBL: {
using Converter = AudioConverter<torch::kFloat64, false>; using Converter = AudioConverter<torch::kFloat64, false>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels}); return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
} }
case AV_SAMPLE_FMT_U8P: { case AV_SAMPLE_FMT_U8P: {
using Converter = AudioConverter<torch::kUInt8, true>; using Converter = AudioConverter<torch::kUInt8, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels}); return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
} }
case AV_SAMPLE_FMT_S16P: { case AV_SAMPLE_FMT_S16P: {
using Converter = AudioConverter<torch::kInt16, true>; using Converter = AudioConverter<torch::kInt16, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels}); return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
} }
case AV_SAMPLE_FMT_S32P: { case AV_SAMPLE_FMT_S32P: {
using Converter = AudioConverter<torch::kInt32, true>; using Converter = AudioConverter<torch::kInt32, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels}); return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
} }
case AV_SAMPLE_FMT_S64P: { case AV_SAMPLE_FMT_S64P: {
using Converter = AudioConverter<torch::kInt64, true>; using Converter = AudioConverter<torch::kInt64, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels}); return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
} }
case AV_SAMPLE_FMT_FLTP: { case AV_SAMPLE_FMT_FLTP: {
using Converter = AudioConverter<torch::kFloat32, true>; using Converter = AudioConverter<torch::kFloat32, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels}); return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
} }
case AV_SAMPLE_FMT_DBLP: { case AV_SAMPLE_FMT_DBLP: {
using Converter = AudioConverter<torch::kFloat64, true>; using Converter = AudioConverter<torch::kFloat64, true>;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{channels}); return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{channels});
} }
default: default:
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
...@@ -96,6 +113,7 @@ std::unique_ptr<Buffer> get_unchunked_buffer(AVSampleFormat fmt, int channels) { ...@@ -96,6 +113,7 @@ std::unique_ptr<Buffer> get_unchunked_buffer(AVSampleFormat fmt, int channels) {
} }
std::unique_ptr<Buffer> get_unchunked_buffer( std::unique_ptr<Buffer> get_unchunked_buffer(
AVRational tb,
AVPixelFormat fmt, AVPixelFormat fmt,
int h, int h,
int w, int w,
...@@ -109,11 +127,11 @@ std::unique_ptr<Buffer> get_unchunked_buffer( ...@@ -109,11 +127,11 @@ std::unique_ptr<Buffer> get_unchunked_buffer(
switch (fmt) { switch (fmt) {
case AV_PIX_FMT_NV12: { case AV_PIX_FMT_NV12: {
using Conv = NV12CudaConverter; using Conv = NV12CudaConverter;
return std::make_unique<UnchunkedBuffer<Conv>>(Conv{h, w, device}); return std::make_unique<UnchunkedBuffer<Conv>>(tb, Conv{h, w, device});
} }
case AV_PIX_FMT_P010: { case AV_PIX_FMT_P010: {
using Conv = P010CudaConverter; using Conv = P010CudaConverter;
return std::make_unique<UnchunkedBuffer<Conv>>(Conv{h, w, device}); return std::make_unique<UnchunkedBuffer<Conv>>(tb, Conv{h, w, device});
} }
case AV_PIX_FMT_P016: { case AV_PIX_FMT_P016: {
TORCH_CHECK( TORCH_CHECK(
...@@ -135,34 +153,39 @@ std::unique_ptr<Buffer> get_unchunked_buffer( ...@@ -135,34 +153,39 @@ std::unique_ptr<Buffer> get_unchunked_buffer(
case AV_PIX_FMT_RGB24: case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24: { case AV_PIX_FMT_BGR24: {
using Converter = InterlacedImageConverter; using Converter = InterlacedImageConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{h, w, 3}); return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{h, w, 3});
} }
case AV_PIX_FMT_ARGB: case AV_PIX_FMT_ARGB:
case AV_PIX_FMT_RGBA: case AV_PIX_FMT_RGBA:
case AV_PIX_FMT_ABGR: case AV_PIX_FMT_ABGR:
case AV_PIX_FMT_BGRA: { case AV_PIX_FMT_BGRA: {
using Converter = InterlacedImageConverter; using Converter = InterlacedImageConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{h, w, 4}); return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{h, w, 4});
} }
case AV_PIX_FMT_GRAY8: { case AV_PIX_FMT_GRAY8: {
using Converter = InterlacedImageConverter; using Converter = InterlacedImageConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{h, w, 1}); return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{h, w, 1});
} }
case AV_PIX_FMT_RGB48LE: { case AV_PIX_FMT_RGB48LE: {
using Converter = Interlaced16BitImageConverter; using Converter = Interlaced16BitImageConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{h, w, 3}); return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{h, w, 3});
} }
case AV_PIX_FMT_YUV444P: { case AV_PIX_FMT_YUV444P: {
using Converter = PlanarImageConverter; using Converter = PlanarImageConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{h, w, 3}); return std::make_unique<UnchunkedBuffer<Converter>>(
tb, Converter{h, w, 3});
} }
case AV_PIX_FMT_YUV420P: { case AV_PIX_FMT_YUV420P: {
using Converter = YUV420PConverter; using Converter = YUV420PConverter;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{h, w}); return std::make_unique<UnchunkedBuffer<Converter>>(tb, Converter{h, w});
} }
case AV_PIX_FMT_NV12: { case AV_PIX_FMT_NV12: {
using Converter = NV12Converter; using Converter = NV12Converter;
return std::make_unique<UnchunkedBuffer<Converter>>(Converter{h, w}); return std::make_unique<UnchunkedBuffer<Converter>>(tb, Converter{h, w});
} }
default: { default: {
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
......
...@@ -16,21 +16,24 @@ class UnchunkedBuffer : public Buffer { ...@@ -16,21 +16,24 @@ class UnchunkedBuffer : public Buffer {
// Each AVFrame is converted to a Tensor and stored here. // Each AVFrame is converted to a Tensor and stored here.
std::deque<torch::Tensor> chunks; std::deque<torch::Tensor> chunks;
double pts = -1.; double pts = -1.;
AVRational time_base;
Converter converter; Converter converter;
public: public:
UnchunkedBuffer(Converter&& converter); UnchunkedBuffer(AVRational time_base, Converter&& converter);
bool is_ready() const override; bool is_ready() const override;
void push_frame(AVFrame* frame, double pts_) override; void push_frame(AVFrame* frame) override;
c10::optional<Chunk> pop_chunk() override; c10::optional<Chunk> pop_chunk() override;
void flush() override; void flush() override;
}; };
std::unique_ptr<Buffer> get_unchunked_buffer( std::unique_ptr<Buffer> get_unchunked_buffer(
AVRational time_base,
AVSampleFormat fmt, AVSampleFormat fmt,
int num_channels); int num_channels);
std::unique_ptr<Buffer> get_unchunked_buffer( std::unique_ptr<Buffer> get_unchunked_buffer(
AVRational time_base,
AVPixelFormat fmt, AVPixelFormat fmt,
int height, int height,
int width, int width,
......
...@@ -9,7 +9,6 @@ namespace io { ...@@ -9,7 +9,6 @@ namespace io {
namespace { namespace {
std::unique_ptr<Buffer> get_buffer( std::unique_ptr<Buffer> get_buffer(
AVCodecContext* codec_ctx,
FilterGraph& filter, FilterGraph& filter,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
...@@ -32,32 +31,27 @@ std::unique_ptr<Buffer> get_buffer( ...@@ -32,32 +31,27 @@ std::unique_ptr<Buffer> get_buffer(
av_get_media_type_string(info.type), av_get_media_type_string(info.type),
". Only video or audio is supported "); ". Only video or audio is supported ");
auto time_base = filter.get_output_timebase();
double frame_duration = double(time_base.num) / time_base.den;
if (info.type == AVMEDIA_TYPE_AUDIO) { if (info.type == AVMEDIA_TYPE_AUDIO) {
AVSampleFormat fmt = (AVSampleFormat)(info.format); AVSampleFormat fmt = (AVSampleFormat)(info.format);
if (frames_per_chunk == -1) { if (frames_per_chunk == -1) {
return detail::get_unchunked_buffer(fmt, codec_ctx->channels); return detail::get_unchunked_buffer(
info.time_base, fmt, info.num_channels);
} else { } else {
return detail::get_chunked_buffer( return detail::get_chunked_buffer(
frames_per_chunk, info.time_base, frames_per_chunk, num_chunks, fmt, info.num_channels);
num_chunks,
frame_duration,
fmt,
codec_ctx->channels);
} }
} else { } else {
AVPixelFormat fmt = (AVPixelFormat)(info.format); AVPixelFormat fmt = (AVPixelFormat)(info.format);
TORCH_INTERNAL_ASSERT(fmt != AV_PIX_FMT_CUDA); TORCH_INTERNAL_ASSERT(fmt != AV_PIX_FMT_CUDA);
if (frames_per_chunk == -1) { if (frames_per_chunk == -1) {
return detail::get_unchunked_buffer(fmt, info.height, info.width, device); return detail::get_unchunked_buffer(
info.time_base, fmt, info.height, info.width, device);
} else { } else {
return detail::get_chunked_buffer( return detail::get_chunked_buffer(
info.time_base,
frames_per_chunk, frames_per_chunk,
num_chunks, num_chunks,
frame_duration,
fmt, fmt,
info.height, info.height,
info.width, info.width,
...@@ -110,22 +104,18 @@ Sink::Sink( ...@@ -110,22 +104,18 @@ Sink::Sink(
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
AVRational frame_rate_, AVRational frame_rate_,
const c10::optional<std::string>& filter_description_, const std::string& filter_desc,
const torch::Device& device) const torch::Device& device)
: input_time_base(input_time_base_), : input_time_base(input_time_base_),
codec_ctx(codec_ctx_), codec_ctx(codec_ctx_),
frame_rate(frame_rate_), frame_rate(frame_rate_),
filter_description(filter_description_.value_or( filter_description(filter_desc),
codec_ctx->codec_type == AVMEDIA_TYPE_AUDIO ? "anull" : "null")),
filter(get_filter_graph( filter(get_filter_graph(
input_time_base_, input_time_base_,
codec_ctx, codec_ctx,
frame_rate, frame_rate,
filter_description)), filter_description)),
output_time_base(filter.get_output_timebase()), buffer(get_buffer(filter, frames_per_chunk, num_chunks, device)) {}
buffer(
get_buffer(codec_ctx, filter, frames_per_chunk, num_chunks, device)) {
}
// 0: some kind of success // 0: some kind of success
// <0: Some error happened // <0: Some error happened
...@@ -139,23 +129,13 @@ int Sink::process_frame(AVFrame* pFrame) { ...@@ -139,23 +129,13 @@ int Sink::process_frame(AVFrame* pFrame) {
return 0; return 0;
} }
if (ret >= 0) { if (ret >= 0) {
double pts = buffer->push_frame(frame);
double(frame->pts * output_time_base.num) / output_time_base.den;
buffer->push_frame(frame, pts);
} }
av_frame_unref(frame); av_frame_unref(frame);
} }
return ret; return ret;
} }
std::string Sink::get_filter_description() const {
return filter_description;
}
FilterGraphOutputInfo Sink::get_filter_output_info() const {
return filter.get_output_info();
}
void Sink::flush() { void Sink::flush() {
filter = get_filter_graph( filter = get_filter_graph(
input_time_base, codec_ctx, frame_rate, filter_description); input_time_base, codec_ctx, frame_rate, filter_description);
......
...@@ -8,31 +8,27 @@ namespace torchaudio { ...@@ -8,31 +8,27 @@ namespace torchaudio {
namespace io { namespace io {
class Sink { class Sink {
AVFramePtr frame; AVFramePtr frame{};
// Parameters for recreating FilterGraph // Parameters for recreating FilterGraph
AVRational input_time_base; AVRational input_time_base;
AVCodecContext* codec_ctx; AVCodecContext* codec_ctx;
AVRational frame_rate; AVRational frame_rate;
std::string filter_description;
FilterGraph filter;
// time_base of filter graph output, used for PTS calc
AVRational output_time_base;
public: public:
const std::string filter_description;
FilterGraph filter;
std::unique_ptr<Buffer> buffer; std::unique_ptr<Buffer> buffer;
Sink( Sink(
AVRational input_time_base, AVRational input_time_base,
AVCodecContext* codec_ctx, AVCodecContext* codec_ctx,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
AVRational frame_rate, AVRational frame_rate,
const c10::optional<std::string>& filter_description, const std::string& filter_description,
const torch::Device& device); const torch::Device& device);
[[nodiscard]] std::string get_filter_description() const;
[[nodiscard]] FilterGraphOutputInfo get_filter_output_info() const;
int process_frame(AVFrame* frame); int process_frame(AVFrame* frame);
bool is_buffer_ready() const; bool is_buffer_ready() const;
......
...@@ -182,7 +182,7 @@ KeyType StreamProcessor::add_stream( ...@@ -182,7 +182,7 @@ KeyType StreamProcessor::add_stream(
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
AVRational frame_rate, AVRational frame_rate,
const c10::optional<std::string>& filter_description, const std::string& filter_description,
const torch::Device& device) { const torch::Device& device) {
// If device is provided, then check that codec_ctx has hw_device_ctx set. // If device is provided, then check that codec_ctx has hw_device_ctx set.
// In case, defining an output stream with HW accel on an input stream that // In case, defining an output stream with HW accel on an input stream that
...@@ -252,12 +252,12 @@ void StreamProcessor::set_discard_timestamp(int64_t timestamp) { ...@@ -252,12 +252,12 @@ void StreamProcessor::set_discard_timestamp(int64_t timestamp) {
// Query methods // Query methods
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
std::string StreamProcessor::get_filter_description(KeyType key) const { std::string StreamProcessor::get_filter_description(KeyType key) const {
return sinks.at(key).get_filter_description(); return sinks.at(key).filter_description;
} }
FilterGraphOutputInfo StreamProcessor::get_filter_output_info( FilterGraphOutputInfo StreamProcessor::get_filter_output_info(
KeyType key) const { KeyType key) const {
return sinks.at(key).get_filter_output_info(); return sinks.at(key).filter.get_output_info();
} }
bool StreamProcessor::is_buffer_ready() const { bool StreamProcessor::is_buffer_ready() const {
......
...@@ -59,7 +59,7 @@ class StreamProcessor { ...@@ -59,7 +59,7 @@ class StreamProcessor {
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
AVRational frame_rate, AVRational frame_rate,
const c10::optional<std::string>& filter_description, const std::string& filter_description,
const torch::Device& device); const torch::Device& device);
// 1. Remove the stream // 1. Remove the stream
......
...@@ -285,7 +285,7 @@ void StreamReader::add_audio_stream( ...@@ -285,7 +285,7 @@ void StreamReader::add_audio_stream(
AVMEDIA_TYPE_AUDIO, AVMEDIA_TYPE_AUDIO,
static_cast<int>(frames_per_chunk), static_cast<int>(frames_per_chunk),
static_cast<int>(num_chunks), static_cast<int>(num_chunks),
filter_desc, filter_desc.value_or("anull"),
decoder, decoder,
decoder_option, decoder_option,
torch::Device(torch::DeviceType::CPU)); torch::Device(torch::DeviceType::CPU));
...@@ -322,7 +322,7 @@ void StreamReader::add_video_stream( ...@@ -322,7 +322,7 @@ void StreamReader::add_video_stream(
AVMEDIA_TYPE_VIDEO, AVMEDIA_TYPE_VIDEO,
static_cast<int>(frames_per_chunk), static_cast<int>(frames_per_chunk),
static_cast<int>(num_chunks), static_cast<int>(num_chunks),
filter_desc, filter_desc.value_or("null"),
decoder, decoder,
decoder_option, decoder_option,
device); device);
...@@ -333,7 +333,7 @@ void StreamReader::add_stream( ...@@ -333,7 +333,7 @@ void StreamReader::add_stream(
AVMediaType media_type, AVMediaType media_type,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
const c10::optional<std::string>& filter_desc, const std::string& filter_desc,
const c10::optional<std::string>& decoder, const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_option, const c10::optional<OptionDict>& decoder_option,
const torch::Device& device) { const torch::Device& device) {
......
...@@ -229,7 +229,7 @@ class StreamReader { ...@@ -229,7 +229,7 @@ class StreamReader {
AVMediaType media_type, AVMediaType media_type,
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
const c10::optional<std::string>& filter_desc, const std::string& filter_desc,
const c10::optional<std::string>& decoder, const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_option, const c10::optional<OptionDict>& decoder_option,
const torch::Device& device); const torch::Device& device);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment