Unverified Commit 83dc5ec7 authored by Joel Frank's avatar Joel Frank Committed by GitHub
Browse files

Add melscale_fbanks and deprecate create_fb_matrix (#1653)

- Renamed torchaudio.functional.create_fb_matrix to torchaudio.functional.melscale_fbanks.
- Added interface with a warning for create_fb_matrix
parent 8a347b62
...@@ -26,6 +26,11 @@ create_fb_matrix ...@@ -26,6 +26,11 @@ create_fb_matrix
.. autofunction:: create_fb_matrix .. autofunction:: create_fb_matrix
melscale_fbanks
---------------
.. autofunction:: melscale_fbanks
linear_fbanks linear_fbanks
------------- -------------
......
...@@ -438,20 +438,20 @@ class Functional(TestBaseMixin): ...@@ -438,20 +438,20 @@ class Functional(TestBaseMixin):
class FunctionalCPUOnly(TestBaseMixin): class FunctionalCPUOnly(TestBaseMixin):
def test_create_fb_matrix_no_warning_high_n_freq(self): def test_melscale_fbanks_no_warning_high_n_freq(self):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") warnings.simplefilter("always")
F.create_fb_matrix(288, 0, 8000, 128, 16000) F.melscale_fbanks(288, 0, 8000, 128, 16000)
assert len(w) == 0 assert len(w) == 0
def test_create_fb_matrix_no_warning_low_n_mels(self): def test_melscale_fbanks_no_warning_low_n_mels(self):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") warnings.simplefilter("always")
F.create_fb_matrix(201, 0, 8000, 89, 16000) F.melscale_fbanks(201, 0, 8000, 89, 16000)
assert len(w) == 0 assert len(w) == 0
def test_create_fb_matrix_warning(self): def test_melscale_fbanks_warning(self):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") warnings.simplefilter("always")
F.create_fb_matrix(201, 0, 8000, 128, 16000) F.melscale_fbanks(201, 0, 8000, 128, 16000)
assert len(w) == 1 assert len(w) == 1
...@@ -76,8 +76,8 @@ class Functional(TestBaseMixin): ...@@ -76,8 +76,8 @@ class Functional(TestBaseMixin):
[param(norm=n) for n in [None, 'slaney']], [param(norm=n) for n in [None, 'slaney']],
[param(mel_scale=s) for s in ['htk', 'slaney']], [param(mel_scale=s) for s in ['htk', 'slaney']],
) )
def test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, def test_create_mel_fb(self, n_mels=40, sample_rate=22050, n_fft=2048,
fmin=0.0, fmax=8000.0, norm=None, mel_scale="htk"): fmin=0.0, fmax=8000.0, norm=None, mel_scale="htk"):
if (norm == "slaney" and StrictVersion(librosa.__version__) < StrictVersion("0.7.2")): if (norm == "slaney" and StrictVersion(librosa.__version__) < StrictVersion("0.7.2")):
self.skipTest('Test is known to fail with older versions of librosa.') self.skipTest('Test is known to fail with older versions of librosa.')
if self.device != 'cpu': if self.device != 'cpu':
...@@ -91,7 +91,7 @@ class Functional(TestBaseMixin): ...@@ -91,7 +91,7 @@ class Functional(TestBaseMixin):
fmin=fmin, fmin=fmin,
htk=mel_scale == "htk", htk=mel_scale == "htk",
norm=norm).T norm=norm).T
result = F.create_fb_matrix( result = F.melscale_fbanks(
sample_rate=sample_rate, sample_rate=sample_rate,
n_mels=n_mels, n_mels=n_mels,
f_max=fmax, f_max=fmax,
......
...@@ -116,7 +116,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -116,7 +116,7 @@ class Functional(TempDirMixin, TestBaseMixin):
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_create_fb_matrix(self): def test_melscale_fbanks(self):
if self.device != torch.device('cpu'): if self.device != torch.device('cpu'):
raise unittest.SkipTest('No need to perform test on device other than CPU') raise unittest.SkipTest('No need to perform test on device other than CPU')
...@@ -127,7 +127,7 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -127,7 +127,7 @@ class Functional(TempDirMixin, TestBaseMixin):
n_mels = 10 n_mels = 10
sample_rate = 16000 sample_rate = 16000
norm = "slaney" norm = "slaney"
return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate, norm) return F.melscale_fbanks(n_stft, f_min, f_max, n_mels, sample_rate, norm)
dummy = torch.zeros(1, 1) dummy = torch.zeros(1, 1)
self._assert_consistency(func, dummy) self._assert_consistency(func, dummy)
......
...@@ -6,6 +6,7 @@ from .functional import ( ...@@ -6,6 +6,7 @@ from .functional import (
compute_kaldi_pitch, compute_kaldi_pitch,
create_dct, create_dct,
create_fb_matrix, create_fb_matrix,
melscale_fbanks,
linear_fbanks, linear_fbanks,
DB_to_amplitude, DB_to_amplitude,
detect_pitch_frequency, detect_pitch_frequency,
...@@ -56,6 +57,7 @@ __all__ = [ ...@@ -56,6 +57,7 @@ __all__ = [
'compute_kaldi_pitch', 'compute_kaldi_pitch',
'create_dct', 'create_dct',
'create_fb_matrix', 'create_fb_matrix',
'melscale_fbanks',
'linear_fbanks', 'linear_fbanks',
'DB_to_amplitude', 'DB_to_amplitude',
'detect_pitch_frequency', 'detect_pitch_frequency',
......
...@@ -19,6 +19,7 @@ __all__ = [ ...@@ -19,6 +19,7 @@ __all__ = [
"compute_deltas", "compute_deltas",
"compute_kaldi_pitch", "compute_kaldi_pitch",
"create_fb_matrix", "create_fb_matrix",
"melscale_fbanks",
"linear_fbanks", "linear_fbanks",
"create_dct", "create_dct",
"compute_deltas", "compute_deltas",
...@@ -431,6 +432,52 @@ def create_fb_matrix( ...@@ -431,6 +432,52 @@ def create_fb_matrix(
size (..., ``n_freqs``), the applied result would be size (..., ``n_freqs``), the applied result would be
``A * create_fb_matrix(A.size(-1), ...)``. ``A * create_fb_matrix(A.size(-1), ...)``.
""" """
warnings.warn(
"The use of `create_fb_matrix` is now deprecated and will be removed in "
"the 0.11 release. "
"Please migrate your code to use `melscale_fbanks` instead. "
"For more information, please refer to https://github.com/pytorch/audio/issues/1574."
)
return melscale_fbanks(
n_freqs=n_freqs,
f_min=f_min,
f_max=f_max,
n_mels=n_mels,
sample_rate=sample_rate,
norm=norm,
mel_scale=mel_scale
)
def melscale_fbanks(
n_freqs: int,
f_min: float,
f_max: float,
n_mels: int,
sample_rate: int,
norm: Optional[str] = None,
mel_scale: str = "htk",
) -> Tensor:
r"""Create a frequency bin conversion matrix.
Args:
n_freqs (int): Number of frequencies to highlight/apply
f_min (float): Minimum frequency (Hz)
f_max (float): Maximum frequency (Hz)
n_mels (int): Number of mel filterbanks
sample_rate (int): Sample rate of the audio waveform
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Returns:
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
meaning number of frequencies to highlight/apply to x the number of filterbanks.
Each column is a filterbank so that assuming there is a matrix A of
size (..., ``n_freqs``), the applied result would be
``A * melscale_fbanks(A.size(-1), ...)``.
"""
if norm is not None and norm != "slaney": if norm is not None and norm != "slaney":
raise ValueError("norm must be one of None or 'slaney'") raise ValueError("norm must be one of None or 'slaney'")
......
...@@ -269,7 +269,7 @@ class MelScale(torch.nn.Module): ...@@ -269,7 +269,7 @@ class MelScale(torch.nn.Module):
self.mel_scale = mel_scale self.mel_scale = mel_scale
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
fb = F.create_fb_matrix( fb = F.melscale_fbanks(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
self.mel_scale) self.mel_scale)
self.register_buffer('fb', fb) self.register_buffer('fb', fb)
...@@ -337,8 +337,8 @@ class InverseMelScale(torch.nn.Module): ...@@ -337,8 +337,8 @@ class InverseMelScale(torch.nn.Module):
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm, fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate,
mel_scale) norm, mel_scale)
self.register_buffer('fb', fb) self.register_buffer('fb', fb)
def forward(self, melspec: Tensor) -> Tensor: def forward(self, melspec: Tensor) -> Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment