Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
48fa42e5
Unverified
Commit
48fa42e5
authored
Sep 21, 2021
by
Patrick von Platen
Committed by
GitHub
Sep 21, 2021
Browse files
Add Speech AutoModels (#13655)
* upload * correct * correct * correct * finish * up * up * up again
parent
ea921365
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
91 additions
and
8 deletions
+91
-8
docs/source/model_doc/auto.rst
docs/source/model_doc/auto.rst
+14
-0
src/transformers/__init__.py
src/transformers/__init__.py
+4
-0
src/transformers/models/auto/__init__.py
src/transformers/models/auto/__init__.py
+8
-0
src/transformers/models/auto/modeling_auto.py
src/transformers/models/auto/modeling_auto.py
+33
-0
src/transformers/pipelines/__init__.py
src/transformers/pipelines/__init__.py
+3
-3
src/transformers/pipelines/automatic_speech_recognition.py
src/transformers/pipelines/automatic_speech_recognition.py
+9
-3
src/transformers/utils/dummy_pt_objects.py
src/transformers/utils/dummy_pt_objects.py
+18
-0
tests/test_pipelines_automatic_speech_recognition.py
tests/test_pipelines_automatic_speech_recognition.py
+2
-2
No files found.
docs/source/model_doc/auto.rst
View file @
48fa42e5
...
@@ -142,6 +142,20 @@ AutoModelForAudioClassification
...
@@ -142,6 +142,20 @@ AutoModelForAudioClassification
:
members
:
:
members
:
AutoModelForCTC
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
..
autoclass
::
transformers
.
AutoModelForCTC
:
members
:
AutoModelForSpeechSeq2Seq
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
..
autoclass
::
transformers
.
AutoModelForSpeechSeq2Seq
:
members
:
AutoModelForObjectDetection
AutoModelForObjectDetection
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...
...
src/transformers/__init__.py
View file @
48fa42e5
...
@@ -557,6 +557,7 @@ if is_torch_available():
...
@@ -557,6 +557,7 @@ if is_torch_available():
"AutoModel"
,
"AutoModel"
,
"AutoModelForAudioClassification"
,
"AutoModelForAudioClassification"
,
"AutoModelForCausalLM"
,
"AutoModelForCausalLM"
,
"AutoModelForCTC"
,
"AutoModelForImageClassification"
,
"AutoModelForImageClassification"
,
"AutoModelForMaskedLM"
,
"AutoModelForMaskedLM"
,
"AutoModelForMultipleChoice"
,
"AutoModelForMultipleChoice"
,
...
@@ -566,6 +567,7 @@ if is_torch_available():
...
@@ -566,6 +567,7 @@ if is_torch_available():
"AutoModelForQuestionAnswering"
,
"AutoModelForQuestionAnswering"
,
"AutoModelForSeq2SeqLM"
,
"AutoModelForSeq2SeqLM"
,
"AutoModelForSequenceClassification"
,
"AutoModelForSequenceClassification"
,
"AutoModelForSpeechSeq2Seq"
,
"AutoModelForTableQuestionAnswering"
,
"AutoModelForTableQuestionAnswering"
,
"AutoModelForTokenClassification"
,
"AutoModelForTokenClassification"
,
"AutoModelWithLMHead"
,
"AutoModelWithLMHead"
,
...
@@ -2320,6 +2322,7 @@ if TYPE_CHECKING:
...
@@ -2320,6 +2322,7 @@ if TYPE_CHECKING:
AutoModel
,
AutoModel
,
AutoModelForAudioClassification
,
AutoModelForAudioClassification
,
AutoModelForCausalLM
,
AutoModelForCausalLM
,
AutoModelForCTC
,
AutoModelForImageClassification
,
AutoModelForImageClassification
,
AutoModelForMaskedLM
,
AutoModelForMaskedLM
,
AutoModelForMultipleChoice
,
AutoModelForMultipleChoice
,
...
@@ -2329,6 +2332,7 @@ if TYPE_CHECKING:
...
@@ -2329,6 +2332,7 @@ if TYPE_CHECKING:
AutoModelForQuestionAnswering
,
AutoModelForQuestionAnswering
,
AutoModelForSeq2SeqLM
,
AutoModelForSeq2SeqLM
,
AutoModelForSequenceClassification
,
AutoModelForSequenceClassification
,
AutoModelForSpeechSeq2Seq
,
AutoModelForTableQuestionAnswering
,
AutoModelForTableQuestionAnswering
,
AutoModelForTokenClassification
,
AutoModelForTokenClassification
,
AutoModelWithLMHead
,
AutoModelWithLMHead
,
...
...
src/transformers/models/auto/__init__.py
View file @
48fa42e5
...
@@ -32,6 +32,7 @@ if is_torch_available():
...
@@ -32,6 +32,7 @@ if is_torch_available():
_import_structure
[
"modeling_auto"
]
=
[
_import_structure
[
"modeling_auto"
]
=
[
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_CAUSAL_LM_MAPPING"
,
"MODEL_FOR_CAUSAL_LM_MAPPING"
,
"MODEL_FOR_CTC_MAPPING"
,
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_MASKED_LM_MAPPING"
,
"MODEL_FOR_MASKED_LM_MAPPING"
,
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING"
,
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING"
,
...
@@ -41,6 +42,7 @@ if is_torch_available():
...
@@ -41,6 +42,7 @@ if is_torch_available():
"MODEL_FOR_QUESTION_ANSWERING_MAPPING"
,
"MODEL_FOR_QUESTION_ANSWERING_MAPPING"
,
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING"
,
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING"
,
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING"
,
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING"
,
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING"
,
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING"
,
"MODEL_MAPPING"
,
"MODEL_MAPPING"
,
...
@@ -48,6 +50,7 @@ if is_torch_available():
...
@@ -48,6 +50,7 @@ if is_torch_available():
"AutoModel"
,
"AutoModel"
,
"AutoModelForAudioClassification"
,
"AutoModelForAudioClassification"
,
"AutoModelForCausalLM"
,
"AutoModelForCausalLM"
,
"AutoModelForCTC"
,
"AutoModelForImageClassification"
,
"AutoModelForImageClassification"
,
"AutoModelForMaskedLM"
,
"AutoModelForMaskedLM"
,
"AutoModelForMultipleChoice"
,
"AutoModelForMultipleChoice"
,
...
@@ -57,6 +60,7 @@ if is_torch_available():
...
@@ -57,6 +60,7 @@ if is_torch_available():
"AutoModelForQuestionAnswering"
,
"AutoModelForQuestionAnswering"
,
"AutoModelForSeq2SeqLM"
,
"AutoModelForSeq2SeqLM"
,
"AutoModelForSequenceClassification"
,
"AutoModelForSequenceClassification"
,
"AutoModelForSpeechSeq2Seq"
,
"AutoModelForTableQuestionAnswering"
,
"AutoModelForTableQuestionAnswering"
,
"AutoModelForTokenClassification"
,
"AutoModelForTokenClassification"
,
"AutoModelWithLMHead"
,
"AutoModelWithLMHead"
,
...
@@ -124,6 +128,7 @@ if TYPE_CHECKING:
...
@@ -124,6 +128,7 @@ if TYPE_CHECKING:
from
.modeling_auto
import
(
from
.modeling_auto
import
(
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
,
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
,
MODEL_FOR_CAUSAL_LM_MAPPING
,
MODEL_FOR_CAUSAL_LM_MAPPING
,
MODEL_FOR_CTC_MAPPING
,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
MODEL_FOR_MASKED_LM_MAPPING
,
MODEL_FOR_MASKED_LM_MAPPING
,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING
,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING
,
...
@@ -133,6 +138,7 @@ if TYPE_CHECKING:
...
@@ -133,6 +138,7 @@ if TYPE_CHECKING:
MODEL_FOR_QUESTION_ANSWERING_MAPPING
,
MODEL_FOR_QUESTION_ANSWERING_MAPPING
,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
,
MODEL_MAPPING
,
MODEL_MAPPING
,
...
@@ -140,6 +146,7 @@ if TYPE_CHECKING:
...
@@ -140,6 +146,7 @@ if TYPE_CHECKING:
AutoModel
,
AutoModel
,
AutoModelForAudioClassification
,
AutoModelForAudioClassification
,
AutoModelForCausalLM
,
AutoModelForCausalLM
,
AutoModelForCTC
,
AutoModelForImageClassification
,
AutoModelForImageClassification
,
AutoModelForMaskedLM
,
AutoModelForMaskedLM
,
AutoModelForMultipleChoice
,
AutoModelForMultipleChoice
,
...
@@ -149,6 +156,7 @@ if TYPE_CHECKING:
...
@@ -149,6 +156,7 @@ if TYPE_CHECKING:
AutoModelForQuestionAnswering
,
AutoModelForQuestionAnswering
,
AutoModelForSeq2SeqLM
,
AutoModelForSeq2SeqLM
,
AutoModelForSequenceClassification
,
AutoModelForSequenceClassification
,
AutoModelForSpeechSeq2Seq
,
AutoModelForTableQuestionAnswering
,
AutoModelForTableQuestionAnswering
,
AutoModelForTokenClassification
,
AutoModelForTokenClassification
,
AutoModelWithLMHead
,
AutoModelWithLMHead
,
...
...
src/transformers/models/auto/modeling_auto.py
View file @
48fa42e5
...
@@ -291,6 +291,13 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
...
@@ -291,6 +291,13 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
]
]
)
)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
=
OrderedDict
(
[
(
"speech-encoder-decoder"
,
"SpeechEncoderDecoderModel"
),
(
"speech_to_text"
,
"Speech2TextForConditionalGeneration"
),
]
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
=
OrderedDict
(
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
=
OrderedDict
(
[
[
# Model for Sequence Classification mapping
# Model for Sequence Classification mapping
...
@@ -462,6 +469,14 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
...
@@ -462,6 +469,14 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
]
]
)
)
MODEL_FOR_CTC_MAPPING_NAMES
=
OrderedDict
(
[
# Model for Connectionist temporal classification (CTC) mapping
(
"wav2vec2"
,
"Wav2Vec2ForCTC"
),
(
"hubert"
,
"HubertForCTC"
),
]
)
MODEL_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_MAPPING_NAMES
)
MODEL_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_MAPPING_NAMES
)
MODEL_FOR_PRETRAINING_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_PRETRAINING_MAPPING_NAMES
)
MODEL_FOR_PRETRAINING_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_PRETRAINING_MAPPING_NAMES
)
MODEL_WITH_LM_HEAD_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_WITH_LM_HEAD_MAPPING_NAMES
)
MODEL_WITH_LM_HEAD_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_WITH_LM_HEAD_MAPPING_NAMES
)
...
@@ -493,6 +508,8 @@ MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
...
@@ -493,6 +508,8 @@ MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
=
_LazyAutoMapping
(
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
CONFIG_MAPPING_NAMES
,
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
)
MODEL_FOR_CTC_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_CTC_MAPPING_NAMES
)
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
)
class
AutoModel
(
_BaseAutoModelClass
):
class
AutoModel
(
_BaseAutoModelClass
):
...
@@ -611,6 +628,22 @@ class AutoModelForAudioClassification(_BaseAutoModelClass):
...
@@ -611,6 +628,22 @@ class AutoModelForAudioClassification(_BaseAutoModelClass):
AutoModelForAudioClassification
=
auto_class_update
(
AutoModelForAudioClassification
,
head_doc
=
"audio classification"
)
AutoModelForAudioClassification
=
auto_class_update
(
AutoModelForAudioClassification
,
head_doc
=
"audio classification"
)
class
AutoModelForCTC
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_FOR_CTC_MAPPING
AutoModelForCTC
=
auto_class_update
(
AutoModelForCTC
,
head_doc
=
"connectionist temporal classification"
)
class
AutoModelForSpeechSeq2Seq
(
_BaseAutoModelClass
):
_model_mapping
=
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
AutoModelForSpeechSeq2Seq
=
auto_class_update
(
AutoModelForSpeechSeq2Seq
,
head_doc
=
"sequence-to-sequence speech-to-text modeing"
)
class
AutoModelWithLMHead
(
_AutoModelWithLMHead
):
class
AutoModelWithLMHead
(
_AutoModelWithLMHead
):
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
):
def
from_config
(
cls
,
config
):
...
...
src/transformers/pipelines/__init__.py
View file @
48fa42e5
...
@@ -90,12 +90,14 @@ if is_torch_available():
...
@@ -90,12 +90,14 @@ if is_torch_available():
AutoModel
,
AutoModel
,
AutoModelForAudioClassification
,
AutoModelForAudioClassification
,
AutoModelForCausalLM
,
AutoModelForCausalLM
,
AutoModelForCTC
,
AutoModelForImageClassification
,
AutoModelForImageClassification
,
AutoModelForMaskedLM
,
AutoModelForMaskedLM
,
AutoModelForObjectDetection
,
AutoModelForObjectDetection
,
AutoModelForQuestionAnswering
,
AutoModelForQuestionAnswering
,
AutoModelForSeq2SeqLM
,
AutoModelForSeq2SeqLM
,
AutoModelForSequenceClassification
,
AutoModelForSequenceClassification
,
AutoModelForSpeechSeq2Seq
,
AutoModelForTableQuestionAnswering
,
AutoModelForTableQuestionAnswering
,
AutoModelForTokenClassification
,
AutoModelForTokenClassification
,
)
)
...
@@ -121,9 +123,7 @@ SUPPORTED_TASKS = {
...
@@ -121,9 +123,7 @@ SUPPORTED_TASKS = {
"automatic-speech-recognition"
:
{
"automatic-speech-recognition"
:
{
"impl"
:
AutomaticSpeechRecognitionPipeline
,
"impl"
:
AutomaticSpeechRecognitionPipeline
,
"tf"
:
(),
"tf"
:
(),
# Only load from `config.architectures`, AutoModelForCTC and AutoModelForConditionalGeneration
"pt"
:
(
AutoModelForCTC
,
AutoModelForSpeechSeq2Seq
)
if
is_torch_available
()
else
(),
# do not exist yet.
"pt"
:
()
if
is_torch_available
()
else
(),
"default"
:
{
"model"
:
{
"pt"
:
"facebook/wav2vec2-base-960h"
}},
"default"
:
{
"model"
:
{
"pt"
:
"facebook/wav2vec2-base-960h"
}},
},
},
"feature-extraction"
:
{
"feature-extraction"
:
{
...
...
src/transformers/pipelines/automatic_speech_recognition.py
View file @
48fa42e5
...
@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Union
...
@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Union
import
numpy
as
np
import
numpy
as
np
from
..file_utils
import
is_torch_available
from
..utils
import
logging
from
..utils
import
logging
from
.base
import
Pipeline
from
.base
import
Pipeline
...
@@ -25,6 +26,9 @@ if TYPE_CHECKING:
...
@@ -25,6 +26,9 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
if
is_torch_available
():
from
..models.auto.modeling_auto
import
MODEL_FOR_CTC_MAPPING
,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
def
ffmpeg_read
(
bpayload
:
bytes
,
sampling_rate
:
int
)
->
np
.
array
:
def
ffmpeg_read
(
bpayload
:
bytes
,
sampling_rate
:
int
)
->
np
.
array
:
"""
"""
...
@@ -102,6 +106,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
...
@@ -102,6 +106,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
if
self
.
framework
==
"tf"
:
if
self
.
framework
==
"tf"
:
raise
ValueError
(
"The AutomaticSpeechRecognitionPipeline is only available in PyTorch."
)
raise
ValueError
(
"The AutomaticSpeechRecognitionPipeline is only available in PyTorch."
)
self
.
check_model_type
(
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
.
items
()
+
MODEL_FOR_CTC_MAPPING
.
items
())
def
__call__
(
def
__call__
(
self
,
self
,
inputs
:
Union
[
np
.
ndarray
,
bytes
,
str
],
inputs
:
Union
[
np
.
ndarray
,
bytes
,
str
],
...
@@ -149,8 +155,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
...
@@ -149,8 +155,8 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
return
processed
return
processed
def
_forward
(
self
,
model_inputs
):
def
_forward
(
self
,
model_inputs
):
name
=
self
.
model
.
__class__
.
__name__
model_class
=
self
.
model
.
__class__
if
name
.
endswith
(
"ForConditionalGeneration"
)
or
name
.
endswith
(
"EncoderDecoderModel"
):
if
model_class
in
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
.
values
(
):
encoder
=
self
.
model
.
get_encoder
()
encoder
=
self
.
model
.
get_encoder
()
# we need to pass `processed.get("attention_mask")` here since audio encoder
# we need to pass `processed.get("attention_mask")` here since audio encoder
# attention mask length is different from expected text decoder `encoder_attention_mask` length
# attention mask length is different from expected text decoder `encoder_attention_mask` length
...
@@ -160,7 +166,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
...
@@ -160,7 +166,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
encoder_outputs
=
encoder
(
**
model_inputs
),
attention_mask
=
model_inputs
.
get
(
"attention_mask"
)
encoder_outputs
=
encoder
(
**
model_inputs
),
attention_mask
=
model_inputs
.
get
(
"attention_mask"
)
)
)
tokens
=
tokens
.
squeeze
(
0
)
tokens
=
tokens
.
squeeze
(
0
)
elif
name
.
endswith
(
"ForCTC"
):
elif
model_class
in
MODEL_FOR_CTC_MAPPING
.
values
(
):
outputs
=
self
.
model
(
**
model_inputs
)
outputs
=
self
.
model
(
**
model_inputs
)
tokens
=
outputs
.
logits
.
squeeze
(
0
).
argmax
(
dim
=-
1
)
tokens
=
outputs
.
logits
.
squeeze
(
0
).
argmax
(
dim
=-
1
)
return
tokens
return
tokens
...
...
src/transformers/utils/dummy_pt_objects.py
View file @
48fa42e5
...
@@ -379,6 +379,15 @@ class AutoModelForCausalLM:
...
@@ -379,6 +379,15 @@ class AutoModelForCausalLM:
requires_backends
(
cls
,
[
"torch"
])
requires_backends
(
cls
,
[
"torch"
])
class
AutoModelForCTC
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
])
class
AutoModelForImageClassification
:
class
AutoModelForImageClassification
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
requires_backends
(
self
,
[
"torch"
])
...
@@ -460,6 +469,15 @@ class AutoModelForSequenceClassification:
...
@@ -460,6 +469,15 @@ class AutoModelForSequenceClassification:
requires_backends
(
cls
,
[
"torch"
])
requires_backends
(
cls
,
[
"torch"
])
class
AutoModelForSpeechSeq2Seq
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
):
requires_backends
(
cls
,
[
"torch"
])
class
AutoModelForTableQuestionAnswering
:
class
AutoModelForTableQuestionAnswering
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
requires_backends
(
self
,
[
"torch"
])
...
...
tests/test_pipelines_automatic_speech_recognition.py
View file @
48fa42e5
...
@@ -49,10 +49,10 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
...
@@ -49,10 +49,10 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
@
require_torch
@
require_torch
def
test_torch_small_no_tokenizer_files
(
self
):
def
test_torch_small_no_tokenizer_files
(
self
):
# test that model without tokenizer file cannot be loaded
# test that model without tokenizer file cannot be loaded
with
pytest
.
raises
(
Value
Error
):
with
pytest
.
raises
(
OS
Error
):
pipeline
(
pipeline
(
task
=
"automatic-speech-recognition"
,
task
=
"automatic-speech-recognition"
,
model
=
"
hf-internal-testing/tiny-random-wav2vec2
"
,
model
=
"
patrickvonplaten/tiny-wav2vec2-no-tokenizer
"
,
framework
=
"pt"
,
framework
=
"pt"
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment