Unverified Commit 2c723e8c authored by moto's avatar moto Committed by GitHub
Browse files

Add "soundfile" backend compatible to "sox_io" (#922)

As a part of the "sox" backend sunset plan (#903), we add a "soundfile" backend that is compatible with the "sox_io" backend. No new public backend name is added. We provide a switch to change the interface/behavior of "soundfile" backend.

This commit contains;
 - The implementation of the new "soundfile" backend.
 - The flag to switch the behavior of "soundfile" backend. (`torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE`)
 - Test for the new backend and switching mechanism.

The default behavior of "soundfile" backend is not changed. The users who want to opt-in the new "soundfile" interface can do so by `torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False` before changing the backend to "soundfile".

In 0.8.0 release, the "soundfile" backend will use this interface by default, and users can still use the legacy one with `torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = True`. In 0.9.0, the legacy interface is removed and `torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE` flag will be eventually removed.
parent 8e370559
...@@ -38,6 +38,18 @@ class TestBackendSwitch_SoXIO(BackendSwitchMixin, common_utils.TorchaudioTestCas ...@@ -38,6 +38,18 @@ class TestBackendSwitch_SoXIO(BackendSwitchMixin, common_utils.TorchaudioTestCas
@common_utils.skipIfNoModule('soundfile') @common_utils.skipIfNoModule('soundfile')
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase): class TestBackendSwitch_soundfile_legacy(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'soundfile' backend = 'soundfile'
backend_module = torchaudio.backend.soundfile_backend backend_module = torchaudio.backend.soundfile_backend
@common_utils.skipIfNoModule('soundfile')
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'soundfile'
backend_module = torchaudio.backend._soundfile_backend
def setUp(self):
torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False
def tearDown(self):
torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = True
...@@ -52,12 +52,18 @@ def get_wav_data( ...@@ -52,12 +52,18 @@ def get_wav_data(
if dtype == 'uint8': if dtype == 'uint8':
base = torch.linspace(0, 255, num_frames, dtype=dtype_) base = torch.linspace(0, 255, num_frames, dtype=dtype_)
if dtype == 'float32': elif dtype == 'int8':
base = torch.linspace(-128, 127, num_frames, dtype=dtype_)
elif dtype == 'float32':
base = torch.linspace(-1., 1., num_frames, dtype=dtype_) base = torch.linspace(-1., 1., num_frames, dtype=dtype_)
if dtype == 'int32': elif dtype == 'float64':
base = torch.linspace(-1., 1., num_frames, dtype=dtype_)
elif dtype == 'int32':
base = torch.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_) base = torch.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_)
if dtype == 'int16': elif dtype == 'int16':
base = torch.linspace(-32768, 32767, num_frames, dtype=dtype_) base = torch.linspace(-32768, 32767, num_frames, dtype=dtype_)
else:
raise NotImplementedError(f'Unsupported dtype {dtype}')
data = base.repeat([num_channels, 1]) data = base.repeat([num_channels, 1])
if not channels_first: if not channels_first:
data = data.transpose(1, 0) data = data.transpose(1, 0)
......
...@@ -45,6 +45,9 @@ class Test_LoadSave(unittest.TestCase): ...@@ -45,6 +45,9 @@ class Test_LoadSave(unittest.TestCase):
test_filepath_wav = os.path.join(test_dirpath, "assets", test_filepath_wav = os.path.join(test_dirpath, "assets",
"steam-train-whistle-daniel_simon.wav") "steam-train-whistle-daniel_simon.wav")
def setUp(self):
torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = True
def test_1_save(self): def test_1_save(self):
for backend in BACKENDS_MP3: for backend in BACKENDS_MP3:
with self.subTest(): with self.subTest():
......
import itertools
from unittest import skipIf
from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available
def name_func(func, _, params):
return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}'
def dtype2subtype(dtype):
return {
"float64": "DOUBLE",
"float32": "FLOAT",
"int32": "PCM_32",
"int16": "PCM_16",
"uint8": "PCM_U8",
"int8": "PCM_S8",
}[dtype]
def skipIfFormatNotSupported(fmt):
fmts = []
if is_module_available("soundfile"):
import soundfile
fmts = soundfile.available_formats()
return skipIf(fmt not in fmts, f'"{fmt}" is not supported by sondfile')
return skipIf(True, '"soundfile" not available.')
def parameterize(*params):
return parameterized.expand(list(itertools.product(*params)), name_func=name_func)
import torch
from torchaudio.backend import _soundfile_backend as soundfile_backend
from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoModule,
get_wav_data,
save_wav,
)
from .common import skipIfFormatNotSupported, parameterize
if _mod_utils.is_module_available("soundfile"):
import soundfile
@skipIfNoModule("soundfile")
class TestInfo(TempDirMixin, PytorchTestCase):
@parameterize(
["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2],
)
def test_wav(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.info` can check wav file correctly"""
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(
dtype, num_channels, normalize=False, num_frames=duration * sample_rate
)
save_wav(path, data, sample_rate)
info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
@parameterize(
["float32", "int32", "int16", "uint8"], [8000, 16000], [4, 8, 16, 32],
)
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.info` can check wav file with channels more than 2 correctly"""
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(
dtype, num_channels, normalize=False, num_frames=duration * sample_rate
)
save_wav(path, data, sample_rate)
info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
@parameterize([8000, 16000], [1, 2])
@skipIfFormatNotSupported("FLAC")
def test_flac(self, sample_rate, num_channels):
"""`soundfile_backend.info` can check flac file correctly"""
duration = 1
num_frames = sample_rate * duration
data = torch.randn(num_frames, num_channels).numpy()
path = self.get_temp_path("data.flac")
soundfile.write(path, data, sample_rate)
info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
@parameterize([8000, 16000], [1, 2])
@skipIfFormatNotSupported("OGG")
def test_ogg(self, sample_rate, num_channels):
"""`soundfile_backend.info` can check ogg file correctly"""
duration = 1
num_frames = sample_rate * duration
data = torch.randn(num_frames, num_channels).numpy()
path = self.get_temp_path("data.ogg")
soundfile.write(path, data, sample_rate)
info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
@parameterize([8000, 16000], [1, 2])
@skipIfFormatNotSupported("NIST")
def test_sphere(self, sample_rate, num_channels):
"""`soundfile_backend.info` can check sph file correctly"""
duration = 1
num_frames = sample_rate * duration
data = torch.randn(num_frames, num_channels).numpy()
path = self.get_temp_path("data.nist")
soundfile.write(path, data, sample_rate)
info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
import itertools
from unittest.mock import patch
import torch
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import _soundfile_backend as soundfile_backend
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoModule,
get_wav_data,
normalize_wav,
load_wav,
save_wav,
)
from .common import (
parameterize,
dtype2subtype,
skipIfFormatNotSupported,
)
if _mod_utils.is_module_available("soundfile"):
import soundfile
def _get_mock_path(
ext: str, dtype: str, sample_rate: int, num_channels: int, num_frames: int,
):
return f"{dtype}_{sample_rate}_{num_channels}_{num_frames}.{ext}"
def _get_mock_params(path: str):
filename, ext = path.split(".")
parts = filename.split("_")
return {
"ext": ext,
"dtype": parts[0],
"sample_rate": int(parts[1]),
"num_channels": int(parts[2]),
"num_frames": int(parts[3]),
}
class SoundFileMock:
def __init__(self, path, mode):
assert mode == "r"
self.path = path
self._params = _get_mock_params(path)
self._start = None
@property
def samplerate(self):
return self._params["sample_rate"]
@property
def format(self):
if self._params["ext"] == "wav":
return "WAV"
if self._params["ext"] == "flac":
return "FLAC"
if self._params["ext"] == "ogg":
return "OGG"
if self._params["ext"] in ["sph", "nis", "nist"]:
return "NIST"
@property
def subtype(self):
if self._params["ext"] == "ogg":
return "VORBIS"
return dtype2subtype(self._params["dtype"])
def _prepare_read(self, start, stop, frames):
assert stop is None
self._start = start
return frames
def read(self, frames, dtype, always_2d):
assert always_2d
data = get_wav_data(
dtype,
self._params["num_channels"],
normalize=False,
num_frames=self._params["num_frames"],
channels_first=False,
).numpy()
return data[self._start:self._start + frames]
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
pass
class MockedLoadTest(PytorchTestCase):
def assert_dtype(
self, ext, dtype, sample_rate, num_channels, normalize, channels_first
):
"""When format is WAV or NIST, normalize=False will return the native dtype Tensor, otherwise float32"""
num_frames = 3 * sample_rate
path = _get_mock_path(ext, dtype, sample_rate, num_channels, num_frames)
expected_dtype = (
torch.float32
if normalize or ext not in ["wav", "nist"]
else getattr(torch, dtype)
)
with patch("soundfile.SoundFile", SoundFileMock):
found, sr = soundfile_backend.load(
path, normalize=normalize, channels_first=channels_first
)
assert found.dtype == expected_dtype
assert sample_rate == sr
@parameterize(
["uint8", "int16", "int32", "float32", "float64"],
[8000, 16000],
[1, 2],
[True, False],
[True, False],
)
def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""Returns native dtype when normalize=False else float32"""
self.assert_dtype(
"wav", dtype, sample_rate, num_channels, normalize, channels_first
)
@parameterize(
["int8", "int16", "int32"], [8000, 16000], [1, 2], [True, False], [True, False],
)
def test_sphere(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""Returns float32 always"""
self.assert_dtype(
"sph", dtype, sample_rate, num_channels, normalize, channels_first
)
@parameterize([8000, 16000], [1, 2], [True, False], [True, False])
def test_ogg(self, sample_rate, num_channels, normalize, channels_first):
"""Returns float32 always"""
self.assert_dtype(
"ogg", "int16", sample_rate, num_channels, normalize, channels_first
)
@parameterize([8000, 16000], [1, 2], [True, False], [True, False])
def test_flac(self, sample_rate, num_channels, normalize, channels_first):
"""`soundfile_backend.load` can load ogg format."""
self.assert_dtype(
"flac", "int16", sample_rate, num_channels, normalize, channels_first
)
class LoadTestBase(TempDirMixin, PytorchTestCase):
def assert_wav(
self,
dtype,
sample_rate,
num_channels,
normalize,
channels_first=True,
duration=1,
):
"""`soundfile_backend.load` can load wav format correctly.
Wav data loaded with soundfile backend should match those with scipy
"""
path = self.get_temp_path("reference.wav")
num_frames = duration * sample_rate
data = get_wav_data(
dtype,
num_channels,
normalize=normalize,
num_frames=num_frames,
channels_first=channels_first,
)
save_wav(path, data, sample_rate, channels_first=channels_first)
expected = load_wav(path, normalize=normalize, channels_first=channels_first)[0]
data, sr = soundfile_backend.load(
path, normalize=normalize, channels_first=channels_first
)
assert sr == sample_rate
self.assertEqual(data, expected)
def assert_sphere(
self, dtype, sample_rate, num_channels, channels_first=True, duration=1,
):
"""`soundfile_backend.load` can load SPHERE format correctly."""
path = self.get_temp_path("reference.sph")
num_frames = duration * sample_rate
raw = get_wav_data(
dtype,
num_channels,
num_frames=num_frames,
normalize=False,
channels_first=False,
)
soundfile.write(
path, raw, sample_rate, subtype=dtype2subtype(dtype), format="NIST"
)
expected = normalize_wav(raw.t() if channels_first else raw)
data, sr = soundfile_backend.load(path, channels_first=channels_first)
assert sr == sample_rate
self.assertEqual(data, expected, atol=1e-4, rtol=1e-8)
def assert_flac(
self, dtype, sample_rate, num_channels, channels_first=True, duration=1,
):
"""`soundfile_backend.load` can load FLAC format correctly."""
path = self.get_temp_path("reference.flac")
num_frames = duration * sample_rate
raw = get_wav_data(
dtype,
num_channels,
num_frames=num_frames,
normalize=False,
channels_first=False,
)
soundfile.write(path, raw, sample_rate)
expected = normalize_wav(raw.t() if channels_first else raw)
data, sr = soundfile_backend.load(path, channels_first=channels_first)
assert sr == sample_rate
self.assertEqual(data, expected, atol=1e-4, rtol=1e-8)
@skipIfNoModule("soundfile")
class TestLoad(LoadTestBase):
"""Test the correctness of `soundfile_backend.load` for various formats"""
@parameterize(
["float32", "int32", "int16"],
[8000, 16000],
[1, 2],
[False, True],
[False, True],
)
def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""`soundfile_backend.load` can load wav format correctly."""
self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first)
@parameterize(
["int16"], [16000], [2], [False],
)
def test_wav_large(self, dtype, sample_rate, num_channels, normalize):
"""`soundfile_backend.load` can load large wav file correctly."""
two_hours = 2 * 60 * 60
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=two_hours)
@parameterize(["float32", "int32", "int16"], [4, 8, 16, 32], [False, True])
def test_multiple_channels(self, dtype, num_channels, channels_first):
"""`soundfile_backend.load` can load wav file with more than 2 channels."""
sample_rate = 8000
normalize = False
self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first)
@parameterize(["int32", "int16"], [8000, 16000], [1, 2], [False, True])
@skipIfFormatNotSupported("NIST")
def test_sphere(self, dtype, sample_rate, num_channels, channels_first):
"""`soundfile_backend.load` can load sphere format correctly."""
self.assert_sphere(dtype, sample_rate, num_channels, channels_first)
@parameterize(["int32", "int16"], [8000, 16000], [1, 2], [False, True])
@skipIfFormatNotSupported("FLAC")
def test_flac(self, dtype, sample_rate, num_channels, channels_first):
"""`soundfile_backend.load` can load flac format correctly."""
self.assert_flac(dtype, sample_rate, num_channels, channels_first)
import itertools
from unittest.mock import patch
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import _soundfile_backend as soundfile_backend
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoModule,
get_wav_data,
load_wav,
)
from .common import parameterize, dtype2subtype, skipIfFormatNotSupported
if _mod_utils.is_module_available("soundfile"):
import soundfile
class MockedSaveTest(PytorchTestCase):
@parameterize(
["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2], [False, True],
)
@patch("soundfile.write")
def test_wav(self, dtype, sample_rate, num_channels, channels_first, mocked_write):
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
filepath = "foo.wav"
input_tensor = get_wav_data(
dtype,
num_channels,
num_frames=3 * sample_rate,
normalize=dtype == "flaot32",
channels_first=channels_first,
).t()
soundfile_backend.save(
filepath, input_tensor, sample_rate, channels_first=channels_first
)
# on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1]
assert args["file"] == filepath
assert args["samplerate"] == sample_rate
assert args["subtype"] == dtype2subtype(dtype)
assert args["format"] is None
self.assertEqual(
args["data"], input_tensor.t() if channels_first else input_tensor
)
@patch("soundfile.write")
def assert_non_wav(
self, fmt, dtype, sample_rate, num_channels, channels_first, mocked_write
):
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
filepath = f"foo.{fmt}"
input_tensor = get_wav_data(
dtype,
num_channels,
num_frames=3 * sample_rate,
normalize=False,
channels_first=channels_first,
).t()
expected_data = input_tensor.t() if channels_first else input_tensor
soundfile_backend.save(
filepath, input_tensor, sample_rate, channels_first=channels_first
)
# on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1]
assert args["file"] == filepath
assert args["samplerate"] == sample_rate
assert args["subtype"] is None
if fmt in ["sph", "nist", "nis"]:
assert args["format"] == "NIST"
else:
assert args["format"] is None
self.assertEqual(args["data"], expected_data)
@parameterize(
["sph", "nist", "nis"],
["int32", "int16"],
[8000, 16000],
[1, 2],
[False, True],
)
def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav(fmt, dtype, sample_rate, num_channels, channels_first)
@parameterize(
["int32", "int16"], [8000, 16000], [1, 2], [False, True],
)
def test_flac(self, dtype, sample_rate, num_channels, channels_first):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav("flac", dtype, sample_rate, num_channels, channels_first)
@parameterize(
["int32", "int16"], [8000, 16000], [1, 2], [False, True],
)
def test_ogg(self, dtype, sample_rate, num_channels, channels_first):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav("ogg", dtype, sample_rate, num_channels, channels_first)
@skipIfNoModule("soundfile")
class SaveTestBase(TempDirMixin, PytorchTestCase):
def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
"""`soundfile_backend.save` can save wav format."""
path = self.get_temp_path("data.wav")
expected = get_wav_data(
dtype, num_channels, num_frames=num_frames, normalize=False
)
soundfile_backend.save(path, expected, sample_rate)
found, sr = load_wav(path, normalize=False)
assert sample_rate == sr
self.assertEqual(found, expected)
def _assert_non_wav(self, fmt, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save non-wav format.
Due to precision missmatch, and the lack of alternative way to decode the
resulting files without using soundfile, only meta data are validated.
"""
num_frames = sample_rate * 3
path = self.get_temp_path(f"data.{fmt}")
expected = get_wav_data(
dtype, num_channels, num_frames=num_frames, normalize=False
)
soundfile_backend.save(path, expected, sample_rate)
sinfo = soundfile.info(path)
assert sinfo.format == fmt.upper()
assert sinfo.frames == num_frames
assert sinfo.channels == num_channels
assert sinfo.samplerate == sample_rate
def assert_flac(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save flac format."""
self._assert_non_wav("flac", dtype, sample_rate, num_channels)
def assert_sphere(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save sph format."""
self._assert_non_wav("nist", dtype, sample_rate, num_channels)
def assert_ogg(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save ogg format.
As we cannot inspect the OGG format (it's lossy), we only check the metadata.
"""
self._assert_non_wav("ogg", dtype, sample_rate, num_channels)
@skipIfNoModule("soundfile")
class TestSave(SaveTestBase):
@parameterize(
["float32", "int32", "int16"], [8000, 16000], [1, 2],
)
def test_wav(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save wav format."""
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterize(
["float32", "int32", "int16"], [4, 8, 16, 32],
)
def test_multiple_channels(self, dtype, num_channels):
"""`soundfile_backend.save` can save wav with more than 2 channels."""
sample_rate = 8000
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterize(
["int32", "int16"], [8000, 16000], [1, 2],
)
@skipIfFormatNotSupported("NIST")
def test_sphere(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save sph format."""
self.assert_sphere(dtype, sample_rate, num_channels)
@parameterize(
[8000, 16000], [1, 2],
)
@skipIfFormatNotSupported("FLAC")
def test_flac(self, sample_rate, num_channels):
"""`soundfile_backend.save` can save flac format."""
self.assert_flac("float32", sample_rate, num_channels)
@parameterize(
[8000, 16000], [1, 2],
)
@skipIfFormatNotSupported("OGG")
def test_ogg(self, sample_rate, num_channels):
"""`soundfile_backend.save` can save ogg/vorbis format."""
self.assert_ogg("float32", sample_rate, num_channels)
@skipIfNoModule("soundfile")
class TestSaveParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of optional parameters of `soundfile_backend.save`"""
@parameterize([(True,), (False,)])
def test_channels_first(self, channels_first):
"""channels_first swaps axes"""
path = self.get_temp_path("data.wav")
data = get_wav_data("int32", 2, channels_first=channels_first)
soundfile_backend.save(path, data, 8000, channels_first=channels_first)
found = load_wav(path)[0]
expected = data if channels_first else data.transpose(1, 0)
self.assertEqual(found, expected, atol=1e-4, rtol=1e-8)
...@@ -8,6 +8,9 @@ from torchaudio import ( ...@@ -8,6 +8,9 @@ from torchaudio import (
sox_effects, sox_effects,
transforms transforms
) )
USE_SOUNDFILE_LEGACY_INTERFACE = True
from torchaudio.backend import ( from torchaudio.backend import (
list_audio_backends, list_audio_backends,
get_audio_backend, get_audio_backend,
......
"""The new soundfile backend which will become default in 0.8.0 onward"""
from typing import Tuple, Optional
import warnings
import torch
from torchaudio._internal import module_utils as _mod_utils
from .common import AudioMetaData
if _mod_utils.is_module_available("soundfile"):
import soundfile
@_mod_utils.requires_module("soundfile")
def info(filepath: str) -> AudioMetaData:
"""Get signal information of an audio file.
Args:
filepath (str or pathlib.Path): Path to audio file.
This functionalso handles ``pathlib.Path`` objects, but is annotated as ``str``
for the consistency with "sox_io" backend, which has a restriction on type annotation
for TorchScript compiler compatiblity.
Returns:
AudioMetaData: meta data of the given audio.
"""
sinfo = soundfile.info(filepath)
return AudioMetaData(sinfo.samplerate, sinfo.frames, sinfo.channels)
_SUBTYPE2DTYPE = {
"PCM_S8": "int8",
"PCM_U8": "uint8",
"PCM_16": "int16",
"PCM_32": "int32",
"FLOAT": "float32",
"DOUBLE": "float64",
}
@_mod_utils.requires_module("soundfile")
def load(
filepath: str,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
) -> Tuple[torch.Tensor, int]:
"""Load audio data from file.
Note:
The formats this function can handle depend on the soundfile installation.
This function is tested on the following formats;
* WAV
* 32-bit floating-point
* 32-bit signed integer
* 16-bit signed integer
* 8-bit unsigned integer
* FLAC
* OGG/VORBIS
* SPHERE
By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
``float32`` dtype and the shape of ``[channel, time]``.
The samples are normalized to fit in the range of ``[-1.0, 1.0]``.
When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
signed integer and 8-bit unsigned integer (24-bit signed integer is not supported),
by providing ``normalize=False``, this function can return integer Tensor, where the samples
are expressed within the whole range of the corresponding dtype, that is, ``int32`` tensor
for 32-bit signed PCM, ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM.
``normalize`` parameter has no effect on 32-bit floating-point WAV and other formats, such as
``flac`` and ``mp3``.
For these formats, this function always returns ``float32`` Tensor with values normalized to
``[-1.0, 1.0]``.
Args:
filepath (str or pathlib.Path): Path to audio file.
This functionalso handles ``pathlib.Path`` objects, but is annotated as ``str``
for the consistency with "sox_io" backend, which has a restriction on type annotation
for TorchScript compiler compatiblity.
frame_offset (int):
Number of frames to skip before start reading data.
num_frames (int):
Maximum number of frames to read. ``-1`` reads all the remaining samples,
starting from ``frame_offset``.
This function may return the less number of frames if there is not enough
frames in the given file.
normalize (bool):
When ``True``, this function always return ``float32``, and sample values are
normalized to ``[-1.0, 1.0]``.
If input file is integer WAV, giving ``False`` will change the resulting Tensor type to
integer type.
This argument has no effect for formats other than integer WAV type.
channels_first (bool):
When True, the returned Tensor has dimension ``[channel, time]``.
Otherwise, the returned Tensor's dimension is ``[time, channel]``.
Returns:
torch.Tensor:
If the input file has integer wav format and normalization is off, then it has
integer type, else ``float32`` type. If ``channels_first=True``, it has
``[channel, time]`` else ``[time, channel]``.
"""
with soundfile.SoundFile(filepath, "r") as file_:
if file_.format != "WAV" or normalize:
dtype = "float32"
elif file_.subtype not in _SUBTYPE2DTYPE:
raise ValueError(f"Unsupported subtype: {file_.subtype}")
else:
dtype = _SUBTYPE2DTYPE[file_.subtype]
frames = file_._prepare_read(frame_offset, None, num_frames)
waveform = file_.read(frames, dtype, always_2d=True)
sample_rate = file_.samplerate
waveform = torch.from_numpy(waveform)
if channels_first:
waveform = waveform.t()
return waveform, sample_rate
@_mod_utils.requires_module("soundfile")
def save(
filepath: str,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
):
"""Save audio data to file.
Note:
The formats this function can handle depend on the soundfile installation.
This function is tested on the following formats;
* WAV
* 32-bit floating-point
* 32-bit signed integer
* 16-bit signed integer
* 8-bit unsigned integer
* FLAC
* OGG/VORBIS
* SPHERE
Args:
filepath (str or pathlib.Path): Path to audio file.
This functionalso handles ``pathlib.Path`` objects, but is annotated as ``str``
for the consistency with "sox_io" backend, which has a restriction on type annotation
for TorchScript compiler compatiblity.
tensor (torch.Tensor): Audio data to save. must be 2D tensor.
sample_rate (int): sampling rate
channels_first (bool):
If ``True``, the given tensor is interpreted as ``[channel, time]``,
otherwise ``[time, channel]``.
compression (Optional[float]):
Not used. It is here only for interface compatibility reson with "sox_io" backend.
"""
if src.ndim != 2:
raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
if compression is not None:
warnings.warn(
'`save` function of "soundfile" backend does not support "compression" parameter. '
"The argument is silently ignored."
)
ext = str(filepath).split(".")[-1].lower()
if ext != "wav":
subtype = None
elif src.dtype == torch.uint8:
subtype = "PCM_U8"
elif src.dtype == torch.int16:
subtype = "PCM_16"
elif src.dtype == torch.int32:
subtype = "PCM_32"
elif src.dtype == torch.float32:
subtype = "FLOAT"
elif src.dtype == torch.float64:
subtype = "DOUBLE"
else:
raise ValueError(f"Unsupported dtype for WAV: {src.dtype}")
format_ = None
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
# so we extend the extensions manually here
if ext in ["nis", "nist", "sph"]:
format_ = "NIST"
if channels_first:
src = src.t()
soundfile.write(
file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format_
)
@_mod_utils.requires_module("soundfile")
@_mod_utils.deprecated('Please use "torchaudio.load".', "0.9.0")
def load_wav(
filepath: str,
frame_offset: int = 0,
num_frames: int = -1,
channels_first: bool = True,
) -> Tuple[torch.Tensor, int]:
"""Load wave file.
This function is defined only for the purpose of compatibility against other backend
for simple usecases, such as ``torchaudio.load_wav(filepath)``.
The implementation is same as :py:func:`load`.
"""
return load(
filepath,
frame_offset,
num_frames,
normalize=False,
channels_first=channels_first,
)
from typing import Any, Optional from typing import Any, Optional
class AudioMetaData:
"""Data class to be returned by :py:func:`~torchaudio.info`.
:ivar int sample_rate: Sample rate
:ivar int num_frames: The number of frames
:ivar int num_channels: The number of channels
"""
def __init__(self, sample_rate: int, num_frames: int, num_channels: int):
self.sample_rate = sample_rate
self.num_frames = num_frames
self.num_channels = num_channels
class SignalInfo: class SignalInfo:
"""Data class returned ``info`` functions. """Data class returned ``info`` functions.
......
...@@ -5,18 +5,7 @@ from torchaudio._internal import ( ...@@ -5,18 +5,7 @@ from torchaudio._internal import (
module_utils as _mod_utils, module_utils as _mod_utils,
) )
from .common import AudioMetaData
class AudioMetaData:
"""Data class to be returned by :py:func:`~torchaudio.backend.sox_io_backend.info`.
:ivar int sample_rate: Sample rate
:ivar int num_frames: The number of frames
:ivar int num_channels: The number of channels
"""
def __init__(self, sample_rate: int, num_frames: int, num_channels: int):
self.sample_rate = sample_rate
self.num_frames = num_frames
self.num_channels = num_channels
@_mod_utils.requires_module('torchaudio._torchaudio') @_mod_utils.requires_module('torchaudio._torchaudio')
......
...@@ -9,6 +9,7 @@ from . import ( ...@@ -9,6 +9,7 @@ from . import (
sox_backend, sox_backend,
sox_io_backend, sox_io_backend,
soundfile_backend, soundfile_backend,
_soundfile_backend,
) )
__all__ = [ __all__ = [
...@@ -58,7 +59,18 @@ def set_audio_backend(backend: Optional[str]): ...@@ -58,7 +59,18 @@ def set_audio_backend(backend: Optional[str]):
elif backend == 'sox_io': elif backend == 'sox_io':
module = sox_io_backend module = sox_io_backend
elif backend == 'soundfile': elif backend == 'soundfile':
module = soundfile_backend if torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE:
warnings.warn(
'The interface of "soundfile" backend is planned to change in 0.8.0 to '
'match that of "sox_io" backend and the current interface will be removed in 0.9.0. '
'To use the new interface, do '
'`torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False` '
'before setting the backend to "soundfile". '
'Please refer to https://github.com/pytorch/audio/issues/903 for the detail.'
)
module = soundfile_backend
else:
module = _soundfile_backend
else: else:
raise NotImplementedError(f'Unexpected backend "{backend}"') raise NotImplementedError(f'Unexpected backend "{backend}"')
...@@ -89,6 +101,6 @@ def get_audio_backend() -> Optional[str]: ...@@ -89,6 +101,6 @@ def get_audio_backend() -> Optional[str]:
return 'sox' return 'sox'
if torchaudio.load == sox_io_backend.load: if torchaudio.load == sox_io_backend.load:
return 'sox_io' return 'sox_io'
if torchaudio.load == soundfile_backend.load: if torchaudio.load in [soundfile_backend.load, _soundfile_backend.load]:
return 'soundfile' return 'soundfile'
raise ValueError('Unknown backend.') raise ValueError('Unknown backend.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment