Commit b1de9f1a authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Revise encoder config arg and docstrings (#3203)

Summary:
For `StreamWriter`,
* Renames arg `config` to codec_config`.
* Renames struct `EncodingConfig` and dataclass `EncodeConfig` to `CodecConfig`.
* Adds docstrings for arg codec_config`.
* Updates `chunk` to `frames` in `write_*_chunk` methods.

Pull Request resolved: https://github.com/pytorch/audio/pull/3203

Reviewed By: mthrok

Differential Revision: D44350153

Pulled By: hwangjeff

fbshipit-source-id: 1b940b1366a43ec0565c362bfcbf62744088b343
parent 4eac61a3
...@@ -42,7 +42,7 @@ Methods ...@@ -42,7 +42,7 @@ Methods
not item.startswith('_') not item.startswith('_')
and item not in inherited_members and item not in inherited_members
and item not in attributes and item not in attributes
and item != "EncodeConfig" and item != "CodecConfig"
%} %}
{{ item | underline("~") }} {{ item | underline("~") }}
...@@ -82,10 +82,10 @@ Support Structures ...@@ -82,10 +82,10 @@ Support Structures
Support Structures Support Structures
------------------ ------------------
EncodeConfig CodecConfig
~~~~~~~~~~~~ ~~~~~~~~~~~
.. autoclass:: torchaudio.io::StreamWriter.EncodeConfig() .. autoclass:: torchaudio.io::StreamWriter.CodecConfig()
:members: :members:
{%- endif %} {%- endif %}
...@@ -536,7 +536,7 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -536,7 +536,7 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
# for that. # for that.
assert math.isclose(val, ref) assert math.isclose(val, ref)
def test_encode_config(self): def test_codec_config(self):
"""Can successfully set configuration and write audio.""" """Can successfully set configuration and write audio."""
ext = "mp3" ext = "mp3"
filename = f"test.{ext}" filename = f"test.{ext}"
...@@ -546,14 +546,14 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -546,14 +546,14 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
# Write data # Write data
dst = self.get_dst(filename) dst = self.get_dst(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext) writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
config = torchaudio.io.StreamWriter.EncodeConfig(bit_rate=198_000, compression_level=3) codec_config = torchaudio.io.StreamWriter.CodecConfig(bit_rate=198_000, compression_level=3)
writer.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels, config=config) writer.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels, codec_config=codec_config)
audio = torch.zeros((8000, 2)) audio = torch.zeros((8000, 2))
with writer.open(): with writer.open():
writer.write_audio_chunk(0, audio) writer.write_audio_chunk(0, audio)
def test_encode_config_bit_rate_output(self): def test_codec_config_bit_rate_output(self):
"""Increasing the specified bit rate yields a larger encoded output.""" """Increasing the specified bit rate yields a larger encoded output."""
ext = "mp3" ext = "mp3"
sample_rate = 44100 sample_rate = 44100
...@@ -565,7 +565,7 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -565,7 +565,7 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
writer.add_audio_stream( writer.add_audio_stream(
sample_rate=sample_rate, sample_rate=sample_rate,
num_channels=num_channels, num_channels=num_channels,
config=torchaudio.io.StreamWriter.EncodeConfig(bit_rate=bit_rate), codec_config=torchaudio.io.StreamWriter.CodecConfig(bit_rate=bit_rate),
) )
with writer.open(): with writer.open():
......
...@@ -35,7 +35,7 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) { ...@@ -35,7 +35,7 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
py::class_<Chunk>(m, "Chunk", py::module_local()) py::class_<Chunk>(m, "Chunk", py::module_local())
.def_readwrite("frames", &Chunk::frames) .def_readwrite("frames", &Chunk::frames)
.def_readwrite("pts", &Chunk::pts); .def_readwrite("pts", &Chunk::pts);
py::class_<EncodingConfig>(m, "EncodingConfig", py::module_local()) py::class_<CodecConfig>(m, "CodecConfig", py::module_local())
.def(py::init<int, int, int, int>()); .def(py::init<int, int, int, int>());
py::class_<StreamWriter>(m, "StreamWriter", py::module_local()) py::class_<StreamWriter>(m, "StreamWriter", py::module_local())
.def(py::init<const std::string&, const c10::optional<std::string>&>()) .def(py::init<const std::string&, const c10::optional<std::string>&>())
......
...@@ -309,7 +309,7 @@ void configure_audio_codec_ctx( ...@@ -309,7 +309,7 @@ void configure_audio_codec_ctx(
int sample_rate, int sample_rate,
int num_channels, int num_channels,
uint64_t channel_layout, uint64_t channel_layout,
const c10::optional<EncodingConfig>& config) { const c10::optional<CodecConfig>& codec_config) {
codec_ctx->sample_fmt = format; codec_ctx->sample_fmt = format;
codec_ctx->sample_rate = sample_rate; codec_ctx->sample_rate = sample_rate;
codec_ctx->time_base = av_inv_q(av_d2q(sample_rate, 1 << 24)); codec_ctx->time_base = av_inv_q(av_d2q(sample_rate, 1 << 24));
...@@ -317,8 +317,8 @@ void configure_audio_codec_ctx( ...@@ -317,8 +317,8 @@ void configure_audio_codec_ctx(
codec_ctx->channel_layout = channel_layout; codec_ctx->channel_layout = channel_layout;
// Set optional stuff // Set optional stuff
if (config) { if (codec_config) {
auto& cfg = config.value(); auto& cfg = codec_config.value();
if (cfg.bit_rate > 0) { if (cfg.bit_rate > 0) {
codec_ctx->bit_rate = cfg.bit_rate; codec_ctx->bit_rate = cfg.bit_rate;
} }
...@@ -411,7 +411,7 @@ void configure_video_codec_ctx( ...@@ -411,7 +411,7 @@ void configure_video_codec_ctx(
AVRational frame_rate, AVRational frame_rate,
int width, int width,
int height, int height,
const c10::optional<EncodingConfig>& config) { const c10::optional<CodecConfig>& codec_config) {
// TODO: Review other options and make them configurable? // TODO: Review other options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00147 // https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00147
// - bit_rate_tolerance // - bit_rate_tolerance
...@@ -423,8 +423,8 @@ void configure_video_codec_ctx( ...@@ -423,8 +423,8 @@ void configure_video_codec_ctx(
ctx->time_base = av_inv_q(frame_rate); ctx->time_base = av_inv_q(frame_rate);
// Set optional stuff // Set optional stuff
if (config) { if (codec_config) {
auto& cfg = config.value(); auto& cfg = codec_config.value();
if (cfg.bit_rate > 0) { if (cfg.bit_rate > 0) {
ctx->bit_rate = cfg.bit_rate; ctx->bit_rate = cfg.bit_rate;
} }
...@@ -624,7 +624,7 @@ EncodeProcess get_audio_encode_process( ...@@ -624,7 +624,7 @@ EncodeProcess get_audio_encode_process(
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config) { const c10::optional<CodecConfig>& codec_config) {
// 1. Check the source format, rate and channels // 1. Check the source format, rate and channels
const AVSampleFormat src_fmt = get_sample_fmt(format); const AVSampleFormat src_fmt = get_sample_fmt(format);
TORCH_CHECK( TORCH_CHECK(
...@@ -658,7 +658,7 @@ EncodeProcess get_audio_encode_process( ...@@ -658,7 +658,7 @@ EncodeProcess get_audio_encode_process(
src_sample_rate, src_sample_rate,
src_num_channels, src_num_channels,
channel_layout, channel_layout,
config); codec_config);
open_codec(codec_ctx, encoder_option); open_codec(codec_ctx, encoder_option);
// 5. Build filter graph // 5. Build filter graph
...@@ -701,7 +701,7 @@ EncodeProcess get_video_encode_process( ...@@ -701,7 +701,7 @@ EncodeProcess get_video_encode_process(
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel, const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) { const c10::optional<CodecConfig>& codec_config) {
// 1. Checkc the source format, rate and resolution // 1. Checkc the source format, rate and resolution
const AVPixelFormat src_fmt = get_pix_fmt(format); const AVPixelFormat src_fmt = get_pix_fmt(format);
AVRational src_rate = av_d2q(frame_rate, 1 << 24); AVRational src_rate = av_d2q(frame_rate, 1 << 24);
...@@ -727,7 +727,7 @@ EncodeProcess get_video_encode_process( ...@@ -727,7 +727,7 @@ EncodeProcess get_video_encode_process(
AVCodecContextPtr codec_ctx = AVCodecContextPtr codec_ctx =
get_codec_ctx(codec, format_ctx->oformat->flags); get_codec_ctx(codec, format_ctx->oformat->flags);
configure_video_codec_ctx( configure_video_codec_ctx(
codec_ctx, enc_fmt, src_rate, src_width, src_height, config); codec_ctx, enc_fmt, src_rate, src_width, src_height, codec_config);
if (hw_accel) { if (hw_accel) {
#ifdef USE_CUDA #ifdef USE_CUDA
configure_hw_accel(codec_ctx, hw_accel.value()); configure_hw_accel(codec_ctx, hw_accel.value());
......
...@@ -41,7 +41,7 @@ EncodeProcess get_audio_encode_process( ...@@ -41,7 +41,7 @@ EncodeProcess get_audio_encode_process(
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config); const c10::optional<CodecConfig>& codec_config);
EncodeProcess get_video_encode_process( EncodeProcess get_video_encode_process(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
...@@ -53,6 +53,6 @@ EncodeProcess get_video_encode_process( ...@@ -53,6 +53,6 @@ EncodeProcess get_video_encode_process(
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel, const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config); const c10::optional<CodecConfig>& codec_config);
}; // namespace torchaudio::io }; // namespace torchaudio::io
...@@ -60,7 +60,7 @@ void StreamWriter::add_audio_stream( ...@@ -60,7 +60,7 @@ void StreamWriter::add_audio_stream(
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config) { const c10::optional<CodecConfig>& codec_config) {
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream."); TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(), pFormatContext->nb_streams == processes.size(),
...@@ -73,7 +73,7 @@ void StreamWriter::add_audio_stream( ...@@ -73,7 +73,7 @@ void StreamWriter::add_audio_stream(
encoder, encoder,
encoder_option, encoder_option,
encoder_format, encoder_format,
config)); codec_config));
} }
void StreamWriter::add_video_stream( void StreamWriter::add_video_stream(
...@@ -85,7 +85,7 @@ void StreamWriter::add_video_stream( ...@@ -85,7 +85,7 @@ void StreamWriter::add_video_stream(
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel, const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) { const c10::optional<CodecConfig>& codec_config) {
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream."); TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(), pFormatContext->nb_streams == processes.size(),
...@@ -100,7 +100,7 @@ void StreamWriter::add_video_stream( ...@@ -100,7 +100,7 @@ void StreamWriter::add_video_stream(
encoder_option, encoder_option,
encoder_format, encoder_format,
hw_accel, hw_accel,
config)); codec_config));
} }
void StreamWriter::set_metadata(const OptionDict& metadata) { void StreamWriter::set_metadata(const OptionDict& metadata) {
......
...@@ -99,6 +99,7 @@ class StreamWriter { ...@@ -99,6 +99,7 @@ class StreamWriter {
/// override the format used for encoding. /// override the format used for encoding.
/// To list supported formats for the encoder, you can use /// To list supported formats for the encoder, you can use
/// ``ffmpeg -h encoder=<ENCODER>`` command. /// ``ffmpeg -h encoder=<ENCODER>`` command.
/// @param codec_config Codec configuration.
void add_audio_stream( void add_audio_stream(
int sample_rate, int sample_rate,
int num_channels, int num_channels,
...@@ -106,7 +107,7 @@ class StreamWriter { ...@@ -106,7 +107,7 @@ class StreamWriter {
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config); const c10::optional<CodecConfig>& codec_config);
/// Add an output video stream. /// Add an output video stream.
/// ///
...@@ -129,6 +130,7 @@ class StreamWriter { ...@@ -129,6 +130,7 @@ class StreamWriter {
/// @param encoder_option See ``add_audio_stream()``. /// @param encoder_option See ``add_audio_stream()``.
/// @param encoder_format See ``add_audio_stream()``. /// @param encoder_format See ``add_audio_stream()``.
/// @param hw_accel Enable hardware acceleration. /// @param hw_accel Enable hardware acceleration.
/// @param codec_config Codec configuration.
/// @parblock /// @parblock
/// When video is encoded on CUDA hardware, for example /// When video is encoded on CUDA hardware, for example
/// `encoder="h264_nvenc"`, passing CUDA device indicator to `hw_accel` /// `encoder="h264_nvenc"`, passing CUDA device indicator to `hw_accel`
...@@ -146,7 +148,7 @@ class StreamWriter { ...@@ -146,7 +148,7 @@ class StreamWriter {
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel, const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config); const c10::optional<CodecConfig>& codec_config);
/// Set file-level metadata /// Set file-level metadata
/// @param metadata metadata. /// @param metadata metadata.
void set_metadata(const OptionDict& metadata); void set_metadata(const OptionDict& metadata);
...@@ -164,7 +166,7 @@ class StreamWriter { ...@@ -164,7 +166,7 @@ class StreamWriter {
/// Write audio data /// Write audio data
/// @param i Stream index. /// @param i Stream index.
/// @param chunk Waveform tensor. Shape: ``(frame, channel)``. /// @param frames Waveform tensor. Shape: ``(frame, channel)``.
/// The ``dtype`` must match what was passed to ``add_audio_stream()`` method. /// The ``dtype`` must match what was passed to ``add_audio_stream()`` method.
/// @param pts /// @param pts
/// @parblock /// @parblock
...@@ -183,7 +185,7 @@ class StreamWriter { ...@@ -183,7 +185,7 @@ class StreamWriter {
const c10::optional<double>& pts = {}); const c10::optional<double>& pts = {});
/// Write video data /// Write video data
/// @param i Stream index. /// @param i Stream index.
/// @param chunk Video/image tensor. Shape: ``(time, channel, height, /// @param frames Video/image tensor. Shape: ``(time, channel, height,
/// width)``. The ``dtype`` must be ``torch.uint8``. The shape ``(height, /// width)``. The ``dtype`` must be ``torch.uint8``. The shape ``(height,
/// width and the number of channels)`` must match what was configured when /// width and the number of channels)`` must match what was configured when
/// calling ``add_video_stream()``. /// calling ``add_video_stream()``.
......
#pragma once #pragma once
namespace torchaudio::io { namespace torchaudio::io {
struct EncodingConfig { struct CodecConfig {
int bit_rate = -1; int bit_rate = -1;
int compression_level = -1; int compression_level = -1;
......
...@@ -6,7 +6,7 @@ import torchaudio ...@@ -6,7 +6,7 @@ import torchaudio
if torchaudio._extension._FFMPEG_INITIALIZED: if torchaudio._extension._FFMPEG_INITIALIZED:
ConfigBase = torchaudio.lib._torchaudio_ffmpeg.EncodingConfig ConfigBase = torchaudio.lib._torchaudio_ffmpeg.CodecConfig
else: else:
ConfigBase = object ConfigBase = object
...@@ -47,11 +47,17 @@ _encoder_format = """Format used to encode media. ...@@ -47,11 +47,17 @@ _encoder_format = """Format used to encode media.
Default: ``None``.""" Default: ``None``."""
_codec_config = """Codec configuration. Please refer to :py:class:`CodecConfig` for
configuration options.
Default: ``None``."""
_format_common_args = _format_doc( _format_common_args = _format_doc(
encoder=_encoder, encoder=_encoder,
encoder_option=_encoder_option, encoder_option=_encoder_option,
encoder_format=_encoder_format, encoder_format=_encoder_format,
codec_config=_codec_config,
) )
...@@ -111,8 +117,8 @@ class StreamWriter: ...@@ -111,8 +117,8 @@ class StreamWriter:
""" """
@dataclass @dataclass
class EncodeConfig(ConfigBase): class CodecConfig(ConfigBase):
"""Encoding configuration.""" """Codec configuration."""
bit_rate: int = -1 bit_rate: int = -1
"""Bit rate""" """Bit rate"""
...@@ -152,7 +158,7 @@ class StreamWriter: ...@@ -152,7 +158,7 @@ class StreamWriter:
encoder: Optional[str] = None, encoder: Optional[str] = None,
encoder_option: Optional[Dict[str, str]] = None, encoder_option: Optional[Dict[str, str]] = None,
encoder_format: Optional[str] = None, encoder_format: Optional[str] = None,
config: Optional[EncodeConfig] = None, codec_config: Optional[CodecConfig] = None,
): ):
"""Add an output audio stream. """Add an output audio stream.
...@@ -178,8 +184,12 @@ class StreamWriter: ...@@ -178,8 +184,12 @@ class StreamWriter:
encoder_option (dict or None, optional): {encoder_option} encoder_option (dict or None, optional): {encoder_option}
encoder_format (str or None, optional): {encoder_format} encoder_format (str or None, optional): {encoder_format}
codec_config (CodecConfig or None, optional): {codec_config}
""" """
self._s.add_audio_stream(sample_rate, num_channels, format, encoder, encoder_option, encoder_format, config) self._s.add_audio_stream(
sample_rate, num_channels, format, encoder, encoder_option, encoder_format, codec_config
)
@_format_common_args @_format_common_args
def add_video_stream( def add_video_stream(
...@@ -192,7 +202,7 @@ class StreamWriter: ...@@ -192,7 +202,7 @@ class StreamWriter:
encoder_option: Optional[Dict[str, str]] = None, encoder_option: Optional[Dict[str, str]] = None,
encoder_format: Optional[str] = None, encoder_format: Optional[str] = None,
hw_accel: Optional[str] = None, hw_accel: Optional[str] = None,
config: Optional[EncodeConfig] = None, codec_config: Optional[CodecConfig] = None,
): ):
"""Add an output video stream. """Add an output video stream.
...@@ -233,9 +243,11 @@ class StreamWriter: ...@@ -233,9 +243,11 @@ class StreamWriter:
If `None`, the video chunk Tensor has to be CPU Tensor. If `None`, the video chunk Tensor has to be CPU Tensor.
Default: ``None``. Default: ``None``.
codec_config (CodecConfig or None, optional): {codec_config}
""" """
self._s.add_video_stream( self._s.add_video_stream(
frame_rate, width, height, format, encoder, encoder_option, encoder_format, hw_accel, config frame_rate, width, height, format, encoder, encoder_option, encoder_format, hw_accel, codec_config
) )
def set_metadata(self, metadata: Dict[str, str]): def set_metadata(self, metadata: Dict[str, str]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment