Commit 9bb35070 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add EncodingConfig (#3179)

Summary:
Adds config object `EncodingConfig` and modifies `StreamWriter` to allow for passing in additional encoder configuration parameters, e.g. bit rate and compression level.

Pull Request resolved: https://github.com/pytorch/audio/pull/3179

Pull Request resolved: https://github.com/pytorch/audio/pull/3164

Reviewed By: mthrok

Differential Revision: D43861413

Pulled By: hwangjeff

fbshipit-source-id: c1682cb2f6e682ab6f1a506511d2be7c7b254161
parent a6b34a5d
......@@ -38,7 +38,12 @@ Methods
-------
{%- for item in members %}
{%- if not item.startswith('_') and item not in inherited_members and item not in attributes %}
{%- if
not item.startswith('_')
and item not in inherited_members
and item not in attributes
and item != "EncodeConfig"
%}
{{ item | underline("~") }}
......@@ -50,6 +55,7 @@ Methods
{%- endfor %}
{%- endif %}
{%- if name == "StreamReader" %}
Support Structures
......@@ -71,4 +77,15 @@ Support Structures
:members:
{%- endfor %}
{%- elif name == "StreamWriter" %}
Support Structures
------------------
EncodeConfig
~~~~~~~~~~~~
.. autoclass:: torchaudio.io::StreamWriter.EncodeConfig()
:members:
{%- endif %}
import io
import math
import torch
......@@ -487,3 +488,48 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
# could introduce a descrepancy, so we compare floats and use math.isclose
# for that.
assert math.isclose(val, ref)
def test_encode_config(self):
"""Can successfully set configuration and write audio."""
ext = "mp3"
filename = f"test.{ext}"
sample_rate = 44100
num_channels = 2
# Write data
dst = self.get_dst(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
config = torchaudio.io.StreamWriter.EncodeConfig(bit_rate=198_000, compression_level=3)
writer.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels, config=config)
audio = torch.zeros((8000, 2))
with writer.open():
writer.write_audio_chunk(0, audio)
def test_encode_config_bit_rate_output(self):
"""Increasing the specified bit rate yields a larger encoded output."""
ext = "mp3"
sample_rate = 44100
num_channels = 2
audio = torch.rand((8000, num_channels))
def write_audio(buffer, bit_rate):
writer = torchaudio.io.StreamWriter(dst=buffer, format=ext)
writer.add_audio_stream(
sample_rate=sample_rate,
num_channels=num_channels,
config=torchaudio.io.StreamWriter.EncodeConfig(bit_rate=bit_rate),
)
with writer.open():
writer.write_audio_chunk(0, audio)
dst = io.BytesIO()
write_audio(dst, 198_000)
out0_size = dst.tell()
dst = io.BytesIO()
write_audio(dst, 320_000)
out1_size = dst.tell()
self.assertGreater(out1_size, out0_size)
......@@ -33,6 +33,8 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
py::class_<Chunk>(m, "Chunk", py::module_local())
.def_readwrite("frames", &Chunk::frames)
.def_readwrite("pts", &Chunk::pts);
py::class_<EncodingConfig>(m, "EncodingConfig", py::module_local())
.def(py::init<int, int, int, int>());
py::class_<StreamWriter>(m, "StreamWriter", py::module_local())
.def(py::init<const std::string&, const c10::optional<std::string>&>())
.def("set_metadata", &StreamWriter::set_metadata)
......
......@@ -89,7 +89,8 @@ void configure_audio_codec(
AVCodecContextPtr& ctx,
int64_t sample_rate,
int64_t num_channels,
const c10::optional<std::string>& format) {
const c10::optional<std::string>& format,
const c10::optional<EncodingConfig>& config) {
// TODO: Review options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00122
// - bit_rate
......@@ -160,6 +161,17 @@ void configure_audio_codec(
}
}
ctx->channel_layout = static_cast<uint64_t>(layout);
// Set optional stuff
if (config) {
auto& cfg = config.value();
if (cfg.bit_rate > 0) {
ctx->bit_rate = cfg.bit_rate;
}
if (cfg.compression_level != -1) {
ctx->compression_level = cfg.compression_level;
}
}
}
void open_codec(
......@@ -177,9 +189,10 @@ AVCodecContextPtr get_audio_codec(
int64_t num_channels,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format) {
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_AUDIO, oformat, encoder);
configure_audio_codec(ctx, sample_rate, num_channels, encoder_format);
configure_audio_codec(ctx, sample_rate, num_channels, encoder_format, config);
open_codec(ctx, encoder_option);
return ctx;
}
......@@ -275,7 +288,8 @@ void configure_video_codec(
double frame_rate,
int64_t width,
int64_t height,
const c10::optional<std::string>& format) {
const c10::optional<std::string>& format,
const c10::optional<EncodingConfig>& config) {
// TODO: Review other options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00147
// - bit_rate
......@@ -340,6 +354,23 @@ void configure_video_codec(
}
return ret;
}();
// Set optional stuff
if (config) {
auto& cfg = config.value();
if (cfg.bit_rate > 0) {
ctx->bit_rate = cfg.bit_rate;
}
if (cfg.compression_level != -1) {
ctx->compression_level = cfg.compression_level;
}
if (cfg.gop_size != -1) {
ctx->gop_size = cfg.gop_size;
}
if (cfg.max_b_frames != -1) {
ctx->max_b_frames = cfg.max_b_frames;
}
}
}
void configure_hw_accel(AVCodecContext* ctx, const std::string& hw_accel) {
......@@ -400,9 +431,10 @@ AVCodecContextPtr get_video_codec(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel) {
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_VIDEO, oformat, encoder);
configure_video_codec(ctx, frame_rate, width, height, encoder_format);
configure_video_codec(ctx, frame_rate, width, height, encoder_format, config);
if (hw_accel) {
#ifdef USE_CUDA
......@@ -479,14 +511,16 @@ EncodeProcess::EncodeProcess(
const enum AVSampleFormat format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format)
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config)
: codec_ctx(get_audio_codec(
format_ctx->oformat,
sample_rate,
num_channels,
encoder,
encoder_option,
encoder_format)),
encoder_format,
config)),
encoder(format_ctx, codec_ctx),
filter(get_audio_filter(format, codec_ctx)),
src_frame(get_audio_frame(format, sample_rate, num_channels, codec_ctx)),
......@@ -501,7 +535,8 @@ EncodeProcess::EncodeProcess(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel)
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config)
: codec_ctx(get_video_codec(
format_ctx->oformat,
frame_rate,
......@@ -510,7 +545,8 @@ EncodeProcess::EncodeProcess(
encoder,
encoder_option,
encoder_format,
hw_accel)),
hw_accel,
config)),
encoder(format_ctx, codec_ctx),
filter(get_video_filter(format, codec_ctx)),
src_frame(get_video_frame(format, codec_ctx)),
......
......@@ -4,6 +4,7 @@
#include <torchaudio/csrc/ffmpeg/filter_graph.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/encoder.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/tensor_converter.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/types.h>
namespace torchaudio::io {
......@@ -25,7 +26,8 @@ class EncodeProcess {
const enum AVSampleFormat format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format);
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config);
// constructor for video
EncodeProcess(
......@@ -37,7 +39,8 @@ class EncodeProcess {
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel);
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config);
void process(
AVMediaType type,
......
......@@ -101,7 +101,8 @@ void StreamWriter::add_audio_stream(
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format) {
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config) {
processes.emplace_back(
pFormatContext,
sample_rate,
......@@ -109,7 +110,8 @@ void StreamWriter::add_audio_stream(
get_src_sample_fmt(format),
encoder,
encoder_option,
encoder_format);
encoder_format,
config);
}
void StreamWriter::add_video_stream(
......@@ -120,7 +122,8 @@ void StreamWriter::add_video_stream(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel) {
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) {
processes.emplace_back(
pFormatContext,
frame_rate,
......@@ -130,7 +133,8 @@ void StreamWriter::add_video_stream(
encoder,
encoder_option,
encoder_format,
hw_accel);
hw_accel,
config);
}
void StreamWriter::set_metadata(const OptionDict& metadata) {
......
......@@ -4,6 +4,7 @@
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/filter_graph.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/encode_process.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/types.h>
namespace torchaudio {
namespace io {
......@@ -104,7 +105,9 @@ class StreamWriter {
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format);
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config);
/// Add an output video stream.
///
/// @param frame_rate Frame rate
......@@ -142,7 +145,8 @@ class StreamWriter {
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel);
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config);
/// Set file-level metadata
/// @param metadata metadata.
void set_metadata(const OptionDict& metadata);
......
#pragma once
namespace torchaudio::io {
struct EncodingConfig {
int bit_rate = -1;
int compression_level = -1;
// video
int gop_size = -1;
int max_b_frames = -1;
};
} // namespace torchaudio::io
from dataclasses import dataclass
from typing import BinaryIO, Dict, Optional, Union
import torch
import torchaudio
if torchaudio._extension._FFMPEG_INITIALIZED:
ConfigBase = torchaudio.lib._torchaudio_ffmpeg.EncodingConfig
else:
ConfigBase = object
def _format_doc(**kwargs):
def decorator(obj):
obj.__doc__ = obj.__doc__.format(**kwargs)
......@@ -103,6 +110,25 @@ class StreamWriter:
Default: `4096`.
"""
@dataclass
class EncodeConfig(ConfigBase):
"""Encoding configuration."""
bit_rate: int = -1
"""Bit rate"""
compression_level: int = -1
"""Compression level"""
gop_size: int = -1
"""The number of pictures in a group of pictures, or 0 for intra_only"""
max_b_frames: int = -1
"""maximum number of B-frames between non-B-frames."""
def __post_init__(self):
super().__init__(self.bit_rate, self.compression_level, self.gop_size, self.max_b_frames)
def __init__(
self,
dst: Union[str, BinaryIO],
......@@ -126,6 +152,7 @@ class StreamWriter:
encoder: Optional[str] = None,
encoder_option: Optional[Dict[str, str]] = None,
encoder_format: Optional[str] = None,
config: Optional[EncodeConfig] = None,
):
"""Add an output audio stream.
......@@ -152,7 +179,7 @@ class StreamWriter:
encoder_format (str or None, optional): {encoder_format}
"""
self._s.add_audio_stream(sample_rate, num_channels, format, encoder, encoder_option, encoder_format)
self._s.add_audio_stream(sample_rate, num_channels, format, encoder, encoder_option, encoder_format, config)
@_format_common_args
def add_video_stream(
......@@ -165,6 +192,7 @@ class StreamWriter:
encoder_option: Optional[Dict[str, str]] = None,
encoder_format: Optional[str] = None,
hw_accel: Optional[str] = None,
config: Optional[EncodeConfig] = None,
):
"""Add an output video stream.
......@@ -206,7 +234,9 @@ class StreamWriter:
If `None`, the video chunk Tensor has to be CPU Tensor.
Default: ``None``.
"""
self._s.add_video_stream(frame_rate, width, height, format, encoder, encoder_option, encoder_format, hw_accel)
self._s.add_video_stream(
frame_rate, width, height, format, encoder, encoder_option, encoder_format, hw_accel, config
)
def set_metadata(self, metadata: Dict[str, str]):
"""Set file-level metadata
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment