Unverified Commit c6bca702 authored by moto's avatar moto Committed by GitHub
Browse files

Extract librosa tests from test_transforms to the dedicated test module (#485)

parent 2554f826
"""Test suites for numerical compatibility with librosa"""
import os
import unittest
import torch
import torchaudio
import torchaudio.functional as F
from torchaudio.common_utils import IMPORT_LIBROSA
if IMPORT_LIBROSA:
import numpy as np
import librosa
import scipy
import pytest
import common_utils
class TestFunctional(unittest.TestCase):
class _LibrosaMixin:
"""Automatically skip tests if librosa is not available"""
def setUp(self):
super().setUp()
if not IMPORT_LIBROSA:
raise unittest.SkipTest('Librosa not available')
def test_griffinlim(self):
class TestFunctional(_LibrosaMixin, unittest.TestCase):
"""Test suite for functions in `functional` module."""
def test_griffinlim(self):
# NOTE: This test is flaky without a fixed random seed
# See https://github.com/pytorch/audio/issues/382
torch.random.manual_seed(42)
......@@ -46,10 +54,6 @@ class TestFunctional(unittest.TestCase):
assert torch.allclose(ta_out, lr_out, atol=5e-5)
def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0):
# Using a decorator here causes parametrize to fail on Python 2
if not IMPORT_LIBROSA:
raise unittest.SkipTest('Librosa is not available')
librosa_fb = librosa.filters.mel(sr=sample_rate,
n_fft=n_fft,
n_mels=n_mels,
......@@ -141,3 +145,202 @@ def test_phase_vocoder(complex_specgrams, rate, hop_length):
complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1]
assert np.allclose(complex_stretch, expected_complex_stretch, atol=1e-5)
def _load_audio_asset(*asset_paths, **kwargs):
file_path = os.path.join(common_utils.TEST_DIR_PATH, 'assets', *asset_paths)
sound, sample_rate = torchaudio.load(file_path, **kwargs)
return sound, sample_rate
def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
sound, sample_rate = _load_audio_asset('sinewave.wav')
sound_librosa = sound.cpu().numpy().squeeze() # (64000)
# test core spectrogram
spect_transform = torchaudio.transforms.Spectrogram(
n_fft=n_fft, hop_length=hop_length, power=power)
out_librosa, _ = librosa.core.spectrum._spectrogram(
y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=power)
out_torch = spect_transform(sound).squeeze().cpu()
assert torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5)
# test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
librosa_mel = librosa.feature.melspectrogram(
y=sound_librosa, sr=sample_rate, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels, htk=True, norm=None)
librosa_mel_tensor = torch.from_numpy(librosa_mel)
torch_mel = melspect_transform(sound).squeeze().cpu()
assert torch.allclose(
torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3)
# test s2db
power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
assert torch.allclose(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3)
mag_to_db_transform = torchaudio.transforms.AmplitudeToDB('magnitude', 80.)
mag_to_db_torch = mag_to_db_transform(torch.abs(sound)).squeeze().cpu()
mag_to_db_librosa = librosa.core.spectrum.amplitude_to_db(sound_librosa)
assert torch.allclose(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3)
power_to_db_torch = power_to_db_transform(melspect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
db_librosa_tensor = torch.from_numpy(db_librosa)
assert torch.allclose(
power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3)
# test MFCC
melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
mfcc_transform = torchaudio.transforms.MFCC(
sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', melkwargs=melkwargs)
# librosa.feature.mfcc doesn't pass kwargs properly since some of the
# kwargs for melspectrogram and mfcc are the same. We just follow the
# function body in
# https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
# to mirror this function call with correct args:
#
# librosa_mfcc = librosa.feature.mfcc(
# y=sound_librosa, sr=sample_rate, n_mfcc = n_mfcc,
# hop_length=hop_length, n_fft=n_fft, htk=True, norm=None, n_mels=n_mels)
librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
torch_mfcc = mfcc_transform(sound).squeeze().cpu()
assert torch.allclose(
torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3)
class TestTransforms(_LibrosaMixin, unittest.TestCase):
"""Test suite for functions in `transforms` module."""
def test_basics1(self):
kwargs = {
'n_fft': 400,
'hop_length': 200,
'power': 2.0,
'n_mels': 128,
'n_mfcc': 40,
'sample_rate': 16000
}
_test_compatibilities(**kwargs)
def test_basics2(self):
kwargs = {
'n_fft': 600,
'hop_length': 100,
'power': 2.0,
'n_mels': 128,
'n_mfcc': 20,
'sample_rate': 16000
}
_test_compatibilities(**kwargs)
# NOTE: Test passes offline, but fails on TravisCI, see #372.
@unittest.skipIf(
os.environ.get('CI') == 'true' and os.environ.get('TRAVIS') == 'true',
'Test is known to fail on TravisCI')
def test_basics3(self):
kwargs = {
'n_fft': 200,
'hop_length': 50,
'power': 2.0,
'n_mels': 128,
'n_mfcc': 50,
'sample_rate': 24000
}
_test_compatibilities(**kwargs)
def test_basics4(self):
kwargs = {
'n_fft': 400,
'hop_length': 200,
'power': 3.0,
'n_mels': 128,
'n_mfcc': 40,
'sample_rate': 16000
}
_test_compatibilities(**kwargs)
@unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
@common_utils.AudioBackendScope("sox")
def test_MelScale(self):
"""MelScale transform is comparable to that of librosa"""
n_fft = 2048
n_mels = 256
hop_length = n_fft // 4
# Prepare spectrogram input. We use torchaudio to compute one.
sound, sample_rate = _load_audio_asset('whitenoise_1min.mp3')
sound = sound.mean(dim=0, keepdim=True)
spec_ta = F.spectrogram(
sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
hop_length=hop_length, win_length=n_fft, power=2, normalized=False)
spec_lr = spec_ta.cpu().numpy().squeeze()
# Perform MelScale with torchaudio and librosa
melspec_ta = torchaudio.transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_ta)
melspec_lr = librosa.feature.melspectrogram(
S=spec_lr, sr=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None)
# Note: Using relaxed rtol instead of atol
assert torch.allclose(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), rtol=1e-3)
def test_InverseMelScale(self):
"""InverseMelScale transform is comparable to that of librosa"""
n_fft = 2048
n_mels = 256
n_stft = n_fft // 2 + 1
hop_length = n_fft // 4
# Prepare mel spectrogram input. We use torchaudio to compute one.
sound, sample_rate = _load_audio_asset(
'steam-train-whistle-daniel_simon.wav', offset=2**10, num_frames=2**14)
sound = sound.mean(dim=0, keepdim=True)
spec_orig = F.spectrogram(
sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
hop_length=hop_length, win_length=n_fft, power=2, normalized=False)
melspec_ta = torchaudio.transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_orig)
melspec_lr = melspec_ta.cpu().numpy().squeeze()
# Perform InverseMelScale with torch audio and librosa
spec_ta = torchaudio.transforms.InverseMelScale(
n_stft, n_mels=n_mels, sample_rate=sample_rate)(melspec_ta)
spec_lr = librosa.feature.inverse.mel_to_stft(
melspec_lr, sr=sample_rate, n_fft=n_fft, power=2.0, htk=True, norm=None)
spec_lr = torch.from_numpy(spec_lr[None, ...])
# Align dimensions
# librosa does not return power spectrogram while torchaudio returns power spectrogram
spec_orig = spec_orig.sqrt()
spec_ta = spec_ta.sqrt()
threshold = 2.0
# This threshold was choosen empirically, based on the following observation
#
# torch.dist(spec_lr, spec_ta, p=float('inf'))
# >>> tensor(1.9666)
#
# The spectrograms reconstructed by librosa and torchaudio are not comparable elementwise.
# This is because they use different approximation algorithms and resulting values can live
# in different magnitude. (although most of them are very close)
# See
# https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm
# https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf
# distance over frequencies.
assert torch.allclose(spec_ta, spec_lr, atol=threshold)
threshold = 1700.0
# This threshold was choosen empirically, based on the following observations
#
# torch.dist(spec_orig, spec_ta, p=1)
# >>> tensor(1644.3516)
# torch.dist(spec_orig, spec_lr, p=1)
# >>> tensor(1420.7103)
# torch.dist(spec_lr, spec_ta, p=1)
# >>> tensor(943.2759)
assert torch.dist(spec_orig, spec_ta, p=1) < threshold
import math
import os
import unittest
import torch
import torchaudio
import torchaudio.transforms as transforms
import torchaudio.functional as F
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
import unittest
from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir
if IMPORT_LIBROSA:
import librosa
from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir
if IMPORT_SCIPY:
import scipy
RUN_CUDA = torch.cuda.is_available()
print("Run test with cuda:", RUN_CUDA)
......@@ -231,124 +226,6 @@ class Tester(unittest.TestCase):
self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))
@unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available')
def test_librosa_consistency(self):
def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
sound, sample_rate = torchaudio.load(input_path)
sound_librosa = sound.cpu().numpy().squeeze() # (64000)
# test core spectrogram
spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=power)
out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
n_fft=n_fft,
hop_length=hop_length,
power=power)
out_torch = spect_transform(sound).squeeze().cpu()
self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))
# test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
htk=True, norm=None)
librosa_mel_tensor = torch.from_numpy(librosa_mel)
torch_mel = melspect_transform(sound).squeeze().cpu()
self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))
# test s2db
power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
self.assertTrue(torch.allclose(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3))
mag_to_db_transform = torchaudio.transforms.AmplitudeToDB('magnitude', 80.)
mag_to_db_torch = mag_to_db_transform(torch.abs(sound)).squeeze().cpu()
mag_to_db_librosa = librosa.core.spectrum.amplitude_to_db(sound_librosa)
self.assertTrue(
torch.allclose(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3)
)
power_to_db_torch = power_to_db_transform(melspect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
db_librosa_tensor = torch.from_numpy(db_librosa)
self.assertTrue(
torch.allclose(power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3)
)
# test MFCC
melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm='ortho',
melkwargs=melkwargs)
# librosa.feature.mfcc doesn't pass kwargs properly since some of the
# kwargs for melspectrogram and mfcc are the same. We just follow the
# function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
# to mirror this function call with correct args:
# librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
# sr=sample_rate,
# n_mfcc = n_mfcc,
# hop_length=hop_length,
# n_fft=n_fft,
# htk=True,
# norm=None,
# n_mels=n_mels)
librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
torch_mfcc = mfcc_transform(sound).squeeze().cpu()
self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3))
kwargs1 = {
'n_fft': 400,
'hop_length': 200,
'power': 2.0,
'n_mels': 128,
'n_mfcc': 40,
'sample_rate': 16000
}
kwargs2 = {
'n_fft': 600,
'hop_length': 100,
'power': 2.0,
'n_mels': 128,
'n_mfcc': 20,
'sample_rate': 16000
}
kwargs3 = {
'n_fft': 200,
'hop_length': 50,
'power': 2.0,
'n_mels': 128,
'n_mfcc': 50,
'sample_rate': 24000
}
kwargs4 = {
'n_fft': 400,
'hop_length': 200,
'power': 3.0,
'n_mels': 128,
'n_mfcc': 40,
'sample_rate': 16000
}
_test_librosa_consistency_helper(**kwargs1)
_test_librosa_consistency_helper(**kwargs2)
# NOTE Test passes offline, but fails on CircleCI, see #372.
# _test_librosa_consistency_helper(**kwargs3)
_test_librosa_consistency_helper(**kwargs4)
def test_scriptmodule_Resample(self):
tensor = torch.rand((2, 1000))
sample_rate = 100.
......@@ -631,99 +508,5 @@ class Tester(unittest.TestCase):
self.assertTrue(torch.allclose(computed, expected))
class TestLibrosaConsistency(unittest.TestCase):
test_dirpath = None
test_dir = None
@classmethod
def setUpClass(cls):
cls.test_dirpath, cls.test_dir = create_temp_assets_dir()
def _to_librosa(self, sound):
return sound.cpu().numpy().squeeze()
def _get_sample_data(self, *asset_paths, **kwargs):
file_path = os.path.join(self.test_dirpath, 'assets', *asset_paths)
sound, sample_rate = torchaudio.load(file_path, **kwargs)
return sound.mean(dim=0, keepdim=True), sample_rate
@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available')
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_MelScale(self):
"""MelScale transform is comparable to that of librosa"""
n_fft = 2048
n_mels = 256
hop_length = n_fft // 4
# Prepare spectrogram input. We use torchaudio to compute one.
sound, sample_rate = self._get_sample_data('whitenoise_1min.mp3')
spec_ta = F.spectrogram(
sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
hop_length=hop_length, win_length=n_fft, power=2, normalized=False)
spec_lr = spec_ta.cpu().numpy().squeeze()
# Perform MelScale with torchaudio and librosa
melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_ta)
melspec_lr = librosa.feature.melspectrogram(
S=spec_lr, sr=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None)
# Note: Using relaxed rtol instead of atol
assert torch.allclose(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), rtol=1e-3)
@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available')
def test_InverseMelScale(self):
"""InverseMelScale transform is comparable to that of librosa"""
n_fft = 2048
n_mels = 256
n_stft = n_fft // 2 + 1
hop_length = n_fft // 4
# Prepare mel spectrogram input. We use torchaudio to compute one.
sound, sample_rate = self._get_sample_data(
'steam-train-whistle-daniel_simon.wav', offset=2**10, num_frames=2**14)
spec_orig = F.spectrogram(
sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
hop_length=hop_length, win_length=n_fft, power=2, normalized=False)
melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_orig)
melspec_lr = melspec_ta.cpu().numpy().squeeze()
# Perform InverseMelScale with torch audio and librosa
spec_ta = transforms.InverseMelScale(
n_stft, n_mels=n_mels, sample_rate=sample_rate)(melspec_ta)
spec_lr = librosa.feature.inverse.mel_to_stft(
melspec_lr, sr=sample_rate, n_fft=n_fft, power=2.0, htk=True, norm=None)
spec_lr = torch.from_numpy(spec_lr[None, ...])
# Align dimensions
# librosa does not return power spectrogram while torchaudio returns power spectrogram
spec_orig = spec_orig.sqrt()
spec_ta = spec_ta.sqrt()
threshold = 2.0
# This threshold was choosen empirically, based on the following observation
#
# torch.dist(spec_lr, spec_ta, p=float('inf'))
# >>> tensor(1.9666)
#
# The spectrograms reconstructed by librosa and torchaudio are not very comparable elementwise.
# This is because they use different approximation algorithms and resulting values can live
# in different magnitude. (although most of them are very close)
# See https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm
# See https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf
# distance over frequencies.
assert torch.allclose(spec_ta, spec_lr, atol=threshold)
threshold = 1700.0
# This threshold was choosen empirically, based on the following observations
#
# torch.dist(spec_orig, spec_ta, p=1)
# >>> tensor(1644.3516)
# torch.dist(spec_orig, spec_lr, p=1)
# >>> tensor(1420.7103)
# torch.dist(spec_lr, spec_ta, p=1)
# >>> tensor(943.2759)
assert torch.dist(spec_orig, spec_ta, p=1) < threshold
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment