Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bbd150e9
Unverified
Commit
bbd150e9
authored
Oct 13, 2022
by
Sanchit Gandhi
Committed by
GitHub
Oct 13, 2022
Browse files
[Whisper] Freeze params of encoder (#19527)
* [Whisper] Freeze params of encoder * add tests
parent
504cd71a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
1 deletion
+43
-1
src/transformers/models/whisper/modeling_whisper.py
src/transformers/models/whisper/modeling_whisper.py
+19
-0
tests/models/whisper/test_modeling_whisper.py
tests/models/whisper/test_modeling_whisper.py
+24
-1
No files found.
src/transformers/models/whisper/modeling_whisper.py
View file @
bbd150e9
...
...
@@ -609,6 +609,11 @@ class WhisperEncoder(WhisperPreTrainedModel):
# Initialize weights and apply final processing
self
.
post_init
()
def
_freeze_parameters
(
self
):
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
self
.
_requires_grad
=
False
def
forward
(
self
,
input_features
,
...
...
@@ -991,6 +996,13 @@ class WhisperModel(WhisperPreTrainedModel):
def
get_decoder
(
self
):
return
self
.
decoder
def
freeze_encoder
(
self
):
"""
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
not be updated during training.
"""
self
.
encoder
.
_freeze_parameters
()
@
add_start_docstrings_to_model_forward
(
WHISPER_INPUTS_DOCSTRING
)
@
add_code_sample_docstrings
(
processor_class
=
_PROCESSOR_FOR_DOC
,
...
...
@@ -1109,6 +1121,13 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
proj_out
=
new_embeddings
def
freeze_encoder
(
self
):
"""
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
not be updated during training.
"""
self
.
model
.
encoder
.
_freeze_parameters
()
@
add_start_docstrings_to_model_forward
(
WHISPER_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
Seq2SeqLMOutput
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
...
...
tests/models/whisper/test_modeling_whisper.py
View file @
bbd150e9
...
...
@@ -182,9 +182,12 @@ class WhisperModelTester:
return
input_lengths
def
create_and_check_model_forward
(
self
,
config
,
inputs_dict
):
def
create_and_check_model_forward
(
self
,
config
,
inputs_dict
,
freeze_encoder
=
False
):
model
=
WhisperModel
(
config
=
config
).
to
(
torch_device
).
eval
()
if
freeze_encoder
:
model
.
freeze_encoder
()
input_features
=
inputs_dict
[
"input_features"
]
decoder_input_ids
=
inputs_dict
[
"decoder_input_ids"
]
...
...
@@ -289,6 +292,26 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_model_forward
(
*
config_and_inputs
)
def
test_model_forward_with_frozen_encoder
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_model_forward
(
*
config_and_inputs
,
freeze_encoder
=
True
)
def
test_requires_grad_with_frozen_encoder
(
self
):
config
=
self
.
model_tester
.
get_config
()
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
model
.
freeze_encoder
()
try
:
encoder_grads
=
[
param
.
requires_grad
for
param
in
model
.
encoder
.
parameters
()]
decoder_grads
=
[
param
.
requires_grad
for
param
in
model
.
decoder
.
parameters
()]
except
AttributeError
:
encoder_grads
=
[
param
.
requires_grad
for
param
in
model
.
model
.
encoder
.
parameters
()]
decoder_grads
=
[
param
.
requires_grad
for
param
in
model
.
model
.
decoder
.
parameters
()]
self
.
assertFalse
(
all
(
encoder_grads
))
self
.
assertTrue
(
all
(
decoder_grads
))
def
test_decoder_model_past_with_large_inputs
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_decoder_model_past_large_inputs
(
*
config_and_inputs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment