Commit 4eac61a3 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Refactor the initialization of EncodeProcess (#3205)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3205

This commit refactors the initialization of EncodeProcess.

Interface-wise, the signature of the constructor of EncodeProcess
has made simpler just to take rvalues of its components, and the
initialization of the components have been moved to helper functions.

Implementat-wise, the order that the components are initialized is
revised, and the source of initialization parameters is also revised.

For example, the original implementation first creates AVCodecContext,
and passes it around to create the other components. This relied on
an assumption that parameters AVCodecContext has (such as image size
and sample rate) are same as the source data. This is not always right,
and as we will introduce custom filter graph and allow on-the-fly
transform of rates and dimensions, it will become even less correct.

The new initialization constructs source AVFrame, TensorConverter and
FilterGraph from source attributes. This makes it easy to introduce
on-the-fly transform.

Reviewed By: nateanl

Differential Revision: D44360650

fbshipit-source-id: bf0e77dc1a5a40fc8e9870c50d07339d812762e8
parent d8a37a21
......@@ -9,47 +9,50 @@
namespace torchaudio::io {
class EncodeProcess {
// In the reverse order of the process
AVCodecContextPtr codec_ctx;
Encoder encoder;
AVFramePtr dst_frame{};
FilterGraph filter;
AVFramePtr src_frame;
TensorConverter converter;
AVFramePtr src_frame;
FilterGraph filter;
AVFramePtr dst_frame{};
Encoder encoder;
AVCodecContextPtr codec_ctx;
public:
// Constructor for audio
EncodeProcess(
AVFormatContext* format_ctx,
int sample_rate,
int num_channels,
const enum AVSampleFormat format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config);
// constructor for video
EncodeProcess(
AVFormatContext* format_ctx,
double frame_rate,
int width,
int height,
const enum AVPixelFormat format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config);
void process(
AVMediaType type,
const torch::Tensor& tensor,
const c10::optional<double>& pts);
TensorConverter&& converter,
AVFramePtr&& frame,
FilterGraph&& filter_graph,
Encoder&& encoder,
AVCodecContextPtr&& codec_ctx) noexcept;
EncodeProcess(EncodeProcess&&) noexcept = default;
void process(const torch::Tensor& tensor, const c10::optional<double>& pts);
void process_frame(AVFrame* src);
void flush();
};
EncodeProcess get_audio_encode_process(
AVFormatContext* format_ctx,
int sample_rate,
int num_channels,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config);
EncodeProcess get_video_encode_process(
AVFormatContext* format_ctx,
double frame_rate,
int width,
int height,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config);
}; // namespace torchaudio::io
......@@ -2,28 +2,11 @@
namespace torchaudio::io {
namespace {
AVStream* add_stream(AVFormatContext* format_ctx, AVCodecContext* codec_ctx) {
AVStream* stream = avformat_new_stream(format_ctx, nullptr);
TORCH_CHECK(stream, "Failed to allocate stream.");
stream->time_base = codec_ctx->time_base;
int ret = avcodec_parameters_from_context(stream->codecpar, codec_ctx);
TORCH_CHECK(
ret >= 0,
"Failed to copy the stream parameter. (",
av_err2string(ret),
")");
return stream;
}
} // namespace
Encoder::Encoder(AVFormatContext* format_ctx_, AVCodecContext* codec_ctx_)
: format_ctx(format_ctx_),
codec_ctx(codec_ctx_),
stream(add_stream(format_ctx, codec_ctx)) {}
Encoder::Encoder(
AVFormatContext* format_ctx,
AVCodecContext* codec_ctx,
AVStream* stream) noexcept
: format_ctx(format_ctx), codec_ctx(codec_ctx), stream(stream) {}
///
/// Encode the given AVFrame data
......
......@@ -12,16 +12,17 @@ class Encoder {
AVFormatContext* format_ctx;
// Reference to codec context (encoder)
AVCodecContext* codec_ctx;
// Stream object
// Encoder object creates AVStream, but it will be deallocated along with
// AVFormatContext, So Encoder does not own it.
// Stream object as reference. Owned by AVFormatContext.
AVStream* stream;
// Temporary object used during the encoding
// Encoder owns it.
AVPacketPtr packet{};
public:
Encoder(AVFormatContext* format_ctx, AVCodecContext* codec_ctx);
Encoder(
AVFormatContext* format_ctx,
AVCodecContext* codec_ctx,
AVStream* stream) noexcept;
void encode(AVFrame* frame);
};
......
......@@ -53,88 +53,54 @@ StreamWriter::StreamWriter(
const c10::optional<std::string>& format)
: StreamWriter(get_output_format_context(dst, format, nullptr)) {}
namespace {
enum AVSampleFormat get_src_sample_fmt(const std::string& src) {
auto fmt = av_get_sample_fmt(src.c_str());
TORCH_CHECK(fmt != AV_SAMPLE_FMT_NONE, "Unknown sample format: ", src);
TORCH_CHECK(
!av_sample_fmt_is_planar(fmt),
"Unexpected sample fotmat value. Valid values are ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_U8),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S16),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S32),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S64),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_FLT),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_DBL),
". ",
"Found: ",
src);
return fmt;
}
enum AVPixelFormat get_src_pixel_fmt(const std::string& src) {
auto fmt = av_get_pix_fmt(src.c_str());
switch (fmt) {
case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
case AV_PIX_FMT_YUV444P:
return fmt;
case AV_PIX_FMT_NONE:
TORCH_CHECK(false, "Unknown pixel format: ", src);
default:
TORCH_CHECK(false, "Unsupported pixel format: ", src);
}
}
} // namespace
void StreamWriter::add_audio_stream(
int64_t sample_rate,
int64_t num_channels,
int sample_rate,
int num_channels,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config) {
processes.emplace_back(
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_audio_encode_process(
pFormatContext,
sample_rate,
num_channels,
get_src_sample_fmt(format),
format,
encoder,
encoder_option,
encoder_format,
config);
config));
}
void StreamWriter::add_video_stream(
double frame_rate,
int64_t width,
int64_t height,
int width,
int height,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) {
processes.emplace_back(
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_video_encode_process(
pFormatContext,
frame_rate,
width,
height,
get_src_pixel_fmt(format),
format,
encoder,
encoder_option,
encoder_format,
hw_accel,
config);
config));
}
void StreamWriter::set_metadata(const OptionDict& metadata) {
......@@ -149,6 +115,10 @@ void StreamWriter::dump_format(int64_t i) {
}
void StreamWriter::open(const c10::optional<OptionDict>& option) {
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
int ret = 0;
// Open the file if it was not provided by client code (i.e. when not
......@@ -210,12 +180,17 @@ void StreamWriter::write_audio_chunk(
const c10::optional<double>& pts) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()),
0 <= i && i < static_cast<int>(pFormatContext->nb_streams),
"Invalid stream index. Index must be in range of [0, ",
processes.size(),
pFormatContext->nb_streams,
"). Found: ",
i);
processes[i].process(AVMEDIA_TYPE_AUDIO, waveform, pts);
TORCH_CHECK(
pFormatContext->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_AUDIO,
"Stream ",
i,
" is not audio type.");
processes[i].process(waveform, pts);
}
void StreamWriter::write_video_chunk(
......@@ -224,12 +199,17 @@ void StreamWriter::write_video_chunk(
const c10::optional<double>& pts) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()),
0 <= i && i < static_cast<int>(pFormatContext->nb_streams),
"Invalid stream index. Index must be in range of [0, ",
processes.size(),
pFormatContext->nb_streams,
"). Found: ",
i);
processes[i].process(AVMEDIA_TYPE_VIDEO, frames, pts);
TORCH_CHECK(
pFormatContext->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_VIDEO,
"Stream ",
i,
" is not video type.");
processes[i].process(frames, pts);
}
void StreamWriter::flush() {
......
......@@ -100,8 +100,8 @@ class StreamWriter {
/// To list supported formats for the encoder, you can use
/// ``ffmpeg -h encoder=<ENCODER>`` command.
void add_audio_stream(
int64_t sample_rate,
int64_t num_channels,
int sample_rate,
int num_channels,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
......@@ -139,8 +139,8 @@ class StreamWriter {
/// @endparblock
void add_video_stream(
double frame_rate,
int64_t width,
int64_t height,
int width,
int height,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment