Commit 02ce049c authored by engineerchuan's avatar engineerchuan Committed by Vincent QB
Browse files

Fix mel filter bank (#294)

* Fixed create_fb_matrix filter bank behavior for fmin/fmax
* add better test for f_min close to f_max
* added one more test for f_min > f_max
* adding one more test
parent e8b27c80
...@@ -217,6 +217,36 @@ class TestFunctional(unittest.TestCase): ...@@ -217,6 +217,36 @@ class TestFunctional(unittest.TestCase):
self._test_istft_of_sine(amplitude=80, L=9, n=6) self._test_istft_of_sine(amplitude=80, L=9, n=6)
self._test_istft_of_sine(amplitude=99, L=10, n=7) self._test_istft_of_sine(amplitude=99, L=10, n=7)
def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0):
# Using a decorator here causes parametrize to fail on Python 2
if not IMPORT_LIBROSA:
raise unittest.SkipTest('Librosa is not available')
librosa_fb = librosa.filters.mel(sr=sample_rate,
n_fft=n_fft,
n_mels=n_mels,
fmax=fmax,
fmin=fmin,
htk=True,
norm=None)
fb = F.create_fb_matrix(sample_rate=sample_rate,
n_mels=n_mels,
f_max=fmax,
f_min=fmin,
n_freqs=(n_fft // 2 + 1))
for i_mel_bank in range(n_mels):
assert torch.allclose(fb[:, i_mel_bank], torch.tensor(librosa_fb[i_mel_bank]), atol=1e-4)
def test_create_fb(self):
self._test_create_fb()
self._test_create_fb(n_mels=128, sample_rate=44100)
self._test_create_fb(n_mels=128, fmin=2000.0, fmax=5000.0)
self._test_create_fb(n_mels=56, fmin=100.0, fmax=9000.0)
self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0)
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)
def _num_stft_bins(signal_len, fft_len, hop_length, pad): def _num_stft_bins(signal_len, fft_len, hop_length, pad):
return (signal_len + 2 * pad - fft_len + hop_length) // hop_length return (signal_len + 2 * pad - fft_len + hop_length) // hop_length
......
...@@ -59,17 +59,18 @@ class Test_JIT(unittest.TestCase): ...@@ -59,17 +59,18 @@ class Test_JIT(unittest.TestCase):
def test_torchscript_create_fb_matrix(self): def test_torchscript_create_fb_matrix(self):
@torch.jit.script @torch.jit.script
def jit_method(n_stft, f_min, f_max, n_mels): def jit_method(n_stft, f_min, f_max, n_mels, sample_rate):
# type: (int, float, float, int) -> Tensor # type: (int, float, float, int, int) -> Tensor
return F.create_fb_matrix(n_stft, f_min, f_max, n_mels) return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate)
n_stft = 100 n_stft = 100
f_min = 0. f_min = 0.
f_max = 20. f_max = 20.
n_mels = 10 n_mels = 10
sample_rate = 16000
jit_out = jit_method(n_stft, f_min, f_max, n_mels) jit_out = jit_method(n_stft, f_min, f_max, n_mels, sample_rate)
py_out = F.create_fb_matrix(n_stft, f_min, f_max, n_mels) py_out = F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate)
self.assertTrue(torch.allclose(jit_out, py_out)) self.assertTrue(torch.allclose(jit_out, py_out))
......
...@@ -290,18 +290,19 @@ def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None): ...@@ -290,18 +290,19 @@ def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None):
@torch.jit.script @torch.jit.script
def create_fb_matrix(n_freqs, f_min, f_max, n_mels): def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate):
# type: (int, float, float, int) -> Tensor # type: (int, float, float, int, int) -> Tensor
r""" r"""
create_fb_matrix(n_freqs, f_min, f_max, n_mels) create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate)
Create a frequency bin conversion matrix. Create a frequency bin conversion matrix.
Args: Args:
n_freqs (int): Number of frequencies to highlight/apply n_freqs (int): Number of frequencies to highlight/apply
f_min (float): Minimum frequency f_min (float): Minimum frequency (Hz)
f_max (float): Maximum frequency f_max (float): Maximum frequency (Hz)
n_mels (int): Number of mel filterbanks n_mels (int): Number of mel filterbanks
sample_rate (int): Sample rate of the audio waveform
Returns: Returns:
torch.Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``) torch.Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
...@@ -311,17 +312,21 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels): ...@@ -311,17 +312,21 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels):
``A * create_fb_matrix(A.size(-1), ...)``. ``A * create_fb_matrix(A.size(-1), ...)``.
""" """
# freq bins # freq bins
freqs = torch.linspace(f_min, f_max, n_freqs) # Equivalent filterbank construction by Librosa
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
i_freqs = all_freqs.ge(f_min) & all_freqs.le(f_max)
freqs = all_freqs[i_freqs]
# calculate mel freq bins # calculate mel freq bins
# hertz to mel(f) is 2595. * math.log10(1. + (f / 700.)) # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
m_min = 0.0 if f_min == 0 else 2595.0 * math.log10(1.0 + (f_min / 700.0)) m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0)) m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
m_pts = torch.linspace(m_min, m_max, n_mels + 2) m_pts = torch.linspace(m_min, m_max, n_mels + 2)
# mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.) # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0) f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
# calculate the difference between each mel point and each stft freq point in hertz # calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1) f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - freqs.unsqueeze(1) # (n_freqs, n_mels + 2) slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_mels + 2)
# create overlapping triangles # create overlapping triangles
zero = torch.zeros(1) zero = torch.zeros(1)
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels) down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
......
...@@ -132,7 +132,7 @@ class MelScale(torch.jit.ScriptModule): ...@@ -132,7 +132,7 @@ class MelScale(torch.jit.ScriptModule):
assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max) assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
self.f_min = f_min self.f_min = f_min
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels) n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
self.fb = torch.jit.Attribute(fb, torch.Tensor) self.fb = torch.jit.Attribute(fb, torch.Tensor)
@torch.jit.script_method @torch.jit.script_method
...@@ -145,7 +145,7 @@ class MelScale(torch.jit.ScriptModule): ...@@ -145,7 +145,7 @@ class MelScale(torch.jit.ScriptModule):
torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time) torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time)
""" """
if self.fb.numel() == 0: if self.fb.numel() == 0:
tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels) tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate)
# Attributes cannot be reassigned outside __init__ so workaround # Attributes cannot be reassigned outside __init__ so workaround
self.fb.resize_(tmp_fb.size()) self.fb.resize_(tmp_fb.size())
self.fb.copy_(tmp_fb) self.fb.copy_(tmp_fb)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment