Unverified Commit 9a307534 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Porting the torchaudio kaldi fbank implementation to audio_utils (#26182)



* add kaldi fbank

* make style

* add herz_to_mel_kaldi tests

* add mel to hertz kaldi test

* integration tests

* correct test and remove comment

* make style

* Apply suggestions from code review
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* change parameter name

* Apply suggestions from Arthur review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update remove_dc_offset description

* fix bug  + make style

* fix error in using np.exp instead of np.power

* make style

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent b132c170
...@@ -30,17 +30,19 @@ def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Unio ...@@ -30,17 +30,19 @@ def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Unio
freq (`float` or `np.ndarray`): freq (`float` or `np.ndarray`):
The frequency, or multiple frequencies, in hertz (Hz). The frequency, or multiple frequencies, in hertz (Hz).
mel_scale (`str`, *optional*, defaults to `"htk"`): mel_scale (`str`, *optional*, defaults to `"htk"`):
The mel frequency scale to use, `"htk"` or `"slaney"`. The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
Returns: Returns:
`float` or `np.ndarray`: The frequencies on the mel scale. `float` or `np.ndarray`: The frequencies on the mel scale.
""" """
if mel_scale not in ["slaney", "htk"]: if mel_scale not in ["slaney", "htk", "kaldi"]:
raise ValueError('mel_scale should be one of "htk" or "slaney".') raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
if mel_scale == "htk": if mel_scale == "htk":
return 2595.0 * np.log10(1.0 + (freq / 700.0)) return 2595.0 * np.log10(1.0 + (freq / 700.0))
elif mel_scale == "kaldi":
return 1127.0 * np.log(1.0 + (freq / 700.0))
min_log_hertz = 1000.0 min_log_hertz = 1000.0
min_log_mel = 15.0 min_log_mel = 15.0
...@@ -64,17 +66,19 @@ def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Unio ...@@ -64,17 +66,19 @@ def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Unio
mels (`float` or `np.ndarray`): mels (`float` or `np.ndarray`):
The frequency, or multiple frequencies, in mels. The frequency, or multiple frequencies, in mels.
mel_scale (`str`, *optional*, `"htk"`): mel_scale (`str`, *optional*, `"htk"`):
The mel frequency scale to use, `"htk"` or `"slaney"`. The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
Returns: Returns:
`float` or `np.ndarray`: The frequencies in hertz. `float` or `np.ndarray`: The frequencies in hertz.
""" """
if mel_scale not in ["slaney", "htk"]: if mel_scale not in ["slaney", "htk", "kaldi"]:
raise ValueError('mel_scale should be one of "htk" or "slaney".') raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
if mel_scale == "htk": if mel_scale == "htk":
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
elif mel_scale == "kaldi":
return 700.0 * (np.exp(mels / 1127.0) - 1.0)
min_log_hertz = 1000.0 min_log_hertz = 1000.0
min_log_mel = 15.0 min_log_mel = 15.0
...@@ -120,6 +124,7 @@ def mel_filter_bank( ...@@ -120,6 +124,7 @@ def mel_filter_bank(
sampling_rate: int, sampling_rate: int,
norm: Optional[str] = None, norm: Optional[str] = None,
mel_scale: str = "htk", mel_scale: str = "htk",
triangularize_in_mel_space: bool = False,
) -> np.ndarray: ) -> np.ndarray:
""" """
Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
...@@ -155,7 +160,10 @@ def mel_filter_bank( ...@@ -155,7 +160,10 @@ def mel_filter_bank(
norm (`str`, *optional*): norm (`str`, *optional*):
If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization). If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization).
mel_scale (`str`, *optional*, defaults to `"htk"`): mel_scale (`str`, *optional*, defaults to `"htk"`):
The mel frequency scale to use, `"htk"` or `"slaney"`. The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
triangularize_in_mel_space (`bool`, *optional*, defaults to `False`):
If this option is enabled, the triangular filter is applied in mel space rather than frequency space. This
should be set to `true` in order to get the same results as `torchaudio` when computing mel filters.
Returns: Returns:
`np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a `np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a
...@@ -164,15 +172,21 @@ def mel_filter_bank( ...@@ -164,15 +172,21 @@ def mel_filter_bank(
if norm is not None and norm != "slaney": if norm is not None and norm != "slaney":
raise ValueError('norm must be one of None or "slaney"') raise ValueError('norm must be one of None or "slaney"')
# frequencies of FFT bins in Hz
fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
# center points of the triangular mel filters # center points of the triangular mel filters
mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale) mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale) mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2) mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale) filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
if triangularize_in_mel_space:
# frequencies of FFT bins in Hz, but filters triangularized in mel space
fft_bin_width = sampling_rate / (num_frequency_bins * 2)
fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
filter_freqs = mel_freqs
else:
# frequencies of FFT bins in Hz
fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs) mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
if norm is not None and norm == "slaney": if norm is not None and norm == "slaney":
...@@ -218,6 +232,7 @@ def window_function( ...@@ -218,6 +232,7 @@ def window_function(
- `"boxcar"`: a rectangular window - `"boxcar"`: a rectangular window
- `"hamming"`: the Hamming window - `"hamming"`: the Hamming window
- `"hann"`: the Hann window - `"hann"`: the Hann window
- `"povey"`: the Povey window
Args: Args:
window_length (`int`): window_length (`int`):
...@@ -243,6 +258,8 @@ def window_function( ...@@ -243,6 +258,8 @@ def window_function(
window = np.hamming(length) window = np.hamming(length)
elif name in ["hann", "hann_window"]: elif name in ["hann", "hann_window"]:
window = np.hanning(length) window = np.hanning(length)
elif name in ["povey"]:
window = np.power(np.hanning(length), 0.85)
else: else:
raise ValueError(f"Unknown window function '{name}'") raise ValueError(f"Unknown window function '{name}'")
...@@ -281,6 +298,7 @@ def spectrogram( ...@@ -281,6 +298,7 @@ def spectrogram(
reference: float = 1.0, reference: float = 1.0,
min_value: float = 1e-10, min_value: float = 1e-10,
db_range: Optional[float] = None, db_range: Optional[float] = None,
remove_dc_offset: Optional[bool] = None,
dtype: np.dtype = np.float32, dtype: np.dtype = np.float32,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -363,6 +381,9 @@ def spectrogram( ...@@ -363,6 +381,9 @@ def spectrogram(
db_range (`float`, *optional*): db_range (`float`, *optional*):
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
peak value and the smallest value will never be more than 80 dB. Must be greater than zero. peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
remove_dc_offset (`bool`, *optional*):
Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in
order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters.
dtype (`np.dtype`, *optional*, defaults to `np.float32`): dtype (`np.dtype`, *optional*, defaults to `np.float32`):
Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be
`np.complex64`. `np.complex64`.
...@@ -414,6 +435,9 @@ def spectrogram( ...@@ -414,6 +435,9 @@ def spectrogram(
for frame_idx in range(num_frames): for frame_idx in range(num_frames):
buffer[:frame_length] = waveform[timestep : timestep + frame_length] buffer[:frame_length] = waveform[timestep : timestep + frame_length]
if remove_dc_offset:
buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
if preemphasis is not None: if preemphasis is not None:
buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1] buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
buffer[0] *= 1 - preemphasis buffer[0] *= 1 - preemphasis
......
...@@ -45,6 +45,10 @@ class AudioUtilsFunctionTester(unittest.TestCase): ...@@ -45,6 +45,10 @@ class AudioUtilsFunctionTester(unittest.TestCase):
expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016]) expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])
self.assertTrue(np.allclose(hertz_to_mel(inputs, "slaney"), expected)) self.assertTrue(np.allclose(hertz_to_mel(inputs, "slaney"), expected))
inputs = np.array([60, 100, 200, 1000, 1001, 2000])
expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])
self.assertTrue(np.allclose(hertz_to_mel(inputs, "kaldi"), expected))
with pytest.raises(ValueError): with pytest.raises(ValueError):
hertz_to_mel(100, mel_scale=None) hertz_to_mel(100, mel_scale=None)
...@@ -63,6 +67,10 @@ class AudioUtilsFunctionTester(unittest.TestCase): ...@@ -63,6 +67,10 @@ class AudioUtilsFunctionTester(unittest.TestCase):
expected = np.array([60, 100, 200, 1000, 1001, 2000]) expected = np.array([60, 100, 200, 1000, 1001, 2000])
self.assertTrue(np.allclose(mel_to_hertz(inputs, "slaney"), expected)) self.assertTrue(np.allclose(mel_to_hertz(inputs, "slaney"), expected))
inputs = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])
expected = np.array([60, 100, 200, 1000, 1001, 2000])
self.assertTrue(np.allclose(mel_to_hertz(inputs, "kaldi"), expected))
with pytest.raises(ValueError): with pytest.raises(ValueError):
mel_to_hertz(100, mel_scale=None) mel_to_hertz(100, mel_scale=None)
...@@ -89,6 +97,18 @@ class AudioUtilsFunctionTester(unittest.TestCase): ...@@ -89,6 +97,18 @@ class AudioUtilsFunctionTester(unittest.TestCase):
) )
self.assertEqual(mel_filters.shape, (513, 13)) self.assertEqual(mel_filters.shape, (513, 13))
mel_filters = mel_filter_bank(
num_frequency_bins=513,
num_mel_filters=13,
min_frequency=100,
max_frequency=4000,
sampling_rate=16000,
norm="slaney",
mel_scale="slaney",
triangularize_in_mel_space=True,
)
self.assertEqual(mel_filters.shape, (513, 13))
def test_mel_filter_bank_htk(self): def test_mel_filter_bank_htk(self):
mel_filters = mel_filter_bank( mel_filters = mel_filter_bank(
num_frequency_bins=16, num_frequency_bins=16,
...@@ -153,6 +173,39 @@ class AudioUtilsFunctionTester(unittest.TestCase): ...@@ -153,6 +173,39 @@ class AudioUtilsFunctionTester(unittest.TestCase):
# fmt: on # fmt: on
self.assertTrue(np.allclose(mel_filters, expected)) self.assertTrue(np.allclose(mel_filters, expected))
def test_mel_filter_bank_kaldi(self):
mel_filters = mel_filter_bank(
num_frequency_bins=16,
num_mel_filters=4,
min_frequency=0,
max_frequency=2000,
sampling_rate=4000,
norm=None,
mel_scale="kaldi",
triangularize_in_mel_space=True,
)
# fmt: off
expected = np.array(
[[0.0000, 0.0000, 0.0000, 0.0000],
[0.6086, 0.0000, 0.0000, 0.0000],
[0.8689, 0.1311, 0.0000, 0.0000],
[0.4110, 0.5890, 0.0000, 0.0000],
[0.0036, 0.9964, 0.0000, 0.0000],
[0.0000, 0.6366, 0.3634, 0.0000],
[0.0000, 0.3027, 0.6973, 0.0000],
[0.0000, 0.0000, 0.9964, 0.0036],
[0.0000, 0.0000, 0.7135, 0.2865],
[0.0000, 0.0000, 0.4507, 0.5493],
[0.0000, 0.0000, 0.2053, 0.7947],
[0.0000, 0.0000, 0.0000, 0.9752],
[0.0000, 0.0000, 0.0000, 0.7585],
[0.0000, 0.0000, 0.0000, 0.5539],
[0.0000, 0.0000, 0.0000, 0.3599],
[0.0000, 0.0000, 0.0000, 0.1756]]
)
# fmt: on
self.assertTrue(np.allclose(mel_filters, expected, atol=5e-5))
def test_mel_filter_bank_slaney_norm(self): def test_mel_filter_bank_slaney_norm(self):
mel_filters = mel_filter_bank( mel_filters = mel_filter_bank(
num_frequency_bins=16, num_frequency_bins=16,
...@@ -271,6 +324,58 @@ class AudioUtilsFunctionTester(unittest.TestCase): ...@@ -271,6 +324,58 @@ class AudioUtilsFunctionTester(unittest.TestCase):
self.assertEqual(spec.shape, (257, 732)) self.assertEqual(spec.shape, (257, 732))
self.assertTrue(np.allclose(spec[:64, 400], expected)) self.assertTrue(np.allclose(spec[:64, 400], expected))
mel_filters = mel_filter_bank(
num_frequency_bins=256,
num_mel_filters=400,
min_frequency=20,
max_frequency=8000,
sampling_rate=16000,
norm=None,
mel_scale="kaldi",
triangularize_in_mel_space=True,
)
mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
spec = spectrogram(
waveform,
window_function(400, "povey", periodic=False),
frame_length=400,
hop_length=160,
fft_length=512,
power=2.0,
center=False,
pad_mode="reflect",
onesided=True,
preemphasis=0.97,
mel_filters=mel_filters,
log_mel="log",
mel_floor=1.1920928955078125e-07,
remove_dc_offset=True,
)
self.assertEqual(spec.shape, (400, 584))
# fmt: off
expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,
-15.94238515, -15.94238515, -15.94238515, -15.94238515,
-6.52463769, -7.73677889, -15.94238515, -15.94238515,
-15.94238515, -15.94238515, -4.18650018, -3.37195286,
-15.94238515, -15.94238515, -15.94238515, -15.94238515,
-4.70190154, -2.4217066 , -15.94238515, -15.94238515,
-15.94238515, -15.94238515, -5.62755239, -3.53385194,
-15.94238515, -15.94238515, -15.94238515, -15.94238515,
-9.43303023, -8.77480925, -15.94238515, -15.94238515,
-15.94238515, -15.94238515, -4.2951092 , -5.51585994,
-15.94238515, -15.94238515, -15.94238515, -4.40151721,
-3.95228878, -15.94238515, -15.94238515, -15.94238515,
-6.10365415, -4.59494697, -15.94238515, -15.94238515,
-15.94238515, -8.10727767, -6.2585298 , -15.94238515,
-15.94238515, -15.94238515, -5.60161702, -4.47217004,
-15.94238515, -15.94238515, -15.94238515, -5.91641988]
)
# fmt: on
self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))
def test_spectrogram_center_padding(self): def test_spectrogram_center_padding(self):
waveform = self._load_datasamples(1)[0] waveform = self._load_datasamples(1)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment