Unverified Commit 5521f6c7 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

add mel_scale option for slaney/htk (#593)

parent ecfed4d9
......@@ -46,20 +46,29 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)
def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0, norm=None):
def _test_create_fb(
self, n_mels=40,
sample_rate=22050,
n_fft=2048,
fmin=0.0,
fmax=8000.0,
norm=None,
mel_scale="htk",
):
librosa_fb = librosa.filters.mel(sr=sample_rate,
n_fft=n_fft,
n_mels=n_mels,
fmax=fmax,
fmin=fmin,
htk=True,
htk=mel_scale == "htk",
norm=norm)
fb = F.create_fb_matrix(sample_rate=sample_rate,
n_mels=n_mels,
f_max=fmax,
f_min=fmin,
n_freqs=(n_fft // 2 + 1),
norm=norm)
norm=norm,
mel_scale=mel_scale)
for i_mel_bank in range(n_mels):
self.assertEqual(
......@@ -73,6 +82,13 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0)
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)
self._test_create_fb(mel_scale="slaney")
self._test_create_fb(n_mels=128, sample_rate=44100, mel_scale="slaney")
self._test_create_fb(n_mels=128, fmin=2000.0, fmax=5000.0, mel_scale="slaney")
self._test_create_fb(n_mels=56, fmin=100.0, fmax=9000.0, mel_scale="slaney")
self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0, mel_scale="slaney")
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0, mel_scale="slaney")
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0, mel_scale="slaney")
if StrictVersion(librosa.__version__) < StrictVersion("0.7.2"):
return
self._test_create_fb(n_mels=128, sample_rate=44100, norm="slaney")
......
......@@ -46,31 +46,32 @@ class TestTransforms(common_utils.TorchaudioTestCase):
self.assertEqual(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5)
@parameterized.expand([
param(norm=norm, **p.kwargs)
param(norm=norm, mel_scale=mel_scale, **p.kwargs)
for p in [
param(n_fft=400, hop_length=200, n_mels=128),
param(n_fft=600, hop_length=100, n_mels=128),
param(n_fft=200, hop_length=50, n_mels=128),
]
for norm in [None, 'slaney']
for mel_scale in ['htk', 'slaney']
])
def test_mel_spectrogram(self, n_fft, hop_length, n_mels, norm):
def test_mel_spectrogram(self, n_fft, hop_length, n_mels, norm, mel_scale):
sample_rate = 16000
sound = common_utils.get_sinusoid(n_channels=1, sample_rate=sample_rate)
sound_librosa = sound.cpu().numpy().squeeze()
melspect_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, norm=norm)
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, norm=norm, mel_scale=mel_scale)
librosa_mel = librosa.feature.melspectrogram(
y=sound_librosa, sr=sample_rate, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels, htk=True, norm=norm)
hop_length=hop_length, n_mels=n_mels, htk=mel_scale == "htk", norm=norm)
librosa_mel_tensor = torch.from_numpy(librosa_mel)
torch_mel = melspect_transform(sound).squeeze().cpu()
self.assertEqual(
torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3, rtol=1e-5)
@parameterized.expand([
param(norm=norm, **p.kwargs)
param(norm=norm, mel_scale=mel_scale, **p.kwargs)
for p in [
param(n_fft=400, hop_length=200, power=2.0, n_mels=128),
param(n_fft=600, hop_length=100, power=2.0, n_mels=128),
......@@ -79,8 +80,9 @@ class TestTransforms(common_utils.TorchaudioTestCase):
param(n_fft=200, hop_length=50, power=2.0, n_mels=128, skip_ci=True),
]
for norm in [None, 'slaney']
for mel_scale in ['htk', 'slaney']
])
def test_s2db(self, n_fft, hop_length, power, n_mels, norm, skip_ci=False):
def test_s2db(self, n_fft, hop_length, power, n_mels, norm, mel_scale, skip_ci=False):
if skip_ci and 'CI' in os.environ:
self.skipTest('Test is known to fail on CI')
sample_rate = 16000
......@@ -92,10 +94,10 @@ class TestTransforms(common_utils.TorchaudioTestCase):
y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=power)
melspect_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, norm=norm)
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, norm=norm, mel_scale=mel_scale)
librosa_mel = librosa.feature.melspectrogram(
y=sound_librosa, sr=sample_rate, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels, htk=True, norm=norm)
hop_length=hop_length, n_mels=n_mels, htk=mel_scale == "htk", norm=norm)
power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
......
......@@ -296,13 +296,81 @@ def DB_to_amplitude(
return ref * torch.pow(torch.pow(10.0, 0.1 * x), power)
def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
r"""Convert Hz to Mels.
Args:
freqs (float): Frequencies in Hz
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Returns:
mels (float): Frequency in Mels
"""
if mel_scale not in ['slaney', 'htk']:
raise ValueError('mel_scale should be one of "htk" or "slaney".')
if mel_scale == "htk":
return 2595.0 * math.log10(1.0 + (freq / 700.0))
# Fill in the linear part
f_min = 0.0
f_sp = 200.0 / 3
mels = (freq - f_min) / f_sp
# Fill in the log-scale part
min_log_hz = 1000.0
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = math.log(6.4) / 27.0
if freq >= min_log_hz:
mels = min_log_mel + math.log(freq / min_log_hz) / logstep
return mels
def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
"""Convert mel bin numbers to frequencies.
Args:
mels (Tensor): Mel frequencies
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Returns:
freqs (Tensor): Mels converted in Hz
"""
if mel_scale not in ['slaney', 'htk']:
raise ValueError('mel_scale should be one of "htk" or "slaney".')
if mel_scale == "htk":
return 700.0 * (10.0**(mels / 2595.0) - 1.0)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# And now the nonlinear scale
min_log_hz = 1000.0
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = math.log(6.4) / 27.0
log_t = (mels >= min_log_mel)
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
return freqs
def create_fb_matrix(
n_freqs: int,
f_min: float,
f_max: float,
n_mels: int,
sample_rate: int,
norm: Optional[str] = None
norm: Optional[str] = None,
mel_scale: str = "htk",
) -> Tensor:
r"""Create a frequency bin conversion matrix.
......@@ -314,6 +382,7 @@ def create_fb_matrix(
sample_rate (int): Sample rate of the audio waveform
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Returns:
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
......@@ -331,12 +400,12 @@ def create_fb_matrix(
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
# calculate mel freq bins
# hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
m_min = _hz_to_mel(f_min, mel_scale=mel_scale)
m_max = _hz_to_mel(f_max, mel_scale=mel_scale)
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
# mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)
# calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_mels + 2)
......
......@@ -249,6 +249,7 @@ class MelScale(torch.nn.Module):
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
"""
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
......@@ -258,18 +259,21 @@ class MelScale(torch.nn.Module):
f_min: float = 0.,
f_max: Optional[float] = None,
n_stft: Optional[int] = None,
norm: Optional[str] = None) -> None:
norm: Optional[str] = None,
mel_scale: str = "htk") -> None:
super(MelScale, self).__init__()
self.n_mels = n_mels
self.sample_rate = sample_rate
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
self.f_min = f_min
self.norm = norm
self.mel_scale = mel_scale
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm)
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
self.mel_scale)
self.register_buffer('fb', fb)
def forward(self, specgram: Tensor) -> Tensor:
......@@ -287,7 +291,8 @@ class MelScale(torch.nn.Module):
if self.fb.numel() == 0:
tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max,
self.n_mels, self.sample_rate, self.norm)
self.n_mels, self.sample_rate, self.norm,
self.mel_scale)
# Attributes cannot be reassigned outside __init__ so workaround
self.fb.resize_(tmp_fb.size())
self.fb.copy_(tmp_fb)
......@@ -321,6 +326,7 @@ class InverseMelScale(torch.nn.Module):
sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
"""
__constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
'tolerance_change', 'sgdargs']
......@@ -335,7 +341,8 @@ class InverseMelScale(torch.nn.Module):
tolerance_loss: float = 1e-5,
tolerance_change: float = 1e-8,
sgdargs: Optional[dict] = None,
norm: Optional[str] = None) -> None:
norm: Optional[str] = None,
mel_scale: str = "htk") -> None:
super(InverseMelScale, self).__init__()
self.n_mels = n_mels
self.sample_rate = sample_rate
......@@ -348,7 +355,8 @@ class InverseMelScale(torch.nn.Module):
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm)
fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm,
mel_scale)
self.register_buffer('fb', fb)
def forward(self, melspec: Tensor) -> Tensor:
......@@ -427,6 +435,7 @@ class MelSpectrogram(torch.nn.Module):
avoid redundancy. Default: ``True``
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
......@@ -450,7 +459,8 @@ class MelSpectrogram(torch.nn.Module):
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
norm: Optional[str] = None) -> None:
norm: Optional[str] = None,
mel_scale: str = "htk") -> None:
super(MelSpectrogram, self).__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
......@@ -467,7 +477,15 @@ class MelSpectrogram(torch.nn.Module):
pad=self.pad, window_fn=window_fn, power=self.power,
normalized=self.normalized, wkwargs=wkwargs,
center=center, pad_mode=pad_mode, onesided=onesided)
self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1, norm)
self.mel_scale = MelScale(
self.n_mels,
self.sample_rate,
self.f_min,
self.f_max,
self.n_fft // 2 + 1,
norm,
mel_scale
)
def forward(self, waveform: Tensor) -> Tensor:
r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment