Commit af2c2bf7 authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

Remove numpy dependency from test_transforms.py, skip tests with missing dependencies. (#112)

parent a422f3fe
from __future__ import print_function from __future__ import print_function
import math
import os import os
import torch import torch
import torchaudio import torchaudio
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
import torchaudio.transforms as transforms import torchaudio.transforms as transforms
import numpy as np
import unittest import unittest
if IMPORT_LIBROSA:
import librosa
if IMPORT_SCIPY:
import scipy
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
...@@ -13,7 +21,7 @@ class Tester(unittest.TestCase): ...@@ -13,7 +21,7 @@ class Tester(unittest.TestCase):
sr = 16000 sr = 16000
freq = 440 freq = 440
volume = .3 volume = .3
sig = (torch.cos(2 * np.pi * torch.arange(0, 4 * sr).float() * freq / sr)) sig = (torch.cos(2 * math.pi * torch.arange(0, 4 * sr).float() * freq / sr))
sig.unsqueeze_(1) # (64000, 1) sig.unsqueeze_(1) # (64000, 1)
sig = (sig * volume * 2**31).long() sig = (sig * volume * 2**31).long()
# file for stereo stft test # file for stereo stft test
...@@ -27,8 +35,7 @@ class Tester(unittest.TestCase): ...@@ -27,8 +35,7 @@ class Tester(unittest.TestCase):
result = transforms.Scale()(audio_orig) result = transforms.Scale()(audio_orig)
self.assertTrue(result.min() >= -1. and result.max() <= 1.) self.assertTrue(result.min() >= -1. and result.max() <= 1.)
maxminmax = np.abs( maxminmax = float(max(abs(audio_orig.min()), abs(audio_orig.max())))
[audio_orig.min(), audio_orig.max()]).max().astype(np.float)
result = transforms.Scale(factor=maxminmax)(audio_orig) result = transforms.Scale(factor=maxminmax)(audio_orig)
self.assertTrue((result.min() == -1. or result.max() == 1.) and self.assertTrue((result.min() == -1. or result.max() == 1.) and
result.min() >= -1. and result.max() <= 1.) result.min() >= -1. and result.max() <= 1.)
...@@ -91,14 +98,13 @@ class Tester(unittest.TestCase): ...@@ -91,14 +98,13 @@ class Tester(unittest.TestCase):
audio_orig = self.sig.clone() audio_orig = self.sig.clone()
length_orig = audio_orig.size(0) length_orig = audio_orig.size(0)
length_new = int(length_orig * 1.2) length_new = int(length_orig * 1.2)
maxminmax = np.abs( maxminmax = float(max(abs(audio_orig.min()), abs(audio_orig.max())))
[audio_orig.min(), audio_orig.max()]).max().astype(np.float)
tset = (transforms.Scale(factor=maxminmax), tset = (transforms.Scale(factor=maxminmax),
transforms.PadTrim(max_len=length_new, channels_first=False)) transforms.PadTrim(max_len=length_new, channels_first=False))
result = transforms.Compose(tset)(audio_orig) result = transforms.Compose(tset)(audio_orig)
self.assertTrue(np.abs([result.min(), result.max()]).max() == 1.) self.assertTrue(max(abs(result.min()), abs(result.max())) == 1.)
self.assertTrue(result.size(0) == length_new) self.assertTrue(result.size(0) == length_new)
...@@ -194,87 +200,107 @@ class Tester(unittest.TestCase): ...@@ -194,87 +200,107 @@ class Tester(unittest.TestCase):
torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled) torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)
norm_check = torch_mfcc.clone() norm_check = torch_mfcc.clone()
norm_check[:, :, 0] *= np.sqrt(n_mels) * 2 norm_check[:, :, 0] *= math.sqrt(n_mels) * 2
norm_check[:, :, 1:] *= np.sqrt(n_mels / 2) * 2 norm_check[:, :, 1:] *= math.sqrt(n_mels / 2) * 2
self.assertTrue(torch_mfcc_norm_none.allclose(norm_check)) self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))
@unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available')
def test_librosa_consistency(self): def test_librosa_consistency(self):
try: def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
import librosa input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
import scipy sound, sample_rate = torchaudio.load(input_path)
except ImportError: sound_librosa = sound.cpu().numpy().squeeze().T # squeeze batch and channel first
return
# test core spectrogram
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop=hop_length, power=2)
sound, sample_rate = torchaudio.load(input_path) out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
sound_librosa = sound.cpu().numpy().squeeze().T # squeeze batch and channel first n_fft=n_fft,
hop_length=hop_length,
n_fft = 400 power=2)
hop_length = 200
power = 2.0 out_torch = spect_transform(sound).squeeze().cpu().t()
n_mels = 128 self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))
n_mfcc = 40
sample_rate = 16000 # test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(sr=sample_rate, window=torch.hann_window,
# test core spectrogram hop=hop_length, n_mels=n_mels, n_fft=n_fft)
spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop=hop_length, power=2) librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
n_fft=n_fft, htk=True, norm=None)
hop_length=hop_length, torch_mel = melspect_transform(sound).squeeze().cpu().t()
power=2)
# lower tolerance, think it's double vs. float
out_torch = spect_transform(sound).squeeze().cpu().numpy().T self.assertTrue(torch.allclose(torch_mel.type(torch.double), torch.from_numpy(librosa_mel), atol=5e-3))
self.assertTrue(np.allclose(out_torch, out_librosa, atol=1e-5))
# test s2db
# test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(sr=sample_rate, window=torch.hann_window, db_transform = torchaudio.transforms.SpectrogramToDB("power", 80.)
hop=hop_length, n_mels=n_mels, n_fft=n_fft) db_torch = db_transform(spect_transform(sound)).squeeze().cpu().t()
librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate, db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))
htk=True, norm=None)
torch_mel = melspect_transform(sound).squeeze().cpu().numpy().T db_torch = db_transform(melspect_transform(sound)).squeeze().cpu().t()
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
# lower tolerance, think it's double vs. float
self.assertTrue(np.allclose(torch_mel, librosa_mel, atol=5e-3)) self.assertTrue(torch.allclose(db_torch.type(torch.double), torch.from_numpy(db_librosa), atol=5e-3))
# test s2db # test MFCC
melkwargs = {'hop': hop_length, 'n_fft': n_fft}
db_transform = torchaudio.transforms.SpectrogramToDB("power", 80.) mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate,
db_torch = db_transform(spect_transform(sound)).squeeze().cpu().numpy().T n_mfcc=n_mfcc,
db_librosa = librosa.core.spectrum.power_to_db(out_librosa) norm='ortho',
self.assertTrue(np.allclose(db_torch, db_librosa, atol=5e-3)) melkwargs=melkwargs)
db_torch = db_transform(melspect_transform(sound)).squeeze().cpu().numpy().T # librosa.feature.mfcc doesn't pass kwargs properly since some of the
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel) # kwargs for melspectrogram and mfcc are the same. We just follow the
# function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
self.assertTrue(np.allclose(db_torch, db_librosa, atol=5e-3)) # to mirror this function call with correct args:
# test MFCC # librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
melkwargs = {'hop': hop_length, 'n_fft': n_fft} # sr=sample_rate,
mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate, # n_mfcc = n_mfcc,
n_mfcc=n_mfcc, # hop_length=hop_length,
norm='ortho', # n_fft=n_fft,
melkwargs=melkwargs) # htk=True,
# norm=None,
# librosa.feature.mfcc doesn't pass kwargs properly since some of the # n_mels=n_mels)
# kwargs for melspectrogram and mfcc are the same. We just follow the
# function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
# to mirror this function call with correct args: torch_mfcc = mfcc_transform(sound).squeeze().cpu().t()
# librosa_mfcc = librosa.feature.mfcc(y=sound_librosa, self.assertTrue(torch.allclose(torch_mfcc.type(torch.double), torch.from_numpy(librosa_mfcc), atol=5e-3))
# sr=sample_rate,
# n_mfcc = n_mfcc, kwargs1 = {
# hop_length=hop_length, 'n_fft': 400,
# n_fft=n_fft, 'hop_length': 200,
# htk=True, 'power': 2.0,
# norm=None, 'n_mels': 128,
# n_mels=n_mels) 'n_mfcc': 40,
'sample_rate': 16000
librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc] }
torch_mfcc = mfcc_transform(sound).squeeze().cpu().numpy().T
kwargs2 = {
self.assertTrue(np.allclose(torch_mfcc, librosa_mfcc, atol=5e-3)) 'n_fft': 600,
'hop_length': 100,
'power': 2.0,
'n_mels': 128,
'n_mfcc': 20,
'sample_rate': 16000
}
kwargs3 = {
'n_fft': 200,
'hop_length': 50,
'power': 2.0,
'n_mels': 128,
'n_mfcc': 50,
'sample_rate': 24000
}
_test_librosa_consistency_helper(**kwargs1)
_test_librosa_consistency_helper(**kwargs2)
_test_librosa_consistency_helper(**kwargs3)
if __name__ == '__main__': if __name__ == '__main__':
......
import sys
PY3 = sys.version_info > (3, 0)
PY34 = sys.version_info >= (3, 4)
def _check_module_exists(name):
r"""Returns if a top-level module with :attr:`name` exists *without**
importing it. This is generally safer than try-catch block around a
`import X`. It avoids third party libraries breaking assumptions of some of
our tests, e.g., setting multiprocessing start method when imported
(see librosa/#747, torchvision/#544).
"""
if not PY3: # Python 2
import imp
try:
imp.find_module(name)
return True
except ImportError:
return False
elif not PY34: # Python [3, 3.4)
import importlib
loader = importlib.find_loader(name)
return loader is not None
else: # Python >= 3.4
import importlib
import importlib.util
spec = importlib.util.find_spec(name)
return spec is not None
IMPORT_SCIPY = _check_module_exists('scipy')
# On Py2, importing librosa 0.6.1 triggers a TypeError (if using newest joblib)
# see librosa/librosa#729.
# TODO: allow Py2 when librosa 0.6.2 releases
IMPORT_LIBROSA = _check_module_exists('librosa') and PY3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment