Unverified Commit 7f919509 authored by Matthijs Hollemans's avatar Matthijs Hollemans Committed by GitHub
Browse files

audio_utils improvements (#21998)

* silly change to allow making a PR

* clean up doc comments

* simplify hertz_to_mel and mel_to_hertz

* fixup

* clean up power_to_db

* also add amplitude_to_db

* move functions

* clean up mel_filter_bank

* fixup

* credit librosa & torchaudio authors

* add unit tests

* tests for power_to_db and amplitude_to_db

* add mel_filter_bank tests

* rewrite STFT

* add convenience spectrogram function

* missing transpose

* fewer transposes

* add integration test to M-CTC-T

* frame length can be either window or FFT length

* rewrite stft API

* add preemphasis coefficient

* move argument

* add log option to spectrogram

* replace M-CTC-T feature extractor

* fix api thing

* replace whisper STFT

* replace whisper mel filters

* replace tvlt's stft

* allow alternate window names

* replace speecht5 stft

* fixup

* fix integration tests

* fix doc comments

* remove manual FFT length calculation

* fix docs

* go away, deprecation warnings

* combine everything into spectrogram function

* add deprecated functions back

* fixup
parent 431b04d8
...@@ -12,10 +12,9 @@ specific language governing permissions and limitations under the License. ...@@ -12,10 +12,9 @@ specific language governing permissions and limitations under the License.
# Utilities for `FeatureExtractors` # Utilities for `FeatureExtractors`
This page lists all the utility functions that can be used by the audio [`FeatureExtractor`] in order to compute special features from a raw audio using common algorithms such as *Short Time Fourier Transform* or *Mel log spectrogram*. This page lists all the utility functions that can be used by the audio [`FeatureExtractor`] in order to compute special features from a raw audio using common algorithms such as *Short Time Fourier Transform* or *log mel spectrogram*.
Most of those are only useful if you are studying the code of the audio processors in the library.
Most of those are only useful if you are studying the code of the image processors in the library.
## Audio Transformations ## Audio Transformations
...@@ -23,12 +22,14 @@ Most of those are only useful if you are studying the code of the image processo ...@@ -23,12 +22,14 @@ Most of those are only useful if you are studying the code of the image processo
[[autodoc]] audio_utils.mel_to_hertz [[autodoc]] audio_utils.mel_to_hertz
[[autodoc]] audio_utils.get_mel_filter_banks [[autodoc]] audio_utils.mel_filter_bank
[[autodoc]] audio_utils.stft [[autodoc]] audio_utils.optimal_fft_length
[[autodoc]] audio_utils.power_to_db [[autodoc]] audio_utils.window_function
[[autodoc]] audio_utils.fram_wave [[autodoc]] audio_utils.spectrogram
[[autodoc]] audio_utils.power_to_db
[[autodoc]] audio_utils.amplitude_to_db
This diff is collapsed.
...@@ -21,7 +21,7 @@ from typing import Any, Dict, List, Optional, Union ...@@ -21,7 +21,7 @@ from typing import Any, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from ...audio_utils import fram_wave, get_mel_filter_banks, power_to_db, stft from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, logging from ...utils import TensorType, logging
...@@ -116,21 +116,21 @@ class ClapFeatureExtractor(SequenceFeatureExtractor): ...@@ -116,21 +116,21 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.frequency_min = frequency_min self.frequency_min = frequency_min
self.frequency_max = frequency_max self.frequency_max = frequency_max
self.mel_filters = get_mel_filter_banks( self.mel_filters = mel_filter_bank(
nb_frequency_bins=self.nb_frequency_bins, num_frequency_bins=self.nb_frequency_bins,
nb_mel_filters=feature_size, num_mel_filters=feature_size,
frequency_min=frequency_min, min_frequency=frequency_min,
frequency_max=frequency_max, max_frequency=frequency_max,
sample_rate=sampling_rate, sampling_rate=sampling_rate,
norm=None, norm=None,
mel_scale="htk", mel_scale="htk",
) )
self.mel_filters_slaney = get_mel_filter_banks( self.mel_filters_slaney = mel_filter_bank(
nb_frequency_bins=self.nb_frequency_bins, num_frequency_bins=self.nb_frequency_bins,
nb_mel_filters=feature_size, num_mel_filters=feature_size,
frequency_min=frequency_min, min_frequency=frequency_min,
frequency_max=frequency_max, max_frequency=frequency_max,
sample_rate=sampling_rate, sampling_rate=sampling_rate,
norm="slaney", norm="slaney",
mel_scale="slaney", mel_scale="slaney",
) )
...@@ -153,24 +153,25 @@ class ClapFeatureExtractor(SequenceFeatureExtractor): ...@@ -153,24 +153,25 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
def _np_extract_fbank_features(self, waveform: np.array, mel_filters: Optional[np.array] = None) -> np.ndarray: def _np_extract_fbank_features(self, waveform: np.array, mel_filters: Optional[np.array] = None) -> np.ndarray:
""" """
Compute the log-Mel spectrogram of the provided `waveform` using the `hanning` window. In CLAP, two different Compute the log-mel spectrogram of the provided `waveform` using the Hann window. In CLAP, two different filter
filter banks are used depending on the truncation pattern: banks are used depending on the truncation pattern:
- `self.mel_filters`: they correspond to the defaults parameters of `torchaduio` which can be obtained from - `self.mel_filters`: they correspond to the default parameters of `torchaudio` which can be obtained from
calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation` calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation`
is set to `"fusion"`. is set to `"fusion"`.
- `self.mel_filteres_slaney` : they correspond to the defaults parameters of `torchlibrosa` which used - `self.mel_filteres_slaney` : they correspond to the default parameters of `librosa` which used
`librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original `librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original
implementation when the truncation mode is not `"fusion"`. implementation when the truncation mode is not `"fusion"`.
""" """
window = np.hanning(self.fft_window_size + 1)[:-1] log_mel_spectrogram = spectrogram(
frames = fram_wave(waveform, self.hop_length, self.fft_window_size) waveform,
spectrogram = stft(frames, window, fft_window_size=self.fft_window_size) window_function(self.fft_window_size, "hann"),
frame_length=self.fft_window_size,
magnitudes = np.abs(spectrogram) ** 2 hop_length=self.hop_length,
mel_spectrogram = np.matmul(mel_filters.T, magnitudes) power=2.0,
log_mel_spectrogram = power_to_db(mel_spectrogram).T mel_filters=mel_filters,
log_mel_spectrogram = np.asarray(log_mel_spectrogram, np.float32) log_mel="dB",
return log_mel_spectrogram )
return log_mel_spectrogram.T
def _random_mel_fusion(self, mel, total_frames, chunk_frames): def _random_mel_fusion(self, mel, total_frames, chunk_frames):
ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3) ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3)
......
...@@ -20,9 +20,8 @@ from typing import List, Optional, Union ...@@ -20,9 +20,8 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
import torchaudio
from packaging import version
from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...file_utils import PaddingStrategy, TensorType from ...file_utils import PaddingStrategy, TensorType
...@@ -31,13 +30,6 @@ from ...utils import logging ...@@ -31,13 +30,6 @@ from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
parsed_torchaudio_version_base = version.parse(version.parse(torchaudio.__version__).base_version)
if not parsed_torchaudio_version_base >= version.parse("0.10"):
logger.warning(
f"You are using torchaudio=={torchaudio.__version__}, but torchaudio>=0.10.0 is required to use "
"MCTCTFeatureExtractor. This requires torch>=1.10.0. Please upgrade torch and torchaudio."
)
class MCTCTFeatureExtractor(SequenceFeatureExtractor): class MCTCTFeatureExtractor(SequenceFeatureExtractor):
r""" r"""
...@@ -110,68 +102,9 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor): ...@@ -110,68 +102,9 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor):
self.sample_size = win_length * sampling_rate // 1000 self.sample_size = win_length * sampling_rate // 1000
self.sample_stride = hop_length * sampling_rate // 1000 self.sample_stride = hop_length * sampling_rate // 1000
self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size))) self.n_fft = optimal_fft_length(self.sample_size)
self.n_freqs = (self.n_fft // 2) + 1 self.n_freqs = (self.n_fft // 2) + 1
@staticmethod
def _num_frames_calc(in_size, frame_size, frame_stride):
return int(1 + np.floor((in_size - frame_size) * 1 / frame_stride))
@staticmethod
def _frame_signal(one_waveform, n_frames, frame_signal_scale, window_length, sample_stride):
scale = frame_signal_scale
frames = np.zeros(n_frames * window_length)
for frame_idx in range(n_frames):
start = frame_idx * window_length
end = (frame_idx + 1) * window_length
wave_start = frame_idx * sample_stride
wave_end = frame_idx * sample_stride + window_length
frames[start:end] = scale * one_waveform[wave_start:wave_end]
return frames
@staticmethod
def _apply_preemphasis_inplace(frames, window_length, preemphasis_coeff):
if frames.size % window_length != 0:
raise ValueError(
f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
f" window_length={window_length}."
)
n_frames = frames.size // window_length
for frame_idx in range(n_frames, 0, -1):
start = (frame_idx - 1) * window_length
end = frame_idx * window_length - 1
frames[start + 1 : end + 1] -= preemphasis_coeff * frames[start:end]
frames[start] *= 1 - preemphasis_coeff
@staticmethod
def _windowing(frames, window_length, window):
if frames.size % window_length != 0:
raise ValueError(
f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
f" window_length={window_length}."
)
shaped = frames.reshape(-1, window_length)
shaped = window * shaped
return shaped
@staticmethod
def _dft(frames, K, n_frames, n_samples, n_fft):
dft = np.zeros([n_frames, K])
for frame in range(n_frames):
begin = frame * n_samples
inwards_buffer = frames[begin : begin + n_samples]
inwards_buffer = np.pad(inwards_buffer, (0, n_fft - n_samples), "constant")
out = np.fft.rfft(inwards_buffer)
dft[frame] = np.abs(out[:K])
return dft
def _extract_mfsc_features(self, one_waveform: np.array) -> np.ndarray: def _extract_mfsc_features(self, one_waveform: np.array) -> np.ndarray:
""" """
Extracts MFSC Features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC code. Extracts MFSC Features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC code.
...@@ -183,36 +116,27 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor): ...@@ -183,36 +116,27 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor):
window = window.numpy() window = window.numpy()
fbanks = torchaudio.functional.melscale_fbanks( fbanks = mel_filter_bank(
n_freqs=self.n_freqs, num_frequency_bins=self.n_freqs,
f_min=0.0, # change this to zeros num_mel_filters=self.feature_size,
f_max=self.sampling_rate / 2.0, min_frequency=0.0,
n_mels=self.feature_size, max_frequency=self.sampling_rate / 2.0,
sample_rate=self.sampling_rate, sampling_rate=self.sampling_rate,
) )
fbanks = fbanks.numpy() msfc_features = spectrogram(
one_waveform * self.frame_signal_scale,
n_frames = self._num_frames_calc(one_waveform.size, self.sample_size, self.sample_stride) window=window,
frame_length=self.sample_size,
frames = self._frame_signal( hop_length=self.sample_stride,
one_waveform, n_frames, self.frame_signal_scale, self.sample_size, self.sample_stride fft_length=self.n_fft,
center=False,
preemphasis=self.preemphasis_coeff,
mel_filters=fbanks,
mel_floor=self.mel_floor,
log_mel="log",
) )
return msfc_features.T
self._apply_preemphasis_inplace(frames, self.sample_size, self.preemphasis_coeff)
frames = self._windowing(frames, self.sample_size, window)
dft_out = self._dft(frames.flatten(), self.n_freqs, n_frames, self.sample_size, self.n_fft)
# msfc_features = STFT * mel frequency banks.
msfc_features = np.einsum("...tf,fm->...tm", dft_out, fbanks)
# clamp feature values then log scale, as implemented in flashlight
msfc_features = np.maximum(msfc_features, self.mel_floor)
msfc_features = np.log(msfc_features)
return msfc_features
def _normalize_one(self, x, input_length, padding_value): def _normalize_one(self, x, input_length, padding_value):
# make sure we normalize float32 arrays # make sure we normalize float32 arrays
......
...@@ -20,7 +20,7 @@ from typing import Any, Dict, List, Optional, Union ...@@ -20,7 +20,7 @@ from typing import Any, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from ...audio_utils import get_mel_filter_banks from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging from ...utils import PaddingStrategy, TensorType, logging
...@@ -110,18 +110,18 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): ...@@ -110,18 +110,18 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
self.sample_size = win_length * sampling_rate // 1000 self.sample_size = win_length * sampling_rate // 1000
self.sample_stride = hop_length * sampling_rate // 1000 self.sample_stride = hop_length * sampling_rate // 1000
self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size))) self.n_fft = optimal_fft_length(self.sample_size)
self.n_freqs = (self.n_fft // 2) + 1 self.n_freqs = (self.n_fft // 2) + 1
window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True) window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True)
self.window = window.numpy().astype(np.float64) self.window = window.numpy().astype(np.float64)
self.mel_filters = get_mel_filter_banks( self.mel_filters = mel_filter_bank(
nb_frequency_bins=self.n_freqs, num_frequency_bins=self.n_freqs,
nb_mel_filters=self.num_mel_bins, num_mel_filters=self.num_mel_bins,
frequency_min=self.fmin, min_frequency=self.fmin,
frequency_max=self.fmax, max_frequency=self.fmax,
sample_rate=self.sampling_rate, sampling_rate=self.sampling_rate,
norm="slaney", norm="slaney",
mel_scale="slaney", mel_scale="slaney",
) )
...@@ -160,31 +160,6 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): ...@@ -160,31 +160,6 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
return normed_input_values return normed_input_values
@staticmethod
def _stft(waveform: np.ndarray, fft_length: int, hop_length: int, window: np.ndarray) -> np.ndarray:
"""
Calculates the magnitude spectrogram over one waveform array.
"""
# center pad the waveform
padding = [(int(fft_length // 2), int(fft_length // 2))]
waveform = np.pad(waveform, padding, mode="reflect")
waveform_size = waveform.size
# promote to float64, since np.fft uses float64 internally
waveform = waveform.astype(np.float64)
num_frames = int(1 + np.floor((waveform_size - fft_length) / hop_length))
num_frequency_bins = (fft_length // 2) + 1
spectrogram = np.empty((num_frames, num_frequency_bins))
start = 0
for frame_idx in range(num_frames):
frame = waveform[start : start + fft_length] * window
spectrogram[frame_idx] = np.abs(np.fft.rfft(frame))
start += hop_length
return spectrogram
def _extract_mel_features( def _extract_mel_features(
self, self,
one_waveform: np.ndarray, one_waveform: np.ndarray,
...@@ -192,14 +167,17 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): ...@@ -192,14 +167,17 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
""" """
Extracts log-mel filterbank features for one waveform array (unbatched). Extracts log-mel filterbank features for one waveform array (unbatched).
""" """
if self.n_fft != self.sample_size: log_mel_spec = spectrogram(
raise NotImplementedError( one_waveform,
f"Currently the STFT frame size must be a power of two, but got {self.sample_size} for a window length of {self.win_length} and sampling rate of {self.sampling_rate}. Ensure `win_length * sampling_rate // 1000` is divisible by two." window=self.window,
frame_length=self.sample_size,
hop_length=self.sample_stride,
fft_length=self.n_fft,
mel_filters=self.mel_filters,
mel_floor=self.mel_floor,
log_mel="log10",
) )
return log_mel_spec.T
stft_out = self._stft(one_waveform, self.n_fft, self.sample_stride, self.window)
return np.log10(np.maximum(self.mel_floor, np.dot(stft_out, self.mel_filters)))
def __call__( def __call__(
self, self,
......
...@@ -18,8 +18,8 @@ from math import ceil ...@@ -18,8 +18,8 @@ from math import ceil
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
from numpy.fft import fft
from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import BatchFeature, SequenceFeatureExtractor from ...feature_extraction_sequence_utils import BatchFeature, SequenceFeatureExtractor
from ...utils import TensorType, logging from ...utils import TensorType, logging
...@@ -83,143 +83,34 @@ class TvltFeatureExtractor(SequenceFeatureExtractor): ...@@ -83,143 +83,34 @@ class TvltFeatureExtractor(SequenceFeatureExtractor):
self.hop_length = sampling_rate // hop_length_to_sampling_rate self.hop_length = sampling_rate // hop_length_to_sampling_rate
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.padding_value = padding_value self.padding_value = padding_value
self.mel_filters = self.get_mel_filters(sampling_rate, n_fft, n_mels=feature_size) self.mel_filters = mel_filter_bank(
num_frequency_bins=1 + n_fft // 2,
# Copied from transformers.models.whisper.feature_extraction_whisper.WhisperFeatureExtractor.get_mel_filters with 45.245640471924965->59.99247463746737 num_mel_filters=feature_size,
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32): min_frequency=0.0,
# Initialize the weights max_frequency=22050.0,
n_mels = int(n_mels) sampling_rate=sampling_rate,
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) norm="slaney",
mel_scale="slaney",
# Center freqs of each FFT bin ).T
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 59.99247463746737
mels = np.linspace(min_mel, max_mel, n_mels + 2)
mels = np.asanyarray(mels)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
mel_f = freqs
fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, np.newaxis]
return weights
# Copied from transformers.models.whisper.feature_extraction_whisper.WhisperFeatureExtractor.fram_wave
def fram_wave(self, waveform, center=True):
"""
Transform a raw waveform into a list of smaller waveforms. The window length defines how much of the signal is
contain in each frame (smalle waveform), while the hope length defines the step between the beginning of each
new frame.
Centering is done by reflecting the waveform which is first centered around `frame_idx * hop_length`.
"""
frames = []
for i in range(0, waveform.shape[0] + 1, self.hop_length):
half_window = (self.n_fft - 1) // 2 + 1
if center:
start = i - half_window if i > half_window else 0
end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
frame = waveform[start:end]
if start == 0:
padd_width = (-i + half_window, 0)
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
elif end == waveform.shape[0]:
padd_width = (0, (i - waveform.shape[0] + half_window))
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
else:
frame = waveform[i : i + self.n_fft]
frame_width = frame.shape[0]
if frame_width < waveform.shape[0]:
frame = np.lib.pad(
frame, pad_width=(0, self.n_fft - frame_width), mode="constant", constant_values=0
)
frames.append(frame)
return np.stack(frames, 0)
# Copied from transformers.models.whisper.feature_extraction_whisper.WhisperFeatureExtractor.stft
def stft(self, frames, window):
"""
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same
results as `torch.stft`.
"""
frame_size = frames.shape[1]
fft_size = self.n_fft
if fft_size is None:
fft_size = frame_size
if fft_size < frame_size:
raise ValueError("FFT size must greater or equal the frame size")
# number of FFT bins to store
num_fft_bins = (fft_size >> 1) + 1
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
fft_signal = np.zeros(fft_size)
for f, frame in enumerate(frames):
if window is not None:
np.multiply(frame, window, out=fft_signal[:frame_size])
else:
fft_signal[:frame_size] = frame
data[f] = fft(fft_signal, axis=0)[:num_fft_bins]
return data.T
def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray: def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
""" """
Compute the log-Mel spectrogram of the provided audio, gives similar results whisper's original torch Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
implementation with 1e-5 tolerance. implementation with 1e-5 tolerance.
""" """
window = np.hanning(self.n_fft + 1)[:-1] log_spec = spectrogram(
waveform,
frames = self.fram_wave(waveform) window_function(self.n_fft, "hann"),
stft = self.stft(frames, window=window) frame_length=self.n_fft,
magnitudes = np.abs(stft[:, :-1]) ** 2 hop_length=self.hop_length,
power=2.0,
filters = self.mel_filters mel_filters=self.mel_filters.T,
mel_spec = filters @ magnitudes log_mel="dB",
db_range=80.0,
log_spec = 10.0 * np.log10(np.maximum(1e-10, mel_spec)) )
log_spec -= 10.0 * np.log10(np.maximum(1e-10, 1.0)) log_spec = log_spec[:, :-1]
log_spec = np.maximum(log_spec, log_spec.max() - 80.0)
log_spec = log_spec - 20.0 log_spec = log_spec - 20.0
log_spec = np.clip(log_spec / 40.0, -2.0, 0.0) + 1.0 log_spec = np.clip(log_spec / 40.0, -2.0, 0.0) + 1.0
return log_spec return log_spec
def __call__( def __call__(
......
...@@ -19,8 +19,8 @@ import copy ...@@ -19,8 +19,8 @@ import copy
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import numpy as np import numpy as np
from numpy.fft import fft
from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, logging from ...utils import TensorType, logging
...@@ -81,138 +81,33 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -81,138 +81,33 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
self.n_samples = chunk_length * sampling_rate self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length self.nb_max_frames = self.n_samples // hop_length
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.mel_filters = self.get_mel_filters(sampling_rate, n_fft, n_mels=feature_size) self.mel_filters = mel_filter_bank(
num_frequency_bins=1 + n_fft // 2,
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32): num_mel_filters=feature_size,
# Initialize the weights min_frequency=0.0,
n_mels = int(n_mels) max_frequency=8000.0,
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) sampling_rate=sampling_rate,
norm="slaney",
# Center freqs of each FFT bin mel_scale="slaney",
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 45.245640471924965
mels = np.linspace(min_mel, max_mel, n_mels + 2)
mels = np.asanyarray(mels)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
mel_f = freqs
fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, np.newaxis]
return weights
def fram_wave(self, waveform, center=True):
"""
Transform a raw waveform into a list of smaller waveforms. The window length defines how much of the signal is
contain in each frame (smalle waveform), while the hope length defines the step between the beginning of each
new frame.
Centering is done by reflecting the waveform which is first centered around `frame_idx * hop_length`.
"""
frames = []
for i in range(0, waveform.shape[0] + 1, self.hop_length):
half_window = (self.n_fft - 1) // 2 + 1
if center:
start = i - half_window if i > half_window else 0
end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
frame = waveform[start:end]
if start == 0:
padd_width = (-i + half_window, 0)
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
elif end == waveform.shape[0]:
padd_width = (0, (i - waveform.shape[0] + half_window))
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
else:
frame = waveform[i : i + self.n_fft]
frame_width = frame.shape[0]
if frame_width < waveform.shape[0]:
frame = np.lib.pad(
frame, pad_width=(0, self.n_fft - frame_width), mode="constant", constant_values=0
) )
frames.append(frame)
return np.stack(frames, 0)
def stft(self, frames, window):
"""
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same
results as `torch.stft`.
"""
frame_size = frames.shape[1]
fft_size = self.n_fft
if fft_size is None:
fft_size = frame_size
if fft_size < frame_size:
raise ValueError("FFT size must greater or equal the frame size")
# number of FFT bins to store
num_fft_bins = (fft_size >> 1) + 1
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
fft_signal = np.zeros(fft_size)
for f, frame in enumerate(frames):
if window is not None:
np.multiply(frame, window, out=fft_signal[:frame_size])
else:
fft_signal[:frame_size] = frame
data[f] = fft(fft_signal, axis=0)[:num_fft_bins]
return data.T
def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray: def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
""" """
Compute the log-Mel spectrogram of the provided audio, gives similar results whisper's original torch Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
implementation with 1e-5 tolerance. implementation with 1e-5 tolerance.
""" """
window = np.hanning(self.n_fft + 1)[:-1] log_spec = spectrogram(
waveform,
frames = self.fram_wave(waveform) window_function(self.n_fft, "hann"),
stft = self.stft(frames, window=window) frame_length=self.n_fft,
magnitudes = np.abs(stft[:, :-1]) ** 2 hop_length=self.hop_length,
power=2.0,
filters = self.mel_filters mel_filters=self.mel_filters,
mel_spec = filters @ magnitudes log_mel="log10",
)
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None)) log_spec = log_spec[:, :-1]
log_spec = np.maximum(log_spec, log_spec.max() - 8.0) log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0 log_spec = (log_spec + 4.0) / 4.0
return log_spec return log_spec
@staticmethod @staticmethod
......
...@@ -160,6 +160,7 @@ class ASTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test ...@@ -160,6 +160,7 @@ class ASTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test
# fmt: on # fmt: on
input_speech = self._load_datasamples(1) input_speech = self._load_datasamples(1)
feaure_extractor = ASTFeatureExtractor() feature_extractor = ASTFeatureExtractor()
input_values = feaure_extractor(input_speech, return_tensors="pt").input_values input_values = feature_extractor(input_speech, return_tensors="pt").input_values
self.assertEquals(input_values.shape, (1, 1024, 128))
self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4)) self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
...@@ -21,7 +21,7 @@ import unittest ...@@ -21,7 +21,7 @@ import unittest
import numpy as np import numpy as np
from transformers import is_speech_available from transformers import is_speech_available
from transformers.testing_utils import require_torch, require_torchaudio from transformers.testing_utils import require_torch
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
...@@ -47,7 +47,6 @@ def floats_list(shape, scale=1.0, rng=None, name=None): ...@@ -47,7 +47,6 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
@require_torch @require_torch
@require_torchaudio
class MCTCTFeatureExtractionTester(unittest.TestCase): class MCTCTFeatureExtractionTester(unittest.TestCase):
def __init__( def __init__(
self, self,
...@@ -102,7 +101,6 @@ class MCTCTFeatureExtractionTester(unittest.TestCase): ...@@ -102,7 +101,6 @@ class MCTCTFeatureExtractionTester(unittest.TestCase):
@require_torch @require_torch
@require_torchaudio
class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = MCTCTFeatureExtractor if is_speech_available() else None feature_extraction_class = MCTCTFeatureExtractor if is_speech_available() else None
...@@ -271,3 +269,38 @@ class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Te ...@@ -271,3 +269,38 @@ class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Te
self.assertTrue(np_processed.input_features.dtype == np.float32) self.assertTrue(np_processed.input_features.dtype == np.float32)
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt") pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_features.dtype == torch.float32) self.assertTrue(pt_processed.input_features.dtype == torch.float32)
def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
def test_integration(self):
# fmt: off
expected = np.array([
[
1.1280, 1.1319, 1.2744, 1.4369, 1.4328, 1.3671, 1.2889, 1.3046,
1.4419, 0.8387, 0.2995, 0.0404, 0.1068, 0.0472, 0.3728, 1.3356,
1.4491, 0.4770, 0.3997, 0.2776, 0.3184, -0.1243, -0.1170, -0.0828
],
[
1.0826, 1.0565, 1.2110, 1.3886, 1.3416, 1.2009, 1.1894, 1.2707,
1.5153, 0.7005, 0.4916, 0.4017, 0.3743, 0.1935, 0.4228, 1.1084,
0.9768, 0.0608, 0.2044, 0.1723, 0.0433, -0.2360, -0.2478, -0.2643
],
[
1.0590, 0.9923, 1.1185, 1.3309, 1.1971, 1.0067, 1.0080, 1.2036,
1.5397, 1.0383, 0.7672, 0.7551, 0.4878, 0.8771, 0.7565, 0.8775,
0.9042, 0.4595, 0.6157, 0.4954, 0.1857, 0.0307, 0.0199, 0.1033
],
])
# fmt: on
input_speech = self._load_datasamples(1)
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
input_features = feature_extractor(input_speech, sampling_rate=16000, return_tensors="pt").input_features
self.assertTrue(np.allclose(input_features[0, 100:103], expected, atol=1e-4))
...@@ -247,3 +247,27 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt ...@@ -247,3 +247,27 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
self.assertTrue(np_processed.input_features.dtype == np.float32) self.assertTrue(np_processed.input_features.dtype == np.float32)
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt") pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_features.dtype == torch.float32) self.assertTrue(pt_processed.input_features.dtype == torch.float32)
def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
def test_integration(self):
# fmt: off
expected = np.array([
-1.5745, -1.7713, -1.7020, -1.6069, -1.2250, -1.1105, -0.9072, -0.8241,
-1.2310, -0.8098, -0.3320, -0.4101, -0.7985, -0.4996, -0.8213, -0.9128,
-1.0420, -1.1286, -1.0440, -0.7999, -0.8405, -1.2275, -1.5443, -1.4625,
])
# fmt: on
input_speech = self._load_datasamples(1)
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
self.assertEquals(input_features.shape, (1, 584, 24))
self.assertTrue(np.allclose(input_features[0, 0, :30], expected, atol=1e-4))
...@@ -395,7 +395,8 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest ...@@ -395,7 +395,8 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
input_speech = self._load_datasamples(1) input_speech = self._load_datasamples(1)
feature_extractor = SpeechT5FeatureExtractor() feature_extractor = SpeechT5FeatureExtractor()
input_values = feature_extractor(input_speech, return_tensors="pt").input_values input_values = feature_extractor(input_speech, return_tensors="pt").input_values
self.assertTrue(torch.allclose(input_values[0, :30], EXPECTED_INPUT_VALUES, atol=1e-4)) self.assertEquals(input_values.shape, (1, 93680))
self.assertTrue(torch.allclose(input_values[0, :30], EXPECTED_INPUT_VALUES, atol=1e-6))
def test_integration_target(self): def test_integration_target(self):
# fmt: off # fmt: off
...@@ -410,4 +411,5 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest ...@@ -410,4 +411,5 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
input_speech = self._load_datasamples(1) input_speech = self._load_datasamples(1)
feature_extractor = SpeechT5FeatureExtractor() feature_extractor = SpeechT5FeatureExtractor()
input_values = feature_extractor(audio_target=input_speech, return_tensors="pt").input_values input_values = feature_extractor(audio_target=input_speech, return_tensors="pt").input_values
self.assertEquals(input_values.shape, (1, 366, 80))
self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4)) self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
...@@ -198,10 +198,10 @@ class TvltFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes ...@@ -198,10 +198,10 @@ class TvltFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
def test_integration(self): def test_integration(self):
input_speech = self._load_datasamples(1) input_speech = self._load_datasamples(1)
feaure_extractor = TvltFeatureExtractor() feature_extractor = TvltFeatureExtractor()
audio_values = feaure_extractor(input_speech, return_tensors="pt").audio_values audio_values = feature_extractor(input_speech, return_tensors="pt").audio_values
self.assertTrue(audio_values.shape, [1, 1, 192, 128]) self.assertEquals(audio_values.shape, (1, 1, 192, 128))
expected_slice = torch.tensor([[-0.3032, -0.2708], [-0.4434, -0.4007]]) expected_slice = torch.tensor([[-0.3032, -0.2708], [-0.4434, -0.4007]])
self.assertTrue(torch.allclose(audio_values[0, 0, :2, :2], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(audio_values[0, 0, :2, :2], expected_slice, atol=1e-4))
...@@ -218,8 +218,9 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest. ...@@ -218,8 +218,9 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
# fmt: on # fmt: on
input_speech = self._load_datasamples(1) input_speech = self._load_datasamples(1)
feaure_extractor = WhisperFeatureExtractor() feature_extractor = WhisperFeatureExtractor()
input_features = feaure_extractor(input_speech, return_tensors="pt").input_features input_features = feature_extractor(input_speech, return_tensors="pt").input_features
self.assertEqual(input_features.shape, (1, 80, 3000))
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4)) self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
def test_zero_mean_unit_variance_normalization_trunc_np_longest(self): def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment