Unverified Commit b985609b authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Make parameter optional string (#641)

* make parameter optional string.

* raise error if incorrect parameter choice.

* test with slaney.
parent 241ab1e8
...@@ -83,7 +83,7 @@ class Functional(common_utils.TestBaseMixin): ...@@ -83,7 +83,7 @@ class Functional(common_utils.TestBaseMixin):
f_max = 20.0 f_max = 20.0
n_mels = 10 n_mels = 10
sample_rate = 16000 sample_rate = 16000
norm = "" norm = "slaney"
return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate, norm) return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate, norm)
dummy = torch.zeros(1, 1) dummy = torch.zeros(1, 1)
......
...@@ -336,7 +336,7 @@ def create_fb_matrix( ...@@ -336,7 +336,7 @@ def create_fb_matrix(
f_max: float, f_max: float,
n_mels: int, n_mels: int,
sample_rate: int, sample_rate: int,
norm: str = "", norm: Optional[str] = None
) -> Tensor: ) -> Tensor:
r"""Create a frequency bin conversion matrix. r"""Create a frequency bin conversion matrix.
...@@ -346,8 +346,8 @@ def create_fb_matrix( ...@@ -346,8 +346,8 @@ def create_fb_matrix(
f_max (float): Maximum frequency (Hz) f_max (float): Maximum frequency (Hz)
n_mels (int): Number of mel filterbanks n_mels (int): Number of mel filterbanks
sample_rate (int): Sample rate of the audio waveform sample_rate (int): Sample rate of the audio waveform
norm (str): If 'slaney', divide the triangular mel weights by the width of the mel band norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: '') (area normalization). (Default: ``None``)
Returns: Returns:
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``) Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
...@@ -356,6 +356,10 @@ def create_fb_matrix( ...@@ -356,6 +356,10 @@ def create_fb_matrix(
size (..., ``n_freqs``), the applied result would be size (..., ``n_freqs``), the applied result would be
``A * create_fb_matrix(A.size(-1), ...)``. ``A * create_fb_matrix(A.size(-1), ...)``.
""" """
if norm is not None and norm != "slaney":
raise ValueError("norm must be one of None or 'slaney'")
# freq bins # freq bins
# Equivalent filterbank construction by Librosa # Equivalent filterbank construction by Librosa
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs) all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
...@@ -376,7 +380,7 @@ def create_fb_matrix( ...@@ -376,7 +380,7 @@ def create_fb_matrix(
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels) up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
fb = torch.max(zero, torch.min(down_slopes, up_slopes)) fb = torch.max(zero, torch.min(down_slopes, up_slopes))
if norm == "slaney": if norm is not None and norm == "slaney":
# Slaney-style mel is scaled to be approx constant energy per channel # Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels]) enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels])
fb *= enorm.unsqueeze(0) fb *= enorm.unsqueeze(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment