Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
186aa6be
Unverified
Commit
186aa6be
authored
Jan 18, 2024
by
Sanchit Gandhi
Committed by
GitHub
Jan 18, 2024
Browse files
[Whisper] Fix audio classification with weighted layer sum (#28563)
* fix * tests * fix test
parent
619ecfe2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
8 deletions
+23
-8
src/transformers/models/whisper/modeling_whisper.py
src/transformers/models/whisper/modeling_whisper.py
+9
-1
tests/models/whisper/test_modeling_whisper.py
tests/models/whisper/test_modeling_whisper.py
+14
-7
No files found.
src/transformers/models/whisper/modeling_whisper.py
View file @
186aa6be
...
@@ -57,6 +57,8 @@ if is_flash_attn_2_available():
...
@@ -57,6 +57,8 @@ if is_flash_attn_2_available():
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
_HIDDEN_STATES_START_POSITION
=
1
_CONFIG_FOR_DOC
=
"WhisperConfig"
_CONFIG_FOR_DOC
=
"WhisperConfig"
_CHECKPOINT_FOR_DOC
=
"openai/whisper-tiny"
_CHECKPOINT_FOR_DOC
=
"openai/whisper-tiny"
...
@@ -2957,6 +2959,11 @@ class WhisperForAudioClassification(WhisperPreTrainedModel):
...
@@ -2957,6 +2959,11 @@ class WhisperForAudioClassification(WhisperPreTrainedModel):
output_hidden_states
=
(
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
)
if
self
.
config
.
use_weighted_layer_sum
:
output_hidden_states
=
True
elif
output_hidden_states
is
None
:
output_hidden_states
=
self
.
config
.
output_hidden_states
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
if
encoder_outputs
is
None
:
if
encoder_outputs
is
None
:
...
@@ -2969,7 +2976,8 @@ class WhisperForAudioClassification(WhisperPreTrainedModel):
...
@@ -2969,7 +2976,8 @@ class WhisperForAudioClassification(WhisperPreTrainedModel):
)
)
if
self
.
config
.
use_weighted_layer_sum
:
if
self
.
config
.
use_weighted_layer_sum
:
hidden_states
=
torch
.
stack
(
encoder_outputs
,
dim
=
1
)
hidden_states
=
encoder_outputs
[
_HIDDEN_STATES_START_POSITION
]
hidden_states
=
torch
.
stack
(
hidden_states
,
dim
=
1
)
norm_weights
=
nn
.
functional
.
softmax
(
self
.
layer_weights
,
dim
=-
1
)
norm_weights
=
nn
.
functional
.
softmax
(
self
.
layer_weights
,
dim
=-
1
)
hidden_states
=
(
hidden_states
*
norm_weights
.
view
(
-
1
,
1
,
1
)).
sum
(
dim
=
1
)
hidden_states
=
(
hidden_states
*
norm_weights
.
view
(
-
1
,
1
,
1
)).
sum
(
dim
=
1
)
else
:
else
:
...
...
tests/models/whisper/test_modeling_whisper.py
View file @
186aa6be
...
@@ -2292,16 +2292,15 @@ class WhisperEncoderModelTester:
...
@@ -2292,16 +2292,15 @@ class WhisperEncoderModelTester:
def
encoder_seq_length
(
self
):
def
encoder_seq_length
(
self
):
return
self
.
get_subsampled_output_lengths
(
self
.
seq_length
)
return
self
.
get_subsampled_output_lengths
(
self
.
seq_length
)
def
create_and_check_model_forward
(
self
,
config
,
inputs_dict
,
freeze_encoder
=
False
):
def
create_and_check_model_forward
(
self
,
config
,
inputs_dict
,
use_weighted_layer_sum
=
False
):
model
=
WhisperForAudioClassification
(
config
=
config
).
to
(
torch_device
).
eval
()
config
.
use_weighted_layer_sum
=
use_weighted_layer_sum
model
=
WhisperForAudioClassification
(
config
=
config
)
if
freeze_encoder
:
model
.
to
(
torch_device
).
eval
()
model
.
freeze_encoder
()
input_features
=
inputs_dict
[
"input_features"
]
input_features
=
inputs_dict
[
"input_features"
]
# first forward pass
with
torch
.
no_grad
():
last_hidden_state
=
model
(
input_features
).
logits
last_hidden_state
=
model
(
input_features
).
logits
self
.
parent
.
assertTrue
(
last_hidden_state
.
shape
,
(
13
,
2
))
self
.
parent
.
assertTrue
(
last_hidden_state
.
shape
,
(
13
,
2
))
...
@@ -2336,6 +2335,14 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
...
@@ -2336,6 +2335,14 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
expected_arg_names
=
[
"input_features"
,
"head_mask"
,
"encoder_outputs"
]
expected_arg_names
=
[
"input_features"
,
"head_mask"
,
"encoder_outputs"
]
self
.
assertListEqual
(
arg_names
[:
len
(
expected_arg_names
)],
expected_arg_names
)
self
.
assertListEqual
(
arg_names
[:
len
(
expected_arg_names
)],
expected_arg_names
)
def
test_forward_pass
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_model_forward
(
*
config_and_inputs
)
def
test_forward_pass_weighted_layer_sum
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_model_forward
(
*
config_and_inputs
,
use_weighted_layer_sum
=
True
)
@
unittest
.
skip
(
reason
=
"Some undefined behavior encountered with tiny versions of this model. Skip for now."
)
@
unittest
.
skip
(
reason
=
"Some undefined behavior encountered with tiny versions of this model. Skip for now."
)
def
test_cpu_offload
(
self
):
def
test_cpu_offload
(
self
):
pass
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment