Commit 9dcc7a15 authored by flyingdown's avatar flyingdown
Browse files

init v0.10.0

parent db2b0b79
Pipeline #254 failed with stages
in 0 seconds
from unittest.mock import patch
import warnings
import tarfile
import torch
from torchaudio.backend import soundfile_backend
from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoModule,
get_wav_data,
save_wav,
nested_params,
)
from torchaudio_unittest.backend.common import (
get_bits_per_sample,
get_encoding,
)
from .common import skipIfFormatNotSupported, parameterize
if _mod_utils.is_module_available("soundfile"):
import soundfile
@skipIfNoModule("soundfile")
class TestInfo(TempDirMixin, PytorchTestCase):
@parameterize(
["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2],
)
def test_wav(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.info` can check wav file correctly"""
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(
dtype, num_channels, normalize=False, num_frames=duration * sample_rate
)
save_wav(path, data, sample_rate)
info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == get_bits_per_sample("wav", dtype)
assert info.encoding == get_encoding("wav", dtype)
@parameterize([8000, 16000], [1, 2])
@skipIfFormatNotSupported("FLAC")
def test_flac(self, sample_rate, num_channels):
"""`soundfile_backend.info` can check flac file correctly"""
duration = 1
num_frames = sample_rate * duration
data = torch.randn(num_frames, num_channels).numpy()
path = self.get_temp_path("data.flac")
soundfile.write(path, data, sample_rate)
info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == 16
assert info.encoding == "FLAC"
@parameterize([8000, 16000], [1, 2])
@skipIfFormatNotSupported("OGG")
def test_ogg(self, sample_rate, num_channels):
"""`soundfile_backend.info` can check ogg file correctly"""
duration = 1
num_frames = sample_rate * duration
data = torch.randn(num_frames, num_channels).numpy()
path = self.get_temp_path("data.ogg")
soundfile.write(path, data, sample_rate)
info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0
assert info.encoding == "VORBIS"
@nested_params(
[8000, 16000],
[1, 2],
[
('PCM_24', 24),
('PCM_32', 32)
],
)
@skipIfFormatNotSupported("NIST")
def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth):
"""`soundfile_backend.info` can check sph file correctly"""
duration = 1
num_frames = sample_rate * duration
data = torch.randn(num_frames, num_channels).numpy()
path = self.get_temp_path("data.nist")
subtype, bits_per_sample = subtype_and_bit_depth
soundfile.write(path, data, sample_rate, subtype=subtype)
info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "PCM_S"
def test_unknown_subtype_warning(self):
"""soundfile_backend.info issues a warning when the subtype is unknown
This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE
dict should be updated.
"""
def _mock_info_func(_):
class MockSoundFileInfo:
samplerate = 8000
frames = 356
channels = 2
subtype = 'UNSEEN_SUBTYPE'
format = 'UNKNOWN'
return MockSoundFileInfo()
with patch("soundfile.info", _mock_info_func):
with warnings.catch_warnings(record=True) as w:
info = soundfile_backend.info("foo")
assert len(w) == 1
assert "UNSEEN_SUBTYPE subtype is unknown to TorchAudio" in str(w[-1].message)
assert info.bits_per_sample == 0
@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext, subtype, bits_per_sample):
"""Query audio via file-like object works"""
duration = 2
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
path = self.get_temp_path(f'test.{ext}')
data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(path, data, sample_rate, subtype=subtype)
with open(path, 'rb') as fileobj:
info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "FLAC" if ext == 'flac' else "PCM_S"
def test_fileobj_wav(self):
"""Loading audio via file-like object works"""
self._test_fileobj('wav', 'PCM_16', 16)
@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Loading audio via file-like object works"""
self._test_fileobj('flac', 'PCM_16', 16)
def _test_tarobj(self, ext, subtype, bits_per_sample):
"""Query compressed audio via file-like object works"""
duration = 2
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')
data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(audio_path, data, sample_rate, subtype=subtype)
with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "FLAC" if ext == 'flac' else "PCM_S"
def test_tarobj_wav(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj('wav', 'PCM_16', 16)
@skipIfFormatNotSupported("FLAC")
def test_tarobj_flac(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj('flac', 'PCM_16', 16)
import os
import tarfile
from unittest.mock import patch
import torch
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import soundfile_backend
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoModule,
get_wav_data,
normalize_wav,
load_wav,
save_wav,
)
from .common import (
parameterize,
dtype2subtype,
skipIfFormatNotSupported,
)
if _mod_utils.is_module_available("soundfile"):
import soundfile
def _get_mock_path(
ext: str, dtype: str, sample_rate: int, num_channels: int, num_frames: int,
):
return f"{dtype}_{sample_rate}_{num_channels}_{num_frames}.{ext}"
def _get_mock_params(path: str):
filename, ext = path.split(".")
parts = filename.split("_")
return {
"ext": ext,
"dtype": parts[0],
"sample_rate": int(parts[1]),
"num_channels": int(parts[2]),
"num_frames": int(parts[3]),
}
class SoundFileMock:
def __init__(self, path, mode):
assert mode == "r"
self.path = path
self._params = _get_mock_params(path)
self._start = None
@property
def samplerate(self):
return self._params["sample_rate"]
@property
def format(self):
if self._params["ext"] == "wav":
return "WAV"
if self._params["ext"] == "flac":
return "FLAC"
if self._params["ext"] == "ogg":
return "OGG"
if self._params["ext"] in ["sph", "nis", "nist"]:
return "NIST"
@property
def subtype(self):
if self._params["ext"] == "ogg":
return "VORBIS"
return dtype2subtype(self._params["dtype"])
def _prepare_read(self, start, stop, frames):
assert stop is None
self._start = start
return frames
def read(self, frames, dtype, always_2d):
assert always_2d
data = get_wav_data(
dtype,
self._params["num_channels"],
normalize=False,
num_frames=self._params["num_frames"],
channels_first=False,
).numpy()
return data[self._start:self._start + frames]
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
pass
class MockedLoadTest(PytorchTestCase):
def assert_dtype(
self, ext, dtype, sample_rate, num_channels, normalize, channels_first
):
"""When format is WAV or NIST, normalize=False will return the native dtype Tensor, otherwise float32"""
num_frames = 3 * sample_rate
path = _get_mock_path(ext, dtype, sample_rate, num_channels, num_frames)
expected_dtype = (
torch.float32
if normalize or ext not in ["wav", "nist"]
else getattr(torch, dtype)
)
with patch("soundfile.SoundFile", SoundFileMock):
found, sr = soundfile_backend.load(
path, normalize=normalize, channels_first=channels_first
)
assert found.dtype == expected_dtype
assert sample_rate == sr
@parameterize(
["uint8", "int16", "int32", "float32", "float64"],
[8000, 16000],
[1, 2],
[True, False],
[True, False],
)
def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""Returns native dtype when normalize=False else float32"""
self.assert_dtype(
"wav", dtype, sample_rate, num_channels, normalize, channels_first
)
@parameterize(
["int8", "int16", "int32"], [8000, 16000], [1, 2], [True, False], [True, False],
)
def test_sphere(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""Returns float32 always"""
self.assert_dtype(
"sph", dtype, sample_rate, num_channels, normalize, channels_first
)
@parameterize([8000, 16000], [1, 2], [True, False], [True, False])
def test_ogg(self, sample_rate, num_channels, normalize, channels_first):
"""Returns float32 always"""
self.assert_dtype(
"ogg", "int16", sample_rate, num_channels, normalize, channels_first
)
@parameterize([8000, 16000], [1, 2], [True, False], [True, False])
def test_flac(self, sample_rate, num_channels, normalize, channels_first):
"""`soundfile_backend.load` can load ogg format."""
self.assert_dtype(
"flac", "int16", sample_rate, num_channels, normalize, channels_first
)
class LoadTestBase(TempDirMixin, PytorchTestCase):
def assert_wav(
self,
dtype,
sample_rate,
num_channels,
normalize,
channels_first=True,
duration=1,
):
"""`soundfile_backend.load` can load wav format correctly.
Wav data loaded with soundfile backend should match those with scipy
"""
path = self.get_temp_path("reference.wav")
num_frames = duration * sample_rate
data = get_wav_data(
dtype,
num_channels,
normalize=normalize,
num_frames=num_frames,
channels_first=channels_first,
)
save_wav(path, data, sample_rate, channels_first=channels_first)
expected = load_wav(path, normalize=normalize, channels_first=channels_first)[0]
data, sr = soundfile_backend.load(
path, normalize=normalize, channels_first=channels_first
)
assert sr == sample_rate
self.assertEqual(data, expected)
def assert_sphere(
self, dtype, sample_rate, num_channels, channels_first=True, duration=1,
):
"""`soundfile_backend.load` can load SPHERE format correctly."""
path = self.get_temp_path("reference.sph")
num_frames = duration * sample_rate
raw = get_wav_data(
dtype,
num_channels,
num_frames=num_frames,
normalize=False,
channels_first=False,
)
soundfile.write(
path, raw, sample_rate, subtype=dtype2subtype(dtype), format="NIST"
)
expected = normalize_wav(raw.t() if channels_first else raw)
data, sr = soundfile_backend.load(path, channels_first=channels_first)
assert sr == sample_rate
self.assertEqual(data, expected, atol=1e-4, rtol=1e-8)
def assert_flac(
self, dtype, sample_rate, num_channels, channels_first=True, duration=1,
):
"""`soundfile_backend.load` can load FLAC format correctly."""
path = self.get_temp_path("reference.flac")
num_frames = duration * sample_rate
raw = get_wav_data(
dtype,
num_channels,
num_frames=num_frames,
normalize=False,
channels_first=False,
)
soundfile.write(path, raw, sample_rate)
expected = normalize_wav(raw.t() if channels_first else raw)
data, sr = soundfile_backend.load(path, channels_first=channels_first)
assert sr == sample_rate
self.assertEqual(data, expected, atol=1e-4, rtol=1e-8)
@skipIfNoModule("soundfile")
class TestLoad(LoadTestBase):
"""Test the correctness of `soundfile_backend.load` for various formats"""
@parameterize(
["float32", "int32", "int16"],
[8000, 16000],
[1, 2],
[False, True],
[False, True],
)
def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""`soundfile_backend.load` can load wav format correctly."""
self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first)
@parameterize(
["int16"], [16000], [2], [False],
)
def test_wav_large(self, dtype, sample_rate, num_channels, normalize):
"""`soundfile_backend.load` can load large wav file correctly."""
two_hours = 2 * 60 * 60
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=two_hours)
@parameterize(["float32", "int32", "int16"], [4, 8, 16, 32], [False, True])
def test_multiple_channels(self, dtype, num_channels, channels_first):
"""`soundfile_backend.load` can load wav file with more than 2 channels."""
sample_rate = 8000
normalize = False
self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first)
@parameterize(["int32", "int16"], [8000, 16000], [1, 2], [False, True])
@skipIfFormatNotSupported("NIST")
def test_sphere(self, dtype, sample_rate, num_channels, channels_first):
"""`soundfile_backend.load` can load sphere format correctly."""
self.assert_sphere(dtype, sample_rate, num_channels, channels_first)
@parameterize(["int32", "int16"], [8000, 16000], [1, 2], [False, True])
@skipIfFormatNotSupported("FLAC")
def test_flac(self, dtype, sample_rate, num_channels, channels_first):
"""`soundfile_backend.load` can load flac format correctly."""
self.assert_flac(dtype, sample_rate, num_channels, channels_first)
@skipIfNoModule("soundfile")
class TestLoadFormat(TempDirMixin, PytorchTestCase):
"""Given `format` parameter, `so.load` can load files without extension"""
original = None
path = None
def _make_file(self, format_):
sample_rate = 8000
path_with_ext = self.get_temp_path(f'test.{format_}')
data = get_wav_data('float32', num_channels=2).numpy().T
soundfile.write(path_with_ext, data, sample_rate)
expected = soundfile.read(path_with_ext, dtype='float32')[0].T
path = os.path.splitext(path_with_ext)[0]
os.rename(path_with_ext, path)
return path, expected
def _test_format(self, format_):
"""Providing format allows to read file without extension"""
path, expected = self._make_file(format_)
found, _ = soundfile_backend.load(path)
self.assertEqual(found, expected)
@parameterized.expand([
('WAV', ), ('wav', ),
])
def test_wav(self, format_):
self._test_format(format_)
@parameterized.expand([
('FLAC', ), ('flac',),
])
@skipIfFormatNotSupported("FLAC")
def test_flac(self, format_):
self._test_format(format_)
@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext):
"""Loading audio via file-like object works"""
sample_rate = 16000
path = self.get_temp_path(f'test.{ext}')
data = get_wav_data('float32', num_channels=2).numpy().T
soundfile.write(path, data, sample_rate)
expected = soundfile.read(path, dtype='float32')[0].T
with open(path, 'rb') as fileobj:
found, sr = soundfile_backend.load(fileobj)
assert sr == sample_rate
self.assertEqual(expected, found)
def test_fileobj_wav(self):
"""Loading audio via file-like object works"""
self._test_fileobj('wav')
@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Loading audio via file-like object works"""
self._test_fileobj('flac')
def _test_tarfile(self, ext):
"""Loading audio via file-like object works"""
sample_rate = 16000
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')
data = get_wav_data('float32', num_channels=2).numpy().T
soundfile.write(audio_path, data, sample_rate)
expected = soundfile.read(audio_path, dtype='float32')[0].T
with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
found, sr = soundfile_backend.load(fileobj)
assert sr == sample_rate
self.assertEqual(expected, found)
def test_tarfile_wav(self):
"""Loading audio via file-like object works"""
self._test_tarfile('wav')
@skipIfFormatNotSupported("FLAC")
def test_tarfile_flac(self):
"""Loading audio via file-like object works"""
self._test_tarfile('flac')
import io
from unittest.mock import patch
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.backend import soundfile_backend
from torchaudio_unittest.common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoModule,
get_wav_data,
load_wav,
nested_params,
)
from .common import (
fetch_wav_subtype,
parameterize,
skipIfFormatNotSupported,
)
if _mod_utils.is_module_available("soundfile"):
import soundfile
class MockedSaveTest(PytorchTestCase):
@nested_params(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
[False, True],
[
(None, None),
('PCM_U', None),
('PCM_U', 8),
('PCM_S', None),
('PCM_S', 16),
('PCM_S', 32),
('PCM_F', None),
('PCM_F', 32),
('PCM_F', 64),
('ULAW', None),
('ULAW', 8),
('ALAW', None),
('ALAW', 8),
],
)
@patch("soundfile.write")
def test_wav(self, dtype, sample_rate, num_channels, channels_first,
enc_params, mocked_write):
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
filepath = "foo.wav"
input_tensor = get_wav_data(
dtype,
num_channels,
num_frames=3 * sample_rate,
normalize=dtype == "float32",
channels_first=channels_first,
).t()
encoding, bits_per_sample = enc_params
soundfile_backend.save(
filepath, input_tensor, sample_rate, channels_first=channels_first,
encoding=encoding, bits_per_sample=bits_per_sample
)
# on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1]
assert args["file"] == filepath
assert args["samplerate"] == sample_rate
assert args["subtype"] == fetch_wav_subtype(
dtype, encoding, bits_per_sample)
assert args["format"] is None
self.assertEqual(
args["data"], input_tensor.t() if channels_first else input_tensor
)
@patch("soundfile.write")
def assert_non_wav(
self, fmt, dtype, sample_rate, num_channels, channels_first, mocked_write,
encoding=None, bits_per_sample=None,
):
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
filepath = f"foo.{fmt}"
input_tensor = get_wav_data(
dtype,
num_channels,
num_frames=3 * sample_rate,
normalize=False,
channels_first=channels_first,
).t()
expected_data = input_tensor.t() if channels_first else input_tensor
soundfile_backend.save(
filepath, input_tensor, sample_rate, channels_first,
encoding=encoding, bits_per_sample=bits_per_sample,
)
# on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1]
assert args["file"] == filepath
assert args["samplerate"] == sample_rate
if fmt in ["sph", "nist", "nis"]:
assert args["format"] == "NIST"
else:
assert args["format"] is None
self.assertEqual(args["data"], expected_data)
@nested_params(
["sph", "nist", "nis"],
["int32", "int16"],
[8000, 16000],
[1, 2],
[False, True],
[
('PCM_S', 8),
('PCM_S', 16),
('PCM_S', 24),
('PCM_S', 32),
('ULAW', 8),
('ALAW', 8),
('ALAW', 16),
('ALAW', 24),
('ALAW', 32),
],
)
def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first, enc_params):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
encoding, bits_per_sample = enc_params
self.assert_non_wav(fmt, dtype, sample_rate, num_channels,
channels_first, encoding=encoding,
bits_per_sample=bits_per_sample)
@parameterize(
["int32", "int16"], [8000, 16000], [1, 2], [False, True],
[8, 16, 24],
)
def test_flac(self, dtype, sample_rate, num_channels,
channels_first, bits_per_sample):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav("flac", dtype, sample_rate, num_channels,
channels_first, bits_per_sample=bits_per_sample)
@parameterize(
["int32", "int16"], [8000, 16000], [1, 2], [False, True],
)
def test_ogg(self, dtype, sample_rate, num_channels, channels_first):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav("ogg", dtype, sample_rate, num_channels, channels_first)
@skipIfNoModule("soundfile")
class SaveTestBase(TempDirMixin, PytorchTestCase):
def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
"""`soundfile_backend.save` can save wav format."""
path = self.get_temp_path("data.wav")
expected = get_wav_data(
dtype, num_channels, num_frames=num_frames, normalize=False
)
soundfile_backend.save(path, expected, sample_rate)
found, sr = load_wav(path, normalize=False)
assert sample_rate == sr
self.assertEqual(found, expected)
def _assert_non_wav(self, fmt, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save non-wav format.
Due to precision missmatch, and the lack of alternative way to decode the
resulting files without using soundfile, only meta data are validated.
"""
num_frames = sample_rate * 3
path = self.get_temp_path(f"data.{fmt}")
expected = get_wav_data(
dtype, num_channels, num_frames=num_frames, normalize=False
)
soundfile_backend.save(path, expected, sample_rate)
sinfo = soundfile.info(path)
assert sinfo.format == fmt.upper()
assert sinfo.frames == num_frames
assert sinfo.channels == num_channels
assert sinfo.samplerate == sample_rate
def assert_flac(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save flac format."""
self._assert_non_wav("flac", dtype, sample_rate, num_channels)
def assert_sphere(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save sph format."""
self._assert_non_wav("nist", dtype, sample_rate, num_channels)
def assert_ogg(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save ogg format.
As we cannot inspect the OGG format (it's lossy), we only check the metadata.
"""
self._assert_non_wav("ogg", dtype, sample_rate, num_channels)
@skipIfNoModule("soundfile")
class TestSave(SaveTestBase):
@parameterize(
["float32", "int32", "int16"], [8000, 16000], [1, 2],
)
def test_wav(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save wav format."""
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterize(
["float32", "int32", "int16"], [4, 8, 16, 32],
)
def test_multiple_channels(self, dtype, num_channels):
"""`soundfile_backend.save` can save wav with more than 2 channels."""
sample_rate = 8000
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterize(
["int32", "int16"], [8000, 16000], [1, 2],
)
@skipIfFormatNotSupported("NIST")
def test_sphere(self, dtype, sample_rate, num_channels):
"""`soundfile_backend.save` can save sph format."""
self.assert_sphere(dtype, sample_rate, num_channels)
@parameterize(
[8000, 16000], [1, 2],
)
@skipIfFormatNotSupported("FLAC")
def test_flac(self, sample_rate, num_channels):
"""`soundfile_backend.save` can save flac format."""
self.assert_flac("float32", sample_rate, num_channels)
@parameterize(
[8000, 16000], [1, 2],
)
@skipIfFormatNotSupported("OGG")
def test_ogg(self, sample_rate, num_channels):
"""`soundfile_backend.save` can save ogg/vorbis format."""
self.assert_ogg("float32", sample_rate, num_channels)
@skipIfNoModule("soundfile")
class TestSaveParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of optional parameters of `soundfile_backend.save`"""
@parameterize([True, False])
def test_channels_first(self, channels_first):
"""channels_first swaps axes"""
path = self.get_temp_path("data.wav")
data = get_wav_data("int32", 2, channels_first=channels_first)
soundfile_backend.save(path, data, 8000, channels_first=channels_first)
found = load_wav(path)[0]
expected = data if channels_first else data.transpose(1, 0)
self.assertEqual(found, expected, atol=1e-4, rtol=1e-8)
@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext):
"""Saving audio to file-like object works"""
sample_rate = 16000
path = self.get_temp_path(f'test.{ext}')
subtype = 'FLOAT' if ext == 'wav' else None
data = get_wav_data('float32', num_channels=2)
soundfile.write(path, data.numpy().T, sample_rate, subtype=subtype)
expected = soundfile.read(path, dtype='float32')[0]
fileobj = io.BytesIO()
soundfile_backend.save(fileobj, data, sample_rate, format=ext)
fileobj.seek(0)
found, sr = soundfile.read(fileobj, dtype='float32')
assert sr == sample_rate
self.assertEqual(expected, found, atol=1e-4, rtol=1e-8)
def test_fileobj_wav(self):
"""Saving audio via file-like object works"""
self._test_fileobj('wav')
@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Saving audio via file-like object works"""
self._test_fileobj('flac')
@skipIfFormatNotSupported("NIST")
def test_fileobj_nist(self):
"""Saving audio via file-like object works"""
self._test_fileobj('NIST')
@skipIfFormatNotSupported("OGG")
def test_fileobj_ogg(self):
"""Saving audio via file-like object works"""
self._test_fileobj('OGG')
def name_func(func, _, params):
return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}'
def get_enc_params(dtype):
if dtype == 'float32':
return 'PCM_F', 32
if dtype == 'int32':
return 'PCM_S', 32
if dtype == 'int16':
return 'PCM_S', 16
if dtype == 'uint8':
return 'PCM_U', 8
raise ValueError(f'Unexpected dtype: {dtype}')
from contextlib import contextmanager
import io
import os
import itertools
import tarfile
from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio.utils.sox_utils import get_buffer_size, set_buffer_size
from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.backend.common import (
get_bits_per_sample,
get_encoding,
)
from torchaudio_unittest.common_utils import (
TempDirMixin,
HttpServerMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoModule,
skipIfNoSox,
get_asset_path,
get_wav_data,
save_wav,
sox_utils,
)
from .common import (
name_func,
)
if _mod_utils.is_module_available("requests"):
import requests
@skipIfNoExec('sox')
@skipIfNoSox
class TestInfo(TempDirMixin, PytorchTestCase):
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file correctly"""
duration = 1
path = self.get_temp_path('data.wav')
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
assert info.encoding == get_encoding('wav', dtype)
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[4, 8, 16, 32],
)), name_func=name_func)
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration = 1
path = self.get_temp_path('data.wav')
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
assert info.encoding == get_encoding('wav', dtype)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[96, 128, 160, 192, 224, 256, 320],
)), name_func=name_func)
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.info` can check mp3 file correctly"""
duration = 1
path = self.get_temp_path('data.mp3')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=bit_rate, duration=duration,
)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
# mp3 does not preserve the number of samples
# assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert info.encoding == "MP3"
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=name_func)
def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.info` can check flac file correctly"""
duration = 1
path = self.get_temp_path('data.flac')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=compression_level, duration=duration,
)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 24 # FLAC standard
assert info.encoding == "FLAC"
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)), name_func=name_func)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.info` can check vorbis file correctly"""
duration = 1
path = self.get_temp_path('data.vorbis')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=quality_level, duration=duration,
)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert info.encoding == "VORBIS"
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[16, 32],
)), name_func=name_func)
def test_sphere(self, sample_rate, num_channels, bits_per_sample):
"""`sox_io_backend.info` can check sph file correctly"""
duration = 1
path = self.get_temp_path('data.sph')
sox_utils.gen_audio_file(
path, sample_rate, num_channels, duration=duration,
bit_depth=bits_per_sample)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "PCM_S"
@parameterized.expand(list(itertools.product(
['int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_amb(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check amb file correctly"""
duration = 1
path = self.get_temp_path('data.amb')
bits_per_sample = sox_utils.get_bit_depth(dtype)
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=bits_per_sample, duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == get_encoding("amb", dtype)
def test_amr_nb(self):
"""`sox_io_backend.info` can check amr-nb file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path('data.amr-nb')
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16,
duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0
assert info.encoding == "AMR_NB"
def test_ulaw(self):
"""`sox_io_backend.info` can check ulaw file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path('data.wav')
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels,
bit_depth=8, encoding='u-law',
duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 8
assert info.encoding == "ULAW"
def test_alaw(self):
"""`sox_io_backend.info` can check alaw file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path('data.wav')
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels,
bit_depth=8, encoding='a-law',
duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 8
assert info.encoding == "ALAW"
def test_gsm(self):
"""`sox_io_backend.info` can check gsm file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path('data.gsm')
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels,
duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_channels == num_channels
assert info.bits_per_sample == 0
assert info.encoding == "GSM"
def test_htk(self):
"""`sox_io_backend.info` can check HTK file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path('data.htk')
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels,
bit_depth=16, duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 16
assert info.encoding == "PCM_S"
@skipIfNoSox
class TestInfoOpus(PytorchTestCase):
@parameterized.expand(list(itertools.product(
['96k'],
[1, 2],
[0, 5, 10],
)), name_func=name_func)
def test_opus(self, bitrate, num_channels, compression_level):
"""`sox_io_backend.info` can check opus file correcty"""
path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus')
info = sox_io_backend.info(path)
assert info.sample_rate == 48000
assert info.num_frames == 32768
assert info.num_channels == num_channels
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert info.encoding == "OPUS"
@skipIfNoSox
class TestLoadWithoutExtension(PytorchTestCase):
def test_mp3(self):
"""Providing `format` allows to read mp3 without extension
libsox does not check header for mp3
https://github.com/pytorch/audio/issues/1040
The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
"""
path = get_asset_path("mp3_without_ext")
sinfo = sox_io_backend.info(path, format="mp3")
assert sinfo.sample_rate == 16000
assert sinfo.num_frames == 81216
assert sinfo.num_channels == 1
assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert sinfo.encoding == "MP3"
class FileObjTestBase(TempDirMixin):
def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self.get_temp_path(f'test.{ext}')
bit_depth = sox_utils.get_bit_depth(dtype)
duration = num_frames / sample_rate
comment_file = self._gen_comment_file(comments) if comments else None
sox_utils.gen_audio_file(
path, sample_rate, num_channels=num_channels,
encoding=sox_utils.get_encoding(dtype),
bit_depth=bit_depth,
duration=duration,
comment_file=comment_file,
)
return path
def _gen_comment_file(self, comments):
comment_path = self.get_temp_path("comment.txt")
with open(comment_path, "w") as file_:
file_.writelines(comments)
return comment_path
@skipIfNoSox
@skipIfNoExec('sox')
class TestFileObject(FileObjTestBase, PytorchTestCase):
def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
format_ = ext if ext in ['mp3'] else None
with open(path, 'rb') as fileobj:
return sox_io_backend.info(fileobj, format_)
def _query_bytesio(self, ext, dtype, sample_rate, num_channels, num_frames):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
format_ = ext if ext in ['mp3'] else None
with open(path, 'rb') as file_:
fileobj = io.BytesIO(file_.read())
return sox_io_backend.info(fileobj, format_)
def _query_tarfile(self, ext, dtype, sample_rate, num_channels, num_frames):
audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
audio_file = os.path.basename(audio_path)
archive_path = self.get_temp_path('archive.tar.gz')
with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
format_ = ext if ext in ['mp3'] else None
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
return sox_io_backend.info(fileobj, format_)
@contextmanager
def _set_buffer_size(self, buffer_size):
try:
original_buffer_size = get_buffer_size()
set_buffer_size(buffer_size)
yield
finally:
set_buffer_size(original_buffer_size)
@parameterized.expand([
('wav', "float32"),
('wav', "int32"),
('wav', "int16"),
('wav', "uint8"),
('mp3', "float32"),
('flac', "float32"),
('vorbis', "float32"),
('amb', "int16"),
])
def test_fileobj(self, ext, dtype):
"""Querying audio via file object works"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([
('vorbis', "float32"),
])
def test_fileobj_large_header(self, ext, dtype):
"""
For audio file with header size exceeding default buffer size:
- Querying audio via file object without enlarging buffer size fails.
- Querying audio via file object after enlarging buffer size succeeds.
"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
comments = "metadata=" + " ".join(["value" for _ in range(1000)])
with self.assertRaisesRegex(RuntimeError, "^Error loading audio file:"):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
with self._set_buffer_size(16384):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([
('wav', "float32"),
('wav', "int32"),
('wav', "int16"),
('wav', "uint8"),
('mp3', "float32"),
('flac', "float32"),
('vorbis', "float32"),
('amb', "int16"),
])
def test_bytesio(self, ext, dtype):
"""Querying audio via ByteIO object works for small data"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([
('wav', "float32"),
('wav', "int32"),
('wav', "int16"),
('wav', "uint8"),
('mp3', "float32"),
('flac', "float32"),
('vorbis', "float32"),
('amb', "int16"),
])
def test_bytesio_tiny(self, ext, dtype):
"""Querying audio via ByteIO object works for small data"""
sample_rate = 8000
num_frames = 4
num_channels = 2
sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([
('wav', "float32"),
('wav', "int32"),
('wav', "int16"),
('wav', "uint8"),
('mp3', "float32"),
('flac', "float32"),
('vorbis', "float32"),
('amb', "int16"),
])
def test_tarfile(self, ext, dtype):
"""Querying compressed audio via file-like object works"""
sample_rate = 16000
num_frames = 3.0 * sample_rate
num_channels = 2
sinfo = self._query_tarfile(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@skipIfNoSox
@skipIfNoExec('sox')
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames):
audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
audio_file = os.path.basename(audio_path)
url = self.get_url(audio_file)
format_ = ext if ext in ['mp3'] else None
with requests.get(url, stream=True) as resp:
return sox_io_backend.info(resp.raw, format=format_)
@parameterized.expand([
('wav', "float32"),
('wav', "int32"),
('wav', "int16"),
('wav', "uint8"),
('mp3', "float32"),
('flac', "float32"),
('vorbis', "float32"),
('amb', "int16"),
])
def test_requests(self, ext, dtype):
"""Querying compressed audio via requests works"""
sample_rate = 16000
num_frames = 3.0 * sample_rate
num_channels = 2
sinfo = self._query_http(ext, dtype, sample_rate, num_channels, num_frames)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@skipIfNoSox
class TestInfoNoSuchFile(PytorchTestCase):
def test_info_fail(self):
"""
When attempted to get info on a non-existing file, error message must contain the file path.
"""
path = "non_existing_audio.wav"
with self.assertRaisesRegex(RuntimeError, "^Error loading audio file: failed to open file {0}$".format(path)):
sox_io_backend.info(path)
import io
import itertools
import tarfile
from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.common_utils import (
TempDirMixin,
HttpServerMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoModule,
skipIfNoSox,
get_asset_path,
get_wav_data,
load_wav,
save_wav,
sox_utils,
)
from .common import (
name_func,
)
if _mod_utils.is_module_available("requests"):
import requests
class LoadTestBase(TempDirMixin, PytorchTestCase):
def assert_format(
self,
format: str,
sample_rate: float,
num_channels: int,
compression: float = None,
bit_depth: int = None,
duration: float = 1,
normalize: bool = True,
encoding: str = None,
atol: float = 4e-05,
rtol: float = 1.3e-06,
):
"""`sox_io_backend.load` can load given format correctly.
file encodings introduce delay and boundary effects so
we create a reference wav file from the original file format
x
|
| 1. Generate given format with Sox
|
v 2. Convert to wav with Sox
given format ----------------------> wav
| |
| 3. Load with torchaudio | 4. Load with scipy
| |
v v
tensor ----------> x <----------- tensor
5. Compare
Underlying assumptions are;
i. Conversion of given format to wav with Sox preserves data.
ii. Loading wav file with scipy is correct.
By combining i & ii, step 2. and 4. allows to load reference given format
data without using torchaudio
"""
path = self.get_temp_path(f'1.original.{format}')
ref_path = self.get_temp_path('2.reference.wav')
# 1. Generate the given format with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels, encoding=encoding,
compression=compression, bit_depth=bit_depth, duration=duration,
)
# 2. Convert to wav with sox
wav_bit_depth = 32 if bit_depth == 24 else None # for 24-bit wav
sox_utils.convert_audio_file(path, ref_path, bit_depth=wav_bit_depth)
# 3. Load the given format with torchaudio
data, sr = sox_io_backend.load(path, normalize=normalize)
# 4. Load wav with scipy
data_ref = load_wav(ref_path, normalize=normalize)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=atol, rtol=rtol)
def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
"""`sox_io_backend.load` can load wav format correctly.
Wav data loaded with sox_io backend should match those with scipy
"""
path = self.get_temp_path('reference.wav')
data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
expected = load_wav(path, normalize=normalize)[0]
data, sr = sox_io_backend.load(path, normalize=normalize)
assert sr == sample_rate
self.assertEqual(data, expected)
@skipIfNoExec('sox')
@skipIfNoSox
class TestLoad(LoadTestBase):
"""Test the correctness of `sox_io_backend.load` for various formats"""
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
[False, True],
)), name_func=name_func)
def test_wav(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load wav format correctly."""
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[False, True],
)), name_func=name_func)
def test_24bit_wav(self, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load 24bit wav format correctly. Corectly casts it to ``int32`` tensor dtype."""
self.assert_format("wav", sample_rate, num_channels, bit_depth=24, normalize=normalize, duration=1)
@parameterized.expand(list(itertools.product(
['int16'],
[16000],
[2],
[False],
)), name_func=name_func)
def test_wav_large(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load large wav file correctly."""
two_hours = 2 * 60 * 60
self.assert_wav(dtype, sample_rate, num_channels, normalize, two_hours)
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[4, 8, 16, 32],
)), name_func=name_func)
def test_multiple_channels(self, dtype, num_channels):
"""`sox_io_backend.load` can load wav file with more than 2 channels."""
sample_rate = 8000
normalize = False
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)
@parameterized.expand(list(itertools.product(
[8000, 16000, 44100],
[1, 2],
[96, 128, 160, 192, 224, 256, 320],
)), name_func=name_func)
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load mp3 format correctly."""
self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=1, atol=5e-05)
@parameterized.expand(list(itertools.product(
[16000],
[2],
[128],
)), name_func=name_func)
def test_mp3_large(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load large mp3 file correctly."""
two_hours = 2 * 60 * 60
self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=two_hours, atol=5e-05)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=name_func)
def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load flac format correctly."""
self.assert_format("flac", sample_rate, num_channels, compression=compression_level, bit_depth=16, duration=1)
@parameterized.expand(list(itertools.product(
[16000],
[2],
[0],
)), name_func=name_func)
def test_flac_large(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load large flac file correctly."""
two_hours = 2 * 60 * 60
self.assert_format(
"flac", sample_rate, num_channels, compression=compression_level, bit_depth=16, duration=two_hours)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)), name_func=name_func)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load vorbis format correctly."""
self.assert_format("vorbis", sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=1)
@parameterized.expand(list(itertools.product(
[16000],
[2],
[10],
)), name_func=name_func)
def test_vorbis_large(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load large vorbis file correctly."""
two_hours = 2 * 60 * 60
self.assert_format(
"vorbis", sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=two_hours)
@parameterized.expand(list(itertools.product(
['96k'],
[1, 2],
[0, 5, 10],
)), name_func=name_func)
def test_opus(self, bitrate, num_channels, compression_level):
"""`sox_io_backend.load` can load opus file correctly."""
ops_path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus')
wav_path = self.get_temp_path(f'{bitrate}_{compression_level}_{num_channels}ch.opus.wav')
sox_utils.convert_audio_file(ops_path, wav_path)
expected, sample_rate = load_wav(wav_path)
found, sr = sox_io_backend.load(ops_path)
assert sample_rate == sr
self.assertEqual(expected, found)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_sphere(self, sample_rate, num_channels):
"""`sox_io_backend.load` can load sph format correctly."""
self.assert_format("sph", sample_rate, num_channels, bit_depth=32, duration=1)
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16'],
[8000, 16000],
[1, 2],
[False, True],
)), name_func=name_func)
def test_amb(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load amb format correctly."""
bit_depth = sox_utils.get_bit_depth(dtype)
encoding = sox_utils.get_encoding(dtype)
self.assert_format(
"amb", sample_rate, num_channels, bit_depth=bit_depth, duration=1, encoding=encoding, normalize=normalize)
def test_amr_nb(self):
"""`sox_io_backend.load` can load amr_nb format correctly."""
self.assert_format("amr-nb", sample_rate=8000, num_channels=1, bit_depth=32, duration=1)
@skipIfNoExec('sox')
@skipIfNoSox
class TestLoadParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of frame parameters of `sox_io_backend.load`"""
original = None
path = None
def setUp(self):
super().setUp()
sample_rate = 8000
self.original = get_wav_data('float32', num_channels=2)
self.path = self.get_temp_path('test.wav')
save_wav(self.path, self.original, sample_rate)
@parameterized.expand(list(itertools.product(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
)), name_func=name_func)
def test_frame(self, frame_offset, num_frames):
"""num_frames and frame_offset correctly specify the region of data"""
found, _ = sox_io_backend.load(self.path, frame_offset, num_frames)
frame_end = None if num_frames == -1 else frame_offset + num_frames
self.assertEqual(found, self.original[:, frame_offset:frame_end])
@parameterized.expand([(True, ), (False, )], name_func=name_func)
def test_channels_first(self, channels_first):
"""channels_first swaps axes"""
found, _ = sox_io_backend.load(self.path, channels_first=channels_first)
expected = self.original if channels_first else self.original.transpose(1, 0)
self.assertEqual(found, expected)
@skipIfNoSox
class TestLoadWithoutExtension(PytorchTestCase):
def test_mp3(self):
"""Providing format allows to read mp3 without extension
libsox does not check header for mp3
https://github.com/pytorch/audio/issues/1040
The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
"""
path = get_asset_path("mp3_without_ext")
_, sr = sox_io_backend.load(path, format="mp3")
assert sr == 16000
class CloggedFileObj:
def __init__(self, fileobj):
self.fileobj = fileobj
self.buffer = b''
def read(self, n):
if not self.buffer:
self.buffer += self.fileobj.read(n)
ret = self.buffer[:2]
self.buffer = self.buffer[2:]
return ret
@skipIfNoSox
@skipIfNoExec('sox')
class TestFileObject(TempDirMixin, PytorchTestCase):
"""
In this test suite, the result of file-like object input is compared against file path input,
because `load` function is rigrously tested for file path inputs to match libsox's result,
"""
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_fileobj(self, ext, compression):
"""Loading audio via file object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')
sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
compression=compression)
expected, _ = sox_io_backend.load(path)
with open(path, 'rb') as fileobj:
found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_bytesio(self, ext, compression):
"""Loading audio via BytesIO object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')
sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
compression=compression)
expected, _ = sox_io_backend.load(path)
with open(path, 'rb') as file_:
fileobj = io.BytesIO(file_.read())
found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_bytesio_clogged(self, ext, compression):
"""Loading audio via clogged file object returns the same result as via file path.
This test case validates the case where fileobject returns shorter bytes than requeted.
"""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')
sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
compression=compression)
expected, _ = sox_io_backend.load(path)
with open(path, 'rb') as file_:
fileobj = CloggedFileObj(io.BytesIO(file_.read()))
found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_bytesio_tiny(self, ext, compression):
"""Loading very small audio via file object returns the same result as via file path.
"""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')
sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
compression=compression, duration=1 / 1600)
expected, _ = sox_io_backend.load(path)
with open(path, 'rb') as file_:
fileobj = io.BytesIO(file_.read())
found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_tarfile(self, ext, compression):
"""Loading compressed audio via file-like object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')
sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=2,
compression=compression)
expected, _ = sox_io_backend.load(audio_path)
with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@skipIfNoSox
@skipIfNoExec('sox')
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_requests(self, ext, compression):
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=2, compression=compression)
expected, _ = sox_io_backend.load(audio_path)
url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
found, sr = sox_io_backend.load(resp.raw, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@parameterized.expand(list(itertools.product(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
)), name_func=name_func)
def test_frame(self, frame_offset, num_frames):
"""num_frames and frame_offset correctly specify the region of data"""
sample_rate = 8000
audio_file = 'test.wav'
audio_path = self.get_temp_path(audio_file)
original = get_wav_data('float32', num_channels=2)
save_wav(audio_path, original, sample_rate)
frame_end = None if num_frames == -1 else frame_offset + num_frames
expected = original[:, frame_offset:frame_end]
url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
found, sr = sox_io_backend.load(resp.raw, frame_offset, num_frames)
assert sr == sample_rate
self.assertEqual(expected, found)
@skipIfNoSox
class TestLoadNoSuchFile(PytorchTestCase):
def test_load_fail(self):
"""
When attempted to load a non-existing file, error message must contain the file path.
"""
path = "non_existing_audio.wav"
with self.assertRaisesRegex(RuntimeError, "^Error loading audio file: failed to open file {0}$".format(path)):
sox_io_backend.load(path)
import itertools
from torchaudio.backend import sox_io_backend
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoSox,
get_wav_data,
)
from .common import (
name_func,
get_enc_params,
)
@skipIfNoExec('sox')
@skipIfNoSox
class TestRoundTripIO(TempDirMixin, PytorchTestCase):
"""save/load round trip should not degrade data for lossless formats"""
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_wav(self, dtype, sample_rate, num_channels):
"""save/load round trip should not degrade data for wav formats"""
original = get_wav_data(dtype, num_channels, normalize=False)
enc, bps = get_enc_params(dtype)
data = original
for i in range(10):
path = self.get_temp_path(f'{i}.wav')
sox_io_backend.save(path, data, sample_rate, encoding=enc, bits_per_sample=bps)
data, sr = sox_io_backend.load(path, normalize=False)
assert sr == sample_rate
self.assertEqual(original, data)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=name_func)
def test_flac(self, sample_rate, num_channels, compression_level):
"""save/load round trip should not degrade data for flac formats"""
original = get_wav_data('float32', num_channels)
data = original
for i in range(10):
path = self.get_temp_path(f'{i}.flac')
sox_io_backend.save(path, data, sample_rate, compression=compression_level)
data, sr = sox_io_backend.load(path)
assert sr == sample_rate
self.assertEqual(original, data)
import io
import os
import unittest
import torch
from torchaudio.backend import sox_io_backend
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
PytorchTestCase,
skipIfNoExec,
skipIfNoSox,
get_wav_data,
load_wav,
save_wav,
sox_utils,
nested_params,
)
from .common import (
name_func,
get_enc_params,
)
def _get_sox_encoding(encoding):
encodings = {
'PCM_F': 'floating-point',
'PCM_S': 'signed-integer',
'PCM_U': 'unsigned-integer',
'ULAW': 'u-law',
'ALAW': 'a-law',
}
return encodings.get(encoding)
class SaveTestBase(TempDirMixin, TorchaudioTestCase):
def assert_save_consistency(
self,
format: str,
*,
compression: float = None,
encoding: str = None,
bits_per_sample: int = None,
sample_rate: float = 8000,
num_channels: int = 2,
num_frames: float = 3 * 8000,
src_dtype: str = 'int32',
test_mode: str = "path",
):
"""`save` function produces file that is comparable with `sox` command
To compare that the file produced by `save` function agains the file produced by
the equivalent `sox` command, we need to load both files.
But there are many formats that cannot be opened with common Python modules (like
SciPy).
So we use `sox` command to prepare the original data and convert the saved files
into a format that SciPy can read (PCM wav).
The following diagram illustrates this process. The difference is 2.1. and 3.1.
This assumes that
- loading data with SciPy preserves the data well.
- converting the resulting files into WAV format with `sox` preserve the data well.
x
| 1. Generate source wav file with SciPy
|
v
-------------- wav ----------------
| |
| 2.1. load with scipy | 3.1. Convert to the target
| then save it into the target | format depth with sox
| format with torchaudio |
v v
target format target format
| |
| 2.2. Convert to wav with sox | 3.2. Convert to wav with sox
| |
v v
wav wav
| |
| 2.3. load with scipy | 3.3. load with scipy
| |
v v
tensor -------> compare <--------- tensor
"""
cmp_encoding = 'floating-point'
cmp_bit_depth = 32
src_path = self.get_temp_path('1.source.wav')
tgt_path = self.get_temp_path(f'2.1.torchaudio.{format}')
tst_path = self.get_temp_path('2.2.result.wav')
sox_path = self.get_temp_path(f'3.1.sox.{format}')
ref_path = self.get_temp_path('3.2.ref.wav')
# 1. Generate original wav
data = get_wav_data(src_dtype, num_channels, normalize=False, num_frames=num_frames)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to target format with torchaudio
data = load_wav(src_path, normalize=False)[0]
if test_mode == "path":
sox_io_backend.save(
tgt_path, data, sample_rate,
compression=compression, encoding=encoding, bits_per_sample=bits_per_sample)
elif test_mode == "fileobj":
with open(tgt_path, 'bw') as file_:
sox_io_backend.save(
file_, data, sample_rate,
format=format, compression=compression,
encoding=encoding, bits_per_sample=bits_per_sample)
elif test_mode == "bytesio":
file_ = io.BytesIO()
sox_io_backend.save(
file_, data, sample_rate,
format=format, compression=compression,
encoding=encoding, bits_per_sample=bits_per_sample)
file_.seek(0)
with open(tgt_path, 'bw') as f:
f.write(file_.read())
else:
raise ValueError(f"Unexpected test mode: {test_mode}")
# 2.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(
tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 2.3. Load with SciPy
found = load_wav(tst_path, normalize=False)[0]
# 3.1. Convert the original wav to target format with sox
sox_encoding = _get_sox_encoding(encoding)
sox_utils.convert_audio_file(
src_path, sox_path,
compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample)
# 3.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(
sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 3.3. Load with SciPy
expected = load_wav(ref_path, normalize=False)[0]
self.assertEqual(found, expected)
@skipIfNoExec('sox')
@skipIfNoSox
class SaveTest(SaveTestBase):
@nested_params(
["path", "fileobj", "bytesio"],
[
('PCM_U', 8),
('PCM_S', 16),
('PCM_S', 32),
('PCM_F', 32),
('PCM_F', 64),
('ULAW', 8),
('ALAW', 8),
],
)
def test_save_wav(self, test_mode, enc_params):
encoding, bits_per_sample = enc_params
self.assert_save_consistency(
"wav", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
[
('float32', ),
('int32', ),
('int16', ),
('uint8', ),
],
)
def test_save_wav_dtype(self, test_mode, params):
dtype, = params
self.assert_save_consistency(
"wav", src_dtype=dtype, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
[
None,
-4.2,
-0.2,
0,
0.2,
96,
128,
160,
192,
224,
256,
320,
],
)
def test_save_mp3(self, test_mode, bit_rate):
if test_mode in ["fileobj", "bytesio"]:
if bit_rate is not None and bit_rate < 1:
raise unittest.SkipTest(
"mp3 format with variable bit rate is known to "
"not yield the exact same result as sox command.")
self.assert_save_consistency(
"mp3", compression=bit_rate, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
[8, 16, 24],
[
None,
0,
1,
2,
3,
4,
5,
6,
7,
8,
],
)
def test_save_flac(self, test_mode, bits_per_sample, compression_level):
self.assert_save_consistency(
"flac", compression=compression_level,
bits_per_sample=bits_per_sample, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
)
def test_save_htk(self, test_mode):
self.assert_save_consistency("htk", test_mode=test_mode, num_channels=1)
@nested_params(
["path", "fileobj", "bytesio"],
[
None,
-1,
0,
1,
2,
3,
3.6,
5,
10,
],
)
def test_save_vorbis(self, test_mode, quality_level):
self.assert_save_consistency(
"vorbis", compression=quality_level, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
[
('PCM_S', 8, ),
('PCM_S', 16, ),
('PCM_S', 24, ),
('PCM_S', 32, ),
('ULAW', 8),
('ALAW', 8),
('ALAW', 16),
('ALAW', 24),
('ALAW', 32),
],
)
def test_save_sphere(self, test_mode, enc_params):
encoding, bits_per_sample = enc_params
self.assert_save_consistency(
"sph", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
[
('PCM_U', 8, ),
('PCM_S', 16, ),
('PCM_S', 24, ),
('PCM_S', 32, ),
('PCM_F', 32, ),
('PCM_F', 64, ),
('ULAW', 8, ),
('ALAW', 8, ),
],
)
def test_save_amb(self, test_mode, enc_params):
encoding, bits_per_sample = enc_params
self.assert_save_consistency(
"amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
[
None,
0,
1,
2,
3,
4,
5,
6,
7,
],
)
def test_save_amr_nb(self, test_mode, bit_rate):
self.assert_save_consistency(
"amr-nb", compression=bit_rate, num_channels=1, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
)
def test_save_gsm(self, test_mode):
self.assert_save_consistency(
"gsm", num_channels=1, test_mode=test_mode)
with self.assertRaises(
RuntimeError, msg="gsm format only supports single channel audio."):
self.assert_save_consistency(
"gsm", num_channels=2, test_mode=test_mode)
with self.assertRaises(
RuntimeError, msg="gsm format only supports a sampling rate of 8kHz."):
self.assert_save_consistency(
"gsm", sample_rate=16000, test_mode=test_mode)
@parameterized.expand([
("wav", "PCM_S", 16),
("mp3", ),
("flac", ),
("vorbis", ),
("sph", "PCM_S", 16),
("amr-nb", ),
("amb", "PCM_S", 16),
], name_func=name_func)
def test_save_large(self, format, encoding=None, bits_per_sample=None):
"""`sox_io_backend.save` can save large files."""
sample_rate = 8000
one_hour = 60 * 60 * sample_rate
self.assert_save_consistency(
format, num_channels=1, sample_rate=8000, num_frames=one_hour,
encoding=encoding, bits_per_sample=bits_per_sample)
@parameterized.expand([
(32, ),
(64, ),
(128, ),
(256, ),
], name_func=name_func)
def test_save_multi_channels(self, num_channels):
"""`sox_io_backend.save` can save audio with many channels"""
self.assert_save_consistency(
"wav", encoding="PCM_S", bits_per_sample=16,
num_channels=num_channels)
@skipIfNoExec('sox')
@skipIfNoSox
class TestSaveParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of optional parameters of `sox_io_backend.save`"""
@parameterized.expand([(True, ), (False, )], name_func=name_func)
def test_save_channels_first(self, channels_first):
"""channels_first swaps axes"""
path = self.get_temp_path('data.wav')
data = get_wav_data(
'int16', 2, channels_first=channels_first, normalize=False)
sox_io_backend.save(
path, data, 8000, channels_first=channels_first)
found = load_wav(path, normalize=False)[0]
expected = data if channels_first else data.transpose(1, 0)
self.assertEqual(found, expected)
@parameterized.expand([
'float32', 'int32', 'int16', 'uint8'
], name_func=name_func)
def test_save_noncontiguous(self, dtype):
"""Noncontiguous tensors are saved correctly"""
path = self.get_temp_path('data.wav')
enc, bps = get_enc_params(dtype)
expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2]
assert not expected.is_contiguous()
sox_io_backend.save(
path, expected, 8000, encoding=enc, bits_per_sample=bps)
found = load_wav(path, normalize=False)[0]
self.assertEqual(found, expected)
@parameterized.expand([
'float32', 'int32', 'int16', 'uint8',
])
def test_save_tensor_preserve(self, dtype):
"""save function should not alter Tensor"""
path = self.get_temp_path('data.wav')
expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2]
data = expected.clone()
sox_io_backend.save(path, data, 8000)
self.assertEqual(data, expected)
@skipIfNoSox
class TestSaveNonExistingDirectory(PytorchTestCase):
def test_save_fail(self):
"""
When attempted to save into a non-existing dir, error message must contain the file path.
"""
path = os.path.join("non_existing_directory", "foo.wav")
with self.assertRaisesRegex(RuntimeError, "^Error saving audio file: failed to open file {0}$".format(path)):
sox_io_backend.save(path, torch.zeros(1, 1), 8000)
import io
import itertools
import unittest
from torchaudio.utils import sox_utils
from torchaudio.backend import sox_io_backend
from torchaudio._internal.module_utils import is_sox_available
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
skipIfNoSox,
get_wav_data,
)
from .common import name_func
skipIfNoMP3 = unittest.skipIf(
not is_sox_available() or
'mp3' not in sox_utils.list_read_formats() or
'mp3' not in sox_utils.list_write_formats(),
'"sox_io" backend does not support MP3')
@skipIfNoSox
class SmokeTest(TempDirMixin, TorchaudioTestCase):
"""Run smoke test on various audio format
The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit
abnormal behaviors.
This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified.
"""
def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype='float32'):
duration = 1
num_frames = sample_rate * duration
path = self.get_temp_path(f'test.{ext}')
original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)
# 1. run save
sox_io_backend.save(path, original, sample_rate, compression=compression)
# 2. run info
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_channels == num_channels
# 3. run load
loaded, sr = sox_io_backend.load(path, normalize=False)
assert sr == sample_rate
assert loaded.shape[0] == num_channels
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_wav(self, dtype, sample_rate, num_channels):
"""Run smoke test on wav format"""
self.run_smoke_test('wav', sample_rate, num_channels, dtype=dtype)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
)))
@skipIfNoMP3
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format"""
self.run_smoke_test('mp3', sample_rate, num_channels, compression=bit_rate)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)))
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""Run smoke test on vorbis format"""
self.run_smoke_test('vorbis', sample_rate, num_channels, compression=quality_level)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=name_func)
def test_flac(self, sample_rate, num_channels, compression_level):
"""Run smoke test on flac format"""
self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level)
@skipIfNoSox
class SmokeTestFileObj(TorchaudioTestCase):
"""Run smoke test on various audio format
The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit
abnormal behaviors.
This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified.
"""
def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype='float32'):
duration = 1
num_frames = sample_rate * duration
original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)
fileobj = io.BytesIO()
# 1. run save
sox_io_backend.save(fileobj, original, sample_rate, compression=compression, format=ext)
# 2. run info
fileobj.seek(0)
info = sox_io_backend.info(fileobj, format=ext)
assert info.sample_rate == sample_rate
assert info.num_channels == num_channels
# 3. run load
fileobj.seek(0)
loaded, sr = sox_io_backend.load(fileobj, normalize=False, format=ext)
assert sr == sample_rate
assert loaded.shape[0] == num_channels
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_wav(self, dtype, sample_rate, num_channels):
"""Run smoke test on wav format"""
self.run_smoke_test('wav', sample_rate, num_channels, dtype=dtype)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
)))
@skipIfNoMP3
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format"""
self.run_smoke_test('mp3', sample_rate, num_channels, compression=bit_rate)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)))
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""Run smoke test on vorbis format"""
self.run_smoke_test('vorbis', sample_rate, num_channels, compression=quality_level)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=name_func)
def test_flac(self, sample_rate, num_channels, compression_level):
"""Run smoke test on flac format"""
self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level)
import itertools
from typing import Optional
import torch
import torchaudio
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
skipIfNoExec,
skipIfNoSox,
get_wav_data,
save_wav,
load_wav,
sox_utils,
torch_script,
)
from .common import (
name_func,
get_enc_params,
)
def py_info_func(filepath: str) -> torchaudio.backend.sox_io_backend.AudioMetaData:
return torchaudio.info(filepath)
def py_load_func(filepath: str, normalize: bool, channels_first: bool):
return torchaudio.load(
filepath, normalize=normalize, channels_first=channels_first)
def py_save_func(
filepath: str,
tensor: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
torchaudio.save(
filepath, tensor, sample_rate, channels_first,
compression, None, encoding, bits_per_sample)
@skipIfNoExec('sox')
@skipIfNoSox
class SoxIO(TempDirMixin, TorchaudioTestCase):
"""TorchScript-ability Test suite for `sox_io_backend`"""
backend = 'sox_io'
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_info_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` is torchscript-able and returns the same result"""
audio_path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
save_wav(audio_path, data, sample_rate)
ts_info_func = torch_script(py_info_func)
py_info = py_info_func(audio_path)
ts_info = ts_info_func(audio_path)
assert py_info.sample_rate == ts_info.sample_rate
assert py_info.num_frames == ts_info.num_frames
assert py_info.num_channels == ts_info.num_channels
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
[False, True],
[False, True],
)), name_func=name_func)
def test_load_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""`sox_io_backend.load` is torchscript-able and returns the same result"""
audio_path = self.get_temp_path(f'test_load_{dtype}_{sample_rate}_{num_channels}_{normalize}.wav')
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
save_wav(audio_path, data, sample_rate)
ts_load_func = torch_script(py_load_func)
py_data, py_sr = py_load_func(
audio_path, normalize=normalize, channels_first=channels_first)
ts_data, ts_sr = ts_load_func(
audio_path, normalize=normalize, channels_first=channels_first)
self.assertEqual(py_sr, ts_sr)
self.assertEqual(py_data, ts_data)
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_save_wav(self, dtype, sample_rate, num_channels):
ts_save_func = torch_script(py_save_func)
expected = get_wav_data(dtype, num_channels, normalize=False)
py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav')
ts_path = self.get_temp_path(f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav')
enc, bps = get_enc_params(dtype)
py_save_func(py_path, expected, sample_rate, True, None, enc, bps)
ts_save_func(ts_path, expected, sample_rate, True, None, enc, bps)
py_data, py_sr = load_wav(py_path, normalize=False)
ts_data, ts_sr = load_wav(ts_path, normalize=False)
self.assertEqual(sample_rate, py_sr)
self.assertEqual(sample_rate, ts_sr)
self.assertEqual(expected, py_data)
self.assertEqual(expected, ts_data)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=name_func)
def test_save_flac(self, sample_rate, num_channels, compression_level):
ts_save_func = torch_script(py_save_func)
expected = get_wav_data('float32', num_channels)
py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac')
ts_path = self.get_temp_path(f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac')
py_save_func(py_path, expected, sample_rate, True, compression_level, None, None)
ts_save_func(ts_path, expected, sample_rate, True, compression_level, None, None)
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
py_path_wav = f'{py_path}.wav'
ts_path_wav = f'{ts_path}.wav'
sox_utils.convert_audio_file(py_path, py_path_wav, bit_depth=32)
sox_utils.convert_audio_file(ts_path, ts_path_wav, bit_depth=32)
py_data, py_sr = load_wav(py_path_wav, normalize=True)
ts_data, ts_sr = load_wav(ts_path_wav, normalize=True)
self.assertEqual(sample_rate, py_sr)
self.assertEqual(sample_rate, ts_sr)
self.assertEqual(expected, py_data)
self.assertEqual(expected, ts_data)
import torchaudio
from torchaudio_unittest import common_utils
class BackendSwitchMixin:
"""Test set/get_audio_backend works"""
backend = None
backend_module = None
def test_switch(self):
torchaudio.set_audio_backend(self.backend)
if self.backend is None:
assert torchaudio.get_audio_backend() is None
else:
assert torchaudio.get_audio_backend() == self.backend
assert torchaudio.load == self.backend_module.load
assert torchaudio.save == self.backend_module.save
assert torchaudio.info == self.backend_module.info
class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = None
backend_module = torchaudio.backend.no_backend
@common_utils.skipIfNoSox
class TestBackendSwitch_SoXIO(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'sox_io'
backend_module = torchaudio.backend.sox_io_backend
@common_utils.skipIfNoModule('soundfile')
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'soundfile'
backend_module = torchaudio.backend.soundfile_backend
from .data_utils import (
get_asset_path,
get_whitenoise,
get_sinusoid,
get_spectrogram,
)
from .backend_utils import (
set_audio_backend,
)
from .case_utils import (
TempDirMixin,
HttpServerMixin,
TestBaseMixin,
PytorchTestCase,
TorchaudioTestCase,
skipIfNoCuda,
skipIfNoExec,
skipIfNoModule,
skipIfNoKaldi,
skipIfNoSox,
skipIfRocm,
skipIfNoQengine,
)
from .wav_utils import (
get_wav_data,
normalize_wav,
load_wav,
save_wav,
)
from .parameterized_utils import (
load_params,
nested_params
)
from .func_utils import torch_script
__all__ = [
'get_asset_path',
'get_whitenoise',
'get_sinusoid',
'get_spectrogram',
'set_audio_backend',
'TempDirMixin',
'HttpServerMixin',
'TestBaseMixin',
'PytorchTestCase',
'TorchaudioTestCase',
'skipIfNoCuda',
'skipIfNoExec',
'skipIfNoModule',
'skipIfNoKaldi',
'skipIfNoSox',
'skipIfNoSoxBackend',
'skipIfRocm',
'skipIfNoQengine',
'get_wav_data',
'normalize_wav',
'load_wav',
'save_wav',
'load_params',
'nested_params',
'torch_script',
]
import unittest
import torchaudio
def set_audio_backend(backend):
"""Allow additional backend value, 'default'"""
backends = torchaudio.list_audio_backends()
if backend == 'soundfile':
be = 'soundfile'
elif backend == 'default':
if 'sox_io' in backends:
be = 'sox_io'
elif 'soundfile' in backends:
be = 'soundfile'
else:
raise unittest.SkipTest('No default backend available')
else:
be = backend
torchaudio.set_audio_backend(be)
import shutil
import os.path
import subprocess
import tempfile
import time
import unittest
import torch
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
from torchaudio._internal.module_utils import (
is_module_available,
is_sox_available,
is_kaldi_available
)
from .backend_utils import set_audio_backend
class TempDirMixin:
"""Mixin to provide easy access to temp dir"""
temp_dir_ = None
@classmethod
def get_base_temp_dir(cls):
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
# this is handy for debugging.
key = 'TORCHAUDIO_TEST_TEMP_DIR'
if key in os.environ:
return os.environ[key]
if cls.temp_dir_ is None:
cls.temp_dir_ = tempfile.TemporaryDirectory()
return cls.temp_dir_.name
@classmethod
def tearDownClass(cls):
super().tearDownClass()
if cls.temp_dir_ is not None:
cls.temp_dir_.cleanup()
cls.temp_dir_ = None
def get_temp_path(self, *paths):
temp_dir = os.path.join(self.get_base_temp_dir(), self.id())
path = os.path.join(temp_dir, *paths)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
class HttpServerMixin(TempDirMixin):
"""Mixin that serves temporary directory as web server
This class creates temporary directory and serve the directory as HTTP service.
The server is up through the execution of all the test suite defined under the subclass.
"""
_proc = None
_port = 8000
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._proc = subprocess.Popen(
['python', '-m', 'http.server', f'{cls._port}'],
cwd=cls.get_base_temp_dir(),
stderr=subprocess.DEVNULL) # Disable server-side error log because it is confusing
time.sleep(2.0)
@classmethod
def tearDownClass(cls):
super().tearDownClass()
cls._proc.kill()
def get_url(self, *route):
return f'http://localhost:{self._port}/{self.id()}/{"/".join(route)}'
class TestBaseMixin:
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
dtype = None
device = None
backend = None
def setUp(self):
super().setUp()
set_audio_backend(self.backend)
@property
def complex_dtype(self):
if self.dtype in ['float32', 'float', torch.float, torch.float32]:
return torch.cfloat
if self.dtype in ['float64', 'double', torch.double, torch.float64]:
return torch.cdouble
raise ValueError(f'No corresponding complex dtype for {self.dtype}')
class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
pass
def skipIfNoExec(cmd):
return unittest.skipIf(shutil.which(cmd) is None, f'`{cmd}` is not available')
def skipIfNoModule(module, display_name=None):
display_name = display_name or module
return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available')
def skipIfNoCuda(test_item):
if torch.cuda.is_available():
return test_item
force_cuda_test = os.environ.get('TORCHAUDIO_TEST_FORCE_CUDA', '0')
if force_cuda_test not in ['0', '1']:
raise ValueError('"TORCHAUDIO_TEST_FORCE_CUDA" must be either "0" or "1".')
if force_cuda_test == '1':
raise RuntimeError('"TORCHAUDIO_TEST_FORCE_CUDA" is set but CUDA is not available.')
return unittest.skip('CUDA is not available.')(test_item)
skipIfNoSox = unittest.skipIf(not is_sox_available(), reason='Sox not available')
skipIfNoKaldi = unittest.skipIf(not is_kaldi_available(), reason='Kaldi not available')
skipIfRocm = unittest.skipIf(os.getenv('TORCHAUDIO_TEST_WITH_ROCM', '0') == '1',
reason="test doesn't currently work on the ROCm stack")
skipIfNoQengine = unittest.skipIf(
'fbgemm' not in torch.backends.quantized.supported_engines,
reason="`fbgemm` is not available."
)
import os.path
from typing import Union, Optional
import torch
_TEST_DIR_PATH = os.path.realpath(
os.path.join(os.path.dirname(__file__), '..'))
def get_asset_path(*paths):
"""Return full path of a test asset"""
return os.path.join(_TEST_DIR_PATH, 'assets', *paths)
def convert_tensor_encoding(
tensor: torch.tensor,
dtype: torch.dtype,
):
"""Convert input tensor with values between -1 and 1 to integer encoding
Args:
tensor: input tensor, assumed between -1 and 1
dtype: desired output tensor dtype
Returns:
Tensor: shape of (n_channels, sample_rate * duration)
"""
if dtype == torch.int32:
tensor *= (tensor > 0) * 2147483647 + (tensor < 0) * 2147483648
if dtype == torch.int16:
tensor *= (tensor > 0) * 32767 + (tensor < 0) * 32768
if dtype == torch.uint8:
tensor *= (tensor > 0) * 127 + (tensor < 0) * 128
tensor += 128
tensor = tensor.to(dtype)
return tensor
def get_whitenoise(
*,
sample_rate: int = 16000,
duration: float = 1, # seconds
n_channels: int = 1,
seed: int = 0,
dtype: Union[str, torch.dtype] = "float32",
device: Union[str, torch.device] = "cpu",
channels_first=True,
scale_factor: float = 1,
):
"""Generate pseudo audio data with whitenoise
Args:
sample_rate: Sampling rate
duration: Length of the resulting Tensor in seconds.
n_channels: Number of channels
seed: Seed value used for random number generation.
Note that this function does not modify global random generator state.
dtype: Torch dtype
device: device
channels_first: whether first dimension is n_channels
scale_factor: scale the Tensor before clamping and quantization
Returns:
Tensor: shape of (n_channels, sample_rate * duration)
"""
if isinstance(dtype, str):
dtype = getattr(torch, dtype)
if dtype not in [torch.float64, torch.float32, torch.int32, torch.int16, torch.uint8]:
raise NotImplementedError(f'dtype {dtype} is not supported.')
# According to the doc, folking rng on all CUDA devices is slow when there are many CUDA devices,
# so we only fork on CPU, generate values and move the data to the given device
with torch.random.fork_rng([]):
torch.random.manual_seed(seed)
tensor = torch.randn([n_channels, int(sample_rate * duration)],
dtype=torch.float32, device='cpu')
tensor /= 2.0
tensor *= scale_factor
tensor.clamp_(-1.0, 1.0)
if not channels_first:
tensor = tensor.t()
tensor = tensor.to(device)
return convert_tensor_encoding(tensor, dtype)
def get_sinusoid(
*,
frequency: float = 300,
sample_rate: int = 16000,
duration: float = 1, # seconds
n_channels: int = 1,
dtype: Union[str, torch.dtype] = "float32",
device: Union[str, torch.device] = "cpu",
channels_first: bool = True,
):
"""Generate pseudo audio data with sine wave.
Args:
frequency: Frequency of sine wave
sample_rate: Sampling rate
duration: Length of the resulting Tensor in seconds.
n_channels: Number of channels
dtype: Torch dtype
device: device
Returns:
Tensor: shape of (n_channels, sample_rate * duration)
"""
if isinstance(dtype, str):
dtype = getattr(torch, dtype)
pie2 = 2 * 3.141592653589793
end = pie2 * frequency * duration
theta = torch.linspace(0, end, int(sample_rate * duration), dtype=torch.float32, device=device)
tensor = torch.sin(theta, out=None).repeat([n_channels, 1])
if not channels_first:
tensor = tensor.t()
return convert_tensor_encoding(tensor, dtype)
def get_spectrogram(
waveform,
*,
n_fft: int = 2048,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: Optional[torch.Tensor] = None,
center: bool = True,
pad_mode: str = 'reflect',
power: Optional[float] = None,
):
"""Generate a spectrogram of the given Tensor
Args:
n_fft: The number of FFT bins.
hop_length: Stride for sliding window. default: ``n_fft // 4``.
win_length: The size of window frame and STFT filter. default: ``n_fft``.
winwdow: Window function. default: Hann window
center: Pad the input sequence if True. See ``torch.stft`` for the detail.
pad_mode: Padding method used when center is True. Default: "reflect".
power: If ``None``, raw spectrogram with complex values are returned,
otherwise the norm of the spectrogram is returned.
"""
hop_length = hop_length or n_fft // 4
win_length = win_length or n_fft
window = torch.hann_window(win_length, device=waveform.device) if window is None else window
spec = torch.stft(
waveform,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
center=center,
window=window,
pad_mode=pad_mode,
return_complex=True)
if power is not None:
spec = spec.abs() ** power
return spec
import io
import torch
def torch_script(obj):
"""TorchScript the given function or Module"""
buffer = io.BytesIO()
torch.jit.save(torch.jit.script(obj), buffer)
buffer.seek(0)
return torch.jit.load(buffer)
import subprocess
import torch
def convert_args(**kwargs):
args = []
for key, value in kwargs.items():
if key == 'sample_rate':
key = 'sample_frequency'
key = '--' + key.replace('_', '-')
value = str(value).lower() if value in [True, False] else str(value)
args.append('%s=%s' % (key, value))
return args
def run_kaldi(command, input_type, input_value):
"""Run provided Kaldi command, pass a tensor and get the resulting tensor
Args:
command (list of str): The command with arguments
input_type (str): 'ark' or 'scp'
input_value (Tensor for 'ark', string for 'scp'): The input to pass.
Must be a path to an audio file for 'scp'.
"""
import kaldi_io
key = 'foo'
process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
if input_type == 'ark':
kaldi_io.write_mat(process.stdin, input_value.cpu().numpy(), key=key)
elif input_type == 'scp':
process.stdin.write(f'{key} {input_value}'.encode('utf8'))
else:
raise NotImplementedError('Unexpected type')
process.stdin.close()
result = dict(kaldi_io.read_mat_ark(process.stdout))['foo']
return torch.from_numpy(result.copy()) # copy supresses some torch warning
import json
from itertools import product
from parameterized import param, parameterized
from .data_utils import get_asset_path
def load_params(*paths):
with open(get_asset_path(*paths), 'r') as file:
return [param(json.loads(line)) for line in file]
def _name_func(func, _, params):
strs = []
for arg in params.args:
if isinstance(arg, tuple):
strs.append("_".join(str(a) for a in arg))
else:
strs.append(str(arg))
# sanitize the test name
name = "_".join(strs).replace(".", "_")
return f'{func.__name__}_{name}'
def nested_params(*params_set):
"""Generate the cartesian product of the given list of parameters.
Args:
params_set (list of parameters): Parameters. When using ``parameterized.param`` class,
all the parameters have to be specified with the class, only using kwargs.
"""
flatten = [p for params in params_set for p in params]
# Parameters to be nested are given as list of plain objects
if all(not isinstance(p, param) for p in flatten):
args = list(product(*params_set))
return parameterized.expand(args, name_func=_name_func)
# Parameters to be nested are given as list of `parameterized.param`
if not all(isinstance(p, param) for p in flatten):
raise TypeError(
"When using ``parameterized.param``, "
"all the parameters have to be of the ``param`` type.")
if any(p.args for p in flatten):
raise ValueError(
"When using ``parameterized.param``, "
"all the parameters have to be provided as keyword argument."
)
args = [param()]
for params in params_set:
args = [param(**x.kwargs, **y.kwargs) for x in args for y in params]
return parameterized.expand(args)
from typing import Optional
import numpy as np
import torch
def psd_numpy(
X: np.array,
mask: Optional[np.array],
multi_mask: bool = False,
normalize: bool = True,
eps: float = 1e-15
) -> np.array:
X_conj = np.conj(X)
psd_X = np.einsum("...cft,...eft->...ftce", X, X_conj)
if mask is not None:
if multi_mask:
mask = mask.mean(axis=-3)
if normalize:
mask = mask / (mask.sum(axis=-1, keepdims=True) + eps)
psd = psd_X * mask[..., None, None]
else:
psd = psd_X
psd = psd.sum(axis=-3)
return torch.tensor(psd, dtype=torch.cdouble)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment