Unverified Commit 83dc5ec7 authored by Joel Frank's avatar Joel Frank Committed by GitHub
Browse files

Add melscale_fbanks and deprecate create_fb_matrix (#1653)

- Renamed torchaudio.functional.create_fb_matrix to torchaudio.functional.melscale_fbanks.
- Added interface with a warning for create_fb_matrix
parent 8a347b62
......@@ -26,6 +26,11 @@ create_fb_matrix
.. autofunction:: create_fb_matrix
melscale_fbanks
---------------
.. autofunction:: melscale_fbanks
linear_fbanks
-------------
......
......@@ -438,20 +438,20 @@ class Functional(TestBaseMixin):
class FunctionalCPUOnly(TestBaseMixin):
def test_create_fb_matrix_no_warning_high_n_freq(self):
def test_melscale_fbanks_no_warning_high_n_freq(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(288, 0, 8000, 128, 16000)
F.melscale_fbanks(288, 0, 8000, 128, 16000)
assert len(w) == 0
def test_create_fb_matrix_no_warning_low_n_mels(self):
def test_melscale_fbanks_no_warning_low_n_mels(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(201, 0, 8000, 89, 16000)
F.melscale_fbanks(201, 0, 8000, 89, 16000)
assert len(w) == 0
def test_create_fb_matrix_warning(self):
def test_melscale_fbanks_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(201, 0, 8000, 128, 16000)
F.melscale_fbanks(201, 0, 8000, 128, 16000)
assert len(w) == 1
......@@ -76,8 +76,8 @@ class Functional(TestBaseMixin):
[param(norm=n) for n in [None, 'slaney']],
[param(mel_scale=s) for s in ['htk', 'slaney']],
)
def test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048,
fmin=0.0, fmax=8000.0, norm=None, mel_scale="htk"):
def test_create_mel_fb(self, n_mels=40, sample_rate=22050, n_fft=2048,
fmin=0.0, fmax=8000.0, norm=None, mel_scale="htk"):
if (norm == "slaney" and StrictVersion(librosa.__version__) < StrictVersion("0.7.2")):
self.skipTest('Test is known to fail with older versions of librosa.')
if self.device != 'cpu':
......@@ -91,7 +91,7 @@ class Functional(TestBaseMixin):
fmin=fmin,
htk=mel_scale == "htk",
norm=norm).T
result = F.create_fb_matrix(
result = F.melscale_fbanks(
sample_rate=sample_rate,
n_mels=n_mels,
f_max=fmax,
......
......@@ -116,7 +116,7 @@ class Functional(TempDirMixin, TestBaseMixin):
self._assert_consistency(func, waveform)
def test_create_fb_matrix(self):
def test_melscale_fbanks(self):
if self.device != torch.device('cpu'):
raise unittest.SkipTest('No need to perform test on device other than CPU')
......@@ -127,7 +127,7 @@ class Functional(TempDirMixin, TestBaseMixin):
n_mels = 10
sample_rate = 16000
norm = "slaney"
return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate, norm)
return F.melscale_fbanks(n_stft, f_min, f_max, n_mels, sample_rate, norm)
dummy = torch.zeros(1, 1)
self._assert_consistency(func, dummy)
......
......@@ -6,6 +6,7 @@ from .functional import (
compute_kaldi_pitch,
create_dct,
create_fb_matrix,
melscale_fbanks,
linear_fbanks,
DB_to_amplitude,
detect_pitch_frequency,
......@@ -56,6 +57,7 @@ __all__ = [
'compute_kaldi_pitch',
'create_dct',
'create_fb_matrix',
'melscale_fbanks',
'linear_fbanks',
'DB_to_amplitude',
'detect_pitch_frequency',
......
......@@ -19,6 +19,7 @@ __all__ = [
"compute_deltas",
"compute_kaldi_pitch",
"create_fb_matrix",
"melscale_fbanks",
"linear_fbanks",
"create_dct",
"compute_deltas",
......@@ -431,6 +432,52 @@ def create_fb_matrix(
size (..., ``n_freqs``), the applied result would be
``A * create_fb_matrix(A.size(-1), ...)``.
"""
warnings.warn(
"The use of `create_fb_matrix` is now deprecated and will be removed in "
"the 0.11 release. "
"Please migrate your code to use `melscale_fbanks` instead. "
"For more information, please refer to https://github.com/pytorch/audio/issues/1574."
)
return melscale_fbanks(
n_freqs=n_freqs,
f_min=f_min,
f_max=f_max,
n_mels=n_mels,
sample_rate=sample_rate,
norm=norm,
mel_scale=mel_scale
)
def melscale_fbanks(
n_freqs: int,
f_min: float,
f_max: float,
n_mels: int,
sample_rate: int,
norm: Optional[str] = None,
mel_scale: str = "htk",
) -> Tensor:
r"""Create a frequency bin conversion matrix.
Args:
n_freqs (int): Number of frequencies to highlight/apply
f_min (float): Minimum frequency (Hz)
f_max (float): Maximum frequency (Hz)
n_mels (int): Number of mel filterbanks
sample_rate (int): Sample rate of the audio waveform
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Returns:
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
meaning number of frequencies to highlight/apply to x the number of filterbanks.
Each column is a filterbank so that assuming there is a matrix A of
size (..., ``n_freqs``), the applied result would be
``A * melscale_fbanks(A.size(-1), ...)``.
"""
if norm is not None and norm != "slaney":
raise ValueError("norm must be one of None or 'slaney'")
......
......@@ -269,7 +269,7 @@ class MelScale(torch.nn.Module):
self.mel_scale = mel_scale
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
fb = F.create_fb_matrix(
fb = F.melscale_fbanks(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
self.mel_scale)
self.register_buffer('fb', fb)
......@@ -337,8 +337,8 @@ class InverseMelScale(torch.nn.Module):
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm,
mel_scale)
fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate,
norm, mel_scale)
self.register_buffer('fb', fb)
def forward(self, melspec: Tensor) -> Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment