Commit 1b648626 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Support encode spec change in StreamWriter (#3207)

Summary:
This commit adds support for changing the spec of media
(such as sample rate, #channels, image size and frame rate)
on-the-fly at encoding time.

The motivation behind this addition is that certain media
formats support only limited number of spec, and it is
cumbersome to require client code to change the spec
every time.

For example, OPUS supports only 48kHz sampling rate, and
vorbis only supports stereo.

To make it easy to work with media of different formats,
this commit makes it so that anything that's not compatible
with the format is automatically converted, and allows
users to specify the override.

Notable implementation detail is that, for sample format and
pixel format, the default value of encoder has higher precedent
to source value, while for other attributes like sample rate and
#channels, the source value has higher precedent as long as
they are supported.

Pull Request resolved: https://github.com/pytorch/audio/pull/3207

Reviewed By: nateanl

Differential Revision: D44439622

Pulled By: mthrok

fbshipit-source-id: 09524f201d485d201150481884a3e9e4d2aab081
parent 4bc4ca75
......@@ -594,8 +594,10 @@ class StreamWriterCorrectnessTest(TempDirMixin, TorchaudioTestCase):
def test_filter_graph_video(self):
"""Can apply additional effect with filter graph"""
rate = 30
src_rate = 30
num_frames, width, height = 400, 160, 90
filter_desc = "framestep=2"
enc_rate = 15
ext = "mp4"
filename = f"test.{ext}"
......@@ -603,7 +605,15 @@ class StreamWriterCorrectnessTest(TempDirMixin, TorchaudioTestCase):
dst = self.get_temp_path(filename)
w = StreamWriter(dst, format=ext)
w.add_video_stream(frame_rate=rate, format="rgb24", height=height, width=width, filter_desc="framestep=2")
w.add_video_stream(
frame_rate=src_rate,
format="rgb24",
height=height,
width=width,
filter_desc=filter_desc,
encoder_format="yuv420p",
encoder_frame_rate=enc_rate,
)
with w.open():
w.write_video_chunk(0, original)
......@@ -614,3 +624,129 @@ class StreamWriterCorrectnessTest(TempDirMixin, TorchaudioTestCase):
(output,) = reader.pop_chunks()
self.assertEqual(output.shape, [num_frames // 2, 3, height, width])
@parameterized.expand(
[
("wav", "pcm_s16le", 8000, 16000, 1, 2),
("wav", "pcm_s16le", 8000, 16000, 2, 1),
("wav", "pcm_s16le", 8000, 16000, 2, 4),
("wav", "pcm_s16le", 16000, 8000, 1, 2),
("wav", "pcm_s16le", 16000, 8000, 2, 1),
("wav", "pcm_s16le", 16000, 8000, 2, 4),
("wav", "pcm_f32le", 8000, 16000, 1, 2),
("wav", "pcm_f32le", 8000, 16000, 2, 1),
("wav", "pcm_f32le", 8000, 16000, 2, 4),
("wav", "pcm_f32le", 16000, 8000, 1, 2),
("wav", "pcm_f32le", 16000, 8000, 2, 1),
("wav", "pcm_f32le", 16000, 8000, 2, 4),
("ogg", "opus", 8000, 48000, 1, 2),
("ogg", "opus", 8000, 48000, 2, 1),
("ogg", "flac", 8000, 41000, 1, 2),
("ogg", "flac", 8000, 41000, 2, 1),
("ogg", "vorbis", 16000, 8000, 1, 2),
("ogg", "vorbis", 16000, 8000, 4, 2),
]
)
def test_change_audio_encoder_spec(self, ext, encoder, src_sr, enc_sr, src_num_channels, enc_num_channels):
"""Can change sample rate and channels on-the-fly"""
filename = f"test.{ext}"
original = get_sinusoid(sample_rate=src_sr, n_channels=src_num_channels, channels_first=False, duration=0.1)
dst = self.get_temp_path(filename)
w = StreamWriter(dst, format=ext)
w.add_audio_stream(
sample_rate=src_sr,
format="flt",
num_channels=src_num_channels,
encoder=encoder,
encoder_sample_rate=enc_sr,
encoder_num_channels=enc_num_channels,
)
with w.open():
w.write_audio_chunk(0, original)
# check
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
i = reader.get_src_stream_info(0)
self.assertEqual(i.sample_rate, enc_sr)
self.assertEqual(i.num_channels, enc_num_channels)
@parameterized.expand(
[
# opus only supports 48kHz
("ogg", "opus", 8000, 48000, 1, 1),
("ogg", "opus", 16000, 48000, 2, 2),
# vorbis only supports 2 channels
("ogg", "vorbis", 16000, 16000, 1, 2),
("ogg", "vorbis", 16000, 16000, 2, 2),
("ogg", "vorbis", 16000, 16000, 4, 2),
]
)
def test_change_encoder_spec_default(
self, ext, encoder, src_sr, expected_sr, src_num_channels, expected_num_channels
):
"""If input rate/channels are not supported, encoder picks supported one automatically."""
filename = f"test.{ext}"
original = get_sinusoid(sample_rate=src_sr, n_channels=src_num_channels, channels_first=False, duration=0.1)
dst = self.get_temp_path(filename)
w = StreamWriter(dst, format=ext)
w.add_audio_stream(
sample_rate=src_sr,
format="flt",
num_channels=src_num_channels,
encoder=encoder,
)
with w.open():
w.write_audio_chunk(0, original)
# check
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
i = reader.get_src_stream_info(0)
self.assertEqual(i.sample_rate, expected_sr)
self.assertEqual(i.num_channels, expected_num_channels)
@parameterized.expand(
[
("mp4", None, 10, 30, (100, 160), (200, 320)),
("mp4", None, 10, 30, (100, 160), (50, 80)),
("mp4", None, 30, 10, (100, 160), (200, 320)),
("mp4", None, 30, 10, (100, 160), (50, 80)),
]
)
def test_change_video_encoder_spec(self, ext, encoder, src_rate, enc_rate, src_size, enc_size):
"""Can change the frame rate and image size on-the-fly"""
width, height = src_size
enc_width, enc_height = enc_size
ext = "mp4"
filename = f"test.{ext}"
num_frames = 256
original = torch.zeros((num_frames, 3, height, width), dtype=torch.uint8)
dst = self.get_temp_path(filename)
w = StreamWriter(dst, format=ext)
w.add_video_stream(
frame_rate=src_rate,
format="rgb24",
height=height,
width=width,
encoder_format="yuv420p",
encoder_frame_rate=enc_rate,
encoder_width=enc_width,
encoder_height=enc_height,
)
with w.open():
w.write_video_chunk(0, original)
# check
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
i = reader.get_src_stream_info(0)
self.assertEqual(i.frame_rate, enc_rate)
self.assertEqual(i.width, enc_width)
self.assertEqual(i.height, enc_height)
#include <torchaudio/csrc/ffmpeg/hw_context.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/encode_process.h>
#include <cmath>
namespace torchaudio::io {
......@@ -23,8 +24,13 @@ void EncodeProcess::process(
const torch::Tensor& tensor,
const c10::optional<double>& pts) {
if (pts) {
const double& pts_val = pts.value();
TORCH_CHECK(
std::isfinite(pts_val) && pts_val >= 0.0,
"The value of PTS must be positive and finite. Found: ",
pts_val)
AVRational tb = codec_ctx->time_base;
auto val = static_cast<int64_t>(std::round(pts.value() * tb.den / tb.num));
auto val = static_cast<int64_t>(std::round(pts_val * tb.den / tb.num));
if (src_frame->pts > val) {
TORCH_WARN_ONCE(
"The provided PTS value is smaller than the next expected value.");
......@@ -64,7 +70,7 @@ void EncodeProcess::flush() {
namespace {
enum AVSampleFormat get_sample_fmt(const std::string& src) {
enum AVSampleFormat get_src_sample_fmt(const std::string& src) {
auto fmt = av_get_sample_fmt(src.c_str());
if (fmt != AV_SAMPLE_FMT_NONE && !av_sample_fmt_is_planar(fmt)) {
return fmt;
......@@ -90,7 +96,7 @@ enum AVSampleFormat get_sample_fmt(const std::string& src) {
".");
}
enum AVPixelFormat get_pix_fmt(const std::string& src) {
enum AVPixelFormat get_src_pix_fmt(const std::string& src) {
AVPixelFormat fmt = av_get_pix_fmt(src.c_str());
switch (fmt) {
case AV_PIX_FMT_GRAY8:
......@@ -205,14 +211,13 @@ bool supported_sample_fmt(
return false;
}
std::vector<std::string> get_supported_formats(
const AVSampleFormat* sample_fmts) {
std::string get_supported_formats(const AVSampleFormat* sample_fmts) {
std::vector<std::string> ret;
while (*sample_fmts != AV_SAMPLE_FMT_NONE) {
ret.emplace_back(av_get_sample_fmt_name(*sample_fmts));
++sample_fmts;
}
return ret;
return c10::Join(", ", ret);
}
AVSampleFormat get_enc_fmt(
......@@ -230,7 +235,7 @@ AVSampleFormat get_enc_fmt(
" does not support ",
encoder_format.value(),
" format. Supported values are; ",
c10::Join(", ", get_supported_formats(codec->sample_fmts)));
get_supported_formats(codec->sample_fmts));
return fmt;
}
if (codec->sample_fmts) {
......@@ -239,22 +244,21 @@ AVSampleFormat get_enc_fmt(
return src_fmt;
};
bool supported_sample_rate(
const int sample_rate,
const int* supported_samplerates) {
if (!supported_samplerates) {
bool supported_sample_rate(const int sample_rate, const AVCodec* codec) {
if (!codec->supported_samplerates) {
return true;
}
while (*supported_samplerates) {
if (sample_rate == *supported_samplerates) {
const int* it = codec->supported_samplerates;
while (*it) {
if (sample_rate == *it) {
return true;
}
++supported_samplerates;
++it;
}
return false;
}
std::vector<int> get_supported_samplerates(const int* supported_samplerates) {
std::string get_supported_samplerates(const int* supported_samplerates) {
std::vector<int> ret;
if (supported_samplerates) {
while (*supported_samplerates) {
......@@ -262,59 +266,99 @@ std::vector<int> get_supported_samplerates(const int* supported_samplerates) {
++supported_samplerates;
}
}
return ret;
return c10::Join(", ", ret);
}
void validate_sample_rate(int sample_rate, const AVCodec* codec) {
TORCH_CHECK(
supported_sample_rate(sample_rate, codec->supported_samplerates),
codec->name,
" does not support sample rate ",
sample_rate,
". Supported values are; ",
c10::Join(", ", get_supported_samplerates(codec->supported_samplerates)));
int get_enc_sr(
int src_sample_rate,
const c10::optional<int>& encoder_sample_rate,
const AVCodec* codec) {
if (encoder_sample_rate) {
const int& encoder_sr = encoder_sample_rate.value();
TORCH_CHECK(
encoder_sr > 0,
"Encoder sample rate must be positive. Found: ",
encoder_sr);
TORCH_CHECK(
supported_sample_rate(encoder_sr, codec),
codec->name,
" does not support sample rate ",
encoder_sr,
". Supported values are; ",
get_supported_samplerates(codec->supported_samplerates));
return encoder_sr;
}
if (codec->supported_samplerates &&
!supported_sample_rate(src_sample_rate, codec)) {
return codec->supported_samplerates[0];
}
return src_sample_rate;
}
std::vector<std::string> get_supported_channels(
const uint64_t* channel_layouts) {
std::vector<std::string> ret;
std::string get_supported_channels(const uint64_t* channel_layouts) {
std::vector<std::string> names;
while (*channel_layouts) {
ret.emplace_back(av_get_channel_name(*channel_layouts));
std::stringstream ss;
ss << av_get_channel_layout_nb_channels(*channel_layouts);
ss << " (" << av_get_channel_name(*channel_layouts) << ")";
names.emplace_back(ss.str());
++channel_layouts;
}
return ret;
return c10::Join(", ", names);
}
uint64_t get_channel_layout(int num_channels, const AVCodec* codec) {
uint64_t get_channel_layout(
const uint64_t src_ch_layout,
const c10::optional<int> enc_num_channels,
const AVCodec* codec) {
// If the override is presented, and if it is supported by codec, we use it.
if (enc_num_channels) {
const int& val = enc_num_channels.value();
TORCH_CHECK(
val > 0, "The number of channels must be greater than 0. Found: ", val);
if (!codec->channel_layouts) {
return static_cast<uint64_t>(av_get_default_channel_layout(val));
}
for (const uint64_t* it = codec->channel_layouts; *it; ++it) {
if (av_get_channel_layout_nb_channels(*it) == val) {
return *it;
}
}
TORCH_CHECK(
false,
"Codec ",
codec->name,
" does not support a channel layout consists of ",
val,
" channels. Supported values are: ",
get_supported_channels(codec->channel_layouts));
}
// If the codec does not have restriction on channel layout, we reuse the
// source channel layout
if (!codec->channel_layouts) {
return static_cast<uint64_t>(av_get_default_channel_layout(num_channels));
return src_ch_layout;
}
// If the codec has restriction, and source layout is supported, we reuse the
// source channel layout
for (const uint64_t* it = codec->channel_layouts; *it; ++it) {
if (av_get_channel_layout_nb_channels(*it) == num_channels) {
return *it;
if (*it == src_ch_layout) {
return src_ch_layout;
}
}
TORCH_CHECK(
false,
"Codec ",
codec->name,
" does not support a channel layout consists of ",
num_channels,
" channels. Supported values are: ",
c10::Join(", ", get_supported_channels(codec->channel_layouts)));
// Use the default layout of the codec.
return codec->channel_layouts[0];
}
void configure_audio_codec_ctx(
AVCodecContext* codec_ctx,
AVSampleFormat format,
int sample_rate,
int num_channels,
uint64_t channel_layout,
const c10::optional<CodecConfig>& codec_config) {
codec_ctx->sample_fmt = format;
codec_ctx->sample_rate = sample_rate;
codec_ctx->time_base = av_inv_q(av_d2q(sample_rate, 1 << 24));
codec_ctx->channels = num_channels;
codec_ctx->channels = av_get_channel_layout_nb_channels(channel_layout);
codec_ctx->channel_layout = channel_layout;
// Set optional stuff
......@@ -346,13 +390,13 @@ bool supported_pix_fmt(const AVPixelFormat fmt, const AVPixelFormat* pix_fmts) {
return false;
}
std::vector<std::string> get_supported_formats(const AVPixelFormat* pix_fmts) {
std::string get_supported_formats(const AVPixelFormat* pix_fmts) {
std::vector<std::string> ret;
while (*pix_fmts != AV_PIX_FMT_NONE) {
ret.emplace_back(av_get_pix_fmt_name(*pix_fmts));
++pix_fmts;
}
return ret;
return c10::Join(", ", ret);
}
AVPixelFormat get_enc_fmt(
......@@ -360,14 +404,15 @@ AVPixelFormat get_enc_fmt(
const c10::optional<std::string>& encoder_format,
const AVCodec* codec) {
if (encoder_format) {
auto fmt = get_pix_fmt(encoder_format.value());
const auto& val = encoder_format.value();
auto fmt = av_get_pix_fmt(val.c_str());
TORCH_CHECK(
supported_pix_fmt(fmt, codec->pix_fmts),
codec->name,
" does not support ",
encoder_format.value(),
val,
" format. Supported values are; ",
c10::Join(", ", get_supported_formats(codec->pix_fmts)));
get_supported_formats(codec->pix_fmts));
return fmt;
}
if (codec->pix_fmts) {
......@@ -388,22 +433,39 @@ bool supported_frame_rate(AVRational rate, const AVRational* rates) {
return false;
}
void validate_frame_rate(AVRational rate, const AVCodec* codec) {
TORCH_CHECK(
supported_frame_rate(rate, codec->supported_framerates),
codec->name,
" does not support frame rate ",
c10::Join("/", std::array<int, 2>{rate.num, rate.den}),
". Supported values are; ",
[&]() {
std::vector<std::string> ret;
for (auto r = codec->supported_framerates;
!(r->num == 0 && r->den == 0);
++r) {
ret.push_back(c10::Join("/", std::array<int, 2>{r->num, r->den}));
}
return c10::Join(", ", ret);
}());
AVRational get_enc_rate(
AVRational src_rate,
const c10::optional<double>& encoder_sample_rate,
const AVCodec* codec) {
if (encoder_sample_rate) {
const double& enc_rate = encoder_sample_rate.value();
TORCH_CHECK(
std::isfinite(enc_rate) && enc_rate > 0,
"Encoder sample rate must be positive and fininte. Found: ",
enc_rate);
AVRational rate = av_d2q(enc_rate, 1 << 24);
TORCH_CHECK(
supported_frame_rate(rate, codec->supported_framerates),
codec->name,
" does not support frame rate: ",
enc_rate,
". Supported values are; ",
[&]() {
std::vector<std::string> ret;
for (auto r = codec->supported_framerates;
!(r->num == 0 && r->den == 0);
++r) {
ret.push_back(c10::Join("/", std::array<int, 2>{r->num, r->den}));
}
return c10::Join(", ", ret);
}());
return rate;
}
if (codec->supported_framerates &&
!supported_frame_rate(src_rate, codec->supported_framerates)) {
return codec->supported_framerates[0];
}
return src_rate;
}
void configure_video_codec_ctx(
......@@ -506,38 +568,40 @@ AVStream* get_stream(AVFormatContext* format_ctx, AVCodecContext* codec_ctx) {
FilterGraph get_audio_filter_graph(
AVSampleFormat src_fmt,
int sample_rate,
uint64_t channel_layout,
int src_sample_rate,
uint64_t src_ch_layout,
const c10::optional<std::string>& filter_desc,
AVSampleFormat enc_fmt,
int enc_sample_rate,
uint64_t enc_ch_layout,
int nb_samples) {
const std::string desc = [&]() -> const std::string {
if (src_fmt == enc_fmt) {
if (nb_samples == 0) {
return filter_desc.value_or("anull");
} else {
std::stringstream ss;
if (filter_desc) {
ss << filter_desc.value() << ",";
}
ss << "asetnsamples=n=" << nb_samples << ":p=0";
return ss.str();
}
} else {
const auto desc = [&]() -> const std::string {
std::vector<std::string> parts;
if (filter_desc) {
parts.push_back(filter_desc.value());
}
if (filter_desc || src_fmt != enc_fmt ||
src_sample_rate != enc_sample_rate || src_ch_layout != enc_ch_layout) {
std::stringstream ss;
if (filter_desc) {
ss << filter_desc.value() << ",";
}
ss << "aformat=" << av_get_sample_fmt_name(enc_fmt);
if (nb_samples > 0) {
ss << ",asetnsamples=n=" << nb_samples << ":p=0";
}
return ss.str();
ss << "aformat=sample_fmts=" << av_get_sample_fmt_name(enc_fmt)
<< ":sample_rates=" << enc_sample_rate << ":channel_layouts=0x"
<< std::hex << enc_ch_layout;
parts.push_back(ss.str());
}
if (nb_samples > 0) {
std::stringstream ss;
ss << "asetnsamples=n=" << nb_samples << ":p=0";
parts.push_back(ss.str());
}
if (parts.size()) {
return c10::Join(",", parts);
}
return "anull";
}();
FilterGraph f{AVMEDIA_TYPE_AUDIO};
f.add_audio_src(src_fmt, {1, sample_rate}, sample_rate, channel_layout);
f.add_audio_src(
src_fmt, {1, src_sample_rate}, src_sample_rate, src_ch_layout);
f.add_sink();
f.add_process(desc);
f.create_filter();
......@@ -546,27 +610,48 @@ FilterGraph get_audio_filter_graph(
FilterGraph get_video_filter_graph(
AVPixelFormat src_fmt,
AVRational rate,
int width,
int height,
AVRational src_rate,
int src_width,
int src_height,
const c10::optional<std::string>& filter_desc,
AVPixelFormat enc_fmt,
AVRational enc_rate,
int enc_width,
int enc_height,
bool is_cuda) {
auto desc = [&]() -> std::string {
if (src_fmt == enc_fmt || is_cuda) {
const auto desc = [&]() -> const std::string {
if (is_cuda) {
return filter_desc.value_or("null");
} else {
}
std::vector<std::string> parts;
if (filter_desc) {
parts.push_back(filter_desc.value());
}
if (filter_desc || (src_width != enc_width || src_height != enc_height)) {
std::stringstream ss;
ss << "scale=" << enc_width << ":" << enc_height;
parts.emplace_back(ss.str());
}
if (filter_desc || src_fmt != enc_fmt) {
std::stringstream ss;
if (filter_desc) {
ss << filter_desc.value() << ",";
}
ss << "format=" << av_get_pix_fmt_name(enc_fmt);
return ss.str();
parts.emplace_back(ss.str());
}
if (filter_desc ||
(src_rate.num != enc_rate.num || src_rate.den != enc_rate.den)) {
std::stringstream ss;
ss << "fps=" << enc_rate.num << "/" << enc_rate.den;
parts.emplace_back(ss.str());
}
if (parts.size()) {
return c10::Join(",", parts);
}
return "null";
}();
FilterGraph f{AVMEDIA_TYPE_VIDEO};
f.add_video_src(src_fmt, av_inv_q(rate), rate, width, height, {1, 1});
f.add_video_src(
src_fmt, av_inv_q(src_rate), src_rate, src_width, src_height, {1, 1});
f.add_sink();
f.add_process(desc);
f.create_filter();
......@@ -587,7 +672,7 @@ AVFramePtr get_audio_frame(
frame->format = format;
frame->channel_layout = channel_layout;
frame->sample_rate = sample_rate;
frame->nb_samples = nb_samples ? nb_samples : 1024;
frame->nb_samples = nb_samples;
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0, "Error allocating the source audio frame:", av_err2string(ret));
......@@ -630,10 +715,11 @@ EncodeProcess get_audio_encode_process(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<int>& encoder_sample_rate,
const c10::optional<int>& encoder_num_channels,
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
// 1. Check the source format, rate and channels
const AVSampleFormat src_fmt = get_sample_fmt(format);
TORCH_CHECK(
src_sample_rate > 0,
"Sample rate must be positive. Found: ",
......@@ -642,6 +728,9 @@ EncodeProcess get_audio_encode_process(
src_num_channels > 0,
"The number of channels must be positive. Found: ",
src_num_channels);
const AVSampleFormat src_fmt = get_src_sample_fmt(format);
const auto src_ch_layout =
static_cast<uint64_t>(av_get_default_channel_layout(src_num_channels));
// 2. Fetch codec from default or override
TORCH_CHECK(
......@@ -651,30 +740,37 @@ EncodeProcess get_audio_encode_process(
const AVCodec* codec = get_codec(format_ctx->oformat->audio_codec, encoder);
// 3. Check that encoding sample format, sample rate and channels
// TODO: introduce encoder_sampel_rate option and allow to change sample rate
const AVSampleFormat enc_fmt = get_enc_fmt(src_fmt, encoder_format, codec);
validate_sample_rate(src_sample_rate, codec);
uint64_t channel_layout = get_channel_layout(src_num_channels, codec);
const int enc_sr = get_enc_sr(src_sample_rate, encoder_sample_rate, codec);
const uint64_t enc_ch_layout = [&]() -> uint64_t {
if (std::strcmp(codec->name, "vorbis") == 0) {
// Special case for vorbis.
// It only supports 2 channels, but it is not listed in channel_layouts
// attributes.
// https://github.com/FFmpeg/FFmpeg/blob/0684e58886881a998f1a7b510d73600ff1df2b90/libavcodec/vorbisenc.c#L1277
// This is the case for at least until FFmpeg 6.0, so it will be
// like this for a while.
return static_cast<uint64_t>(av_get_default_channel_layout(2));
}
return get_channel_layout(src_ch_layout, encoder_num_channels, codec);
}();
// 4. Initialize codec context
AVCodecContextPtr codec_ctx =
get_codec_ctx(codec, format_ctx->oformat->flags);
configure_audio_codec_ctx(
codec_ctx,
enc_fmt,
src_sample_rate,
src_num_channels,
channel_layout,
codec_config);
codec_ctx, enc_fmt, enc_sr, enc_ch_layout, codec_config);
open_codec(codec_ctx, encoder_option);
// 5. Build filter graph
FilterGraph filter_graph = get_audio_filter_graph(
src_fmt,
src_sample_rate,
channel_layout,
src_ch_layout,
filter_desc,
enc_fmt,
enc_sr,
enc_ch_layout,
codec_ctx->frame_size);
// 6. Instantiate source frame
......@@ -682,8 +778,8 @@ EncodeProcess get_audio_encode_process(
src_fmt,
src_sample_rate,
src_num_channels,
channel_layout,
codec_ctx->frame_size);
src_ch_layout,
codec_ctx->frame_size > 0 ? codec_ctx->frame_size : 256);
// 7. Instantiate Converter
TensorConverter converter{
......@@ -712,18 +808,21 @@ EncodeProcess get_video_encode_process(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<double>& encoder_frame_rate,
const c10::optional<int>& encoder_width,
const c10::optional<int>& encoder_height,
const c10::optional<std::string>& hw_accel,
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
// 1. Checkc the source format, rate and resolution
const AVPixelFormat src_fmt = get_pix_fmt(format);
AVRational src_rate = av_d2q(frame_rate, 1 << 24);
TORCH_CHECK(
src_rate.num > 0 && src_rate.den != 0,
std::isfinite(frame_rate) && frame_rate > 0,
"Frame rate must be positive and finite. Found: ",
frame_rate);
TORCH_CHECK(src_width > 0, "width must be positive. Found: ", src_width);
TORCH_CHECK(src_height > 0, "height must be positive. Found: ", src_height);
const AVPixelFormat src_fmt = get_src_pix_fmt(format);
const AVRational src_rate = av_d2q(frame_rate, 1 << 24);
// 2. Fetch codec from default or override
TORCH_CHECK(
......@@ -734,13 +833,29 @@ EncodeProcess get_video_encode_process(
// 3. Check that encoding format, rate
const AVPixelFormat enc_fmt = get_enc_fmt(src_fmt, encoder_format, codec);
validate_frame_rate(src_rate, codec);
const AVRational enc_rate = get_enc_rate(src_rate, encoder_frame_rate, codec);
const int enc_width = [&]() -> int {
if (!encoder_width) {
return src_width;
}
const int& val = encoder_width.value();
TORCH_CHECK(val > 0, "Encoder width must be positive. Found: ", val);
return val;
}();
const int enc_height = [&]() -> int {
if (!encoder_height) {
return src_height;
}
const int& val = encoder_height.value();
TORCH_CHECK(val > 0, "Encoder height must be positive. Found: ", val);
return val;
}();
// 4. Initialize codec context
AVCodecContextPtr codec_ctx =
get_codec_ctx(codec, format_ctx->oformat->flags);
configure_video_codec_ctx(
codec_ctx, enc_fmt, src_rate, src_width, src_height, codec_config);
codec_ctx, enc_fmt, enc_rate, enc_width, enc_height, codec_config);
if (hw_accel) {
#ifdef USE_CUDA
configure_hw_accel(codec_ctx, hw_accel.value());
......@@ -761,6 +876,9 @@ EncodeProcess get_video_encode_process(
src_height,
filter_desc,
enc_fmt,
enc_rate,
enc_width,
enc_height,
hw_accel.has_value());
// 6. Instantiate source frame
......
......@@ -28,9 +28,10 @@ class EncodeProcess {
void process(const torch::Tensor& tensor, const c10::optional<double>& pts);
void process_frame(AVFrame* src);
void flush();
private:
void process_frame(AVFrame* src);
};
EncodeProcess get_audio_encode_process(
......@@ -41,6 +42,8 @@ EncodeProcess get_audio_encode_process(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<int>& encoder_sample_rate,
const c10::optional<int>& encoder_num_channels,
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc);
......@@ -53,6 +56,9 @@ EncodeProcess get_video_encode_process(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<double>& encoder_frame_rate,
const c10::optional<int>& encoder_width,
const c10::optional<int>& encoder_height,
const c10::optional<std::string>& hw_accel,
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc);
......
......@@ -60,6 +60,8 @@ void StreamWriter::add_audio_stream(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<int>& encoder_sample_rate,
const c10::optional<int>& encoder_num_channels,
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
......@@ -74,6 +76,8 @@ void StreamWriter::add_audio_stream(
encoder,
encoder_option,
encoder_format,
encoder_sample_rate,
encoder_num_channels,
codec_config,
filter_desc));
}
......@@ -86,6 +90,9 @@ void StreamWriter::add_video_stream(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<double>& encoder_frame_rate,
const c10::optional<int>& encoder_width,
const c10::optional<int>& encoder_height,
const c10::optional<std::string>& hw_accel,
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
......@@ -102,6 +109,9 @@ void StreamWriter::add_video_stream(
encoder,
encoder_option,
encoder_format,
encoder_frame_rate,
encoder_width,
encoder_height,
hw_accel,
codec_config,
filter_desc));
......
......@@ -109,6 +109,8 @@ class StreamWriter {
const c10::optional<std::string>& encoder = c10::nullopt,
const c10::optional<OptionDict>& encoder_option = c10::nullopt,
const c10::optional<std::string>& encoder_format = c10::nullopt,
const c10::optional<int>& encoder_sample_rate = c10::nullopt,
const c10::optional<int>& encoder_num_channels = c10::nullopt,
const c10::optional<CodecConfig>& codec_config = c10::nullopt,
const c10::optional<std::string>& filter_desc = c10::nullopt);
......@@ -152,6 +154,9 @@ class StreamWriter {
const c10::optional<std::string>& encoder = c10::nullopt,
const c10::optional<OptionDict>& encoder_option = c10::nullopt,
const c10::optional<std::string>& encoder_format = c10::nullopt,
const c10::optional<double>& encoder_frame_rate = c10::nullopt,
const c10::optional<int>& encoder_width = c10::nullopt,
const c10::optional<int>& encoder_height = c10::nullopt,
const c10::optional<std::string>& hw_accel = c10::nullopt,
const c10::optional<CodecConfig>& codec_config = c10::nullopt,
const c10::optional<std::string>& filter_desc = c10::nullopt);
......
......@@ -45,7 +45,18 @@ _encoder_format = """Format used to encode media.
To list supported formats for the encoder, you can use
``ffmpeg -h encoder=<ENCODER>`` command.
Default: ``None``."""
Default: ``None``.
Note:
When ``encoder_format`` option is not provided, encoder uses its default format.
For example, when encoding audio into wav format, 16-bit signed integer is used,
and when encoding video into mp4 format (h264 encoder), one of YUV format is used.
This is because typically, 32-bit or 16-bit floating point is used in audio models but
they are not commonly used in audio formats. Similarly, RGB24 is commonly used in vision
models, but video formats usually (and better) support YUV formats.
"""
_codec_config = """Codec configuration. Please refer to :py:class:`CodecConfig` for
configuration options.
......@@ -162,6 +173,8 @@ class StreamWriter:
encoder: Optional[str] = None,
encoder_option: Optional[Dict[str, str]] = None,
encoder_format: Optional[str] = None,
encoder_sample_rate: Optional[int] = None,
encoder_num_channels: Optional[int] = None,
codec_config: Optional[CodecConfig] = None,
filter_desc: Optional[str] = None,
):
......@@ -190,12 +203,53 @@ class StreamWriter:
encoder_format (str or None, optional): {encoder_format}
encoder_sample_rate (int or None, optional): Override the sample rate used for encoding time.
Some encoders pose restriction on the sample rate used for encoding.
If the source sample rate is not supported by the encoder, the source sample rate is used,
otherwise a default one is picked.
For example, ``"opus"`` encoder only supports 48k Hz, so, when encoding a
waveform with ``"opus"`` encoder, it is always encoded as 48k Hz.
Meanwhile ``"mp3"`` (``"libmp3lame"``) supports 44.1k, 48k, 32k, 22.05k,
24k, 16k, 11.025k, 12k and 8k Hz.
If the original sample rate is one of these, then the original sample rate
is used, otherwise it will be resampled to a default one (44.1k).
When encoding into WAV format, there is no restriction on sample rate,
so the original sample rate will be used.
Providing ``encoder_sample_rate`` will override this behavior and
make encoder attempt to use the provided sample rate.
The provided value must be one support by the encoder.
encoder_num_channels (int or None, optional): Override the number of channels used for encoding.
Similar to sample rate, some encoders (such as ``"opus"``,
``"vorbis"`` and ``"g722"``) pose restriction on
the numbe of channels that can be used for encoding.
If the original number of channels is supported by encoder,
then it will be used, otherwise, the encoder attempts to
remix the channel to one of the supported ones.
Providing ``encoder_num_channels`` will override this behavior and
make encoder attempt to use the provided number of channels.
The provided value must be one support by the encoder.
codec_config (CodecConfig or None, optional): {codec_config}
filter_desc (str or None, optional): {filter_desc}
"""
self._s.add_audio_stream(
sample_rate, num_channels, format, encoder, encoder_option, encoder_format, codec_config, filter_desc
sample_rate,
num_channels,
format,
encoder,
encoder_option,
encoder_format,
encoder_sample_rate,
encoder_num_channels,
codec_config,
filter_desc,
)
@_format_common_args
......@@ -208,6 +262,9 @@ class StreamWriter:
encoder: Optional[str] = None,
encoder_option: Optional[Dict[str, str]] = None,
encoder_format: Optional[str] = None,
encoder_frame_rate: Optional[float] = None,
encoder_width: Optional[int] = None,
encoder_height: Optional[int] = None,
hw_accel: Optional[str] = None,
codec_config: Optional[CodecConfig] = None,
filter_desc: Optional[str] = None,
......@@ -242,6 +299,24 @@ class StreamWriter:
encoder_format (str or None, optional): {encoder_format}
encoder_frame_rate (float or None, optional): Override the frame rate used for encoding.
Some encoders, (such as ``"mpeg1"`` and ``"mpeg2"``) pose restriction on the
frame rate that can be used for encoding.
If such case, if the source frame rate (provided as ``frame_rate``) is not
one of the supported frame rate, then a default one is picked, and the frame rate
is changed on-the-fly. Otherwise the source frame rate is used.
Providing ``encoder_frame_rate`` will override this behavior and
make encoder attempts to use the provided sample rate.
The provided value must be one support by the encoder.
encoder_width (int or None, optional): Width of the image used for encoding.
This allows to change the image size during encoding.
encoder_height (int or None, optional): Height of the image used for encoding.
This allows to change the image size during encoding.
hw_accel (str or None, optional): Enable hardware acceleration.
When video is encoded on CUDA hardware, for example
......@@ -264,6 +339,9 @@ class StreamWriter:
encoder,
encoder_option,
encoder_format,
encoder_frame_rate,
encoder_width,
encoder_height,
hw_accel,
codec_config,
filter_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment