Commit 54d1cede authored by PCerles's avatar PCerles Committed by Soumith Chintala
Browse files

Librosa consistency (#83)

* saving commit

* more fixes

* Update test.py

* flake8 style, tests moved

* typo

* DCT now a method of MFCC class. Changed DCT impelementation to @dhpollack suggestion"

* changes to address pr comments

* ordering

* fix dct docstring

* fix dct docstring

* DCT matrix needs to go on same device

* DCT matrix needs to go on same device

* log mfcc option
parent e874ef04
......@@ -136,15 +136,17 @@ class Tester(unittest.TestCase):
self.assertTrue(repr_test.__repr__())
def test_mel2(self):
top_db = 80.
s2db = transforms.SpectrogramToDB("power", top_db)
audio_orig = self.sig.clone() # (16000, 1)
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
mel_transform = transforms.MelSpectrogram()
# check defaults
spectrogram_torch = mel_transform(audio_scaled) # (1, 319, 40)
spectrogram_torch = s2db(mel_transform(audio_scaled)) # (1, 319, 40)
self.assertTrue(spectrogram_torch.dim() == 3)
self.assertTrue(spectrogram_torch.le(0.).all())
self.assertTrue(spectrogram_torch.ge(mel_transform.top_db).all())
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels)
# check correctness of filterbank conversion matrix
self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all())
......@@ -152,20 +154,18 @@ class Tester(unittest.TestCase):
# check options
kwargs = {"window": torch.hamming_window, "pad": 10, "ws": 500, "hop": 125, "n_fft": 800, "n_mels": 50}
mel_transform2 = transforms.MelSpectrogram(**kwargs)
spectrogram2_torch = mel_transform2(audio_scaled) # (1, 506, 50)
spectrogram2_torch = s2db(mel_transform2(audio_scaled)) # (1, 506, 50)
self.assertTrue(spectrogram2_torch.dim() == 3)
self.assertTrue(spectrogram2_torch.le(0.).all())
self.assertTrue(spectrogram2_torch.ge(mel_transform.top_db).all())
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram2_torch.size(-1), mel_transform2.n_mels)
self.assertTrue(mel_transform2.fm.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform2.fm.fb.sum(1).ge(0.).all())
# check on multi-channel audio
x_stereo, sr_stereo = torchaudio.load(self.test_filepath)
spectrogram_stereo = mel_transform(x_stereo)
spectrogram_stereo = s2db(mel_transform(x_stereo))
self.assertTrue(spectrogram_stereo.dim() == 3)
self.assertTrue(spectrogram_stereo.size(0) == 2)
self.assertTrue(spectrogram_stereo.le(0.).all())
self.assertTrue(spectrogram_stereo.ge(mel_transform.top_db).all())
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels)
# check filterbank matrix creation
fb_matrix_transform = transforms.MelScale(n_mels=100, sr=16000, f_max=None, f_min=0., n_stft=400)
......@@ -173,5 +173,120 @@ class Tester(unittest.TestCase):
self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
def test_mfcc(self):
audio_orig = self.sig.clone()
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
sample_rate = 16000
n_mfcc = 40
n_mels = 128
mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate,
n_mfcc=n_mfcc,
norm='ortho')
# check defaults
torch_mfcc = mfcc_transform(audio_scaled)
self.assertTrue(torch_mfcc.dim() == 3)
self.assertTrue(torch_mfcc.shape[2] == n_mfcc)
self.assertTrue(torch_mfcc.shape[1] == 321)
# check melkwargs are passed through
melkwargs = {'ws': 200}
mfcc_transform2 = torchaudio.transforms.MFCC(sr=sample_rate,
n_mfcc=n_mfcc,
norm='ortho',
melkwargs=melkwargs)
torch_mfcc2 = mfcc_transform2(audio_scaled)
self.assertTrue(torch_mfcc2.shape[1] == 641)
# check norms work correctly
mfcc_transform_norm_none = torchaudio.transforms.MFCC(sr=sample_rate,
n_mfcc=n_mfcc,
norm=None)
torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)
norm_check = torch_mfcc.clone()
norm_check[:, :, 0] *= np.sqrt(n_mels) * 2
norm_check[:, :, 1:] *= np.sqrt(n_mels / 2) * 2
self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))
def test_librosa_consistency(self):
try:
import librosa
import scipy
except ImportError:
return
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
sound, sample_rate = torchaudio.load(input_path)
sound_librosa = sound.cpu().numpy().squeeze().T # squeeze batch and channel first
n_fft = 400
hop_length = 200
power = 2.0
n_mels = 128
n_mfcc = 40
sample_rate = 16000
# test core spectrogram
spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop=hop_length, power=2)
out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
n_fft=n_fft,
hop_length=hop_length,
power=2)
out_torch = spect_transform(sound).squeeze().cpu().numpy().T
self.assertTrue(np.allclose(out_torch, out_librosa, atol=1e-5))
# test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(sr=sample_rate, window=torch.hann_window,
hop=hop_length, n_mels=n_mels, n_fft=n_fft)
librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
htk=True, norm=None)
torch_mel = melspect_transform(sound).squeeze().cpu().numpy().T
# lower tolerance, think it's double vs. float
self.assertTrue(np.allclose(torch_mel, librosa_mel, atol=5e-3))
# test s2db
db_transform = torchaudio.transforms.SpectrogramToDB("power", 80.)
db_torch = db_transform(spect_transform(sound)).squeeze().cpu().numpy().T
db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
self.assertTrue(np.allclose(db_torch, db_librosa, atol=5e-3))
db_torch = db_transform(melspect_transform(sound)).squeeze().cpu().numpy().T
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
self.assertTrue(np.allclose(db_torch, db_librosa, atol=5e-3))
# test MFCC
melkwargs = {'hop': hop_length, 'n_fft': n_fft}
mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate,
n_mfcc=n_mfcc,
norm='ortho',
melkwargs=melkwargs)
# librosa.feature.mfcc doesn't pass kwargs properly since some of the
# kwargs for melspectrogram and mfcc are the same. We just follow the
# function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
# to mirror this function call with correct args:
# librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
# sr=sample_rate,
# n_mfcc = n_mfcc,
# hop_length=hop_length,
# n_fft=n_fft,
# htk=True,
# norm=None,
# n_mels=n_mels)
librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
torch_mfcc = mfcc_transform(sound).squeeze().cpu().numpy().T
self.assertTrue(np.allclose(torch_mfcc, librosa_mfcc, atol=5e-3))
if __name__ == '__main__':
unittest.main()
......@@ -144,7 +144,6 @@ class LC2CL(object):
Returns:
tensor (Tensor): Tensor of audio signal with shape (CxL)
"""
return tensor.transpose(0, 1).contiguous()
......@@ -161,24 +160,28 @@ class Spectrogram(object):
"""Create a spectrogram from a raw audio signal
Args:
sr (int): sample rate of audio signal
ws (int): window size
n_fft (int, optional): size of fft, creates n_fft // 2 + 1 bins
ws (int): window size. default: n_fft
hop (int, optional): length of hop between STFT windows. default: ws // 2
n_fft (int, optional): size of fft, creates n_fft // 2 + 1 bins. default: ws
pad (int): two sided padding of signal
window (torch windowing function): default: torch.hann_window
power (int > 0 ) : Exponent for the magnitude spectrogram,
e.g., 1 for energy, 2 for power, etc.
normalize (bool) : whether to normalize by magnitude after stft
wkwargs (dict, optional): arguments for window function
"""
def __init__(self, ws=400, hop=None, n_fft=None,
pad=0, window=torch.hann_window, wkwargs=None):
self.window = window(ws) if wkwargs is None else window(ws, **wkwargs)
self.ws = ws
self.hop = hop if hop is not None else ws // 2
def __init__(self, n_fft=400, ws=None, hop=None,
pad=0, window=torch.hann_window,
power=2, normalize=False, wkwargs=None):
self.n_fft = n_fft
# number of fft bins. the returned STFT result will have n_fft // 2 + 1
# number of frequecies due to onesided=True in torch.stft
self.n_fft = n_fft if n_fft is not None else ws
self.ws = ws if ws is not None else n_fft
self.hop = hop if hop is not None else self.ws // 2
self.window = window(self.ws) if wkwargs is None else window(self.ws, **wkwargs)
self.pad = pad
self.power = power
self.normalize = normalize
self.wkwargs = wkwargs
def __call__(self, sig):
......@@ -199,11 +202,15 @@ class Spectrogram(object):
with torch.no_grad():
sig = torch.nn.functional.pad(sig, (self.pad, self.pad), "constant")
self.window = self.window.to(sig.device)
# default values are consistent with librosa.core.spectrum._spectrogram
spec_f = torch.stft(sig, self.n_fft, self.hop, self.ws,
self.window, center=False,
normalized=True, onesided=True).transpose(1, 2)
spec_f /= self.window.pow(2).sum().sqrt()
spec_f = spec_f.pow(2).sum(-1) # get power of "complex" tensor (c, l, n_fft)
self.window, center=True,
normalized=False, onesided=True,
pad_mode='reflect').transpose(1, 2)
if self.normalize:
spec_f /= self.window.pow(2).sum().sqrt()
spec_f = spec_f.pow(self.power).sum(-1) # get power of "complex" tensor (c, l, n_fft)
return spec_f
......@@ -224,7 +231,7 @@ class MelScale(object):
n_stft (int, optional): number of filter banks from stft. Calculated from first input
if `None` is given. See `n_fft` in `Spectrogram`.
"""
def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0., n_stft=None):
def __init__(self, n_mels=128, sr=16000, f_max=None, f_min=0., n_stft=None):
self.n_mels = n_mels
self.sr = sr
self.f_max = f_max if f_max is not None else sr // 2
......@@ -234,6 +241,9 @@ class MelScale(object):
def __call__(self, spec_f):
if self.fb is None:
self.fb = self._create_fb_matrix(spec_f.size(2)).to(spec_f.device)
else:
# need to ensure same device for dot product
self.fb = self.fb.to(spec_f.device)
spec_m = torch.matmul(spec_f, self.fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m
......@@ -268,38 +278,121 @@ class MelScale(object):
return 700. * (10**(mel / 2595.) - 1.)
def SPEC2DB(*args, **kwargs):
warn("SPEC2DB has been renamed to SpectogramToDB, please update your program")
return SpectogramToDB(*args, **kwargs)
class SpectogramToDB(object):
class SpectrogramToDB(object):
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
Args:
stype (str): scale of input spectrogram ("power" or "magnitude"). The
power being the elementwise square of the magnitude. default: "power"
top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
is -80.
is 80.
"""
def __init__(self, stype="power", top_db=None):
self.stype = stype
if top_db is not None and top_db > 0:
top_db = -top_db
if top_db < 0:
raise ValueError('top_db must be positive value')
self.top_db = top_db
self.multiplier = 10. if stype == "power" else 20.
self.amin = 1e-10
self.ref_value = 1.
self.db_multiplier = np.log10(np.maximum(self.amin, self.ref_value))
def __call__(self, spec):
# numerically stable implementation from librosa
# https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
spec_db = self.multiplier * torch.log10(torch.clamp(spec, min=self.amin))
spec_db -= self.multiplier * self.db_multiplier
spec_db = self.multiplier * torch.log10(spec / spec.max()) # power -> dB
if self.top_db is not None:
spec_db = torch.max(spec_db, spec_db.new_full((1,), self.top_db))
spec_db = torch.max(spec_db, spec_db.new_full((1,), spec_db.max() - self.top_db))
return spec_db
def MEL2(*args, **kwargs):
warn("MEL2 has been renamed to MelSpectrogram")
return MelSpectrogram(*args, **kwargs)
class MFCC(object):
"""Create the Mel-frequency cepstrum coefficients from an audio signal
By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
This is not the textbook implementation, but is implemented here to
give consistency with librosa.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
Args:
sr (int) : sample rate of audio signal
n_mfcc (int) : number of mfc coefficients to retain
dct_type (int) : type of DCT (discrete cosine transform) to use
norm (string) : norm to use
log_mels (bool) : whether to use log-mel spectrograms instead of db-scaled
melkwargs (dict, optional): arguments for MelSpectrogram
"""
def __init__(self, sr=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False,
melkwargs=None):
supported_dct_types = [2]
if dct_type not in supported_dct_types:
raise ValueError('DCT type not supported'.format(dct_type))
self.sr = sr
self.n_mfcc = n_mfcc
self.dct_type = dct_type
self.norm = norm
self.melkwargs = melkwargs
self.top_db = 80.
self.s2db = SpectrogramToDB("power", self.top_db)
if melkwargs is not None:
self.MelSpectrogram = MelSpectrogram(sr=self.sr, **melkwargs)
else:
self.MelSpectrogram = MelSpectrogram(sr=self.sr)
if self.n_mfcc > self.MelSpectrogram.n_mels:
raise ValueError('Cannot select more MFCC coefficients than # mel bins')
self.dct_mat = self.create_dct()
self.log_mels = log_mels
def create_dct(self):
"""
Creates a DCT transformation matrix with shape (num_mels, num_mfcc),
normalized depending on self.norm
Returns:
The transformation matrix, to be right-multiplied to row-wise data.
"""
outdim = self.n_mfcc
dim = self.MelSpectrogram.n_mels
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
n = np.arange(dim)
k = np.arange(outdim)[:, np.newaxis]
dct = np.cos(np.pi / dim * (n + 0.5) * k)
if self.norm == 'ortho':
dct[0] *= 1.0 / np.sqrt(2)
dct *= np.sqrt(2.0 / dim)
else:
dct *= 2
return torch.Tensor(dct.T)
def __call__(self, sig):
"""
Args:
sig (Tensor): Tensor of audio of size (channels [c], samples [n])
Returns:
spec_mel_db (Tensor): channels x hops x n_mels (c, l, m), where channels
is unchanged, hops is the number of hops, and n_mels is the
number of mel bins.
"""
mel_spect = self.MelSpectrogram(sig)
if self.log_mels:
log_offset = 1e-6
mel_spect = torch.log(mel_spect + log_offset)
else:
mel_spect = self.s2db(mel_spect)
mfcc = torch.matmul(mel_spect, self.dct_mat.to(mel_spect.device))
return mfcc
class MelSpectrogram(object):
......@@ -327,25 +420,24 @@ class MelSpectrogram(object):
>>> sig, sr = torchaudio.load("test.wav", normalization=True)
>>> spec_mel = transforms.MelSpectrogram(sr)(sig) # (c, l, m)
"""
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None, f_min=0., f_max=None,
pad=0, n_mels=40, window=torch.hann_window, wkwargs=None):
def __init__(self, sr=16000, n_fft=400, ws=None, hop=None, f_min=0., f_max=None,
pad=0, n_mels=128, window=torch.hann_window, wkwargs=None):
self.window = window
self.sr = sr
self.ws = ws
self.hop = hop if hop is not None else ws // 2
self.n_fft = n_fft # number of fourier bins (ws // 2 + 1 by default)
self.n_fft = n_fft
self.ws = ws if ws is not None else n_fft
self.hop = hop if hop is not None else self.ws // 2
self.pad = pad
self.n_mels = n_mels # number of mel frequency bins
self.wkwargs = wkwargs
self.top_db = -80.
self.f_max = f_max
self.f_min = f_min
self.spec = Spectrogram(self.ws, self.hop, self.n_fft,
self.pad, self.window, self.wkwargs)
self.spec = Spectrogram(n_fft=self.n_fft, ws=self.ws, hop=self.hop,
pad=self.pad, window=self.window, power=2,
normalize=False, wkwargs=self.wkwargs)
self.fm = MelScale(self.n_mels, self.sr, self.f_max, self.f_min)
self.s2db = SpectogramToDB("power", self.top_db)
self.transforms = Compose([
self.spec, self.fm, self.s2db,
self.spec, self.fm
])
def __call__(self, sig):
......@@ -354,14 +446,14 @@ class MelSpectrogram(object):
sig (Tensor): Tensor of audio of size (channels [c], samples [n])
Returns:
spec_mel_db (Tensor): channels x hops x n_mels (c, l, m), where channels
spec_mel (Tensor): channels x hops x n_mels (c, l, m), where channels
is unchanged, hops is the number of hops, and n_mels is the
number of mel bins.
"""
spec_mel_db = self.transforms(sig)
spec_mel = self.transforms(sig)
return spec_mel_db
return spec_mel
def MEL(*args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment