Commit 715eb34a authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add additional filter graph option to StreamWriter (#3194)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3194

Reviewed By: hwangjeff

Differential Revision: D44283910

Pulled By: mthrok

fbshipit-source-id: 49125724896bf7190ec27f056b6bfef260019f8e
parent 0846a411
......@@ -580,3 +580,57 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
out1_size = dst.tell()
self.assertGreater(out1_size, out0_size)
def test_filter_graph_audio(self):
"""Can apply additional effect with filter graph"""
sample_rate = 8000
num_channels = 2
ext = "wav"
filename = f"test.{ext}"
original = get_audio_chunk("s16", num_channels=num_channels, sample_rate=sample_rate)
dst = self.get_dst(filename)
w = StreamWriter(dst, format=ext)
w.add_audio_stream(sample_rate=8000, num_channels=num_channels, filter_desc="areverse", format="s16")
with w.open():
w.write_audio_chunk(0, original)
# check
if self.test_fileobj:
dst.flush()
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
reader.add_audio_stream(-1)
reader.process_all_packets()
(output,) = reader.pop_chunks()
self.assertEqual(output, original.flip(0))
def test_filter_graph_video(self):
"""Can apply additional effect with filter graph"""
rate = 30
num_frames, width, height = 400, 160, 90
ext = "mp4"
filename = f"test.{ext}"
original = torch.zeros((num_frames, 3, height, width), dtype=torch.uint8)
dst = self.get_dst(filename)
w = StreamWriter(dst, format=ext)
w.add_video_stream(frame_rate=rate, format="rgb24", height=height, width=width, filter_desc="framestep=2")
with w.open():
w.write_video_chunk(0, original)
# check
if self.test_fileobj:
dst.flush()
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
reader.add_video_stream(-1)
reader.process_all_packets()
(output,) = reader.pop_chunks()
self.assertEqual(output.shape, [num_frames // 2, 3, height, width])
......@@ -513,19 +513,26 @@ FilterGraph get_audio_filter_graph(
AVSampleFormat src_fmt,
int sample_rate,
uint64_t channel_layout,
const c10::optional<std::string>& filter_desc,
AVSampleFormat enc_fmt,
int nb_samples) {
const std::string filter_desc = [&]() -> const std::string {
const std::string desc = [&]() -> const std::string {
if (src_fmt == enc_fmt) {
if (nb_samples == 0) {
return "anull";
return filter_desc.value_or("anull");
} else {
std::stringstream ss;
if (filter_desc) {
ss << filter_desc.value() << ",";
}
ss << "asetnsamples=n=" << nb_samples << ":p=0";
return ss.str();
}
} else {
std::stringstream ss;
if (filter_desc) {
ss << filter_desc.value() << ",";
}
ss << "aformat=" << av_get_sample_fmt_name(enc_fmt);
if (nb_samples > 0) {
ss << ",asetnsamples=n=" << nb_samples << ":p=0";
......@@ -537,7 +544,7 @@ FilterGraph get_audio_filter_graph(
FilterGraph f{AVMEDIA_TYPE_AUDIO};
f.add_audio_src(src_fmt, {1, sample_rate}, sample_rate, channel_layout);
f.add_sink();
f.add_process(filter_desc);
f.add_process(desc);
f.create_filter();
return f;
}
......@@ -547,13 +554,17 @@ FilterGraph get_video_filter_graph(
AVRational rate,
int width,
int height,
const c10::optional<std::string>& filter_desc,
AVPixelFormat enc_fmt,
bool is_cuda) {
auto desc = [&]() -> std::string {
if (src_fmt == enc_fmt || is_cuda) {
return "null";
return filter_desc.value_or("null");
} else {
std::stringstream ss;
if (filter_desc) {
ss << filter_desc.value() << ",";
}
ss << "format=" << av_get_pix_fmt_name(enc_fmt);
return ss.str();
}
......@@ -624,7 +635,8 @@ EncodeProcess get_audio_encode_process(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<CodecConfig>& codec_config) {
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
// 1. Check the source format, rate and channels
const AVSampleFormat src_fmt = get_sample_fmt(format);
TORCH_CHECK(
......@@ -663,7 +675,12 @@ EncodeProcess get_audio_encode_process(
// 5. Build filter graph
FilterGraph filter_graph = get_audio_filter_graph(
src_fmt, src_sample_rate, channel_layout, enc_fmt, codec_ctx->frame_size);
src_fmt,
src_sample_rate,
channel_layout,
filter_desc,
enc_fmt,
codec_ctx->frame_size);
// 6. Instantiate source frame
AVFramePtr src_frame = get_audio_frame(
......@@ -701,7 +718,8 @@ EncodeProcess get_video_encode_process(
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<CodecConfig>& codec_config) {
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
// 1. Checkc the source format, rate and resolution
const AVPixelFormat src_fmt = get_pix_fmt(format);
AVRational src_rate = av_d2q(frame_rate, 1 << 24);
......@@ -742,7 +760,13 @@ EncodeProcess get_video_encode_process(
// 5. Build filter graph
FilterGraph filter_graph = get_video_filter_graph(
src_fmt, src_rate, src_width, src_height, enc_fmt, hw_accel.has_value());
src_fmt,
src_rate,
src_width,
src_height,
filter_desc,
enc_fmt,
hw_accel.has_value());
// 6. Instantiate source frame
AVFramePtr src_frame = [&]() {
......
......@@ -41,7 +41,8 @@ EncodeProcess get_audio_encode_process(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<CodecConfig>& codec_config);
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc);
EncodeProcess get_video_encode_process(
AVFormatContext* format_ctx,
......@@ -53,6 +54,7 @@ EncodeProcess get_video_encode_process(
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<CodecConfig>& codec_config);
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc);
}; // namespace torchaudio::io
......@@ -60,7 +60,8 @@ void StreamWriter::add_audio_stream(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<CodecConfig>& codec_config) {
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
......@@ -73,7 +74,8 @@ void StreamWriter::add_audio_stream(
encoder,
encoder_option,
encoder_format,
codec_config));
codec_config,
filter_desc));
}
void StreamWriter::add_video_stream(
......@@ -85,7 +87,8 @@ void StreamWriter::add_video_stream(
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<CodecConfig>& codec_config) {
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
......@@ -100,7 +103,8 @@ void StreamWriter::add_video_stream(
encoder_option,
encoder_format,
hw_accel,
codec_config));
codec_config,
filter_desc));
}
void StreamWriter::set_metadata(const OptionDict& metadata) {
......
......@@ -100,6 +100,8 @@ class StreamWriter {
/// To list supported formats for the encoder, you can use
/// ``ffmpeg -h encoder=<ENCODER>`` command.
/// @param codec_config Codec configuration.
/// @param filter_desc Additional processing to apply before
/// encoding the input data
void add_audio_stream(
int sample_rate,
int num_channels,
......@@ -107,7 +109,8 @@ class StreamWriter {
const c10::optional<std::string>& encoder = c10::nullopt,
const c10::optional<OptionDict>& encoder_option = c10::nullopt,
const c10::optional<std::string>& encoder_format = c10::nullopt,
const c10::optional<CodecConfig>& codec_config = c10::nullopt);
const c10::optional<CodecConfig>& codec_config = c10::nullopt,
const c10::optional<std::string>& filter_desc = c10::nullopt);
/// Add an output video stream.
///
......@@ -139,6 +142,8 @@ class StreamWriter {
///
/// If `None`, the video chunk Tensor has to be a CPU Tensor.
/// @endparblock
/// @param filter_desc Additional processing to apply before
/// encoding the input data
void add_video_stream(
double frame_rate,
int width,
......@@ -148,7 +153,8 @@ class StreamWriter {
const c10::optional<OptionDict>& encoder_option = c10::nullopt,
const c10::optional<std::string>& encoder_format = c10::nullopt,
const c10::optional<std::string>& hw_accel = c10::nullopt,
const c10::optional<CodecConfig>& codec_config = c10::nullopt);
const c10::optional<CodecConfig>& codec_config = c10::nullopt,
const c10::optional<std::string>& filter_desc = c10::nullopt);
/// Set file-level metadata
/// @param metadata metadata.
void set_metadata(const OptionDict& metadata);
......
......@@ -144,9 +144,12 @@ void write_interlaced_video(
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(2) == buffer->width);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(3) == num_channels);
// TODO: writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472
TORCH_INTERNAL_ASSERT(av_frame_is_writable(buffer), "frame is not writable.");
if (!av_frame_is_writable(buffer)) {
int ret = av_frame_make_writable(buffer);
TORCH_INTERNAL_ASSERT(
ret >= 0, "Failed to make frame writable: ", av_err2string(ret));
}
size_t stride = buffer->width * num_channels;
uint8_t* src = frame.data_ptr<uint8_t>();
......@@ -191,9 +194,12 @@ void write_planar_video(
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(2), buffer->height);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(3), buffer->width);
// TODO: writable
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472
TORCH_INTERNAL_ASSERT(av_frame_is_writable(buffer), "frame is not writable.");
if (!av_frame_is_writable(buffer)) {
int ret = av_frame_make_writable(buffer);
TORCH_INTERNAL_ASSERT(
ret >= 0, "Failed to make frame writable: ", av_err2string(ret));
}
for (int j = 0; j < num_colors; ++j) {
uint8_t* src = frame.index({0, j}).data_ptr<uint8_t>();
......
......@@ -53,11 +53,15 @@ _codec_config = """Codec configuration. Please refer to :py:class:`CodecConfig`
Default: ``None``."""
_filter_desc = """Additional processing to apply before encoding the input media.
"""
_format_common_args = _format_doc(
encoder=_encoder,
encoder_option=_encoder_option,
encoder_format=_encoder_format,
codec_config=_codec_config,
filter_desc=_filter_desc,
)
......@@ -159,6 +163,7 @@ class StreamWriter:
encoder_option: Optional[Dict[str, str]] = None,
encoder_format: Optional[str] = None,
codec_config: Optional[CodecConfig] = None,
filter_desc: Optional[str] = None,
):
"""Add an output audio stream.
......@@ -186,9 +191,11 @@ class StreamWriter:
encoder_format (str or None, optional): {encoder_format}
codec_config (CodecConfig or None, optional): {codec_config}
filter_desc (str or None, optional): {filter_desc}
"""
self._s.add_audio_stream(
sample_rate, num_channels, format, encoder, encoder_option, encoder_format, codec_config
sample_rate, num_channels, format, encoder, encoder_option, encoder_format, codec_config, filter_desc
)
@_format_common_args
......@@ -203,6 +210,7 @@ class StreamWriter:
encoder_format: Optional[str] = None,
hw_accel: Optional[str] = None,
codec_config: Optional[CodecConfig] = None,
filter_desc: Optional[str] = None,
):
"""Add an output video stream.
......@@ -245,9 +253,20 @@ class StreamWriter:
Default: ``None``.
codec_config (CodecConfig or None, optional): {codec_config}
filter_desc (str or None, optional): {filter_desc}
"""
self._s.add_video_stream(
frame_rate, width, height, format, encoder, encoder_option, encoder_format, hw_accel, codec_config
frame_rate,
width,
height,
format,
encoder,
encoder_option,
encoder_format,
hw_accel,
codec_config,
filter_desc,
)
def set_metadata(self, metadata: Dict[str, str]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment