Commit b799fcd6 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Introduce I/O backend dispatcher (#3015)

Summary:
Adds I/O backend dispatcher that routes I/O requests to FFmpeg, SoX, or Soundfile backend, per library availability. It allows users to specify a backend mapped to a media library, i.e. one of `["ffmpeg", "sox", "soundfile"]`, to use via keyword argument, with FFmpeg being the default. Environment variable `TORCHAUDIO_USE_BACKEND_DISPATCHER` gates enablement of the dispatcher; specifically, if `TORCHAUDIO_USE_BACKEND_DISPATCHER` is explicitly set to `1`, importing TorchAudio makes it accessible via `torchaudio.info`, `torchaudio.load`, and `torchaudio.save`.

Pull Request resolved: https://github.com/pytorch/audio/pull/3015

Reviewed By: mthrok

Differential Revision: D43258649

Pulled By: hwangjeff

fbshipit-source-id: 8f12e4e56b9fa3f0814dd3fed3e1783ab23a53a1
parent 9db4bdf1
import io
from unittest.mock import patch
import torch
from parameterized import parameterized
from torchaudio._backend.utils import (
FFmpegBackend,
get_info_func,
get_load_func,
get_save_func,
SoundfileBackend,
SoXBackend,
)
from torchaudio_unittest.common_utils import PytorchTestCase
class DispatcherTest(PytorchTestCase):
@parameterized.expand(
[
# FFmpeg backend is used when no backend is specified.
({"ffmpeg": FFmpegBackend, "sox": SoXBackend, "soundfile": SoundfileBackend}, FFmpegBackend),
# SoX backend is used when no backend is specified and FFmpeg is not available.
({"sox": SoXBackend, "soundfile": SoundfileBackend}, SoXBackend),
]
)
def test_info(self, available_backends, expected_backend):
filename = "test.wav"
format = "wav"
with patch("torchaudio._backend.utils.get_available_backends", return_value=available_backends), patch(
f"torchaudio._backend.utils.{expected_backend.__name__}.info"
) as mock_info:
get_info_func()(filename, format=format)
mock_info.assert_called_once_with(filename, format, 4096)
@parameterized.expand(
[
# FFmpeg backend is used when no backend is specified.
({"ffmpeg": FFmpegBackend, "sox": SoXBackend, "soundfile": SoundfileBackend}, FFmpegBackend),
# Soundfile backend is used when no backend is specified, FFmpeg is not available,
# and input is file-like object (i.e. SoX is properly skipped over).
({"sox": SoXBackend, "soundfile": SoundfileBackend}, SoundfileBackend),
]
)
def test_info_fileobj(self, available_backends, expected_backend):
f = io.BytesIO()
format = "wav"
buffer_size = 8192
with patch("torchaudio._backend.utils.get_available_backends", return_value=available_backends), patch(
f"torchaudio._backend.utils.{expected_backend.__name__}.info"
) as mock_info:
get_info_func()(f, format=format, buffer_size=buffer_size)
mock_info.assert_called_once_with(f, format, buffer_size)
@parameterized.expand(
[
# FFmpeg backend is used when no backend is specified.
({"ffmpeg": FFmpegBackend, "sox": SoXBackend, "soundfile": SoundfileBackend}, FFmpegBackend),
# SoX backend is used when no backend is specified and FFmpeg is not available.
({"sox": SoXBackend, "soundfile": SoundfileBackend}, SoXBackend),
]
)
def test_load(self, available_backends, expected_backend):
filename = "test.wav"
format = "wav"
with patch("torchaudio._backend.utils.get_available_backends", return_value=available_backends), patch(
f"torchaudio._backend.utils.{expected_backend.__name__}.load"
) as mock_load:
get_load_func()(filename, format=format)
mock_load.assert_called_once_with(filename, 0, -1, True, True, format, 4096)
@parameterized.expand(
[
# FFmpeg backend is used when no backend is specified.
({"ffmpeg": FFmpegBackend, "sox": SoXBackend, "soundfile": SoundfileBackend}, FFmpegBackend),
# Soundfile backend is used when no backend is specified, FFmpeg is not available,
# and input is file-like object (i.e. SoX is properly skipped over).
({"sox": SoXBackend, "soundfile": SoundfileBackend}, SoundfileBackend),
]
)
def test_load_fileobj(self, available_backends, expected_backend):
f = io.BytesIO()
format = "wav"
buffer_size = 8192
with patch("torchaudio._backend.utils.get_available_backends", return_value=available_backends), patch(
f"torchaudio._backend.utils.{expected_backend.__name__}.load"
) as mock_load:
get_load_func()(f, format=format, buffer_size=buffer_size)
mock_load.assert_called_once_with(f, 0, -1, True, True, format, buffer_size)
@parameterized.expand(
[
# FFmpeg backend is used when no backend is specified.
({"ffmpeg": FFmpegBackend, "sox": SoXBackend, "soundfile": SoundfileBackend}, FFmpegBackend),
# SoX backend is used when no backend is specified and FFmpeg is not available.
({"sox": SoXBackend, "soundfile": SoundfileBackend}, SoXBackend),
]
)
def test_save(self, available_backends, expected_backend):
src = torch.zeros((2, 10))
filename = "test.wav"
format = "wav"
sample_rate = 16000
with patch("torchaudio._backend.utils.get_available_backends", return_value=available_backends), patch(
f"torchaudio._backend.utils.{expected_backend.__name__}.save"
) as mock_save:
get_save_func()(filename, src, sample_rate, format=format)
mock_save.assert_called_once_with(filename, src, sample_rate, True, format, None, None, 4096)
@parameterized.expand(
[
# FFmpeg backend is used when no backend is specified.
({"ffmpeg": FFmpegBackend, "sox": SoXBackend, "soundfile": SoundfileBackend}, FFmpegBackend),
# Soundfile backend is used when no backend is specified, FFmpeg is not available,
# and input is file-like object (i.e. SoX is properly skipped over).
({"sox": SoXBackend, "soundfile": SoundfileBackend}, SoundfileBackend),
]
)
def test_save_fileobj(self, available_backends, expected_backend):
src = torch.zeros((2, 10))
f = io.BytesIO()
format = "wav"
buffer_size = 8192
sample_rate = 16000
with patch("torchaudio._backend.utils.get_available_backends", return_value=available_backends), patch(
f"torchaudio._backend.utils.{expected_backend.__name__}.save"
) as mock_save:
get_save_func()(f, src, sample_rate, format=format, buffer_size=buffer_size)
mock_save.assert_called_once_with(f, src, sample_rate, True, format, None, None, buffer_size)
import io
import os
import subprocess
import sys
from functools import partial
import torch
from parameterized import parameterized
from torchaudio._backend.utils import get_save_func
from torchaudio.io._compat import _get_encoder, _get_encoder_format
from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func
from torchaudio_unittest.common_utils import (
get_wav_data,
load_wav,
nested_params,
PytorchTestCase,
save_wav,
skipIfNoExec,
skipIfNoFFmpeg,
TempDirMixin,
TorchaudioTestCase,
)
def _convert_audio_file(src_path, dst_path, format=None, acodec=None):
command = ["ffmpeg", "-i", src_path, "-strict", "-2"]
if format:
command += ["-sample_fmt", format]
if acodec:
command += ["-acodec", acodec]
command += [dst_path]
print(" ".join(command), file=sys.stderr)
subprocess.run(command, check=True)
class SaveTestBase(TempDirMixin, TorchaudioTestCase):
_save = partial(get_save_func(), backend="ffmpeg")
def assert_save_consistency(
self,
format: str,
*,
encoding: str = None,
bits_per_sample: int = None,
sample_rate: float = 8000,
num_channels: int = 2,
num_frames: float = 3 * 8000,
src_dtype: str = "int32",
test_mode: str = "path",
):
"""`save` function produces file that is comparable with `ffmpeg` command
To compare that the file produced by `save` function agains the file produced by
the equivalent `ffmpeg` command, we need to load both files.
But there are many formats that cannot be opened with common Python modules (like
SciPy).
So we use `ffmpeg` command to prepare the original data and convert the saved files
into a format that SciPy can read (PCM wav).
The following diagram illustrates this process. The difference is 2.1. and 3.1.
This assumes that
- loading data with SciPy preserves the data well.
- converting the resulting files into WAV format with `ffmpeg` preserve the data well.
x
| 1. Generate source wav file with SciPy
|
v
-------------- wav ----------------
| |
| 2.1. load with scipy | 3.1. Convert to the target
| then save it into the target | format depth with ffmpeg
| format with torchaudio |
v v
target format target format
| |
| 2.2. Convert to wav with ffmpeg | 3.2. Convert to wav with ffmpeg
| |
v v
wav wav
| |
| 2.3. load with scipy | 3.3. load with scipy
| |
v v
tensor -------> compare <--------- tensor
"""
src_path = self.get_temp_path("1.source.wav")
tgt_path = self.get_temp_path(f"2.1.torchaudio.{format}")
tst_path = self.get_temp_path("2.2.result.wav")
sox_path = self.get_temp_path(f"3.1.ffmpeg.{format}")
ref_path = self.get_temp_path("3.2.ref.wav")
# 1. Generate original wav
data = get_wav_data(src_dtype, num_channels, normalize=False, num_frames=num_frames)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to target format with torchaudio
data = load_wav(src_path, normalize=False)[0]
if test_mode == "path":
self._save(tgt_path, data, sample_rate, encoding=encoding, bits_per_sample=bits_per_sample)
elif test_mode == "fileobj":
with open(tgt_path, "bw") as file_:
self._save(
file_,
data,
sample_rate,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
elif test_mode == "bytesio":
file_ = io.BytesIO()
self._save(
file_,
data,
sample_rate,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
file_.seek(0)
with open(tgt_path, "bw") as f:
f.write(file_.read())
else:
raise ValueError(f"Unexpected test mode: {test_mode}")
# 2.2. Convert the target format to wav with ffmpeg
_convert_audio_file(tgt_path, tst_path, acodec="pcm_f32le")
# 2.3. Load with SciPy
found = load_wav(tst_path, normalize=False)[0]
# 3.1. Convert the original wav to target format with ffmpeg
acodec = _get_encoder(data.dtype, format, encoding, bits_per_sample)
sample_fmt = _get_encoder_format(format, bits_per_sample)
_convert_audio_file(src_path, sox_path, acodec=acodec, format=sample_fmt)
# 3.2. Convert the target format to wav with ffmpeg
_convert_audio_file(sox_path, ref_path, acodec="pcm_f32le")
# 3.3. Load with SciPy
expected = load_wav(ref_path, normalize=False)[0]
self.assertEqual(found, expected)
@skipIfNoExec("ffmpeg")
@skipIfNoFFmpeg
class SaveTest(SaveTestBase):
@nested_params(
["path", "fileobj", "bytesio"],
[
("PCM_U", 8),
("PCM_S", 16),
("PCM_S", 32),
("PCM_F", 32),
("PCM_F", 64),
("ULAW", 8),
("ALAW", 8),
],
)
def test_save_wav(self, test_mode, enc_params):
encoding, bits_per_sample = enc_params
self.assert_save_consistency("wav", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
[
("float32",),
("int32",),
("int16",),
("uint8",),
],
)
def test_save_wav_dtype(self, test_mode, params):
(dtype,) = params
self.assert_save_consistency("wav", src_dtype=dtype, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
# NOTE: Supported sample formats: s16 s32 (24 bits)
# [8, 16, 24],
[16, 24],
)
def test_save_flac(self, test_mode, bits_per_sample):
# -acodec flac -sample_fmt s16
# 24 bits needs to be mapped to s32
self.assert_save_consistency("flac", bits_per_sample=bits_per_sample, test_mode=test_mode)
# @nested_params(
# ["path", "fileobj", "bytesio"],
# )
# # NOTE: FFmpeg: Unable to find a suitable output format
# def test_save_htk(self, test_mode):
# self.assert_save_consistency("htk", test_mode=test_mode, num_channels=1)
@nested_params(
["path", "fileobj", "bytesio"],
)
def test_save_vorbis(self, test_mode):
# NOTE: ffmpeg doesn't recognize extension "vorbis", so we use "ogg"
# self.assert_save_consistency("vorbis", test_mode=test_mode)
self.assert_save_consistency("ogg", test_mode=test_mode)
# @nested_params(
# ["path", "fileobj", "bytesio"],
# [
# (
# "PCM_S",
# 8,
# ),
# (
# "PCM_S",
# 16,
# ),
# (
# "PCM_S",
# 24,
# ),
# (
# "PCM_S",
# 32,
# ),
# ("ULAW", 8),
# ("ALAW", 8),
# ("ALAW", 16),
# ("ALAW", 24),
# ("ALAW", 32),
# ],
# )
# NOTE: FFmpeg doesn't support encoding sphere files.
# def test_save_sphere(self, test_mode, enc_params):
# encoding, bits_per_sample = enc_params
# self.assert_save_consistency("sph", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
# @nested_params(
# ["path", "fileobj", "bytesio"],
# [
# (
# "PCM_U",
# 8,
# ),
# (
# "PCM_S",
# 16,
# ),
# (
# "PCM_S",
# 24,
# ),
# (
# "PCM_S",
# 32,
# ),
# (
# "PCM_F",
# 32,
# ),
# (
# "PCM_F",
# 64,
# ),
# (
# "ULAW",
# 8,
# ),
# (
# "ALAW",
# 8,
# ),
# ],
# )
# NOTE: FFmpeg doesn't support amb.
# def test_save_amb(self, test_mode, enc_params):
# encoding, bits_per_sample = enc_params
# self.assert_save_consistency("amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
# @nested_params(
# ["path", "fileobj", "bytesio"],
# )
# # NOTE: FFmpeg: Unable to find a suitable output format
# def test_save_amr_nb(self, test_mode):
# self.assert_save_consistency("amr-nb", num_channels=1, test_mode=test_mode)
# @nested_params(
# ["path", "fileobj", "bytesio"],
# )
# # NOTE: FFmpeg: RuntimeError: Unexpected codec: gsm
# def test_save_gsm(self, test_mode):
# self.assert_save_consistency("gsm", num_channels=1, test_mode=test_mode)
# with self.assertRaises(RuntimeError, msg="gsm format only supports single channel audio."):
# self.assert_save_consistency("gsm", num_channels=2, test_mode=test_mode)
# with self.assertRaises(RuntimeError, msg="gsm format only supports a sampling rate of 8kHz."):
# self.assert_save_consistency("gsm", sample_rate=16000, test_mode=test_mode)
@parameterized.expand(
[
("wav", "PCM_S", 16),
("flac",),
("ogg",),
# ("sph", "PCM_S", 16),
# ("amr-nb",),
# ("amb", "PCM_S", 16),
],
name_func=name_func,
)
def test_save_large(self, format, encoding=None, bits_per_sample=None):
"""`self._save` can save large files."""
sample_rate = 8000
one_hour = 60 * 60 * sample_rate
self.assert_save_consistency(
format,
# NOTE: for ogg, ffmpeg only supports >= 2 channels
num_channels=2,
sample_rate=8000,
num_frames=one_hour,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
@parameterized.expand(
[
(16,),
# NOTE: FFmpeg doesn't support more than 16 channels.
# (32,),
# (64,),
# (128,),
# (256,),
],
name_func=name_func,
)
def test_save_multi_channels(self, num_channels):
"""`self._save` can save audio with many channels"""
self.assert_save_consistency("wav", encoding="PCM_S", bits_per_sample=16, num_channels=num_channels)
@skipIfNoFFmpeg
class TestSaveParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of optional parameters of `self._save`"""
_save = partial(get_save_func(), backend="ffmpeg")
@parameterized.expand([(True,), (False,)], name_func=name_func)
def test_save_channels_first(self, channels_first):
"""channels_first swaps axes"""
path = self.get_temp_path("data.wav")
data = get_wav_data("int16", 2, channels_first=channels_first, normalize=False)
self._save(path, data, 8000, channels_first=channels_first)
found = load_wav(path, normalize=False)[0]
expected = data if channels_first else data.transpose(1, 0)
self.assertEqual(found, expected)
@parameterized.expand(["float32", "int32", "int16", "uint8"], name_func=name_func)
def test_save_noncontiguous(self, dtype):
"""Noncontiguous tensors are saved correctly"""
path = self.get_temp_path("data.wav")
enc, bps = get_enc_params(dtype)
expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2]
assert not expected.is_contiguous()
self._save(path, expected, 8000, encoding=enc, bits_per_sample=bps)
found = load_wav(path, normalize=False)[0]
self.assertEqual(found, expected)
@parameterized.expand(
[
"float32",
"int32",
"int16",
"uint8",
]
)
def test_save_tensor_preserve(self, dtype):
"""save function should not alter Tensor"""
path = self.get_temp_path("data.wav")
expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2]
data = expected.clone()
self._save(path, data, 8000)
self.assertEqual(data, expected)
@skipIfNoFFmpeg
class TestSaveNonExistingDirectory(PytorchTestCase):
_save = partial(get_save_func(), backend="ffmpeg")
def test_save_fail(self):
"""
When attempted to save into a non-existing dir, error message must contain the file path.
"""
path = os.path.join("non_existing_directory", "foo.wav")
with self.assertRaisesRegex(RuntimeError, path):
self._save(path, torch.zeros(1, 1), 8000)
import io
from torchaudio._backend.utils import get_info_func, get_load_func, get_save_func
from torchaudio_unittest.common_utils import get_wav_data, PytorchTestCase, skipIfNoFFmpeg, TempDirMixin
@skipIfNoFFmpeg
class SmokeTest(TempDirMixin, PytorchTestCase):
def run_smoke_test(self, ext, sample_rate, num_channels, *, dtype="float32"):
duration = 1
num_frames = sample_rate * duration
path = self.get_temp_path(f"test.{ext}")
original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)
get_save_func()(path, original, sample_rate)
info = get_info_func()(path)
assert info.sample_rate == sample_rate
assert info.num_channels == num_channels
loaded, sr = get_load_func()(path, normalize=False)
assert sr == sample_rate
assert loaded.shape[0] == num_channels
def test_wav(self):
dtype = "float32"
sample_rate = 16000
num_channels = 2
self.run_smoke_test("wav", sample_rate, num_channels, dtype=dtype)
@skipIfNoFFmpeg
class SmokeTestFileObj(TempDirMixin, PytorchTestCase):
def run_smoke_test(self, ext, sample_rate, num_channels, *, dtype="float32"):
buffer_size = 8192
duration = 1
num_frames = sample_rate * duration
fileobj = io.BytesIO()
original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)
get_save_func()(fileobj, original, sample_rate, format=ext, buffer_size=buffer_size)
fileobj.seek(0)
info = get_info_func()(fileobj, format=ext, buffer_size=buffer_size)
assert info.sample_rate == sample_rate
assert info.num_channels == num_channels
fileobj.seek(0)
loaded, sr = get_load_func()(fileobj, normalize=False, format=ext, buffer_size=buffer_size)
assert sr == sample_rate
assert loaded.shape[0] == num_channels
def test_wav(self):
dtype = "float32"
sample_rate = 16000
num_channels = 2
self.run_smoke_test("wav", sample_rate, num_channels, dtype=dtype)
import itertools
from unittest import skipIf
from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available
def name_func(func, _, params):
return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}'
def dtype2subtype(dtype):
return {
"float64": "DOUBLE",
"float32": "FLOAT",
"int32": "PCM_32",
"int16": "PCM_16",
"uint8": "PCM_U8",
"int8": "PCM_S8",
}[dtype]
def skipIfFormatNotSupported(fmt):
fmts = []
if is_module_available("soundfile"):
import soundfile
fmts = soundfile.available_formats()
return skipIf(fmt not in fmts, f'"{fmt}" is not supported by soundfile')
return skipIf(True, '"soundfile" not available.')
def parameterize(*params):
return parameterized.expand(list(itertools.product(*params)), name_func=name_func)
def fetch_wav_subtype(dtype, encoding, bits_per_sample):
subtype = {
(None, None): dtype2subtype(dtype),
(None, 8): "PCM_U8",
("PCM_U", None): "PCM_U8",
("PCM_U", 8): "PCM_U8",
("PCM_S", None): "PCM_32",
("PCM_S", 16): "PCM_16",
("PCM_S", 32): "PCM_32",
("PCM_F", None): "FLOAT",
("PCM_F", 32): "FLOAT",
("PCM_F", 64): "DOUBLE",
("ULAW", None): "ULAW",
("ULAW", 8): "ULAW",
("ALAW", None): "ALAW",
("ALAW", 8): "ALAW",
}.get((encoding, bits_per_sample))
if subtype:
return subtype
raise ValueError(f"wav does not support ({encoding}, {bits_per_sample}).")
import tarfile
import warnings
from functools import partial
from unittest.mock import patch
import torch
from torchaudio._backend.utils import get_info_func
from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.backend.common import get_bits_per_sample, get_encoding
from torchaudio_unittest.common_utils import (
get_wav_data,
nested_params,
PytorchTestCase,
save_wav,
skipIfNoModule,
TempDirMixin,
)
from .common import parameterize, skipIfFormatNotSupported
if _mod_utils.is_module_available("soundfile"):
import soundfile
@skipIfNoModule("soundfile")
class TestInfo(TempDirMixin, PytorchTestCase):
_info = partial(get_info_func(), backend="soundfile")
@parameterize(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
def test_wav(self, dtype, sample_rate, num_channels):
"""`self._info` can check wav file correctly"""
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = self._info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == get_bits_per_sample("wav", dtype)
assert info.encoding == get_encoding("wav", dtype)
@parameterize([8000, 16000], [1, 2])
@skipIfFormatNotSupported("FLAC")
def test_flac(self, sample_rate, num_channels):
"""`self._info` can check flac file correctly"""
duration = 1
num_frames = sample_rate * duration
data = torch.randn(num_frames, num_channels).numpy()
path = self.get_temp_path("data.flac")
soundfile.write(path, data, sample_rate)
info = self._info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == 16
assert info.encoding == "FLAC"
@parameterize([8000, 16000], [1, 2])
@skipIfFormatNotSupported("OGG")
def test_ogg(self, sample_rate, num_channels):
"""`self._info` can check ogg file correctly"""
duration = 1
num_frames = sample_rate * duration
data = torch.randn(num_frames, num_channels).numpy()
path = self.get_temp_path("data.ogg")
soundfile.write(path, data, sample_rate)
info = self._info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0
assert info.encoding == "VORBIS"
@nested_params(
[8000, 16000],
[1, 2],
[("PCM_24", 24), ("PCM_32", 32)],
)
@skipIfFormatNotSupported("NIST")
def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth):
"""`self._info` can check sph file correctly"""
duration = 1
num_frames = sample_rate * duration
data = torch.randn(num_frames, num_channels).numpy()
path = self.get_temp_path("data.nist")
subtype, bits_per_sample = subtype_and_bit_depth
soundfile.write(path, data, sample_rate, subtype=subtype)
info = self._info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "PCM_S"
def test_unknown_subtype_warning(self):
"""self._info issues a warning when the subtype is unknown
This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE
dict should be updated.
"""
def _mock_info_func(_):
class MockSoundFileInfo:
samplerate = 8000
frames = 356
channels = 2
subtype = "UNSEEN_SUBTYPE"
format = "UNKNOWN"
return MockSoundFileInfo()
with patch("soundfile.info", _mock_info_func):
with warnings.catch_warnings(record=True) as w:
info = self._info("foo")
assert len(w) == 1
assert "UNSEEN_SUBTYPE subtype is unknown to TorchAudio" in str(w[-1].message)
assert info.bits_per_sample == 0
@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
_info = partial(get_info_func(), backend="soundfile")
def _test_fileobj(self, ext, subtype, bits_per_sample):
"""Query audio via file-like object works"""
duration = 2
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
path = self.get_temp_path(f"test.{ext}")
data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(path, data, sample_rate, subtype=subtype)
with open(path, "rb") as fileobj:
info = self._info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "FLAC" if ext == "flac" else "PCM_S"
def test_fileobj_wav(self):
"""Loading audio via file-like object works"""
self._test_fileobj("wav", "PCM_16", 16)
@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Loading audio via file-like object works"""
self._test_fileobj("flac", "PCM_16", 16)
def _test_tarobj(self, ext, subtype, bits_per_sample):
"""Query compressed audio via file-like object works"""
duration = 2
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
audio_file = f"test.{ext}"
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path("archive.tar.gz")
data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(audio_path, data, sample_rate, subtype=subtype)
with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file)
info = self._info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "FLAC" if ext == "flac" else "PCM_S"
def test_tarobj_wav(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj("wav", "PCM_16", 16)
@skipIfFormatNotSupported("FLAC")
def test_tarobj_flac(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj("flac", "PCM_16", 16)
import os
import tarfile
from functools import partial
from unittest.mock import patch
import torch
from parameterized import parameterized
from torchaudio._backend.utils import get_load_func
from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.common_utils import (
get_wav_data,
load_wav,
normalize_wav,
PytorchTestCase,
save_wav,
skipIfNoModule,
TempDirMixin,
)
from .common import dtype2subtype, parameterize, skipIfFormatNotSupported
if _mod_utils.is_module_available("soundfile"):
import soundfile
def _get_mock_path(
ext: str,
dtype: str,
sample_rate: int,
num_channels: int,
num_frames: int,
):
return f"{dtype}_{sample_rate}_{num_channels}_{num_frames}.{ext}"
def _get_mock_params(path: str):
filename, ext = path.split(".")
parts = filename.split("_")
return {
"ext": ext,
"dtype": parts[0],
"sample_rate": int(parts[1]),
"num_channels": int(parts[2]),
"num_frames": int(parts[3]),
}
class SoundFileMock:
def __init__(self, path, mode):
assert mode == "r"
self.path = path
self._params = _get_mock_params(path)
self._start = None
@property
def samplerate(self):
return self._params["sample_rate"]
@property
def format(self):
if self._params["ext"] == "wav":
return "WAV"
if self._params["ext"] == "flac":
return "FLAC"
if self._params["ext"] == "ogg":
return "OGG"
if self._params["ext"] in ["sph", "nis", "nist"]:
return "NIST"
@property
def subtype(self):
if self._params["ext"] == "ogg":
return "VORBIS"
return dtype2subtype(self._params["dtype"])
def _prepare_read(self, start, stop, frames):
assert stop is None
self._start = start
return frames
def read(self, frames, dtype, always_2d):
assert always_2d
data = get_wav_data(
dtype,
self._params["num_channels"],
normalize=False,
num_frames=self._params["num_frames"],
channels_first=False,
).numpy()
return data[self._start : self._start + frames]
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
pass
class MockedLoadTest(PytorchTestCase):
_load = partial(get_load_func(), backend="soundfile")
def assert_dtype(self, ext, dtype, sample_rate, num_channels, normalize, channels_first):
"""When format is WAV or NIST, normalize=False will return the native dtype Tensor, otherwise float32"""
num_frames = 3 * sample_rate
path = _get_mock_path(ext, dtype, sample_rate, num_channels, num_frames)
expected_dtype = torch.float32 if normalize or ext not in ["wav", "nist"] else getattr(torch, dtype)
with patch("soundfile.SoundFile", SoundFileMock):
found, sr = self._load(path, normalize=normalize, channels_first=channels_first)
assert found.dtype == expected_dtype
assert sample_rate == sr
@parameterize(
["uint8", "int16", "int32", "float32", "float64"],
[8000, 16000],
[1, 2],
[True, False],
[True, False],
)
def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""Returns native dtype when normalize=False else float32"""
self.assert_dtype("wav", dtype, sample_rate, num_channels, normalize, channels_first)
@parameterize(
["int8", "int16", "int32"],
[8000, 16000],
[1, 2],
[True, False],
[True, False],
)
def test_sphere(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""Returns float32 always"""
self.assert_dtype("sph", dtype, sample_rate, num_channels, normalize, channels_first)
@parameterize([8000, 16000], [1, 2], [True, False], [True, False])
def test_ogg(self, sample_rate, num_channels, normalize, channels_first):
"""Returns float32 always"""
self.assert_dtype("ogg", "int16", sample_rate, num_channels, normalize, channels_first)
@parameterize([8000, 16000], [1, 2], [True, False], [True, False])
def test_flac(self, sample_rate, num_channels, normalize, channels_first):
"""`soundfile_backend.load` can load ogg format."""
self.assert_dtype("flac", "int16", sample_rate, num_channels, normalize, channels_first)
class LoadTestBase(TempDirMixin, PytorchTestCase):
_load = partial(get_load_func(), backend="soundfile")
def assert_wav(
self,
dtype,
sample_rate,
num_channels,
normalize,
channels_first=True,
duration=1,
):
"""`soundfile_backend.load` can load wav format correctly.
Wav data loaded with soundfile backend should match those with scipy
"""
path = self.get_temp_path("reference.wav")
num_frames = duration * sample_rate
data = get_wav_data(
dtype,
num_channels,
normalize=normalize,
num_frames=num_frames,
channels_first=channels_first,
)
save_wav(path, data, sample_rate, channels_first=channels_first)
expected = load_wav(path, normalize=normalize, channels_first=channels_first)[0]
data, sr = self._load(path, normalize=normalize, channels_first=channels_first)
assert sr == sample_rate
self.assertEqual(data, expected)
def assert_sphere(
self,
dtype,
sample_rate,
num_channels,
channels_first=True,
duration=1,
):
"""`soundfile_backend.load` can load SPHERE format correctly."""
path = self.get_temp_path("reference.sph")
num_frames = duration * sample_rate
raw = get_wav_data(
dtype,
num_channels,
num_frames=num_frames,
normalize=False,
channels_first=False,
)
soundfile.write(path, raw, sample_rate, subtype=dtype2subtype(dtype), format="NIST")
expected = normalize_wav(raw.t() if channels_first else raw)
data, sr = self._load(path, channels_first=channels_first)
assert sr == sample_rate
self.assertEqual(data, expected, atol=1e-4, rtol=1e-8)
def assert_flac(
self,
dtype,
sample_rate,
num_channels,
channels_first=True,
duration=1,
):
"""`soundfile_backend.load` can load FLAC format correctly."""
path = self.get_temp_path("reference.flac")
num_frames = duration * sample_rate
raw = get_wav_data(
dtype,
num_channels,
num_frames=num_frames,
normalize=False,
channels_first=False,
)
soundfile.write(path, raw, sample_rate)
expected = normalize_wav(raw.t() if channels_first else raw)
data, sr = self._load(path, channels_first=channels_first)
assert sr == sample_rate
self.assertEqual(data, expected, atol=1e-4, rtol=1e-8)
@skipIfNoModule("soundfile")
class TestLoad(LoadTestBase):
"""Test the correctness of `soundfile_backend.load` for various formats"""
@parameterize(
["float32", "int32", "int16"],
[8000, 16000],
[1, 2],
[False, True],
[False, True],
)
def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""`soundfile_backend.load` can load wav format correctly."""
self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first)
@parameterize(
["int16"],
[16000],
[2],
[False],
)
def test_wav_large(self, dtype, sample_rate, num_channels, normalize):
"""`soundfile_backend.load` can load large wav file correctly."""
two_hours = 2 * 60 * 60
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=two_hours)
@parameterize(["float32", "int32", "int16"], [4, 8, 16, 32], [False, True])
def test_multiple_channels(self, dtype, num_channels, channels_first):
"""`soundfile_backend.load` can load wav file with more than 2 channels."""
sample_rate = 8000
normalize = False
self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first)
@parameterize(["int32", "int16"], [8000, 16000], [1, 2], [False, True])
@skipIfFormatNotSupported("NIST")
def test_sphere(self, dtype, sample_rate, num_channels, channels_first):
"""`soundfile_backend.load` can load sphere format correctly."""
self.assert_sphere(dtype, sample_rate, num_channels, channels_first)
@parameterize(["int32", "int16"], [8000, 16000], [1, 2], [False, True])
@skipIfFormatNotSupported("FLAC")
def test_flac(self, dtype, sample_rate, num_channels, channels_first):
"""`soundfile_backend.load` can load flac format correctly."""
self.assert_flac(dtype, sample_rate, num_channels, channels_first)
@skipIfNoModule("soundfile")
class TestLoadFormat(TempDirMixin, PytorchTestCase):
"""Given `format` parameter, `so.load` can load files without extension"""
_load = partial(get_load_func(), backend="soundfile")
original = None
path = None
def _make_file(self, format_):
sample_rate = 8000
path_with_ext = self.get_temp_path(f"test.{format_}")
data = get_wav_data("float32", num_channels=2).numpy().T
soundfile.write(path_with_ext, data, sample_rate)
expected = soundfile.read(path_with_ext, dtype="float32")[0].T
path = os.path.splitext(path_with_ext)[0]
os.rename(path_with_ext, path)
return path, expected
def _test_format(self, format_):
"""Providing format allows to read file without extension"""
path, expected = self._make_file(format_)
found, _ = self._load(path)
self.assertEqual(found, expected)
@parameterized.expand(
[
("WAV",),
("wav",),
]
)
def test_wav(self, format_):
self._test_format(format_)
@parameterized.expand(
[
("FLAC",),
("flac",),
]
)
@skipIfFormatNotSupported("FLAC")
def test_flac(self, format_):
self._test_format(format_)
@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
_load = partial(get_load_func(), backend="soundfile")
def _test_fileobj(self, ext):
"""Loading audio via file-like object works"""
sample_rate = 16000
path = self.get_temp_path(f"test.{ext}")
data = get_wav_data("float32", num_channels=2).numpy().T
soundfile.write(path, data, sample_rate)
expected = soundfile.read(path, dtype="float32")[0].T
with open(path, "rb") as fileobj:
found, sr = self._load(fileobj)
assert sr == sample_rate
self.assertEqual(expected, found)
def test_fileobj_wav(self):
"""Loading audio via file-like object works"""
self._test_fileobj("wav")
@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Loading audio via file-like object works"""
self._test_fileobj("flac")
def _test_tarfile(self, ext):
"""Loading audio via file-like object works"""
sample_rate = 16000
audio_file = f"test.{ext}"
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path("archive.tar.gz")
data = get_wav_data("float32", num_channels=2).numpy().T
soundfile.write(audio_path, data, sample_rate)
expected = soundfile.read(audio_path, dtype="float32")[0].T
with tarfile.TarFile(archive_path, "w") as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, "r") as tarobj:
fileobj = tarobj.extractfile(audio_file)
found, sr = self._load(fileobj)
assert sr == sample_rate
self.assertEqual(expected, found)
def test_tarfile_wav(self):
"""Loading audio via file-like object works"""
self._test_tarfile("wav")
@skipIfFormatNotSupported("FLAC")
def test_tarfile_flac(self):
"""Loading audio via file-like object works"""
self._test_tarfile("flac")
import io
from functools import partial
from unittest.mock import patch
from torchaudio._backend.utils import get_save_func
from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.common_utils import (
get_wav_data,
load_wav,
nested_params,
PytorchTestCase,
skipIfNoModule,
TempDirMixin,
)
from .common import fetch_wav_subtype, parameterize, skipIfFormatNotSupported
if _mod_utils.is_module_available("soundfile"):
import soundfile
class MockedSaveTest(PytorchTestCase):
_save = partial(get_save_func(), backend="soundfile")
@nested_params(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
[False, True],
[
(None, None),
("PCM_U", None),
("PCM_U", 8),
("PCM_S", None),
("PCM_S", 16),
("PCM_S", 32),
("PCM_F", None),
("PCM_F", 32),
("PCM_F", 64),
("ULAW", None),
("ULAW", 8),
("ALAW", None),
("ALAW", 8),
],
)
@patch("soundfile.write")
def test_wav(self, dtype, sample_rate, num_channels, channels_first, enc_params, mocked_write):
"""self._save passes correct subtype to soundfile.write when WAV"""
filepath = "foo.wav"
input_tensor = get_wav_data(
dtype,
num_channels,
num_frames=3 * sample_rate,
normalize=dtype == "float32",
channels_first=channels_first,
).t()
encoding, bits_per_sample = enc_params
self._save(
filepath,
input_tensor,
sample_rate,
channels_first=channels_first,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
# on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1]
assert args["file"] == filepath
assert args["samplerate"] == sample_rate
assert args["subtype"] == fetch_wav_subtype(dtype, encoding, bits_per_sample)
assert args["format"] is None
self.assertEqual(args["data"], input_tensor.t() if channels_first else input_tensor)
@patch("soundfile.write")
def assert_non_wav(
self,
fmt,
dtype,
sample_rate,
num_channels,
channels_first,
mocked_write,
encoding=None,
bits_per_sample=None,
):
"""self._save passes correct subtype and format to soundfile.write when SPHERE"""
filepath = f"foo.{fmt}"
input_tensor = get_wav_data(
dtype,
num_channels,
num_frames=3 * sample_rate,
normalize=False,
channels_first=channels_first,
).t()
expected_data = input_tensor.t() if channels_first else input_tensor
self._save(
filepath,
input_tensor,
sample_rate,
channels_first,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
# on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1]
assert args["file"] == filepath
assert args["samplerate"] == sample_rate
if fmt in ["sph", "nist", "nis"]:
assert args["format"] == "NIST"
else:
assert args["format"] is None
self.assertEqual(args["data"], expected_data)
@nested_params(
["sph", "nist", "nis"],
["int32", "int16"],
[8000, 16000],
[1, 2],
[False, True],
[
("PCM_S", 8),
("PCM_S", 16),
("PCM_S", 24),
("PCM_S", 32),
("ULAW", 8),
("ALAW", 8),
("ALAW", 16),
("ALAW", 24),
("ALAW", 32),
],
)
def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first, enc_params):
"""self._save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
encoding, bits_per_sample = enc_params
self.assert_non_wav(
fmt, dtype, sample_rate, num_channels, channels_first, encoding=encoding, bits_per_sample=bits_per_sample
)
@parameterize(
["int32", "int16"],
[8000, 16000],
[1, 2],
[False, True],
[8, 16, 24],
)
def test_flac(self, dtype, sample_rate, num_channels, channels_first, bits_per_sample):
"""self._save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav("flac", dtype, sample_rate, num_channels, channels_first, bits_per_sample=bits_per_sample)
@parameterize(
["int32", "int16"],
[8000, 16000],
[1, 2],
[False, True],
)
def test_ogg(self, dtype, sample_rate, num_channels, channels_first):
"""self._save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav("ogg", dtype, sample_rate, num_channels, channels_first)
@skipIfNoModule("soundfile")
class SaveTestBase(TempDirMixin, PytorchTestCase):
_save = partial(get_save_func(), backend="soundfile")
def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
"""`self._save` can save wav format."""
path = self.get_temp_path("data.wav")
expected = get_wav_data(dtype, num_channels, num_frames=num_frames, normalize=False)
self._save(path, expected, sample_rate)
found, sr = load_wav(path, normalize=False)
assert sample_rate == sr
self.assertEqual(found, expected)
def _assert_non_wav(self, fmt, dtype, sample_rate, num_channels):
"""`self._save` can save non-wav format.
Due to precision missmatch, and the lack of alternative way to decode the
resulting files without using soundfile, only meta data are validated.
"""
num_frames = sample_rate * 3
path = self.get_temp_path(f"data.{fmt}")
expected = get_wav_data(dtype, num_channels, num_frames=num_frames, normalize=False)
self._save(path, expected, sample_rate)
sinfo = soundfile.info(path)
assert sinfo.format == fmt.upper()
assert sinfo.frames == num_frames
assert sinfo.channels == num_channels
assert sinfo.samplerate == sample_rate
def assert_flac(self, dtype, sample_rate, num_channels):
"""`self._save` can save flac format."""
self._assert_non_wav("flac", dtype, sample_rate, num_channels)
def assert_sphere(self, dtype, sample_rate, num_channels):
"""`self._save` can save sph format."""
self._assert_non_wav("nist", dtype, sample_rate, num_channels)
def assert_ogg(self, dtype, sample_rate, num_channels):
"""`self._save` can save ogg format.
As we cannot inspect the OGG format (it's lossy), we only check the metadata.
"""
self._assert_non_wav("ogg", dtype, sample_rate, num_channels)
@skipIfNoModule("soundfile")
class TestSave(SaveTestBase):
@parameterize(
["float32", "int32", "int16"],
[8000, 16000],
[1, 2],
)
def test_wav(self, dtype, sample_rate, num_channels):
"""`self._save` can save wav format."""
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterize(
["float32", "int32", "int16"],
[4, 8, 16, 32],
)
def test_multiple_channels(self, dtype, num_channels):
"""`self._save` can save wav with more than 2 channels."""
sample_rate = 8000
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterize(
["int32", "int16"],
[8000, 16000],
[1, 2],
)
@skipIfFormatNotSupported("NIST")
def test_sphere(self, dtype, sample_rate, num_channels):
"""`self._save` can save sph format."""
self.assert_sphere(dtype, sample_rate, num_channels)
@parameterize(
[8000, 16000],
[1, 2],
)
@skipIfFormatNotSupported("FLAC")
def test_flac(self, sample_rate, num_channels):
"""`self._save` can save flac format."""
self.assert_flac("float32", sample_rate, num_channels)
@parameterize(
[8000, 16000],
[1, 2],
)
@skipIfFormatNotSupported("OGG")
def test_ogg(self, sample_rate, num_channels):
"""`self._save` can save ogg/vorbis format."""
self.assert_ogg("float32", sample_rate, num_channels)
@skipIfNoModule("soundfile")
class TestSaveParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of optional parameters of `self._save`"""
_save = partial(get_save_func(), backend="soundfile")
@parameterize([True, False])
def test_channels_first(self, channels_first):
"""channels_first swaps axes"""
path = self.get_temp_path("data.wav")
data = get_wav_data("int32", 2, channels_first=channels_first)
self._save(path, data, 8000, channels_first=channels_first)
found = load_wav(path)[0]
expected = data if channels_first else data.transpose(1, 0)
self.assertEqual(found, expected, atol=1e-4, rtol=1e-8)
@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
_save = partial(get_save_func(), backend="soundfile")
def _test_fileobj(self, ext):
"""Saving audio to file-like object works"""
sample_rate = 16000
path = self.get_temp_path(f"test.{ext}")
subtype = "FLOAT" if ext == "wav" else None
data = get_wav_data("float32", num_channels=2)
soundfile.write(path, data.numpy().T, sample_rate, subtype=subtype)
expected = soundfile.read(path, dtype="float32")[0]
fileobj = io.BytesIO()
self._save(fileobj, data, sample_rate, format=ext)
fileobj.seek(0)
found, sr = soundfile.read(fileobj, dtype="float32")
assert sr == sample_rate
self.assertEqual(expected, found, atol=1e-4, rtol=1e-8)
def test_fileobj_wav(self):
"""Saving audio via file-like object works"""
self._test_fileobj("wav")
@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Saving audio via file-like object works"""
self._test_fileobj("flac")
@skipIfFormatNotSupported("NIST")
def test_fileobj_nist(self):
"""Saving audio via file-like object works"""
self._test_fileobj("NIST")
@skipIfFormatNotSupported("OGG")
def test_fileobj_ogg(self):
"""Saving audio via file-like object works"""
self._test_fileobj("OGG")
def name_func(func, _, params):
return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}'
def get_enc_params(dtype):
if dtype == "float32":
return "PCM_F", 32
if dtype == "int32":
return "PCM_S", 32
if dtype == "int16":
return "PCM_S", 16
if dtype == "uint8":
return "PCM_U", 8
raise ValueError(f"Unexpected dtype: {dtype}")
import itertools
import os
from functools import partial
from parameterized import parameterized
from torchaudio._backend.utils import get_info_func
from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.backend.common import get_encoding
from torchaudio_unittest.common_utils import (
get_asset_path,
get_wav_data,
HttpServerMixin,
PytorchTestCase,
save_wav,
skipIfNoExec,
skipIfNoModule,
skipIfNoSox,
sox_utils,
TempDirMixin,
)
from .common import name_func
if _mod_utils.is_module_available("requests"):
import requests
@skipIfNoExec("sox")
@skipIfNoSox
class TestInfo(TempDirMixin, PytorchTestCase):
_info = partial(get_info_func(), backend="sox")
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_wav(self, dtype, sample_rate, num_channels):
"""`self._info` can check wav file correctly"""
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = self._info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
assert info.encoding == get_encoding("wav", dtype)
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[4, 8, 16, 32],
)
),
name_func=name_func,
)
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`self._info` can check wav file with channels more than 2 correctly"""
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = self._info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
assert info.encoding == get_encoding("wav", dtype)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)
),
name_func=name_func,
)
def test_flac(self, sample_rate, num_channels, compression_level):
"""`self._info` can check flac file correctly"""
duration = 1
path = self.get_temp_path("data.flac")
sox_utils.gen_audio_file(
path,
sample_rate,
num_channels,
compression=compression_level,
duration=duration,
)
info = self._info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 24 # FLAC standard
assert info.encoding == "FLAC"
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)
),
name_func=name_func,
)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`self._info` can check vorbis file correctly"""
duration = 1
path = self.get_temp_path("data.vorbis")
sox_utils.gen_audio_file(
path,
sample_rate,
num_channels,
compression=quality_level,
duration=duration,
)
info = self._info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert info.encoding == "VORBIS"
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
[16, 32],
)
),
name_func=name_func,
)
def test_sphere(self, sample_rate, num_channels, bits_per_sample):
"""`self._info` can check sph file correctly"""
duration = 1
path = self.get_temp_path("data.sph")
sox_utils.gen_audio_file(path, sample_rate, num_channels, duration=duration, bit_depth=bits_per_sample)
info = self._info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == "PCM_S"
@parameterized.expand(
list(
itertools.product(
["int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_amb(self, dtype, sample_rate, num_channels):
"""`self._info` can check amb file correctly"""
duration = 1
path = self.get_temp_path("data.amb")
bits_per_sample = sox_utils.get_bit_depth(dtype)
sox_utils.gen_audio_file(path, sample_rate, num_channels, bit_depth=bits_per_sample, duration=duration)
info = self._info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample
assert info.encoding == get_encoding("amb", dtype)
def test_amr_nb(self):
"""`self._info` can check amr-nb file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path("data.amr-nb")
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, duration=duration
)
info = self._info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0
assert info.encoding == "AMR_NB"
def test_ulaw(self):
"""`self._info` can check ulaw file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path("data.wav")
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=8, encoding="u-law", duration=duration
)
info = self._info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 8
assert info.encoding == "ULAW"
def test_alaw(self):
"""`self._info` can check alaw file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path("data.wav")
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=8, encoding="a-law", duration=duration
)
info = self._info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 8
assert info.encoding == "ALAW"
def test_gsm(self):
"""`self._info` can check gsm file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path("data.gsm")
sox_utils.gen_audio_file(path, sample_rate=sample_rate, num_channels=num_channels, duration=duration)
info = self._info(path)
assert info.sample_rate == sample_rate
assert info.num_channels == num_channels
assert info.bits_per_sample == 0
assert info.encoding == "GSM"
def test_htk(self):
"""`self._info` can check HTK file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path("data.htk")
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, duration=duration
)
info = self._info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 16
assert info.encoding == "PCM_S"
@skipIfNoSox
class TestInfoOpus(PytorchTestCase):
_info = partial(get_info_func(), backend="sox")
@parameterized.expand(
list(
itertools.product(
["96k"],
[1, 2],
[0, 5, 10],
)
),
name_func=name_func,
)
def test_opus(self, bitrate, num_channels, compression_level):
"""`self._info` can check opus file correcty"""
path = get_asset_path("io", f"{bitrate}_{compression_level}_{num_channels}ch.opus")
info = self._info(path)
assert info.sample_rate == 48000
assert info.num_frames == 32768
assert info.num_channels == num_channels
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
assert info.encoding == "OPUS"
class FileObjTestBase(TempDirMixin):
def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self.get_temp_path(f"test.{ext}")
bit_depth = sox_utils.get_bit_depth(dtype)
duration = num_frames / sample_rate
comment_file = self._gen_comment_file(comments) if comments else None
sox_utils.gen_audio_file(
path,
sample_rate,
num_channels=num_channels,
encoding=sox_utils.get_encoding(dtype),
bit_depth=bit_depth,
duration=duration,
comment_file=comment_file,
)
return path
def _gen_comment_file(self, comments):
comment_path = self.get_temp_path("comment.txt")
with open(comment_path, "w") as file_:
file_.writelines(comments)
return comment_path
class Unseekable:
def __init__(self, fileobj):
self.fileobj = fileobj
def read(self, n):
return self.fileobj.read(n)
@skipIfNoSox
@skipIfNoExec("sox")
class TestFileObject(FileObjTestBase, PytorchTestCase):
_info = partial(get_info_func(), backend="sox")
def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
with open(path, "rb") as fileobj:
return self._info(fileobj, None)
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
# ("mp3", "float32"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_fileobj(self, ext, dtype):
"""Querying audio via file object works"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
with self.assertRaisesRegex(ValueError, "SoX backend does not support reading"):
self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames)
@skipIfNoSox
@skipIfNoExec("sox")
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
_info = partial(get_info_func(), backend="sox")
def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames):
audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
audio_file = os.path.basename(audio_path)
url = self.get_url(audio_file)
# format_ = ext if ext in ["mp3"] else None
with requests.get(url, stream=True) as resp:
return self._info(Unseekable(resp.raw), format=None)
@parameterized.expand(
[
("wav", "float32"),
("wav", "int32"),
("wav", "int16"),
("wav", "uint8"),
# ("mp3", "float32"),
("flac", "float32"),
("vorbis", "float32"),
("amb", "int16"),
]
)
def test_requests(self, ext, dtype):
"""Querying compressed audio via requests works"""
sample_rate = 16000
num_frames = 3.0 * sample_rate
num_channels = 2
with self.assertRaisesRegex(ValueError, "SoX backend does not support reading"):
self._query_http(ext, dtype, sample_rate, num_channels, num_frames)
@skipIfNoSox
class TestInfoNoSuchFile(PytorchTestCase):
_info = partial(get_info_func(), backend="sox")
def test_info_fail(self):
"""
When attempted to get info on a non-existing file, error message must contain the file path.
"""
path = "non_existing_audio.wav"
with self.assertRaisesRegex(RuntimeError, path):
self._info(path)
import itertools
from functools import partial
import torch
import torchaudio
from parameterized import parameterized
from torchaudio._backend.utils import get_load_func
from torchaudio_unittest.common_utils import (
get_asset_path,
get_wav_data,
load_wav,
nested_params,
PytorchTestCase,
save_wav,
skipIfNoExec,
skipIfNoSox,
sox_utils,
TempDirMixin,
)
from .common import name_func
class LoadTestBase(TempDirMixin, PytorchTestCase):
_load = partial(get_load_func(), backend="sox")
def assert_format(
self,
format: str,
sample_rate: float,
num_channels: int,
compression: float = None,
bit_depth: int = None,
duration: float = 1,
normalize: bool = True,
encoding: str = None,
atol: float = 4e-05,
rtol: float = 1.3e-06,
):
"""`sox_io_backend.load` can load given format correctly.
file encodings introduce delay and boundary effects so
we create a reference wav file from the original file format
x
|
| 1. Generate given format with Sox
|
v 2. Convert to wav with Sox
given format ----------------------> wav
| |
| 3. Load with torchaudio | 4. Load with scipy
| |
v v
tensor ----------> x <----------- tensor
5. Compare
Underlying assumptions are;
i. Conversion of given format to wav with Sox preserves data.
ii. Loading wav file with scipy is correct.
By combining i & ii, step 2. and 4. allows to load reference given format
data without using torchaudio
"""
path = self.get_temp_path(f"1.original.{format}")
ref_path = self.get_temp_path("2.reference.wav")
# 1. Generate the given format with sox
sox_utils.gen_audio_file(
path,
sample_rate,
num_channels,
encoding=encoding,
compression=compression,
bit_depth=bit_depth,
duration=duration,
)
# 2. Convert to wav with sox
wav_bit_depth = 32 if bit_depth == 24 else None # for 24-bit wav
sox_utils.convert_audio_file(path, ref_path, bit_depth=wav_bit_depth)
# 3. Load the given format with torchaudio
data, sr = self._load(path, normalize=normalize)
# 4. Load wav with scipy
data_ref = load_wav(ref_path, normalize=normalize)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=atol, rtol=rtol)
def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
"""`sox_io_backend.load` can load wav format correctly.
Wav data loaded with sox_io backend should match those with scipy
"""
path = self.get_temp_path("reference.wav")
data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
expected = load_wav(path, normalize=normalize)[0]
data, sr = self._load(path, normalize=normalize)
assert sr == sample_rate
self.assertEqual(data, expected)
@skipIfNoExec("sox")
@skipIfNoSox
class TestLoad(LoadTestBase):
"""Test the correctness of `sox_io_backend.load` for various formats"""
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
[False, True],
)
),
name_func=name_func,
)
def test_wav(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load wav format correctly."""
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
[False, True],
)
),
name_func=name_func,
)
def test_24bit_wav(self, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load 24bit wav format correctly. Corectly casts it to ``int32`` tensor dtype."""
self.assert_format("wav", sample_rate, num_channels, bit_depth=24, normalize=normalize, duration=1)
@parameterized.expand(
list(
itertools.product(
["int16"],
[16000],
[2],
[False],
)
),
name_func=name_func,
)
def test_wav_large(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load large wav file correctly."""
two_hours = 2 * 60 * 60
self.assert_wav(dtype, sample_rate, num_channels, normalize, two_hours)
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16", "uint8"],
[4, 8, 16, 32],
)
),
name_func=name_func,
)
def test_multiple_channels(self, dtype, num_channels):
"""`sox_io_backend.load` can load wav file with more than 2 channels."""
sample_rate = 8000
normalize = False
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)
),
name_func=name_func,
)
def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load flac format correctly."""
self.assert_format("flac", sample_rate, num_channels, compression=compression_level, bit_depth=16, duration=1)
@parameterized.expand(
list(
itertools.product(
[16000],
[2],
[0],
)
),
name_func=name_func,
)
def test_flac_large(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load large flac file correctly."""
two_hours = 2 * 60 * 60
self.assert_format(
"flac", sample_rate, num_channels, compression=compression_level, bit_depth=16, duration=two_hours
)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)
),
name_func=name_func,
)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load vorbis format correctly."""
self.assert_format("vorbis", sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=1)
@parameterized.expand(
list(
itertools.product(
[16000],
[2],
[10],
)
),
name_func=name_func,
)
def test_vorbis_large(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load large vorbis file correctly."""
two_hours = 2 * 60 * 60
self.assert_format(
"vorbis", sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=two_hours
)
@parameterized.expand(
list(
itertools.product(
["96k"],
[1, 2],
[0, 5, 10],
)
),
name_func=name_func,
)
def test_opus(self, bitrate, num_channels, compression_level):
"""`sox_io_backend.load` can load opus file correctly."""
ops_path = get_asset_path("io", f"{bitrate}_{compression_level}_{num_channels}ch.opus")
wav_path = self.get_temp_path(f"{bitrate}_{compression_level}_{num_channels}ch.opus.wav")
sox_utils.convert_audio_file(ops_path, wav_path)
expected, sample_rate = load_wav(wav_path)
found, sr = self._load(ops_path)
assert sample_rate == sr
self.assertEqual(expected, found)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_sphere(self, sample_rate, num_channels):
"""`sox_io_backend.load` can load sph format correctly."""
self.assert_format("sph", sample_rate, num_channels, bit_depth=32, duration=1)
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16"],
[8000, 16000],
[1, 2],
[False, True],
)
),
name_func=name_func,
)
def test_amb(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load amb format correctly."""
bit_depth = sox_utils.get_bit_depth(dtype)
encoding = sox_utils.get_encoding(dtype)
self.assert_format(
"amb", sample_rate, num_channels, bit_depth=bit_depth, duration=1, encoding=encoding, normalize=normalize
)
def test_amr_nb(self):
"""`sox_io_backend.load` can load amr_nb format correctly."""
self.assert_format("amr-nb", sample_rate=8000, num_channels=1, bit_depth=32, duration=1)
@skipIfNoSox
class TestLoadParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of frame parameters of `sox_io_backend.load`"""
def _test(self, func, frame_offset, num_frames, channels_first, normalize):
original = get_wav_data("int16", num_channels=2, normalize=False)
path = self.get_temp_path("test.wav")
save_wav(path, original, sample_rate=8000)
output, _ = func(path, frame_offset, num_frames, normalize, channels_first, None)
frame_end = None if num_frames == -1 else frame_offset + num_frames
expected = original[:, slice(frame_offset, frame_end)]
if not channels_first:
expected = expected.T
if normalize:
expected = expected.to(torch.float32) / (2**15)
self.assertEqual(output, expected)
@nested_params(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
[True, False],
[True, False],
)
def test_sox(self, frame_offset, num_frames, channels_first, normalize):
"""The combination of properly changes the output tensor"""
self._test(torch.ops.torchaudio.sox_io_load_audio_file, frame_offset, num_frames, channels_first, normalize)
# test file-like obj
def func(path, *args):
with open(path, "rb") as fileobj:
return torchaudio.lib._torchaudio_sox.load_audio_fileobj(fileobj, *args)
self._test(func, frame_offset, num_frames, channels_first, normalize)
@nested_params(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
[True, False],
[True, False],
)
def test_ffmpeg(self, frame_offset, num_frames, channels_first, normalize):
"""The combination of properly changes the output tensor"""
from torchaudio.io._compat import load_audio, load_audio_fileobj
self._test(load_audio, frame_offset, num_frames, channels_first, normalize)
# test file-like obj
def func(path, *args):
with open(path, "rb") as fileobj:
return load_audio_fileobj(fileobj, *args)
self._test(func, frame_offset, num_frames, channels_first, normalize)
@skipIfNoSox
@skipIfNoExec("sox")
class TestFileObject(TempDirMixin, PytorchTestCase):
"""
In this test suite, the result of file-like object input is compared against file path input,
because `load` function is rigrously tested for file path inputs to match libsox's result,
"""
_load = partial(get_load_func(), backend="sox")
@parameterized.expand(
[
("wav", {"bit_depth": 16}),
("wav", {"bit_depth": 24}),
("wav", {"bit_depth": 32}),
("flac", {"compression": 0}),
("flac", {"compression": 5}),
("flac", {"compression": 8}),
("vorbis", {"compression": -1}),
("vorbis", {"compression": 10}),
("amb", {}),
]
)
def test_fileobj(self, ext, kwargs):
"""Loading audio via file object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ["mp3"] else None
path = self.get_temp_path(f"test.{ext}")
sox_utils.gen_audio_file(path, sample_rate, num_channels=2, **kwargs)
expected, _ = self._load(path)
with open(path, "rb") as fileobj:
with self.assertRaisesRegex(ValueError, "SoX backend does not support loading"):
self._load(fileobj, format=format_)
@skipIfNoSox
class TestLoadNoSuchFile(PytorchTestCase):
_load = partial(get_load_func(), backend="sox")
def test_load_fail(self):
"""
When attempted to load a non-existing file, error message must contain the file path.
"""
path = "non_existing_audio.wav"
with self.assertRaisesRegex(RuntimeError, path):
self._load(path)
import itertools
from functools import partial
from parameterized import parameterized
from torchaudio._backend.utils import get_load_func, get_save_func
from torchaudio_unittest.common_utils import get_wav_data, PytorchTestCase, skipIfNoExec, skipIfNoSox, TempDirMixin
from .common import get_enc_params, name_func
@skipIfNoExec("sox")
@skipIfNoSox
class TestRoundTripIO(TempDirMixin, PytorchTestCase):
"""save/load round trip should not degrade data for lossless formats"""
_load = partial(get_load_func(), backend="sox")
_save = partial(get_save_func(), backend="sox")
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_wav(self, dtype, sample_rate, num_channels):
"""save/load round trip should not degrade data for wav formats"""
original = get_wav_data(dtype, num_channels, normalize=False)
enc, bps = get_enc_params(dtype)
data = original
for i in range(10):
path = self.get_temp_path(f"{i}.wav")
self._save(path, data, sample_rate, encoding=enc, bits_per_sample=bps)
data, sr = self._load(path, normalize=False)
assert sr == sample_rate
self.assertEqual(original, data)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_flac(self, sample_rate, num_channels):
"""save/load round trip should not degrade data for flac formats"""
original = get_wav_data("float32", num_channels)
data = original
for i in range(10):
path = self.get_temp_path(f"{i}.flac")
self._save(path, data, sample_rate)
data, sr = self._load(path)
assert sr == sample_rate
self.assertEqual(original, data)
import io
import os
from functools import partial
import torch
from parameterized import parameterized
from torchaudio._backend.utils import get_save_func
from torchaudio_unittest.common_utils import (
get_wav_data,
load_wav,
nested_params,
PytorchTestCase,
save_wav,
skipIfNoExec,
skipIfNoSox,
sox_utils,
TempDirMixin,
TorchaudioTestCase,
)
from .common import get_enc_params, name_func
def _get_sox_encoding(encoding):
encodings = {
"PCM_F": "floating-point",
"PCM_S": "signed-integer",
"PCM_U": "unsigned-integer",
"ULAW": "u-law",
"ALAW": "a-law",
}
return encodings.get(encoding)
class SaveTestBase(TempDirMixin, TorchaudioTestCase):
_save = partial(get_save_func(), backend="sox")
def assert_save_consistency(
self,
format: str,
*,
encoding: str = None,
bits_per_sample: int = None,
sample_rate: float = 8000,
num_channels: int = 2,
num_frames: float = 3 * 8000,
src_dtype: str = "int32",
test_mode: str = "path",
):
"""`save` function produces file that is comparable with `sox` command
To compare that the file produced by `save` function agains the file produced by
the equivalent `sox` command, we need to load both files.
But there are many formats that cannot be opened with common Python modules (like
SciPy).
So we use `sox` command to prepare the original data and convert the saved files
into a format that SciPy can read (PCM wav).
The following diagram illustrates this process. The difference is 2.1. and 3.1.
This assumes that
- loading data with SciPy preserves the data well.
- converting the resulting files into WAV format with `sox` preserve the data well.
x
| 1. Generate source wav file with SciPy
|
v
-------------- wav ----------------
| |
| 2.1. load with scipy | 3.1. Convert to the target
| then save it into the target | format depth with sox
| format with torchaudio |
v v
target format target format
| |
| 2.2. Convert to wav with sox | 3.2. Convert to wav with sox
| |
v v
wav wav
| |
| 2.3. load with scipy | 3.3. load with scipy
| |
v v
tensor -------> compare <--------- tensor
"""
cmp_encoding = "floating-point"
cmp_bit_depth = 32
src_path = self.get_temp_path("1.source.wav")
tgt_path = self.get_temp_path(f"2.1.torchaudio.{format}")
tst_path = self.get_temp_path("2.2.result.wav")
sox_path = self.get_temp_path(f"3.1.sox.{format}")
ref_path = self.get_temp_path("3.2.ref.wav")
# 1. Generate original wav
data = get_wav_data(src_dtype, num_channels, normalize=False, num_frames=num_frames)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to target format with torchaudio
data = load_wav(src_path, normalize=False)[0]
if test_mode == "path":
self._save(tgt_path, data, sample_rate, encoding=encoding, bits_per_sample=bits_per_sample)
elif test_mode == "fileobj":
with open(tgt_path, "bw") as file_:
self._save(
file_,
data,
sample_rate,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
elif test_mode == "bytesio":
file_ = io.BytesIO()
self._save(
file_,
data,
sample_rate,
format=format,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
file_.seek(0)
with open(tgt_path, "bw") as f:
f.write(file_.read())
else:
raise ValueError(f"Unexpected test mode: {test_mode}")
# 2.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 2.3. Load with SciPy
found = load_wav(tst_path, normalize=False)[0]
# 3.1. Convert the original wav to target format with sox
sox_encoding = _get_sox_encoding(encoding)
sox_utils.convert_audio_file(src_path, sox_path, encoding=sox_encoding, bit_depth=bits_per_sample)
# 3.2. Convert the target format to wav with sox
sox_utils.convert_audio_file(sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 3.3. Load with SciPy
expected = load_wav(ref_path, normalize=False)[0]
self.assertEqual(found, expected)
@skipIfNoExec("sox")
@skipIfNoSox
class SaveTest(SaveTestBase):
@nested_params(
[
("PCM_U", 8),
("PCM_S", 16),
("PCM_S", 32),
("PCM_F", 32),
("PCM_F", 64),
("ULAW", 8),
("ALAW", 8),
],
)
def test_save_wav(self, enc_params):
encoding, bits_per_sample = enc_params
self.assert_save_consistency("wav", encoding=encoding, bits_per_sample=bits_per_sample, test_mode="path")
@nested_params(
[
("float32",),
("int32",),
("int16",),
("uint8",),
],
)
def test_save_wav_dtype(self, params):
(dtype,) = params
self.assert_save_consistency("wav", src_dtype=dtype, test_mode="path")
@nested_params(
[8, 16, 24],
)
def test_save_flac(self, bits_per_sample):
self.assert_save_consistency("flac", bits_per_sample=bits_per_sample, test_mode="path")
def test_save_htk(self):
self.assert_save_consistency("htk", test_mode="path", num_channels=1)
def test_save_vorbis(self):
self.assert_save_consistency("vorbis", test_mode="path")
@nested_params(
[
(
"PCM_S",
8,
),
(
"PCM_S",
16,
),
(
"PCM_S",
24,
),
(
"PCM_S",
32,
),
("ULAW", 8),
("ALAW", 8),
("ALAW", 16),
("ALAW", 24),
("ALAW", 32),
],
)
def test_save_sphere(self, enc_params):
encoding, bits_per_sample = enc_params
self.assert_save_consistency("sph", encoding=encoding, bits_per_sample=bits_per_sample, test_mode="path")
@nested_params(
[
(
"PCM_U",
8,
),
(
"PCM_S",
16,
),
(
"PCM_S",
24,
),
(
"PCM_S",
32,
),
(
"PCM_F",
32,
),
(
"PCM_F",
64,
),
(
"ULAW",
8,
),
(
"ALAW",
8,
),
],
)
def test_save_amb(self, enc_params):
encoding, bits_per_sample = enc_params
self.assert_save_consistency("amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode="path")
def test_save_amr_nb(self):
self.assert_save_consistency("amr-nb", num_channels=1, test_mode="path")
def test_save_gsm(self):
self.assert_save_consistency("gsm", num_channels=1, test_mode="path")
with self.assertRaises(RuntimeError, msg="gsm format only supports single channel audio."):
self.assert_save_consistency("gsm", num_channels=2, test_mode="path")
with self.assertRaises(RuntimeError, msg="gsm format only supports a sampling rate of 8kHz."):
self.assert_save_consistency("gsm", sample_rate=16000, test_mode="path")
@parameterized.expand(
[
("wav", "PCM_S", 16),
("flac",),
("vorbis",),
("sph", "PCM_S", 16),
("amr-nb",),
("amb", "PCM_S", 16),
],
name_func=name_func,
)
def test_save_large(self, format, encoding=None, bits_per_sample=None):
"""`self._save` can save large files."""
sample_rate = 8000
one_hour = 60 * 60 * sample_rate
self.assert_save_consistency(
format,
num_channels=1,
sample_rate=8000,
num_frames=one_hour,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
@parameterized.expand(
[
(32,),
(64,),
(128,),
(256,),
],
name_func=name_func,
)
def test_save_multi_channels(self, num_channels):
"""`self._save` can save audio with many channels"""
self.assert_save_consistency("wav", encoding="PCM_S", bits_per_sample=16, num_channels=num_channels)
@skipIfNoExec("sox")
@skipIfNoSox
class TestSaveParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of optional parameters of `self._save`"""
_save = partial(get_save_func(), backend="sox")
@parameterized.expand([(True,), (False,)], name_func=name_func)
def test_save_channels_first(self, channels_first):
"""channels_first swaps axes"""
path = self.get_temp_path("data.wav")
data = get_wav_data("int16", 2, channels_first=channels_first, normalize=False)
self._save(path, data, 8000, channels_first=channels_first)
found = load_wav(path, normalize=False)[0]
expected = data if channels_first else data.transpose(1, 0)
self.assertEqual(found, expected)
@parameterized.expand(["float32", "int32", "int16", "uint8"], name_func=name_func)
def test_save_noncontiguous(self, dtype):
"""Noncontiguous tensors are saved correctly"""
path = self.get_temp_path("data.wav")
enc, bps = get_enc_params(dtype)
expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2]
assert not expected.is_contiguous()
self._save(path, expected, 8000, encoding=enc, bits_per_sample=bps)
found = load_wav(path, normalize=False)[0]
self.assertEqual(found, expected)
@parameterized.expand(
[
"float32",
"int32",
"int16",
"uint8",
]
)
def test_save_tensor_preserve(self, dtype):
"""save function should not alter Tensor"""
path = self.get_temp_path("data.wav")
expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2]
data = expected.clone()
self._save(path, data, 8000)
self.assertEqual(data, expected)
@skipIfNoSox
class TestSaveNonExistingDirectory(PytorchTestCase):
_save = partial(get_save_func(), backend="sox")
def test_save_fail(self):
"""
When attempted to save into a non-existing dir, error message must contain the file path.
"""
path = os.path.join("non_existing_directory", "foo.wav")
with self.assertRaisesRegex(RuntimeError, path):
self._save(path, torch.zeros(1, 1), 8000)
import itertools
from functools import partial
from parameterized import parameterized
from torchaudio._backend.utils import get_info_func, get_load_func, get_save_func
from torchaudio_unittest.common_utils import get_wav_data, skipIfNoSox, TempDirMixin, TorchaudioTestCase
from .common import name_func
@skipIfNoSox
class SmokeTest(TempDirMixin, TorchaudioTestCase):
"""Run smoke test on various audio format
The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit
abnormal behaviors.
This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified.
"""
_info = partial(get_info_func(), backend="sox")
_load = partial(get_load_func(), backend="sox")
_save = partial(get_save_func(), backend="sox")
def run_smoke_test(self, ext, sample_rate, num_channels, *, dtype="float32"):
duration = 1
num_frames = sample_rate * duration
path = self.get_temp_path(f"test.{ext}")
original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)
# 1. run save
self._save(path, original, sample_rate)
# 2. run info
info = self._info(path)
assert info.sample_rate == sample_rate
assert info.num_channels == num_channels
# 3. run load
loaded, sr = self._load(path, normalize=False)
assert sr == sample_rate
assert loaded.shape[0] == num_channels
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16", "uint8"],
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_wav(self, dtype, sample_rate, num_channels):
"""Run smoke test on wav format"""
self.run_smoke_test("wav", sample_rate, num_channels, dtype=dtype)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
)
)
)
def test_vorbis(self, sample_rate, num_channels):
"""Run smoke test on vorbis format"""
self.run_smoke_test("vorbis", sample_rate, num_channels)
@parameterized.expand(
list(
itertools.product(
[8000, 16000],
[1, 2],
)
),
name_func=name_func,
)
def test_flac(self, sample_rate, num_channels):
"""Run smoke test on flac format"""
self.run_smoke_test("flac", sample_rate, num_channels)
......@@ -11,6 +11,7 @@ from torchaudio import ( # noqa: F401
transforms,
utils,
)
from torchaudio.backend import get_audio_backend, list_audio_backends, set_audio_backend
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment