Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6ce6a5ad
Unverified
Commit
6ce6a5ad
authored
Sep 26, 2023
by
sanjeevk-os
Committed by
GitHub
Sep 26, 2023
Browse files
added support for gradient checkpointing in ESM models (#26386)
parent
a8531f3b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
6 deletions
+27
-6
src/transformers/models/esm/modeling_esm.py
src/transformers/models/esm/modeling_esm.py
+5
-6
tests/models/esm/test_modeling_esm.py
tests/models/esm/test_modeling_esm.py
+22
-0
No files found.
src/transformers/models/esm/modeling_esm.py
View file @
6ce6a5ad
...
...
@@ -690,6 +690,7 @@ class EsmPreTrainedModel(PreTrainedModel):
config_class
=
EsmConfig
base_model_prefix
=
"esm"
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"EsmLayer"
,
"EsmFoldTriangularSelfAttentionBlock"
,
"EsmEmbeddings"
]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
...
...
@@ -709,6 +710,10 @@ class EsmPreTrainedModel(PreTrainedModel):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
EsmEncoder
):
module
.
gradient_checkpointing
=
value
ESM_START_DOCSTRING
=
r
"""
...
...
@@ -785,8 +790,6 @@ class EsmModel(EsmPreTrainedModel):
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
"""
supports_gradient_checkpointing
=
False
def
__init__
(
self
,
config
,
add_pooling_layer
=
True
):
super
().
__init__
(
config
)
self
.
config
=
config
...
...
@@ -803,10 +806,6 @@ class EsmModel(EsmPreTrainedModel):
# Initialize weights and apply final processing
self
.
post_init
()
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
EsmEncoder
):
module
.
gradient_checkpointing
=
value
def
get_input_embeddings
(
self
):
return
self
.
embeddings
.
word_embeddings
...
...
tests/models/esm/test_modeling_esm.py
View file @
6ce6a5ad
...
...
@@ -151,6 +151,24 @@ class EsmModelTester:
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
token_labels
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
num_labels
))
def
create_and_check_forward_and_backwards
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
,
gradient_checkpointing
=
False
,
):
model
=
EsmForMaskedLM
(
config
)
if
gradient_checkpointing
:
model
.
gradient_checkpointing_enable
()
model
.
to
(
torch_device
)
result
=
model
(
input_ids
,
attention_mask
=
input_mask
,
labels
=
token_labels
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
))
result
.
loss
.
backward
()
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
...
...
@@ -219,6 +237,10 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_for_token_classification
(
*
config_and_inputs
)
def
test_esm_gradient_checkpointing
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_forward_and_backwards
(
*
config_and_inputs
,
gradient_checkpointing
=
True
)
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
ESM_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment