Commit fd24af00 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Use null filter in case no filter is used (#3109)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3109

Change the logic around StreamWriter preprocessing.
Currently, no preprocessing is expressed as `nullptr` to `unique_ptr<FilterGraph>`.

This commit changes it to `[a]null` filter, which is just a pass through.
This makes a code a bit simpler, and serves better preparation for adding
filters for CUDA process.

Reviewed By: xiaohui-zhang

Differential Revision: D43593321

fbshipit-source-id: 9ca71c2c8bf652384a0f56b4c41b32d908f61201
parent be3bd1ac
......@@ -4,23 +4,28 @@ namespace torchaudio::io {
namespace {
std::unique_ptr<FilterGraph> get_audio_filter(
FilterGraph get_audio_filter(
AVSampleFormat src_fmt,
AVCodecContext* codec_ctx) {
if (src_fmt == codec_ctx->sample_fmt) {
return {nullptr};
}
std::stringstream desc;
desc << "aformat=" << av_get_sample_fmt_name(codec_ctx->sample_fmt);
auto p = std::make_unique<FilterGraph>(AVMEDIA_TYPE_AUDIO);
p->add_audio_src(
auto desc = [&]() -> std::string {
if (src_fmt == codec_ctx->sample_fmt) {
return "anull";
} else {
std::stringstream ss;
ss << "aformat=" << av_get_sample_fmt_name(codec_ctx->sample_fmt);
return ss.str();
}
}();
FilterGraph p{AVMEDIA_TYPE_AUDIO};
p.add_audio_src(
src_fmt,
codec_ctx->time_base,
codec_ctx->sample_rate,
codec_ctx->channel_layout);
p->add_sink();
p->add_process(desc.str());
p->create_filter();
p.add_sink();
p.add_process(desc);
p.create_filter();
return p;
}
......
......@@ -5,7 +5,7 @@ namespace torchaudio::io {
OutputStream::OutputStream(
AVFormatContext* format_ctx,
AVCodecContext* codec_ctx_,
std::unique_ptr<FilterGraph>&& filter_)
FilterGraph&& filter_)
: codec_ctx(codec_ctx_),
encoder(format_ctx, codec_ctx),
filter(std::move(filter_)),
......@@ -13,13 +13,9 @@ OutputStream::OutputStream(
num_frames(0) {}
void OutputStream::process_frame(AVFrame* src) {
if (!filter) {
encoder.encode(src);
return;
}
int ret = filter->add_frame(src);
int ret = filter.add_frame(src);
while (ret >= 0) {
ret = filter->get_frame(dst_frame);
ret = filter.get_frame(dst_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) {
encoder.encode(nullptr);
......
......@@ -13,7 +13,7 @@ struct OutputStream {
// Encoder + Muxer
Encoder encoder;
// Filter for additional processing
std::unique_ptr<FilterGraph> filter;
FilterGraph filter;
// frame that output from FilterGraph is written
AVFramePtr dst_frame;
// The number of samples written so far
......@@ -22,7 +22,7 @@ struct OutputStream {
OutputStream(
AVFormatContext* format_ctx,
AVCodecContext* codec_ctx,
std::unique_ptr<FilterGraph>&& filter);
FilterGraph&& filter);
virtual void write_chunk(const torch::Tensor& input) = 0;
void process_frame(AVFrame* src);
......
......@@ -8,26 +8,31 @@ namespace torchaudio::io {
namespace {
std::unique_ptr<FilterGraph> get_video_filter(
FilterGraph get_video_filter(
AVPixelFormat src_fmt,
AVCodecContext* codec_ctx,
const torch::Device& device) {
if (src_fmt == codec_ctx->pix_fmt || device.type() != c10::DeviceType::CPU) {
return {nullptr};
}
std::stringstream desc;
desc << "format=" << av_get_pix_fmt_name(codec_ctx->pix_fmt);
auto desc = [&]() -> std::string {
if (src_fmt == codec_ctx->pix_fmt ||
device.type() != c10::DeviceType::CPU) {
return "null";
} else {
std::stringstream ss;
ss << "format=" << av_get_pix_fmt_name(codec_ctx->pix_fmt);
return ss.str();
}
}();
auto p = std::make_unique<FilterGraph>(AVMEDIA_TYPE_VIDEO);
p->add_video_src(
FilterGraph p{AVMEDIA_TYPE_VIDEO};
p.add_video_src(
src_fmt,
codec_ctx->time_base,
codec_ctx->width,
codec_ctx->height,
codec_ctx->sample_aspect_ratio);
p->add_sink();
p->add_process(desc.str());
p->create_filter();
p.add_sink();
p.add_process(desc);
p.create_filter();
return p;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment