Unverified Commit 674a71d1 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Add target `dtype` argument to `save` function for sox backend (#1204)

parent 47d97e30
import io
import itertools
import torch
from torchaudio.backend import sox_io_backend
from parameterized import parameterized
......@@ -24,7 +25,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
"""`sox_io_backend.save` can save wav format."""
path = self.get_temp_path('data.wav')
expected = get_wav_data(dtype, num_channels, num_frames=num_frames)
sox_io_backend.save(path, expected, sample_rate)
sox_io_backend.save(path, expected, sample_rate, dtype=None)
found, sr = load_wav(path)
assert sample_rate == sr
self.assertEqual(found, expected)
......@@ -68,7 +69,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to mp3 with torchaudio
sox_io_backend.save(
mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate)
mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate, dtype=None)
# 2.2. Convert the mp3 to wav with Sox
sox_utils.convert_audio_file(mp3_path, wav_path)
# 2.3. Load
......@@ -99,7 +100,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to flac with torchaudio
sox_io_backend.save(
flc_path, load_wav(src_path)[0], sample_rate, compression=compression_level)
flc_path, load_wav(src_path)[0], sample_rate, compression=compression_level, dtype=None)
# 2.2. Convert the flac to wav with Sox
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32)
......@@ -132,7 +133,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to vorbis with torchaudio
sox_io_backend.save(
vbs_path, load_wav(src_path)[0], sample_rate, compression=quality_level)
vbs_path, load_wav(src_path)[0], sample_rate, compression=quality_level, dtype=None)
# 2.2. Convert the vorbis to wav with Sox
sox_utils.convert_audio_file(vbs_path, wav_path)
# 2.3. Load
......@@ -184,7 +185,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to sph with torchaudio
sox_io_backend.save(flc_path, load_wav(src_path)[0], sample_rate)
sox_io_backend.save(flc_path, load_wav(src_path)[0], sample_rate, dtype=None)
# 2.2. Convert the sph to wav with Sox
# converting to 32 bit because sph file has 24 bit depth which scipy cannot handle.
sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32)
......@@ -216,7 +217,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to amb with torchaudio
sox_io_backend.save(amb_path, load_wav(src_path, normalize=False)[0], sample_rate)
sox_io_backend.save(amb_path, load_wav(src_path, normalize=False)[0], sample_rate, dtype=None)
# 2.2. Convert the amb to wav with Sox
sox_utils.convert_audio_file(amb_path, wav_path)
# 2.3. Load
......@@ -248,7 +249,7 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to amr_nb with torchaudio
sox_io_backend.save(amr_path, load_wav(src_path, normalize=False)[0], sample_rate)
sox_io_backend.save(amr_path, load_wav(src_path, normalize=False)[0], sample_rate, dtype=None)
# 2.2. Convert the amr_nb to wav with Sox
sox_utils.convert_audio_file(amr_path, wav_path)
# 2.3. Load
......@@ -389,7 +390,7 @@ class TestSaveParams(TempDirMixin, PytorchTestCase):
path = self.get_temp_path('data.wav')
data = get_wav_data('int32', 2, channels_first=channels_first)
sox_io_backend.save(
path, data, 8000, channels_first=channels_first)
path, data, 8000, channels_first=channels_first, dtype=None)
found = load_wav(path)[0]
expected = data if channels_first else data.transpose(1, 0)
self.assertEqual(found, expected)
......@@ -402,7 +403,7 @@ class TestSaveParams(TempDirMixin, PytorchTestCase):
path = self.get_temp_path('data.wav')
expected = get_wav_data(dtype, 4)[::2, ::2]
assert not expected.is_contiguous()
sox_io_backend.save(path, expected, 8000)
sox_io_backend.save(path, expected, 8000, dtype=None)
found = load_wav(path)[0]
self.assertEqual(found, expected)
......@@ -415,10 +416,24 @@ class TestSaveParams(TempDirMixin, PytorchTestCase):
expected = get_wav_data(dtype, 4)[::2, ::2]
data = expected.clone()
sox_io_backend.save(path, data, 8000)
sox_io_backend.save(path, data, 8000, dtype=None)
self.assertEqual(data, expected)
@parameterized.expand([
('float32', torch.tensor([-1.0, -0.5, 0, 0.5, 1.0]).to(torch.float32)),
('int32', torch.tensor([-2147483648, -1073741824, 0, 1073741824, 2147483647]).to(torch.int32)),
('int16', torch.tensor([-32768, -16384, 0, 16384, 32767]).to(torch.int16)),
('uint8', torch.tensor([0, 64, 128, 192, 255]).to(torch.uint8)),
])
def test_dtype_conversion(self, dtype, expected):
"""`save` performs dtype conversion on float32 src tensors only."""
path = self.get_temp_path("data.wav")
data = torch.tensor([-1.0, -0.5, 0, 0.5, 1.0]).to(torch.float32).view(-1, 1)
sox_io_backend.save(path, data, 8000, dtype=dtype)
found = load_wav(path, normalize=False)[0]
self.assertEqual(found, expected.view(-1, 1))
@skipIfNoExtension
@skipIfNoExec('sox')
......@@ -452,11 +467,11 @@ class TestFileObject(SaveTestBase):
res_path = self.get_temp_path(f'test.{ext}')
sox_io_backend.save(
ref_path, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression)
sample_rate=sample_rate, compression=compression, dtype=None)
with open(res_path, 'wb') as fileobj:
sox_io_backend.save(
fileobj, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression, format=ext)
sample_rate=sample_rate, compression=compression, format=ext, dtype=None)
expected_data, _ = sox_io_backend.load(ref_path)
data, sr = sox_io_backend.load(res_path)
......@@ -489,11 +504,11 @@ class TestFileObject(SaveTestBase):
res_path = self.get_temp_path(f'test.{ext}')
sox_io_backend.save(
ref_path, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression)
sample_rate=sample_rate, compression=compression, dtype=None)
fileobj = io.BytesIO()
sox_io_backend.save(
fileobj, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression, format=ext)
sample_rate=sample_rate, compression=compression, format=ext, dtype=None)
fileobj.seek(0)
with open(res_path, 'wb') as file_:
file_.write(fileobj.read())
......
import os
import warnings
from typing import Tuple, Optional
import torch
......@@ -178,15 +179,16 @@ def _save(
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
dtype: Optional[str] = None,
):
if hasattr(filepath, 'write'):
if format is None:
raise RuntimeError('`format` is required when saving to file object.')
torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression, format)
filepath, src, sample_rate, channels_first, compression, format, dtype)
else:
torch.ops.torchaudio.sox_io_save_audio_file(
os.fspath(filepath), src, sample_rate, channels_first, compression, format)
os.fspath(filepath), src, sample_rate, channels_first, compression, format, dtype)
@_mod_utils.requires_module('torchaudio._torchaudio')
......@@ -197,6 +199,7 @@ def save(
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
dtype: Optional[str] = None,
):
"""Save audio data to file.
......@@ -243,12 +246,22 @@ def save(
format (str, optional):
Output audio format. This is required when the output audio format cannot be infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object).
dtype (str, optional)
Output tensor dtype.
Valid values: ``"uint8", "int16", "int32", "float32", "float64", None``
``dtype=None`` means no conversion is performed.
``dtype`` parameter is only effective for ``float32`` Tensor.
"""
if src.dtype == torch.float32 and dtype is None:
warnings.warn(
'`dtype` default value will be changed to `int16` in 0.9 release.'
'Specify `dtype` to suppress this warning.'
)
if not torch.jit.is_scripting():
_save(filepath, src, sample_rate, channels_first, compression, format)
_save(filepath, src, sample_rate, channels_first, compression, format, dtype)
return
torch.ops.torchaudio.sox_io_save_audio_file(
filepath, src, sample_rate, channels_first, compression, format)
filepath, src, sample_rate, channels_first, compression, format, dtype)
@_mod_utils.requires_module('torchaudio._torchaudio')
......
......@@ -107,10 +107,19 @@ void save_audio_file(
int64_t sample_rate,
bool channels_first,
c10::optional<double> compression,
c10::optional<std::string> format) {
c10::optional<std::string> format,
c10::optional<std::string> dtype) {
validate_input_tensor(tensor);
auto signal = TensorSignal(tensor, sample_rate, channels_first);
if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) {
throw std::runtime_error(
"dtype conversion only supported for float32 tensors");
}
const auto tgt_dtype =
(tensor.dtype() == torch::kFloat32 && dtype.has_value())
? get_dtype_from_str(dtype.value())
: tensor.dtype();
const auto filetype = [&]() {
if (format.has_value())
......@@ -124,8 +133,7 @@ void save_audio_file(
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
}
const auto signal_info = get_signalinfo(&signal, filetype);
const auto encoding_info =
get_encodinginfo(filetype, tensor.dtype(), compression);
const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression);
SoxFormat sf(sox_open_write(
path.c_str(),
......@@ -239,10 +247,19 @@ void save_audio_fileobj(
int64_t sample_rate,
bool channels_first,
c10::optional<double> compression,
std::string filetype) {
std::string filetype,
c10::optional<std::string> dtype) {
validate_input_tensor(tensor);
auto signal = TensorSignal(tensor, sample_rate, channels_first);
if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) {
throw std::runtime_error(
"dtype conversion only supported for float32 tensors");
}
const auto tgt_dtype =
(tensor.dtype() == torch::kFloat32 && dtype.has_value())
? get_dtype_from_str(dtype.value())
: tensor.dtype();
if (filetype == "amr-nb") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
......@@ -253,8 +270,7 @@ void save_audio_fileobj(
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
}
const auto signal_info = get_signalinfo(&signal, filetype);
const auto encoding_info =
get_encodinginfo(filetype, tensor.dtype(), compression);
const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression);
AutoReleaseBuffer buffer;
......
......@@ -46,7 +46,8 @@ void save_audio_file(
int64_t sample_rate,
bool channels_first,
c10::optional<double> compression,
c10::optional<std::string> format);
c10::optional<std::string> format,
c10::optional<std::string> dtype);
#ifdef TORCH_API_INCLUDE_EXTENSION_H
......@@ -68,7 +69,8 @@ void save_audio_fileobj(
int64_t sample_rate,
bool channels_first,
c10::optional<double> compression,
std::string filetype);
std::string filetype,
c10::optional<std::string> dtype);
#endif // TORCH_API_INCLUDE_EXTENSION_H
......
......@@ -156,6 +156,24 @@ caffe2::TypeMeta get_dtype(
return c10::scalarTypeToTypeMeta(dtype);
}
caffe2::TypeMeta get_dtype_from_str(const std::string dtype) {
const auto tgt_dtype = [&]() {
if (dtype == "uint8")
return torch::kUInt8;
else if (dtype == "int16")
return torch::kInt16;
else if (dtype == "int32")
return torch::kInt32;
else if (dtype == "float32")
return torch::kFloat32;
else if (dtype == "float64")
return torch::kFloat64;
else
throw std::runtime_error("Unsupported dtype");
}();
return c10::scalarTypeToTypeMeta(tgt_dtype);
}
torch::Tensor convert_to_tensor(
sox_sample_t* buffer,
const int32_t num_samples,
......
......@@ -85,6 +85,8 @@ caffe2::TypeMeta get_dtype(
const sox_encoding_t encoding,
const unsigned precision);
caffe2::TypeMeta get_dtype_from_str(const std::string dtype);
///
/// Convert sox_sample_t buffer to uint8/int16/int32/float32 Tensor
/// NOTE: This function might modify the values in the input buffer to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment