Commit af2c2bf7 authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

Remove numpy dependency from test_transforms.py, skip tests with missing dependencies. (#112)

parent a422f3fe
from __future__ import print_function
import math
import os
import torch
import torchaudio
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
import torchaudio.transforms as transforms
import numpy as np
import unittest
if IMPORT_LIBROSA:
import librosa
if IMPORT_SCIPY:
import scipy
class Tester(unittest.TestCase):
......@@ -13,7 +21,7 @@ class Tester(unittest.TestCase):
sr = 16000
freq = 440
volume = .3
sig = (torch.cos(2 * np.pi * torch.arange(0, 4 * sr).float() * freq / sr))
sig = (torch.cos(2 * math.pi * torch.arange(0, 4 * sr).float() * freq / sr))
sig.unsqueeze_(1) # (64000, 1)
sig = (sig * volume * 2**31).long()
# file for stereo stft test
......@@ -27,8 +35,7 @@ class Tester(unittest.TestCase):
result = transforms.Scale()(audio_orig)
self.assertTrue(result.min() >= -1. and result.max() <= 1.)
maxminmax = np.abs(
[audio_orig.min(), audio_orig.max()]).max().astype(np.float)
maxminmax = float(max(abs(audio_orig.min()), abs(audio_orig.max())))
result = transforms.Scale(factor=maxminmax)(audio_orig)
self.assertTrue((result.min() == -1. or result.max() == 1.) and
result.min() >= -1. and result.max() <= 1.)
......@@ -91,14 +98,13 @@ class Tester(unittest.TestCase):
audio_orig = self.sig.clone()
length_orig = audio_orig.size(0)
length_new = int(length_orig * 1.2)
maxminmax = np.abs(
[audio_orig.min(), audio_orig.max()]).max().astype(np.float)
maxminmax = float(max(abs(audio_orig.min()), abs(audio_orig.max())))
tset = (transforms.Scale(factor=maxminmax),
transforms.PadTrim(max_len=length_new, channels_first=False))
result = transforms.Compose(tset)(audio_orig)
self.assertTrue(np.abs([result.min(), result.max()]).max() == 1.)
self.assertTrue(max(abs(result.min()), abs(result.max())) == 1.)
self.assertTrue(result.size(0) == length_new)
......@@ -194,87 +200,107 @@ class Tester(unittest.TestCase):
torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)
norm_check = torch_mfcc.clone()
norm_check[:, :, 0] *= np.sqrt(n_mels) * 2
norm_check[:, :, 1:] *= np.sqrt(n_mels / 2) * 2
norm_check[:, :, 0] *= math.sqrt(n_mels) * 2
norm_check[:, :, 1:] *= math.sqrt(n_mels / 2) * 2
self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))
@unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available')
def test_librosa_consistency(self):
try:
import librosa
import scipy
except ImportError:
return
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
sound, sample_rate = torchaudio.load(input_path)
sound_librosa = sound.cpu().numpy().squeeze().T # squeeze batch and channel first
n_fft = 400
hop_length = 200
power = 2.0
n_mels = 128
n_mfcc = 40
sample_rate = 16000
# test core spectrogram
spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop=hop_length, power=2)
out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
n_fft=n_fft,
hop_length=hop_length,
power=2)
out_torch = spect_transform(sound).squeeze().cpu().numpy().T
self.assertTrue(np.allclose(out_torch, out_librosa, atol=1e-5))
# test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(sr=sample_rate, window=torch.hann_window,
hop=hop_length, n_mels=n_mels, n_fft=n_fft)
librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
htk=True, norm=None)
torch_mel = melspect_transform(sound).squeeze().cpu().numpy().T
# lower tolerance, think it's double vs. float
self.assertTrue(np.allclose(torch_mel, librosa_mel, atol=5e-3))
# test s2db
db_transform = torchaudio.transforms.SpectrogramToDB("power", 80.)
db_torch = db_transform(spect_transform(sound)).squeeze().cpu().numpy().T
db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
self.assertTrue(np.allclose(db_torch, db_librosa, atol=5e-3))
db_torch = db_transform(melspect_transform(sound)).squeeze().cpu().numpy().T
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
self.assertTrue(np.allclose(db_torch, db_librosa, atol=5e-3))
# test MFCC
melkwargs = {'hop': hop_length, 'n_fft': n_fft}
mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate,
n_mfcc=n_mfcc,
norm='ortho',
melkwargs=melkwargs)
# librosa.feature.mfcc doesn't pass kwargs properly since some of the
# kwargs for melspectrogram and mfcc are the same. We just follow the
# function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
# to mirror this function call with correct args:
# librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
# sr=sample_rate,
# n_mfcc = n_mfcc,
# hop_length=hop_length,
# n_fft=n_fft,
# htk=True,
# norm=None,
# n_mels=n_mels)
librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
torch_mfcc = mfcc_transform(sound).squeeze().cpu().numpy().T
self.assertTrue(np.allclose(torch_mfcc, librosa_mfcc, atol=5e-3))
def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
sound, sample_rate = torchaudio.load(input_path)
sound_librosa = sound.cpu().numpy().squeeze().T # squeeze batch and channel first
# test core spectrogram
spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop=hop_length, power=2)
out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
n_fft=n_fft,
hop_length=hop_length,
power=2)
out_torch = spect_transform(sound).squeeze().cpu().t()
self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))
# test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(sr=sample_rate, window=torch.hann_window,
hop=hop_length, n_mels=n_mels, n_fft=n_fft)
librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
htk=True, norm=None)
torch_mel = melspect_transform(sound).squeeze().cpu().t()
# lower tolerance, think it's double vs. float
self.assertTrue(torch.allclose(torch_mel.type(torch.double), torch.from_numpy(librosa_mel), atol=5e-3))
# test s2db
db_transform = torchaudio.transforms.SpectrogramToDB("power", 80.)
db_torch = db_transform(spect_transform(sound)).squeeze().cpu().t()
db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))
db_torch = db_transform(melspect_transform(sound)).squeeze().cpu().t()
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
self.assertTrue(torch.allclose(db_torch.type(torch.double), torch.from_numpy(db_librosa), atol=5e-3))
# test MFCC
melkwargs = {'hop': hop_length, 'n_fft': n_fft}
mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate,
n_mfcc=n_mfcc,
norm='ortho',
melkwargs=melkwargs)
# librosa.feature.mfcc doesn't pass kwargs properly since some of the
# kwargs for melspectrogram and mfcc are the same. We just follow the
# function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
# to mirror this function call with correct args:
# librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
# sr=sample_rate,
# n_mfcc = n_mfcc,
# hop_length=hop_length,
# n_fft=n_fft,
# htk=True,
# norm=None,
# n_mels=n_mels)
librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
torch_mfcc = mfcc_transform(sound).squeeze().cpu().t()
self.assertTrue(torch.allclose(torch_mfcc.type(torch.double), torch.from_numpy(librosa_mfcc), atol=5e-3))
kwargs1 = {
'n_fft': 400,
'hop_length': 200,
'power': 2.0,
'n_mels': 128,
'n_mfcc': 40,
'sample_rate': 16000
}
kwargs2 = {
'n_fft': 600,
'hop_length': 100,
'power': 2.0,
'n_mels': 128,
'n_mfcc': 20,
'sample_rate': 16000
}
kwargs3 = {
'n_fft': 200,
'hop_length': 50,
'power': 2.0,
'n_mels': 128,
'n_mfcc': 50,
'sample_rate': 24000
}
_test_librosa_consistency_helper(**kwargs1)
_test_librosa_consistency_helper(**kwargs2)
_test_librosa_consistency_helper(**kwargs3)
if __name__ == '__main__':
......
import sys
PY3 = sys.version_info > (3, 0)
PY34 = sys.version_info >= (3, 4)
def _check_module_exists(name):
r"""Returns if a top-level module with :attr:`name` exists *without**
importing it. This is generally safer than try-catch block around a
`import X`. It avoids third party libraries breaking assumptions of some of
our tests, e.g., setting multiprocessing start method when imported
(see librosa/#747, torchvision/#544).
"""
if not PY3: # Python 2
import imp
try:
imp.find_module(name)
return True
except ImportError:
return False
elif not PY34: # Python [3, 3.4)
import importlib
loader = importlib.find_loader(name)
return loader is not None
else: # Python >= 3.4
import importlib
import importlib.util
spec = importlib.util.find_spec(name)
return spec is not None
IMPORT_SCIPY = _check_module_exists('scipy')
# On Py2, importing librosa 0.6.1 triggers a TypeError (if using newest joblib)
# see librosa/librosa#729.
# TODO: allow Py2 when librosa 0.6.2 releases
IMPORT_LIBROSA = _check_module_exists('librosa') and PY3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment