Commit fd24af00 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Use null filter in case no filter is used (#3109)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3109

Change the logic around StreamWriter preprocessing.
Currently, no preprocessing is expressed as `nullptr` to `unique_ptr<FilterGraph>`.

This commit changes it to `[a]null` filter, which is just a pass through.
This makes a code a bit simpler, and serves better preparation for adding
filters for CUDA process.

Reviewed By: xiaohui-zhang

Differential Revision: D43593321

fbshipit-source-id: 9ca71c2c8bf652384a0f56b4c41b32d908f61201
parent be3bd1ac
...@@ -4,23 +4,28 @@ namespace torchaudio::io { ...@@ -4,23 +4,28 @@ namespace torchaudio::io {
namespace { namespace {
std::unique_ptr<FilterGraph> get_audio_filter( FilterGraph get_audio_filter(
AVSampleFormat src_fmt, AVSampleFormat src_fmt,
AVCodecContext* codec_ctx) { AVCodecContext* codec_ctx) {
auto desc = [&]() -> std::string {
if (src_fmt == codec_ctx->sample_fmt) { if (src_fmt == codec_ctx->sample_fmt) {
return {nullptr}; return "anull";
} else {
std::stringstream ss;
ss << "aformat=" << av_get_sample_fmt_name(codec_ctx->sample_fmt);
return ss.str();
} }
std::stringstream desc; }();
desc << "aformat=" << av_get_sample_fmt_name(codec_ctx->sample_fmt);
auto p = std::make_unique<FilterGraph>(AVMEDIA_TYPE_AUDIO); FilterGraph p{AVMEDIA_TYPE_AUDIO};
p->add_audio_src( p.add_audio_src(
src_fmt, src_fmt,
codec_ctx->time_base, codec_ctx->time_base,
codec_ctx->sample_rate, codec_ctx->sample_rate,
codec_ctx->channel_layout); codec_ctx->channel_layout);
p->add_sink(); p.add_sink();
p->add_process(desc.str()); p.add_process(desc);
p->create_filter(); p.create_filter();
return p; return p;
} }
......
...@@ -5,7 +5,7 @@ namespace torchaudio::io { ...@@ -5,7 +5,7 @@ namespace torchaudio::io {
OutputStream::OutputStream( OutputStream::OutputStream(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
AVCodecContext* codec_ctx_, AVCodecContext* codec_ctx_,
std::unique_ptr<FilterGraph>&& filter_) FilterGraph&& filter_)
: codec_ctx(codec_ctx_), : codec_ctx(codec_ctx_),
encoder(format_ctx, codec_ctx), encoder(format_ctx, codec_ctx),
filter(std::move(filter_)), filter(std::move(filter_)),
...@@ -13,13 +13,9 @@ OutputStream::OutputStream( ...@@ -13,13 +13,9 @@ OutputStream::OutputStream(
num_frames(0) {} num_frames(0) {}
void OutputStream::process_frame(AVFrame* src) { void OutputStream::process_frame(AVFrame* src) {
if (!filter) { int ret = filter.add_frame(src);
encoder.encode(src);
return;
}
int ret = filter->add_frame(src);
while (ret >= 0) { while (ret >= 0) {
ret = filter->get_frame(dst_frame); ret = filter.get_frame(dst_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) { if (ret == AVERROR_EOF) {
encoder.encode(nullptr); encoder.encode(nullptr);
......
...@@ -13,7 +13,7 @@ struct OutputStream { ...@@ -13,7 +13,7 @@ struct OutputStream {
// Encoder + Muxer // Encoder + Muxer
Encoder encoder; Encoder encoder;
// Filter for additional processing // Filter for additional processing
std::unique_ptr<FilterGraph> filter; FilterGraph filter;
// frame that output from FilterGraph is written // frame that output from FilterGraph is written
AVFramePtr dst_frame; AVFramePtr dst_frame;
// The number of samples written so far // The number of samples written so far
...@@ -22,7 +22,7 @@ struct OutputStream { ...@@ -22,7 +22,7 @@ struct OutputStream {
OutputStream( OutputStream(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
AVCodecContext* codec_ctx, AVCodecContext* codec_ctx,
std::unique_ptr<FilterGraph>&& filter); FilterGraph&& filter);
virtual void write_chunk(const torch::Tensor& input) = 0; virtual void write_chunk(const torch::Tensor& input) = 0;
void process_frame(AVFrame* src); void process_frame(AVFrame* src);
......
...@@ -8,26 +8,31 @@ namespace torchaudio::io { ...@@ -8,26 +8,31 @@ namespace torchaudio::io {
namespace { namespace {
std::unique_ptr<FilterGraph> get_video_filter( FilterGraph get_video_filter(
AVPixelFormat src_fmt, AVPixelFormat src_fmt,
AVCodecContext* codec_ctx, AVCodecContext* codec_ctx,
const torch::Device& device) { const torch::Device& device) {
if (src_fmt == codec_ctx->pix_fmt || device.type() != c10::DeviceType::CPU) { auto desc = [&]() -> std::string {
return {nullptr}; if (src_fmt == codec_ctx->pix_fmt ||
device.type() != c10::DeviceType::CPU) {
return "null";
} else {
std::stringstream ss;
ss << "format=" << av_get_pix_fmt_name(codec_ctx->pix_fmt);
return ss.str();
} }
std::stringstream desc; }();
desc << "format=" << av_get_pix_fmt_name(codec_ctx->pix_fmt);
auto p = std::make_unique<FilterGraph>(AVMEDIA_TYPE_VIDEO); FilterGraph p{AVMEDIA_TYPE_VIDEO};
p->add_video_src( p.add_video_src(
src_fmt, src_fmt,
codec_ctx->time_base, codec_ctx->time_base,
codec_ctx->width, codec_ctx->width,
codec_ctx->height, codec_ctx->height,
codec_ctx->sample_aspect_ratio); codec_ctx->sample_aspect_ratio);
p->add_sink(); p.add_sink();
p->add_process(desc.str()); p.add_process(desc);
p->create_filter(); p.create_filter();
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment