Unverified Commit 56ab0368 authored by Joel Frank's avatar Joel Frank Committed by GitHub
Browse files

MFCC test refactor (#1618)

parent 0e513208
......@@ -127,41 +127,58 @@ class Tester(common_utils.TorchaudioTestCase):
self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
def test_mfcc(self):
audio_orig = self.waveform.clone()
audio_scaled = self.scale(audio_orig) # (1, 16000)
def test_mfcc_defaults(self):
"""Check the default configuration of the MFCC transform.
"""
sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_mfcc = 40
n_mels = 128
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm='ortho')
# check defaults
torch_mfcc = mfcc_transform(audio_scaled) # (1, 40, 321)
self.assertTrue(torch_mfcc.dim() == 3)
self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
self.assertTrue(torch_mfcc.shape[2] == 321)
# check melkwargs are passed through
torch_mfcc = mfcc_transform(audio) # (1, 40, 81)
self.assertEqual(torch_mfcc.dim(), 3)
self.assertEqual(torch_mfcc.shape[1], n_mfcc)
self.assertEqual(torch_mfcc.shape[2], 81)
def test_mfcc_kwargs_passthrough(self):
"""Check kwargs get correctly passed to the MelSpectrogram transform.
"""
sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_mfcc = 40
melkwargs = {'win_length': 200}
mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm='ortho',
melkwargs=melkwargs)
torch_mfcc2 = mfcc_transform2(audio_scaled) # (1, 40, 641)
self.assertTrue(torch_mfcc2.shape[2] == 641)
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm='ortho',
melkwargs=melkwargs)
torch_mfcc = mfcc_transform(audio) # (1, 40, 161)
self.assertEqual(torch_mfcc.shape[2], 161)
def test_mfcc_norms(self):
"""Check if MFCC-DCT norms work correctly.
"""
sample_rate = 16000
audio = common_utils.get_whitenoise(sample_rate=sample_rate)
n_mfcc = 40
n_mels = 128
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm='ortho')
# check norms work correctly
mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm=None)
torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled) # (1, 40, 321)
torch_mfcc_norm_none = mfcc_transform_norm_none(audio) # (1, 40, 81)
norm_check = torch_mfcc.clone()
norm_check = mfcc_transform(audio)
norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))
self.assertEqual(torch_mfcc_norm_none, norm_check)
def test_resample_size(self):
input_path = common_utils.get_asset_path('sinewave.wav')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment