Unverified Commit 4e945660 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Fix audio feature extractor deps (#24636)

* Fix audio feature extractor deps

* use audio utils window over torch window
parent cd4584e3
...@@ -285,6 +285,7 @@ _import_structure = { ...@@ -285,6 +285,7 @@ _import_structure = {
"models.encodec": [ "models.encodec": [
"ENCODEC_PRETRAINED_CONFIG_ARCHIVE_MAP", "ENCODEC_PRETRAINED_CONFIG_ARCHIVE_MAP",
"EncodecConfig", "EncodecConfig",
"EncodecFeatureExtractor",
], ],
"models.encoder_decoder": ["EncoderDecoderConfig"], "models.encoder_decoder": ["EncoderDecoderConfig"],
"models.ernie": [ "models.ernie": [
...@@ -388,7 +389,7 @@ _import_structure = { ...@@ -388,7 +389,7 @@ _import_structure = {
"models.maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig", "MaskFormerSwinConfig"], "models.maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig", "MaskFormerSwinConfig"],
"models.mbart": ["MBartConfig"], "models.mbart": ["MBartConfig"],
"models.mbart50": [], "models.mbart50": [],
"models.mctct": ["MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MCTCTConfig", "MCTCTProcessor"], "models.mctct": ["MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MCTCTConfig", "MCTCTFeatureExtractor", "MCTCTProcessor"],
"models.mega": ["MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegaConfig"], "models.mega": ["MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegaConfig"],
"models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"], "models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
"models.megatron_gpt2": [], "models.megatron_gpt2": [],
...@@ -481,6 +482,7 @@ _import_structure = { ...@@ -481,6 +482,7 @@ _import_structure = {
"SPEECHT5_PRETRAINED_CONFIG_ARCHIVE_MAP", "SPEECHT5_PRETRAINED_CONFIG_ARCHIVE_MAP",
"SPEECHT5_PRETRAINED_HIFIGAN_CONFIG_ARCHIVE_MAP", "SPEECHT5_PRETRAINED_HIFIGAN_CONFIG_ARCHIVE_MAP",
"SpeechT5Config", "SpeechT5Config",
"SpeechT5FeatureExtractor",
"SpeechT5HifiGanConfig", "SpeechT5HifiGanConfig",
"SpeechT5Processor", "SpeechT5Processor",
], ],
...@@ -519,6 +521,7 @@ _import_structure = { ...@@ -519,6 +521,7 @@ _import_structure = {
"models.tvlt": [ "models.tvlt": [
"TVLT_PRETRAINED_CONFIG_ARCHIVE_MAP", "TVLT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TvltConfig", "TvltConfig",
"TvltFeatureExtractor",
"TvltProcessor", "TvltProcessor",
], ],
"models.umt5": [], "models.umt5": [],
...@@ -843,11 +846,7 @@ except OptionalDependencyNotAvailable: ...@@ -843,11 +846,7 @@ except OptionalDependencyNotAvailable:
] ]
else: else:
_import_structure["models.audio_spectrogram_transformer"].append("ASTFeatureExtractor") _import_structure["models.audio_spectrogram_transformer"].append("ASTFeatureExtractor")
_import_structure["models.encodec"].append("EncodecFeatureExtractor")
_import_structure["models.mctct"].append("MCTCTFeatureExtractor")
_import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor") _import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor")
_import_structure["models.speecht5"].append("SpeechT5FeatureExtractor")
_import_structure["models.tvlt"].append("TvltFeatureExtractor")
# Tensorflow-text-specific objects # Tensorflow-text-specific objects
try: try:
...@@ -4170,6 +4169,7 @@ if TYPE_CHECKING: ...@@ -4170,6 +4169,7 @@ if TYPE_CHECKING:
from .models.encodec import ( from .models.encodec import (
ENCODEC_PRETRAINED_CONFIG_ARCHIVE_MAP, ENCODEC_PRETRAINED_CONFIG_ARCHIVE_MAP,
EncodecConfig, EncodecConfig,
EncodecFeatureExtractor,
) )
from .models.encoder_decoder import EncoderDecoderConfig from .models.encoder_decoder import EncoderDecoderConfig
from .models.ernie import ERNIE_PRETRAINED_CONFIG_ARCHIVE_MAP, ErnieConfig from .models.ernie import ERNIE_PRETRAINED_CONFIG_ARCHIVE_MAP, ErnieConfig
...@@ -4265,7 +4265,7 @@ if TYPE_CHECKING: ...@@ -4265,7 +4265,7 @@ if TYPE_CHECKING:
from .models.mask2former import MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, Mask2FormerConfig from .models.mask2former import MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, Mask2FormerConfig
from .models.maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig, MaskFormerSwinConfig from .models.maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig, MaskFormerSwinConfig
from .models.mbart import MBartConfig from .models.mbart import MBartConfig
from .models.mctct import MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP, MCTCTConfig, MCTCTProcessor from .models.mctct import MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP, MCTCTConfig, MCTCTFeatureExtractor, MCTCTProcessor
from .models.mega import MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP, MegaConfig from .models.mega import MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP, MegaConfig
from .models.megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig from .models.megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig
from .models.mgp_str import MGP_STR_PRETRAINED_CONFIG_ARCHIVE_MAP, MgpstrConfig, MgpstrProcessor, MgpstrTokenizer from .models.mgp_str import MGP_STR_PRETRAINED_CONFIG_ARCHIVE_MAP, MgpstrConfig, MgpstrProcessor, MgpstrTokenizer
...@@ -4355,6 +4355,7 @@ if TYPE_CHECKING: ...@@ -4355,6 +4355,7 @@ if TYPE_CHECKING:
SPEECHT5_PRETRAINED_CONFIG_ARCHIVE_MAP, SPEECHT5_PRETRAINED_CONFIG_ARCHIVE_MAP,
SPEECHT5_PRETRAINED_HIFIGAN_CONFIG_ARCHIVE_MAP, SPEECHT5_PRETRAINED_HIFIGAN_CONFIG_ARCHIVE_MAP,
SpeechT5Config, SpeechT5Config,
SpeechT5FeatureExtractor,
SpeechT5HifiGanConfig, SpeechT5HifiGanConfig,
SpeechT5Processor, SpeechT5Processor,
) )
...@@ -4386,7 +4387,7 @@ if TYPE_CHECKING: ...@@ -4386,7 +4387,7 @@ if TYPE_CHECKING:
TransfoXLTokenizer, TransfoXLTokenizer,
) )
from .models.trocr import TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP, TrOCRConfig, TrOCRProcessor from .models.trocr import TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP, TrOCRConfig, TrOCRProcessor
from .models.tvlt import TVLT_PRETRAINED_CONFIG_ARCHIVE_MAP, TvltConfig, TvltProcessor from .models.tvlt import TVLT_PRETRAINED_CONFIG_ARCHIVE_MAP, TvltConfig, TvltFeatureExtractor, TvltProcessor
from .models.unispeech import UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechConfig from .models.unispeech import UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechConfig
from .models.unispeech_sat import UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechSatConfig from .models.unispeech_sat import UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechSatConfig
from .models.upernet import UperNetConfig from .models.upernet import UperNetConfig
...@@ -4681,11 +4682,7 @@ if TYPE_CHECKING: ...@@ -4681,11 +4682,7 @@ if TYPE_CHECKING:
from .utils.dummy_speech_objects import * from .utils.dummy_speech_objects import *
else: else:
from .models.audio_spectrogram_transformer import ASTFeatureExtractor from .models.audio_spectrogram_transformer import ASTFeatureExtractor
from .models.encodec import EncodecFeatureExtractor
from .models.mctct import MCTCTFeatureExtractor
from .models.speech_to_text import Speech2TextFeatureExtractor from .models.speech_to_text import Speech2TextFeatureExtractor
from .models.speecht5 import SpeechT5FeatureExtractor
from .models.tvlt import TvltFeatureExtractor
try: try:
if not is_tensorflow_text_available(): if not is_tensorflow_text_available():
......
...@@ -13,24 +13,16 @@ ...@@ -13,24 +13,16 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_speech_available, is_torch_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = { _import_structure = {
"configuration_mctct": ["MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MCTCTConfig"], "configuration_mctct": ["MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MCTCTConfig"],
"feature_extraction_mctct": ["MCTCTFeatureExtractor"],
"processing_mctct": ["MCTCTProcessor"], "processing_mctct": ["MCTCTProcessor"],
} }
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_mctct"] = ["MCTCTFeatureExtractor"]
try: try:
if not is_torch_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
...@@ -47,16 +39,9 @@ else: ...@@ -47,16 +39,9 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_mctct import MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP, MCTCTConfig from .configuration_mctct import MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP, MCTCTConfig
from .feature_extraction_mctct import MCTCTFeatureExtractor
from .processing_mctct import MCTCTProcessor from .processing_mctct import MCTCTProcessor
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_mctct import MCTCTFeatureExtractor
try: try:
if not is_torch_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
......
...@@ -19,9 +19,8 @@ Feature extractor class for M-CTC-T ...@@ -19,9 +19,8 @@ Feature extractor class for M-CTC-T
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
import torch
from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...file_utils import PaddingStrategy, TensorType from ...file_utils import PaddingStrategy, TensorType
...@@ -110,11 +109,9 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor): ...@@ -110,11 +109,9 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor):
Extracts MFSC Features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC code. Extracts MFSC Features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC code.
""" """
if self.win_function == "hamming_window": if self.win_function == "hamming_window":
window = torch.hamming_window(window_length=self.sample_size, periodic=False, alpha=0.54, beta=0.46) window = window_function(window_length=self.sample_size, name=self.win_function, periodic=False)
else: else:
window = getattr(torch, self.win_function)() window = window_function(window_length=self.sample_size, name=self.win_function)
window = window.numpy()
fbanks = mel_filter_bank( fbanks = mel_filter_bank(
num_frequency_bins=self.n_freqs, num_frequency_bins=self.n_freqs,
......
...@@ -17,7 +17,6 @@ from ...utils import ( ...@@ -17,7 +17,6 @@ from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
is_sentencepiece_available, is_sentencepiece_available,
is_speech_available,
is_torch_available, is_torch_available,
) )
...@@ -29,6 +28,7 @@ _import_structure = { ...@@ -29,6 +28,7 @@ _import_structure = {
"SpeechT5Config", "SpeechT5Config",
"SpeechT5HifiGanConfig", "SpeechT5HifiGanConfig",
], ],
"feature_extraction_speecht5": ["SpeechT5FeatureExtractor"],
"processing_speecht5": ["SpeechT5Processor"], "processing_speecht5": ["SpeechT5Processor"],
} }
...@@ -40,14 +40,6 @@ except OptionalDependencyNotAvailable: ...@@ -40,14 +40,6 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["tokenization_speecht5"] = ["SpeechT5Tokenizer"] _import_structure["tokenization_speecht5"] = ["SpeechT5Tokenizer"]
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_speecht5"] = ["SpeechT5FeatureExtractor"]
try: try:
if not is_torch_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
...@@ -71,6 +63,7 @@ if TYPE_CHECKING: ...@@ -71,6 +63,7 @@ if TYPE_CHECKING:
SpeechT5Config, SpeechT5Config,
SpeechT5HifiGanConfig, SpeechT5HifiGanConfig,
) )
from .feature_extraction_speecht5 import SpeechT5FeatureExtractor
from .processing_speecht5 import SpeechT5Processor from .processing_speecht5 import SpeechT5Processor
try: try:
...@@ -81,14 +74,6 @@ if TYPE_CHECKING: ...@@ -81,14 +74,6 @@ if TYPE_CHECKING:
else: else:
from .tokenization_speecht5 import SpeechT5Tokenizer from .tokenization_speecht5 import SpeechT5Tokenizer
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_speecht5 import SpeechT5FeatureExtractor
try: try:
if not is_torch_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
......
...@@ -18,9 +18,8 @@ import warnings ...@@ -18,9 +18,8 @@ import warnings
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch
from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram, window_function
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging from ...utils import PaddingStrategy, TensorType, logging
...@@ -113,8 +112,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): ...@@ -113,8 +112,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
self.n_fft = optimal_fft_length(self.sample_size) self.n_fft = optimal_fft_length(self.sample_size)
self.n_freqs = (self.n_fft // 2) + 1 self.n_freqs = (self.n_fft // 2) + 1
window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True) self.window = window_function(window_length=self.sample_size, name=self.win_function, periodic=True)
self.window = window.numpy().astype(np.float64)
self.mel_filters = mel_filter_bank( self.mel_filters = mel_filter_bank(
num_frequency_bins=self.n_freqs, num_frequency_bins=self.n_freqs,
......
...@@ -20,7 +20,6 @@ from typing import TYPE_CHECKING ...@@ -20,7 +20,6 @@ from typing import TYPE_CHECKING
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
is_speech_available,
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
) )
...@@ -28,6 +27,7 @@ from ...utils import ( ...@@ -28,6 +27,7 @@ from ...utils import (
_import_structure = { _import_structure = {
"configuration_tvlt": ["TVLT_PRETRAINED_CONFIG_ARCHIVE_MAP", "TvltConfig"], "configuration_tvlt": ["TVLT_PRETRAINED_CONFIG_ARCHIVE_MAP", "TvltConfig"],
"feature_extraction_tvlt": ["TvltFeatureExtractor"],
"processing_tvlt": ["TvltProcessor"], "processing_tvlt": ["TvltProcessor"],
} }
...@@ -53,17 +53,11 @@ except OptionalDependencyNotAvailable: ...@@ -53,17 +53,11 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["image_processing_tvlt"] = ["TvltImageProcessor"] _import_structure["image_processing_tvlt"] = ["TvltImageProcessor"]
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_tvlt"] = ["TvltFeatureExtractor"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_tvlt import TVLT_PRETRAINED_CONFIG_ARCHIVE_MAP, TvltConfig from .configuration_tvlt import TVLT_PRETRAINED_CONFIG_ARCHIVE_MAP, TvltConfig
from .processing_tvlt import TvltProcessor from .processing_tvlt import TvltProcessor
from .feature_extraction_tvlt import TvltFeatureExtractor
try: try:
if not is_torch_available(): if not is_torch_available():
...@@ -87,13 +81,6 @@ if TYPE_CHECKING: ...@@ -87,13 +81,6 @@ if TYPE_CHECKING:
else: else:
from .image_processing_tvlt import TvltImageProcessor from .image_processing_tvlt import TvltImageProcessor
try:
if not is_speech_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_tvlt import TvltFeatureExtractor
else: else:
import sys import sys
......
...@@ -9,36 +9,8 @@ class ASTFeatureExtractor(metaclass=DummyObject): ...@@ -9,36 +9,8 @@ class ASTFeatureExtractor(metaclass=DummyObject):
requires_backends(self, ["speech"]) requires_backends(self, ["speech"])
class EncodecFeatureExtractor(metaclass=DummyObject):
_backends = ["speech"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["speech"])
class MCTCTFeatureExtractor(metaclass=DummyObject):
_backends = ["speech"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["speech"])
class Speech2TextFeatureExtractor(metaclass=DummyObject): class Speech2TextFeatureExtractor(metaclass=DummyObject):
_backends = ["speech"] _backends = ["speech"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["speech"]) requires_backends(self, ["speech"])
class SpeechT5FeatureExtractor(metaclass=DummyObject):
_backends = ["speech"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["speech"])
class TvltFeatureExtractor(metaclass=DummyObject):
_backends = ["speech"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["speech"])
...@@ -20,16 +20,13 @@ import unittest ...@@ -20,16 +20,13 @@ import unittest
import numpy as np import numpy as np
from transformers import is_speech_available from transformers import EncodecFeatureExtractor
from transformers.testing_utils import require_torch from transformers.testing_utils import require_torch
from transformers.utils.import_utils import is_torch_available from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
if is_speech_available():
from transformers import EncodecFeatureExtractor
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -103,7 +100,7 @@ class EnCodecFeatureExtractionTester(unittest.TestCase): ...@@ -103,7 +100,7 @@ class EnCodecFeatureExtractionTester(unittest.TestCase):
@require_torch @require_torch
class EnCodecFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): class EnCodecFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = EncodecFeatureExtractor if is_speech_available() else None feature_extraction_class = EncodecFeatureExtractor
def setUp(self): def setUp(self):
self.feat_extract_tester = EnCodecFeatureExtractionTester(self) self.feat_extract_tester = EnCodecFeatureExtractionTester(self)
......
...@@ -20,15 +20,12 @@ import unittest ...@@ -20,15 +20,12 @@ import unittest
import numpy as np import numpy as np
from transformers import is_speech_available from transformers import MCTCTFeatureExtractor
from transformers.testing_utils import require_torch from transformers.testing_utils import require_torch
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
if is_speech_available():
from transformers import MCTCTFeatureExtractor
global_rng = random.Random() global_rng = random.Random()
...@@ -102,7 +99,7 @@ class MCTCTFeatureExtractionTester(unittest.TestCase): ...@@ -102,7 +99,7 @@ class MCTCTFeatureExtractionTester(unittest.TestCase):
@require_torch @require_torch
class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = MCTCTFeatureExtractor if is_speech_available() else None feature_extraction_class = MCTCTFeatureExtractor
def setUp(self): def setUp(self):
self.feat_extract_tester = MCTCTFeatureExtractionTester(self) self.feat_extract_tester = MCTCTFeatureExtractionTester(self)
......
...@@ -20,16 +20,13 @@ import unittest ...@@ -20,16 +20,13 @@ import unittest
import numpy as np import numpy as np
from transformers import BatchFeature, is_speech_available from transformers import BatchFeature, SpeechT5FeatureExtractor
from transformers.testing_utils import require_torch from transformers.testing_utils import require_torch
from transformers.utils.import_utils import is_torch_available from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
if is_speech_available():
from transformers import SpeechT5FeatureExtractor
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -142,7 +139,7 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase): ...@@ -142,7 +139,7 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
@require_torch @require_torch
class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = SpeechT5FeatureExtractor if is_speech_available() else None feature_extraction_class = SpeechT5FeatureExtractor
def setUp(self): def setUp(self):
self.feat_extract_tester = SpeechT5FeatureExtractionTester(self) self.feat_extract_tester = SpeechT5FeatureExtractionTester(self)
......
...@@ -22,7 +22,7 @@ import unittest ...@@ -22,7 +22,7 @@ import unittest
import numpy as np import numpy as np
from transformers import is_datasets_available, is_speech_available from transformers import TvltFeatureExtractor, is_datasets_available
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
from transformers.utils.import_utils import is_torch_available from transformers.utils.import_utils import is_torch_available
...@@ -35,9 +35,6 @@ if is_torch_available(): ...@@ -35,9 +35,6 @@ if is_torch_available():
if is_datasets_available(): if is_datasets_available():
from datasets import load_dataset from datasets import load_dataset
if is_speech_available():
from transformers import TvltFeatureExtractor
global_rng = random.Random() global_rng = random.Random()
...@@ -111,7 +108,7 @@ class TvltFeatureExtractionTester(unittest.TestCase): ...@@ -111,7 +108,7 @@ class TvltFeatureExtractionTester(unittest.TestCase):
@require_torch @require_torch
@require_torchaudio @require_torchaudio
class TvltFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): class TvltFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = TvltFeatureExtractor if is_speech_available() else None feature_extraction_class = TvltFeatureExtractor
def setUp(self): def setUp(self):
self.feat_extract_tester = TvltFeatureExtractionTester(self) self.feat_extract_tester = TvltFeatureExtractionTester(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment