Unverified Commit c3cb2015 authored by moto's avatar moto Committed by GitHub
Browse files

Add encoding and bits_per_sample option to save function (#1226)

parent 4f9b5520
def name_func(func, _, params): def name_func(func, _, params):
return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}' return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}'
def get_enc_params(dtype):
if dtype == 'float32':
return 'PCM_F', 32
if dtype == 'int32':
return 'PCM_S', 32
if dtype == 'int16':
return 'PCM_S', 16
if dtype == 'uint8':
return 'PCM_U', 8
raise ValueError(f'Unexpected dtype: {dtype}')
...@@ -12,6 +12,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -12,6 +12,7 @@ from torchaudio_unittest.common_utils import (
) )
from .common import ( from .common import (
name_func, name_func,
get_enc_params,
) )
...@@ -27,10 +28,11 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase): ...@@ -27,10 +28,11 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase):
def test_wav(self, dtype, sample_rate, num_channels): def test_wav(self, dtype, sample_rate, num_channels):
"""save/load round trip should not degrade data for wav formats""" """save/load round trip should not degrade data for wav formats"""
original = get_wav_data(dtype, num_channels, normalize=False) original = get_wav_data(dtype, num_channels, normalize=False)
enc, bps = get_enc_params(dtype)
data = original data = original
for i in range(10): for i in range(10):
path = self.get_temp_path(f'{i}.wav') path = self.get_temp_path(f'{i}.wav')
sox_io_backend.save(path, data, sample_rate) sox_io_backend.save(path, data, sample_rate, encoding=enc, bits_per_sample=bps)
data, sr = sox_io_backend.load(path, normalize=False) data, sr = sox_io_backend.load(path, normalize=False)
assert sr == sample_rate assert sr == sample_rate
self.assertEqual(original, data) self.assertEqual(original, data)
......
...@@ -17,6 +17,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -17,6 +17,7 @@ from torchaudio_unittest.common_utils import (
) )
from .common import ( from .common import (
name_func, name_func,
get_enc_params,
) )
...@@ -35,8 +36,12 @@ def py_save_func( ...@@ -35,8 +36,12 @@ def py_save_func(
sample_rate: int, sample_rate: int,
channels_first: bool = True, channels_first: bool = True,
compression: Optional[float] = None, compression: Optional[float] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
): ):
torchaudio.save(filepath, tensor, sample_rate, channels_first, compression) torchaudio.save(
filepath, tensor, sample_rate, channels_first,
compression, None, encoding, bits_per_sample)
@skipIfNoExec('sox') @skipIfNoExec('sox')
...@@ -102,15 +107,16 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -102,15 +107,16 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
torch.jit.script(py_save_func).save(script_path) torch.jit.script(py_save_func).save(script_path)
ts_save_func = torch.jit.load(script_path) ts_save_func = torch.jit.load(script_path)
expected = get_wav_data(dtype, num_channels) expected = get_wav_data(dtype, num_channels, normalize=False)
py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav') py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav')
ts_path = self.get_temp_path(f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav') ts_path = self.get_temp_path(f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav')
enc, bps = get_enc_params(dtype)
py_save_func(py_path, expected, sample_rate, True, None) py_save_func(py_path, expected, sample_rate, True, None, enc, bps)
ts_save_func(ts_path, expected, sample_rate, True, None) ts_save_func(ts_path, expected, sample_rate, True, None, enc, bps)
py_data, py_sr = load_wav(py_path) py_data, py_sr = load_wav(py_path, normalize=False)
ts_data, ts_sr = load_wav(ts_path) ts_data, ts_sr = load_wav(ts_path, normalize=False)
self.assertEqual(sample_rate, py_sr) self.assertEqual(sample_rate, py_sr)
self.assertEqual(sample_rate, ts_sr) self.assertEqual(sample_rate, ts_sr)
...@@ -131,8 +137,8 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -131,8 +137,8 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac') py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac')
ts_path = self.get_temp_path(f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac') ts_path = self.get_temp_path(f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac')
py_save_func(py_path, expected, sample_rate, True, compression_level) py_save_func(py_path, expected, sample_rate, True, compression_level, None, None)
ts_save_func(ts_path, expected, sample_rate, True, compression_level) ts_save_func(ts_path, expected, sample_rate, True, compression_level, None, None)
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. # converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
py_path_wav = f'{py_path}.wav' py_path_wav = f'{py_path}.wav'
......
import sys
import subprocess import subprocess
import warnings import warnings
...@@ -32,6 +33,7 @@ def gen_audio_file( ...@@ -32,6 +33,7 @@ def gen_audio_file(
command = [ command = [
'sox', 'sox',
'-V3', # verbose '-V3', # verbose
'--no-dither', # disable automatic dithering
'-R', '-R',
# -R is supposed to be repeatable, though the implementation looks suspicious # -R is supposed to be repeatable, though the implementation looks suspicious
# and not setting the seed to a fixed value. # and not setting the seed to a fixed value.
...@@ -61,21 +63,23 @@ def gen_audio_file( ...@@ -61,21 +63,23 @@ def gen_audio_file(
] ]
if attenuation is not None: if attenuation is not None:
command += ['vol', f'-{attenuation}dB'] command += ['vol', f'-{attenuation}dB']
print(' '.join(command)) print(' '.join(command), file=sys.stderr)
subprocess.run(command, check=True) subprocess.run(command, check=True)
def convert_audio_file( def convert_audio_file(
src_path, dst_path, src_path, dst_path,
*, bit_depth=None, compression=None): *, encoding=None, bit_depth=None, compression=None):
"""Convert audio file with `sox` command.""" """Convert audio file with `sox` command."""
command = ['sox', '-V3', '-R', str(src_path)] command = ['sox', '-V3', '--no-dither', '-R', str(src_path)]
if encoding is not None:
command += ['--encoding', str(encoding)]
if bit_depth is not None: if bit_depth is not None:
command += ['--bits', str(bit_depth)] command += ['--bits', str(bit_depth)]
if compression is not None: if compression is not None:
command += ['--compression', str(compression)] command += ['--compression', str(compression)]
command += [dst_path] command += [dst_path]
print(' '.join(command)) print(' '.join(command), file=sys.stderr)
subprocess.run(command, check=True) subprocess.run(command, check=True)
......
import os import os
import warnings
from typing import Tuple, Optional from typing import Tuple, Optional
import torch import torch
...@@ -152,26 +151,6 @@ def load( ...@@ -152,26 +151,6 @@ def load(
filepath, frame_offset, num_frames, normalize, channels_first, format) filepath, frame_offset, num_frames, normalize, channels_first, format)
@torch.jit.unused
def _save(
filepath: str,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
dtype: Optional[str] = None,
):
if hasattr(filepath, 'write'):
if format is None:
raise RuntimeError('`format` is required when saving to file object.')
torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression, format, dtype)
else:
torch.ops.torchaudio.sox_io_save_audio_file(
os.fspath(filepath), src, sample_rate, channels_first, compression, format, dtype)
@_mod_utils.requires_module('torchaudio._torchaudio') @_mod_utils.requires_module('torchaudio._torchaudio')
def save( def save(
filepath: str, filepath: str,
...@@ -180,30 +159,11 @@ def save( ...@@ -180,30 +159,11 @@ def save(
channels_first: bool = True, channels_first: bool = True,
compression: Optional[float] = None, compression: Optional[float] = None,
format: Optional[str] = None, format: Optional[str] = None,
dtype: Optional[str] = None, encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
): ):
"""Save audio data to file. """Save audio data to file.
Note:
Supported formats are;
* WAV, AMB
* 32-bit floating-point
* 32-bit signed integer
* 16-bit signed integer
* 8-bit unsigned integer
* MP3
* FLAC
* OGG/VORBIS
* SPHERE
* AMR-NB
To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc.
Args: Args:
filepath (str or pathlib.Path): Path to save file. filepath (str or pathlib.Path): Path to save file.
This function also handles ``pathlib.Path`` objects, but is annotated This function also handles ``pathlib.Path`` objects, but is annotated
...@@ -215,32 +175,137 @@ def save( ...@@ -215,32 +175,137 @@ def save(
compression (Optional[float]): Used for formats other than WAV. compression (Optional[float]): Used for formats other than WAV.
This corresponds to ``-C`` option of ``sox`` command. This corresponds to ``-C`` option of ``sox`` command.
* | ``MP3``: Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or ``"mp3"``
| VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``. Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
* | ``FLAC``: compression level. Whole number from ``0`` to ``8``. VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.
| ``8`` is default and highest compression.
* | ``OGG/VORBIS``: number from ``-1`` to ``10``; ``-1`` is the highest compression ``"flac"``
| and lowest quality. Default: ``3``. Whole number from ``0`` to ``8``. ``8`` is default and highest compression.
``"ogg"``, ``"vorbis"``
Number from ``-1`` to ``10``; ``-1`` is the highest compression
and lowest quality. Default: ``3``.
See the detail at http://sox.sourceforge.net/soxformat.html. See the detail at http://sox.sourceforge.net/soxformat.html.
format (str, optional): Output audio format. format (str, optional): Override the audio format.
This is required when the output audio format cannot be infered from When ``filepath`` argument is path-like object, audio format is infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object). file extension. If file extension is missing or different, you can specify the
dtype (str, optional): Output tensor dtype. correct format with this argument.
Valid values: ``"uint8", "int16", "int32", "float32", "float64", None``
``dtype=None`` means no conversion is performed. When ``filepath`` argument is file-like object, this argument is required.
``dtype`` parameter is only effective for ``float32`` Tensor.
Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``,
``"amb"``, ``"flac"`` and ``"sph"``.
encoding (str, optional): Changes the encoding for the supported formats.
This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"``
and ``"sph"``. Valid values are;
- ``"PCM_S"`` (signed integer Linear PCM)
- ``"PCM_U"`` (unsigned integer Linear PCM)
- ``"PCM_F"`` (floating point PCM)
- ``"ULAW"`` (mu-law)
- ``"ALAW"`` (a-law)
Default values
If not provided, the default value is picked based on ``format`` and ``bits_per_sample``.
``"wav"``, ``"amb"``
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
| Tensor is used to determine the default value.
- ``"PCM_U"`` if dtype is ``uint8``
- ``"PCM_S"`` if dtype is ``int16`` or ``int32`
- ``"PCM_F"`` if dtype is ``float32``
- ``"PCM_U"`` if ``bits_per_sample=8``
- ``"PCM_S"`` otherwise
``"sph"`` format;
- the default value is ``"PCM_S"``
bits_per_sample (int, optional): Changes the bit depth for the supported formats.
When ``format`` is one of ``"wav"``, ``"flac"``, ``"sph"``, or ``"amb"``, you can change the
bit depth. Valid values are ``8``, ``16``, ``32`` and ``64``.
Default Value;
If not provided, the default values are picked based on ``format`` and ``"encoding"``;
``"wav"``, ``"amb"``;
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
| Tensor is used.
- ``8`` if dtype is ``uint8``
- ``16`` if dtype is ``int16``
- ``32`` if dtype is ``int32`` or ``float32``
- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
- ``16`` if ``encoding`` is ``"PCM_S"``
- ``32`` if ``encoding`` is ``"PCM_F"``
``"flac"`` format;
- the default value is ``24``
``"sph"`` format;
- ``16`` if ``encoding`` is ``"PCM_U"``, ``"PCM_S"``, ``"PCM_F"`` or not provided.
- ``8`` if ``encoding`` is ``"ULAW"`` or ``"ALAW"``
``"amb"`` format;
- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
- ``16`` if ``encoding`` is ``"PCM_S"`` or not provided.
- ``32`` if ``encoding`` is ``"PCM_F"``
Supported formats/encodings/bit depth/compression are;
``"wav"``, ``"amb"``
- 32-bit floating-point PCM
- 32-bit signed integer PCM
- 24-bit signed integer PCM
- 16-bit signed integer PCM
- 8-bit unsigned integer PCM
- 8-bit mu-law
- 8-bit a-law
Note: Default encoding/bit depth is determined by the dtype of the input Tensor.
``"mp3"``
Fixed bit rate (such as 128kHz) and variable bit rate compression.
Default: VBR with high quality.
``"flac"``
- 8-bit
- 16-bit
- 24-bit (default)
``"ogg"``, ``"vorbis"``
- Different quality level. Default: approx. 112kbps
``"sph"``
- 8-bit signed integer PCM
- 16-bit signed integer PCM
- 24-bit signed integer PCM
- 32-bit signed integer PCM (default)
- 8-bit mu-law
- 8-bit a-law
- 16-bit a-law
- 24-bit a-law
- 32-bit a-law
``"amr-nb"``
Bitrate ranging from 4.75 kbit/s to 12.2 kbit/s. Default: 4.75 kbit/s
Note:
To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``,
``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has
to be linked to ``libsox`` and corresponding codec libraries such as ``libmad``
or ``libmp3lame`` etc.
""" """
if src.dtype == torch.float32 and dtype is None:
warnings.warn(
'`dtype` default value will be changed to `int16` in 0.9 release.'
'Specify `dtype` to suppress this warning.'
)
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_save(filepath, src, sample_rate, channels_first, compression, format, dtype) if hasattr(filepath, 'write'):
return torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression,
format, encoding, bits_per_sample)
return
filepath = os.fspath(filepath)
torch.ops.torchaudio.sox_io_save_audio_file( torch.ops.torchaudio.sox_io_save_audio_file(
filepath, src, sample_rate, channels_first, compression, format, dtype) filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample)
@_mod_utils.requires_module('torchaudio._torchaudio') @_mod_utils.requires_module('torchaudio._torchaudio')
......
...@@ -9,6 +9,7 @@ set( ...@@ -9,6 +9,7 @@ set(
sox/utils.cpp sox/utils.cpp
sox/effects.cpp sox/effects.cpp
sox/effects_chain.cpp sox/effects_chain.cpp
sox/types.cpp
) )
if(BUILD_TRANSDUCER) if(BUILD_TRANSDUCER)
......
...@@ -68,21 +68,43 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { ...@@ -68,21 +68,43 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// Ensure that it's a multiple of the number of channels // Ensure that it's a multiple of the number of channels
*osamp -= *osamp % num_channels; *osamp -= *osamp % num_channels;
// Slice the input Tensor and unnormalize the values // Slice the input Tensor
const auto tensor_ = [&]() { const auto tensor_ = [&]() {
auto i_frame = index / num_channels; auto i_frame = index / num_channels;
auto num_frames = *osamp / num_channels; auto num_frames = *osamp / num_channels;
auto t = (priv->channels_first) auto t = (priv->channels_first)
? tensor.index({Slice(), Slice(i_frame, i_frame + num_frames)}).t() ? tensor.index({Slice(), Slice(i_frame, i_frame + num_frames)}).t()
: tensor.index({Slice(i_frame, i_frame + num_frames), Slice()}); : tensor.index({Slice(i_frame, i_frame + num_frames), Slice()});
return unnormalize_wav(t.reshape({-1})).contiguous(); return t.reshape({-1}).contiguous();
}(); }();
priv->index += *osamp;
// Write data to SoxEffectsChain buffer.
auto ptr = tensor_.data_ptr<int32_t>();
std::copy(ptr, ptr + *osamp, obuf);
// Convert to sox_sample_t (int32_t) and write to buffer
SOX_SAMPLE_LOCALS;
const auto dtype = tensor_.dtype();
if (dtype == torch::kFloat32) {
auto ptr = tensor_.data_ptr<float_t>();
for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_FLOAT_32BIT_TO_SAMPLE(ptr[i], effp->clips);
}
} else if (dtype == torch::kInt32) {
auto ptr = tensor_.data_ptr<int32_t>();
for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_SIGNED_32BIT_TO_SAMPLE(ptr[i], effp->clips);
}
} else if (dtype == torch::kInt16) {
auto ptr = tensor_.data_ptr<int16_t>();
for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_SIGNED_16BIT_TO_SAMPLE(ptr[i], effp->clips);
}
} else if (dtype == torch::kUInt8) {
auto ptr = tensor_.data_ptr<uint8_t>();
for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_UNSIGNED_8BIT_TO_SAMPLE(ptr[i], effp->clips);
}
} else {
throw std::runtime_error("Unexpected dtype.");
}
priv->index += *osamp;
return (priv->index == num_samples) ? SOX_EOF : SOX_SUCCESS; return (priv->index == num_samples) ? SOX_EOF : SOX_SUCCESS;
} }
...@@ -430,7 +452,7 @@ int fileobj_output_flow( ...@@ -430,7 +452,7 @@ int fileobj_output_flow(
fflush(fp); fflush(fp);
// Copy the encoded chunk to python object. // Copy the encoded chunk to python object.
fileobj->attr("write")(py::bytes(*buffer, *buffer_size)); fileobj->attr("write")(py::bytes(*buffer, ftell(fp)));
// Reset FILE* // Reset FILE*
sf->tell_off = 0; sf->tell_off = 0;
......
...@@ -116,35 +116,27 @@ void save_audio_file( ...@@ -116,35 +116,27 @@ void save_audio_file(
torch::Tensor tensor, torch::Tensor tensor,
int64_t sample_rate, int64_t sample_rate,
bool channels_first, bool channels_first,
c10::optional<double> compression, c10::optional<double>& compression,
c10::optional<std::string> format, c10::optional<std::string>& format,
c10::optional<std::string> dtype) { c10::optional<std::string>& encoding,
c10::optional<int64_t>& bits_per_sample) {
validate_input_tensor(tensor); validate_input_tensor(tensor);
if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) {
throw std::runtime_error(
"dtype conversion only supported for float32 tensors");
}
const auto tgt_dtype =
(tensor.dtype() == torch::kFloat32 && dtype.has_value())
? get_dtype_from_str(dtype.value())
: tensor.dtype();
const auto filetype = [&]() { const auto filetype = [&]() {
if (format.has_value()) if (format.has_value())
return format.value(); return format.value();
return get_filetype(path); return get_filetype(path);
}(); }();
if (filetype == "amr-nb") { if (filetype == "amr-nb") {
const auto num_channels = tensor.size(channels_first ? 0 : 1); const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK( TORCH_CHECK(
num_channels == 1, "amr-nb format only supports single channel audio."); num_channels == 1, "amr-nb format only supports single channel audio.");
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
} }
const auto signal_info = const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first); get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = const auto encoding_info = get_encodinginfo_for_save(
get_encodinginfo_for_save(filetype, tgt_dtype, compression); filetype, tensor.dtype(), compression, encoding, bits_per_sample);
SoxFormat sf(sox_open_write( SoxFormat sf(sox_open_write(
path.c_str(), path.c_str(),
...@@ -258,19 +250,17 @@ void save_audio_fileobj( ...@@ -258,19 +250,17 @@ void save_audio_fileobj(
torch::Tensor tensor, torch::Tensor tensor,
int64_t sample_rate, int64_t sample_rate,
bool channels_first, bool channels_first,
c10::optional<double> compression, c10::optional<double>& compression,
std::string filetype, c10::optional<std::string>& format,
c10::optional<std::string> dtype) { c10::optional<std::string>& encoding,
c10::optional<int64_t>& bits_per_sample) {
validate_input_tensor(tensor); validate_input_tensor(tensor);
if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) { if (!format.has_value()) {
throw std::runtime_error( throw std::runtime_error(
"dtype conversion only supported for float32 tensors"); "`format` is required when saving to file object.");
} }
const auto tgt_dtype = const auto filetype = format.value();
(tensor.dtype() == torch::kFloat32 && dtype.has_value())
? get_dtype_from_str(dtype.value())
: tensor.dtype();
if (filetype == "amr-nb") { if (filetype == "amr-nb") {
const auto num_channels = tensor.size(channels_first ? 0 : 1); const auto num_channels = tensor.size(channels_first ? 0 : 1);
...@@ -278,12 +268,11 @@ void save_audio_fileobj( ...@@ -278,12 +268,11 @@ void save_audio_fileobj(
throw std::runtime_error( throw std::runtime_error(
"amr-nb format only supports single channel audio."); "amr-nb format only supports single channel audio.");
} }
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
} }
const auto signal_info = const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first); get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = const auto encoding_info = get_encodinginfo_for_save(
get_encodinginfo_for_save(filetype, tgt_dtype, compression); filetype, tensor.dtype(), compression, encoding, bits_per_sample);
AutoReleaseBuffer buffer; AutoReleaseBuffer buffer;
......
...@@ -28,9 +28,10 @@ void save_audio_file( ...@@ -28,9 +28,10 @@ void save_audio_file(
torch::Tensor tensor, torch::Tensor tensor,
int64_t sample_rate, int64_t sample_rate,
bool channels_first, bool channels_first,
c10::optional<double> compression, c10::optional<double>& compression,
c10::optional<std::string> format, c10::optional<std::string>& format,
c10::optional<std::string> dtype); c10::optional<std::string>& encoding,
c10::optional<int64_t>& bits_per_sample);
#ifdef TORCH_API_INCLUDE_EXTENSION_H #ifdef TORCH_API_INCLUDE_EXTENSION_H
...@@ -51,9 +52,10 @@ void save_audio_fileobj( ...@@ -51,9 +52,10 @@ void save_audio_fileobj(
torch::Tensor tensor, torch::Tensor tensor,
int64_t sample_rate, int64_t sample_rate,
bool channels_first, bool channels_first,
c10::optional<double> compression, c10::optional<double>& compression,
std::string filetype, c10::optional<std::string>& format,
c10::optional<std::string> dtype); c10::optional<std::string>& encoding,
c10::optional<int64_t>& bits_per_sample);
#endif // TORCH_API_INCLUDE_EXTENSION_H #endif // TORCH_API_INCLUDE_EXTENSION_H
......
#include <torchaudio/csrc/sox/types.h>
namespace torchaudio {
namespace sox_utils {
Format get_format_from_string(const std::string& format) {
if (format == "wav")
return Format::WAV;
if (format == "mp3")
return Format::MP3;
if (format == "flac")
return Format::FLAC;
if (format == "ogg" || format == "vorbis")
return Format::VORBIS;
if (format == "amr-nb")
return Format::AMR_NB;
if (format == "amr-wb")
return Format::AMR_WB;
if (format == "amb")
return Format::AMB;
if (format == "sph")
return Format::SPHERE;
std::ostringstream stream;
stream << "Internal Error: unexpected format value: " << format;
throw std::runtime_error(stream.str());
}
std::string to_string(Encoding v) {
switch (v) {
case Encoding::UNKNOWN:
return "UNKNOWN";
case Encoding::PCM_SIGNED:
return "PCM_S";
case Encoding::PCM_UNSIGNED:
return "PCM_U";
case Encoding::PCM_FLOAT:
return "PCM_F";
case Encoding::FLAC:
return "FLAC";
case Encoding::ULAW:
return "ULAW";
case Encoding::ALAW:
return "ALAW";
case Encoding::MP3:
return "MP3";
case Encoding::VORBIS:
return "VORBIS";
case Encoding::AMR_WB:
return "AMR_WB";
case Encoding::AMR_NB:
return "AMR_NB";
case Encoding::OPUS:
return "OPUS";
default:
throw std::runtime_error("Internal Error: unexpected encoding.");
}
}
Encoding get_encoding_from_option(const c10::optional<std::string>& encoding) {
if (!encoding.has_value())
return Encoding::NOT_PROVIDED;
std::string v = encoding.value();
if (v == "PCM_S")
return Encoding::PCM_SIGNED;
if (v == "PCM_U")
return Encoding::PCM_UNSIGNED;
if (v == "PCM_F")
return Encoding::PCM_FLOAT;
if (v == "ULAW")
return Encoding::ULAW;
if (v == "ALAW")
return Encoding::ALAW;
std::ostringstream stream;
stream << "Internal Error: unexpected encoding value: " << v;
throw std::runtime_error(stream.str());
}
BitDepth get_bit_depth_from_option(const c10::optional<int64_t>& bit_depth) {
if (!bit_depth.has_value())
return BitDepth::NOT_PROVIDED;
int64_t v = bit_depth.value();
switch (v) {
case 8:
return BitDepth::B8;
case 16:
return BitDepth::B16;
case 24:
return BitDepth::B24;
case 32:
return BitDepth::B32;
case 64:
return BitDepth::B64;
default: {
std::ostringstream s;
s << "Internal Error: unexpected bit depth value: " << v;
throw std::runtime_error(s.str());
}
}
}
} // namespace sox_utils
} // namespace torchaudio
#ifndef TORCHAUDIO_SOX_TYPES_H
#define TORCHAUDIO_SOX_TYPES_H
#include <torch/script.h>
namespace torchaudio {
namespace sox_utils {
enum class Format {
WAV,
MP3,
FLAC,
VORBIS,
AMR_NB,
AMR_WB,
AMB,
SPHERE,
};
Format get_format_from_string(const std::string& format);
enum class Encoding {
NOT_PROVIDED,
UNKNOWN,
PCM_SIGNED,
PCM_UNSIGNED,
PCM_FLOAT,
FLAC,
ULAW,
ALAW,
MP3,
VORBIS,
AMR_WB,
AMR_NB,
OPUS,
};
std::string to_string(Encoding v);
Encoding get_encoding_from_option(const c10::optional<std::string>& encoding);
enum class BitDepth : unsigned {
NOT_PROVIDED = 0,
B8 = 8,
B16 = 16,
B24 = 24,
B32 = 32,
B64 = 64,
};
BitDepth get_bit_depth_from_option(const c10::optional<int64_t>& bit_depth);
} // namespace sox_utils
} // namespace torchaudio
#endif
#include <c10/core/ScalarType.h> #include <c10/core/ScalarType.h>
#include <sox.h> #include <sox.h>
#include <torchaudio/csrc/sox/types.h>
#include <torchaudio/csrc/sox/utils.h> #include <torchaudio/csrc/sox/utils.h>
namespace torchaudio { namespace torchaudio {
...@@ -163,22 +164,32 @@ torch::Tensor convert_to_tensor( ...@@ -163,22 +164,32 @@ torch::Tensor convert_to_tensor(
const caffe2::TypeMeta dtype, const caffe2::TypeMeta dtype,
const bool normalize, const bool normalize,
const bool channels_first) { const bool channels_first) {
auto t = torch::from_blob( torch::Tensor t;
buffer, {num_samples / num_channels, num_channels}, torch::kInt32); uint64_t dummy;
// Note: Tensor created from_blob does not own data but borrwos SOX_SAMPLE_LOCALS;
// So make sure to create a new copy after processing samples.
if (normalize || dtype == torch::kFloat32) { if (normalize || dtype == torch::kFloat32) {
t = t.to(torch::kFloat32); t = torch::empty(
t *= (t > 0) / 2147483647. + (t < 0) / 2147483648.; {num_samples / num_channels, num_channels}, torch::kFloat32);
auto ptr = t.data_ptr<float_t>();
for (int32_t i = 0; i < num_samples; ++i) {
ptr[i] = SOX_SAMPLE_TO_FLOAT_32BIT(buffer[i], dummy);
}
} else if (dtype == torch::kInt32) { } else if (dtype == torch::kInt32) {
t = t.clone(); t = torch::from_blob(
buffer, {num_samples / num_channels, num_channels}, torch::kInt32)
.clone();
} else if (dtype == torch::kInt16) { } else if (dtype == torch::kInt16) {
t.floor_divide_(1 << 16); t = torch::empty({num_samples / num_channels, num_channels}, torch::kInt16);
t = t.to(torch::kInt16); auto ptr = t.data_ptr<int16_t>();
for (int32_t i = 0; i < num_samples; ++i) {
ptr[i] = SOX_SAMPLE_TO_SIGNED_16BIT(buffer[i], dummy);
}
} else if (dtype == torch::kUInt8) { } else if (dtype == torch::kUInt8) {
t.floor_divide_(1 << 24); t = torch::empty({num_samples / num_channels, num_channels}, torch::kUInt8);
t += 128; auto ptr = t.data_ptr<uint8_t>();
t = t.to(torch::kUInt8); for (int32_t i = 0; i < num_samples; ++i) {
ptr[i] = SOX_SAMPLE_TO_UNSIGNED_8BIT(buffer[i], dummy);
}
} else { } else {
throw std::runtime_error("Unsupported dtype."); throw std::runtime_error("Unsupported dtype.");
} }
...@@ -188,63 +199,181 @@ torch::Tensor convert_to_tensor( ...@@ -188,63 +199,181 @@ torch::Tensor convert_to_tensor(
return t.contiguous(); return t.contiguous();
} }
torch::Tensor unnormalize_wav(const torch::Tensor input_tensor) {
const auto dtype = input_tensor.dtype();
auto tensor = input_tensor;
if (dtype == torch::kFloat32) {
double multi_pos = 2147483647.;
double multi_neg = -2147483648.;
auto mult = (tensor > 0) * multi_pos - (tensor < 0) * multi_neg;
tensor = tensor.to(torch::dtype(torch::kFloat64));
tensor *= mult;
tensor.clamp_(multi_neg, multi_pos);
tensor = tensor.to(torch::dtype(torch::kInt32));
} else if (dtype == torch::kInt32) {
// already denormalized
} else if (dtype == torch::kInt16) {
tensor = tensor.to(torch::dtype(torch::kInt32));
tensor *= ((tensor != 0) * 65536);
} else if (dtype == torch::kUInt8) {
tensor = tensor.to(torch::dtype(torch::kInt32));
tensor -= 128;
tensor *= 16777216;
} else {
throw std::runtime_error("Unexpected dtype.");
}
return tensor;
}
const std::string get_filetype(const std::string path) { const std::string get_filetype(const std::string path) {
std::string ext = path.substr(path.find_last_of(".") + 1); std::string ext = path.substr(path.find_last_of(".") + 1);
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
return ext; return ext;
} }
sox_encoding_t get_encoding( namespace {
const std::string filetype,
const caffe2::TypeMeta dtype) { std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
if (filetype == "mp3") const std::string format,
return SOX_ENCODING_MP3; const caffe2::TypeMeta dtype,
if (filetype == "flac") const Encoding& encoding,
return SOX_ENCODING_FLAC; const BitDepth& bits_per_sample) {
if (filetype == "ogg" || filetype == "vorbis") switch (encoding) {
return SOX_ENCODING_VORBIS; case Encoding::NOT_PROVIDED:
if (filetype == "wav" || filetype == "amb") { switch (bits_per_sample) {
if (dtype == torch::kUInt8) case BitDepth::NOT_PROVIDED:
return SOX_ENCODING_UNSIGNED; if (dtype == torch::kFloat32)
if (dtype == torch::kInt16) return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
return SOX_ENCODING_SIGN2; if (dtype == torch::kInt32)
if (dtype == torch::kInt32) return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
return SOX_ENCODING_SIGN2; if (dtype == torch::kInt16)
if (dtype == torch::kFloat32) return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
return SOX_ENCODING_FLOAT; if (dtype == torch::kUInt8)
throw std::runtime_error("Unsupported dtype."); return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
throw std::runtime_error("Internal Error: Unexpected dtype.");
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
return std::make_tuple<>(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bits_per_sample));
}
case Encoding::PCM_SIGNED:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
case BitDepth::B8:
throw std::runtime_error(
format + " does not support 8-bit signed PCM encoding.");
default:
return std::make_tuple<>(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bits_per_sample));
}
case Encoding::PCM_UNSIGNED:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
throw std::runtime_error(
format + " only supports 8-bit for unsigned PCM encoding.");
}
case Encoding::PCM_FLOAT:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B32:
return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
case BitDepth::B64:
return std::make_tuple<>(SOX_ENCODING_FLOAT, 64);
default:
throw std::runtime_error(
format +
" only supports 32-bit or 64-bit for floating-point PCM encoding.");
}
case Encoding::ULAW:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
default:
throw std::runtime_error(
format + " only supports 8-bit for mu-law encoding.");
}
case Encoding::ALAW:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ALAW, 8);
default:
throw std::runtime_error(
format + " only supports 8-bit for a-law encoding.");
}
default:
throw std::runtime_error(
format + " does not support encoding: " + to_string(encoding));
}
}
std::tuple<sox_encoding_t, unsigned> get_save_encoding(
const std::string& format,
const caffe2::TypeMeta dtype,
const c10::optional<std::string>& encoding,
const c10::optional<int64_t>& bits_per_sample) {
const Format fmt = get_format_from_string(format);
const Encoding enc = get_encoding_from_option(encoding);
const BitDepth bps = get_bit_depth_from_option(bits_per_sample);
switch (fmt) {
case Format::WAV:
case Format::AMB:
return get_save_encoding_for_wav(format, dtype, enc, bps);
case Format::MP3:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("mp3 does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"mp3 does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_MP3, 16);
case Format::VORBIS:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("vorbis does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"vorbis does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_VORBIS, 16);
case Format::AMR_NB:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("amr-nb does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"amr-nb does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16);
case Format::FLAC:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("flac does not support `encoding` option.");
switch (bps) {
case BitDepth::B32:
case BitDepth::B64:
throw std::runtime_error(
"flac does not support `bits_per_sample` larger than 24.");
default:
return std::make_tuple<>(
SOX_ENCODING_FLAC, static_cast<unsigned>(bps));
}
case Format::SPHERE:
switch (enc) {
case Encoding::NOT_PROVIDED:
case Encoding::PCM_SIGNED:
switch (bps) {
case BitDepth::NOT_PROVIDED:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
default:
return std::make_tuple<>(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bps));
}
case Encoding::PCM_UNSIGNED:
throw std::runtime_error(
"sph does not support unsigned integer PCM.");
case Encoding::PCM_FLOAT:
throw std::runtime_error("sph does not support floating point PCM.");
case Encoding::ULAW:
switch (bps) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
default:
throw std::runtime_error(
"sph only supports 8-bit for mu-law encoding.");
}
case Encoding::ALAW:
switch (bps) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ALAW, 8);
default:
return std::make_tuple<>(
SOX_ENCODING_ALAW, static_cast<unsigned>(bps));
}
default:
throw std::runtime_error(
"sph does not support encoding: " + encoding.value());
}
default:
throw std::runtime_error("Unsupported format: " + format);
} }
if (filetype == "sph")
return SOX_ENCODING_SIGN2;
if (filetype == "amr-nb")
return SOX_ENCODING_AMR_NB;
throw std::runtime_error("Unsupported file type: " + filetype);
} }
unsigned get_precision( unsigned get_precision(
...@@ -270,14 +399,13 @@ unsigned get_precision( ...@@ -270,14 +399,13 @@ unsigned get_precision(
if (filetype == "sph") if (filetype == "sph")
return 32; return 32;
if (filetype == "amr-nb") { if (filetype == "amr-nb") {
TORCH_INTERNAL_ASSERT(
dtype == torch::kInt16,
"When saving to AMR-NB format, the input tensor must be int16 type.");
return 16; return 16;
} }
throw std::runtime_error("Unsupported file type: " + filetype); throw std::runtime_error("Unsupported file type: " + filetype);
} }
} // namespace
sox_signalinfo_t get_signalinfo( sox_signalinfo_t get_signalinfo(
const torch::Tensor* waveform, const torch::Tensor* waveform,
const int64_t sample_rate, const int64_t sample_rate,
...@@ -325,12 +453,15 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype) { ...@@ -325,12 +453,15 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype) {
} }
sox_encodinginfo_t get_encodinginfo_for_save( sox_encodinginfo_t get_encodinginfo_for_save(
const std::string filetype, const std::string& format,
const caffe2::TypeMeta dtype, const caffe2::TypeMeta dtype,
c10::optional<double>& compression) { const c10::optional<double>& compression,
const c10::optional<std::string>& encoding,
const c10::optional<int64_t>& bits_per_sample) {
auto enc = get_save_encoding(format, dtype, encoding, bits_per_sample);
return sox_encodinginfo_t{ return sox_encodinginfo_t{
/*encoding=*/get_encoding(filetype, dtype), /*encoding=*/std::get<0>(enc),
/*bits_per_sample=*/get_precision(filetype, dtype), /*bits_per_sample=*/std::get<1>(enc),
/*compression=*/compression.value_or(HUGE_VAL), /*compression=*/compression.value_or(HUGE_VAL),
/*reverse_bytes=*/sox_option_default, /*reverse_bytes=*/sox_option_default,
/*reverse_nibbles=*/sox_option_default, /*reverse_nibbles=*/sox_option_default,
......
...@@ -93,11 +93,6 @@ torch::Tensor convert_to_tensor( ...@@ -93,11 +93,6 @@ torch::Tensor convert_to_tensor(
const bool normalize, const bool normalize,
const bool channels_first); const bool channels_first);
///
/// Convert float32/int32/int16/uint8 Tensor to int32 for Torch -> Sox
/// conversion.
torch::Tensor unnormalize_wav(const torch::Tensor);
/// Extract extension from file path /// Extract extension from file path
const std::string get_filetype(const std::string path); const std::string get_filetype(const std::string path);
...@@ -113,9 +108,11 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype); ...@@ -113,9 +108,11 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype);
/// Get sox_encodinginfo_t for saving to file/file object /// Get sox_encodinginfo_t for saving to file/file object
sox_encodinginfo_t get_encodinginfo_for_save( sox_encodinginfo_t get_encodinginfo_for_save(
const std::string filetype, const std::string& format,
const caffe2::TypeMeta dtype, const caffe2::TypeMeta dtype,
c10::optional<double>& compression); const c10::optional<double>& compression,
const c10::optional<std::string>& encoding,
const c10::optional<int64_t>& bits_per_sample);
#ifdef TORCH_API_INCLUDE_EXTENSION_H #ifdef TORCH_API_INCLUDE_EXTENSION_H
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment