Commit b1de9f1a authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Revise encoder config arg and docstrings (#3203)

Summary:
For `StreamWriter`,
* Renames arg `config` to codec_config`.
* Renames struct `EncodingConfig` and dataclass `EncodeConfig` to `CodecConfig`.
* Adds docstrings for arg codec_config`.
* Updates `chunk` to `frames` in `write_*_chunk` methods.

Pull Request resolved: https://github.com/pytorch/audio/pull/3203

Reviewed By: mthrok

Differential Revision: D44350153

Pulled By: hwangjeff

fbshipit-source-id: 1b940b1366a43ec0565c362bfcbf62744088b343
parent 4eac61a3
......@@ -42,7 +42,7 @@ Methods
not item.startswith('_')
and item not in inherited_members
and item not in attributes
and item != "EncodeConfig"
and item != "CodecConfig"
%}
{{ item | underline("~") }}
......@@ -82,10 +82,10 @@ Support Structures
Support Structures
------------------
EncodeConfig
~~~~~~~~~~~~
CodecConfig
~~~~~~~~~~~
.. autoclass:: torchaudio.io::StreamWriter.EncodeConfig()
.. autoclass:: torchaudio.io::StreamWriter.CodecConfig()
:members:
{%- endif %}
......@@ -536,7 +536,7 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
# for that.
assert math.isclose(val, ref)
def test_encode_config(self):
def test_codec_config(self):
"""Can successfully set configuration and write audio."""
ext = "mp3"
filename = f"test.{ext}"
......@@ -546,14 +546,14 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
# Write data
dst = self.get_dst(filename)
writer = torchaudio.io.StreamWriter(dst=dst, format=ext)
config = torchaudio.io.StreamWriter.EncodeConfig(bit_rate=198_000, compression_level=3)
writer.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels, config=config)
codec_config = torchaudio.io.StreamWriter.CodecConfig(bit_rate=198_000, compression_level=3)
writer.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels, codec_config=codec_config)
audio = torch.zeros((8000, 2))
with writer.open():
writer.write_audio_chunk(0, audio)
def test_encode_config_bit_rate_output(self):
def test_codec_config_bit_rate_output(self):
"""Increasing the specified bit rate yields a larger encoded output."""
ext = "mp3"
sample_rate = 44100
......@@ -565,7 +565,7 @@ class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
writer.add_audio_stream(
sample_rate=sample_rate,
num_channels=num_channels,
config=torchaudio.io.StreamWriter.EncodeConfig(bit_rate=bit_rate),
codec_config=torchaudio.io.StreamWriter.CodecConfig(bit_rate=bit_rate),
)
with writer.open():
......
......@@ -35,7 +35,7 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
py::class_<Chunk>(m, "Chunk", py::module_local())
.def_readwrite("frames", &Chunk::frames)
.def_readwrite("pts", &Chunk::pts);
py::class_<EncodingConfig>(m, "EncodingConfig", py::module_local())
py::class_<CodecConfig>(m, "CodecConfig", py::module_local())
.def(py::init<int, int, int, int>());
py::class_<StreamWriter>(m, "StreamWriter", py::module_local())
.def(py::init<const std::string&, const c10::optional<std::string>&>())
......
......@@ -309,7 +309,7 @@ void configure_audio_codec_ctx(
int sample_rate,
int num_channels,
uint64_t channel_layout,
const c10::optional<EncodingConfig>& config) {
const c10::optional<CodecConfig>& codec_config) {
codec_ctx->sample_fmt = format;
codec_ctx->sample_rate = sample_rate;
codec_ctx->time_base = av_inv_q(av_d2q(sample_rate, 1 << 24));
......@@ -317,8 +317,8 @@ void configure_audio_codec_ctx(
codec_ctx->channel_layout = channel_layout;
// Set optional stuff
if (config) {
auto& cfg = config.value();
if (codec_config) {
auto& cfg = codec_config.value();
if (cfg.bit_rate > 0) {
codec_ctx->bit_rate = cfg.bit_rate;
}
......@@ -411,7 +411,7 @@ void configure_video_codec_ctx(
AVRational frame_rate,
int width,
int height,
const c10::optional<EncodingConfig>& config) {
const c10::optional<CodecConfig>& codec_config) {
// TODO: Review other options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00147
// - bit_rate_tolerance
......@@ -423,8 +423,8 @@ void configure_video_codec_ctx(
ctx->time_base = av_inv_q(frame_rate);
// Set optional stuff
if (config) {
auto& cfg = config.value();
if (codec_config) {
auto& cfg = codec_config.value();
if (cfg.bit_rate > 0) {
ctx->bit_rate = cfg.bit_rate;
}
......@@ -624,7 +624,7 @@ EncodeProcess get_audio_encode_process(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config) {
const c10::optional<CodecConfig>& codec_config) {
// 1. Check the source format, rate and channels
const AVSampleFormat src_fmt = get_sample_fmt(format);
TORCH_CHECK(
......@@ -658,7 +658,7 @@ EncodeProcess get_audio_encode_process(
src_sample_rate,
src_num_channels,
channel_layout,
config);
codec_config);
open_codec(codec_ctx, encoder_option);
// 5. Build filter graph
......@@ -701,7 +701,7 @@ EncodeProcess get_video_encode_process(
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) {
const c10::optional<CodecConfig>& codec_config) {
// 1. Checkc the source format, rate and resolution
const AVPixelFormat src_fmt = get_pix_fmt(format);
AVRational src_rate = av_d2q(frame_rate, 1 << 24);
......@@ -727,7 +727,7 @@ EncodeProcess get_video_encode_process(
AVCodecContextPtr codec_ctx =
get_codec_ctx(codec, format_ctx->oformat->flags);
configure_video_codec_ctx(
codec_ctx, enc_fmt, src_rate, src_width, src_height, config);
codec_ctx, enc_fmt, src_rate, src_width, src_height, codec_config);
if (hw_accel) {
#ifdef USE_CUDA
configure_hw_accel(codec_ctx, hw_accel.value());
......
......@@ -41,7 +41,7 @@ EncodeProcess get_audio_encode_process(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config);
const c10::optional<CodecConfig>& codec_config);
EncodeProcess get_video_encode_process(
AVFormatContext* format_ctx,
......@@ -53,6 +53,6 @@ EncodeProcess get_video_encode_process(
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config);
const c10::optional<CodecConfig>& codec_config);
}; // namespace torchaudio::io
......@@ -60,7 +60,7 @@ void StreamWriter::add_audio_stream(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config) {
const c10::optional<CodecConfig>& codec_config) {
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
......@@ -73,7 +73,7 @@ void StreamWriter::add_audio_stream(
encoder,
encoder_option,
encoder_format,
config));
codec_config));
}
void StreamWriter::add_video_stream(
......@@ -85,7 +85,7 @@ void StreamWriter::add_video_stream(
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) {
const c10::optional<CodecConfig>& codec_config) {
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
......@@ -100,7 +100,7 @@ void StreamWriter::add_video_stream(
encoder_option,
encoder_format,
hw_accel,
config));
codec_config));
}
void StreamWriter::set_metadata(const OptionDict& metadata) {
......
......@@ -99,6 +99,7 @@ class StreamWriter {
/// override the format used for encoding.
/// To list supported formats for the encoder, you can use
/// ``ffmpeg -h encoder=<ENCODER>`` command.
/// @param codec_config Codec configuration.
void add_audio_stream(
int sample_rate,
int num_channels,
......@@ -106,7 +107,7 @@ class StreamWriter {
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config);
const c10::optional<CodecConfig>& codec_config);
/// Add an output video stream.
///
......@@ -129,6 +130,7 @@ class StreamWriter {
/// @param encoder_option See ``add_audio_stream()``.
/// @param encoder_format See ``add_audio_stream()``.
/// @param hw_accel Enable hardware acceleration.
/// @param codec_config Codec configuration.
/// @parblock
/// When video is encoded on CUDA hardware, for example
/// `encoder="h264_nvenc"`, passing CUDA device indicator to `hw_accel`
......@@ -146,7 +148,7 @@ class StreamWriter {
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config);
const c10::optional<CodecConfig>& codec_config);
/// Set file-level metadata
/// @param metadata metadata.
void set_metadata(const OptionDict& metadata);
......@@ -164,7 +166,7 @@ class StreamWriter {
/// Write audio data
/// @param i Stream index.
/// @param chunk Waveform tensor. Shape: ``(frame, channel)``.
/// @param frames Waveform tensor. Shape: ``(frame, channel)``.
/// The ``dtype`` must match what was passed to ``add_audio_stream()`` method.
/// @param pts
/// @parblock
......@@ -183,7 +185,7 @@ class StreamWriter {
const c10::optional<double>& pts = {});
/// Write video data
/// @param i Stream index.
/// @param chunk Video/image tensor. Shape: ``(time, channel, height,
/// @param frames Video/image tensor. Shape: ``(time, channel, height,
/// width)``. The ``dtype`` must be ``torch.uint8``. The shape ``(height,
/// width and the number of channels)`` must match what was configured when
/// calling ``add_video_stream()``.
......
#pragma once
namespace torchaudio::io {
struct EncodingConfig {
struct CodecConfig {
int bit_rate = -1;
int compression_level = -1;
......
......@@ -6,7 +6,7 @@ import torchaudio
if torchaudio._extension._FFMPEG_INITIALIZED:
ConfigBase = torchaudio.lib._torchaudio_ffmpeg.EncodingConfig
ConfigBase = torchaudio.lib._torchaudio_ffmpeg.CodecConfig
else:
ConfigBase = object
......@@ -47,11 +47,17 @@ _encoder_format = """Format used to encode media.
Default: ``None``."""
_codec_config = """Codec configuration. Please refer to :py:class:`CodecConfig` for
configuration options.
Default: ``None``."""
_format_common_args = _format_doc(
encoder=_encoder,
encoder_option=_encoder_option,
encoder_format=_encoder_format,
codec_config=_codec_config,
)
......@@ -111,8 +117,8 @@ class StreamWriter:
"""
@dataclass
class EncodeConfig(ConfigBase):
"""Encoding configuration."""
class CodecConfig(ConfigBase):
"""Codec configuration."""
bit_rate: int = -1
"""Bit rate"""
......@@ -152,7 +158,7 @@ class StreamWriter:
encoder: Optional[str] = None,
encoder_option: Optional[Dict[str, str]] = None,
encoder_format: Optional[str] = None,
config: Optional[EncodeConfig] = None,
codec_config: Optional[CodecConfig] = None,
):
"""Add an output audio stream.
......@@ -178,8 +184,12 @@ class StreamWriter:
encoder_option (dict or None, optional): {encoder_option}
encoder_format (str or None, optional): {encoder_format}
codec_config (CodecConfig or None, optional): {codec_config}
"""
self._s.add_audio_stream(sample_rate, num_channels, format, encoder, encoder_option, encoder_format, config)
self._s.add_audio_stream(
sample_rate, num_channels, format, encoder, encoder_option, encoder_format, codec_config
)
@_format_common_args
def add_video_stream(
......@@ -192,7 +202,7 @@ class StreamWriter:
encoder_option: Optional[Dict[str, str]] = None,
encoder_format: Optional[str] = None,
hw_accel: Optional[str] = None,
config: Optional[EncodeConfig] = None,
codec_config: Optional[CodecConfig] = None,
):
"""Add an output video stream.
......@@ -233,9 +243,11 @@ class StreamWriter:
If `None`, the video chunk Tensor has to be CPU Tensor.
Default: ``None``.
codec_config (CodecConfig or None, optional): {codec_config}
"""
self._s.add_video_stream(
frame_rate, width, height, format, encoder, encoder_option, encoder_format, hw_accel, config
frame_rate, width, height, format, encoder, encoder_option, encoder_format, hw_accel, codec_config
)
def set_metadata(self, metadata: Dict[str, str]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment