Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0eaa5ea3
Unverified
Commit
0eaa5ea3
authored
Jan 18, 2024
by
Sanchit Gandhi
Committed by
GitHub
Jan 18, 2024
Browse files
[ASR Pipe] Update init to set model type and subsequently call parent init method (#28486)
* add image processor arg * super * rm args
parent
c662c78c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
79 deletions
+14
-79
src/transformers/pipelines/automatic_speech_recognition.py
src/transformers/pipelines/automatic_speech_recognition.py
+14
-79
No files found.
src/transformers/pipelines/automatic_speech_recognition.py
View file @
0eaa5ea3
...
@@ -17,11 +17,10 @@ from typing import TYPE_CHECKING, Dict, Optional, Union
...
@@ -17,11 +17,10 @@ from typing import TYPE_CHECKING, Dict, Optional, Union
import
numpy
as
np
import
numpy
as
np
import
requests
import
requests
from
..modelcard
import
ModelCard
from
..tokenization_utils
import
PreTrainedTokenizer
from
..tokenization_utils
import
PreTrainedTokenizer
from
..utils
import
is_torch_available
,
is_torchaudio_available
,
logging
from
..utils
import
is_torch_available
,
is_torchaudio_available
,
logging
from
.audio_utils
import
ffmpeg_read
from
.audio_utils
import
ffmpeg_read
from
.base
import
ArgumentHandler
,
ChunkPipeline
,
infer_framework_load_model
from
.base
import
ChunkPipeline
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -35,7 +34,7 @@ logger = logging.get_logger(__name__)
...
@@ -35,7 +34,7 @@ logger = logging.get_logger(__name__)
if
is_torch_available
():
if
is_torch_available
():
import
torch
import
torch
from
..models.auto.modeling_auto
import
MODEL_FOR_CTC_MAPPING_NAMES
,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
from
..models.auto.modeling_auto
import
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
def
rescale_stride
(
stride
,
ratio
):
def
rescale_stride
(
stride
,
ratio
):
...
@@ -155,11 +154,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
...
@@ -155,11 +154,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
feature_extractor ([`SequenceFeatureExtractor`]):
The feature extractor that will be used by the pipeline to encode waveform for the model.
tokenizer ([`PreTrainedTokenizer`]):
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
[`PreTrainedTokenizer`].
[`PreTrainedTokenizer`].
feature_extractor ([`SequenceFeatureExtractor`]):
decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*):
The feature extractor that will be used by the pipeline to encode waveform for the model.
[PyCTCDecode's
BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180)
can be passed for language model boosted decoding. See [`Wav2Vec2ProcessorWithLM`] for more information.
chunk_length_s (`float`, *optional*, defaults to 0):
chunk_length_s (`float`, *optional*, defaults to 0):
The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default).
The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default).
...
@@ -190,10 +193,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
...
@@ -190,10 +193,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
device (Union[`int`, `torch.device`], *optional*):
device (Union[`int`, `torch.device`], *optional*):
Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the
Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the
model on the associated CUDA device id.
model on the associated CUDA device id.
decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*):
torch_dtype (Union[`int`, `torch.dtype`], *optional*):
[PyCTCDecode's
The data-type (dtype) of the computation. Setting this to `None` will use float32 precision. Set to
BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180)
`torch.float16` or `torch.bfloat16` to use half-precision in the respective dtypes.
can be passed for language model boosted decoding. See [`Wav2Vec2ProcessorWithLM`] for more information.
"""
"""
...
@@ -203,77 +205,14 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
...
@@ -203,77 +205,14 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
feature_extractor
:
Union
[
"SequenceFeatureExtractor"
,
str
]
=
None
,
feature_extractor
:
Union
[
"SequenceFeatureExtractor"
,
str
]
=
None
,
tokenizer
:
Optional
[
PreTrainedTokenizer
]
=
None
,
tokenizer
:
Optional
[
PreTrainedTokenizer
]
=
None
,
decoder
:
Optional
[
Union
[
"BeamSearchDecoderCTC"
,
str
]]
=
None
,
decoder
:
Optional
[
Union
[
"BeamSearchDecoderCTC"
,
str
]]
=
None
,
modelcard
:
Optional
[
ModelCard
]
=
None
,
framework
:
Optional
[
str
]
=
None
,
task
:
str
=
""
,
args_parser
:
ArgumentHandler
=
None
,
device
:
Union
[
int
,
"torch.device"
]
=
None
,
device
:
Union
[
int
,
"torch.device"
]
=
None
,
torch_dtype
:
Optional
[
Union
[
str
,
"torch.dtype"
]]
=
None
,
torch_dtype
:
Optional
[
Union
[
str
,
"torch.dtype"
]]
=
None
,
binary_output
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
):
):
if
framework
is
None
:
framework
,
model
=
infer_framework_load_model
(
model
,
config
=
model
.
config
)
self
.
task
=
task
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
feature_extractor
=
feature_extractor
self
.
modelcard
=
modelcard
self
.
framework
=
framework
# `accelerate` device map
hf_device_map
=
getattr
(
self
.
model
,
"hf_device_map"
,
None
)
if
hf_device_map
is
not
None
and
device
is
not
None
:
raise
ValueError
(
"The model has been loaded with `accelerate` and therefore cannot be moved to a specific device. Please "
"discard the `device` argument when creating your pipeline object."
)
if
self
.
framework
==
"tf"
:
raise
ValueError
(
"The AutomaticSpeechRecognitionPipeline is only available in PyTorch."
)
# We shouldn't call `model.to()` for models loaded with accelerate
if
device
is
not
None
and
not
(
isinstance
(
device
,
int
)
and
device
<
0
):
self
.
model
.
to
(
device
)
if
device
is
None
:
if
hf_device_map
is
not
None
:
# Take the first device used by `accelerate`.
device
=
next
(
iter
(
hf_device_map
.
values
()))
else
:
device
=
-
1
if
is_torch_available
()
and
self
.
framework
==
"pt"
:
if
isinstance
(
device
,
torch
.
device
):
self
.
device
=
device
elif
isinstance
(
device
,
str
):
self
.
device
=
torch
.
device
(
device
)
elif
device
<
0
:
self
.
device
=
torch
.
device
(
"cpu"
)
else
:
self
.
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
else
:
self
.
device
=
device
if
device
is
not
None
else
-
1
self
.
torch_dtype
=
torch_dtype
self
.
binary_output
=
binary_output
# Update config and generation_config with task specific parameters
task_specific_params
=
self
.
model
.
config
.
task_specific_params
if
task_specific_params
is
not
None
and
task
in
task_specific_params
:
self
.
model
.
config
.
update
(
task_specific_params
.
get
(
task
))
if
self
.
model
.
can_generate
():
self
.
model
.
generation_config
.
update
(
**
task_specific_params
.
get
(
task
))
self
.
call_count
=
0
self
.
_batch_size
=
kwargs
.
pop
(
"batch_size"
,
None
)
self
.
_num_workers
=
kwargs
.
pop
(
"num_workers"
,
None
)
# set the model type so we can check we have the right pre- and post-processing parameters
# set the model type so we can check we have the right pre- and post-processing parameters
if
self
.
model
.
config
.
model_type
==
"whisper"
:
if
model
.
config
.
model_type
==
"whisper"
:
self
.
type
=
"seq2seq_whisper"
self
.
type
=
"seq2seq_whisper"
elif
self
.
model
.
__class__
.
__name__
in
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
.
values
():
elif
model
.
__class__
.
__name__
in
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
.
values
():
self
.
type
=
"seq2seq"
self
.
type
=
"seq2seq"
elif
(
elif
(
feature_extractor
.
_processor_class
feature_extractor
.
_processor_class
...
@@ -285,11 +224,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
...
@@ -285,11 +224,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
else
:
else
:
self
.
type
=
"ctc"
self
.
type
=
"ctc"
self
.
_preprocess_params
,
self
.
_forward_params
,
self
.
_postprocess_params
=
self
.
_sanitize_parameters
(
**
kwargs
)
super
().
__init__
(
model
,
tokenizer
,
feature_extractor
,
device
=
device
,
torch_dtype
=
torch_dtype
,
**
kwargs
)
mapping
=
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
.
copy
()
mapping
.
update
(
MODEL_FOR_CTC_MAPPING_NAMES
)
self
.
check_model_type
(
mapping
)
def
__call__
(
def
__call__
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment