Unverified Commit 56ab0368 authored by Joel Frank's avatar Joel Frank Committed by GitHub
Browse files

MFCC test refactor (#1618)

parent 0e513208
...@@ -127,41 +127,58 @@ class Tester(common_utils.TorchaudioTestCase): ...@@ -127,41 +127,58 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all()) self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
self.assertEqual(fb_matrix_transform.fb.size(), (400, 100)) self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
def test_mfcc(self): def test_mfcc_defaults(self):
audio_orig = self.waveform.clone() """Check the default configuration of the MFCC transform.
audio_scaled = self.scale(audio_orig) # (1, 16000) """
sample_rate = 16000 sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_mfcc = 40 n_mfcc = 40
n_mels = 128
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc, n_mfcc=n_mfcc,
norm='ortho') norm='ortho')
# check defaults torch_mfcc = mfcc_transform(audio) # (1, 40, 81)
torch_mfcc = mfcc_transform(audio_scaled) # (1, 40, 321) self.assertEqual(torch_mfcc.dim(), 3)
self.assertTrue(torch_mfcc.dim() == 3) self.assertEqual(torch_mfcc.shape[1], n_mfcc)
self.assertTrue(torch_mfcc.shape[1] == n_mfcc) self.assertEqual(torch_mfcc.shape[2], 81)
self.assertTrue(torch_mfcc.shape[2] == 321)
# check melkwargs are passed through def test_mfcc_kwargs_passthrough(self):
"""Check kwargs get correctly passed to the MelSpectrogram transform.
"""
sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_mfcc = 40
melkwargs = {'win_length': 200} melkwargs = {'win_length': 200}
mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate, mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc, n_mfcc=n_mfcc,
norm='ortho', norm='ortho',
melkwargs=melkwargs) melkwargs=melkwargs)
torch_mfcc2 = mfcc_transform2(audio_scaled) # (1, 40, 641) torch_mfcc = mfcc_transform(audio) # (1, 40, 161)
self.assertTrue(torch_mfcc2.shape[2] == 641) self.assertEqual(torch_mfcc.shape[2], 161)
def test_mfcc_norms(self):
"""Check if MFCC-DCT norms work correctly.
"""
sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_mfcc = 40
n_mels = 128
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm='ortho')
# check norms work correctly # check norms work correctly
mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate, mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc, n_mfcc=n_mfcc,
norm=None) norm=None)
torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled) # (1, 40, 321) torch_mfcc_norm_none = mfcc_transform_norm_none(audio) # (1, 40, 81)
norm_check = torch_mfcc.clone() norm_check = mfcc_transform(audio)
norm_check[:, 0, :] *= math.sqrt(n_mels) * 2 norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2 norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
self.assertTrue(torch_mfcc_norm_none.allclose(norm_check)) self.assertEqual(torch_mfcc_norm_none, norm_check)
def test_resample_size(self): def test_resample_size(self):
input_path = common_utils.get_asset_path('sinewave.wav') input_path = common_utils.get_asset_path('sinewave.wav')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment