Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d355741e
Unverified
Commit
d355741e
authored
May 27, 2024
by
Yoach Lacombe
Committed by
GitHub
May 27, 2024
Browse files
Fix pad_to_max_length Whisper (#30787)
* fix pad_to_max_length Whisper * add tests * make style
parent
b84cd675
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
84 additions
and
3 deletions
+84
-3
src/transformers/models/whisper/generation_whisper.py
src/transformers/models/whisper/generation_whisper.py
+8
-3
tests/models/whisper/test_modeling_whisper.py
tests/models/whisper/test_modeling_whisper.py
+76
-0
No files found.
src/transformers/models/whisper/generation_whisper.py
View file @
d355741e
...
...
@@ -122,7 +122,9 @@ def _get_attr_from_logit_processors(logits_processor, logit_processor_class, att
return
None
def
_pad_to_max_length
(
current_segments
,
pad_token_id
,
padding
=
"right"
,
bos_token_tensor
=
None
,
cut_off_length
=
None
):
def
_pad_to_max_length
(
current_segments
,
pad_token_id
,
device
,
padding
=
"right"
,
bos_token_tensor
=
None
,
cut_off_length
=
None
):
max_total_length
=
0
sequences
=
[]
if
padding
not
in
[
"right"
,
"left"
]:
...
...
@@ -143,7 +145,7 @@ def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_toke
elif
bos_token_tensor
is
not
None
:
sequences
.
append
(
bos_token_tensor
)
else
:
sequences
.
append
(
torch
.
tensor
([]))
sequences
.
append
(
torch
.
tensor
([]
,
device
=
device
))
for
i
in
range
(
len
(
current_segments
)):
pad_length
=
max_total_length
-
len
(
sequences
[
i
])
...
...
@@ -733,7 +735,9 @@ class WhisperGenerationMixin:
if
(
prompt_ids
is
not
None
and
generation_config
.
prompt_condition_type
==
"first-segment"
)
else
current_segments
)
sequences
=
_pad_to_max_length
(
final_segments
,
generation_config
.
pad_token_id
,
padding
=
"right"
)
sequences
=
_pad_to_max_length
(
final_segments
,
generation_config
.
pad_token_id
,
device
=
self
.
device
,
padding
=
"right"
)
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
if
return_segments
:
...
...
@@ -1506,6 +1510,7 @@ class WhisperGenerationMixin:
prev_tokens
=
_pad_to_max_length
(
active_segments
,
generation_config
.
pad_token_id
,
device
=
device
,
padding
=
"left"
,
bos_token_tensor
=
prev_ids
,
cut_off_length
=
cut_off_length
,
...
...
tests/models/whisper/test_modeling_whisper.py
View file @
d355741e
...
...
@@ -35,6 +35,7 @@ from transformers.testing_utils import (
require_torch
,
require_torch_fp16
,
require_torch_gpu
,
require_torch_multi_gpu
,
require_torchaudio
,
slow
,
torch_device
,
...
...
@@ -2866,6 +2867,81 @@ class WhisperModelIntegrationTests(unittest.TestCase):
for
i
in
range
(
num_samples
):
assert
decoded_all
[
i
]
==
EXPECTED_TEXT
[
i
]
@
require_torch_gpu
@
slow
def
test_whisper_empty_longform
(
self
):
processor
=
WhisperProcessor
.
from_pretrained
(
"openai/whisper-tiny"
)
model
=
WhisperForConditionalGeneration
.
from_pretrained
(
"openai/whisper-tiny"
)
model
=
model
.
to
(
torch_device
)
ds
=
load_dataset
(
"distil-whisper/meanwhile"
,
"default"
)[
"test"
]
ds
=
ds
.
cast_column
(
"audio"
,
Audio
(
sampling_rate
=
16000
))
num_samples
=
8
audio
=
ds
[:
num_samples
][
"audio"
]
audios
=
[
x
[
"array"
]
for
x
in
audio
]
audios
[
0
][:]
=
np
.
zeros
(
audios
[
0
].
shape
)
inputs
=
processor
(
audios
,
return_tensors
=
"pt"
,
truncation
=
False
,
padding
=
"longest"
,
return_attention_mask
=
True
,
sampling_rate
=
16_000
,
)
inputs
=
inputs
.
to
(
device
=
torch_device
)
gen_kwargs
=
{
"no_speech_threshold"
:
0.2
,
"temperature"
:
(
0.0
,),
"logprob_threshold"
:
0.0
,
# Ignore logprob, use only no-speech prob
"num_beams"
:
5
,
"language"
:
"fr"
,
"task"
:
"transcribe"
,
}
torch
.
manual_seed
(
0
)
model
.
generate
(
**
inputs
,
**
gen_kwargs
)
@
require_torch_multi_gpu
@
slow
def
test_whisper_empty_longform_multi_gpu
(
self
):
processor
=
WhisperProcessor
.
from_pretrained
(
"openai/whisper-tiny"
)
model
=
WhisperForConditionalGeneration
.
from_pretrained
(
"openai/whisper-tiny"
,
device_map
=
"auto"
)
ds
=
load_dataset
(
"distil-whisper/meanwhile"
,
"default"
)[
"test"
]
ds
=
ds
.
cast_column
(
"audio"
,
Audio
(
sampling_rate
=
16000
))
num_samples
=
8
audio
=
ds
[:
num_samples
][
"audio"
]
audios
=
[
x
[
"array"
]
for
x
in
audio
]
audios
[
0
][:]
=
np
.
zeros
(
audios
[
0
].
shape
)
inputs
=
processor
(
audios
,
return_tensors
=
"pt"
,
truncation
=
False
,
padding
=
"longest"
,
return_attention_mask
=
True
,
sampling_rate
=
16_000
,
)
inputs
=
inputs
.
to
(
device
=
model
.
device
)
gen_kwargs
=
{
"no_speech_threshold"
:
0.2
,
"temperature"
:
(
0.0
,),
"logprob_threshold"
:
0.0
,
# Ignore logprob, use only no-speech prob
"num_beams"
:
5
,
"language"
:
"fr"
,
"task"
:
"transcribe"
,
}
torch
.
manual_seed
(
0
)
model
.
generate
(
**
inputs
,
**
gen_kwargs
)
def
prepare_whisper_encoder_inputs_dict
(
config
,
input_features
,
head_mask
=
None
):
if
head_mask
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment