Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
bb6f6145
"...git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "1f5b7872b03e9e3d42801872bc59681ef36357b5"
Commit
bb6f6145
authored
Dec 30, 2021
by
Gustaf Ahdritz
Browse files
Slightly refactor recycling code to improve FP16 stability
parent
b47138dc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
31 deletions
+41
-31
openfold/model/model.py
openfold/model/model.py
+41
-31
No files found.
openfold/model/model.py
View file @
bb6f6145
...
...
@@ -197,43 +197,52 @@ class AlphaFold(nn.Module):
feats
[
"msa_feat"
],
)
# Inject information from previous recycling iterations
if
_recycle
:
# Initialize the recycling embeddings, if needs be
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
# [*, N, C_m]
m_1_prev
=
m
.
new_zeros
(
(
*
batch_dims
,
n
,
self
.
config
.
input_embedder
.
c_m
),
)
# Initialize the recycling embeddings, if needs be
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
# [*, N, C_m]
m_1_prev
=
m
.
new_zeros
(
(
*
batch_dims
,
n
,
self
.
config
.
input_embedder
.
c_m
),
requires_grad
=
False
,
)
# [*, N, N, C_z]
z_prev
=
z
.
new_zeros
(
(
*
batch_dims
,
n
,
n
,
self
.
config
.
input_embedder
.
c_z
),
)
# [*, N, N, C_z]
z_prev
=
z
.
new_zeros
(
(
*
batch_dims
,
n
,
n
,
self
.
config
.
input_embedder
.
c_z
),
requires_grad
=
False
,
)
# [*, N, 3]
x_prev
=
z
.
new_zeros
(
(
*
batch_dims
,
n
,
residue_constants
.
atom_type_num
,
3
),
)
# [*, N, 3]
x_prev
=
z
.
new_zeros
(
(
*
batch_dims
,
n
,
residue_constants
.
atom_type_num
,
3
),
requires_grad
=
False
,
)
x_prev
=
pseudo_beta_fn
(
feats
[
"aatype"
],
x_prev
,
None
)
x_prev
=
pseudo_beta_fn
(
feats
[
"aatype"
],
x_prev
,
None
)
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb
,
z_prev_emb
=
self
.
recycling_embedder
(
m_1_prev
,
z_prev
,
x_prev
,
)
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb
,
z_prev_emb
=
self
.
recycling_embedder
(
m_1_prev
,
z_prev
,
x_prev
,
)
# [*, S_c, N, C_m]
m
[...,
0
,
:,
:]
=
m
[...,
0
,
:,
:]
+
m_1_prev_emb
# If the number of recycling iterations is 0, skip recycling
# altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training.
if
(
not
_recycle
):
m_1_prev_emb
*=
0
z_prev_emb
*=
0
# [*,
N
, N, C_
z
]
z
=
z
+
z
_prev_emb
# [*,
S_c
, N, C_
m
]
m
[...,
0
,
:,
:]
+=
m_1
_prev_emb
# Possibly prevents memory fragmentation
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
# [*, N, N, C_z]
z
+=
z_prev_emb
# Possibly prevents memory fragmentation
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if
self
.
config
.
template
.
enabled
:
...
...
@@ -408,11 +417,12 @@ class AlphaFold(nn.Module):
# Enable grad iff we're training and it's the final recycling layer
is_final_iter
=
cycle_no
==
(
num_iters
-
1
)
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
# Sidestep AMP bug (PyTorch issue #65766)
if
is_final_iter
:
self
.
_enable_activation_checkpointing
()
# Sidestep AMP bug (PyTorch issue #65766)
if
torch
.
is_autocast_enabled
():
torch
.
clear_autocast_cache
()
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment