Commit 9bb35070 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add EncodingConfig (#3179)

Summary:
Adds config object `EncodingConfig` and modifies `StreamWriter` to allow for passing in additional encoder configuration parameters, e.g. bit rate and compression level.

Pull Request resolved: https://github.com/pytorch/audio/pull/3179

Pull Request resolved: https://github.com/pytorch/audio/pull/3164

Reviewed By: mthrok

Differential Revision: D43861413

Pulled By: hwangjeff

fbshipit-source-id: c1682cb2f6e682ab6f1a506511d2be7c7b254161
parent a6b34a5d
...@@ -38,7 +38,12 @@ Methods ...@@ -38,7 +38,12 @@ Methods
------- -------
{%- for item in members %} {%- for item in members %}
{%- if not item.startswith('_') and item not in inherited_members and item not in attributes %} {%- if
not item.startswith('_')
and item not in inherited_members
and item not in attributes
and item != "EncodeConfig"
%}
{{ item | underline("~") }} {{ item | underline("~") }}
...@@ -50,6 +55,7 @@ Methods ...@@ -50,6 +55,7 @@ Methods
{%- endfor %} {%- endfor %}
{%- endif %} {%- endif %}
{%- if name == "StreamReader" %} {%- if name == "StreamReader" %}
Support Structures Support Structures
...@@ -71,4 +77,15 @@ Support Structures ...@@ -71,4 +77,15 @@ Support Structures
:members: :members:
{%- endfor %} {%- endfor %}
{%- elif name == "StreamWriter" %}
Support Structures
------------------
EncodeConfig
~~~~~~~~~~~~
.. autoclass:: torchaudio.io::StreamWriter.EncodeConfig()
:members:
{%- endif %} {%- endif %}
import io
import math import math
import torch import torch
...@@ -487,3 +488,48 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -487,3 +488,48 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
# could introduce a descrepancy, so we compare floats and use math.isclose # could introduce a descrepancy, so we compare floats and use math.isclose
# for that. # for that.
assert math.isclose(val, ref) assert math.isclose(val, ref)
def test_encode_config(self):
"""Can successfully set configuration and write audio."""
ext = "mp3"
filename = f"test.{ext}"
sample_rate = 44100
num_channels = 2
# Write data
dst = self.get_dst(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
config = torchaudio.io.StreamWriter.EncodeConfig(bit_rate=198_000, compression_level=3)
writer.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels, config=config)
audio = torch.zeros((8000, 2))
with writer.open():
writer.write_audio_chunk(0, audio)
def test_encode_config_bit_rate_output(self):
"""Increasing the specified bit rate yields a larger encoded output."""
ext = "mp3"
sample_rate = 44100
num_channels = 2
audio = torch.rand((8000, num_channels))
def write_audio(buffer, bit_rate):
writer = torchaudio.io.StreamWriter(dst=buffer, format=ext)
writer.add_audio_stream(
sample_rate=sample_rate,
num_channels=num_channels,
config=torchaudio.io.StreamWriter.EncodeConfig(bit_rate=bit_rate),
)
with writer.open():
writer.write_audio_chunk(0, audio)
dst = io.BytesIO()
write_audio(dst, 198_000)
out0_size = dst.tell()
dst = io.BytesIO()
write_audio(dst, 320_000)
out1_size = dst.tell()
self.assertGreater(out1_size, out0_size)
...@@ -33,6 +33,8 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) { ...@@ -33,6 +33,8 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
py::class_<Chunk>(m, "Chunk", py::module_local()) py::class_<Chunk>(m, "Chunk", py::module_local())
.def_readwrite("frames", &Chunk::frames) .def_readwrite("frames", &Chunk::frames)
.def_readwrite("pts", &Chunk::pts); .def_readwrite("pts", &Chunk::pts);
py::class_<EncodingConfig>(m, "EncodingConfig", py::module_local())
.def(py::init<int, int, int, int>());
py::class_<StreamWriter>(m, "StreamWriter", py::module_local()) py::class_<StreamWriter>(m, "StreamWriter", py::module_local())
.def(py::init<const std::string&, const c10::optional<std::string>&>()) .def(py::init<const std::string&, const c10::optional<std::string>&>())
.def("set_metadata", &StreamWriter::set_metadata) .def("set_metadata", &StreamWriter::set_metadata)
......
...@@ -89,7 +89,8 @@ void configure_audio_codec( ...@@ -89,7 +89,8 @@ void configure_audio_codec(
AVCodecContextPtr& ctx, AVCodecContextPtr& ctx,
int64_t sample_rate, int64_t sample_rate,
int64_t num_channels, int64_t num_channels,
const c10::optional<std::string>& format) { const c10::optional<std::string>& format,
const c10::optional<EncodingConfig>& config) {
// TODO: Review options and make them configurable? // TODO: Review options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00122 // https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00122
// - bit_rate // - bit_rate
...@@ -160,6 +161,17 @@ void configure_audio_codec( ...@@ -160,6 +161,17 @@ void configure_audio_codec(
} }
} }
ctx->channel_layout = static_cast<uint64_t>(layout); ctx->channel_layout = static_cast<uint64_t>(layout);
// Set optional stuff
if (config) {
auto& cfg = config.value();
if (cfg.bit_rate > 0) {
ctx->bit_rate = cfg.bit_rate;
}
if (cfg.compression_level != -1) {
ctx->compression_level = cfg.compression_level;
}
}
} }
void open_codec( void open_codec(
...@@ -177,9 +189,10 @@ AVCodecContextPtr get_audio_codec( ...@@ -177,9 +189,10 @@ AVCodecContextPtr get_audio_codec(
int64_t num_channels, int64_t num_channels,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format) { const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_AUDIO, oformat, encoder); AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_AUDIO, oformat, encoder);
configure_audio_codec(ctx, sample_rate, num_channels, encoder_format); configure_audio_codec(ctx, sample_rate, num_channels, encoder_format, config);
open_codec(ctx, encoder_option); open_codec(ctx, encoder_option);
return ctx; return ctx;
} }
...@@ -275,7 +288,8 @@ void configure_video_codec( ...@@ -275,7 +288,8 @@ void configure_video_codec(
double frame_rate, double frame_rate,
int64_t width, int64_t width,
int64_t height, int64_t height,
const c10::optional<std::string>& format) { const c10::optional<std::string>& format,
const c10::optional<EncodingConfig>& config) {
// TODO: Review other options and make them configurable? // TODO: Review other options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00147 // https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00147
// - bit_rate // - bit_rate
...@@ -340,6 +354,23 @@ void configure_video_codec( ...@@ -340,6 +354,23 @@ void configure_video_codec(
} }
return ret; return ret;
}(); }();
// Set optional stuff
if (config) {
auto& cfg = config.value();
if (cfg.bit_rate > 0) {
ctx->bit_rate = cfg.bit_rate;
}
if (cfg.compression_level != -1) {
ctx->compression_level = cfg.compression_level;
}
if (cfg.gop_size != -1) {
ctx->gop_size = cfg.gop_size;
}
if (cfg.max_b_frames != -1) {
ctx->max_b_frames = cfg.max_b_frames;
}
}
} }
void configure_hw_accel(AVCodecContext* ctx, const std::string& hw_accel) { void configure_hw_accel(AVCodecContext* ctx, const std::string& hw_accel) {
...@@ -400,9 +431,10 @@ AVCodecContextPtr get_video_codec( ...@@ -400,9 +431,10 @@ AVCodecContextPtr get_video_codec(
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel) { const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_VIDEO, oformat, encoder); AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_VIDEO, oformat, encoder);
configure_video_codec(ctx, frame_rate, width, height, encoder_format); configure_video_codec(ctx, frame_rate, width, height, encoder_format, config);
if (hw_accel) { if (hw_accel) {
#ifdef USE_CUDA #ifdef USE_CUDA
...@@ -479,14 +511,16 @@ EncodeProcess::EncodeProcess( ...@@ -479,14 +511,16 @@ EncodeProcess::EncodeProcess(
const enum AVSampleFormat format, const enum AVSampleFormat format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format) const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config)
: codec_ctx(get_audio_codec( : codec_ctx(get_audio_codec(
format_ctx->oformat, format_ctx->oformat,
sample_rate, sample_rate,
num_channels, num_channels,
encoder, encoder,
encoder_option, encoder_option,
encoder_format)), encoder_format,
config)),
encoder(format_ctx, codec_ctx), encoder(format_ctx, codec_ctx),
filter(get_audio_filter(format, codec_ctx)), filter(get_audio_filter(format, codec_ctx)),
src_frame(get_audio_frame(format, sample_rate, num_channels, codec_ctx)), src_frame(get_audio_frame(format, sample_rate, num_channels, codec_ctx)),
...@@ -501,7 +535,8 @@ EncodeProcess::EncodeProcess( ...@@ -501,7 +535,8 @@ EncodeProcess::EncodeProcess(
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel) const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config)
: codec_ctx(get_video_codec( : codec_ctx(get_video_codec(
format_ctx->oformat, format_ctx->oformat,
frame_rate, frame_rate,
...@@ -510,7 +545,8 @@ EncodeProcess::EncodeProcess( ...@@ -510,7 +545,8 @@ EncodeProcess::EncodeProcess(
encoder, encoder,
encoder_option, encoder_option,
encoder_format, encoder_format,
hw_accel)), hw_accel,
config)),
encoder(format_ctx, codec_ctx), encoder(format_ctx, codec_ctx),
filter(get_video_filter(format, codec_ctx)), filter(get_video_filter(format, codec_ctx)),
src_frame(get_video_frame(format, codec_ctx)), src_frame(get_video_frame(format, codec_ctx)),
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <torchaudio/csrc/ffmpeg/filter_graph.h> #include <torchaudio/csrc/ffmpeg/filter_graph.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/encoder.h> #include <torchaudio/csrc/ffmpeg/stream_writer/encoder.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/tensor_converter.h> #include <torchaudio/csrc/ffmpeg/stream_writer/tensor_converter.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/types.h>
namespace torchaudio::io { namespace torchaudio::io {
...@@ -25,7 +26,8 @@ class EncodeProcess { ...@@ -25,7 +26,8 @@ class EncodeProcess {
const enum AVSampleFormat format, const enum AVSampleFormat format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format); const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config);
// constructor for video // constructor for video
EncodeProcess( EncodeProcess(
...@@ -37,7 +39,8 @@ class EncodeProcess { ...@@ -37,7 +39,8 @@ class EncodeProcess {
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel); const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config);
void process( void process(
AVMediaType type, AVMediaType type,
......
...@@ -101,7 +101,8 @@ void StreamWriter::add_audio_stream( ...@@ -101,7 +101,8 @@ void StreamWriter::add_audio_stream(
const std::string& format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format) { const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config) {
processes.emplace_back( processes.emplace_back(
pFormatContext, pFormatContext,
sample_rate, sample_rate,
...@@ -109,7 +110,8 @@ void StreamWriter::add_audio_stream( ...@@ -109,7 +110,8 @@ void StreamWriter::add_audio_stream(
get_src_sample_fmt(format), get_src_sample_fmt(format),
encoder, encoder,
encoder_option, encoder_option,
encoder_format); encoder_format,
config);
} }
void StreamWriter::add_video_stream( void StreamWriter::add_video_stream(
...@@ -120,7 +122,8 @@ void StreamWriter::add_video_stream( ...@@ -120,7 +122,8 @@ void StreamWriter::add_video_stream(
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel) { const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) {
processes.emplace_back( processes.emplace_back(
pFormatContext, pFormatContext,
frame_rate, frame_rate,
...@@ -130,7 +133,8 @@ void StreamWriter::add_video_stream( ...@@ -130,7 +133,8 @@ void StreamWriter::add_video_stream(
encoder, encoder,
encoder_option, encoder_option,
encoder_format, encoder_format,
hw_accel); hw_accel,
config);
} }
void StreamWriter::set_metadata(const OptionDict& metadata) { void StreamWriter::set_metadata(const OptionDict& metadata) {
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <torchaudio/csrc/ffmpeg/ffmpeg.h> #include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/filter_graph.h> #include <torchaudio/csrc/ffmpeg/filter_graph.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/encode_process.h> #include <torchaudio/csrc/ffmpeg/stream_writer/encode_process.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/types.h>
namespace torchaudio { namespace torchaudio {
namespace io { namespace io {
...@@ -104,7 +105,9 @@ class StreamWriter { ...@@ -104,7 +105,9 @@ class StreamWriter {
const std::string& format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format); const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config);
/// Add an output video stream. /// Add an output video stream.
/// ///
/// @param frame_rate Frame rate /// @param frame_rate Frame rate
...@@ -142,7 +145,8 @@ class StreamWriter { ...@@ -142,7 +145,8 @@ class StreamWriter {
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel); const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config);
/// Set file-level metadata /// Set file-level metadata
/// @param metadata metadata. /// @param metadata metadata.
void set_metadata(const OptionDict& metadata); void set_metadata(const OptionDict& metadata);
......
#pragma once
namespace torchaudio::io {
struct EncodingConfig {
int bit_rate = -1;
int compression_level = -1;
// video
int gop_size = -1;
int max_b_frames = -1;
};
} // namespace torchaudio::io
from dataclasses import dataclass
from typing import BinaryIO, Dict, Optional, Union from typing import BinaryIO, Dict, Optional, Union
import torch import torch
import torchaudio import torchaudio
if torchaudio._extension._FFMPEG_INITIALIZED:
ConfigBase = torchaudio.lib._torchaudio_ffmpeg.EncodingConfig
else:
ConfigBase = object
def _format_doc(**kwargs): def _format_doc(**kwargs):
def decorator(obj): def decorator(obj):
obj.__doc__ = obj.__doc__.format(**kwargs) obj.__doc__ = obj.__doc__.format(**kwargs)
...@@ -103,6 +110,25 @@ class StreamWriter: ...@@ -103,6 +110,25 @@ class StreamWriter:
Default: `4096`. Default: `4096`.
""" """
@dataclass
class EncodeConfig(ConfigBase):
"""Encoding configuration."""
bit_rate: int = -1
"""Bit rate"""
compression_level: int = -1
"""Compression level"""
gop_size: int = -1
"""The number of pictures in a group of pictures, or 0 for intra_only"""
max_b_frames: int = -1
"""maximum number of B-frames between non-B-frames."""
def __post_init__(self):
super().__init__(self.bit_rate, self.compression_level, self.gop_size, self.max_b_frames)
def __init__( def __init__(
self, self,
dst: Union[str, BinaryIO], dst: Union[str, BinaryIO],
...@@ -126,6 +152,7 @@ class StreamWriter: ...@@ -126,6 +152,7 @@ class StreamWriter:
encoder: Optional[str] = None, encoder: Optional[str] = None,
encoder_option: Optional[Dict[str, str]] = None, encoder_option: Optional[Dict[str, str]] = None,
encoder_format: Optional[str] = None, encoder_format: Optional[str] = None,
config: Optional[EncodeConfig] = None,
): ):
"""Add an output audio stream. """Add an output audio stream.
...@@ -152,7 +179,7 @@ class StreamWriter: ...@@ -152,7 +179,7 @@ class StreamWriter:
encoder_format (str or None, optional): {encoder_format} encoder_format (str or None, optional): {encoder_format}
""" """
self._s.add_audio_stream(sample_rate, num_channels, format, encoder, encoder_option, encoder_format) self._s.add_audio_stream(sample_rate, num_channels, format, encoder, encoder_option, encoder_format, config)
@_format_common_args @_format_common_args
def add_video_stream( def add_video_stream(
...@@ -165,6 +192,7 @@ class StreamWriter: ...@@ -165,6 +192,7 @@ class StreamWriter:
encoder_option: Optional[Dict[str, str]] = None, encoder_option: Optional[Dict[str, str]] = None,
encoder_format: Optional[str] = None, encoder_format: Optional[str] = None,
hw_accel: Optional[str] = None, hw_accel: Optional[str] = None,
config: Optional[EncodeConfig] = None,
): ):
"""Add an output video stream. """Add an output video stream.
...@@ -206,7 +234,9 @@ class StreamWriter: ...@@ -206,7 +234,9 @@ class StreamWriter:
If `None`, the video chunk Tensor has to be CPU Tensor. If `None`, the video chunk Tensor has to be CPU Tensor.
Default: ``None``. Default: ``None``.
""" """
self._s.add_video_stream(frame_rate, width, height, format, encoder, encoder_option, encoder_format, hw_accel) self._s.add_video_stream(
frame_rate, width, height, format, encoder, encoder_option, encoder_format, hw_accel, config
)
def set_metadata(self, metadata: Dict[str, str]): def set_metadata(self, metadata: Dict[str, str]):
"""Set file-level metadata """Set file-level metadata
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment