"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "45e11091e5332809972aa64148ac0a42aa775841"
Unverified Commit 48fa42e5 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add Speech AutoModels (#13655)

* upload

* correct

* correct

* correct

* finish

* up

* up

* up again
parent ea921365
...@@ -142,6 +142,20 @@ AutoModelForAudioClassification ...@@ -142,6 +142,20 @@ AutoModelForAudioClassification
:members: :members:
AutoModelForCTC
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AutoModelForCTC
:members:
AutoModelForSpeechSeq2Seq
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AutoModelForSpeechSeq2Seq
:members:
AutoModelForObjectDetection AutoModelForObjectDetection
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -557,6 +557,7 @@ if is_torch_available(): ...@@ -557,6 +557,7 @@ if is_torch_available():
"AutoModel", "AutoModel",
"AutoModelForAudioClassification", "AutoModelForAudioClassification",
"AutoModelForCausalLM", "AutoModelForCausalLM",
"AutoModelForCTC",
"AutoModelForImageClassification", "AutoModelForImageClassification",
"AutoModelForMaskedLM", "AutoModelForMaskedLM",
"AutoModelForMultipleChoice", "AutoModelForMultipleChoice",
...@@ -566,6 +567,7 @@ if is_torch_available(): ...@@ -566,6 +567,7 @@ if is_torch_available():
"AutoModelForQuestionAnswering", "AutoModelForQuestionAnswering",
"AutoModelForSeq2SeqLM", "AutoModelForSeq2SeqLM",
"AutoModelForSequenceClassification", "AutoModelForSequenceClassification",
"AutoModelForSpeechSeq2Seq",
"AutoModelForTableQuestionAnswering", "AutoModelForTableQuestionAnswering",
"AutoModelForTokenClassification", "AutoModelForTokenClassification",
"AutoModelWithLMHead", "AutoModelWithLMHead",
...@@ -2320,6 +2322,7 @@ if TYPE_CHECKING: ...@@ -2320,6 +2322,7 @@ if TYPE_CHECKING:
AutoModel, AutoModel,
AutoModelForAudioClassification, AutoModelForAudioClassification,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForCTC,
AutoModelForImageClassification, AutoModelForImageClassification,
AutoModelForMaskedLM, AutoModelForMaskedLM,
AutoModelForMultipleChoice, AutoModelForMultipleChoice,
...@@ -2329,6 +2332,7 @@ if TYPE_CHECKING: ...@@ -2329,6 +2332,7 @@ if TYPE_CHECKING:
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering, AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification, AutoModelForTokenClassification,
AutoModelWithLMHead, AutoModelWithLMHead,
......
...@@ -32,6 +32,7 @@ if is_torch_available(): ...@@ -32,6 +32,7 @@ if is_torch_available():
_import_structure["modeling_auto"] = [ _import_structure["modeling_auto"] = [
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
"MODEL_FOR_CAUSAL_LM_MAPPING", "MODEL_FOR_CAUSAL_LM_MAPPING",
"MODEL_FOR_CTC_MAPPING",
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING", "MODEL_FOR_MASKED_LM_MAPPING",
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
...@@ -41,6 +42,7 @@ if is_torch_available(): ...@@ -41,6 +42,7 @@ if is_torch_available():
"MODEL_FOR_QUESTION_ANSWERING_MAPPING", "MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_MAPPING", "MODEL_MAPPING",
...@@ -48,6 +50,7 @@ if is_torch_available(): ...@@ -48,6 +50,7 @@ if is_torch_available():
"AutoModel", "AutoModel",
"AutoModelForAudioClassification", "AutoModelForAudioClassification",
"AutoModelForCausalLM", "AutoModelForCausalLM",
"AutoModelForCTC",
"AutoModelForImageClassification", "AutoModelForImageClassification",
"AutoModelForMaskedLM", "AutoModelForMaskedLM",
"AutoModelForMultipleChoice", "AutoModelForMultipleChoice",
...@@ -57,6 +60,7 @@ if is_torch_available(): ...@@ -57,6 +60,7 @@ if is_torch_available():
"AutoModelForQuestionAnswering", "AutoModelForQuestionAnswering",
"AutoModelForSeq2SeqLM", "AutoModelForSeq2SeqLM",
"AutoModelForSequenceClassification", "AutoModelForSequenceClassification",
"AutoModelForSpeechSeq2Seq",
"AutoModelForTableQuestionAnswering", "AutoModelForTableQuestionAnswering",
"AutoModelForTokenClassification", "AutoModelForTokenClassification",
"AutoModelWithLMHead", "AutoModelWithLMHead",
...@@ -124,6 +128,7 @@ if TYPE_CHECKING: ...@@ -124,6 +128,7 @@ if TYPE_CHECKING:
from .modeling_auto import ( from .modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CTC_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
...@@ -133,6 +138,7 @@ if TYPE_CHECKING: ...@@ -133,6 +138,7 @@ if TYPE_CHECKING:
MODEL_FOR_QUESTION_ANSWERING_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
...@@ -140,6 +146,7 @@ if TYPE_CHECKING: ...@@ -140,6 +146,7 @@ if TYPE_CHECKING:
AutoModel, AutoModel,
AutoModelForAudioClassification, AutoModelForAudioClassification,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForCTC,
AutoModelForImageClassification, AutoModelForImageClassification,
AutoModelForMaskedLM, AutoModelForMaskedLM,
AutoModelForMultipleChoice, AutoModelForMultipleChoice,
...@@ -149,6 +156,7 @@ if TYPE_CHECKING: ...@@ -149,6 +156,7 @@ if TYPE_CHECKING:
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering, AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification, AutoModelForTokenClassification,
AutoModelWithLMHead, AutoModelWithLMHead,
......
...@@ -291,6 +291,13 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -291,6 +291,13 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
] ]
) )
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
("speech_to_text", "Speech2TextForConditionalGeneration"),
]
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Sequence Classification mapping # Model for Sequence Classification mapping
...@@ -462,6 +469,14 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -462,6 +469,14 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
] ]
) )
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
[
# Model for Connectionist temporal classification (CTC) mapping
("wav2vec2", "Wav2Vec2ForCTC"),
("hubert", "HubertForCTC"),
]
)
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
...@@ -493,6 +508,8 @@ MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( ...@@ -493,6 +508,8 @@ MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
) )
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
class AutoModel(_BaseAutoModelClass): class AutoModel(_BaseAutoModelClass):
...@@ -611,6 +628,22 @@ class AutoModelForAudioClassification(_BaseAutoModelClass): ...@@ -611,6 +628,22 @@ class AutoModelForAudioClassification(_BaseAutoModelClass):
AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification") AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")
class AutoModelForCTC(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_CTC_MAPPING
AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")
class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
AutoModelForSpeechSeq2Seq = auto_class_update(
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeing"
)
class AutoModelWithLMHead(_AutoModelWithLMHead): class AutoModelWithLMHead(_AutoModelWithLMHead):
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
......
...@@ -90,12 +90,14 @@ if is_torch_available(): ...@@ -90,12 +90,14 @@ if is_torch_available():
AutoModel, AutoModel,
AutoModelForAudioClassification, AutoModelForAudioClassification,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForCTC,
AutoModelForImageClassification, AutoModelForImageClassification,
AutoModelForMaskedLM, AutoModelForMaskedLM,
AutoModelForObjectDetection, AutoModelForObjectDetection,
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTableQuestionAnswering, AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification, AutoModelForTokenClassification,
) )
...@@ -121,9 +123,7 @@ SUPPORTED_TASKS = { ...@@ -121,9 +123,7 @@ SUPPORTED_TASKS = {
"automatic-speech-recognition": { "automatic-speech-recognition": {
"impl": AutomaticSpeechRecognitionPipeline, "impl": AutomaticSpeechRecognitionPipeline,
"tf": (), "tf": (),
# Only load from `config.architectures`, AutoModelForCTC and AutoModelForConditionalGeneration "pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),
# do not exist yet.
"pt": () if is_torch_available() else (),
"default": {"model": {"pt": "facebook/wav2vec2-base-960h"}}, "default": {"model": {"pt": "facebook/wav2vec2-base-960h"}},
}, },
"feature-extraction": { "feature-extraction": {
......
...@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Union ...@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Union
import numpy as np import numpy as np
from ..file_utils import is_torch_available
from ..utils import logging from ..utils import logging
from .base import Pipeline from .base import Pipeline
...@@ -25,6 +26,9 @@ if TYPE_CHECKING: ...@@ -25,6 +26,9 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array: def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
""" """
...@@ -102,6 +106,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): ...@@ -102,6 +106,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
if self.framework == "tf": if self.framework == "tf":
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.") raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
self.check_model_type(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items() + MODEL_FOR_CTC_MAPPING.items())
def __call__( def __call__(
self, self,
inputs: Union[np.ndarray, bytes, str], inputs: Union[np.ndarray, bytes, str],
...@@ -149,8 +155,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): ...@@ -149,8 +155,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
return processed return processed
def _forward(self, model_inputs): def _forward(self, model_inputs):
name = self.model.__class__.__name__ model_class = self.model.__class__
if name.endswith("ForConditionalGeneration") or name.endswith("EncoderDecoderModel"): if model_class in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
encoder = self.model.get_encoder() encoder = self.model.get_encoder()
# we need to pass `processed.get("attention_mask")` here since audio encoder # we need to pass `processed.get("attention_mask")` here since audio encoder
# attention mask length is different from expected text decoder `encoder_attention_mask` length # attention mask length is different from expected text decoder `encoder_attention_mask` length
...@@ -160,7 +166,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): ...@@ -160,7 +166,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
encoder_outputs=encoder(**model_inputs), attention_mask=model_inputs.get("attention_mask") encoder_outputs=encoder(**model_inputs), attention_mask=model_inputs.get("attention_mask")
) )
tokens = tokens.squeeze(0) tokens = tokens.squeeze(0)
elif name.endswith("ForCTC"): elif model_class in MODEL_FOR_CTC_MAPPING.values():
outputs = self.model(**model_inputs) outputs = self.model(**model_inputs)
tokens = outputs.logits.squeeze(0).argmax(dim=-1) tokens = outputs.logits.squeeze(0).argmax(dim=-1)
return tokens return tokens
......
...@@ -379,6 +379,15 @@ class AutoModelForCausalLM: ...@@ -379,6 +379,15 @@ class AutoModelForCausalLM:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class AutoModelForCTC:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoModelForImageClassification: class AutoModelForImageClassification:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
...@@ -460,6 +469,15 @@ class AutoModelForSequenceClassification: ...@@ -460,6 +469,15 @@ class AutoModelForSequenceClassification:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class AutoModelForSpeechSeq2Seq:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoModelForTableQuestionAnswering: class AutoModelForTableQuestionAnswering:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
......
...@@ -49,10 +49,10 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -49,10 +49,10 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
@require_torch @require_torch
def test_torch_small_no_tokenizer_files(self): def test_torch_small_no_tokenizer_files(self):
# test that model without tokenizer file cannot be loaded # test that model without tokenizer file cannot be loaded
with pytest.raises(ValueError): with pytest.raises(OSError):
pipeline( pipeline(
task="automatic-speech-recognition", task="automatic-speech-recognition",
model="hf-internal-testing/tiny-random-wav2vec2", model="patrickvonplaten/tiny-wav2vec2-no-tokenizer",
framework="pt", framework="pt",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment