Unverified Commit 7de5f98c authored by Jcaw's avatar Jcaw Committed by GitHub
Browse files

Parameterize librosa compatibility test (#1350)

Parameterize `test_create_fb` so each set of values are tested
independently. Also explicitly skip on older versions of librosa (< 0.7.2) when
`norm="slaney"`.
parent f4589714
...@@ -3,7 +3,7 @@ import unittest ...@@ -3,7 +3,7 @@ import unittest
from distutils.version import StrictVersion from distutils.version import StrictVersion
import torch import torch
from parameterized import parameterized from parameterized import parameterized, param
import torchaudio.functional as F import torchaudio.functional as F
from torchaudio._internal.module_utils import is_module_available from torchaudio._internal.module_utils import is_module_available
...@@ -46,15 +46,25 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -46,15 +46,25 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5) self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)
def _test_create_fb( @parameterized.expand([
self, n_mels=40, param(norm=norm, mel_scale=mel_scale, **p.kwargs)
sample_rate=22050, for p in [
n_fft=2048, param(),
fmin=0.0, param(n_mels=128, sample_rate=44100),
fmax=8000.0, param(n_mels=128, fmin=2000.0, fmax=5000.0),
norm=None, param(n_mels=56, fmin=100.0, fmax=9000.0),
mel_scale="htk", param(n_mels=56, fmin=800.0, fmax=900.0),
): param(n_mels=56, fmin=1900.0, fmax=900.0),
param(n_mels=10, fmin=1900.0, fmax=900.0),
]
for norm in [None, 'slaney']
for mel_scale in ['htk', 'slaney']
])
def test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048,
fmin=0.0, fmax=8000.0, norm=None, mel_scale="htk"):
if (norm == "slaney" and StrictVersion(librosa.__version__) < StrictVersion("0.7.2")):
self.skipTest('Test is known to fail with older versions of librosa.')
librosa_fb = librosa.filters.mel(sr=sample_rate, librosa_fb = librosa.filters.mel(sr=sample_rate,
n_fft=n_fft, n_fft=n_fft,
n_mels=n_mels, n_mels=n_mels,
...@@ -74,30 +84,6 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -74,30 +84,6 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assertEqual( self.assertEqual(
fb[:, i_mel_bank], torch.tensor(librosa_fb[i_mel_bank]), atol=1e-4, rtol=1e-5) fb[:, i_mel_bank], torch.tensor(librosa_fb[i_mel_bank]), atol=1e-4, rtol=1e-5)
def test_create_fb(self):
self._test_create_fb()
self._test_create_fb(n_mels=128, sample_rate=44100)
self._test_create_fb(n_mels=128, fmin=2000.0, fmax=5000.0)
self._test_create_fb(n_mels=56, fmin=100.0, fmax=9000.0)
self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0)
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)
self._test_create_fb(mel_scale="slaney")
self._test_create_fb(n_mels=128, sample_rate=44100, mel_scale="slaney")
self._test_create_fb(n_mels=128, fmin=2000.0, fmax=5000.0, mel_scale="slaney")
self._test_create_fb(n_mels=56, fmin=100.0, fmax=9000.0, mel_scale="slaney")
self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0, mel_scale="slaney")
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0, mel_scale="slaney")
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0, mel_scale="slaney")
if StrictVersion(librosa.__version__) < StrictVersion("0.7.2"):
return
self._test_create_fb(n_mels=128, sample_rate=44100, norm="slaney")
self._test_create_fb(n_mels=128, fmin=2000.0, fmax=5000.0, norm="slaney")
self._test_create_fb(n_mels=56, fmin=100.0, fmax=9000.0, norm="slaney")
self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0, norm="slaney")
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0, norm="slaney")
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0, norm="slaney")
def test_amplitude_to_DB(self): def test_amplitude_to_DB(self):
spec = torch.rand((6, 201)) spec = torch.rand((6, 201))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment