Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2e17db8a
Unverified
Commit
2e17db8a
authored
Nov 22, 2022
by
Younes Belkada
Committed by
GitHub
Nov 22, 2022
Browse files
[ESM] fix `accelerate` tests for esmfold (#20387)
* fix `accelerate` tests for esmfold * cleaner solution
parent
d2357a01
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
4 deletions
+7
-4
src/transformers/models/esm/modeling_esm.py
src/transformers/models/esm/modeling_esm.py
+1
-1
src/transformers/models/esm/modeling_esmfold.py
src/transformers/models/esm/modeling_esmfold.py
+6
-3
No files found.
src/transformers/models/esm/modeling_esm.py
View file @
2e17db8a
...
...
@@ -638,7 +638,7 @@ class EsmPreTrainedModel(PreTrainedModel):
config_class
=
EsmConfig
base_model_prefix
=
"esm"
_no_split_modules
=
[
"EsmLayer"
]
_no_split_modules
=
[
"EsmLayer"
,
"EsmFoldTriangularSelfAttentionBlock"
]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def
_init_weights
(
self
,
module
):
...
...
src/transformers/models/esm/modeling_esmfold.py
View file @
2e17db8a
...
...
@@ -1956,9 +1956,9 @@ class EsmFoldingTrunk(nn.Module):
for
recycle_idx
in
range
(
no_recycles
):
with
ContextManagers
([]
if
recycle_idx
==
no_recycles
-
1
else
[
torch
.
no_grad
()]):
# === Recycling ===
recycle_s
=
self
.
recycle_s_norm
(
recycle_s
.
detach
())
recycle_z
=
self
.
recycle_z_norm
(
recycle_z
.
detach
())
recycle_z
+=
self
.
recycle_disto
(
recycle_bins
.
detach
())
recycle_s
=
self
.
recycle_s_norm
(
recycle_s
.
detach
())
.
to
(
device
)
recycle_z
=
self
.
recycle_z_norm
(
recycle_z
.
detach
())
.
to
(
device
)
recycle_z
+=
self
.
recycle_disto
(
recycle_bins
.
detach
())
.
to
(
device
)
s_s
,
s_z
=
trunk_iter
(
s_s_0
+
recycle_s
,
s_z_0
+
recycle_z
,
residx
,
mask
)
...
...
@@ -2207,6 +2207,9 @@ class EsmForProteinFolding(EsmPreTrainedModel):
return
EsmForProteinFoldingOutput
(
**
structure
)
def
af2_idx_to_esm_idx
(
self
,
aa
,
mask
):
# avoid indexing on different devices
if
self
.
af2_to_esm
.
device
!=
aa
.
device
:
self
.
af2_to_esm
=
self
.
af2_to_esm
.
to
(
aa
.
device
)
aa
=
(
aa
+
1
).
masked_fill
(
mask
!=
1
,
0
)
return
self
.
af2_to_esm
[
aa
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment