Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
bb6f6145
Commit
bb6f6145
authored
Dec 30, 2021
by
Gustaf Ahdritz
Browse files
Slightly refactor recycling code to improve FP16 stability
parent
b47138dc
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
31 deletions
+41
-31
openfold/model/model.py
openfold/model/model.py
+41
-31
No files found.
openfold/model/model.py
View file @
bb6f6145
...
@@ -197,23 +197,24 @@ class AlphaFold(nn.Module):
...
@@ -197,23 +197,24 @@ class AlphaFold(nn.Module):
feats
[
"msa_feat"
],
feats
[
"msa_feat"
],
)
)
# Inject information from previous recycling iterations
if
_recycle
:
# Initialize the recycling embeddings, if needs be
# Initialize the recycling embeddings, if needs be
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
# [*, N, C_m]
# [*, N, C_m]
m_1_prev
=
m
.
new_zeros
(
m_1_prev
=
m
.
new_zeros
(
(
*
batch_dims
,
n
,
self
.
config
.
input_embedder
.
c_m
),
(
*
batch_dims
,
n
,
self
.
config
.
input_embedder
.
c_m
),
requires_grad
=
False
,
)
)
# [*, N, N, C_z]
# [*, N, N, C_z]
z_prev
=
z
.
new_zeros
(
z_prev
=
z
.
new_zeros
(
(
*
batch_dims
,
n
,
n
,
self
.
config
.
input_embedder
.
c_z
),
(
*
batch_dims
,
n
,
n
,
self
.
config
.
input_embedder
.
c_z
),
requires_grad
=
False
,
)
)
# [*, N, 3]
# [*, N, 3]
x_prev
=
z
.
new_zeros
(
x_prev
=
z
.
new_zeros
(
(
*
batch_dims
,
n
,
residue_constants
.
atom_type_num
,
3
),
(
*
batch_dims
,
n
,
residue_constants
.
atom_type_num
,
3
),
requires_grad
=
False
,
)
)
x_prev
=
pseudo_beta_fn
(
feats
[
"aatype"
],
x_prev
,
None
)
x_prev
=
pseudo_beta_fn
(
feats
[
"aatype"
],
x_prev
,
None
)
...
@@ -226,11 +227,19 @@ class AlphaFold(nn.Module):
...
@@ -226,11 +227,19 @@ class AlphaFold(nn.Module):
x_prev
,
x_prev
,
)
)
# If the number of recycling iterations is 0, skip recycling
# altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training.
if
(
not
_recycle
):
m_1_prev_emb
*=
0
z_prev_emb
*=
0
# [*, S_c, N, C_m]
# [*, S_c, N, C_m]
m
[...,
0
,
:,
:]
=
m
[...,
0
,
:,
:]
+
m_1_prev_emb
m
[...,
0
,
:,
:]
+
=
m_1_prev_emb
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
z
+
z_prev_emb
z
+
=
z_prev_emb
# Possibly prevents memory fragmentation
# Possibly prevents memory fragmentation
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
...
@@ -408,11 +417,12 @@ class AlphaFold(nn.Module):
...
@@ -408,11 +417,12 @@ class AlphaFold(nn.Module):
# Enable grad iff we're training and it's the final recycling layer
# Enable grad iff we're training and it's the final recycling layer
is_final_iter
=
cycle_no
==
(
num_iters
-
1
)
is_final_iter
=
cycle_no
==
(
num_iters
-
1
)
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
# Sidestep AMP bug (PyTorch issue #65766)
if
is_final_iter
:
if
is_final_iter
:
self
.
_enable_activation_checkpointing
()
self
.
_enable_activation_checkpointing
()
# Sidestep AMP bug (PyTorch issue #65766)
if
torch
.
is_autocast_enabled
():
if
torch
.
is_autocast_enabled
():
torch
.
clear_autocast_cache
()
torch
.
clear_autocast_cache
()
# Run the next iteration of the model
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
feats
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment