Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
bb6f6145
Commit
bb6f6145
authored
Dec 30, 2021
by
Gustaf Ahdritz
Browse files
Slightly refactor recycling code to improve FP16 stability
parent
b47138dc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
31 deletions
+41
-31
openfold/model/model.py
openfold/model/model.py
+41
-31
No files found.
openfold/model/model.py
View file @
bb6f6145
...
@@ -197,43 +197,52 @@ class AlphaFold(nn.Module):
...
@@ -197,43 +197,52 @@ class AlphaFold(nn.Module):
feats
[
"msa_feat"
],
feats
[
"msa_feat"
],
)
)
# Inject information from previous recycling iterations
# Initialize the recycling embeddings, if needs be
if
_recycle
:
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
# Initialize the recycling embeddings, if needs be
# [*, N, C_m]
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
m_1_prev
=
m
.
new_zeros
(
# [*, N, C_m]
(
*
batch_dims
,
n
,
self
.
config
.
input_embedder
.
c_m
),
m_1_prev
=
m
.
new_zeros
(
requires_grad
=
False
,
(
*
batch_dims
,
n
,
self
.
config
.
input_embedder
.
c_m
),
)
)
# [*, N, N, C_z]
# [*, N, N, C_z]
z_prev
=
z
.
new_zeros
(
z_prev
=
z
.
new_zeros
(
(
*
batch_dims
,
n
,
n
,
self
.
config
.
input_embedder
.
c_z
),
(
*
batch_dims
,
n
,
n
,
self
.
config
.
input_embedder
.
c_z
),
)
requires_grad
=
False
,
)
# [*, N, 3]
# [*, N, 3]
x_prev
=
z
.
new_zeros
(
x_prev
=
z
.
new_zeros
(
(
*
batch_dims
,
n
,
residue_constants
.
atom_type_num
,
3
),
(
*
batch_dims
,
n
,
residue_constants
.
atom_type_num
,
3
),
)
requires_grad
=
False
,
)
x_prev
=
pseudo_beta_fn
(
feats
[
"aatype"
],
x_prev
,
None
)
x_prev
=
pseudo_beta_fn
(
feats
[
"aatype"
],
x_prev
,
None
)
# m_1_prev_emb: [*, N, C_m]
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb
,
z_prev_emb
=
self
.
recycling_embedder
(
m_1_prev_emb
,
z_prev_emb
=
self
.
recycling_embedder
(
m_1_prev
,
m_1_prev
,
z_prev
,
z_prev
,
x_prev
,
x_prev
,
)
)
# [*, S_c, N, C_m]
# If the number of recycling iterations is 0, skip recycling
m
[...,
0
,
:,
:]
=
m
[...,
0
,
:,
:]
+
m_1_prev_emb
# altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training.
if
(
not
_recycle
):
m_1_prev_emb
*=
0
z_prev_emb
*=
0
# [*,
N
, N, C_
z
]
# [*,
S_c
, N, C_
m
]
z
=
z
+
z
_prev_emb
m
[...,
0
,
:,
:]
+=
m_1
_prev_emb
# Possibly prevents memory fragmentation
# [*, N, N, C_z]
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
z
+=
z_prev_emb
# Possibly prevents memory fragmentation
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
# Embed the templates + merge with MSA/pair embeddings
if
self
.
config
.
template
.
enabled
:
if
self
.
config
.
template
.
enabled
:
...
@@ -408,11 +417,12 @@ class AlphaFold(nn.Module):
...
@@ -408,11 +417,12 @@ class AlphaFold(nn.Module):
# Enable grad iff we're training and it's the final recycling layer
# Enable grad iff we're training and it's the final recycling layer
is_final_iter
=
cycle_no
==
(
num_iters
-
1
)
is_final_iter
=
cycle_no
==
(
num_iters
-
1
)
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
# Sidestep AMP bug (PyTorch issue #65766)
if
is_final_iter
:
if
is_final_iter
:
self
.
_enable_activation_checkpointing
()
self
.
_enable_activation_checkpointing
()
# Sidestep AMP bug (PyTorch issue #65766)
if
torch
.
is_autocast_enabled
():
if
torch
.
is_autocast_enabled
():
torch
.
clear_autocast_cache
()
torch
.
clear_autocast_cache
()
# Run the next iteration of the model
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
feats
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment