Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
186aa6be
Unverified
Commit
186aa6be
authored
Jan 18, 2024
by
Sanchit Gandhi
Committed by
GitHub
Jan 18, 2024
Browse files
[Whisper] Fix audio classification with weighted layer sum (#28563)
* fix * tests * fix test
parent
619ecfe2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
8 deletions
+23
-8
src/transformers/models/whisper/modeling_whisper.py
src/transformers/models/whisper/modeling_whisper.py
+9
-1
tests/models/whisper/test_modeling_whisper.py
tests/models/whisper/test_modeling_whisper.py
+14
-7
No files found.
src/transformers/models/whisper/modeling_whisper.py
View file @
186aa6be
...
@@ -57,6 +57,8 @@ if is_flash_attn_2_available():
...
@@ -57,6 +57,8 @@ if is_flash_attn_2_available():
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
_HIDDEN_STATES_START_POSITION
=
1
_CONFIG_FOR_DOC
=
"WhisperConfig"
_CONFIG_FOR_DOC
=
"WhisperConfig"
_CHECKPOINT_FOR_DOC
=
"openai/whisper-tiny"
_CHECKPOINT_FOR_DOC
=
"openai/whisper-tiny"
...
@@ -2957,6 +2959,11 @@ class WhisperForAudioClassification(WhisperPreTrainedModel):
...
@@ -2957,6 +2959,11 @@ class WhisperForAudioClassification(WhisperPreTrainedModel):
output_hidden_states
=
(
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
)
if
self
.
config
.
use_weighted_layer_sum
:
output_hidden_states
=
True
elif
output_hidden_states
is
None
:
output_hidden_states
=
self
.
config
.
output_hidden_states
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
if
encoder_outputs
is
None
:
if
encoder_outputs
is
None
:
...
@@ -2969,7 +2976,8 @@ class WhisperForAudioClassification(WhisperPreTrainedModel):
...
@@ -2969,7 +2976,8 @@ class WhisperForAudioClassification(WhisperPreTrainedModel):
)
)
if
self
.
config
.
use_weighted_layer_sum
:
if
self
.
config
.
use_weighted_layer_sum
:
hidden_states
=
torch
.
stack
(
encoder_outputs
,
dim
=
1
)
hidden_states
=
encoder_outputs
[
_HIDDEN_STATES_START_POSITION
]
hidden_states
=
torch
.
stack
(
hidden_states
,
dim
=
1
)
norm_weights
=
nn
.
functional
.
softmax
(
self
.
layer_weights
,
dim
=-
1
)
norm_weights
=
nn
.
functional
.
softmax
(
self
.
layer_weights
,
dim
=-
1
)
hidden_states
=
(
hidden_states
*
norm_weights
.
view
(
-
1
,
1
,
1
)).
sum
(
dim
=
1
)
hidden_states
=
(
hidden_states
*
norm_weights
.
view
(
-
1
,
1
,
1
)).
sum
(
dim
=
1
)
else
:
else
:
...
...
tests/models/whisper/test_modeling_whisper.py
View file @
186aa6be
...
@@ -2292,16 +2292,15 @@ class WhisperEncoderModelTester:
...
@@ -2292,16 +2292,15 @@ class WhisperEncoderModelTester:
def
encoder_seq_length
(
self
):
def
encoder_seq_length
(
self
):
return
self
.
get_subsampled_output_lengths
(
self
.
seq_length
)
return
self
.
get_subsampled_output_lengths
(
self
.
seq_length
)
def
create_and_check_model_forward
(
self
,
config
,
inputs_dict
,
freeze_encoder
=
False
):
def
create_and_check_model_forward
(
self
,
config
,
inputs_dict
,
use_weighted_layer_sum
=
False
):
model
=
WhisperForAudioClassification
(
config
=
config
).
to
(
torch_device
).
eval
()
config
.
use_weighted_layer_sum
=
use_weighted_layer_sum
model
=
WhisperForAudioClassification
(
config
=
config
)
if
freeze_encoder
:
model
.
to
(
torch_device
).
eval
()
model
.
freeze_encoder
()
input_features
=
inputs_dict
[
"input_features"
]
input_features
=
inputs_dict
[
"input_features"
]
# first forward pass
with
torch
.
no_grad
():
last_hidden_state
=
model
(
input_features
).
logits
last_hidden_state
=
model
(
input_features
).
logits
self
.
parent
.
assertTrue
(
last_hidden_state
.
shape
,
(
13
,
2
))
self
.
parent
.
assertTrue
(
last_hidden_state
.
shape
,
(
13
,
2
))
...
@@ -2336,6 +2335,14 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
...
@@ -2336,6 +2335,14 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
expected_arg_names
=
[
"input_features"
,
"head_mask"
,
"encoder_outputs"
]
expected_arg_names
=
[
"input_features"
,
"head_mask"
,
"encoder_outputs"
]
self
.
assertListEqual
(
arg_names
[:
len
(
expected_arg_names
)],
expected_arg_names
)
self
.
assertListEqual
(
arg_names
[:
len
(
expected_arg_names
)],
expected_arg_names
)
def
test_forward_pass
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_model_forward
(
*
config_and_inputs
)
def
test_forward_pass_weighted_layer_sum
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_model_forward
(
*
config_and_inputs
,
use_weighted_layer_sum
=
True
)
@
unittest
.
skip
(
reason
=
"Some undefined behavior encountered with tiny versions of this model. Skip for now."
)
@
unittest
.
skip
(
reason
=
"Some undefined behavior encountered with tiny versions of this model. Skip for now."
)
def
test_cpu_offload
(
self
):
def
test_cpu_offload
(
self
):
pass
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment