Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
f34bef8e
"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "96dfcbd05101d0bc3cc6ad87548c4d69b475b989"
Commit
f34bef8e
authored
Oct 02, 2021
by
Gustaf Ahdritz
Browse files
Sidestep AMP bug
parent
304b5ff7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
23 deletions
+14
-23
openfold/model/model.py
openfold/model/model.py
+14
-23
No files found.
openfold/model/model.py
View file @
f34bef8e
...
@@ -235,6 +235,9 @@ class AlphaFold(nn.Module):
...
@@ -235,6 +235,9 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
z
+
z_prev_emb
z
=
z
+
z_prev_emb
# This can matter during inference when N_res is very large
del
m_1_prev_emb
,
z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
# Embed the templates + merge with MSA/pair embeddings
if
(
self
.
config
.
template
.
enabled
):
if
(
self
.
config
.
template
.
enabled
):
template_feats
=
{
template_feats
=
{
...
@@ -332,16 +335,6 @@ class AlphaFold(nn.Module):
...
@@ -332,16 +335,6 @@ class AlphaFold(nn.Module):
self
.
config
.
extra_msa
.
extra_msa_stack
.
blocks_per_ckpt
self
.
config
.
extra_msa
.
extra_msa_stack
.
blocks_per_ckpt
)
)
def
_disable_grad
(
self
):
vals
=
[
p
.
requires_grad
for
p
in
self
.
parameters
()]
for
p
in
self
.
parameters
():
p
.
requires_grad_
(
False
)
return
vals
def
_enable_grad
(
self
,
vals
):
for
p
,
v
in
zip
(
self
.
parameters
(),
vals
):
p
.
requires_grad_
(
v
)
def
forward
(
self
,
batch
):
def
forward
(
self
,
batch
):
"""
"""
Args:
Args:
...
@@ -394,27 +387,25 @@ class AlphaFold(nn.Module):
...
@@ -394,27 +387,25 @@ class AlphaFold(nn.Module):
"""
"""
# Initialize recycling embeddings
# Initialize recycling embeddings
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
# Disable activation checkpointing until the final recycling layer
is_grad_enabled
=
torch
.
is_grad_enabled
()
self
.
_disable_activation_checkpointing
()
grad_vals
=
self
.
_disable_grad
()
# Main recycling loop
# Main recycling loop
for
cycle_no
in
range
(
self
.
config
.
no_cycles
):
for
cycle_no
in
range
(
self
.
config
.
no_cycles
):
# Select the features for the current recycling cycle
# Select the features for the current recycling cycle
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
# Enable grad iff we're training and it's the final recycling layer
# Enable grad iff we're training and it's the final recycling layer
is_final_iter
=
(
cycle_no
==
(
self
.
config
.
no_cycles
-
1
))
is_final_iter
=
(
cycle_no
==
(
self
.
config
.
no_cycles
-
1
))
if
(
is_final_iter
):
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
self
.
_enable_activation_checkpointing
()
# Sidestep AMP bug discussed in pytorch issue #65766
self
.
_enable_grad
(
grad_vals
)
if
(
is_final_iter
and
torch
.
is_autocast_enabled
()):
torch
.
clear_autocast_cache
()
# Run the next iteration of the model
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
m_1_prev
,
z_prev
,
x_prev
,
feats
,
m_1_prev
,
z_prev
,
x_prev
,
)
)
# Run auxiliary heads
# Run auxiliary heads
outputs
.
update
(
self
.
aux_heads
(
outputs
))
outputs
.
update
(
self
.
aux_heads
(
outputs
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment