Unverified Commit 93cc6da7 authored by moto's avatar moto Committed by GitHub
Browse files

Adopt PyTorch's test util to librosa compatibilities test (#646)

parent 6fc8953c
......@@ -3,6 +3,7 @@ import os
import unittest
import torch
from torch.testing._internal.common_utils import TestCase
import torchaudio
import torchaudio.functional as F
from torchaudio.common_utils import IMPORT_LIBROSA
......@@ -17,15 +18,8 @@ import pytest
import common_utils
class _LibrosaMixin:
"""Automatically skip tests if librosa is not available"""
def setUp(self):
super().setUp()
if not IMPORT_LIBROSA:
raise unittest.SkipTest('Librosa not available')
class TestFunctional(_LibrosaMixin, unittest.TestCase):
@unittest.skipIf(not IMPORT_LIBROSA, "Librosa not available")
class TestFunctional(TestCase):
"""Test suite for functions in `functional` module."""
def test_griffinlim(self):
# NOTE: This test is flaky without a fixed random seed
......@@ -51,7 +45,7 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
momentum=momentum, init=init, length=length)
lr_out = torch.from_numpy(lr_out).unsqueeze(0)
torch.testing.assert_allclose(ta_out, lr_out, atol=5e-5, rtol=1e-5)
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)
def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0, norm=None):
librosa_fb = librosa.filters.mel(sr=sample_rate,
......@@ -69,8 +63,8 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
norm=norm)
for i_mel_bank in range(n_mels):
torch.testing.assert_allclose(fb[:, i_mel_bank], torch.tensor(librosa_fb[i_mel_bank]),
atol=1e-4, rtol=1e-5)
self.assertEqual(
fb[:, i_mel_bank], torch.tensor(librosa_fb[i_mel_bank]), atol=1e-4, rtol=1e-5)
def test_create_fb(self):
self._test_create_fb()
......@@ -101,7 +95,7 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
lr_out = librosa.core.power_to_db(spec.numpy())
lr_out = torch.from_numpy(lr_out)
torch.testing.assert_allclose(ta_out, lr_out, atol=5e-5, rtol=1e-5)
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)
# Amplitude to DB
multiplier = 20.0
......@@ -110,7 +104,7 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
lr_out = librosa.core.amplitude_to_db(spec.numpy())
lr_out = torch.from_numpy(lr_out)
torch.testing.assert_allclose(ta_out, lr_out, atol=5e-5, rtol=1e-5)
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)
@pytest.mark.parametrize('complex_specgrams', [
......@@ -161,73 +155,73 @@ def _load_audio_asset(*asset_paths, **kwargs):
return sound, sample_rate
def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
sound, sample_rate = _load_audio_asset('sinewave.wav')
sound_librosa = sound.cpu().numpy().squeeze() # (64000)
# test core spectrogram
spect_transform = torchaudio.transforms.Spectrogram(
n_fft=n_fft, hop_length=hop_length, power=power)
out_librosa, _ = librosa.core.spectrum._spectrogram(
y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=power)
out_torch = spect_transform(sound).squeeze().cpu()
torch.testing.assert_allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5)
# test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
librosa_mel = librosa.feature.melspectrogram(
y=sound_librosa, sr=sample_rate, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels, htk=True, norm=None)
librosa_mel_tensor = torch.from_numpy(librosa_mel)
torch_mel = melspect_transform(sound).squeeze().cpu()
torch.testing.assert_allclose(
torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3, rtol=1e-5)
# test s2db
power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
torch.testing.assert_allclose(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3, rtol=1e-5)
mag_to_db_transform = torchaudio.transforms.AmplitudeToDB('magnitude', 80.)
mag_to_db_torch = mag_to_db_transform(torch.abs(sound)).squeeze().cpu()
mag_to_db_librosa = librosa.core.spectrum.amplitude_to_db(sound_librosa)
torch.testing.assert_allclose(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3, rtol=1e-5)
power_to_db_torch = power_to_db_transform(melspect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
db_librosa_tensor = torch.from_numpy(db_librosa)
torch.testing.assert_allclose(
power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3, rtol=1e-5)
# test MFCC
melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
mfcc_transform = torchaudio.transforms.MFCC(
sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', melkwargs=melkwargs)
# librosa.feature.mfcc doesn't pass kwargs properly since some of the
# kwargs for melspectrogram and mfcc are the same. We just follow the
# function body in
# https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
# to mirror this function call with correct args:
#
# librosa_mfcc = librosa.feature.mfcc(
# y=sound_librosa, sr=sample_rate, n_mfcc = n_mfcc,
# hop_length=hop_length, n_fft=n_fft, htk=True, norm=None, n_mels=n_mels)
librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
torch_mfcc = mfcc_transform(sound).squeeze().cpu()
torch.testing.assert_allclose(
torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3, rtol=1e-5)
class TestTransforms(_LibrosaMixin, unittest.TestCase):
@unittest.skipIf(not IMPORT_LIBROSA, "Librosa not available")
class TestTransforms(TestCase):
"""Test suite for functions in `transforms` module."""
def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
sound, sample_rate = _load_audio_asset('sinewave.wav')
sound_librosa = sound.cpu().numpy().squeeze() # (64000)
# test core spectrogram
spect_transform = torchaudio.transforms.Spectrogram(
n_fft=n_fft, hop_length=hop_length, power=power)
out_librosa, _ = librosa.core.spectrum._spectrogram(
y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=power)
out_torch = spect_transform(sound).squeeze().cpu()
self.assertEqual(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5)
# test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
librosa_mel = librosa.feature.melspectrogram(
y=sound_librosa, sr=sample_rate, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels, htk=True, norm=None)
librosa_mel_tensor = torch.from_numpy(librosa_mel)
torch_mel = melspect_transform(sound).squeeze().cpu()
self.assertEqual(
torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3, rtol=1e-5)
# test s2db
power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
self.assertEqual(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3, rtol=1e-5)
mag_to_db_transform = torchaudio.transforms.AmplitudeToDB('magnitude', 80.)
mag_to_db_torch = mag_to_db_transform(torch.abs(sound)).squeeze().cpu()
mag_to_db_librosa = librosa.core.spectrum.amplitude_to_db(sound_librosa)
self.assertEqual(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3, rtol=1e-5)
power_to_db_torch = power_to_db_transform(melspect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
db_librosa_tensor = torch.from_numpy(db_librosa)
self.assertEqual(
power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3, rtol=1e-5)
# test MFCC
melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
mfcc_transform = torchaudio.transforms.MFCC(
sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', melkwargs=melkwargs)
# librosa.feature.mfcc doesn't pass kwargs properly since some of the
# kwargs for melspectrogram and mfcc are the same. We just follow the
# function body in
# https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
# to mirror this function call with correct args:
#
# librosa_mfcc = librosa.feature.mfcc(
# y=sound_librosa, sr=sample_rate, n_mfcc = n_mfcc,
# hop_length=hop_length, n_fft=n_fft, htk=True, norm=None, n_mels=n_mels)
librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
torch_mfcc = mfcc_transform(sound).squeeze().cpu()
self.assertEqual(
torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3, rtol=1e-5)
def test_basics1(self):
kwargs = {
'n_fft': 400,
......@@ -237,7 +231,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
'n_mfcc': 40,
'sample_rate': 16000
}
_test_compatibilities(**kwargs)
self.assert_compatibilities(**kwargs)
def test_basics2(self):
kwargs = {
......@@ -248,7 +242,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
'n_mfcc': 20,
'sample_rate': 16000
}
_test_compatibilities(**kwargs)
self.assert_compatibilities(**kwargs)
# NOTE: Test passes offline, but fails on TravisCI (and CircleCI), see #372.
@unittest.skipIf('CI' in os.environ, 'Test is known to fail on CI')
......@@ -261,7 +255,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
'n_mfcc': 50,
'sample_rate': 24000
}
_test_compatibilities(**kwargs)
self.assert_compatibilities(**kwargs)
def test_basics4(self):
kwargs = {
......@@ -272,7 +266,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
'n_mfcc': 40,
'sample_rate': 16000
}
_test_compatibilities(**kwargs)
self.assert_compatibilities(**kwargs)
@unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
@common_utils.AudioBackendScope("sox")
......@@ -295,7 +289,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
S=spec_lr, sr=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None)
# Note: Using relaxed rtol instead of atol
torch.testing.assert_allclose(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), atol=1e-8, rtol=1e-3)
self.assertEqual(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), atol=1e-8, rtol=1e-3)
def test_InverseMelScale(self):
"""InverseMelScale transform is comparable to that of librosa"""
......@@ -338,7 +332,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
# https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm
# https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf
# distance over frequencies.
torch.testing.assert_allclose(spec_ta, spec_lr, atol=threshold, rtol=1e-5)
self.assertEqual(spec_ta, spec_lr, atol=threshold, rtol=1e-5)
threshold = 1700.0
# This threshold was choosen empirically, based on the following observations
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment