Unverified Commit be74b2ea authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Add numpy alternative to FE using torchaudio (#26339)

* add audio_utils usage in the FE of SpeechToText

* clean unecessary parameters of AudioSpectrogramTransformer FE

* add audio_utils usage in AST

* add serialization tests and function to FEs

* make style

* remove use_torchaudio and move to_dict to FE

* test audio_utils usage

* make style and fix import (remove torchaudio dependency import)

* fix torch dependency for jax and tensor tests

* fix typo

* clean tests with suggestions

* add lines to test if is_speech_availble is False
parent e2647450
...@@ -146,6 +146,7 @@ _import_structure = { ...@@ -146,6 +146,7 @@ _import_structure = {
"models.audio_spectrogram_transformer": [ "models.audio_spectrogram_transformer": [
"AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"ASTConfig", "ASTConfig",
"ASTFeatureExtractor",
], ],
"models.auto": [ "models.auto": [
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
...@@ -535,6 +536,7 @@ _import_structure = { ...@@ -535,6 +536,7 @@ _import_structure = {
"models.speech_to_text": [ "models.speech_to_text": [
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Speech2TextConfig", "Speech2TextConfig",
"Speech2TextFeatureExtractor",
"Speech2TextProcessor", "Speech2TextProcessor",
], ],
"models.speech_to_text_2": [ "models.speech_to_text_2": [
...@@ -913,20 +915,6 @@ except OptionalDependencyNotAvailable: ...@@ -913,20 +915,6 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"] _import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"]
# Speech-specific objects
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_speech_objects
_import_structure["utils.dummy_speech_objects"] = [
name for name in dir(dummy_speech_objects) if not name.startswith("_")
]
else:
_import_structure["models.audio_spectrogram_transformer"].append("ASTFeatureExtractor")
_import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor")
# Tensorflow-text-specific objects # Tensorflow-text-specific objects
try: try:
if not is_tensorflow_text_available(): if not is_tensorflow_text_available():
...@@ -4352,6 +4340,7 @@ if TYPE_CHECKING: ...@@ -4352,6 +4340,7 @@ if TYPE_CHECKING:
from .models.audio_spectrogram_transformer import ( from .models.audio_spectrogram_transformer import (
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
ASTConfig, ASTConfig,
ASTFeatureExtractor,
) )
from .models.auto import ( from .models.auto import (
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -4722,6 +4711,7 @@ if TYPE_CHECKING: ...@@ -4722,6 +4711,7 @@ if TYPE_CHECKING:
from .models.speech_to_text import ( from .models.speech_to_text import (
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Speech2TextConfig, Speech2TextConfig,
Speech2TextFeatureExtractor,
Speech2TextProcessor, Speech2TextProcessor,
) )
from .models.speech_to_text_2 import ( from .models.speech_to_text_2 import (
...@@ -5067,15 +5057,6 @@ if TYPE_CHECKING: ...@@ -5067,15 +5057,6 @@ if TYPE_CHECKING:
else: else:
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_speech_objects import *
else:
from .models.audio_spectrogram_transformer import ASTFeatureExtractor
from .models.speech_to_text import Speech2TextFeatureExtractor
try: try:
if not is_tensorflow_text_available(): if not is_tensorflow_text_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
......
...@@ -584,14 +584,15 @@ class FeatureExtractionMixin(PushToHubMixin): ...@@ -584,14 +584,15 @@ class FeatureExtractionMixin(PushToHubMixin):
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
""" """
Serializes this instance to a Python dictionary. Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance.
""" """
output = copy.deepcopy(self.__dict__) output = copy.deepcopy(self.__dict__)
output["feature_extractor_type"] = self.__class__.__name__ output["feature_extractor_type"] = self.__class__.__name__
if "mel_filters" in output:
del output["mel_filters"]
if "window" in output:
del output["window"]
return output return output
@classmethod @classmethod
......
...@@ -13,14 +13,15 @@ ...@@ -13,14 +13,15 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_speech_available, is_torch_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = { _import_structure = {
"configuration_audio_spectrogram_transformer": [ "configuration_audio_spectrogram_transformer": [
"AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"ASTConfig", "ASTConfig",
] ],
"feature_extraction_audio_spectrogram_transformer": ["ASTFeatureExtractor"],
} }
try: try:
...@@ -36,19 +37,13 @@ else: ...@@ -36,19 +37,13 @@ else:
"ASTPreTrainedModel", "ASTPreTrainedModel",
] ]
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_audio_spectrogram_transformer"] = ["ASTFeatureExtractor"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_audio_spectrogram_transformer import ( from .configuration_audio_spectrogram_transformer import (
AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
ASTConfig, ASTConfig,
) )
from .feature_extraction_audio_spectrogram_transformer import ASTFeatureExtractor
try: try:
if not is_torch_available(): if not is_torch_available():
...@@ -63,14 +58,6 @@ if TYPE_CHECKING: ...@@ -63,14 +58,6 @@ if TYPE_CHECKING:
ASTPreTrainedModel, ASTPreTrainedModel,
) )
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_audio_spectrogram_transformer import ASTFeatureExtractor
else: else:
import sys import sys
......
...@@ -19,12 +19,18 @@ Feature extractor class for Audio Spectrogram Transformer. ...@@ -19,12 +19,18 @@ Feature extractor class for Audio Spectrogram Transformer.
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
import torch
import torchaudio.compliance.kaldi as ta_kaldi
from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, logging from ...utils import TensorType, is_speech_available, is_torch_available, logging
if is_speech_available():
import torchaudio.compliance.kaldi as ta_kaldi
if is_torch_available():
import torch
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -37,8 +43,8 @@ class ASTFeatureExtractor(SequenceFeatureExtractor): ...@@ -37,8 +43,8 @@ class ASTFeatureExtractor(SequenceFeatureExtractor):
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
most of the main methods. Users should refer to this superclass for more information regarding those methods. most of the main methods. Users should refer to this superclass for more information regarding those methods.
This class extracts mel-filter bank features from raw speech using TorchAudio, pads/truncates them to a fixed This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy
length and normalizes them using a mean and standard deviation. otherwise, pads/truncates them to a fixed length and normalizes them using a mean and standard deviation.
Args: Args:
feature_size (`int`, *optional*, defaults to 1): feature_size (`int`, *optional*, defaults to 1):
...@@ -83,6 +89,21 @@ class ASTFeatureExtractor(SequenceFeatureExtractor): ...@@ -83,6 +89,21 @@ class ASTFeatureExtractor(SequenceFeatureExtractor):
self.std = std self.std = std
self.return_attention_mask = return_attention_mask self.return_attention_mask = return_attention_mask
if not is_speech_available():
mel_filters = mel_filter_bank(
num_frequency_bins=256,
num_mel_filters=self.num_mel_bins,
min_frequency=20,
max_frequency=sampling_rate // 2,
sampling_rate=sampling_rate,
norm=None,
mel_scale="kaldi",
triangularize_in_mel_space=True,
)
self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
self.window = window_function(400, "hann", periodic=False)
def _extract_fbank_features( def _extract_fbank_features(
self, self,
waveform: np.ndarray, waveform: np.ndarray,
...@@ -93,17 +114,32 @@ class ASTFeatureExtractor(SequenceFeatureExtractor): ...@@ -93,17 +114,32 @@ class ASTFeatureExtractor(SequenceFeatureExtractor):
and hence the waveform should not be normalized before feature extraction. and hence the waveform should not be normalized before feature extraction.
""" """
# waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers # waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
if is_speech_available():
waveform = torch.from_numpy(waveform).unsqueeze(0) waveform = torch.from_numpy(waveform).unsqueeze(0)
fbank = ta_kaldi.fbank( fbank = ta_kaldi.fbank(
waveform, waveform,
htk_compat=True,
sample_frequency=self.sampling_rate, sample_frequency=self.sampling_rate,
use_energy=False,
window_type="hanning", window_type="hanning",
num_mel_bins=self.num_mel_bins, num_mel_bins=self.num_mel_bins,
dither=0.0,
frame_shift=10,
) )
else:
waveform = np.squeeze(waveform)
fbank = spectrogram(
waveform,
self.window,
frame_length=400,
hop_length=160,
fft_length=512,
power=2.0,
center=False,
preemphasis=0.97,
mel_filters=self.mel_filters,
log_mel="log",
mel_floor=1.192092955078125e-07,
remove_dc_offset=True,
).T
fbank = torch.from_numpy(fbank)
n_frames = fbank.shape[0] n_frames = fbank.shape[0]
difference = max_length - n_frames difference = max_length - n_frames
......
...@@ -17,7 +17,6 @@ from ...utils import ( ...@@ -17,7 +17,6 @@ from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
is_sentencepiece_available, is_sentencepiece_available,
is_speech_available,
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
) )
...@@ -25,6 +24,7 @@ from ...utils import ( ...@@ -25,6 +24,7 @@ from ...utils import (
_import_structure = { _import_structure = {
"configuration_speech_to_text": ["SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "Speech2TextConfig"], "configuration_speech_to_text": ["SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "Speech2TextConfig"],
"feature_extraction_speech_to_text": ["Speech2TextFeatureExtractor"],
"processing_speech_to_text": ["Speech2TextProcessor"], "processing_speech_to_text": ["Speech2TextProcessor"],
} }
...@@ -36,14 +36,6 @@ except OptionalDependencyNotAvailable: ...@@ -36,14 +36,6 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"] _import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"]
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_speech_to_text"] = ["Speech2TextFeatureExtractor"]
try: try:
if not is_tf_available(): if not is_tf_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
...@@ -73,6 +65,7 @@ else: ...@@ -73,6 +65,7 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig from .configuration_speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig
from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor
from .processing_speech_to_text import Speech2TextProcessor from .processing_speech_to_text import Speech2TextProcessor
try: try:
...@@ -83,14 +76,6 @@ if TYPE_CHECKING: ...@@ -83,14 +76,6 @@ if TYPE_CHECKING:
else: else:
from .tokenization_speech_to_text import Speech2TextTokenizer from .tokenization_speech_to_text import Speech2TextTokenizer
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor
try: try:
if not is_tf_available(): if not is_tf_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
......
...@@ -19,14 +19,17 @@ Feature extractor class for Speech2Text ...@@ -19,14 +19,17 @@ Feature extractor class for Speech2Text
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
import torch
import torchaudio.compliance.kaldi as ta_kaldi
from ...audio_utils import mel_filter_bank, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging from ...utils import PaddingStrategy, TensorType, is_speech_available, logging
if is_speech_available():
import torch
import torchaudio.compliance.kaldi as ta_kaldi
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -37,8 +40,8 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): ...@@ -37,8 +40,8 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
This feature extractor inherits from [`Speech2TextFeatureExtractor`] which contains most of the main methods. Users This feature extractor inherits from [`Speech2TextFeatureExtractor`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods. should refer to this superclass for more information regarding those methods.
This class extracts mel-filter bank features from raw speech using TorchAudio and applies utterance-level cepstral This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy
mean and variance normalization to the extracted features. otherwise, and applies utterance-level cepstral mean and variance normalization to the extracted features.
Args: Args:
feature_size (`int`, *optional*, defaults to 80): feature_size (`int`, *optional*, defaults to 80):
...@@ -77,6 +80,21 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): ...@@ -77,6 +80,21 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
self.normalize_vars = normalize_vars self.normalize_vars = normalize_vars
self.return_attention_mask = True self.return_attention_mask = True
if not is_speech_available():
mel_filters = mel_filter_bank(
num_frequency_bins=256,
num_mel_filters=self.num_mel_bins,
min_frequency=20,
max_frequency=sampling_rate // 2,
sampling_rate=sampling_rate,
norm=None,
mel_scale="kaldi",
triangularize_in_mel_space=True,
)
self.mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))
self.window = window_function(400, "povey", periodic=False)
def _extract_fbank_features( def _extract_fbank_features(
self, self,
waveform: np.ndarray, waveform: np.ndarray,
...@@ -86,9 +104,27 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): ...@@ -86,9 +104,27 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
and hence the waveform should not be normalized before feature extraction. and hence the waveform should not be normalized before feature extraction.
""" """
waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
if is_speech_available():
waveform = torch.from_numpy(waveform).unsqueeze(0) waveform = torch.from_numpy(waveform).unsqueeze(0)
features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate) features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate)
return features.numpy() features = features.numpy()
else:
waveform = np.squeeze(waveform)
features = spectrogram(
waveform,
self.window,
frame_length=400,
hop_length=160,
fft_length=512,
power=2.0,
center=False,
preemphasis=0.97,
mel_filters=self.mel_filters,
log_mel="log",
mel_floor=1.192092955078125e-07,
remove_dc_offset=True,
).T
return features
@staticmethod @staticmethod
def utterance_cmvn( def utterance_cmvn(
......
...@@ -15,13 +15,15 @@ ...@@ -15,13 +15,15 @@
import itertools import itertools
import os
import random import random
import tempfile
import unittest import unittest
import numpy as np import numpy as np
from transformers import ASTFeatureExtractor from transformers import ASTFeatureExtractor
from transformers.testing_utils import require_torch, require_torchaudio from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
from transformers.utils.import_utils import is_torch_available from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
...@@ -173,3 +175,48 @@ class ASTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test ...@@ -173,3 +175,48 @@ class ASTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test
input_values = feature_extractor(input_speech, return_tensors="pt").input_values input_values = feature_extractor(input_speech, return_tensors="pt").input_values
self.assertEquals(input_values.shape, (1, 1024, 128)) self.assertEquals(input_values.shape, (1, 1024, 128))
self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4)) self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
def test_feat_extract_from_and_save_pretrained(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
self.assertDictEqual(dict_first, dict_second)
def test_feat_extract_to_json_file(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
feat_extract_first.to_json_file(json_file_path)
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
self.assertEqual(dict_first, dict_second)
# exact same tests than before, except that we simulate that torchaudio is not available
@require_torch
@unittest.mock.patch(
"transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer.is_speech_available",
lambda: False,
)
class ASTFeatureExtractionWithoutTorchaudioTest(ASTFeatureExtractionTest):
def test_using_audio_utils(self):
# Tests that it uses audio_utils instead of torchaudio
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
self.assertTrue(hasattr(feat_extract, "window"))
self.assertTrue(hasattr(feat_extract, "mel_filters"))
from transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer import (
is_speech_available,
)
self.assertFalse(is_speech_available())
...@@ -15,20 +15,19 @@ ...@@ -15,20 +15,19 @@
import itertools import itertools
import os
import random import random
import tempfile
import unittest import unittest
import numpy as np import numpy as np
from transformers import is_speech_available from transformers import Speech2TextFeatureExtractor
from transformers.testing_utils import require_torch, require_torchaudio from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
if is_speech_available():
from transformers import Speech2TextFeatureExtractor
global_rng = random.Random() global_rng = random.Random()
...@@ -105,7 +104,7 @@ class Speech2TextFeatureExtractionTester(unittest.TestCase): ...@@ -105,7 +104,7 @@ class Speech2TextFeatureExtractionTester(unittest.TestCase):
@require_torch @require_torch
@require_torchaudio @require_torchaudio
class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = Speech2TextFeatureExtractor if is_speech_available() else None feature_extraction_class = Speech2TextFeatureExtractor
def setUp(self): def setUp(self):
self.feat_extract_tester = Speech2TextFeatureExtractionTester(self) self.feat_extract_tester = Speech2TextFeatureExtractionTester(self)
...@@ -280,3 +279,45 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt ...@@ -280,3 +279,45 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
input_features = feature_extractor(input_speech, return_tensors="pt").input_features input_features = feature_extractor(input_speech, return_tensors="pt").input_features
self.assertEquals(input_features.shape, (1, 584, 24)) self.assertEquals(input_features.shape, (1, 584, 24))
self.assertTrue(np.allclose(input_features[0, 0, :30], expected, atol=1e-4)) self.assertTrue(np.allclose(input_features[0, 0, :30], expected, atol=1e-4))
def test_feat_extract_from_and_save_pretrained(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
self.assertDictEqual(dict_first, dict_second)
def test_feat_extract_to_json_file(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
feat_extract_first.to_json_file(json_file_path)
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
self.assertEqual(dict_first, dict_second)
# exact same tests than before, except that we simulate that torchaudio is not available
@require_torch
@unittest.mock.patch(
"transformers.models.speech_to_text.feature_extraction_speech_to_text.is_speech_available", lambda: False
)
class Speech2TextFeatureExtractionWithoutTorchaudioTest(Speech2TextFeatureExtractionTest):
def test_using_audio_utils(self):
# Tests that it uses audio_utils instead of torchaudio
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
self.assertTrue(hasattr(feat_extract, "window"))
self.assertTrue(hasattr(feat_extract, "mel_filters"))
from transformers.models.speech_to_text.feature_extraction_speech_to_text import is_speech_available
self.assertFalse(is_speech_available())
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from transformers import Speech2TextTokenizer, is_speech_available from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor, Speech2TextTokenizer
from transformers.models.speech_to_text.tokenization_speech_to_text import VOCAB_FILES_NAMES, save_json from transformers.models.speech_to_text.tokenization_speech_to_text import VOCAB_FILES_NAMES, save_json
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_torch, require_torchaudio from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_torch, require_torchaudio
from transformers.utils import FEATURE_EXTRACTOR_NAME from transformers.utils import FEATURE_EXTRACTOR_NAME
...@@ -26,10 +26,6 @@ from transformers.utils import FEATURE_EXTRACTOR_NAME ...@@ -26,10 +26,6 @@ from transformers.utils import FEATURE_EXTRACTOR_NAME
from .test_feature_extraction_speech_to_text import floats_list from .test_feature_extraction_speech_to_text import floats_list
if is_speech_available():
from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor
SAMPLE_SP = get_tests_dir("fixtures/test_sentencepiece.model") SAMPLE_SP = get_tests_dir("fixtures/test_sentencepiece.model")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment