Unverified Commit 7de5f98c authored by Jcaw's avatar Jcaw Committed by GitHub
Browse files

Parameterize librosa compatibility test (#1350)

Parameterize `test_create_fb` so each set of values are tested
independently. Also explicitly skip on older versions of librosa (< 0.7.2) when
`norm="slaney"`.
parent f4589714
......@@ -3,7 +3,7 @@ import unittest
from distutils.version import StrictVersion
import torch
from parameterized import parameterized
from parameterized import parameterized, param
import torchaudio.functional as F
from torchaudio._internal.module_utils import is_module_available
......@@ -46,15 +46,25 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)
def _test_create_fb(
self, n_mels=40,
sample_rate=22050,
n_fft=2048,
fmin=0.0,
fmax=8000.0,
norm=None,
mel_scale="htk",
):
@parameterized.expand([
param(norm=norm, mel_scale=mel_scale, **p.kwargs)
for p in [
param(),
param(n_mels=128, sample_rate=44100),
param(n_mels=128, fmin=2000.0, fmax=5000.0),
param(n_mels=56, fmin=100.0, fmax=9000.0),
param(n_mels=56, fmin=800.0, fmax=900.0),
param(n_mels=56, fmin=1900.0, fmax=900.0),
param(n_mels=10, fmin=1900.0, fmax=900.0),
]
for norm in [None, 'slaney']
for mel_scale in ['htk', 'slaney']
])
def test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048,
fmin=0.0, fmax=8000.0, norm=None, mel_scale="htk"):
if (norm == "slaney" and StrictVersion(librosa.__version__) < StrictVersion("0.7.2")):
self.skipTest('Test is known to fail with older versions of librosa.')
librosa_fb = librosa.filters.mel(sr=sample_rate,
n_fft=n_fft,
n_mels=n_mels,
......@@ -74,30 +84,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assertEqual(
fb[:, i_mel_bank], torch.tensor(librosa_fb[i_mel_bank]), atol=1e-4, rtol=1e-5)
def test_create_fb(self):
self._test_create_fb()
self._test_create_fb(n_mels=128, sample_rate=44100)
self._test_create_fb(n_mels=128, fmin=2000.0, fmax=5000.0)
self._test_create_fb(n_mels=56, fmin=100.0, fmax=9000.0)
self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0)
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)
self._test_create_fb(mel_scale="slaney")
self._test_create_fb(n_mels=128, sample_rate=44100, mel_scale="slaney")
self._test_create_fb(n_mels=128, fmin=2000.0, fmax=5000.0, mel_scale="slaney")
self._test_create_fb(n_mels=56, fmin=100.0, fmax=9000.0, mel_scale="slaney")
self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0, mel_scale="slaney")
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0, mel_scale="slaney")
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0, mel_scale="slaney")
if StrictVersion(librosa.__version__) < StrictVersion("0.7.2"):
return
self._test_create_fb(n_mels=128, sample_rate=44100, norm="slaney")
self._test_create_fb(n_mels=128, fmin=2000.0, fmax=5000.0, norm="slaney")
self._test_create_fb(n_mels=56, fmin=100.0, fmax=9000.0, norm="slaney")
self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0, norm="slaney")
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0, norm="slaney")
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0, norm="slaney")
def test_amplitude_to_DB(self):
spec = torch.rand((6, 201))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment