Commit 4eac61a3 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Refactor the initialization of EncodeProcess (#3205)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3205

This commit refactors the initialization of EncodeProcess.

Interface-wise, the signature of the constructor of EncodeProcess
has made simpler just to take rvalues of its components, and the
initialization of the components have been moved to helper functions.

Implementat-wise, the order that the components are initialized is
revised, and the source of initialization parameters is also revised.

For example, the original implementation first creates AVCodecContext,
and passes it around to create the other components. This relied on
an assumption that parameters AVCodecContext has (such as image size
and sample rate) are same as the source data. This is not always right,
and as we will introduce custom filter graph and allow on-the-fly
transform of rates and dimensions, it will become even less correct.

The new initialization constructs source AVFrame, TensorConverter and
FilterGraph from source attributes. This makes it easy to introduce
on-the-fly transform.

Reviewed By: nateanl

Differential Revision: D44360650

fbshipit-source-id: bf0e77dc1a5a40fc8e9870c50d07339d812762e8
parent d8a37a21
...@@ -9,47 +9,50 @@ ...@@ -9,47 +9,50 @@
namespace torchaudio::io { namespace torchaudio::io {
class EncodeProcess { class EncodeProcess {
// In the reverse order of the process
AVCodecContextPtr codec_ctx;
Encoder encoder;
AVFramePtr dst_frame{};
FilterGraph filter;
AVFramePtr src_frame;
TensorConverter converter; TensorConverter converter;
AVFramePtr src_frame;
FilterGraph filter;
AVFramePtr dst_frame{};
Encoder encoder;
AVCodecContextPtr codec_ctx;
public: public:
// Constructor for audio
EncodeProcess( EncodeProcess(
TensorConverter&& converter,
AVFramePtr&& frame,
FilterGraph&& filter_graph,
Encoder&& encoder,
AVCodecContextPtr&& codec_ctx) noexcept;
EncodeProcess(EncodeProcess&&) noexcept = default;
void process(const torch::Tensor& tensor, const c10::optional<double>& pts);
void process_frame(AVFrame* src);
void flush();
};
EncodeProcess get_audio_encode_process(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
int sample_rate, int sample_rate,
int num_channels, int num_channels,
const enum AVSampleFormat format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config); const c10::optional<EncodingConfig>& config);
// constructor for video EncodeProcess get_video_encode_process(
EncodeProcess(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
double frame_rate, double frame_rate,
int width, int width,
int height, int height,
const enum AVPixelFormat format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel, const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config); const c10::optional<EncodingConfig>& config);
void process(
AVMediaType type,
const torch::Tensor& tensor,
const c10::optional<double>& pts);
void process_frame(AVFrame* src);
void flush();
};
}; // namespace torchaudio::io }; // namespace torchaudio::io
...@@ -2,28 +2,11 @@ ...@@ -2,28 +2,11 @@
namespace torchaudio::io { namespace torchaudio::io {
namespace { Encoder::Encoder(
AVFormatContext* format_ctx,
AVStream* add_stream(AVFormatContext* format_ctx, AVCodecContext* codec_ctx) { AVCodecContext* codec_ctx,
AVStream* stream = avformat_new_stream(format_ctx, nullptr); AVStream* stream) noexcept
TORCH_CHECK(stream, "Failed to allocate stream."); : format_ctx(format_ctx), codec_ctx(codec_ctx), stream(stream) {}
stream->time_base = codec_ctx->time_base;
int ret = avcodec_parameters_from_context(stream->codecpar, codec_ctx);
TORCH_CHECK(
ret >= 0,
"Failed to copy the stream parameter. (",
av_err2string(ret),
")");
return stream;
}
} // namespace
Encoder::Encoder(AVFormatContext* format_ctx_, AVCodecContext* codec_ctx_)
: format_ctx(format_ctx_),
codec_ctx(codec_ctx_),
stream(add_stream(format_ctx, codec_ctx)) {}
/// ///
/// Encode the given AVFrame data /// Encode the given AVFrame data
......
...@@ -12,16 +12,17 @@ class Encoder { ...@@ -12,16 +12,17 @@ class Encoder {
AVFormatContext* format_ctx; AVFormatContext* format_ctx;
// Reference to codec context (encoder) // Reference to codec context (encoder)
AVCodecContext* codec_ctx; AVCodecContext* codec_ctx;
// Stream object // Stream object as reference. Owned by AVFormatContext.
// Encoder object creates AVStream, but it will be deallocated along with
// AVFormatContext, So Encoder does not own it.
AVStream* stream; AVStream* stream;
// Temporary object used during the encoding // Temporary object used during the encoding
// Encoder owns it. // Encoder owns it.
AVPacketPtr packet{}; AVPacketPtr packet{};
public: public:
Encoder(AVFormatContext* format_ctx, AVCodecContext* codec_ctx); Encoder(
AVFormatContext* format_ctx,
AVCodecContext* codec_ctx,
AVStream* stream) noexcept;
void encode(AVFrame* frame); void encode(AVFrame* frame);
}; };
......
...@@ -53,88 +53,54 @@ StreamWriter::StreamWriter( ...@@ -53,88 +53,54 @@ StreamWriter::StreamWriter(
const c10::optional<std::string>& format) const c10::optional<std::string>& format)
: StreamWriter(get_output_format_context(dst, format, nullptr)) {} : StreamWriter(get_output_format_context(dst, format, nullptr)) {}
namespace {
enum AVSampleFormat get_src_sample_fmt(const std::string& src) {
auto fmt = av_get_sample_fmt(src.c_str());
TORCH_CHECK(fmt != AV_SAMPLE_FMT_NONE, "Unknown sample format: ", src);
TORCH_CHECK(
!av_sample_fmt_is_planar(fmt),
"Unexpected sample fotmat value. Valid values are ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_U8),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S16),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S32),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S64),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_FLT),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_DBL),
". ",
"Found: ",
src);
return fmt;
}
enum AVPixelFormat get_src_pixel_fmt(const std::string& src) {
auto fmt = av_get_pix_fmt(src.c_str());
switch (fmt) {
case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
case AV_PIX_FMT_YUV444P:
return fmt;
case AV_PIX_FMT_NONE:
TORCH_CHECK(false, "Unknown pixel format: ", src);
default:
TORCH_CHECK(false, "Unsupported pixel format: ", src);
}
}
} // namespace
void StreamWriter::add_audio_stream( void StreamWriter::add_audio_stream(
int64_t sample_rate, int sample_rate,
int64_t num_channels, int num_channels,
const std::string& format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config) { const c10::optional<EncodingConfig>& config) {
processes.emplace_back( TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_audio_encode_process(
pFormatContext, pFormatContext,
sample_rate, sample_rate,
num_channels, num_channels,
get_src_sample_fmt(format), format,
encoder, encoder,
encoder_option, encoder_option,
encoder_format, encoder_format,
config); config));
} }
void StreamWriter::add_video_stream( void StreamWriter::add_video_stream(
double frame_rate, double frame_rate,
int64_t width, int width,
int64_t height, int height,
const std::string& format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel, const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) { const c10::optional<EncodingConfig>& config) {
processes.emplace_back( TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_video_encode_process(
pFormatContext, pFormatContext,
frame_rate, frame_rate,
width, width,
height, height,
get_src_pixel_fmt(format), format,
encoder, encoder,
encoder_option, encoder_option,
encoder_format, encoder_format,
hw_accel, hw_accel,
config); config));
} }
void StreamWriter::set_metadata(const OptionDict& metadata) { void StreamWriter::set_metadata(const OptionDict& metadata) {
...@@ -149,6 +115,10 @@ void StreamWriter::dump_format(int64_t i) { ...@@ -149,6 +115,10 @@ void StreamWriter::dump_format(int64_t i) {
} }
void StreamWriter::open(const c10::optional<OptionDict>& option) { void StreamWriter::open(const c10::optional<OptionDict>& option) {
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
int ret = 0; int ret = 0;
// Open the file if it was not provided by client code (i.e. when not // Open the file if it was not provided by client code (i.e. when not
...@@ -210,12 +180,17 @@ void StreamWriter::write_audio_chunk( ...@@ -210,12 +180,17 @@ void StreamWriter::write_audio_chunk(
const c10::optional<double>& pts) { const c10::optional<double>& pts) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?"); TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK( TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()), 0 <= i && i < static_cast<int>(pFormatContext->nb_streams),
"Invalid stream index. Index must be in range of [0, ", "Invalid stream index. Index must be in range of [0, ",
processes.size(), pFormatContext->nb_streams,
"). Found: ", "). Found: ",
i); i);
processes[i].process(AVMEDIA_TYPE_AUDIO, waveform, pts); TORCH_CHECK(
pFormatContext->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_AUDIO,
"Stream ",
i,
" is not audio type.");
processes[i].process(waveform, pts);
} }
void StreamWriter::write_video_chunk( void StreamWriter::write_video_chunk(
...@@ -224,12 +199,17 @@ void StreamWriter::write_video_chunk( ...@@ -224,12 +199,17 @@ void StreamWriter::write_video_chunk(
const c10::optional<double>& pts) { const c10::optional<double>& pts) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?"); TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK( TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()), 0 <= i && i < static_cast<int>(pFormatContext->nb_streams),
"Invalid stream index. Index must be in range of [0, ", "Invalid stream index. Index must be in range of [0, ",
processes.size(), pFormatContext->nb_streams,
"). Found: ", "). Found: ",
i); i);
processes[i].process(AVMEDIA_TYPE_VIDEO, frames, pts); TORCH_CHECK(
pFormatContext->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_VIDEO,
"Stream ",
i,
" is not video type.");
processes[i].process(frames, pts);
} }
void StreamWriter::flush() { void StreamWriter::flush() {
......
...@@ -100,8 +100,8 @@ class StreamWriter { ...@@ -100,8 +100,8 @@ class StreamWriter {
/// To list supported formats for the encoder, you can use /// To list supported formats for the encoder, you can use
/// ``ffmpeg -h encoder=<ENCODER>`` command. /// ``ffmpeg -h encoder=<ENCODER>`` command.
void add_audio_stream( void add_audio_stream(
int64_t sample_rate, int sample_rate,
int64_t num_channels, int num_channels,
const std::string& format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
...@@ -139,8 +139,8 @@ class StreamWriter { ...@@ -139,8 +139,8 @@ class StreamWriter {
/// @endparblock /// @endparblock
void add_video_stream( void add_video_stream(
double frame_rate, double frame_rate,
int64_t width, int width,
int64_t height, int height,
const std::string& format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment