Unverified Commit 814619f5 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Whisper] Use torch for stft if available (#26119)

* [Whisper] Use torch for stft if available

* update docstring

* mock patch decorator

* fit on one line
parent 7e93ce40
...@@ -19,12 +19,16 @@ from typing import List, Optional, Union ...@@ -19,12 +19,16 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
from ... import is_torch_available
from ...audio_utils import mel_filter_bank, spectrogram, window_function from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, logging from ...utils import TensorType, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -109,6 +113,24 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -109,6 +113,24 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
log_spec = (log_spec + 4.0) / 4.0 log_spec = (log_spec + 4.0) / 4.0
return log_spec return log_spec
def _torch_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
"""
Compute the log-mel spectrogram of the provided audio using the PyTorch STFT implementation.
"""
waveform = torch.from_numpy(waveform).type(torch.float32)
window = torch.hann_window(self.n_fft)
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
mel_spec = mel_filters.T @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec.numpy()
@staticmethod @staticmethod
# Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
def zero_mean_unit_var_norm( def zero_mean_unit_var_norm(
...@@ -146,7 +168,8 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -146,7 +168,8 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
**kwargs, **kwargs,
) -> BatchFeature: ) -> BatchFeature:
""" """
Main method to featurize and prepare for the model one or several sequence(s). Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for
the STFT computation if available, otherwise a slower NumPy based one.
Args: Args:
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
...@@ -246,7 +269,10 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -246,7 +269,10 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
# make sure list is in array format # make sure list is in array format
input_features = padded_inputs.get("input_features").transpose(2, 0, 1) input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
input_features = [self._np_extract_fbank_features(waveform) for waveform in input_features[0]] extract_fbank_features = (
self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features
)
input_features = [extract_fbank_features(waveform) for waveform in input_features[0]]
if isinstance(input_features[0], List): if isinstance(input_features[0], List):
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features] padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
......
...@@ -23,16 +23,13 @@ import unittest ...@@ -23,16 +23,13 @@ import unittest
import numpy as np import numpy as np
from datasets import load_dataset from datasets import load_dataset
from transformers import is_speech_available from transformers import WhisperFeatureExtractor
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio from transformers.testing_utils import check_json_file_has_correct_format, require_torch
from transformers.utils.import_utils import is_torch_available from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
if is_speech_available():
from transformers import WhisperFeatureExtractor
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -53,8 +50,6 @@ def floats_list(shape, scale=1.0, rng=None, name=None): ...@@ -53,8 +50,6 @@ def floats_list(shape, scale=1.0, rng=None, name=None):
return values return values
@require_torch
@require_torchaudio
class WhisperFeatureExtractionTester(unittest.TestCase): class WhisperFeatureExtractionTester(unittest.TestCase):
def __init__( def __init__(
self, self,
...@@ -111,10 +106,8 @@ class WhisperFeatureExtractionTester(unittest.TestCase): ...@@ -111,10 +106,8 @@ class WhisperFeatureExtractionTester(unittest.TestCase):
return speech_inputs return speech_inputs
@require_torch
@require_torchaudio
class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = WhisperFeatureExtractor if is_speech_available() else None feature_extraction_class = WhisperFeatureExtractor
def setUp(self): def setUp(self):
self.feat_extract_tester = WhisperFeatureExtractionTester(self) self.feat_extract_tester = WhisperFeatureExtractionTester(self)
...@@ -193,6 +186,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest. ...@@ -193,6 +186,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
@require_torch
def test_double_precision_pad(self): def test_double_precision_pad(self):
import torch import torch
...@@ -213,7 +207,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest. ...@@ -213,7 +207,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
return [x["array"] for x in speech_samples] return [x["array"] for x in speech_samples]
def test_integration(self): @require_torch
def test_torch_integration(self):
# fmt: off # fmt: off
EXPECTED_INPUT_FEATURES = torch.tensor( EXPECTED_INPUT_FEATURES = torch.tensor(
[ [
...@@ -231,6 +226,25 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest. ...@@ -231,6 +226,25 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
self.assertEqual(input_features.shape, (1, 80, 3000)) self.assertEqual(input_features.shape, (1, 80, 3000))
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4)) self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
@unittest.mock.patch("transformers.models.whisper.feature_extraction_whisper.is_torch_available", lambda: False)
def test_numpy_integration(self):
# fmt: off
EXPECTED_INPUT_FEATURES = np.array(
[
0.1193, -0.0946, -0.1098, -0.0196, 0.0225, -0.0690, -0.1736, 0.0951,
0.0971, -0.0817, -0.0702, 0.0162, 0.0260, 0.0017, -0.0192, -0.1678,
0.0709, -0.1867, -0.0655, -0.0274, -0.0234, -0.1884, -0.0516, -0.0554,
-0.0274, -0.1425, -0.1423, 0.0837, 0.0377, -0.0854
]
)
# fmt: on
input_speech = self._load_datasamples(1)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="np").input_features
self.assertEqual(input_features.shape, (1, 80, 3000))
self.assertTrue(np.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
def test_zero_mean_unit_variance_normalization_trunc_np_longest(self): def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
audio = self._load_datasamples(1)[0] audio = self._load_datasamples(1)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment