"docs/source/vscode:/vscode.git/clone" did not exist on "1c1a2ffbff2052100053cddb3a87d45fb9d210ca"
Unverified Commit 7f919509 authored by Matthijs Hollemans's avatar Matthijs Hollemans Committed by GitHub
Browse files

audio_utils improvements (#21998)

* silly change to allow making a PR

* clean up doc comments

* simplify hertz_to_mel and mel_to_hertz

* fixup

* clean up power_to_db

* also add amplitude_to_db

* move functions

* clean up mel_filter_bank

* fixup

* credit librosa & torchaudio authors

* add unit tests

* tests for power_to_db and amplitude_to_db

* add mel_filter_bank tests

* rewrite STFT

* add convenience spectrogram function

* missing transpose

* fewer transposes

* add integration test to M-CTC-T

* frame length can be either window or FFT length

* rewrite stft API

* add preemphasis coefficient

* move argument

* add log option to spectrogram

* replace M-CTC-T feature extractor

* fix api thing

* replace whisper STFT

* replace whisper mel filters

* replace tvlt's stft

* allow alternate window names

* replace speecht5 stft

* fixup

* fix integration tests

* fix doc comments

* remove manual FFT length calculation

* fix docs

* go away, deprecation warnings

* combine everything into spectrogram function

* add deprecated functions back

* fixup
parent 431b04d8
......@@ -12,10 +12,9 @@ specific language governing permissions and limitations under the License.
# Utilities for `FeatureExtractors`
This page lists all the utility functions that can be used by the audio [`FeatureExtractor`] in order to compute special features from a raw audio using common algorithms such as *Short Time Fourier Transform* or *Mel log spectrogram*.
This page lists all the utility functions that can be used by the audio [`FeatureExtractor`] in order to compute special features from a raw audio using common algorithms such as *Short Time Fourier Transform* or *log mel spectrogram*.
Most of those are only useful if you are studying the code of the image processors in the library.
Most of those are only useful if you are studying the code of the audio processors in the library.
## Audio Transformations
......@@ -23,12 +22,14 @@ Most of those are only useful if you are studying the code of the image processo
[[autodoc]] audio_utils.mel_to_hertz
[[autodoc]] audio_utils.get_mel_filter_banks
[[autodoc]] audio_utils.mel_filter_bank
[[autodoc]] audio_utils.stft
[[autodoc]] audio_utils.optimal_fft_length
[[autodoc]] audio_utils.power_to_db
[[autodoc]] audio_utils.window_function
[[autodoc]] audio_utils.fram_wave
[[autodoc]] audio_utils.spectrogram
[[autodoc]] audio_utils.power_to_db
[[autodoc]] audio_utils.amplitude_to_db
This diff is collapsed.
......@@ -21,7 +21,7 @@ from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from ...audio_utils import fram_wave, get_mel_filter_banks, power_to_db, stft
from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, logging
......@@ -116,21 +116,21 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
self.sampling_rate = sampling_rate
self.frequency_min = frequency_min
self.frequency_max = frequency_max
self.mel_filters = get_mel_filter_banks(
nb_frequency_bins=self.nb_frequency_bins,
nb_mel_filters=feature_size,
frequency_min=frequency_min,
frequency_max=frequency_max,
sample_rate=sampling_rate,
self.mel_filters = mel_filter_bank(
num_frequency_bins=self.nb_frequency_bins,
num_mel_filters=feature_size,
min_frequency=frequency_min,
max_frequency=frequency_max,
sampling_rate=sampling_rate,
norm=None,
mel_scale="htk",
)
self.mel_filters_slaney = get_mel_filter_banks(
nb_frequency_bins=self.nb_frequency_bins,
nb_mel_filters=feature_size,
frequency_min=frequency_min,
frequency_max=frequency_max,
sample_rate=sampling_rate,
self.mel_filters_slaney = mel_filter_bank(
num_frequency_bins=self.nb_frequency_bins,
num_mel_filters=feature_size,
min_frequency=frequency_min,
max_frequency=frequency_max,
sampling_rate=sampling_rate,
norm="slaney",
mel_scale="slaney",
)
......@@ -153,24 +153,25 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
def _np_extract_fbank_features(self, waveform: np.array, mel_filters: Optional[np.array] = None) -> np.ndarray:
"""
Compute the log-Mel spectrogram of the provided `waveform` using the `hanning` window. In CLAP, two different
filter banks are used depending on the truncation pattern:
- `self.mel_filters`: they correspond to the defaults parameters of `torchaduio` which can be obtained from
Compute the log-mel spectrogram of the provided `waveform` using the Hann window. In CLAP, two different filter
banks are used depending on the truncation pattern:
- `self.mel_filters`: they correspond to the default parameters of `torchaudio` which can be obtained from
calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation`
is set to `"fusion"`.
- `self.mel_filteres_slaney` : they correspond to the defaults parameters of `torchlibrosa` which used
- `self.mel_filteres_slaney` : they correspond to the default parameters of `librosa` which used
`librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original
implementation when the truncation mode is not `"fusion"`.
"""
window = np.hanning(self.fft_window_size + 1)[:-1]
frames = fram_wave(waveform, self.hop_length, self.fft_window_size)
spectrogram = stft(frames, window, fft_window_size=self.fft_window_size)
magnitudes = np.abs(spectrogram) ** 2
mel_spectrogram = np.matmul(mel_filters.T, magnitudes)
log_mel_spectrogram = power_to_db(mel_spectrogram).T
log_mel_spectrogram = np.asarray(log_mel_spectrogram, np.float32)
return log_mel_spectrogram
log_mel_spectrogram = spectrogram(
waveform,
window_function(self.fft_window_size, "hann"),
frame_length=self.fft_window_size,
hop_length=self.hop_length,
power=2.0,
mel_filters=mel_filters,
log_mel="dB",
)
return log_mel_spectrogram.T
def _random_mel_fusion(self, mel, total_frames, chunk_frames):
ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3)
......
......@@ -20,9 +20,8 @@ from typing import List, Optional, Union
import numpy as np
import torch
import torchaudio
from packaging import version
from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...file_utils import PaddingStrategy, TensorType
......@@ -31,13 +30,6 @@ from ...utils import logging
logger = logging.get_logger(__name__)
parsed_torchaudio_version_base = version.parse(version.parse(torchaudio.__version__).base_version)
if not parsed_torchaudio_version_base >= version.parse("0.10"):
logger.warning(
f"You are using torchaudio=={torchaudio.__version__}, but torchaudio>=0.10.0 is required to use "
"MCTCTFeatureExtractor. This requires torch>=1.10.0. Please upgrade torch and torchaudio."
)
class MCTCTFeatureExtractor(SequenceFeatureExtractor):
r"""
......@@ -110,68 +102,9 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor):
self.sample_size = win_length * sampling_rate // 1000
self.sample_stride = hop_length * sampling_rate // 1000
self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size)))
self.n_fft = optimal_fft_length(self.sample_size)
self.n_freqs = (self.n_fft // 2) + 1
@staticmethod
def _num_frames_calc(in_size, frame_size, frame_stride):
return int(1 + np.floor((in_size - frame_size) * 1 / frame_stride))
@staticmethod
def _frame_signal(one_waveform, n_frames, frame_signal_scale, window_length, sample_stride):
scale = frame_signal_scale
frames = np.zeros(n_frames * window_length)
for frame_idx in range(n_frames):
start = frame_idx * window_length
end = (frame_idx + 1) * window_length
wave_start = frame_idx * sample_stride
wave_end = frame_idx * sample_stride + window_length
frames[start:end] = scale * one_waveform[wave_start:wave_end]
return frames
@staticmethod
def _apply_preemphasis_inplace(frames, window_length, preemphasis_coeff):
if frames.size % window_length != 0:
raise ValueError(
f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
f" window_length={window_length}."
)
n_frames = frames.size // window_length
for frame_idx in range(n_frames, 0, -1):
start = (frame_idx - 1) * window_length
end = frame_idx * window_length - 1
frames[start + 1 : end + 1] -= preemphasis_coeff * frames[start:end]
frames[start] *= 1 - preemphasis_coeff
@staticmethod
def _windowing(frames, window_length, window):
if frames.size % window_length != 0:
raise ValueError(
f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
f" window_length={window_length}."
)
shaped = frames.reshape(-1, window_length)
shaped = window * shaped
return shaped
@staticmethod
def _dft(frames, K, n_frames, n_samples, n_fft):
dft = np.zeros([n_frames, K])
for frame in range(n_frames):
begin = frame * n_samples
inwards_buffer = frames[begin : begin + n_samples]
inwards_buffer = np.pad(inwards_buffer, (0, n_fft - n_samples), "constant")
out = np.fft.rfft(inwards_buffer)
dft[frame] = np.abs(out[:K])
return dft
def _extract_mfsc_features(self, one_waveform: np.array) -> np.ndarray:
"""
Extracts MFSC Features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC code.
......@@ -183,36 +116,27 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor):
window = window.numpy()
fbanks = torchaudio.functional.melscale_fbanks(
n_freqs=self.n_freqs,
f_min=0.0, # change this to zeros
f_max=self.sampling_rate / 2.0,
n_mels=self.feature_size,
sample_rate=self.sampling_rate,
fbanks = mel_filter_bank(
num_frequency_bins=self.n_freqs,
num_mel_filters=self.feature_size,
min_frequency=0.0,
max_frequency=self.sampling_rate / 2.0,
sampling_rate=self.sampling_rate,
)
fbanks = fbanks.numpy()
n_frames = self._num_frames_calc(one_waveform.size, self.sample_size, self.sample_stride)
frames = self._frame_signal(
one_waveform, n_frames, self.frame_signal_scale, self.sample_size, self.sample_stride
msfc_features = spectrogram(
one_waveform * self.frame_signal_scale,
window=window,
frame_length=self.sample_size,
hop_length=self.sample_stride,
fft_length=self.n_fft,
center=False,
preemphasis=self.preemphasis_coeff,
mel_filters=fbanks,
mel_floor=self.mel_floor,
log_mel="log",
)
self._apply_preemphasis_inplace(frames, self.sample_size, self.preemphasis_coeff)
frames = self._windowing(frames, self.sample_size, window)
dft_out = self._dft(frames.flatten(), self.n_freqs, n_frames, self.sample_size, self.n_fft)
# msfc_features = STFT * mel frequency banks.
msfc_features = np.einsum("...tf,fm->...tm", dft_out, fbanks)
# clamp feature values then log scale, as implemented in flashlight
msfc_features = np.maximum(msfc_features, self.mel_floor)
msfc_features = np.log(msfc_features)
return msfc_features
return msfc_features.T
def _normalize_one(self, x, input_length, padding_value):
# make sure we normalize float32 arrays
......
......@@ -20,7 +20,7 @@ from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from ...audio_utils import get_mel_filter_banks
from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging
......@@ -110,18 +110,18 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
self.sample_size = win_length * sampling_rate // 1000
self.sample_stride = hop_length * sampling_rate // 1000
self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size)))
self.n_fft = optimal_fft_length(self.sample_size)
self.n_freqs = (self.n_fft // 2) + 1
window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True)
self.window = window.numpy().astype(np.float64)
self.mel_filters = get_mel_filter_banks(
nb_frequency_bins=self.n_freqs,
nb_mel_filters=self.num_mel_bins,
frequency_min=self.fmin,
frequency_max=self.fmax,
sample_rate=self.sampling_rate,
self.mel_filters = mel_filter_bank(
num_frequency_bins=self.n_freqs,
num_mel_filters=self.num_mel_bins,
min_frequency=self.fmin,
max_frequency=self.fmax,
sampling_rate=self.sampling_rate,
norm="slaney",
mel_scale="slaney",
)
......@@ -160,31 +160,6 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
return normed_input_values
@staticmethod
def _stft(waveform: np.ndarray, fft_length: int, hop_length: int, window: np.ndarray) -> np.ndarray:
"""
Calculates the magnitude spectrogram over one waveform array.
"""
# center pad the waveform
padding = [(int(fft_length // 2), int(fft_length // 2))]
waveform = np.pad(waveform, padding, mode="reflect")
waveform_size = waveform.size
# promote to float64, since np.fft uses float64 internally
waveform = waveform.astype(np.float64)
num_frames = int(1 + np.floor((waveform_size - fft_length) / hop_length))
num_frequency_bins = (fft_length // 2) + 1
spectrogram = np.empty((num_frames, num_frequency_bins))
start = 0
for frame_idx in range(num_frames):
frame = waveform[start : start + fft_length] * window
spectrogram[frame_idx] = np.abs(np.fft.rfft(frame))
start += hop_length
return spectrogram
def _extract_mel_features(
self,
one_waveform: np.ndarray,
......@@ -192,14 +167,17 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
"""
Extracts log-mel filterbank features for one waveform array (unbatched).
"""
if self.n_fft != self.sample_size:
raise NotImplementedError(
f"Currently the STFT frame size must be a power of two, but got {self.sample_size} for a window length of {self.win_length} and sampling rate of {self.sampling_rate}. Ensure `win_length * sampling_rate // 1000` is divisible by two."
)
stft_out = self._stft(one_waveform, self.n_fft, self.sample_stride, self.window)
return np.log10(np.maximum(self.mel_floor, np.dot(stft_out, self.mel_filters)))
log_mel_spec = spectrogram(
one_waveform,
window=self.window,
frame_length=self.sample_size,
hop_length=self.sample_stride,
fft_length=self.n_fft,
mel_filters=self.mel_filters,
mel_floor=self.mel_floor,
log_mel="log10",
)
return log_mel_spec.T
def __call__(
self,
......
......@@ -18,8 +18,8 @@ from math import ceil
from typing import List, Optional, Union
import numpy as np
from numpy.fft import fft
from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import BatchFeature, SequenceFeatureExtractor
from ...utils import TensorType, logging
......@@ -83,143 +83,34 @@ class TvltFeatureExtractor(SequenceFeatureExtractor):
self.hop_length = sampling_rate // hop_length_to_sampling_rate
self.sampling_rate = sampling_rate
self.padding_value = padding_value
self.mel_filters = self.get_mel_filters(sampling_rate, n_fft, n_mels=feature_size)
# Copied from transformers.models.whisper.feature_extraction_whisper.WhisperFeatureExtractor.get_mel_filters with 45.245640471924965->59.99247463746737
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
# Initialize the weights
n_mels = int(n_mels)
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
# Center freqs of each FFT bin
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 59.99247463746737
mels = np.linspace(min_mel, max_mel, n_mels + 2)
mels = np.asanyarray(mels)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
mel_f = freqs
fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, np.newaxis]
return weights
# Copied from transformers.models.whisper.feature_extraction_whisper.WhisperFeatureExtractor.fram_wave
def fram_wave(self, waveform, center=True):
"""
Transform a raw waveform into a list of smaller waveforms. The window length defines how much of the signal is
contain in each frame (smalle waveform), while the hope length defines the step between the beginning of each
new frame.
Centering is done by reflecting the waveform which is first centered around `frame_idx * hop_length`.
"""
frames = []
for i in range(0, waveform.shape[0] + 1, self.hop_length):
half_window = (self.n_fft - 1) // 2 + 1
if center:
start = i - half_window if i > half_window else 0
end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
frame = waveform[start:end]
if start == 0:
padd_width = (-i + half_window, 0)
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
elif end == waveform.shape[0]:
padd_width = (0, (i - waveform.shape[0] + half_window))
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
else:
frame = waveform[i : i + self.n_fft]
frame_width = frame.shape[0]
if frame_width < waveform.shape[0]:
frame = np.lib.pad(
frame, pad_width=(0, self.n_fft - frame_width), mode="constant", constant_values=0
)
frames.append(frame)
return np.stack(frames, 0)
# Copied from transformers.models.whisper.feature_extraction_whisper.WhisperFeatureExtractor.stft
def stft(self, frames, window):
"""
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same
results as `torch.stft`.
"""
frame_size = frames.shape[1]
fft_size = self.n_fft
if fft_size is None:
fft_size = frame_size
if fft_size < frame_size:
raise ValueError("FFT size must greater or equal the frame size")
# number of FFT bins to store
num_fft_bins = (fft_size >> 1) + 1
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
fft_signal = np.zeros(fft_size)
for f, frame in enumerate(frames):
if window is not None:
np.multiply(frame, window, out=fft_signal[:frame_size])
else:
fft_signal[:frame_size] = frame
data[f] = fft(fft_signal, axis=0)[:num_fft_bins]
return data.T
self.mel_filters = mel_filter_bank(
num_frequency_bins=1 + n_fft // 2,
num_mel_filters=feature_size,
min_frequency=0.0,
max_frequency=22050.0,
sampling_rate=sampling_rate,
norm="slaney",
mel_scale="slaney",
).T
def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
"""
Compute the log-Mel spectrogram of the provided audio, gives similar results whisper's original torch
Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
implementation with 1e-5 tolerance.
"""
window = np.hanning(self.n_fft + 1)[:-1]
frames = self.fram_wave(waveform)
stft = self.stft(frames, window=window)
magnitudes = np.abs(stft[:, :-1]) ** 2
filters = self.mel_filters
mel_spec = filters @ magnitudes
log_spec = 10.0 * np.log10(np.maximum(1e-10, mel_spec))
log_spec -= 10.0 * np.log10(np.maximum(1e-10, 1.0))
log_spec = np.maximum(log_spec, log_spec.max() - 80.0)
log_spec = spectrogram(
waveform,
window_function(self.n_fft, "hann"),
frame_length=self.n_fft,
hop_length=self.hop_length,
power=2.0,
mel_filters=self.mel_filters.T,
log_mel="dB",
db_range=80.0,
)
log_spec = log_spec[:, :-1]
log_spec = log_spec - 20.0
log_spec = np.clip(log_spec / 40.0, -2.0, 0.0) + 1.0
return log_spec
def __call__(
......
......@@ -19,8 +19,8 @@ import copy
from typing import Any, Dict, List, Optional, Union
import numpy as np
from numpy.fft import fft
from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, logging
......@@ -81,138 +81,33 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length
self.sampling_rate = sampling_rate
self.mel_filters = self.get_mel_filters(sampling_rate, n_fft, n_mels=feature_size)
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
# Initialize the weights
n_mels = int(n_mels)
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
# Center freqs of each FFT bin
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 45.245640471924965
mels = np.linspace(min_mel, max_mel, n_mels + 2)
mels = np.asanyarray(mels)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
mel_f = freqs
fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, np.newaxis]
return weights
def fram_wave(self, waveform, center=True):
"""
Transform a raw waveform into a list of smaller waveforms. The window length defines how much of the signal is
contain in each frame (smalle waveform), while the hope length defines the step between the beginning of each
new frame.
Centering is done by reflecting the waveform which is first centered around `frame_idx * hop_length`.
"""
frames = []
for i in range(0, waveform.shape[0] + 1, self.hop_length):
half_window = (self.n_fft - 1) // 2 + 1
if center:
start = i - half_window if i > half_window else 0
end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
frame = waveform[start:end]
if start == 0:
padd_width = (-i + half_window, 0)
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
elif end == waveform.shape[0]:
padd_width = (0, (i - waveform.shape[0] + half_window))
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
else:
frame = waveform[i : i + self.n_fft]
frame_width = frame.shape[0]
if frame_width < waveform.shape[0]:
frame = np.lib.pad(
frame, pad_width=(0, self.n_fft - frame_width), mode="constant", constant_values=0
)
frames.append(frame)
return np.stack(frames, 0)
def stft(self, frames, window):
"""
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same
results as `torch.stft`.
"""
frame_size = frames.shape[1]
fft_size = self.n_fft
if fft_size is None:
fft_size = frame_size
if fft_size < frame_size:
raise ValueError("FFT size must greater or equal the frame size")
# number of FFT bins to store
num_fft_bins = (fft_size >> 1) + 1
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
fft_signal = np.zeros(fft_size)
for f, frame in enumerate(frames):
if window is not None:
np.multiply(frame, window, out=fft_signal[:frame_size])
else:
fft_signal[:frame_size] = frame
data[f] = fft(fft_signal, axis=0)[:num_fft_bins]
return data.T
self.mel_filters = mel_filter_bank(
num_frequency_bins=1 + n_fft // 2,
num_mel_filters=feature_size,
min_frequency=0.0,
max_frequency=8000.0,
sampling_rate=sampling_rate,
norm="slaney",
mel_scale="slaney",
)
def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
"""
Compute the log-Mel spectrogram of the provided audio, gives similar results whisper's original torch
Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
implementation with 1e-5 tolerance.
"""
window = np.hanning(self.n_fft + 1)[:-1]
frames = self.fram_wave(waveform)
stft = self.stft(frames, window=window)
magnitudes = np.abs(stft[:, :-1]) ** 2
filters = self.mel_filters
mel_spec = filters @ magnitudes
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
log_spec = spectrogram(
waveform,
window_function(self.n_fft, "hann"),
frame_length=self.n_fft,
hop_length=self.hop_length,
power=2.0,
mel_filters=self.mel_filters,
log_mel="log10",
)
log_spec = log_spec[:, :-1]
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
@staticmethod
......
......@@ -160,6 +160,7 @@ class ASTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test
# fmt: on
input_speech = self._load_datasamples(1)
feaure_extractor = ASTFeatureExtractor()
input_values = feaure_extractor(input_speech, return_tensors="pt").input_values
feature_extractor = ASTFeatureExtractor()
input_values = feature_extractor(input_speech, return_tensors="pt").input_values
self.assertEquals(input_values.shape, (1, 1024, 128))
self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
......@@ -21,7 +21,7 @@ import unittest
import numpy as np
from transformers import is_speech_available
from transformers.testing_utils import require_torch, require_torchaudio
from transformers.testing_utils import require_torch
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
......@@ -47,7 +47,6 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
@require_torch
@require_torchaudio
class MCTCTFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
......@@ -102,7 +101,6 @@ class MCTCTFeatureExtractionTester(unittest.TestCase):
@require_torch
@require_torchaudio
class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = MCTCTFeatureExtractor if is_speech_available() else None
......@@ -271,3 +269,38 @@ class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Te
self.assertTrue(np_processed.input_features.dtype == np.float32)
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
def test_integration(self):
# fmt: off
expected = np.array([
[
1.1280, 1.1319, 1.2744, 1.4369, 1.4328, 1.3671, 1.2889, 1.3046,
1.4419, 0.8387, 0.2995, 0.0404, 0.1068, 0.0472, 0.3728, 1.3356,
1.4491, 0.4770, 0.3997, 0.2776, 0.3184, -0.1243, -0.1170, -0.0828
],
[
1.0826, 1.0565, 1.2110, 1.3886, 1.3416, 1.2009, 1.1894, 1.2707,
1.5153, 0.7005, 0.4916, 0.4017, 0.3743, 0.1935, 0.4228, 1.1084,
0.9768, 0.0608, 0.2044, 0.1723, 0.0433, -0.2360, -0.2478, -0.2643
],
[
1.0590, 0.9923, 1.1185, 1.3309, 1.1971, 1.0067, 1.0080, 1.2036,
1.5397, 1.0383, 0.7672, 0.7551, 0.4878, 0.8771, 0.7565, 0.8775,
0.9042, 0.4595, 0.6157, 0.4954, 0.1857, 0.0307, 0.0199, 0.1033
],
])
# fmt: on
input_speech = self._load_datasamples(1)
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
input_features = feature_extractor(input_speech, sampling_rate=16000, return_tensors="pt").input_features
self.assertTrue(np.allclose(input_features[0, 100:103], expected, atol=1e-4))
......@@ -247,3 +247,27 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
self.assertTrue(np_processed.input_features.dtype == np.float32)
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
def test_integration(self):
# fmt: off
expected = np.array([
-1.5745, -1.7713, -1.7020, -1.6069, -1.2250, -1.1105, -0.9072, -0.8241,
-1.2310, -0.8098, -0.3320, -0.4101, -0.7985, -0.4996, -0.8213, -0.9128,
-1.0420, -1.1286, -1.0440, -0.7999, -0.8405, -1.2275, -1.5443, -1.4625,
])
# fmt: on
input_speech = self._load_datasamples(1)
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
self.assertEquals(input_features.shape, (1, 584, 24))
self.assertTrue(np.allclose(input_features[0, 0, :30], expected, atol=1e-4))
......@@ -395,7 +395,8 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
input_speech = self._load_datasamples(1)
feature_extractor = SpeechT5FeatureExtractor()
input_values = feature_extractor(input_speech, return_tensors="pt").input_values
self.assertTrue(torch.allclose(input_values[0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
self.assertEquals(input_values.shape, (1, 93680))
self.assertTrue(torch.allclose(input_values[0, :30], EXPECTED_INPUT_VALUES, atol=1e-6))
def test_integration_target(self):
# fmt: off
......@@ -410,4 +411,5 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
input_speech = self._load_datasamples(1)
feature_extractor = SpeechT5FeatureExtractor()
input_values = feature_extractor(audio_target=input_speech, return_tensors="pt").input_values
self.assertEquals(input_values.shape, (1, 366, 80))
self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
......@@ -198,10 +198,10 @@ class TvltFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
def test_integration(self):
input_speech = self._load_datasamples(1)
feaure_extractor = TvltFeatureExtractor()
audio_values = feaure_extractor(input_speech, return_tensors="pt").audio_values
feature_extractor = TvltFeatureExtractor()
audio_values = feature_extractor(input_speech, return_tensors="pt").audio_values
self.assertTrue(audio_values.shape, [1, 1, 192, 128])
self.assertEquals(audio_values.shape, (1, 1, 192, 128))
expected_slice = torch.tensor([[-0.3032, -0.2708], [-0.4434, -0.4007]])
self.assertTrue(torch.allclose(audio_values[0, 0, :2, :2], expected_slice, atol=1e-4))
......@@ -218,8 +218,9 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
# fmt: on
input_speech = self._load_datasamples(1)
feaure_extractor = WhisperFeatureExtractor()
input_features = feaure_extractor(input_speech, return_tensors="pt").input_features
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
self.assertEqual(input_features.shape, (1, 80, 3000))
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment