Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
994b7a4e
Unverified
Commit
994b7a4e
authored
Oct 07, 2022
by
Arthur
Committed by
GitHub
Oct 07, 2022
Browse files
update attention mask handling (#19385)
* update feature extractor params * update attention mask handling
parent
a26d71d6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
9 deletions
+31
-9
src/transformers/models/whisper/feature_extraction_whisper.py
...transformers/models/whisper/feature_extraction_whisper.py
+16
-1
src/transformers/pipelines/automatic_speech_recognition.py
src/transformers/pipelines/automatic_speech_recognition.py
+15
-8
No files found.
src/transformers/models/whisper/feature_extraction_whisper.py
View file @
994b7a4e
...
...
@@ -218,6 +218,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
return_attention_mask
:
Optional
[
bool
]
=
None
,
padding
:
Optional
[
str
]
=
"max_length"
,
max_length
:
Optional
[
int
]
=
None
,
sampling_rate
:
Optional
[
int
]
=
None
,
**
kwargs
)
->
BatchFeature
:
"""
...
...
@@ -255,11 +256,25 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
- `'np'`: Return Numpy `np.ndarray` objects.
sampling_rate (`int`, *optional*):
The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
`sampling_rate` at the forward call to prevent silent errors.
`sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
pipeline.
padding_value (`float`, defaults to 0.0):
The value that is used to fill the padding values / vectors.
"""
if
sampling_rate
is
not
None
:
if
sampling_rate
!=
self
.
sampling_rate
:
raise
ValueError
(
f
"The model corresponding to this feature extractor:
{
self
}
was trained using a sampling rate of"
f
"
{
self
.
sampling_rate
}
. Please make sure that the provided `raw_speech` input was sampled with"
f
"
{
self
.
sampling_rate
}
and not
{
sampling_rate
}
."
)
else
:
logger
.
warning
(
"It is strongly recommended to pass the `sampling_rate` argument to this function. "
"Failing to do so can result in silent errors that might be hard to debug."
)
is_batched
=
bool
(
isinstance
(
raw_speech
,
(
list
,
tuple
))
and
(
isinstance
(
raw_speech
[
0
],
np
.
ndarray
)
or
isinstance
(
raw_speech
[
0
],
(
tuple
,
list
)))
...
...
src/transformers/pipelines/automatic_speech_recognition.py
View file @
994b7a4e
...
...
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
inspect
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Dict
,
Optional
,
Union
...
...
@@ -259,9 +260,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# Currently chunking is not possible at this level for `seq2seq` so
# it's ok.
align_to
=
self
.
model
.
config
.
inputs_to_logits_ratio
chunk_len
=
int
(
round
(
chunk_length_s
*
self
.
feature_extractor
.
sampling_rate
/
align_to
)
)
*
align_to
stride_left
=
int
(
round
(
stride_length_s
[
0
]
*
self
.
feature_extractor
.
sampling_rate
/
align_to
)
)
*
align_to
stride_right
=
int
(
round
(
stride_length_s
[
1
]
*
self
.
feature_extractor
.
sampling_rate
/
align_to
)
)
*
align_to
chunk_len
=
int
(
round
(
chunk_length_s
*
self
.
feature_extractor
.
sampling_rate
/
align_to
)
*
align_to
)
stride_left
=
int
(
round
(
stride_length_s
[
0
]
*
self
.
feature_extractor
.
sampling_rate
/
align_to
)
*
align_to
)
stride_right
=
int
(
round
(
stride_length_s
[
1
]
*
self
.
feature_extractor
.
sampling_rate
/
align_to
)
*
align_to
)
if
self
.
type
not
in
{
"ctc"
,
"ctc_with_lm"
}:
raise
ValueError
(
...
...
@@ -304,12 +305,18 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
f
"`input_features` or `input_values` key, but only has
{
model_inputs
.
keys
()
}
"
)
accepts_attention_mask
=
"attention_mask"
in
set
(
inspect
.
signature
(
encoder
.
forward
).
parameters
.
keys
())
if
accepts_attention_mask
:
attention_mask
=
model_inputs
.
pop
(
"attention_mask"
,
None
)
tokens
=
self
.
model
.
generate
(
encoder_outputs
=
encoder
(
inputs
,
attention_mask
=
attention_mask
),
attention_mask
=
attention_mask
,
)
else
:
tokens
=
self
.
model
.
generate
(
inputs
)
out
=
{
"tokens"
:
tokens
}
else
:
stride
=
model_inputs
.
pop
(
"stride"
,
None
)
input_values
=
model_inputs
.
pop
(
"input_values"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment