Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
9ce8713c
Commit
9ce8713c
authored
Sep 22, 2021
by
Gustaf Ahdritz
Browse files
Improve memory efficiency, fix OpenMM CUDA + loss bugs
parent
754d2ba8
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
135 additions
and
111 deletions
+135
-111
openfold/model/model.py
openfold/model/model.py
+73
-37
openfold/model/msa.py
openfold/model/msa.py
+3
-3
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+0
-11
openfold/utils/feats.py
openfold/utils/feats.py
+3
-20
openfold/utils/loss.py
openfold/utils/loss.py
+12
-10
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+17
-5
run_pretrained_alphafold.py
run_pretrained_alphafold.py
+27
-25
No files found.
openfold/model/model.py
View file @
9ce8713c
...
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
torch
import
torch.nn
as
nn
...
...
@@ -44,6 +45,7 @@ from openfold.utils.loss import (
)
from
openfold.utils.tensor_utils
import
(
dict_multimap
,
tensor_tree_map
,
)
...
...
@@ -103,48 +105,70 @@ class AlphaFold(nn.Module):
self
.
config
=
config
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
):
# Build template angle feats
angle_feats
=
atom37_to_torsion_angles
(
batch
[
"template_aatype"
],
batch
[
"template_all_atom_positions"
],
batch
[
"template_all_atom_masks"
],
eps
=
1e-8
)
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
-
2
]
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
),
batch
,
)
# Stow this away for later
batch
[
"torsion_angles_mask"
]
=
angle_feats
[
"torsion_angles_mask"
]
# Build template angle feats
angle_feats
=
atom37_to_torsion_angles
(
single_template_feats
[
"template_aatype"
],
single_template_feats
[
"template_all_atom_positions"
],
single_template_feats
[
"template_all_atom_masks"
],
eps
=
1e-8
)
template_angle_feat
=
build_template_angle_feat
(
angle_feats
,
batch
[
"template_aatype"
],
)
template_angle_feat
=
build_template_angle_feat
(
angle_feats
,
single_template_feats
[
"template_aatype"
],
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
# [*, S_t, N, N, C_t]
t
=
build_template_pair_feat
(
batch
,
eps
=
self
.
config
.
template
.
eps
,
**
self
.
config
.
template
.
distogram
)
t
=
self
.
template_pair_embedder
(
t
)
t
=
self
.
template_pair_stack
(
t
,
pair_mask
.
unsqueeze
(
-
3
),
_mask_trans
=
self
.
config
.
_mask_trans
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
# [*, S_t, N, N, C_t]
t
=
build_template_pair_feat
(
single_template_feats
,
eps
=
self
.
config
.
template
.
eps
,
**
self
.
config
.
template
.
distogram
)
t
=
self
.
template_pair_embedder
(
t
)
t
=
self
.
template_pair_stack
(
t
,
pair_mask
.
unsqueeze
(
-
3
),
_mask_trans
=
self
.
config
.
_mask_trans
)
template_embeds
.
append
({
"angle"
:
a
,
"pair"
:
t
,
"torsion_mask"
:
angle_feats
[
"torsion_angles_mask"
]
})
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
template_embeds
,
)
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
,
t
emplate_embeds
[
"pair"
]
,
z
,
template_mask
=
batch
[
"template_mask"
]
)
t
*=
torch
.
sum
(
batch
[
"template_mask"
])
>
0
return
a
,
t
return
{
"template_angle_embedding"
:
a
,
"template_pair_embedding"
:
t
,
"torsion_angles_mask"
:
angle_feats
[
"torsion_angles_mask"
],
}
def
forward
(
self
,
batch
):
"""
...
...
@@ -210,6 +234,7 @@ class AlphaFold(nn.Module):
# Grab some data about the input
batch_dims
=
feats
[
"target_feat"
].
shape
[:
-
2
]
no_batch_dims
=
len
(
batch_dims
)
n
=
feats
[
"target_feat"
].
shape
[
-
2
]
n_seq
=
feats
[
"msa_feat"
].
shape
[
-
3
]
device
=
feats
[
"target_feat"
].
device
...
...
@@ -257,7 +282,7 @@ class AlphaFold(nn.Module):
x_prev
=
pseudo_beta_fn
(
feats
[
"aatype"
],
x_prev
,
None
# TODO: figure this part out
None
)
# m_1_prev_emb: [*, N, C_m]
...
...
@@ -276,17 +301,28 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings
if
(
self
.
config
.
template
.
enabled
):
a
,
t
=
self
.
embed_templates
(
feats
,
z
,
pair_mask
)
template_feats
=
{
k
:
v
for
k
,
v
in
feats
.
items
()
if
"template_"
in
k
}
template_embeds
=
self
.
embed_templates
(
template_feats
,
z
,
pair_mask
,
no_batch_dims
,
)
# [*, N, N, C_z]
z
+=
t
z
+=
t
emplate_embeds
[
"template_pair_embedding"
]
if
(
self
.
config
.
template
.
embed_angles
):
# [*, S = S_c + S_t, N, C_m]
m
=
torch
.
cat
([
m
,
a
],
dim
=-
3
)
m
=
torch
.
cat
(
[
m
,
template_embeds
[
"template_angle_embedding"
]],
dim
=-
3
)
# [*, S, N]
torsion_angles_mask
=
feat
s
[
"torsion_angles_mask"
]
torsion_angles_mask
=
template_embed
s
[
"torsion_angles_mask"
]
msa_mask
=
torch
.
cat
(
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
axis
=-
2
)
...
...
openfold/model/msa.py
View file @
9ce8713c
...
...
@@ -106,6 +106,7 @@ class MSAAttention(nn.Module):
(
*
((
-
1
,)
*
len
(
bias
.
shape
[:
-
4
])),
-
1
,
self
.
no_heads
,
n_res
,
-
1
)
)
biases
=
[
bias
]
if
(
self
.
pair_bias
):
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm_z
(
z
)
...
...
@@ -116,14 +117,13 @@ class MSAAttention(nn.Module):
# [*, 1, no_heads, N_res, N_res]
z
=
permute_final_dims
(
z
,
2
,
0
,
1
).
unsqueeze
(
-
4
)
# [*, N_seq, no_heads, N_res, N_res]
bias
=
bias
+
z
biases
.
append
(
z
)
mha_inputs
=
{
"q_x"
:
m
,
"k_x"
:
m
,
"v_x"
:
m
,
"biases"
:
[
bias
]
"biases"
:
bias
es
}
if
(
not
self
.
training
and
self
.
chunk_size
is
not
None
):
m
=
chunk_layer
(
...
...
openfold/model/triangular_attention.py
View file @
9ce8713c
...
...
@@ -96,17 +96,6 @@ class TriangleAttention(nn.Module):
# [*, 1, H, I, J]
triangle_bias
=
triangle_bias
.
unsqueeze
(
-
4
)
# Broadcasting and chunking doesn't really work yet (TODO)
# [*, I, H, I, J]
i
=
x
.
shape
[
-
3
]
triangle_bias
=
triangle_bias
.
expand
(
(
*
((
-
1
,)
*
len
(
triangle_bias
.
shape
[:
-
4
])),
i
,
-
1
,
-
1
,
-
1
)
)
#print(x.shape)
#print(mask_bias.shape)
#print(triangle_bias.shape)
mha_inputs
=
{
"q_x"
:
x
,
"k_x"
:
x
,
...
...
openfold/utils/feats.py
View file @
9ce8713c
...
...
@@ -486,7 +486,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
to_concat
=
[
dgram
,
template_mask_2d
[...,
None
]]
aatype_one_hot
=
nn
.
functional
.
one_hot
(
batch
[
"template_aatype"
],
batch
[
"target_feat"
].
shape
[
-
1
]
batch
[
"template_aatype"
],
residue_constants
.
restype_num
+
2
,
)
n_res
=
batch
[
"template_aatype"
].
shape
[
-
1
]
...
...
@@ -502,34 +502,17 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
)
n
,
ca
,
c
=
[
residue_constants
.
atom_order
[
a
]
for
a
in
[
'N'
,
'CA'
,
'C'
]]
#t_aa_pos = batch["template_all_atom_positions"]
#affines = T.make_transform_from_reference(
# n_xyz=t_aa_pos[..., n],
# ca_xyz=t_aa_pos[..., ca],
# c_xyz=t_aa_pos[..., c],
#)
#rots = affines.rots
#trans = affines.trans
#affine_vec = rot_mul_vec(
# rots.transpose(-1, -2),
# trans[..., None, :, :] - trans[..., None, :],
#)
#inverted_dists = torch.rsqrt(eps + torch.sum(inverted_dists**2, dim=-1))
t_aa_masks
=
batch
[
"template_all_atom_masks"
]
template_mask
=
(
t_aa_masks
[...,
n
]
*
t_aa_masks
[...,
ca
]
*
t_aa_masks
[...,
c
]
)
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
#inverted_dists *= template_mask_2d
#unit_vector = affine_vec * inverted_dists.unsqueeze(-1)
#unit_vector = unit_vector.unsqueeze(-2)
unit_vector
=
template_mask_2d
.
new_zeros
(
*
template_mask_2d
.
shape
,
3
)
to_concat
.
append
(
unit_vector
)
to_concat
.
append
(
template_mask_2d
[...,
None
])
act
=
torch
.
cat
(
to_concat
,
dim
=-
1
)
act
*=
template_mask_2d
[...,
None
]
...
...
openfold/utils/loss.py
View file @
9ce8713c
...
...
@@ -32,7 +32,7 @@ from openfold.utils.tensor_utils import (
def
softmax_cross_entropy
(
logits
,
labels
):
loss
=
-
1
*
torch
.
sum
(
labels
*
torch
.
nn
.
functional
.
log_softmax
(
logits
),
labels
*
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
),
dim
=-
1
,
)
return
loss
...
...
@@ -219,18 +219,19 @@ def supervised_chi_loss(
chi_weight
:
float
,
angle_norm_weight
:
float
,
eps
=
1e-6
,
**
kwargs
,
)
->
torch
.
Tensor
:
pred_angles
=
angles_sin_cos
[...,
3
:,
:]
residue_type_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
aatype
,
residue_constants
.
restype_num
+
1
,
)
.
unsqueeze
(
-
3
)
)
chi_pi_periodic
=
torch
.
einsum
(
"...ij,jk->ik"
,
residue_type_one_hot
,
a
atype
.
new_tensor
(
residue_constants
.
chi_pi_periodic
)
residue_type_one_hot
.
type
(
angles_sin_cos
.
dtype
)
,
a
ngles_sin_cos
.
new_tensor
(
residue_constants
.
chi_pi_periodic
)
,
)
true_chi
=
chi_angles
.
unsqueeze
(
-
3
)
true_chi
=
chi_angles
sin_true_chi
=
torch
.
sin
(
true_chi
)
cos_true_chi
=
torch
.
cos
(
true_chi
)
sin_cos_true_chi
=
torch
.
stack
([
sin_true_chi
,
cos_true_chi
],
dim
=-
1
)
...
...
@@ -247,9 +248,9 @@ def supervised_chi_loss(
sq_chi_error
=
torch
.
minimum
(
sq_chi_error
,
sq_chi_error_shifted
)
sq_chi_loss
=
masked_mean
(
sq_chi_error
,
chi_mask
.
unsqueeze
(
-
3
),
dim
=
(
-
1
,
-
2
,
-
3
)
chi_mask
,
sq_chi_error
,
dim
=
(
-
1
,
-
2
)
)
loss
=
0
loss
+=
chi_weight
*
sq_chi_loss
...
...
@@ -258,7 +259,7 @@ def supervised_chi_loss(
)
norm_error
=
torch
.
abs
(
angle_norm
-
1.
)
angle_norm_loss
=
masked_mean
(
norm_error
,
sequence
_mask
[...,
None
,
:,
None
]
,
dim
=
(
-
1
,
-
2
,
-
3
)
seq
_mask
[...,
None
],
norm_error
,
dim
=
(
-
1
,
-
2
)
)
loss
+=
angle_norm_weight
*
angle_norm_loss
...
...
@@ -390,11 +391,11 @@ def distogram_loss(
keepdims
=
True
)
true_bins
=
torch
.
sum
(
dists
>
sq_break
s
,
dim
=-
1
)
true_bins
=
torch
.
sum
(
dists
>
boundarie
s
,
dim
=-
1
)
errors
=
softmax_cross_entropy
(
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_bins
,
n
um
_bins
),
torch
.
nn
.
functional
.
one_hot
(
true_bins
,
n
o
_bins
),
)
square_mask
=
pseudo_beta_mask
[...,
None
]
*
pseudo_beta_mask
[...,
None
,
:]
...
...
@@ -1240,6 +1241,7 @@ def experimentally_resolved_loss(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
return
loss
...
...
openfold/utils/tensor_utils.py
View file @
9ce8713c
...
...
@@ -43,6 +43,19 @@ def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
return
torch
.
bucketize
(
dists
,
boundaries
)
def
dict_multimap
(
fn
,
dicts
):
first
=
dicts
[
0
]
new_dict
=
{}
for
k
,
v
in
first
.
items
():
all_v
=
[
d
[
k
]
for
d
in
dicts
]
if
(
type
(
v
)
is
dict
):
new_dict
[
k
]
=
dict_multimap
(
all_v
)
else
:
new_dict
[
k
]
=
fn
(
all_v
)
return
new_dict
def
stack_tensor_dicts
(
dicts
):
first
=
dicts
[
0
]
new_dict
=
{}
...
...
@@ -154,13 +167,12 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
orig_batch_dims
=
[
max
(
s
)
for
s
in
zip
(
*
initial_dims
)]
def
prep_inputs
(
t
):
t
=
t
.
expand
(
*
orig_batch_dims
,
*
t
.
shape
[
no_batch_dims
:])
# TODO: make this more memory efficient. This sucks
if
(
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
):
t
=
t
.
expand
(
*
orig_batch_dims
,
*
t
.
shape
[
no_batch_dims
:])
t
=
t
.
reshape
(
-
1
,
*
t
.
shape
[
no_batch_dims
:])
return
t
#shape = lambda t: t.shape
#print(tensor_tree_map(shape, inputs))
flattened_inputs
=
tensor_tree_map
(
prep_inputs
,
inputs
)
flat_batch_dim
=
1
...
...
@@ -175,7 +187,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
out
=
None
for
_
in
range
(
no_chunks
):
# Chunk the input
select_chunk
=
lambda
t
:
t
[
i
:
i
+
chunk_size
]
select_chunk
=
lambda
t
:
t
[
i
:
i
+
chunk_size
]
if
t
.
shape
[
0
]
!=
1
else
t
chunks
=
tensor_tree_map
(
select_chunk
,
flattened_inputs
)
# Run the layer on the chunk
...
...
run_pretrained_alphafold.py
View file @
9ce8713c
...
...
@@ -14,8 +14,8 @@
# limitations under the License.
import
os
import
sys
sys
.
path
.
append
(
"lib/conda/lib/python3.9/site-packages"
)
#
import sys
#
sys.path.append("lib/conda/lib/python3.9/site-packages")
import
math
import
pickle
...
...
@@ -26,9 +26,13 @@ import numpy as np
from
config
import
model_config
from
openfold.model.model
import
AlphaFold
import
openfold.np.protein
as
protein
from
openfold.np
import
residue_constants
,
protein
#os.environ["OPENMM_DEFAULT_PLATFORM"] = "CPU"
os
.
environ
[
"OPENMM_DEFAULT_PLATFORM"
]
=
"OpenCL"
#os.environ["OPENMM_CPU_THREADS"] = "16"
import
openfold.np.relax.relax
as
relax
from
openfold.np
import
residue_constants
from
openfold.utils.import_weights
import
(
import_jax_weights_
,
)
...
...
@@ -41,37 +45,35 @@ from openfold.utils.tensor_utils import (
MODEL_NAME
=
"model_1"
MODEL_DEVICE
=
"cuda:1"
PARAM_PATH
=
"openfold/resources/params/params_model_1.npz"
FEAT_PATH
=
"tests/test_data/sample_feats.pickle"
#
FEAT_PATH = "tests/test_data/sample_feats.pickle"
FEAT_PATH
=
"prediction/1OJN_feats.pickle"
config
=
model_config
(
MODEL_NAME
)
model
=
AlphaFold
(
config
.
model
)
model
=
model
.
eval
()
import_jax_weights_
(
model
,
PARAM_PATH
)
model_device
=
'cuda:1'
model
=
model
.
to
(
model_device
)
model
=
model
.
to
(
MODEL_DEVICE
)
with
open
(
FEAT_PATH
,
"rb"
)
as
f
:
batch
=
pickle
.
load
(
f
)
batch
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
model_device
)
for
k
,
v
in
batch
.
items
()}
longs
=
[
"aatype"
,
"template_aatype"
,
"extra_msa"
,
"residx_atom37_to_atom14"
,
"residx_atom14_to_atom37"
,
]
for
l
in
longs
:
batch
[
l
]
=
batch
[
l
].
long
()
# Move the recycling dimension to the end
move_dim
=
lambda
t
:
t
.
permute
(
*
range
(
len
(
t
.
shape
))[
1
:],
0
).
contiguous
()
batch
=
tensor_tree_map
(
move_dim
,
batch
)
with
torch
.
no_grad
():
batch
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
MODEL_DEVICE
)
for
k
,
v
in
batch
.
items
()}
longs
=
[
"aatype"
,
"template_aatype"
,
"extra_msa"
,
"residx_atom37_to_atom14"
,
"residx_atom14_to_atom37"
,
]
for
l
in
longs
:
batch
[
l
]
=
batch
[
l
].
long
()
# Move the recycling dimension to the end
move_dim
=
lambda
t
:
t
.
permute
(
*
range
(
len
(
t
.
shape
))[
1
:],
0
).
contiguous
()
batch
=
tensor_tree_map
(
move_dim
,
batch
)
t
=
time
.
time
()
out
=
model
(
batch
)
print
(
f
"Inference time:
{
time
.
time
()
-
t
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment