Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5f1918a4
"git@developer.sourcefind.cn:wuxk1/megatron-lm.git" did not exist on "2b699d4432632ddde5bd4aeed728abf9e4c9a0b8"
Unverified
Commit
5f1918a4
authored
Feb 07, 2022
by
Patrick von Platen
Committed by
GitHub
Feb 07, 2022
Browse files
[ASR pipeline] correct asr pipeline for seq2seq models (#15541)
parent
e02bdce7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
3 deletions
+30
-3
src/transformers/pipelines/automatic_speech_recognition.py
src/transformers/pipelines/automatic_speech_recognition.py
+12
-3
tests/test_pipelines_automatic_speech_recognition.py
tests/test_pipelines_automatic_speech_recognition.py
+18
-0
No files found.
src/transformers/pipelines/automatic_speech_recognition.py
View file @
5f1918a4
...
...
@@ -265,10 +265,19 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# it here.
# Consume values so we can let extra information flow freely through
# the pipeline (important for `partial` in microphone)
input_features
=
model_inputs
.
pop
(
"input_features"
)
attention_mask
=
model_inputs
.
pop
(
"attention_mask"
)
if
"input_features"
in
model_inputs
:
inputs
=
model_inputs
.
pop
(
"input_features"
)
elif
"input_values"
in
model_inputs
:
inputs
=
model_inputs
.
pop
(
"input_values"
)
else
:
raise
ValueError
(
"Seq2Seq speech recognition model requires either a "
f
"`input_features` or `input_values` key, but only has
{
model_inputs
.
keys
()
}
"
)
attention_mask
=
model_inputs
.
pop
(
"attention_mask"
,
None
)
tokens
=
self
.
model
.
generate
(
encoder_outputs
=
encoder
(
input
_features
=
input_feature
s
,
attention_mask
=
attention_mask
),
encoder_outputs
=
encoder
(
inputs
,
attention_mask
=
attention_mask
),
attention_mask
=
attention_mask
,
)
out
=
{
"tokens"
:
tokens
}
...
...
tests/test_pipelines_automatic_speech_recognition.py
View file @
5f1918a4
...
...
@@ -107,6 +107,24 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output
=
speech_recognizer
(
waveform
)
self
.
assertEqual
(
output
,
{
"text"
:
"(Applaudissements)"
})
@
require_torch
def
test_small_model_pt_seq2seq
(
self
):
model_id
=
"hf-internal-testing/tiny-random-speech-encoder-decoder"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
)
feature_extractor
=
AutoFeatureExtractor
.
from_pretrained
(
model_id
)
speech_recognizer
=
pipeline
(
task
=
"automatic-speech-recognition"
,
model
=
model_id
,
tokenizer
=
tokenizer
,
feature_extractor
=
feature_extractor
,
framework
=
"pt"
,
)
waveform
=
np
.
tile
(
np
.
arange
(
1000
,
dtype
=
np
.
float32
),
34
)
output
=
speech_recognizer
(
waveform
)
self
.
assertEqual
(
output
,
{
"text"
:
"あл ش 湯 清 ه ܬ া लᆨしث ल eか u w 全 u"
})
@
slow
@
require_torch
@
require_pyctcdecode
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment