Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bbd150e9
Unverified
Commit
bbd150e9
authored
Oct 13, 2022
by
Sanchit Gandhi
Committed by
GitHub
Oct 13, 2022
Browse files
[Whisper] Freeze params of encoder (#19527)
* [Whisper] Freeze params of encoder * add tests
parent
504cd71a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
1 deletion
+43
-1
src/transformers/models/whisper/modeling_whisper.py
src/transformers/models/whisper/modeling_whisper.py
+19
-0
tests/models/whisper/test_modeling_whisper.py
tests/models/whisper/test_modeling_whisper.py
+24
-1
No files found.
src/transformers/models/whisper/modeling_whisper.py
View file @
bbd150e9
...
@@ -609,6 +609,11 @@ class WhisperEncoder(WhisperPreTrainedModel):
...
@@ -609,6 +609,11 @@ class WhisperEncoder(WhisperPreTrainedModel):
# Initialize weights and apply final processing
# Initialize weights and apply final processing
self
.
post_init
()
self
.
post_init
()
def
_freeze_parameters
(
self
):
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
self
.
_requires_grad
=
False
def
forward
(
def
forward
(
self
,
self
,
input_features
,
input_features
,
...
@@ -991,6 +996,13 @@ class WhisperModel(WhisperPreTrainedModel):
...
@@ -991,6 +996,13 @@ class WhisperModel(WhisperPreTrainedModel):
def
get_decoder
(
self
):
def
get_decoder
(
self
):
return
self
.
decoder
return
self
.
decoder
def
freeze_encoder
(
self
):
"""
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
not be updated during training.
"""
self
.
encoder
.
_freeze_parameters
()
@
add_start_docstrings_to_model_forward
(
WHISPER_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
WHISPER_INPUTS_DOCSTRING
)
@
add_code_sample_docstrings
(
@
add_code_sample_docstrings
(
processor_class
=
_PROCESSOR_FOR_DOC
,
processor_class
=
_PROCESSOR_FOR_DOC
,
...
@@ -1109,6 +1121,13 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
...
@@ -1109,6 +1121,13 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
def
set_output_embeddings
(
self
,
new_embeddings
):
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
proj_out
=
new_embeddings
self
.
proj_out
=
new_embeddings
def
freeze_encoder
(
self
):
"""
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
not be updated during training.
"""
self
.
model
.
encoder
.
_freeze_parameters
()
@
add_start_docstrings_to_model_forward
(
WHISPER_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
WHISPER_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
Seq2SeqLMOutput
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
Seq2SeqLMOutput
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
def
forward
(
...
...
tests/models/whisper/test_modeling_whisper.py
View file @
bbd150e9
...
@@ -182,9 +182,12 @@ class WhisperModelTester:
...
@@ -182,9 +182,12 @@ class WhisperModelTester:
return
input_lengths
return
input_lengths
def
create_and_check_model_forward
(
self
,
config
,
inputs_dict
):
def
create_and_check_model_forward
(
self
,
config
,
inputs_dict
,
freeze_encoder
=
False
):
model
=
WhisperModel
(
config
=
config
).
to
(
torch_device
).
eval
()
model
=
WhisperModel
(
config
=
config
).
to
(
torch_device
).
eval
()
if
freeze_encoder
:
model
.
freeze_encoder
()
input_features
=
inputs_dict
[
"input_features"
]
input_features
=
inputs_dict
[
"input_features"
]
decoder_input_ids
=
inputs_dict
[
"decoder_input_ids"
]
decoder_input_ids
=
inputs_dict
[
"decoder_input_ids"
]
...
@@ -289,6 +292,26 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
...
@@ -289,6 +292,26 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_model_forward
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_model_forward
(
*
config_and_inputs
)
def
test_model_forward_with_frozen_encoder
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_model_forward
(
*
config_and_inputs
,
freeze_encoder
=
True
)
def
test_requires_grad_with_frozen_encoder
(
self
):
config
=
self
.
model_tester
.
get_config
()
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
model
.
freeze_encoder
()
try
:
encoder_grads
=
[
param
.
requires_grad
for
param
in
model
.
encoder
.
parameters
()]
decoder_grads
=
[
param
.
requires_grad
for
param
in
model
.
decoder
.
parameters
()]
except
AttributeError
:
encoder_grads
=
[
param
.
requires_grad
for
param
in
model
.
model
.
encoder
.
parameters
()]
decoder_grads
=
[
param
.
requires_grad
for
param
in
model
.
model
.
decoder
.
parameters
()]
self
.
assertFalse
(
all
(
encoder_grads
))
self
.
assertTrue
(
all
(
decoder_grads
))
def
test_decoder_model_past_with_large_inputs
(
self
):
def
test_decoder_model_past_with_large_inputs
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_decoder_model_past_large_inputs
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_decoder_model_past_large_inputs
(
*
config_and_inputs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment