Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
de0fa7b1
Commit
de0fa7b1
authored
Sep 27, 2021
by
Gustaf Ahdritz
Browse files
Fix tensor casting, improve TorchScript compatibility
parent
7d53297c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
38 additions
and
14 deletions
+38
-14
openfold/model/evoformer.py
openfold/model/evoformer.py
+1
-1
openfold/model/model.py
openfold/model/model.py
+24
-4
openfold/model/primitives.py
openfold/model/primitives.py
+1
-1
openfold/utils/deepspeed.py
openfold/utils/deepspeed.py
+2
-1
openfold/utils/feats.py
openfold/utils/feats.py
+8
-5
openfold/utils/loss.py
openfold/utils/loss.py
+2
-2
No files found.
openfold/model/evoformer.py
View file @
de0fa7b1
...
...
@@ -430,7 +430,7 @@ class ExtraMSAStack(nn.Module):
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
"""
_
,
z
,
_
=
self
.
stack
(
m
,
z
,
...
...
openfold/model/model.py
View file @
de0fa7b1
...
...
@@ -43,7 +43,6 @@ from openfold.model.template import (
from
openfold.utils.loss
import
(
compute_plddt
,
)
from
openfold.utils.tensor_utils
import
(
dict_multimap
,
tensor_tree_map
,
...
...
@@ -162,7 +161,7 @@ class AlphaFold(nn.Module):
z
,
template_mask
=
batch
[
"template_mask"
]
)
t
=
t
*
torch
.
sum
(
batch
[
"template_mask"
])
>
0
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
return
{
"template_angle_embedding"
:
a
,
...
...
@@ -318,6 +317,22 @@ class AlphaFold(nn.Module):
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
def
_disable_activation_checkpointing
(
self
):
self
.
template_pair_stack
.
blocks_per_ckpt
=
None
self
.
evoformer
.
blocks_per_ckpt
=
None
self
.
extra_msa_stack
.
stack
.
blocks_per_ckpt
=
None
def
_enable_activation_checkpointing
(
self
):
self
.
template_pair_stack
.
blocks_per_ckpt
=
(
self
.
config
.
template
.
template_pair_stack
.
blocks_per_ckpt
)
self
.
evoformer
.
blocks_per_ckpt
=
(
self
.
config
.
evoformer_stack
.
blocks_per_ckpt
)
self
.
extra_msa_stack
.
stack
.
blocks_per_ckpt
=
(
self
.
config
.
extra_msa
.
extra_msa_stack
.
blocks_per_ckpt
)
def
forward
(
self
,
batch
):
"""
Args:
...
...
@@ -368,9 +383,12 @@ class AlphaFold(nn.Module):
"template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask
"""
#
R
ecycling embeddings
#
Initialize r
ecycling embeddings
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
# Disable activation checkpointing until the final recycling layer
self
.
_disable_activation_checkpointing
()
# Main recycling loop
for
cycle_no
in
range
(
self
.
config
.
no_cycles
):
# Select the features for the current recycling cycle
...
...
@@ -379,9 +397,11 @@ class AlphaFold(nn.Module):
# Enable grad iff we're training and it's the final recycling layer
is_final_iter
=
(
cycle_no
==
self
.
config
.
no_cycles
-
1
)
if
(
self
.
training
and
is_final_iter
):
self
.
_enable_activation_checkpointing
()
with
torch
.
set_grad_enabled
(
self
.
training
and
is_final_iter
):
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
m_1_prev
,
z_prev
,
x_prev
,
feats
,
m_1_prev
,
z_prev
,
x_prev
,
)
outputs
.
update
(
self
.
aux_heads
(
outputs
))
...
...
openfold/model/primitives.py
View file @
de0fa7b1
...
...
@@ -313,7 +313,7 @@ class GlobalAttention(nn.Module):
# [*, N_res, H * C_hidden]
q
=
self
.
linear_q
(
q
)
q
=
q
*
self
.
c_hidden
**
(
-
0.5
)
q
=
q
*
(
self
.
c_hidden
**
(
-
0.5
)
)
# [*, N_res, H, C_hidden]
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
...
...
openfold/utils/deepspeed.py
View file @
de0fa7b1
...
...
@@ -70,7 +70,8 @@ def checkpoint_blocks(
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
e
=
s
+
blocks_per_ckpt
args
=
deepspeed
.
checkpointing
.
checkpoint
(
chunker
(
s
,
e
),
*
args
)
args
=
checkpoint
(
chunker
(
s
,
e
),
*
args
)
#args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args
=
wrap
(
args
)
return
args
openfold/utils/feats.py
View file @
de0fa7b1
...
...
@@ -279,7 +279,8 @@ def atom37_to_torsion_angles(
)
torsion_angles_sin_cos
=
torch
.
stack
(
[
fourth_atom_rel_pos
[...,
2
],
fourth_atom_rel_pos
[...,
1
]],
dim
=-
1
)
[
fourth_atom_rel_pos
[...,
2
],
fourth_atom_rel_pos
[...,
1
]],
dim
=-
1
)
denom
=
torch
.
sqrt
(
torch
.
sum
(
torch
.
square
(
torsion_angles_sin_cos
),
dim
=-
1
,
keepdims
=
True
...
...
@@ -336,7 +337,7 @@ def atom37_to_frames(
restype_rigidgroup_mask
=
torch
.
zeros
(
(
*
aatype
.
shape
[:
-
1
],
21
,
8
),
dtype
=
torch
.
float
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
,
requires_grad
=
False
)
...
...
@@ -390,14 +391,16 @@ def atom37_to_frames(
)
gt_atoms_exist
=
batched_gather
(
all_atom_mask
.
float
()
,
all_atom_mask
,
residx_rigidgroup_base_atom37_idx
,
dim
=-
1
,
no_batch_dims
=
len
(
all_atom_mask
.
shape
[:
-
1
])
)
gt_exists
=
torch
.
min
(
gt_atoms_exist
,
dim
=-
1
)[
0
]
*
group_exists
rots
=
torch
.
eye
(
3
,
device
=
aatype
.
device
,
requires_grad
=
False
)
rots
=
torch
.
eye
(
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
,
requires_grad
=
False
)
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
...
...
@@ -408,7 +411,7 @@ def atom37_to_frames(
*
((
1
,)
*
batch_dims
),
21
,
8
)
restype_rigidgroup_rots
=
torch
.
eye
(
3
,
device
=
aatype
.
device
,
requires_grad
=
False
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
,
requires_grad
=
False
)
restype_rigidgroup_rots
=
torch
.
tile
(
restype_rigidgroup_rots
,
...
...
openfold/utils/loss.py
View file @
de0fa7b1
...
...
@@ -1385,9 +1385,9 @@ class AlphaFoldLoss(nn.Module):
for
k
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
k
].
weight
if
(
weight
):
print
(
k
)
loss
=
loss_fn
()
print
(
loss
)
#print(k)
#print(loss)
cum_loss
+=
weight
*
loss
return
cum_loss
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment