Unverified Commit 5521f6c7 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

add mel_scale option for slaney/htk (#593)

parent ecfed4d9
...@@ -46,20 +46,29 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -46,20 +46,29 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5) self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)
def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0, norm=None): def _test_create_fb(
self, n_mels=40,
sample_rate=22050,
n_fft=2048,
fmin=0.0,
fmax=8000.0,
norm=None,
mel_scale="htk",
):
librosa_fb = librosa.filters.mel(sr=sample_rate, librosa_fb = librosa.filters.mel(sr=sample_rate,
n_fft=n_fft, n_fft=n_fft,
n_mels=n_mels, n_mels=n_mels,
fmax=fmax, fmax=fmax,
fmin=fmin, fmin=fmin,
htk=True, htk=mel_scale == "htk",
norm=norm) norm=norm)
fb = F.create_fb_matrix(sample_rate=sample_rate, fb = F.create_fb_matrix(sample_rate=sample_rate,
n_mels=n_mels, n_mels=n_mels,
f_max=fmax, f_max=fmax,
f_min=fmin, f_min=fmin,
n_freqs=(n_fft // 2 + 1), n_freqs=(n_fft // 2 + 1),
norm=norm) norm=norm,
mel_scale=mel_scale)
for i_mel_bank in range(n_mels): for i_mel_bank in range(n_mels):
self.assertEqual( self.assertEqual(
...@@ -73,6 +82,13 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -73,6 +82,13 @@ class TestFunctional(common_utils.TorchaudioTestCase):
self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0) self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0)
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0) self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0) self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)
self._test_create_fb(mel_scale="slaney")
self._test_create_fb(n_mels=128, sample_rate=44100, mel_scale="slaney")
self._test_create_fb(n_mels=128, fmin=2000.0, fmax=5000.0, mel_scale="slaney")
self._test_create_fb(n_mels=56, fmin=100.0, fmax=9000.0, mel_scale="slaney")
self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0, mel_scale="slaney")
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0, mel_scale="slaney")
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0, mel_scale="slaney")
if StrictVersion(librosa.__version__) < StrictVersion("0.7.2"): if StrictVersion(librosa.__version__) < StrictVersion("0.7.2"):
return return
self._test_create_fb(n_mels=128, sample_rate=44100, norm="slaney") self._test_create_fb(n_mels=128, sample_rate=44100, norm="slaney")
......
...@@ -46,31 +46,32 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -46,31 +46,32 @@ class TestTransforms(common_utils.TorchaudioTestCase):
self.assertEqual(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5) self.assertEqual(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5)
@parameterized.expand([ @parameterized.expand([
param(norm=norm, **p.kwargs) param(norm=norm, mel_scale=mel_scale, **p.kwargs)
for p in [ for p in [
param(n_fft=400, hop_length=200, n_mels=128), param(n_fft=400, hop_length=200, n_mels=128),
param(n_fft=600, hop_length=100, n_mels=128), param(n_fft=600, hop_length=100, n_mels=128),
param(n_fft=200, hop_length=50, n_mels=128), param(n_fft=200, hop_length=50, n_mels=128),
] ]
for norm in [None, 'slaney'] for norm in [None, 'slaney']
for mel_scale in ['htk', 'slaney']
]) ])
def test_mel_spectrogram(self, n_fft, hop_length, n_mels, norm): def test_mel_spectrogram(self, n_fft, hop_length, n_mels, norm, mel_scale):
sample_rate = 16000 sample_rate = 16000
sound = common_utils.get_sinusoid(n_channels=1, sample_rate=sample_rate) sound = common_utils.get_sinusoid(n_channels=1, sample_rate=sample_rate)
sound_librosa = sound.cpu().numpy().squeeze() sound_librosa = sound.cpu().numpy().squeeze()
melspect_transform = torchaudio.transforms.MelSpectrogram( melspect_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window, sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, norm=norm) hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, norm=norm, mel_scale=mel_scale)
librosa_mel = librosa.feature.melspectrogram( librosa_mel = librosa.feature.melspectrogram(
y=sound_librosa, sr=sample_rate, n_fft=n_fft, y=sound_librosa, sr=sample_rate, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels, htk=True, norm=norm) hop_length=hop_length, n_mels=n_mels, htk=mel_scale == "htk", norm=norm)
librosa_mel_tensor = torch.from_numpy(librosa_mel) librosa_mel_tensor = torch.from_numpy(librosa_mel)
torch_mel = melspect_transform(sound).squeeze().cpu() torch_mel = melspect_transform(sound).squeeze().cpu()
self.assertEqual( self.assertEqual(
torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3, rtol=1e-5) torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3, rtol=1e-5)
@parameterized.expand([ @parameterized.expand([
param(norm=norm, **p.kwargs) param(norm=norm, mel_scale=mel_scale, **p.kwargs)
for p in [ for p in [
param(n_fft=400, hop_length=200, power=2.0, n_mels=128), param(n_fft=400, hop_length=200, power=2.0, n_mels=128),
param(n_fft=600, hop_length=100, power=2.0, n_mels=128), param(n_fft=600, hop_length=100, power=2.0, n_mels=128),
...@@ -79,8 +80,9 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -79,8 +80,9 @@ class TestTransforms(common_utils.TorchaudioTestCase):
param(n_fft=200, hop_length=50, power=2.0, n_mels=128, skip_ci=True), param(n_fft=200, hop_length=50, power=2.0, n_mels=128, skip_ci=True),
] ]
for norm in [None, 'slaney'] for norm in [None, 'slaney']
for mel_scale in ['htk', 'slaney']
]) ])
def test_s2db(self, n_fft, hop_length, power, n_mels, norm, skip_ci=False): def test_s2db(self, n_fft, hop_length, power, n_mels, norm, mel_scale, skip_ci=False):
if skip_ci and 'CI' in os.environ: if skip_ci and 'CI' in os.environ:
self.skipTest('Test is known to fail on CI') self.skipTest('Test is known to fail on CI')
sample_rate = 16000 sample_rate = 16000
...@@ -92,10 +94,10 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -92,10 +94,10 @@ class TestTransforms(common_utils.TorchaudioTestCase):
y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=power) y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=power)
melspect_transform = torchaudio.transforms.MelSpectrogram( melspect_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window, sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, norm=norm) hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, norm=norm, mel_scale=mel_scale)
librosa_mel = librosa.feature.melspectrogram( librosa_mel = librosa.feature.melspectrogram(
y=sound_librosa, sr=sample_rate, n_fft=n_fft, y=sound_librosa, sr=sample_rate, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels, htk=True, norm=norm) hop_length=hop_length, n_mels=n_mels, htk=mel_scale == "htk", norm=norm)
power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.) power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu() power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
......
...@@ -296,13 +296,81 @@ def DB_to_amplitude( ...@@ -296,13 +296,81 @@ def DB_to_amplitude(
return ref * torch.pow(torch.pow(10.0, 0.1 * x), power) return ref * torch.pow(torch.pow(10.0, 0.1 * x), power)
def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
r"""Convert Hz to Mels.
Args:
freqs (float): Frequencies in Hz
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Returns:
mels (float): Frequency in Mels
"""
if mel_scale not in ['slaney', 'htk']:
raise ValueError('mel_scale should be one of "htk" or "slaney".')
if mel_scale == "htk":
return 2595.0 * math.log10(1.0 + (freq / 700.0))
# Fill in the linear part
f_min = 0.0
f_sp = 200.0 / 3
mels = (freq - f_min) / f_sp
# Fill in the log-scale part
min_log_hz = 1000.0
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = math.log(6.4) / 27.0
if freq >= min_log_hz:
mels = min_log_mel + math.log(freq / min_log_hz) / logstep
return mels
def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
"""Convert mel bin numbers to frequencies.
Args:
mels (Tensor): Mel frequencies
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Returns:
freqs (Tensor): Mels converted in Hz
"""
if mel_scale not in ['slaney', 'htk']:
raise ValueError('mel_scale should be one of "htk" or "slaney".')
if mel_scale == "htk":
return 700.0 * (10.0**(mels / 2595.0) - 1.0)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# And now the nonlinear scale
min_log_hz = 1000.0
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = math.log(6.4) / 27.0
log_t = (mels >= min_log_mel)
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
return freqs
def create_fb_matrix( def create_fb_matrix(
n_freqs: int, n_freqs: int,
f_min: float, f_min: float,
f_max: float, f_max: float,
n_mels: int, n_mels: int,
sample_rate: int, sample_rate: int,
norm: Optional[str] = None norm: Optional[str] = None,
mel_scale: str = "htk",
) -> Tensor: ) -> Tensor:
r"""Create a frequency bin conversion matrix. r"""Create a frequency bin conversion matrix.
...@@ -314,6 +382,7 @@ def create_fb_matrix( ...@@ -314,6 +382,7 @@ def create_fb_matrix(
sample_rate (int): Sample rate of the audio waveform sample_rate (int): Sample rate of the audio waveform
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``) (area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Returns: Returns:
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``) Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
...@@ -331,12 +400,12 @@ def create_fb_matrix( ...@@ -331,12 +400,12 @@ def create_fb_matrix(
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs) all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
# calculate mel freq bins # calculate mel freq bins
# hertz to mel(f) is 2595. * math.log10(1. + (f / 700.)) m_min = _hz_to_mel(f_min, mel_scale=mel_scale)
m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0)) m_max = _hz_to_mel(f_max, mel_scale=mel_scale)
m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
m_pts = torch.linspace(m_min, m_max, n_mels + 2) m_pts = torch.linspace(m_min, m_max, n_mels + 2)
# mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.) f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)
f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
# calculate the difference between each mel point and each stft freq point in hertz # calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1) f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_mels + 2) slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_mels + 2)
......
...@@ -249,6 +249,7 @@ class MelScale(torch.nn.Module): ...@@ -249,6 +249,7 @@ class MelScale(torch.nn.Module):
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``) if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``) (area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
""" """
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max'] __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
...@@ -258,18 +259,21 @@ class MelScale(torch.nn.Module): ...@@ -258,18 +259,21 @@ class MelScale(torch.nn.Module):
f_min: float = 0., f_min: float = 0.,
f_max: Optional[float] = None, f_max: Optional[float] = None,
n_stft: Optional[int] = None, n_stft: Optional[int] = None,
norm: Optional[str] = None) -> None: norm: Optional[str] = None,
mel_scale: str = "htk") -> None:
super(MelScale, self).__init__() super(MelScale, self).__init__()
self.n_mels = n_mels self.n_mels = n_mels
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.f_max = f_max if f_max is not None else float(sample_rate // 2) self.f_max = f_max if f_max is not None else float(sample_rate // 2)
self.f_min = f_min self.f_min = f_min
self.norm = norm self.norm = norm
self.mel_scale = mel_scale
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm) n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
self.mel_scale)
self.register_buffer('fb', fb) self.register_buffer('fb', fb)
def forward(self, specgram: Tensor) -> Tensor: def forward(self, specgram: Tensor) -> Tensor:
...@@ -287,7 +291,8 @@ class MelScale(torch.nn.Module): ...@@ -287,7 +291,8 @@ class MelScale(torch.nn.Module):
if self.fb.numel() == 0: if self.fb.numel() == 0:
tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max,
self.n_mels, self.sample_rate, self.norm) self.n_mels, self.sample_rate, self.norm,
self.mel_scale)
# Attributes cannot be reassigned outside __init__ so workaround # Attributes cannot be reassigned outside __init__ so workaround
self.fb.resize_(tmp_fb.size()) self.fb.resize_(tmp_fb.size())
self.fb.copy_(tmp_fb) self.fb.copy_(tmp_fb)
...@@ -321,6 +326,7 @@ class InverseMelScale(torch.nn.Module): ...@@ -321,6 +326,7 @@ class InverseMelScale(torch.nn.Module):
sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``) sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``) (area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
""" """
__constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss', __constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
'tolerance_change', 'sgdargs'] 'tolerance_change', 'sgdargs']
...@@ -335,7 +341,8 @@ class InverseMelScale(torch.nn.Module): ...@@ -335,7 +341,8 @@ class InverseMelScale(torch.nn.Module):
tolerance_loss: float = 1e-5, tolerance_loss: float = 1e-5,
tolerance_change: float = 1e-8, tolerance_change: float = 1e-8,
sgdargs: Optional[dict] = None, sgdargs: Optional[dict] = None,
norm: Optional[str] = None) -> None: norm: Optional[str] = None,
mel_scale: str = "htk") -> None:
super(InverseMelScale, self).__init__() super(InverseMelScale, self).__init__()
self.n_mels = n_mels self.n_mels = n_mels
self.sample_rate = sample_rate self.sample_rate = sample_rate
...@@ -348,7 +355,8 @@ class InverseMelScale(torch.nn.Module): ...@@ -348,7 +355,8 @@ class InverseMelScale(torch.nn.Module):
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm) fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm,
mel_scale)
self.register_buffer('fb', fb) self.register_buffer('fb', fb)
def forward(self, melspec: Tensor) -> Tensor: def forward(self, melspec: Tensor) -> Tensor:
...@@ -427,6 +435,7 @@ class MelSpectrogram(torch.nn.Module): ...@@ -427,6 +435,7 @@ class MelSpectrogram(torch.nn.Module):
avoid redundancy. Default: ``True`` avoid redundancy. Default: ``True``
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``) (area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Example Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True) >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
...@@ -450,7 +459,8 @@ class MelSpectrogram(torch.nn.Module): ...@@ -450,7 +459,8 @@ class MelSpectrogram(torch.nn.Module):
center: bool = True, center: bool = True,
pad_mode: str = "reflect", pad_mode: str = "reflect",
onesided: bool = True, onesided: bool = True,
norm: Optional[str] = None) -> None: norm: Optional[str] = None,
mel_scale: str = "htk") -> None:
super(MelSpectrogram, self).__init__() super(MelSpectrogram, self).__init__()
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.n_fft = n_fft self.n_fft = n_fft
...@@ -467,7 +477,15 @@ class MelSpectrogram(torch.nn.Module): ...@@ -467,7 +477,15 @@ class MelSpectrogram(torch.nn.Module):
pad=self.pad, window_fn=window_fn, power=self.power, pad=self.pad, window_fn=window_fn, power=self.power,
normalized=self.normalized, wkwargs=wkwargs, normalized=self.normalized, wkwargs=wkwargs,
center=center, pad_mode=pad_mode, onesided=onesided) center=center, pad_mode=pad_mode, onesided=onesided)
self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1, norm) self.mel_scale = MelScale(
self.n_mels,
self.sample_rate,
self.f_min,
self.f_max,
self.n_fft // 2 + 1,
norm,
mel_scale
)
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment