Unverified Commit 7f919509 authored by Matthijs Hollemans's avatar Matthijs Hollemans Committed by GitHub
Browse files

audio_utils improvements (#21998)

* silly change to allow making a PR

* clean up doc comments

* simplify hertz_to_mel and mel_to_hertz

* fixup

* clean up power_to_db

* also add amplitude_to_db

* move functions

* clean up mel_filter_bank

* fixup

* credit librosa & torchaudio authors

* add unit tests

* tests for power_to_db and amplitude_to_db

* add mel_filter_bank tests

* rewrite STFT

* add convenience spectrogram function

* missing transpose

* fewer transposes

* add integration test to M-CTC-T

* frame length can be either window or FFT length

* rewrite stft API

* add preemphasis coefficient

* move argument

* add log option to spectrogram

* replace M-CTC-T feature extractor

* fix api thing

* replace whisper STFT

* replace whisper mel filters

* replace tvlt's stft

* allow alternate window names

* replace speecht5 stft

* fixup

* fix integration tests

* fix doc comments

* remove manual FFT length calculation

* fix docs

* go away, deprecation warnings

* combine everything into spectrogram function

* add deprecated functions back

* fixup
parent 431b04d8
...@@ -12,10 +12,9 @@ specific language governing permissions and limitations under the License. ...@@ -12,10 +12,9 @@ specific language governing permissions and limitations under the License.
# Utilities for `FeatureExtractors` # Utilities for `FeatureExtractors`
This page lists all the utility functions that can be used by the audio [`FeatureExtractor`] in order to compute special features from a raw audio using common algorithms such as *Short Time Fourier Transform* or *Mel log spectrogram*. This page lists all the utility functions that can be used by the audio [`FeatureExtractor`] in order to compute special features from a raw audio using common algorithms such as *Short Time Fourier Transform* or *log mel spectrogram*.
Most of those are only useful if you are studying the code of the audio processors in the library.
Most of those are only useful if you are studying the code of the image processors in the library.
## Audio Transformations ## Audio Transformations
...@@ -23,12 +22,14 @@ Most of those are only useful if you are studying the code of the image processo ...@@ -23,12 +22,14 @@ Most of those are only useful if you are studying the code of the image processo
[[autodoc]] audio_utils.mel_to_hertz [[autodoc]] audio_utils.mel_to_hertz
[[autodoc]] audio_utils.get_mel_filter_banks [[autodoc]] audio_utils.mel_filter_bank
[[autodoc]] audio_utils.stft [[autodoc]] audio_utils.optimal_fft_length
[[autodoc]] audio_utils.power_to_db [[autodoc]] audio_utils.window_function
[[autodoc]] audio_utils.fram_wave [[autodoc]] audio_utils.spectrogram
[[autodoc]] audio_utils.power_to_db
[[autodoc]] audio_utils.amplitude_to_db
# coding=utf-8 # coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. # Copyright 2023 The HuggingFace Inc. team and the librosa & torchaudio authors.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,66 +13,61 @@ ...@@ -13,66 +13,61 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Audio processing functions to extract feature from a raw audio. Should all be in numpy to support all frameworks, and Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks
remmove unecessary dependencies. and remove unnecessary dependencies.
""" """
import math
import warnings import warnings
from typing import Optional from typing import Optional, Union
import numpy as np import numpy as np
from numpy.fft import fft
def hertz_to_mel(freq: float, mel_scale: str = "htk") -> float: def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
"""Convert Hertz to Mels. """
Convert frequency from hertz to mels.
Args: Args:
freqs (`float`): freq (`float` or `np.ndarray`):
Frequencies in Hertz The frequency, or multiple frequencies, in hertz (Hz).
mel_scale (`str`, *optional*, defaults to `"htk"`): mel_scale (`str`, *optional*, defaults to `"htk"`):
Scale to use, `htk` or `slaney`. The mel frequency scale to use, `"htk"` or `"slaney"`.
Returns: Returns:
mels (`float`): `float` or `np.ndarray`: The frequencies on the mel scale.
Frequency in Mels
""" """
if mel_scale not in ["slaney", "htk"]: if mel_scale not in ["slaney", "htk"]:
raise ValueError('mel_scale should be one of "htk" or "slaney".') raise ValueError('mel_scale should be one of "htk" or "slaney".')
if mel_scale == "htk": if mel_scale == "htk":
return 2595.0 * math.log10(1.0 + (freq / 700.0)) return 2595.0 * np.log10(1.0 + (freq / 700.0))
# Fill in the linear part
frequency_min = 0.0
f_sp = 200.0 / 3
mels = (freq - frequency_min) / f_sp
# Fill in the log-scale part
min_log_hertz = 1000.0 min_log_hertz = 1000.0
min_log_mel = (min_log_hertz - frequency_min) / f_sp min_log_mel = 15.0
logstep = math.log(6.4) / 27.0 logstep = 27.0 / np.log(6.4)
mels = 3.0 * freq / 200.0
if freq >= min_log_hertz: if isinstance(freq, np.ndarray):
mels = min_log_mel + math.log(freq / min_log_hertz) / logstep log_region = freq >= min_log_hertz
mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
elif freq >= min_log_hertz:
mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
return mels return mels
def mel_to_hertz(mels: np.array, mel_scale: str = "htk") -> np.array: def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
"""Convert mel bin numbers to frequencies. """
Convert frequency from mels to hertz.
Args: Args:
mels (`np.array`): mels (`float` or `np.ndarray`):
Mel frequencies The frequency, or multiple frequencies, in mels.
mel_scale (`str`, *optional*, `"htk"`): mel_scale (`str`, *optional*, `"htk"`):
Scale to use: `htk` or `slaney`. The mel frequency scale to use, `"htk"` or `"slaney"`.
Returns: Returns:
freqs (`np.array`): `float` or `np.ndarray`: The frequencies in hertz.
Mels converted to Hertz
""" """
if mel_scale not in ["slaney", "htk"]: if mel_scale not in ["slaney", "htk"]:
...@@ -81,171 +76,509 @@ def mel_to_hertz(mels: np.array, mel_scale: str = "htk") -> np.array: ...@@ -81,171 +76,509 @@ def mel_to_hertz(mels: np.array, mel_scale: str = "htk") -> np.array:
if mel_scale == "htk": if mel_scale == "htk":
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
# Fill in the linear scale
frequency_min = 0.0
f_sp = 200.0 / 3
freqs = frequency_min + f_sp * mels
# And now the nonlinear scale
min_log_hertz = 1000.0 min_log_hertz = 1000.0
min_log_mel = (min_log_hertz - frequency_min) / f_sp min_log_mel = 15.0
logstep = math.log(6.4) / 27.0 logstep = np.log(6.4) / 27.0
freq = 200.0 * mels / 3.0
log_t = mels >= min_log_mel if isinstance(mels, np.ndarray):
freqs[log_t] = min_log_hertz * np.exp(logstep * (mels[log_t] - min_log_mel)) log_region = mels >= min_log_mel
freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
elif mels >= min_log_mel:
freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel))
return freqs return freq
def _create_triangular_filterbank( def _create_triangular_filter_bank(fft_freqs: np.ndarray, filter_freqs: np.ndarray) -> np.ndarray:
all_freqs: np.array, """
f_pts: np.array, Creates a triangular filter bank.
) -> np.array:
"""Create a triangular filter bank.
Adapted from *torchaudio* and *librosa*.
Args: Args:
all_freqs (`np.array` of shape (`nb_frequency_bins`, )): fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`):
Discrete frequencies used when the STFT was computed. Discrete frequencies of the FFT bins in Hz.
f_pts (`np.array`, of shape (`nb_mel_filters`, )): filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`):
Coordinates of the middle points of the triangular filters to create. Center frequencies of the triangular filters to create, in Hz.
Returns: Returns:
fb (np.array): `np.ndarray` of shape `(num_frequency_bins, num_mel_filters)`
The filter bank of size (`nb_frequency_bins`, `nb_mel_filters`).
""" """
# Adapted from Librosa filter_diff = np.diff(filter_freqs)
# calculate the difference between each filter mid point and each stft freq point in hertz slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1) down_slopes = -slopes[:, :-2] / filter_diff[:-1]
slopes = np.expand_dims(f_pts, 0) - np.expand_dims(all_freqs, 1) # (nb_frequency_bins, n_filter + 2) up_slopes = slopes[:, 2:] / filter_diff[1:]
# create overlapping triangles return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
zero = np.zeros(1)
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (nb_frequency_bins, n_filter)
up_slopes = slopes[:, 2:] / f_diff[1:] # (nb_frequency_bins, n_filter) def mel_filter_bank(
fb = np.maximum(zero, np.minimum(down_slopes, up_slopes)) num_frequency_bins: int,
num_mel_filters: int,
return fb min_frequency: float,
max_frequency: float,
sampling_rate: int,
def get_mel_filter_banks(
nb_frequency_bins: int,
nb_mel_filters: int,
frequency_min: float,
frequency_max: float,
sample_rate: int,
norm: Optional[str] = None, norm: Optional[str] = None,
mel_scale: str = "htk", mel_scale: str = "htk",
) -> np.array: ) -> np.ndarray:
""" """
Create a frequency bin conversion matrix used to obtain the Mel Spectrogram. This is called a *mel filter bank*, Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
and various implementation exist, which differ in the number of filters, the shape of the filters, the way the various implementation exist, which differ in the number of filters, the shape of the filters, the way the filters
filters are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these
features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency. features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency.
This code is heavily inspired from the *torchaudio* implementation, see
[here](https://pytorch.org/audio/stable/transforms.html) for more details. Different banks of mel filters were introduced in the literature. The following variations are supported:
- MFCC FB-20: introduced in 1980 by Davis and Mermelstein, it assumes a sampling frequency of 10 kHz and a speech
Tips: bandwidth of `[0, 4600]` Hz.
- Different banks of Mel filters were introduced in the litterature. The following variation are supported: - MFCC FB-24 HTK: from the Cambridge HMM Toolkit (HTK) (1995) uses a filter bank of 24 filters for a speech
- MFCC FB-20: introduced in 1980 by Davis and Mermelstein, it assumes a sampling frequency of 10 kHertz bandwidth of `[0, 8000]` Hz. This assumes sampling rate ≥ 16 kHz.
and a speech bandwidth of `[0, 4600]` Hertz - MFCC FB-40: from the Auditory Toolbox for MATLAB written by Slaney in 1998, assumes a sampling rate of 16 kHz and
- MFCC FB-24 HTK: from the Cambridge HMM Toolkit (HTK) (1995) uses a filter bank of 24 filters for a speech bandwidth of `[133, 6854]` Hz. This version also includes area normalization.
speech bandwidth `[0, 8000]` Hertz (sampling rate ≥ 16 kHertz). - HFCC-E FB-29 (Human Factor Cepstral Coefficients) of Skowronski and Harris (2004), assumes a sampling rate of
- MFCC FB-40: from the Auditory Toolbox for MATLAB written by Slaney in 1998, assumes a sampling rate 12.5 kHz and speech bandwidth of `[0, 6250]` Hz.
of 16 kHertz, and speech bandwidth [133, 6854] Hertz. This version also includes an area normalization.
- HFCC-E FB-29 (Human Factor Cepstral Coefficients) of Skowronski and Harris (2004), assumes sampling This code is adapted from *torchaudio* and *librosa*. Note that the default parameters of torchaudio's
rate of 12.5 kHertz and speech bandwidth [0, 6250] Hertz `melscale_fbanks` implement the `"htk"` filters while librosa uses the `"slaney"` implementation.
- The default parameters of `torchaudio`'s mel filterbanks implement the `"htk"` filers while `torchlibrosa`
uses the `"slaney"` implementation.
Args: Args:
nb_frequency_bins (`int`): num_frequency_bins (`int`):
Number of frequencies used to compute the spectrogram (should be the same as in `stft`). Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
nb_mel_filters (`int`): num_mel_filters (`int`):
Number of Mel filers to generate. Number of mel filters to generate.
frequency_min (`float`): min_frequency (`float`):
Minimum frequency of interest(Hertz). Lowest frequency of interest in Hz.
frequency_max (`float`): max_frequency (`float`):
Maximum frequency of interest(Hertz). Highest frequency of interest in Hz. This should not exceed `sampling_rate / 2`.
sample_rate (`int`): sampling_rate (`int`):
Sample rate of the audio waveform. Sample rate of the audio waveform.
norm (`str`, *optional*): norm (`str`, *optional*):
If "slaney", divide the triangular Mel weights by the width of the mel band (area normalization). If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization).
mel_scale (`str`, *optional*, defaults to `"htk"`): mel_scale (`str`, *optional*, defaults to `"htk"`):
Scale to use: `"htk"` or `"slaney"`. The mel frequency scale to use, `"htk"` or `"slaney"`.
Returns: Returns:
`np.ndarray`: Triangular filter banks (fb matrix) of shape (`nb_frequency_bins`, `nb_mel_filters`). This matrix `np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a
is a projection matrix to go from a spectrogram to a Mel Spectrogram. projection matrix to go from a spectrogram to a mel spectrogram.
""" """
if norm is not None and norm != "slaney": if norm is not None and norm != "slaney":
raise ValueError('norm must be one of None or "slaney"') raise ValueError('norm must be one of None or "slaney"')
# freqency bins # frequencies of FFT bins in Hz
all_freqs = np.linspace(0, sample_rate // 2, nb_frequency_bins) fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
# Compute mim and max frequencies in mel scale # center points of the triangular mel filters
m_min = hertz_to_mel(frequency_min, mel_scale=mel_scale) mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
m_max = hertz_to_mel(frequency_max, mel_scale=mel_scale) mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
# create the centers of the triangular mel filters. mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
m_pts = np.linspace(m_min, m_max, nb_mel_filters + 2)
f_pts = mel_to_hertz(m_pts, mel_scale=mel_scale)
# create the filterbank
filterbank = _create_triangular_filterbank(all_freqs, f_pts)
if norm is not None and norm == "slaney": if norm is not None and norm == "slaney":
# Slaney-style mel is scaled to be approx constant energy per channel # Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (f_pts[2 : nb_mel_filters + 2] - f_pts[:nb_mel_filters]) enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters])
filterbank *= np.expand_dims(enorm, 0) mel_filters *= np.expand_dims(enorm, 0)
if (filterbank.max(axis=0) == 0.0).any(): if (mel_filters.max(axis=0) == 0.0).any():
warnings.warn( warnings.warn(
"At least one mel filterbank has all zero values. " "At least one mel filter has all zero values. "
f"The value for `nb_mel_filters` ({nb_mel_filters}) may be set too high. " f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. "
f"Or, the value for `nb_frequency_bins` ({nb_frequency_bins}) may be set too low." f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low."
) )
return filterbank return mel_filters
def optimal_fft_length(window_length: int) -> int:
"""
Finds the best FFT input size for a given `window_length`. This function takes a given window length and, if not
already a power of two, rounds it up to the next power or two.
The FFT algorithm works fastest when the length of the input is a power of two, which may be larger than the size
of the window or analysis frame. For example, if the window is 400 samples, using an FFT input size of 512 samples
is more optimal than an FFT size of 400 samples. Using a larger FFT size does not affect the detected frequencies,
it simply gives a higher frequency resolution (i.e. the frequency bins are smaller).
"""
return 2 ** int(np.ceil(np.log2(window_length)))
def power_to_db(mel_spectrogram, top_db=None, a_min=1e-10, ref=1.0): def window_function(
window_length: int,
name: str = "hann",
periodic: bool = True,
frame_length: Optional[int] = None,
center: bool = True,
) -> np.ndarray:
""" """
Convert a mel spectrogram from power to db scale, this function is the numpy implementation of librosa.power_to_lb. Returns an array containing the specified window. This window is intended to be used with `stft`.
It computes `10 * log10(mel_spectrogram / ref)`, using basic log properties for stability.
Tips: The following window types are supported:
- The motivation behind applying the log function on the mel spectrogram is that humans do not hear loudness on
a - `"boxcar"`: a rectangular window
linear scale. Generally to double the percieved volume of a sound we need to put 8 times as much energy into - `"hamming"`: the Hamming window
it. - `"hann"`: the Hann window
- This means that large variations in energy may not sound all that different if the sound is loud to begin
with. This compression operation makes the mel features match more closely what humans actually hear.
Args: Args:
mel_spectrogram (`np.array`): window_length (`int`):
Input mel spectrogram. The length of the window in samples.
top_db (`int`, *optional*): name (`str`, *optional*, defaults to `"hann"`):
The maximum decibel value. The name of the window function.
a_min (`int`, *optional*, default to 1e-10): periodic (`bool`, *optional*, defaults to `True`):
Minimum value to use when cliping the mel spectrogram. Whether the window is periodic or symmetric.
ref (`float`, *optional*, default to 1.0): frame_length (`int`, *optional*):
Maximum reference value used to scale the mel_spectrogram. The length of the analysis frames in samples. Provide a value for `frame_length` if the window is smaller
than the frame length, so that it will be zero-padded.
center (`bool`, *optional*, defaults to `True`):
Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided.
Returns:
`np.ndarray` of shape `(window_length,)` or `(frame_length,)` containing the window.
""" """
log_spec = 10 * np.log10(np.clip(mel_spectrogram, a_min=a_min, a_max=None)) length = window_length + 1 if periodic else window_length
log_spec -= 10.0 * np.log10(np.maximum(a_min, ref))
if top_db is not None: if name == "boxcar":
if top_db < 0: window = np.ones(length)
raise ValueError("top_db must be non-negative") elif name in ["hamming", "hamming_window"]:
log_spec = np.clip(log_spec, min=np.maximum(log_spec) - top_db, max=np.inf) window = np.hamming(length)
return log_spec elif name in ["hann", "hann_window"]:
window = np.hanning(length)
else:
raise ValueError(f"Unknown window function '{name}'")
if periodic:
window = window[:-1]
if frame_length is None:
return window
if window_length > frame_length:
raise ValueError(
f"Length of the window ({window_length}) may not be larger than frame_length ({frame_length})"
)
padded_window = np.zeros(frame_length)
offset = (frame_length - window_length) // 2 if center else 0
padded_window[offset : offset + window_length] = window
return padded_window
# TODO This method does not support batching yet as we are mainly focused on inference.
def spectrogram(
waveform: np.ndarray,
window: np.ndarray,
frame_length: int,
hop_length: int,
fft_length: Optional[int] = None,
power: Optional[float] = 1.0,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
preemphasis: Optional[float] = None,
mel_filters: Optional[np.ndarray] = None,
mel_floor: float = 1e-10,
log_mel: Optional[str] = None,
reference: float = 1.0,
min_value: float = 1e-10,
db_range: Optional[float] = None,
dtype: np.dtype = np.float32,
) -> np.ndarray:
"""
Calculates a spectrogram over one waveform using the Short-Time Fourier Transform.
This function can create the following kinds of spectrograms:
- amplitude spectrogram (`power = 1.0`)
- power spectrogram (`power = 2.0`)
- complex-valued spectrogram (`power = None`)
- log spectrogram (use `log_mel` argument)
- mel spectrogram (provide `mel_filters`)
- log-mel spectrogram (provide `mel_filters` and `log_mel`)
How this works:
1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
- hop_length` samples.
2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
3. The DFT is taken of each windowed frame.
4. The results are stacked into a spectrogram.
We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
- The analysis frame. This is the size of the time slices that the input waveform is split into.
- The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
- The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
typically the next power of two.
Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and
`torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms
can be constructed.
Args:
waveform (`np.ndarray` of shape `(length,)`):
The input waveform. This must be a single real-valued, mono waveform.
window (`np.ndarray` of shape `(frame_length,)`):
The windowing function to apply, including zero-padding if necessary. The actual window length may be
shorter than `frame_length`, but we're assuming the array has already been zero-padded.
frame_length (`int`):
The length of the analysis frames in samples. With librosa this is always equal to `fft_length` but we also
allow smaller sizes.
hop_length (`int`):
The stride between successive analysis frames in samples.
fft_length (`int`, *optional*):
The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have.
For optimal speed, this should be a power of two. If `None`, uses `frame_length`.
power (`float`, *optional*, defaults to 1.0):
If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `None`, returns
complex numbers.
center (`bool`, *optional*, defaults to `True`):
Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame
`t` will start at time `t * hop_length`.
pad_mode (`str`, *optional*, defaults to `"reflect"`):
Padding mode used when `center` is `True`. Possible values are: `"constant"` (pad with zeros), `"edge"`
(pad with edge values), `"reflect"` (pads with mirrored values).
onesided (`bool`, *optional*, defaults to `True`):
If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
preemphasis (`float`, *optional*)
Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
The mel filter bank. If supplied, applies a this filter bank to create a mel spectrogram.
mel_floor (`float`, *optional*, defaults to 1e-10):
Minimum value of mel frequency banks.
log_mel (`str`, *optional*):
How to convert the spectrogram to log scale. Possible options are: `None` (don't convert), `"log"` (take
the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). Can only be
used when `power` is not `None`.
reference (`float`, *optional*, defaults to 1.0):
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
the loudest part to 0 dB. Must be greater than zero.
min_value (`float`, *optional*, defaults to `1e-10`):
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
`log(0)`. For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an
amplitude spectrogram, the value `1e-5` corresponds to -100 dB. Must be greater than zero.
db_range (`float`, *optional*):
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
dtype (`np.dtype`, *optional*, defaults to `np.float32`):
Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be
`np.complex64`.
Returns:
`nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape
`(num_mel_filters, length)` for a mel spectrogram.
"""
window_length = len(window)
if fft_length is None:
fft_length = frame_length
if frame_length > fft_length:
raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
if window_length != frame_length:
raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
if hop_length <= 0:
raise ValueError("hop_length must be greater than zero")
if waveform.ndim != 1:
raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
if np.iscomplexobj(waveform):
raise ValueError("Complex-valued input waveforms are not currently supported")
# center pad the waveform
if center:
padding = [(int(frame_length // 2), int(frame_length // 2))]
waveform = np.pad(waveform, padding, mode=pad_mode)
# promote to float64, since np.fft uses float64 internally
waveform = waveform.astype(np.float64)
window = window.astype(np.float64)
# split waveform into frames of frame_length size
num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))
num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)
# rfft is faster than fft
fft_func = np.fft.rfft if onesided else np.fft.fft
buffer = np.zeros(fft_length)
timestep = 0
for frame_idx in range(num_frames):
buffer[:frame_length] = waveform[timestep : timestep + frame_length]
if preemphasis is not None:
buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
buffer[0] *= 1 - preemphasis
buffer[:frame_length] *= window
spectrogram[frame_idx] = fft_func(buffer)
timestep += hop_length
# note: ** is much faster than np.power
if power is not None:
spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
spectrogram = spectrogram.T
if mel_filters is not None:
spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))
if power is not None and log_mel is not None:
if log_mel == "log":
spectrogram = np.log(spectrogram)
elif log_mel == "log10":
spectrogram = np.log10(spectrogram)
elif log_mel == "dB":
if power == 1.0:
spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)
elif power == 2.0:
spectrogram = power_to_db(spectrogram, reference, min_value, db_range)
else:
raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
else:
raise ValueError(f"Unknown log_mel option: {log_mel}")
spectrogram = np.asarray(spectrogram, dtype)
return spectrogram
def power_to_db(
spectrogram: np.ndarray,
reference: float = 1.0,
min_value: float = 1e-10,
db_range: Optional[float] = None,
) -> np.ndarray:
"""
Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`, using basic
logarithm properties for numerical stability.
The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
This means that large variations in energy may not sound all that different if the sound is loud to begin with.
This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
Based on the implementation of `librosa.power_to_db`.
Args:
spectrogram (`np.ndarray`):
The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared!
reference (`float`, *optional*, defaults to 1.0):
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
the loudest part to 0 dB. Must be greater than zero.
min_value (`float`, *optional*, defaults to `1e-10`):
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
`log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
db_range (`float`, *optional*):
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
Returns:
`np.ndarray`: the spectrogram in decibels
"""
if reference <= 0.0:
raise ValueError("reference must be greater than zero")
if min_value <= 0.0:
raise ValueError("min_value must be greater than zero")
reference = max(min_value, reference)
spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
if db_range is not None:
if db_range <= 0.0:
raise ValueError("db_range must be greater than zero")
spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
return spectrogram
def amplitude_to_db(
spectrogram: np.ndarray,
reference: float = 1.0,
min_value: float = 1e-5,
db_range: Optional[float] = None,
) -> np.ndarray:
"""
Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`, using
basic logarithm properties for numerical stability.
The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
This means that large variations in energy may not sound all that different if the sound is loud to begin with.
This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
Args:
spectrogram (`np.ndarray`):
The input amplitude (mel) spectrogram.
reference (`float`, *optional*, defaults to 1.0):
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
the loudest part to 0 dB. Must be greater than zero.
min_value (`float`, *optional*, defaults to `1e-5`):
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
`log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
db_range (`float`, *optional*):
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
Returns:
`np.ndarray`: the spectrogram in decibels
"""
if reference <= 0.0:
raise ValueError("reference must be greater than zero")
if min_value <= 0.0:
raise ValueError("min_value must be greater than zero")
reference = max(min_value, reference)
spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
if db_range is not None:
if db_range <= 0.0:
raise ValueError("db_range must be greater than zero")
spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
return spectrogram
### deprecated functions below this line ###
def get_mel_filter_banks(
nb_frequency_bins: int,
nb_mel_filters: int,
frequency_min: float,
frequency_max: float,
sample_rate: int,
norm: Optional[str] = None,
mel_scale: str = "htk",
) -> np.array:
warnings.warn(
"The function `get_mel_filter_banks` is deprecated and will be removed in version 4.31.0 of Transformers",
FutureWarning,
)
return mel_filter_bank(
num_frequency_bins=nb_frequency_bins,
num_mel_filters=nb_mel_filters,
min_frequency=frequency_min,
max_frequency=frequency_max,
sampling_rate=sample_rate,
norm=norm,
mel_scale=mel_scale,
)
# TODO @ArthurZucker: This method does not support batching yet as we are mainly focus on inference.
def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int = 400, center: bool = True): def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int = 400, center: bool = True):
""" """
In order to compute the short time fourier transform, the waveform needs to be split in overlapping windowed In order to compute the short time fourier transform, the waveform needs to be split in overlapping windowed
...@@ -270,6 +603,10 @@ def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int = ...@@ -270,6 +603,10 @@ def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int =
framed_waveform (`np.array` of shape `(waveform.shape // hop_length , fft_window_size)`): framed_waveform (`np.array` of shape `(waveform.shape // hop_length , fft_window_size)`):
The framed waveforms that can be fed to `np.fft`. The framed waveforms that can be fed to `np.fft`.
""" """
warnings.warn(
"The function `fram_wave` is deprecated and will be removed in version 4.31.0 of Transformers",
FutureWarning,
)
frames = [] frames = []
for i in range(0, waveform.shape[0] + 1, hop_length): for i in range(0, waveform.shape[0] + 1, hop_length):
if center: if center:
...@@ -298,9 +635,6 @@ def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int = ...@@ -298,9 +635,6 @@ def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int =
return frames return frames
# TODO @ArthurZucker: This method does not support batching yet as we are mainly focus on inference.
def stft(frames: np.array, windowing_function: np.array, fft_window_size: int = None): def stft(frames: np.array, windowing_function: np.array, fft_window_size: int = None):
""" """
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same results Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same results
...@@ -337,6 +671,10 @@ def stft(frames: np.array, windowing_function: np.array, fft_window_size: int = ...@@ -337,6 +671,10 @@ def stft(frames: np.array, windowing_function: np.array, fft_window_size: int =
spectrogram (`np.ndarray`): spectrogram (`np.ndarray`):
A spectrogram of shape `(num_frames, nb_frequency_bins)` obtained using the STFT algorithm A spectrogram of shape `(num_frames, nb_frequency_bins)` obtained using the STFT algorithm
""" """
warnings.warn(
"The function `stft` is deprecated and will be removed in version 4.31.0 of Transformers",
FutureWarning,
)
frame_size = frames.shape[1] frame_size = frames.shape[1]
if fft_window_size is None: if fft_window_size is None:
...@@ -355,5 +693,5 @@ def stft(frames: np.array, windowing_function: np.array, fft_window_size: int = ...@@ -355,5 +693,5 @@ def stft(frames: np.array, windowing_function: np.array, fft_window_size: int =
np.multiply(frame, windowing_function, out=fft_signal[:frame_size]) np.multiply(frame, windowing_function, out=fft_signal[:frame_size])
else: else:
fft_signal[:frame_size] = frame fft_signal[:frame_size] = frame
spectrogram[f] = fft(fft_signal, axis=0)[:nb_frequency_bins] spectrogram[f] = np.fft.fft(fft_signal, axis=0)[:nb_frequency_bins]
return spectrogram.T return spectrogram.T
...@@ -21,7 +21,7 @@ from typing import Any, Dict, List, Optional, Union ...@@ -21,7 +21,7 @@ from typing import Any, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from ...audio_utils import fram_wave, get_mel_filter_banks, power_to_db, stft from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, logging from ...utils import TensorType, logging
...@@ -116,21 +116,21 @@ class ClapFeatureExtractor(SequenceFeatureExtractor): ...@@ -116,21 +116,21 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.frequency_min = frequency_min self.frequency_min = frequency_min
self.frequency_max = frequency_max self.frequency_max = frequency_max
self.mel_filters = get_mel_filter_banks( self.mel_filters = mel_filter_bank(
nb_frequency_bins=self.nb_frequency_bins, num_frequency_bins=self.nb_frequency_bins,
nb_mel_filters=feature_size, num_mel_filters=feature_size,
frequency_min=frequency_min, min_frequency=frequency_min,
frequency_max=frequency_max, max_frequency=frequency_max,
sample_rate=sampling_rate, sampling_rate=sampling_rate,
norm=None, norm=None,
mel_scale="htk", mel_scale="htk",
) )
self.mel_filters_slaney = get_mel_filter_banks( self.mel_filters_slaney = mel_filter_bank(
nb_frequency_bins=self.nb_frequency_bins, num_frequency_bins=self.nb_frequency_bins,
nb_mel_filters=feature_size, num_mel_filters=feature_size,
frequency_min=frequency_min, min_frequency=frequency_min,
frequency_max=frequency_max, max_frequency=frequency_max,
sample_rate=sampling_rate, sampling_rate=sampling_rate,
norm="slaney", norm="slaney",
mel_scale="slaney", mel_scale="slaney",
) )
...@@ -153,24 +153,25 @@ class ClapFeatureExtractor(SequenceFeatureExtractor): ...@@ -153,24 +153,25 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
def _np_extract_fbank_features(self, waveform: np.array, mel_filters: Optional[np.array] = None) -> np.ndarray: def _np_extract_fbank_features(self, waveform: np.array, mel_filters: Optional[np.array] = None) -> np.ndarray:
""" """
Compute the log-Mel spectrogram of the provided `waveform` using the `hanning` window. In CLAP, two different Compute the log-mel spectrogram of the provided `waveform` using the Hann window. In CLAP, two different filter
filter banks are used depending on the truncation pattern: banks are used depending on the truncation pattern:
- `self.mel_filters`: they correspond to the defaults parameters of `torchaduio` which can be obtained from - `self.mel_filters`: they correspond to the default parameters of `torchaudio` which can be obtained from
calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation` calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation`
is set to `"fusion"`. is set to `"fusion"`.
- `self.mel_filteres_slaney` : they correspond to the defaults parameters of `torchlibrosa` which used - `self.mel_filteres_slaney` : they correspond to the default parameters of `librosa` which used
`librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original `librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original
implementation when the truncation mode is not `"fusion"`. implementation when the truncation mode is not `"fusion"`.
""" """
window = np.hanning(self.fft_window_size + 1)[:-1] log_mel_spectrogram = spectrogram(
frames = fram_wave(waveform, self.hop_length, self.fft_window_size) waveform,
spectrogram = stft(frames, window, fft_window_size=self.fft_window_size) window_function(self.fft_window_size, "hann"),
frame_length=self.fft_window_size,
magnitudes = np.abs(spectrogram) ** 2 hop_length=self.hop_length,
mel_spectrogram = np.matmul(mel_filters.T, magnitudes) power=2.0,
log_mel_spectrogram = power_to_db(mel_spectrogram).T mel_filters=mel_filters,
log_mel_spectrogram = np.asarray(log_mel_spectrogram, np.float32) log_mel="dB",
return log_mel_spectrogram )
return log_mel_spectrogram.T
def _random_mel_fusion(self, mel, total_frames, chunk_frames): def _random_mel_fusion(self, mel, total_frames, chunk_frames):
ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3) ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3)
......
...@@ -20,9 +20,8 @@ from typing import List, Optional, Union ...@@ -20,9 +20,8 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
import torchaudio
from packaging import version
from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...file_utils import PaddingStrategy, TensorType from ...file_utils import PaddingStrategy, TensorType
...@@ -31,13 +30,6 @@ from ...utils import logging ...@@ -31,13 +30,6 @@ from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
parsed_torchaudio_version_base = version.parse(version.parse(torchaudio.__version__).base_version)
if not parsed_torchaudio_version_base >= version.parse("0.10"):
logger.warning(
f"You are using torchaudio=={torchaudio.__version__}, but torchaudio>=0.10.0 is required to use "
"MCTCTFeatureExtractor. This requires torch>=1.10.0. Please upgrade torch and torchaudio."
)
class MCTCTFeatureExtractor(SequenceFeatureExtractor): class MCTCTFeatureExtractor(SequenceFeatureExtractor):
r""" r"""
...@@ -110,68 +102,9 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor): ...@@ -110,68 +102,9 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor):
self.sample_size = win_length * sampling_rate // 1000 self.sample_size = win_length * sampling_rate // 1000
self.sample_stride = hop_length * sampling_rate // 1000 self.sample_stride = hop_length * sampling_rate // 1000
self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size))) self.n_fft = optimal_fft_length(self.sample_size)
self.n_freqs = (self.n_fft // 2) + 1 self.n_freqs = (self.n_fft // 2) + 1
@staticmethod
def _num_frames_calc(in_size, frame_size, frame_stride):
return int(1 + np.floor((in_size - frame_size) * 1 / frame_stride))
@staticmethod
def _frame_signal(one_waveform, n_frames, frame_signal_scale, window_length, sample_stride):
scale = frame_signal_scale
frames = np.zeros(n_frames * window_length)
for frame_idx in range(n_frames):
start = frame_idx * window_length
end = (frame_idx + 1) * window_length
wave_start = frame_idx * sample_stride
wave_end = frame_idx * sample_stride + window_length
frames[start:end] = scale * one_waveform[wave_start:wave_end]
return frames
@staticmethod
def _apply_preemphasis_inplace(frames, window_length, preemphasis_coeff):
if frames.size % window_length != 0:
raise ValueError(
f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
f" window_length={window_length}."
)
n_frames = frames.size // window_length
for frame_idx in range(n_frames, 0, -1):
start = (frame_idx - 1) * window_length
end = frame_idx * window_length - 1
frames[start + 1 : end + 1] -= preemphasis_coeff * frames[start:end]
frames[start] *= 1 - preemphasis_coeff
@staticmethod
def _windowing(frames, window_length, window):
if frames.size % window_length != 0:
raise ValueError(
f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
f" window_length={window_length}."
)
shaped = frames.reshape(-1, window_length)
shaped = window * shaped
return shaped
@staticmethod
def _dft(frames, K, n_frames, n_samples, n_fft):
dft = np.zeros([n_frames, K])
for frame in range(n_frames):
begin = frame * n_samples
inwards_buffer = frames[begin : begin + n_samples]
inwards_buffer = np.pad(inwards_buffer, (0, n_fft - n_samples), "constant")
out = np.fft.rfft(inwards_buffer)
dft[frame] = np.abs(out[:K])
return dft
def _extract_mfsc_features(self, one_waveform: np.array) -> np.ndarray: def _extract_mfsc_features(self, one_waveform: np.array) -> np.ndarray:
""" """
Extracts MFSC Features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC code. Extracts MFSC Features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC code.
...@@ -183,36 +116,27 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor): ...@@ -183,36 +116,27 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor):
window = window.numpy() window = window.numpy()
fbanks = torchaudio.functional.melscale_fbanks( fbanks = mel_filter_bank(
n_freqs=self.n_freqs, num_frequency_bins=self.n_freqs,
f_min=0.0, # change this to zeros num_mel_filters=self.feature_size,
f_max=self.sampling_rate / 2.0, min_frequency=0.0,
n_mels=self.feature_size, max_frequency=self.sampling_rate / 2.0,
sample_rate=self.sampling_rate, sampling_rate=self.sampling_rate,
) )
fbanks = fbanks.numpy() msfc_features = spectrogram(
one_waveform * self.frame_signal_scale,
n_frames = self._num_frames_calc(one_waveform.size, self.sample_size, self.sample_stride) window=window,
frame_length=self.sample_size,
frames = self._frame_signal( hop_length=self.sample_stride,
one_waveform, n_frames, self.frame_signal_scale, self.sample_size, self.sample_stride fft_length=self.n_fft,
center=False,
preemphasis=self.preemphasis_coeff,
mel_filters=fbanks,
mel_floor=self.mel_floor,
log_mel="log",
) )
return msfc_features.T
self._apply_preemphasis_inplace(frames, self.sample_size, self.preemphasis_coeff)
frames = self._windowing(frames, self.sample_size, window)
dft_out = self._dft(frames.flatten(), self.n_freqs, n_frames, self.sample_size, self.n_fft)
# msfc_features = STFT * mel frequency banks.
msfc_features = np.einsum("...tf,fm->...tm", dft_out, fbanks)
# clamp feature values then log scale, as implemented in flashlight
msfc_features = np.maximum(msfc_features, self.mel_floor)
msfc_features = np.log(msfc_features)
return msfc_features
def _normalize_one(self, x, input_length, padding_value): def _normalize_one(self, x, input_length, padding_value):
# make sure we normalize float32 arrays # make sure we normalize float32 arrays
......
...@@ -20,7 +20,7 @@ from typing import Any, Dict, List, Optional, Union ...@@ -20,7 +20,7 @@ from typing import Any, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from ...audio_utils import get_mel_filter_banks from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging from ...utils import PaddingStrategy, TensorType, logging
...@@ -110,18 +110,18 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): ...@@ -110,18 +110,18 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
self.sample_size = win_length * sampling_rate // 1000 self.sample_size = win_length * sampling_rate // 1000
self.sample_stride = hop_length * sampling_rate // 1000 self.sample_stride = hop_length * sampling_rate // 1000
self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size))) self.n_fft = optimal_fft_length(self.sample_size)
self.n_freqs = (self.n_fft // 2) + 1 self.n_freqs = (self.n_fft // 2) + 1
window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True) window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True)
self.window = window.numpy().astype(np.float64) self.window = window.numpy().astype(np.float64)
self.mel_filters = get_mel_filter_banks( self.mel_filters = mel_filter_bank(
nb_frequency_bins=self.n_freqs, num_frequency_bins=self.n_freqs,
nb_mel_filters=self.num_mel_bins, num_mel_filters=self.num_mel_bins,
frequency_min=self.fmin, min_frequency=self.fmin,
frequency_max=self.fmax, max_frequency=self.fmax,
sample_rate=self.sampling_rate, sampling_rate=self.sampling_rate,
norm="slaney", norm="slaney",
mel_scale="slaney", mel_scale="slaney",
) )
...@@ -160,31 +160,6 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): ...@@ -160,31 +160,6 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
return normed_input_values return normed_input_values
@staticmethod
def _stft(waveform: np.ndarray, fft_length: int, hop_length: int, window: np.ndarray) -> np.ndarray:
"""
Calculates the magnitude spectrogram over one waveform array.
"""
# center pad the waveform
padding = [(int(fft_length // 2), int(fft_length // 2))]
waveform = np.pad(waveform, padding, mode="reflect")
waveform_size = waveform.size
# promote to float64, since np.fft uses float64 internally
waveform = waveform.astype(np.float64)
num_frames = int(1 + np.floor((waveform_size - fft_length) / hop_length))
num_frequency_bins = (fft_length // 2) + 1
spectrogram = np.empty((num_frames, num_frequency_bins))
start = 0
for frame_idx in range(num_frames):
frame = waveform[start : start + fft_length] * window
spectrogram[frame_idx] = np.abs(np.fft.rfft(frame))
start += hop_length
return spectrogram
def _extract_mel_features( def _extract_mel_features(
self, self,
one_waveform: np.ndarray, one_waveform: np.ndarray,
...@@ -192,14 +167,17 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): ...@@ -192,14 +167,17 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
""" """
Extracts log-mel filterbank features for one waveform array (unbatched). Extracts log-mel filterbank features for one waveform array (unbatched).
""" """
if self.n_fft != self.sample_size: log_mel_spec = spectrogram(
raise NotImplementedError( one_waveform,
f"Currently the STFT frame size must be a power of two, but got {self.sample_size} for a window length of {self.win_length} and sampling rate of {self.sampling_rate}. Ensure `win_length * sampling_rate // 1000` is divisible by two." window=self.window,
frame_length=self.sample_size,
hop_length=self.sample_stride,
fft_length=self.n_fft,
mel_filters=self.mel_filters,
mel_floor=self.mel_floor,
log_mel="log10",
) )
return log_mel_spec.T
stft_out = self._stft(one_waveform, self.n_fft, self.sample_stride, self.window)
return np.log10(np.maximum(self.mel_floor, np.dot(stft_out, self.mel_filters)))
def __call__( def __call__(
self, self,
......
...@@ -18,8 +18,8 @@ from math import ceil ...@@ -18,8 +18,8 @@ from math import ceil
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
from numpy.fft import fft
from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import BatchFeature, SequenceFeatureExtractor from ...feature_extraction_sequence_utils import BatchFeature, SequenceFeatureExtractor
from ...utils import TensorType, logging from ...utils import TensorType, logging
...@@ -83,143 +83,34 @@ class TvltFeatureExtractor(SequenceFeatureExtractor): ...@@ -83,143 +83,34 @@ class TvltFeatureExtractor(SequenceFeatureExtractor):
self.hop_length = sampling_rate // hop_length_to_sampling_rate self.hop_length = sampling_rate // hop_length_to_sampling_rate
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.padding_value = padding_value self.padding_value = padding_value
self.mel_filters = self.get_mel_filters(sampling_rate, n_fft, n_mels=feature_size) self.mel_filters = mel_filter_bank(
num_frequency_bins=1 + n_fft // 2,
# Copied from transformers.models.whisper.feature_extraction_whisper.WhisperFeatureExtractor.get_mel_filters with 45.245640471924965->59.99247463746737 num_mel_filters=feature_size,
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32): min_frequency=0.0,
# Initialize the weights max_frequency=22050.0,
n_mels = int(n_mels) sampling_rate=sampling_rate,
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) norm="slaney",
mel_scale="slaney",
# Center freqs of each FFT bin ).T
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 59.99247463746737
mels = np.linspace(min_mel, max_mel, n_mels + 2)
mels = np.asanyarray(mels)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
mel_f = freqs
fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, np.newaxis]
return weights
# Copied from transformers.models.whisper.feature_extraction_whisper.WhisperFeatureExtractor.fram_wave
def fram_wave(self, waveform, center=True):
"""
Transform a raw waveform into a list of smaller waveforms. The window length defines how much of the signal is
contain in each frame (smalle waveform), while the hope length defines the step between the beginning of each
new frame.
Centering is done by reflecting the waveform which is first centered around `frame_idx * hop_length`.
"""
frames = []
for i in range(0, waveform.shape[0] + 1, self.hop_length):
half_window = (self.n_fft - 1) // 2 + 1
if center:
start = i - half_window if i > half_window else 0
end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
frame = waveform[start:end]
if start == 0:
padd_width = (-i + half_window, 0)
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
elif end == waveform.shape[0]:
padd_width = (0, (i - waveform.shape[0] + half_window))
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
else:
frame = waveform[i : i + self.n_fft]
frame_width = frame.shape[0]
if frame_width < waveform.shape[0]:
frame = np.lib.pad(
frame, pad_width=(0, self.n_fft - frame_width), mode="constant", constant_values=0
)
frames.append(frame)
return np.stack(frames, 0)
# Copied from transformers.models.whisper.feature_extraction_whisper.WhisperFeatureExtractor.stft
def stft(self, frames, window):
"""
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same
results as `torch.stft`.
"""
frame_size = frames.shape[1]
fft_size = self.n_fft
if fft_size is None:
fft_size = frame_size
if fft_size < frame_size:
raise ValueError("FFT size must greater or equal the frame size")
# number of FFT bins to store
num_fft_bins = (fft_size >> 1) + 1
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
fft_signal = np.zeros(fft_size)
for f, frame in enumerate(frames):
if window is not None:
np.multiply(frame, window, out=fft_signal[:frame_size])
else:
fft_signal[:frame_size] = frame
data[f] = fft(fft_signal, axis=0)[:num_fft_bins]
return data.T
def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray: def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
""" """
Compute the log-Mel spectrogram of the provided audio, gives similar results whisper's original torch Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
implementation with 1e-5 tolerance. implementation with 1e-5 tolerance.
""" """
window = np.hanning(self.n_fft + 1)[:-1] log_spec = spectrogram(
waveform,
frames = self.fram_wave(waveform) window_function(self.n_fft, "hann"),
stft = self.stft(frames, window=window) frame_length=self.n_fft,
magnitudes = np.abs(stft[:, :-1]) ** 2 hop_length=self.hop_length,
power=2.0,
filters = self.mel_filters mel_filters=self.mel_filters.T,
mel_spec = filters @ magnitudes log_mel="dB",
db_range=80.0,
log_spec = 10.0 * np.log10(np.maximum(1e-10, mel_spec)) )
log_spec -= 10.0 * np.log10(np.maximum(1e-10, 1.0)) log_spec = log_spec[:, :-1]
log_spec = np.maximum(log_spec, log_spec.max() - 80.0)
log_spec = log_spec - 20.0 log_spec = log_spec - 20.0
log_spec = np.clip(log_spec / 40.0, -2.0, 0.0) + 1.0 log_spec = np.clip(log_spec / 40.0, -2.0, 0.0) + 1.0
return log_spec return log_spec
def __call__( def __call__(
......
...@@ -19,8 +19,8 @@ import copy ...@@ -19,8 +19,8 @@ import copy
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import numpy as np import numpy as np
from numpy.fft import fft
from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, logging from ...utils import TensorType, logging
...@@ -81,138 +81,33 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -81,138 +81,33 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
self.n_samples = chunk_length * sampling_rate self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length self.nb_max_frames = self.n_samples // hop_length
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.mel_filters = self.get_mel_filters(sampling_rate, n_fft, n_mels=feature_size) self.mel_filters = mel_filter_bank(
num_frequency_bins=1 + n_fft // 2,
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32): num_mel_filters=feature_size,
# Initialize the weights min_frequency=0.0,
n_mels = int(n_mels) max_frequency=8000.0,
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) sampling_rate=sampling_rate,
norm="slaney",
# Center freqs of each FFT bin mel_scale="slaney",
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 45.245640471924965
mels = np.linspace(min_mel, max_mel, n_mels + 2)
mels = np.asanyarray(mels)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
mel_f = freqs
fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, np.newaxis]
return weights
def fram_wave(self, waveform, center=True):
"""
Transform a raw waveform into a list of smaller waveforms. The window length defines how much of the signal is
contain in each frame (smalle waveform), while the hope length defines the step between the beginning of each
new frame.
Centering is done by reflecting the waveform which is first centered around `frame_idx * hop_length`.
"""
frames = []
for i in range(0, waveform.shape[0] + 1, self.hop_length):
half_window = (self.n_fft - 1) // 2 + 1
if center:
start = i - half_window if i > half_window else 0
end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
frame = waveform[start:end]
if start == 0:
padd_width = (-i + half_window, 0)
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
elif end == waveform.shape[0]:
padd_width = (0, (i - waveform.shape[0] + half_window))
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
else:
frame = waveform[i : i + self.n_fft]
frame_width = frame.shape[0]
if frame_width < waveform.shape[0]:
frame = np.lib.pad(
frame, pad_width=(0, self.n_fft - frame_width), mode="constant", constant_values=0
) )
frames.append(frame)
return np.stack(frames, 0)
def stft(self, frames, window):
"""
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same
results as `torch.stft`.
"""
frame_size = frames.shape[1]
fft_size = self.n_fft
if fft_size is None:
fft_size = frame_size
if fft_size < frame_size:
raise ValueError("FFT size must greater or equal the frame size")
# number of FFT bins to store
num_fft_bins = (fft_size >> 1) + 1
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
fft_signal = np.zeros(fft_size)
for f, frame in enumerate(frames):
if window is not None:
np.multiply(frame, window, out=fft_signal[:frame_size])
else:
fft_signal[:frame_size] = frame
data[f] = fft(fft_signal, axis=0)[:num_fft_bins]
return data.T
def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray: def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
""" """
Compute the log-Mel spectrogram of the provided audio, gives similar results whisper's original torch Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
implementation with 1e-5 tolerance. implementation with 1e-5 tolerance.
""" """
window = np.hanning(self.n_fft + 1)[:-1] log_spec = spectrogram(
waveform,
frames = self.fram_wave(waveform) window_function(self.n_fft, "hann"),
stft = self.stft(frames, window=window) frame_length=self.n_fft,
magnitudes = np.abs(stft[:, :-1]) ** 2 hop_length=self.hop_length,
power=2.0,
filters = self.mel_filters mel_filters=self.mel_filters,
mel_spec = filters @ magnitudes log_mel="log10",
)
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None)) log_spec = log_spec[:, :-1]
log_spec = np.maximum(log_spec, log_spec.max() - 8.0) log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0 log_spec = (log_spec + 4.0) / 4.0
return log_spec return log_spec
@staticmethod @staticmethod
......
...@@ -160,6 +160,7 @@ class ASTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test ...@@ -160,6 +160,7 @@ class ASTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test
# fmt: on # fmt: on
input_speech = self._load_datasamples(1) input_speech = self._load_datasamples(1)
feaure_extractor = ASTFeatureExtractor() feature_extractor = ASTFeatureExtractor()
input_values = feaure_extractor(input_speech, return_tensors="pt").input_values input_values = feature_extractor(input_speech, return_tensors="pt").input_values
self.assertEquals(input_values.shape, (1, 1024, 128))
self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4)) self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
...@@ -21,7 +21,7 @@ import unittest ...@@ -21,7 +21,7 @@ import unittest
import numpy as np import numpy as np
from transformers import is_speech_available from transformers import is_speech_available
from transformers.testing_utils import require_torch, require_torchaudio from transformers.testing_utils import require_torch
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
...@@ -47,7 +47,6 @@ def floats_list(shape, scale=1.0, rng=None, name=None): ...@@ -47,7 +47,6 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
@require_torch @require_torch
@require_torchaudio
class MCTCTFeatureExtractionTester(unittest.TestCase): class MCTCTFeatureExtractionTester(unittest.TestCase):
def __init__( def __init__(
self, self,
...@@ -102,7 +101,6 @@ class MCTCTFeatureExtractionTester(unittest.TestCase): ...@@ -102,7 +101,6 @@ class MCTCTFeatureExtractionTester(unittest.TestCase):
@require_torch @require_torch
@require_torchaudio
class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = MCTCTFeatureExtractor if is_speech_available() else None feature_extraction_class = MCTCTFeatureExtractor if is_speech_available() else None
...@@ -271,3 +269,38 @@ class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Te ...@@ -271,3 +269,38 @@ class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Te
self.assertTrue(np_processed.input_features.dtype == np.float32) self.assertTrue(np_processed.input_features.dtype == np.float32)
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt") pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_features.dtype == torch.float32) self.assertTrue(pt_processed.input_features.dtype == torch.float32)
def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
def test_integration(self):
# fmt: off
expected = np.array([
[
1.1280, 1.1319, 1.2744, 1.4369, 1.4328, 1.3671, 1.2889, 1.3046,
1.4419, 0.8387, 0.2995, 0.0404, 0.1068, 0.0472, 0.3728, 1.3356,
1.4491, 0.4770, 0.3997, 0.2776, 0.3184, -0.1243, -0.1170, -0.0828
],
[
1.0826, 1.0565, 1.2110, 1.3886, 1.3416, 1.2009, 1.1894, 1.2707,
1.5153, 0.7005, 0.4916, 0.4017, 0.3743, 0.1935, 0.4228, 1.1084,
0.9768, 0.0608, 0.2044, 0.1723, 0.0433, -0.2360, -0.2478, -0.2643
],
[
1.0590, 0.9923, 1.1185, 1.3309, 1.1971, 1.0067, 1.0080, 1.2036,
1.5397, 1.0383, 0.7672, 0.7551, 0.4878, 0.8771, 0.7565, 0.8775,
0.9042, 0.4595, 0.6157, 0.4954, 0.1857, 0.0307, 0.0199, 0.1033
],
])
# fmt: on
input_speech = self._load_datasamples(1)
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
input_features = feature_extractor(input_speech, sampling_rate=16000, return_tensors="pt").input_features
self.assertTrue(np.allclose(input_features[0, 100:103], expected, atol=1e-4))
...@@ -247,3 +247,27 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt ...@@ -247,3 +247,27 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
self.assertTrue(np_processed.input_features.dtype == np.float32) self.assertTrue(np_processed.input_features.dtype == np.float32)
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt") pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_features.dtype == torch.float32) self.assertTrue(pt_processed.input_features.dtype == torch.float32)
def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
def test_integration(self):
# fmt: off
expected = np.array([
-1.5745, -1.7713, -1.7020, -1.6069, -1.2250, -1.1105, -0.9072, -0.8241,
-1.2310, -0.8098, -0.3320, -0.4101, -0.7985, -0.4996, -0.8213, -0.9128,
-1.0420, -1.1286, -1.0440, -0.7999, -0.8405, -1.2275, -1.5443, -1.4625,
])
# fmt: on
input_speech = self._load_datasamples(1)
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
self.assertEquals(input_features.shape, (1, 584, 24))
self.assertTrue(np.allclose(input_features[0, 0, :30], expected, atol=1e-4))
...@@ -395,7 +395,8 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest ...@@ -395,7 +395,8 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
input_speech = self._load_datasamples(1) input_speech = self._load_datasamples(1)
feature_extractor = SpeechT5FeatureExtractor() feature_extractor = SpeechT5FeatureExtractor()
input_values = feature_extractor(input_speech, return_tensors="pt").input_values input_values = feature_extractor(input_speech, return_tensors="pt").input_values
self.assertTrue(torch.allclose(input_values[0, :30], EXPECTED_INPUT_VALUES, atol=1e-4)) self.assertEquals(input_values.shape, (1, 93680))
self.assertTrue(torch.allclose(input_values[0, :30], EXPECTED_INPUT_VALUES, atol=1e-6))
def test_integration_target(self): def test_integration_target(self):
# fmt: off # fmt: off
...@@ -410,4 +411,5 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest ...@@ -410,4 +411,5 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
input_speech = self._load_datasamples(1) input_speech = self._load_datasamples(1)
feature_extractor = SpeechT5FeatureExtractor() feature_extractor = SpeechT5FeatureExtractor()
input_values = feature_extractor(audio_target=input_speech, return_tensors="pt").input_values input_values = feature_extractor(audio_target=input_speech, return_tensors="pt").input_values
self.assertEquals(input_values.shape, (1, 366, 80))
self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4)) self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
...@@ -198,10 +198,10 @@ class TvltFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes ...@@ -198,10 +198,10 @@ class TvltFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
def test_integration(self): def test_integration(self):
input_speech = self._load_datasamples(1) input_speech = self._load_datasamples(1)
feaure_extractor = TvltFeatureExtractor() feature_extractor = TvltFeatureExtractor()
audio_values = feaure_extractor(input_speech, return_tensors="pt").audio_values audio_values = feature_extractor(input_speech, return_tensors="pt").audio_values
self.assertTrue(audio_values.shape, [1, 1, 192, 128]) self.assertEquals(audio_values.shape, (1, 1, 192, 128))
expected_slice = torch.tensor([[-0.3032, -0.2708], [-0.4434, -0.4007]]) expected_slice = torch.tensor([[-0.3032, -0.2708], [-0.4434, -0.4007]])
self.assertTrue(torch.allclose(audio_values[0, 0, :2, :2], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(audio_values[0, 0, :2, :2], expected_slice, atol=1e-4))
...@@ -218,8 +218,9 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest. ...@@ -218,8 +218,9 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
# fmt: on # fmt: on
input_speech = self._load_datasamples(1) input_speech = self._load_datasamples(1)
feaure_extractor = WhisperFeatureExtractor() feature_extractor = WhisperFeatureExtractor()
input_features = feaure_extractor(input_speech, return_tensors="pt").input_features input_features = feature_extractor(input_speech, return_tensors="pt").input_features
self.assertEqual(input_features.shape, (1, 80, 3000))
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4)) self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
def test_zero_mean_unit_variance_normalization_trunc_np_longest(self): def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import pytest
from transformers.audio_utils import (
amplitude_to_db,
hertz_to_mel,
mel_filter_bank,
mel_to_hertz,
power_to_db,
spectrogram,
window_function,
)
class AudioUtilsFunctionTester(unittest.TestCase):
def test_hertz_to_mel(self):
self.assertEqual(hertz_to_mel(0.0), 0.0)
self.assertAlmostEqual(hertz_to_mel(100), 150.48910241)
inputs = np.array([100, 200])
expected = np.array([150.48910241, 283.22989816])
self.assertTrue(np.allclose(hertz_to_mel(inputs), expected))
self.assertEqual(hertz_to_mel(0.0, "slaney"), 0.0)
self.assertEqual(hertz_to_mel(100, "slaney"), 1.5)
inputs = np.array([60, 100, 200, 1000, 1001, 2000])
expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])
self.assertTrue(np.allclose(hertz_to_mel(inputs, "slaney"), expected))
with pytest.raises(ValueError):
hertz_to_mel(100, mel_scale=None)
def test_mel_to_hertz(self):
self.assertEqual(mel_to_hertz(0.0), 0.0)
self.assertAlmostEqual(mel_to_hertz(150.48910241), 100)
inputs = np.array([150.48910241, 283.22989816])
expected = np.array([100, 200])
self.assertTrue(np.allclose(mel_to_hertz(inputs), expected))
self.assertEqual(mel_to_hertz(0.0, "slaney"), 0.0)
self.assertEqual(mel_to_hertz(1.5, "slaney"), 100)
inputs = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])
expected = np.array([60, 100, 200, 1000, 1001, 2000])
self.assertTrue(np.allclose(mel_to_hertz(inputs, "slaney"), expected))
with pytest.raises(ValueError):
mel_to_hertz(100, mel_scale=None)
def test_mel_filter_bank_shape(self):
mel_filters = mel_filter_bank(
num_frequency_bins=513,
num_mel_filters=13,
min_frequency=100,
max_frequency=4000,
sampling_rate=16000,
norm=None,
mel_scale="htk",
)
self.assertEqual(mel_filters.shape, (513, 13))
mel_filters = mel_filter_bank(
num_frequency_bins=513,
num_mel_filters=13,
min_frequency=100,
max_frequency=4000,
sampling_rate=16000,
norm="slaney",
mel_scale="slaney",
)
self.assertEqual(mel_filters.shape, (513, 13))
def test_mel_filter_bank_htk(self):
mel_filters = mel_filter_bank(
num_frequency_bins=16,
num_mel_filters=4,
min_frequency=0,
max_frequency=2000,
sampling_rate=4000,
norm=None,
mel_scale="htk",
)
# fmt: off
expected = np.array([
[0.0 , 0.0 , 0.0 , 0.0 ],
[0.61454786, 0.0 , 0.0 , 0.0 ],
[0.82511046, 0.17488954, 0.0 , 0.0 ],
[0.35597035, 0.64402965, 0.0 , 0.0 ],
[0.0 , 0.91360726, 0.08639274, 0.0 ],
[0.0 , 0.55547007, 0.44452993, 0.0 ],
[0.0 , 0.19733289, 0.80266711, 0.0 ],
[0.0 , 0.0 , 0.87724349, 0.12275651],
[0.0 , 0.0 , 0.6038449 , 0.3961551 ],
[0.0 , 0.0 , 0.33044631, 0.66955369],
[0.0 , 0.0 , 0.05704771, 0.94295229],
[0.0 , 0.0 , 0.0 , 0.83483975],
[0.0 , 0.0 , 0.0 , 0.62612982],
[0.0 , 0.0 , 0.0 , 0.41741988],
[0.0 , 0.0 , 0.0 , 0.20870994],
[0.0 , 0.0 , 0.0 , 0.0 ]
])
# fmt: on
self.assertTrue(np.allclose(mel_filters, expected))
def test_mel_filter_bank_slaney(self):
mel_filters = mel_filter_bank(
num_frequency_bins=16,
num_mel_filters=4,
min_frequency=0,
max_frequency=2000,
sampling_rate=4000,
norm=None,
mel_scale="slaney",
)
# fmt: off
expected = np.array([
[0.0 , 0.0 , 0.0 , 0.0 ],
[0.39869419, 0.0 , 0.0 , 0.0 ],
[0.79738839, 0.0 , 0.0 , 0.0 ],
[0.80391742, 0.19608258, 0.0 , 0.0 ],
[0.40522322, 0.59477678, 0.0 , 0.0 ],
[0.00652903, 0.99347097, 0.0 , 0.0 ],
[0.0 , 0.60796161, 0.39203839, 0.0 ],
[0.0 , 0.20939631, 0.79060369, 0.0 ],
[0.0 , 0.0 , 0.84685344, 0.15314656],
[0.0 , 0.0 , 0.52418477, 0.47581523],
[0.0 , 0.0 , 0.2015161 , 0.7984839 ],
[0.0 , 0.0 , 0.0 , 0.9141874 ],
[0.0 , 0.0 , 0.0 , 0.68564055],
[0.0 , 0.0 , 0.0 , 0.4570937 ],
[0.0 , 0.0 , 0.0 , 0.22854685],
[0.0 , 0.0 , 0.0 , 0.0 ]
])
# fmt: on
self.assertTrue(np.allclose(mel_filters, expected))
def test_mel_filter_bank_slaney_norm(self):
mel_filters = mel_filter_bank(
num_frequency_bins=16,
num_mel_filters=4,
min_frequency=0,
max_frequency=2000,
sampling_rate=4000,
norm="slaney",
mel_scale="slaney",
)
# fmt: off
expected = np.array([
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[1.19217795e-03, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[2.38435591e-03, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[2.40387905e-03, 5.86232616e-04, 0.00000000e+00, 0.00000000e+00],
[1.21170110e-03, 1.77821783e-03, 0.00000000e+00, 0.00000000e+00],
[1.95231437e-05, 2.97020305e-03, 0.00000000e+00, 0.00000000e+00],
[0.00000000e+00, 1.81763684e-03, 1.04857612e-03, 0.00000000e+00],
[0.00000000e+00, 6.26036972e-04, 2.11460963e-03, 0.00000000e+00],
[0.00000000e+00, 0.00000000e+00, 2.26505954e-03, 3.07332945e-04],
[0.00000000e+00, 0.00000000e+00, 1.40202503e-03, 9.54861093e-04],
[0.00000000e+00, 0.00000000e+00, 5.38990521e-04, 1.60238924e-03],
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.83458185e-03],
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.37593638e-03],
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.17290923e-04],
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 4.58645462e-04],
[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]
])
# fmt: on
self.assertTrue(np.allclose(mel_filters, expected))
def test_window_function(self):
window = window_function(16, "hann")
self.assertEqual(len(window), 16)
# fmt: off
expected = np.array([
0.0, 0.03806023, 0.14644661, 0.30865828, 0.5, 0.69134172, 0.85355339, 0.96193977,
1.0, 0.96193977, 0.85355339, 0.69134172, 0.5, 0.30865828, 0.14644661, 0.03806023,
])
# fmt: on
self.assertTrue(np.allclose(window, expected))
def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
def test_spectrogram_impulse(self):
waveform = np.zeros(40)
waveform[9] = 1.0 # impulse shifted in time
spec = spectrogram(
waveform,
window_function(12, "hann", frame_length=16),
frame_length=16,
hop_length=4,
power=1.0,
center=True,
pad_mode="reflect",
onesided=True,
)
self.assertEqual(spec.shape, (9, 11))
expected = np.array([[0.0, 0.0669873, 0.9330127, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])
self.assertTrue(np.allclose(spec, expected))
def test_spectrogram_integration_test(self):
waveform = self._load_datasamples(1)[0]
spec = spectrogram(
waveform,
window_function(400, "hann", frame_length=512),
frame_length=512,
hop_length=128,
power=1.0,
center=True,
pad_mode="reflect",
onesided=True,
)
self.assertEqual(spec.shape, (257, 732))
# fmt: off
expected = np.array([
0.02464888, 0.04648664, 0.05872392, 0.02311783, 0.0327175 ,
0.02433643, 0.01198814, 0.02055709, 0.01559287, 0.01394357,
0.01299037, 0.01728045, 0.0254554 , 0.02486533, 0.02011792,
0.01755333, 0.02100457, 0.02337024, 0.01436963, 0.01464558,
0.0211017 , 0.0193489 , 0.01272165, 0.01858462, 0.03722598,
0.0456542 , 0.03281558, 0.00620586, 0.02226466, 0.03618042,
0.03508182, 0.02271432, 0.01051649, 0.01225771, 0.02315293,
0.02331886, 0.01417785, 0.0106844 , 0.01791214, 0.017177 ,
0.02125114, 0.05028201, 0.06830665, 0.05216664, 0.01963666,
0.06941418, 0.11513043, 0.12257859, 0.10948435, 0.08568069,
0.05509328, 0.05047818, 0.047112 , 0.05060737, 0.02982424,
0.02803827, 0.02933729, 0.01760491, 0.00587815, 0.02117637,
0.0293578 , 0.03452379, 0.02194803, 0.01676056,
])
# fmt: on
self.assertTrue(np.allclose(spec[:64, 400], expected))
spec = spectrogram(
waveform,
window_function(400, "hann"),
frame_length=400,
hop_length=128,
fft_length=512,
power=1.0,
center=True,
pad_mode="reflect",
onesided=True,
)
self.assertEqual(spec.shape, (257, 732))
self.assertTrue(np.allclose(spec[:64, 400], expected))
def test_spectrogram_center_padding(self):
waveform = self._load_datasamples(1)[0]
spec = spectrogram(
waveform,
window_function(512, "hann"),
frame_length=512,
hop_length=128,
center=True,
pad_mode="reflect",
)
self.assertEqual(spec.shape, (257, 732))
# fmt: off
expected = np.array([
0.1287945 , 0.12792738, 0.08311573, 0.03155122, 0.02470202,
0.00727857, 0.00910694, 0.00686163, 0.01238981, 0.01473668,
0.00336144, 0.00370314, 0.00600871, 0.01120164, 0.01942998,
0.03132008, 0.0232842 , 0.01124642, 0.02754783, 0.02423725,
0.00147893, 0.00038027, 0.00112299, 0.00596233, 0.00571529,
0.02084235, 0.0231855 , 0.00810006, 0.01837943, 0.00651339,
0.00093931, 0.00067426, 0.01058399, 0.01270507, 0.00151734,
0.00331913, 0.00302416, 0.01081792, 0.00754549, 0.00148963,
0.00111943, 0.00152573, 0.00608017, 0.01749986, 0.01205949,
0.0143082 , 0.01910573, 0.00413786, 0.03916619, 0.09873404,
0.08302026, 0.02673891, 0.00401255, 0.01397392, 0.00751862,
0.01024884, 0.01544606, 0.00638907, 0.00623633, 0.0085103 ,
0.00217659, 0.00276204, 0.00260835, 0.00299299,
])
# fmt: on
self.assertTrue(np.allclose(spec[:64, 0], expected))
spec = spectrogram(
waveform,
window_function(512, "hann"),
frame_length=512,
hop_length=128,
center=True,
pad_mode="constant",
)
self.assertEqual(spec.shape, (257, 732))
# fmt: off
expected = np.array([
0.06558744, 0.06889656, 0.06263352, 0.04264418, 0.03404115,
0.03244197, 0.02279134, 0.01646339, 0.01452216, 0.00826055,
0.00062093, 0.0031821 , 0.00419456, 0.00689327, 0.01106367,
0.01712119, 0.01721762, 0.00977533, 0.01606626, 0.02275621,
0.01727687, 0.00992739, 0.01217688, 0.01049927, 0.01022947,
0.01302475, 0.01166873, 0.01081812, 0.01057327, 0.00767912,
0.00429567, 0.00089625, 0.00654583, 0.00912084, 0.00700984,
0.00225026, 0.00290545, 0.00667712, 0.00730663, 0.00410813,
0.00073102, 0.00219296, 0.00527618, 0.00996585, 0.01123781,
0.00872816, 0.01165121, 0.02047945, 0.03681747, 0.0514379 ,
0.05137928, 0.03960042, 0.02821562, 0.01813349, 0.01201322,
0.01260964, 0.00900654, 0.00207905, 0.00456714, 0.00850599,
0.00788239, 0.00664407, 0.00824227, 0.00628301,
])
# fmt: on
self.assertTrue(np.allclose(spec[:64, 0], expected))
spec = spectrogram(
waveform,
window_function(512, "hann"),
frame_length=512,
hop_length=128,
center=False,
)
self.assertEqual(spec.shape, (257, 728))
# fmt: off
expected = np.array([
0.00250445, 0.02161521, 0.06232229, 0.04339567, 0.00937727,
0.01080616, 0.00248685, 0.0095264 , 0.00727476, 0.0079152 ,
0.00839946, 0.00254932, 0.00716622, 0.005559 , 0.00272623,
0.00581774, 0.01896395, 0.01829788, 0.01020514, 0.01632692,
0.00870888, 0.02065827, 0.0136022 , 0.0132382 , 0.011827 ,
0.00194505, 0.0189979 , 0.026874 , 0.02194014, 0.01923883,
0.01621437, 0.00661967, 0.00289517, 0.00470257, 0.00957801,
0.00191455, 0.00431664, 0.00544359, 0.01126213, 0.00785778,
0.00423469, 0.01322504, 0.02226548, 0.02318576, 0.03428908,
0.03648811, 0.0202938 , 0.011902 , 0.03226198, 0.06347476,
0.01306318, 0.05308729, 0.05474771, 0.03127991, 0.00998512,
0.01449977, 0.01272741, 0.00868176, 0.00850386, 0.00313876,
0.00811857, 0.00538216, 0.00685749, 0.00535275,
])
# fmt: on
self.assertTrue(np.allclose(spec[:64, 0], expected))
def test_spectrogram_shapes(self):
waveform = self._load_datasamples(1)[0]
spec = spectrogram(
waveform,
window_function(400, "hann"),
frame_length=400,
hop_length=128,
power=1.0,
center=True,
pad_mode="reflect",
onesided=True,
)
self.assertEqual(spec.shape, (201, 732))
spec = spectrogram(
waveform,
window_function(400, "hann"),
frame_length=400,
hop_length=128,
power=1.0,
center=False,
pad_mode="reflect",
onesided=True,
)
self.assertEqual(spec.shape, (201, 729))
spec = spectrogram(
waveform,
window_function(400, "hann"),
frame_length=400,
hop_length=128,
fft_length=512,
power=1.0,
center=True,
pad_mode="reflect",
onesided=True,
)
self.assertEqual(spec.shape, (257, 732))
spec = spectrogram(
waveform,
window_function(400, "hann", frame_length=512),
frame_length=512,
hop_length=64,
power=1.0,
center=True,
pad_mode="reflect",
onesided=False,
)
self.assertEqual(spec.shape, (512, 1464))
spec = spectrogram(
waveform,
window_function(512, "hann"),
frame_length=512,
hop_length=64,
power=1.0,
center=True,
pad_mode="reflect",
onesided=False,
)
self.assertEqual(spec.shape, (512, 1464))
spec = spectrogram(
waveform,
window_function(512, "hann"),
frame_length=512,
hop_length=512,
power=1.0,
center=True,
pad_mode="reflect",
onesided=False,
)
self.assertEqual(spec.shape, (512, 183))
def test_mel_spectrogram(self):
waveform = self._load_datasamples(1)[0]
mel_filters = mel_filter_bank(
num_frequency_bins=513,
num_mel_filters=13,
min_frequency=100,
max_frequency=4000,
sampling_rate=16000,
norm=None,
mel_scale="htk",
)
self.assertEqual(mel_filters.shape, (513, 13))
spec = spectrogram(
waveform,
window_function(800, "hann", frame_length=1024),
frame_length=1024,
hop_length=128,
power=2.0,
)
self.assertEqual(spec.shape, (513, 732))
spec = spectrogram(
waveform,
window_function(800, "hann", frame_length=1024),
frame_length=1024,
hop_length=128,
power=2.0,
mel_filters=mel_filters,
)
self.assertEqual(spec.shape, (13, 732))
# fmt: off
expected = np.array([
1.08027889e+02, 1.48080673e+01, 7.70758213e+00, 9.57676639e-01,
8.81639061e-02, 5.26073833e-02, 1.52736155e-02, 9.95350117e-03,
7.95364356e-03, 1.01148004e-02, 4.29241020e-03, 9.90708797e-03,
9.44153646e-04
])
# fmt: on
self.assertTrue(np.allclose(spec[:, 300], expected))
def test_spectrogram_power(self):
waveform = self._load_datasamples(1)[0]
spec = spectrogram(
waveform,
window_function(400, "hann", frame_length=512),
frame_length=512,
hop_length=128,
power=None,
)
self.assertEqual(spec.shape, (257, 732))
self.assertEqual(spec.dtype, np.complex64)
# fmt: off
expected = np.array([
0.01452305+0.01820039j, -0.01737362-0.01641946j,
0.0121028 +0.01565081j, -0.02794554-0.03021514j,
0.04719803+0.04086519j, -0.04391563-0.02779365j,
0.05682834+0.01571325j, -0.08604821-0.02023657j,
0.07497991+0.0186641j , -0.06366091-0.00922475j,
0.11003416+0.0114788j , -0.13677941-0.01523552j,
0.10934535-0.00117226j, -0.11635598+0.02551187j,
0.14708674-0.03469823j, -0.1328196 +0.06034218j,
0.12667368-0.13973421j, -0.14764774+0.18912019j,
0.10235471-0.12181523j, -0.00773012+0.04730498j,
-0.01487191-0.07312611j, -0.02739162+0.09619419j,
0.02895459-0.05398273j, 0.01198589+0.05276592j,
-0.02117299-0.10123465j, 0.00666388+0.09526499j,
-0.01672773-0.05649684j, 0.02723125+0.05939891j,
-0.01879361-0.062954j , 0.03686557+0.04568823j,
-0.07394181-0.07949649j, 0.06238583+0.13905765j,
])
# fmt: on
self.assertTrue(np.allclose(spec[64:96, 321], expected))
spec = spectrogram(
waveform,
window_function(400, "hann", frame_length=512),
frame_length=512,
hop_length=128,
power=1.0,
)
self.assertEqual(spec.shape, (257, 732))
self.assertEqual(spec.dtype, np.float64)
# fmt: off
expected = np.array([
0.02328461, 0.02390484, 0.01978448, 0.04115711, 0.0624309 ,
0.05197181, 0.05896072, 0.08839577, 0.07726794, 0.06432579,
0.11063128, 0.13762532, 0.10935163, 0.11911998, 0.15112405,
0.14588428, 0.18860507, 0.23992978, 0.15910825, 0.04793241,
0.07462307, 0.10001811, 0.06125769, 0.05411011, 0.10342509,
0.09549777, 0.05892122, 0.06534349, 0.06569936, 0.05870678,
0.10856833, 0.1524107 , 0.11463385, 0.05766969, 0.12385171,
0.14472842, 0.11978184, 0.10353675, 0.07244056, 0.03461861,
0.02624896, 0.02227475, 0.01238363, 0.00885281, 0.0110049 ,
0.00807005, 0.01033663, 0.01703181, 0.01445856, 0.00585615,
0.0132431 , 0.02754132, 0.01524478, 0.0204908 , 0.07453328,
0.10716327, 0.07195779, 0.08816078, 0.18340898, 0.16449876,
0.12322842, 0.1621659 , 0.12334293, 0.06033659,
])
# fmt: on
self.assertTrue(np.allclose(spec[64:128, 321], expected))
spec = spectrogram(
waveform,
window_function(400, "hann", frame_length=512),
frame_length=512,
hop_length=128,
power=2.0,
)
self.assertEqual(spec.shape, (257, 732))
self.assertEqual(spec.dtype, np.float64)
# fmt: off
expected = np.array([
5.42173162e-04, 5.71441371e-04, 3.91425507e-04, 1.69390778e-03,
3.89761780e-03, 2.70106923e-03, 3.47636663e-03, 7.81381316e-03,
5.97033510e-03, 4.13780799e-03, 1.22392802e-02, 1.89407300e-02,
1.19577805e-02, 1.41895693e-02, 2.28384770e-02, 2.12822221e-02,
3.55718732e-02, 5.75663000e-02, 2.53154356e-02, 2.29751552e-03,
5.56860259e-03, 1.00036217e-02, 3.75250424e-03, 2.92790355e-03,
1.06967501e-02, 9.11982451e-03, 3.47171025e-03, 4.26977174e-03,
4.31640586e-03, 3.44648538e-03, 1.17870830e-02, 2.32290216e-02,
1.31409196e-02, 3.32579296e-03, 1.53392460e-02, 2.09463164e-02,
1.43476883e-02, 1.07198600e-02, 5.24763530e-03, 1.19844836e-03,
6.89007982e-04, 4.96164430e-04, 1.53354369e-04, 7.83722571e-05,
1.21107812e-04, 6.51257360e-05, 1.06845939e-04, 2.90082477e-04,
2.09049831e-04, 3.42945241e-05, 1.75379610e-04, 7.58524227e-04,
2.32403356e-04, 4.19872697e-04, 5.55520924e-03, 1.14839673e-02,
5.17792348e-03, 7.77232368e-03, 3.36388536e-02, 2.70598419e-02,
1.51852425e-02, 2.62977779e-02, 1.52134784e-02, 3.64050455e-03,
])
# fmt: on
self.assertTrue(np.allclose(spec[64:128, 321], expected))
def test_power_to_db(self):
spectrogram = np.zeros((2, 3))
spectrogram[0, 0] = 2.0
spectrogram[0, 1] = 0.5
spectrogram[0, 2] = 0.707
spectrogram[1, 1] = 1.0
output = power_to_db(spectrogram, reference=1.0)
expected = np.array([[3.01029996, -3.01029996, -1.50580586], [-100.0, 0.0, -100.0]])
self.assertTrue(np.allclose(output, expected))
output = power_to_db(spectrogram, reference=2.0)
expected = np.array([[0.0, -6.02059991, -4.51610582], [-103.01029996, -3.01029996, -103.01029996]])
self.assertTrue(np.allclose(output, expected))
output = power_to_db(spectrogram, min_value=1e-6)
expected = np.array([[3.01029996, -3.01029996, -1.50580586], [-60.0, 0.0, -60.0]])
self.assertTrue(np.allclose(output, expected))
output = power_to_db(spectrogram, db_range=80)
expected = np.array([[3.01029996, -3.01029996, -1.50580586], [-76.98970004, 0.0, -76.98970004]])
self.assertTrue(np.allclose(output, expected))
output = power_to_db(spectrogram, reference=2.0, db_range=80)
expected = np.array([[0.0, -6.02059991, -4.51610582], [-80.0, -3.01029996, -80.0]])
self.assertTrue(np.allclose(output, expected))
output = power_to_db(spectrogram, reference=2.0, min_value=1e-6, db_range=80)
expected = np.array([[0.0, -6.02059991, -4.51610582], [-63.01029996, -3.01029996, -63.01029996]])
self.assertTrue(np.allclose(output, expected))
with pytest.raises(ValueError):
power_to_db(spectrogram, reference=0.0)
with pytest.raises(ValueError):
power_to_db(spectrogram, min_value=0.0)
with pytest.raises(ValueError):
power_to_db(spectrogram, db_range=-80)
def test_amplitude_to_db(self):
spectrogram = np.zeros((2, 3))
spectrogram[0, 0] = 2.0
spectrogram[0, 1] = 0.5
spectrogram[0, 2] = 0.707
spectrogram[1, 1] = 1.0
output = amplitude_to_db(spectrogram, reference=1.0)
expected = np.array([[6.02059991, -6.02059991, -3.01161172], [-100.0, 0.0, -100.0]])
self.assertTrue(np.allclose(output, expected))
output = amplitude_to_db(spectrogram, reference=2.0)
expected = np.array([[0.0, -12.04119983, -9.03221164], [-106.02059991, -6.02059991, -106.02059991]])
self.assertTrue(np.allclose(output, expected))
output = amplitude_to_db(spectrogram, min_value=1e-3)
expected = np.array([[6.02059991, -6.02059991, -3.01161172], [-60.0, 0.0, -60.0]])
self.assertTrue(np.allclose(output, expected))
output = amplitude_to_db(spectrogram, db_range=80)
expected = np.array([[6.02059991, -6.02059991, -3.01161172], [-73.97940009, 0.0, -73.97940009]])
self.assertTrue(np.allclose(output, expected))
output = amplitude_to_db(spectrogram, reference=2.0, db_range=80)
expected = np.array([[0.0, -12.04119983, -9.03221164], [-80.0, -6.02059991, -80.0]])
self.assertTrue(np.allclose(output, expected))
output = amplitude_to_db(spectrogram, reference=2.0, min_value=1e-3, db_range=80)
expected = np.array([[0.0, -12.04119983, -9.03221164], [-66.02059991, -6.02059991, -66.02059991]])
self.assertTrue(np.allclose(output, expected))
with pytest.raises(ValueError):
amplitude_to_db(spectrogram, reference=0.0)
with pytest.raises(ValueError):
amplitude_to_db(spectrogram, min_value=0.0)
with pytest.raises(ValueError):
amplitude_to_db(spectrogram, db_range=-80)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment