Commit 8b616bce authored by David Pollack's avatar David Pollack Committed by Soumith Chintala
Browse files

fix MEL2 and update filterbank conversion matrix

parent a8d6a41b
...@@ -155,13 +155,26 @@ class Tester(unittest.TestCase): ...@@ -155,13 +155,26 @@ class Tester(unittest.TestCase):
audio_orig = self.sig.clone() # (16000, 1) audio_orig = self.sig.clone() # (16000, 1)
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1) audio_scaled = transforms.Scale()(audio_orig) # (16000, 1)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000) audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
mel_transform = transforms.MEL2(window=torch.hamming_window, pad=10) mel_transform = transforms.MEL2()
# check defaults
spectrogram_torch = mel_transform(audio_scaled) # (1, 319, 40) spectrogram_torch = mel_transform(audio_scaled) # (1, 319, 40)
self.assertTrue(spectrogram_torch.dim() == 3) self.assertTrue(spectrogram_torch.dim() == 3)
self.assertTrue(spectrogram_torch.le(0.).all()) self.assertTrue(spectrogram_torch.le(0.).all())
self.assertTrue(spectrogram_torch.ge(mel_transform.top_db).all()) self.assertTrue(spectrogram_torch.ge(mel_transform.top_db).all())
self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels) self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels)
# load stereo file # check correctness of filterbank conversion matrix
self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform.fm.fb.sum(1).ge(0.).all())
# check options
mel_transform2 = transforms.MEL2(window=torch.hamming_window, pad=10, ws=500, hop=125, n_fft=800, n_mels=50)
spectrogram2_torch = mel_transform2(audio_scaled) # (1, 506, 50)
self.assertTrue(spectrogram2_torch.dim() == 3)
self.assertTrue(spectrogram2_torch.le(0.).all())
self.assertTrue(spectrogram2_torch.ge(mel_transform.top_db).all())
self.assertEqual(spectrogram2_torch.size(-1), mel_transform2.n_mels)
self.assertTrue(mel_transform2.fm.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform2.fm.fb.sum(1).ge(0.).all())
# check on multi-channel audio
x_stereo, sr_stereo = torchaudio.load(self.test_filepath) x_stereo, sr_stereo = torchaudio.load(self.test_filepath)
spectrogram_stereo = mel_transform(x_stereo) spectrogram_stereo = mel_transform(x_stereo)
self.assertTrue(spectrogram_stereo.dim() == 3) self.assertTrue(spectrogram_stereo.dim() == 3)
...@@ -169,6 +182,11 @@ class Tester(unittest.TestCase): ...@@ -169,6 +182,11 @@ class Tester(unittest.TestCase):
self.assertTrue(spectrogram_stereo.le(0.).all()) self.assertTrue(spectrogram_stereo.le(0.).all())
self.assertTrue(spectrogram_stereo.ge(mel_transform.top_db).all()) self.assertTrue(spectrogram_stereo.ge(mel_transform.top_db).all())
self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels) self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels)
# check filterbank matrix creation
fb_matrix_transform = transforms.F2M(n_mels=100, sr=16000, f_max=None, f_min=0., n_stft=400)
self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -160,23 +160,22 @@ class SPECTROGRAM(object): ...@@ -160,23 +160,22 @@ class SPECTROGRAM(object):
Args: Args:
sr (int): sample rate of audio signal sr (int): sample rate of audio signal
ws (int): window size, often called the fft size as well ws (int): window size
hop (int, optional): length of hop between STFT windows. default: ws // 2 hop (int, optional): length of hop between STFT windows. default: ws // 2
n_fft (int, optional): number of fft bins. default: ws // 2 + 1 n_fft (int, optional): size of fft, creates n_fft // 2 + 1 bins. default: ws
pad (int): two sided padding of signal pad (int): two sided padding of signal
window (torch windowing function): default: torch.hann_window window (torch windowing function): default: torch.hann_window
wkwargs (dict, optional): arguments for window function wkwargs (dict, optional): arguments for window function
""" """
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None, def __init__(self, ws=400, hop=None, n_fft=None,
pad=0, window=torch.hann_window, wkwargs=None): pad=0, window=torch.hann_window, wkwargs=None):
self.window = window(ws) if wkwargs is None else window(ws, **wkwargs) self.window = window(ws) if wkwargs is None else window(ws, **wkwargs)
self.sr = sr
self.ws = ws self.ws = ws
self.hop = hop if hop is not None else ws // 2 self.hop = hop if hop is not None else ws // 2
# number of fft bins. the returned STFT result will have n_fft // 2 + 1 # number of fft bins. the returned STFT result will have n_fft // 2 + 1
# number of frequecies due to onesided=True in torch.stft # number of frequecies due to onesided=True in torch.stft
self.n_fft = (n_fft - 1) * 2 if n_fft is not None else ws self.n_fft = n_fft if n_fft is not None else ws
self.pad = pad self.pad = pad
self.wkwargs = wkwargs self.wkwargs = wkwargs
...@@ -212,17 +211,17 @@ class F2M(object): ...@@ -212,17 +211,17 @@ class F2M(object):
Args: Args:
n_mels (int): number of MEL bins n_mels (int): number of MEL bins
sr (int): sample rate of audio signal sr (int): sample rate of audio signal
f_max (float, optional): maximum frequency. default: sr // 2 f_max (float, optional): maximum frequency. default: `sr` // 2
f_min (float): minimum frequency. default: 0 f_min (float): minimum frequency. default: 0
n_fft (int, optional): number of filter banks from stft. Calculated from first input n_stft (int, optional): number of filter banks from stft. Calculated from first input
if `None` is given. if `None` is given. See `n_fft` in `SPECTROGRAM`.
""" """
def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0., n_fft=None): def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0., n_stft=None):
self.n_mels = n_mels self.n_mels = n_mels
self.sr = sr self.sr = sr
self.f_max = f_max if f_max is not None else sr // 2 self.f_max = f_max if f_max is not None else sr // 2
self.f_min = f_min self.f_min = f_min
self.fb = self._create_fb_matrix(n_fft) if n_fft is not None else n_fft self.fb = self._create_fb_matrix(n_stft) if n_stft is not None else n_stft
def __call__(self, spec_f): def __call__(self, spec_f):
if self.fb is None: if self.fb is None:
...@@ -230,27 +229,35 @@ class F2M(object): ...@@ -230,27 +229,35 @@ class F2M(object):
spec_m = torch.matmul(spec_f, self.fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels) spec_m = torch.matmul(spec_f, self.fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m return spec_m
def _create_fb_matrix(self, n_fft): def _create_fb_matrix(self, n_stft):
""" Create a frequency bin conversion matrix. """ Create a frequency bin conversion matrix.
Args: Args:
n_fft (int): number of filter banks from spectrogram n_stft (int): number of filter banks from spectrogram
""" """
m_min = 0. if self.f_min == 0 else 2595 * np.log10(1. + (self.f_min / 700)) # get stft freq bins
m_max = 2595 * np.log10(1. + (self.f_max / 700)) stft_freqs = torch.linspace(self.f_min, self.f_max, n_stft)
# calculate mel freq bins
m_min = 0. if self.f_min == 0 else self._hertz_to_mel(self.f_min)
m_max = self._hertz_to_mel(self.f_max)
m_pts = torch.linspace(m_min, m_max, self.n_mels + 2) m_pts = torch.linspace(m_min, m_max, self.n_mels + 2)
f_pts = (700 * (10**(m_pts / 2595) - 1)) f_pts = self._mel_to_hertz(m_pts)
# calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1) # (n_stft, n_mels + 2)
# create overlapping triangles
z = torch.tensor(0.)
down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_stft, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_stft, n_mels)
fb = torch.max(z, torch.min(down_slopes, up_slopes))
return fb
bins = torch.floor(((n_fft - 1) * 2) * f_pts / self.sr).long() def _hertz_to_mel(self, f):
return 2595. * torch.log10(torch.tensor(1.) + (f / 700.))
fb = torch.zeros(n_fft, self.n_mels, dtype=torch.float) def _mel_to_hertz(self, mel):
for m in range(1, self.n_mels + 1): return 700. * (10**(mel / 2595.) - 1.)
f_m_minus = bins[m - 1].item()
f_m_plus = bins[m + 1].item()
fb[f_m_minus:f_m_plus, m - 1] = torch.bartlett_window(f_m_plus - f_m_minus)
return fb
class SPEC2DB(object): class SPEC2DB(object):
...@@ -287,12 +294,12 @@ class MEL2(object): ...@@ -287,12 +294,12 @@ class MEL2(object):
Args: Args:
sr (int): sample rate of audio signal sr (int): sample rate of audio signal
ws (int): window size, often called the fft size as well ws (int): window size
hop (int, optional): length of hop between STFT windows. default: ws // 2 hop (int, optional): length of hop between STFT windows. default: `ws` // 2
n_fft (int, optional): number of fft bins. default: ws // 2 + 1 n_fft (int, optional): number of fft bins. default: `ws` // 2 + 1
pad (int): two sided padding of signal pad (int): two sided padding of signal
n_mels (int): number of MEL bins n_mels (int): number of MEL bins
window (torch windowing function): default: torch.hann_window window (torch windowing function): default: `torch.hann_window`
wkwargs (dict, optional): arguments for window function wkwargs (dict, optional): arguments for window function
Example: Example:
...@@ -312,9 +319,9 @@ class MEL2(object): ...@@ -312,9 +319,9 @@ class MEL2(object):
self.top_db = -80. self.top_db = -80.
self.f_max = None self.f_max = None
self.f_min = 0. self.f_min = 0.
self.spec = SPECTROGRAM(self.sr, self.ws, self.hop, self.n_fft, self.spec = SPECTROGRAM(self.ws, self.hop, self.n_fft,
self.pad, self.window, self.wkwargs) self.pad, self.window, self.wkwargs)
self.fm = F2M(self.n_mels, self.sr, self.f_max, self.f_min, self.n_fft) self.fm = F2M(self.n_mels, self.sr, self.f_max, self.f_min)
self.s2db = SPEC2DB("power", self.top_db) self.s2db = SPEC2DB("power", self.top_db)
self.transforms = Compose([ self.transforms = Compose([
self.spec, self.fm, self.s2db, self.spec, self.fm, self.s2db,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment