You need to sign in or sign up before continuing.
Commit fd24af00 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Use null filter in case no filter is used (#3109)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3109

Change the logic around StreamWriter preprocessing.
Currently, no preprocessing is expressed as `nullptr` to `unique_ptr<FilterGraph>`.

This commit changes it to `[a]null` filter, which is just a pass through.
This makes a code a bit simpler, and serves better preparation for adding
filters for CUDA process.

Reviewed By: xiaohui-zhang

Differential Revision: D43593321

fbshipit-source-id: 9ca71c2c8bf652384a0f56b4c41b32d908f61201
parent be3bd1ac
...@@ -4,23 +4,28 @@ namespace torchaudio::io { ...@@ -4,23 +4,28 @@ namespace torchaudio::io {
namespace { namespace {
std::unique_ptr<FilterGraph> get_audio_filter( FilterGraph get_audio_filter(
AVSampleFormat src_fmt, AVSampleFormat src_fmt,
AVCodecContext* codec_ctx) { AVCodecContext* codec_ctx) {
auto desc = [&]() -> std::string {
if (src_fmt == codec_ctx->sample_fmt) { if (src_fmt == codec_ctx->sample_fmt) {
return {nullptr}; return "anull";
} else {
std::stringstream ss;
ss << "aformat=" << av_get_sample_fmt_name(codec_ctx->sample_fmt);
return ss.str();
} }
std::stringstream desc; }();
desc << "aformat=" << av_get_sample_fmt_name(codec_ctx->sample_fmt);
auto p = std::make_unique<FilterGraph>(AVMEDIA_TYPE_AUDIO); FilterGraph p{AVMEDIA_TYPE_AUDIO};
p->add_audio_src( p.add_audio_src(
src_fmt, src_fmt,
codec_ctx->time_base, codec_ctx->time_base,
codec_ctx->sample_rate, codec_ctx->sample_rate,
codec_ctx->channel_layout); codec_ctx->channel_layout);
p->add_sink(); p.add_sink();
p->add_process(desc.str()); p.add_process(desc);
p->create_filter(); p.create_filter();
return p; return p;
} }
......
...@@ -5,7 +5,7 @@ namespace torchaudio::io { ...@@ -5,7 +5,7 @@ namespace torchaudio::io {
OutputStream::OutputStream( OutputStream::OutputStream(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
AVCodecContext* codec_ctx_, AVCodecContext* codec_ctx_,
std::unique_ptr<FilterGraph>&& filter_) FilterGraph&& filter_)
: codec_ctx(codec_ctx_), : codec_ctx(codec_ctx_),
encoder(format_ctx, codec_ctx), encoder(format_ctx, codec_ctx),
filter(std::move(filter_)), filter(std::move(filter_)),
...@@ -13,13 +13,9 @@ OutputStream::OutputStream( ...@@ -13,13 +13,9 @@ OutputStream::OutputStream(
num_frames(0) {} num_frames(0) {}
void OutputStream::process_frame(AVFrame* src) { void OutputStream::process_frame(AVFrame* src) {
if (!filter) { int ret = filter.add_frame(src);
encoder.encode(src);
return;
}
int ret = filter->add_frame(src);
while (ret >= 0) { while (ret >= 0) {
ret = filter->get_frame(dst_frame); ret = filter.get_frame(dst_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) { if (ret == AVERROR_EOF) {
encoder.encode(nullptr); encoder.encode(nullptr);
......
...@@ -13,7 +13,7 @@ struct OutputStream { ...@@ -13,7 +13,7 @@ struct OutputStream {
// Encoder + Muxer // Encoder + Muxer
Encoder encoder; Encoder encoder;
// Filter for additional processing // Filter for additional processing
std::unique_ptr<FilterGraph> filter; FilterGraph filter;
// frame that output from FilterGraph is written // frame that output from FilterGraph is written
AVFramePtr dst_frame; AVFramePtr dst_frame;
// The number of samples written so far // The number of samples written so far
...@@ -22,7 +22,7 @@ struct OutputStream { ...@@ -22,7 +22,7 @@ struct OutputStream {
OutputStream( OutputStream(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
AVCodecContext* codec_ctx, AVCodecContext* codec_ctx,
std::unique_ptr<FilterGraph>&& filter); FilterGraph&& filter);
virtual void write_chunk(const torch::Tensor& input) = 0; virtual void write_chunk(const torch::Tensor& input) = 0;
void process_frame(AVFrame* src); void process_frame(AVFrame* src);
......
...@@ -8,26 +8,31 @@ namespace torchaudio::io { ...@@ -8,26 +8,31 @@ namespace torchaudio::io {
namespace { namespace {
std::unique_ptr<FilterGraph> get_video_filter( FilterGraph get_video_filter(
AVPixelFormat src_fmt, AVPixelFormat src_fmt,
AVCodecContext* codec_ctx, AVCodecContext* codec_ctx,
const torch::Device& device) { const torch::Device& device) {
if (src_fmt == codec_ctx->pix_fmt || device.type() != c10::DeviceType::CPU) { auto desc = [&]() -> std::string {
return {nullptr}; if (src_fmt == codec_ctx->pix_fmt ||
device.type() != c10::DeviceType::CPU) {
return "null";
} else {
std::stringstream ss;
ss << "format=" << av_get_pix_fmt_name(codec_ctx->pix_fmt);
return ss.str();
} }
std::stringstream desc; }();
desc << "format=" << av_get_pix_fmt_name(codec_ctx->pix_fmt);
auto p = std::make_unique<FilterGraph>(AVMEDIA_TYPE_VIDEO); FilterGraph p{AVMEDIA_TYPE_VIDEO};
p->add_video_src( p.add_video_src(
src_fmt, src_fmt,
codec_ctx->time_base, codec_ctx->time_base,
codec_ctx->width, codec_ctx->width,
codec_ctx->height, codec_ctx->height,
codec_ctx->sample_aspect_ratio); codec_ctx->sample_aspect_ratio);
p->add_sink(); p.add_sink();
p->add_process(desc.str()); p.add_process(desc);
p->create_filter(); p.create_filter();
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment