Unverified Commit daa0007a authored by moto's avatar moto Committed by GitHub
Browse files

Add smoke tests to sox_io and sox_effects (#806)

Currently all the tests in `sox_io_backend` and `sox_effects` (for new SoX effects implementation) requires additional `sox`, and this prevents running test in environment where `sox` command is not available even though `torchaudio` extension is available (such as fb internal). This PR adds smoke tests for these modules, which just runs functions to see if they do not crash.
parent 3781cb23
......@@ -42,7 +42,7 @@
{"effects": [["fade", "l", "3"]]}
{"effects": [["fade", "p", "3"]]}
{"effects": [["fir", "0.0195", "-0.082", "0.234", "0.891", "-0.145", "0.043"]]}
{"effects": [["fir", "test/assets/sox_effect_test_fir_coeffs.txt"]]}
{"effects": [["fir", "<ASSET_DIR>/sox_effect_test_fir_coeffs.txt"]]}
{"effects": [["flanger"]]}
{"effects": [["gain", "-n"]]}
{"effects": [["gain", "-n", "-3"]]}
......
import json
from parameterized import param
from ..common_utils import get_asset_path
def name_func(func, _, params):
if isinstance(params.args[0], str):
args = "_".join([str(arg) for arg in params.args])
else:
args = "_".join([str(arg) for arg in params.args[0]])
return f'{func.__name__}_{args}'
def load_params(*paths):
params = []
with open(get_asset_path(*paths), 'r') as file:
for line in file:
data = json.loads(line)
for effect in data['effects']:
for i, arg in enumerate(effect):
if arg.startswith("<ASSET_DIR>"):
effect[i] = arg.replace("<ASSET_DIR>", get_asset_path())
params.append(param(data))
return params
from torchaudio import sox_effects
from parameterized import parameterized
from ..common_utils import (
TempDirMixin,
TorchaudioTestCase,
skipIfNoExtension,
get_wav_data,
get_sinusoid,
save_wav,
)
from .common import (
name_func,
load_params,
)
@skipIfNoExtension
class SmokeTest(TempDirMixin, TorchaudioTestCase):
"""Run smoke test on various effects
The purpose of this test suite is to verify that sox_effect functionalities do not exhibit
abnormal behaviors.
This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified.
"""
@parameterized.expand(
load_params("sox_effect_test_args.json"),
name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}',
)
def test_apply_effects_tensor(self, args):
"""`apply_effects_tensor` should not crash"""
effects = args['effects']
num_channels = args.get("num_channels", 2)
input_sr = args.get("input_sample_rate", 8000)
original = get_sinusoid(
frequency=800, sample_rate=input_sr,
n_channels=num_channels, dtype='float32')
_found, _sr = sox_effects.apply_effects_tensor(original, input_sr, effects)
@parameterized.expand(
load_params("sox_effect_test_args.json"),
name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}',
)
def test_apply_effects(self, args):
"""`apply_effects_file` should return identical data as sox command"""
dtype = 'int32'
channels_first = True
effects = args['effects']
num_channels = args.get("num_channels", 2)
input_sr = args.get("input_sample_rate", 8000)
input_path = self.get_temp_path('input.wav')
data = get_wav_data(dtype, num_channels, channels_first=channels_first)
save_wav(input_path, data, input_sr, channels_first=channels_first)
_found, _sr = sox_effects.apply_effects_file(
input_path, effects, normalize=False, channels_first=channels_first)
......@@ -9,7 +9,6 @@ from ..common_utils import (
PytorchTestCase,
skipIfNoExtension,
get_whitenoise,
load_wav,
save_wav,
)
......
......@@ -11,10 +11,10 @@ from ..common_utils import (
get_wav_data,
save_wav,
load_wav,
load_params,
sox_utils,
)
from .common import (
load_params,
name_func,
)
......
......@@ -9,9 +9,11 @@ from ..common_utils import (
PytorchTestCase,
skipIfNoExtension,
get_sinusoid,
load_params,
save_wav,
)
from .common import (
load_params,
)
class SoxEffectTensorTransform(torch.nn.Module):
......
import itertools
from torchaudio.backend import sox_io_backend
from parameterized import parameterized
from ..common_utils import (
TempDirMixin,
TorchaudioTestCase,
skipIfNoExtension,
get_wav_data,
)
from .common import name_func
@skipIfNoExtension
class SmokeTest(TempDirMixin, TorchaudioTestCase):
"""Run smoke test on various audio format
The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit
abnormal behaviors.
This test suite should be able to run without any additional tools (such as sox command),
however without such tools, the correctness of each function cannot be verified.
"""
def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype='float32'):
duration = 1
num_frames = sample_rate * duration
path = self.get_temp_path(f'test.{ext}')
original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)
# 1. run save
sox_io_backend.save(path, original, sample_rate, compression=compression)
# 2. run info
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_channels == num_channels
# 3. run load
loaded, sr = sox_io_backend.load(path, normalize=False)
assert sr == sample_rate
assert loaded.shape[0] == num_channels
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_wav(self, dtype, sample_rate, num_channels):
"""Run smoke test on wav format"""
self.run_smoke_test('wav', sample_rate, num_channels, dtype=dtype)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
)), name_func=name_func)
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format"""
self.run_smoke_test('mp3', sample_rate, num_channels, compression=bit_rate)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)), name_func=name_func)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""Run smoke test on vorbis format"""
self.run_smoke_test('vorbis', sample_rate, num_channels, compression=quality_level)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=name_func)
def test_flac(self, sample_rate, num_channels, compression_level):
"""Run smoke test on flac format"""
self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment