Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
eb49136d
Commit
eb49136d
authored
Sep 30, 2021
by
Gustaf Ahdritz
Browse files
Finish purge of in-place ops, get grads working, add TM
parent
1d47c1e7
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
152 additions
and
141 deletions
+152
-141
config.py
config.py
+15
-7
openfold/model/dropout.py
openfold/model/dropout.py
+1
-1
openfold/model/embedders.py
openfold/model/embedders.py
+0
-1
openfold/model/heads.py
openfold/model/heads.py
+6
-6
openfold/model/model.py
openfold/model/model.py
+24
-20
openfold/model/msa.py
openfold/model/msa.py
+1
-3
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+1
-1
openfold/model/pair_transition.py
openfold/model/pair_transition.py
+1
-1
openfold/model/primitives.py
openfold/model/primitives.py
+3
-3
openfold/model/structure_module.py
openfold/model/structure_module.py
+13
-14
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+1
-3
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+1
-1
openfold/utils/affine_utils.py
openfold/utils/affine_utils.py
+5
-11
openfold/utils/feats.py
openfold/utils/feats.py
+23
-18
openfold/utils/loss.py
openfold/utils/loss.py
+56
-38
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+1
-13
No files found.
config.py
View file @
eb49136d
...
...
@@ -23,8 +23,8 @@ blocks_per_ckpt = mlc.FieldReference(1, field_type=int)
chunk_size
=
mlc
.
FieldReference
(
None
,
field_type
=
int
)
aux_distogram_bins
=
mlc
.
FieldReference
(
64
)
eps
=
1e-
4
inf
=
1e
4
eps
=
1e-
8
inf
=
1e
8
config
=
mlc
.
ConfigDict
({
"model"
:
{
...
...
@@ -33,7 +33,7 @@ config = mlc.ConfigDict({
"c_t"
:
c_t
,
"c_e"
:
c_e
,
"c_s"
:
c_s
,
"no_cycles"
:
4
,
"no_cycles"
:
2
,
#
4,
"_mask_trans"
:
False
,
"input_embedder"
:
{
"tf_dim"
:
22
,
...
...
@@ -117,7 +117,7 @@ config = mlc.ConfigDict({
"inf"
:
inf
,
#1e9,
"eps"
:
eps
,
#1e-10,
},
"enabled"
:
False
,
#
True,
"enabled"
:
True
,
},
"evoformer_stack"
:
{
"c_m"
:
c_m
,
...
...
@@ -147,7 +147,7 @@ config = mlc.ConfigDict({
"no_qk_points"
:
4
,
"no_v_points"
:
8
,
"dropout_rate"
:
0.1
,
"no_blocks"
:
8
,
"no_blocks"
:
2
,
#
8,
"no_transition_layers"
:
1
,
"no_resnet_blocks"
:
2
,
"no_angles"
:
7
,
...
...
@@ -165,10 +165,10 @@ config = mlc.ConfigDict({
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
},
"tm
_score
"
:
{
"tm"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
"enabled"
:
Fals
e
,
"enabled"
:
Tru
e
,
},
"masked_msa"
:
{
"c_m"
:
c_m
,
...
...
@@ -239,6 +239,14 @@ config = mlc.ConfigDict({
"eps"
:
eps
,
#1e-6,
"weight"
:
0.
,
},
"tm"
:
{
"max_bin"
:
31
,
"no_bins"
:
64
,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"eps"
:
eps
,
#1e-8,
"weight"
:
1.0
,
},
"eps"
:
eps
,
},
})
openfold/model/dropout.py
View file @
eb49136d
...
...
@@ -53,7 +53,7 @@ class Dropout(nn.Module):
if
(
self
.
batch_dim
is
not
None
):
for
bd
in
self
.
batch_dim
:
shape
[
bd
]
=
1
mask
=
x
.
new_ones
(
shape
,
requires_grad
=
False
)
mask
=
x
.
new_ones
(
shape
)
mask
=
self
.
dropout
(
mask
)
x
=
x
*
mask
return
x
...
...
openfold/model/embedders.py
View file @
eb49136d
...
...
@@ -194,7 +194,6 @@ class RecyclingEmbedder(nn.Module):
self
.
max_bin
,
self
.
no_bins
,
dtype
=
x
.
dtype
,
requires_grad
=
False
,
device
=
x
.
device
)
...
...
openfold/model/heads.py
View file @
eb49136d
...
...
@@ -40,9 +40,9 @@ class AuxiliaryHeads(nn.Module):
**
config
[
"experimentally_resolved"
],
)
if
(
config
.
tm
_score
.
enabled
):
self
.
tm
_score
=
TMScoreHead
(
**
config
[
"tm_score"
]
,
if
(
config
.
tm
.
enabled
):
self
.
tm
=
TMScoreHead
(
**
config
.
tm
,
)
self
.
config
=
config
...
...
@@ -68,9 +68,9 @@ class AuxiliaryHeads(nn.Module):
experimentally_resolved_logits
)
if
(
self
.
config
.
tm
_score
.
enabled
):
tm_
score_
logits
=
self
.
tm
_score
(
outputs
[
"pair"
])
aux_out
[
"tm_
score_
logits"
]
=
tm_
score_
logits
if
(
self
.
config
.
tm
.
enabled
):
tm_logits
=
self
.
tm
(
outputs
[
"pair"
])
aux_out
[
"tm_logits"
]
=
tm_logits
return
aux_out
...
...
openfold/model/model.py
View file @
eb49136d
...
...
@@ -115,10 +115,6 @@ class AlphaFold(nn.Module):
batch
,
)
#tensor_dtype = (
# single_template_feats["template_all_atom_masks"].dtype
#)
# Build template angle feats
angle_feats
=
atom37_to_torsion_angles
(
single_template_feats
[
"template_aatype"
],
...
...
@@ -127,10 +123,6 @@ class AlphaFold(nn.Module):
eps
=
self
.
config
.
template
.
eps
,
)
#angle_feats = tensor_tree_map(
# lambda t: t.type(tensor_dtype), angle_feats
#)
template_angle_feat
=
build_template_angle_feat
(
angle_feats
,
single_template_feats
[
"template_aatype"
],
...
...
@@ -211,19 +203,16 @@ class AlphaFold(nn.Module):
# [*, N, C_m]
m_1_prev
=
m
.
new_zeros
(
(
*
batch_dims
,
n
,
self
.
config
.
c_m
),
requires_grad
=
False
,
)
# [*, N, N, C_z]
z_prev
=
z
.
new_zeros
(
(
*
batch_dims
,
n
,
n
,
self
.
config
.
c_z
),
requires_grad
=
False
,
)
# [*, N, 3]
x_prev
=
z
.
new_zeros
(
(
*
batch_dims
,
n
,
residue_constants
.
atom_type_num
,
3
),
requires_grad
=
False
,
)
x_prev
=
pseudo_beta_fn
(
...
...
@@ -241,7 +230,7 @@ class AlphaFold(nn.Module):
)
# [*, S_c, N, C_m]
m
[...,
0
,
:,
:]
+
=
m_1_prev_emb
m
[...,
0
,
:,
:]
=
m
[...,
0
,
:,
:]
+
m_1_prev_emb
# [*, N, N, C_z]
z
=
z
+
z_prev_emb
...
...
@@ -312,6 +301,7 @@ class AlphaFold(nn.Module):
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
)
outputs
[
"final_atom_mask"
]
=
feats
[
"atom37_atom_exists"
]
outputs
[
"final_affine_tensor"
]
=
outputs
[
"sm"
][
"frames"
][
-
1
]
# Save embeddings for use during the next recycling iteration
...
...
@@ -342,6 +332,16 @@ class AlphaFold(nn.Module):
self
.
config
.
extra_msa
.
extra_msa_stack
.
blocks_per_ckpt
)
def
_disable_grad
(
self
):
vals
=
[
p
.
requires_grad
for
p
in
self
.
parameters
()]
for
p
in
self
.
parameters
():
p
.
requires_grad_
(
False
)
return
vals
def
_enable_grad
(
self
,
vals
):
for
p
,
v
in
zip
(
self
.
parameters
(),
vals
):
p
.
requires_grad_
(
v
)
def
forward
(
self
,
batch
):
"""
Args:
...
...
@@ -391,12 +391,13 @@ class AlphaFold(nn.Module):
for which C_alpha is used instead)
"template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask
"""
"""
# Initialize recycling embeddings
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
# Disable activation checkpointing until the final recycling layer
self
.
_disable_activation_checkpointing
()
grad_vals
=
self
.
_disable_grad
()
# Main recycling loop
for
cycle_no
in
range
(
self
.
config
.
no_cycles
):
...
...
@@ -405,14 +406,17 @@ class AlphaFold(nn.Module):
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter
=
(
cycle_no
==
self
.
config
.
no_cycles
-
1
)
if
(
self
.
training
and
is_final_iter
):
is_final_iter
=
(
cycle_no
==
(
self
.
config
.
no_cycles
-
1
)
)
if
(
is_final_iter
):
self
.
_enable_activation_checkpointing
()
with
torch
.
set_grad_enabled
(
self
.
training
and
is_final_iter
):
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
m_1_prev
,
z_prev
,
x_prev
,
)
self
.
_enable_grad
(
grad_vals
)
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
m_1_prev
,
z_prev
,
x_prev
,
)
# Run auxiliary heads
outputs
.
update
(
self
.
aux_heads
(
outputs
))
return
outputs
openfold/model/msa.py
View file @
eb49136d
...
...
@@ -94,10 +94,8 @@ class MSAAttention(nn.Module):
n_seq
,
n_res
=
m
.
shape
[
-
3
:
-
1
]
if
(
mask
is
None
):
# [*, N_seq, N_res]
mask
=
torch
.
ones
(
mask
=
m
.
new_
ones
(
m
.
shape
[:
-
3
]
+
(
n_seq
,
n_res
),
device
=
m
.
device
,
requires_grad
=
False
)
# [*, N_seq, 1, 1, N_res]
...
...
openfold/model/outer_product_mean.py
View file @
eb49136d
...
...
@@ -70,7 +70,7 @@ class OuterProductMean(nn.Module):
[*, N_res, N_res, C_z] pair embedding update
"""
if
(
mask
is
None
):
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
]
,
requires_grad
=
False
)
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
# [*, N_seq, N_res, C_m]
m
=
self
.
layer_norm
(
m
)
...
...
openfold/model/pair_transition.py
View file @
eb49136d
...
...
@@ -64,7 +64,7 @@ class PairTransition(nn.Module):
"""
# DISCREPANCY: DeepMind forgets to apply the mask in this module.
if
(
mask
is
None
):
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
]
,
requires_grad
=
False
)
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
# [*, N_res, N_res, 1]
mask
=
mask
.
unsqueeze
(
-
1
)
...
...
openfold/model/primitives.py
View file @
eb49136d
...
...
@@ -251,10 +251,10 @@ class Attention(nn.Module):
permute_final_dims
(
k
,
(
0
,
2
,
3
,
1
)),
# [*, H, C_hidden, K]
)
norm
=
1
/
math
.
sqrt
(
self
.
c_hidden
)
# [1]
a
*
=
norm
a
=
a
*
norm
if
(
biases
is
not
None
):
for
b
in
biases
:
a
+
=
b
a
=
a
+
b
a
=
self
.
softmax
(
a
)
#print(torch.any(torch.isnan(a)))
...
...
@@ -330,7 +330,7 @@ class GlobalAttention(nn.Module):
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
a
+
=
bias
a
=
a
+
bias
a
=
self
.
softmax
(
a
)
# [*, N_res, H, C_hidden]
...
...
openfold/model/structure_module.py
View file @
eb49136d
...
...
@@ -27,7 +27,7 @@ from openfold.np.residue_constants import (
)
from
openfold.utils.affine_utils
import
T
,
quat_to_rot
from
openfold.utils.tensor_utils
import
(
stack_tensor_dicts
,
dict_multimap
,
permute_final_dims
,
flatten_final_dims
,
)
...
...
@@ -337,10 +337,15 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# [*, H, 3, N_res, P_v]
o_pt
=
torch
.
matmul
(
a
.
unsqueeze
(
-
3
),
# [*, H, 1, N_res, N_res]
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
)),
# [*, H, 3, N_res, P_v]
o_pt
=
torch
.
sum
(
(
a
[...,
None
,
:,
:,
None
]
*
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
),
dim
=-
2
)
# [*, N_res, H, P_v, 3]
...
...
@@ -702,7 +707,7 @@ class StructureModule(nn.Module):
"""
if
(
mask
is
None
):
# [*, N]
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
]
,
requires_grad
=
False
)
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
# [*, N, C_s]
s
=
self
.
layer_norm_s
(
s
)
...
...
@@ -718,7 +723,7 @@ class StructureModule(nn.Module):
t
=
T
.
identity
(
s
.
shape
[:
-
1
],
s
.
dtype
,
s
.
device
,
self
.
training
)
outputs
=
[]
for
l
in
range
(
self
.
no_blocks
):
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
t
,
mask
)
s
=
self
.
ipa_dropout
(
s
)
...
...
@@ -751,10 +756,10 @@ class StructureModule(nn.Module):
outputs
.
append
(
preds
)
if
(
l
<
self
.
no_blocks
-
1
):
if
(
i
<
(
self
.
no_blocks
-
1
)
)
:
t
=
t
.
stop_rot_gradient
()
outputs
=
stack_tensor_dicts
(
outputs
)
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
[
"single"
]
=
s
return
outputs
...
...
@@ -765,27 +770,23 @@ class StructureModule(nn.Module):
restype_rigid_group_default_frame
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
if
(
self
.
group_idx
is
None
):
self
.
group_idx
=
torch
.
tensor
(
restype_atom14_to_rigid_group
,
device
=
device
,
requires_grad
=
False
,
)
if
(
self
.
atom_mask
is
None
):
self
.
atom_mask
=
torch
.
tensor
(
restype_atom14_mask
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
if
(
self
.
lit_positions
is
None
):
self
.
lit_positions
=
torch
.
tensor
(
restype_atom14_rigid_group_positions
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
def
torsion_angles_to_frames
(
self
,
t
,
alpha
,
f
):
...
...
@@ -799,8 +800,6 @@ class StructureModule(nn.Module):
f
# [*, N]
):
# Lazily initialize the residue constants on the correct device
# TODO: Maybe this stuff should be done on CPU instead (so these
# arrays
self
.
_init_residue_constants
(
t
.
rots
.
dtype
,
t
.
rots
.
device
)
return
_frames_and_literature_positions_to_atom14_pos
(
t
,
...
...
openfold/model/triangular_attention.py
View file @
eb49136d
...
...
@@ -73,10 +73,8 @@ class TriangleAttention(nn.Module):
"""
if
(
mask
is
None
):
# [*, I, J]
mask
=
torch
.
ones
(
mask
=
x
.
new_
ones
(
x
.
shape
[:
-
1
],
device
=
x
.
device
,
requires_grad
=
False
,
)
# Shape annotations assume self.starting. Else, I and J are flipped
...
...
openfold/model/triangular_multiplicative_update.py
View file @
eb49136d
...
...
@@ -91,7 +91,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
[*, N_res, N_res, C_z] output tensor
"""
if
(
mask
is
None
):
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
]
,
requires_grad
=
False
)
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
mask
.
unsqueeze
(
-
1
)
...
...
openfold/utils/affine_utils.py
View file @
eb49136d
...
...
@@ -163,7 +163,7 @@ class T:
return
trans
@
staticmethod
def
identity
(
shape
,
dtype
,
device
,
requires_grad
=
Fals
e
):
def
identity
(
shape
,
dtype
,
device
,
requires_grad
=
Tru
e
):
return
T
(
T
.
identity_rot
(
shape
,
dtype
,
device
,
requires_grad
),
T
.
identity_trans
(
shape
,
dtype
,
device
,
requires_grad
),
...
...
@@ -191,11 +191,6 @@ class T:
e0
=
origin
-
p_neg_x_axis
e1
=
p_xy_plane
-
origin
# Angle norming is very sensitive to floating point imprecisions
#float_type = e0.dtype
#e0 = e0.float()
#e1 = e1.float()
e0
=
e0
/
torch
.
sqrt
(
torch
.
sum
(
e0
**
2
,
dim
=-
1
,
keepdims
=
True
)
+
eps
)
e1
=
e1
-
e0
*
torch
.
sum
(
e0
*
e1
,
dim
=-
1
,
keepdims
=
True
)
e1
=
e1
/
torch
.
sqrt
(
torch
.
sum
(
e1
**
2
,
dim
=-
1
,
keepdims
=
True
)
+
eps
)
...
...
@@ -203,8 +198,6 @@ class T:
rots
=
torch
.
stack
([
e0
,
e1
,
e2
],
dim
=-
1
)
#rots = rots.type(float_type)
return
T
(
rots
,
origin
)
@
staticmethod
...
...
@@ -221,7 +214,8 @@ class T:
return
T
(
rots
,
trans
)
def
map_tensor_fn
(
self
,
fn
):
""" Apply a function that takes a tensor as its only argument to the
"""
Apply a function that takes a tensor as its only argument to the
rotations and translations, treating the final two/one
dimension(s), respectively, as batch dimensions.
...
...
@@ -253,7 +247,7 @@ class T:
n_xyz
=
n_xyz
+
translation
c_xyz
=
c_xyz
+
translation
c_x
,
c_y
,
c_z
=
[
c_xyz
[...,
i
]
for
i
in
range
(
3
)]
c_x
,
c_y
,
c_z
=
[
c_xyz
[...,
i
]
for
i
in
range
(
3
)]
norm
=
torch
.
sqrt
(
eps
+
c_x
**
2
+
c_y
**
2
)
sin_c1
=
-
c_y
/
norm
cos_c1
=
c_x
/
norm
...
...
@@ -278,7 +272,7 @@ class T:
c1_rots
[...,
2
,
0
]
=
-
1
*
sin_c2
c1_rots
[...,
2
,
2
]
=
cos_c2
c_rots
=
rot_matmul
(
c2_rot
_matrix
,
c1_rot
_matrix
)
c_rots
=
rot_matmul
(
c2_rot
s
,
c1_rot
s
)
n_xyz
=
rot_vec_mul
(
c_rots
,
n_xyz
)
_
,
n_y
,
n_z
=
[
n_xyz
[...,
i
]
for
i
in
range
(
3
)]
...
...
openfold/utils/feats.py
View file @
eb49136d
...
...
@@ -151,7 +151,7 @@ def atom14_to_atom37(atom14, batch):
no_batch_dims
=
len
(
atom14
.
shape
[:
-
2
]),
)
atom37_data
*
=
batch
[
"atom37_atom_exists"
][...,
None
]
atom37_data
=
atom37_data
*
batch
[
"atom37_atom_exists"
][...,
None
]
return
atom37_data
...
...
@@ -288,7 +288,7 @@ def atom37_to_torsion_angles(
)
torsion_angles_sin_cos
=
torsion_angles_sin_cos
/
denom
torsion_angles_sin_cos
*
=
torch
.
tensor
(
torsion_angles_sin_cos
=
torsion_angles_sin_cos
*
torch
.
tensor
(
[
1.
,
1.
,
-
1.
,
1.
,
1.
,
1.
,
1.
],
device
=
aatype
.
device
,
)[((
None
,)
*
len
(
torsion_angles_sin_cos
.
shape
[:
-
2
]))
+
(
slice
(
None
),
None
)]
...
...
@@ -335,11 +335,8 @@ def atom37_to_frames(
restype
,
chi_idx
+
4
,
:
]
=
names
[
1
:]
restype_rigidgroup_mask
=
torch
.
zeros
(
restype_rigidgroup_mask
=
all_atom_mask
.
new_
zeros
(
(
*
aatype
.
shape
[:
-
1
],
21
,
8
),
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
,
requires_grad
=
False
)
restype_rigidgroup_mask
[...,
0
]
=
1
restype_rigidgroup_mask
[...,
3
]
=
1
...
...
@@ -399,7 +396,7 @@ def atom37_to_frames(
gt_exists
=
torch
.
min
(
gt_atoms_exist
,
dim
=-
1
)[
0
]
*
group_exists
rots
=
torch
.
eye
(
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
,
requires_grad
=
False
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
)
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
[...,
0
,
0
,
0
]
=
-
1
...
...
@@ -411,7 +408,7 @@ def atom37_to_frames(
*
((
1
,)
*
batch_dims
),
21
,
8
)
restype_rigidgroup_rots
=
torch
.
eye
(
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
,
requires_grad
=
False
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
)
restype_rigidgroup_rots
=
torch
.
tile
(
restype_rigidgroup_rots
,
...
...
@@ -476,7 +473,7 @@ def build_template_angle_feat(angle_feats, template_aatype):
return
template_angle_feat
def
build_template_pair_feat
(
batch
,
min_bin
,
max_bin
,
no_bins
,
eps
=
1e-
6
,
inf
=
1e8
):
def
build_template_pair_feat
(
batch
,
min_bin
,
max_bin
,
no_bins
,
eps
=
1e-
20
,
inf
=
1e8
):
template_mask
=
batch
[
"template_pseudo_beta_mask"
]
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
...
...
@@ -507,20 +504,30 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
)
n
,
ca
,
c
=
[
rc
.
atom_order
[
a
]
for
a
in
[
'N'
,
'CA'
,
'C'
]]
affines
=
T
.
make_transform_from_reference
(
n_xyz
=
batch
[
"template_all_atom_positions"
][...,
n
,
:],
ca_xyz
=
batch
[
"template_all_atom_positions"
][...,
ca
,
:],
c_xyz
=
batch
[
"template_all_atom_positions"
][...,
c
,
:],
)
points
=
affines
.
get_trans
()[...,
None
,
:,
:]
affine_vec
=
affines
[...,
None
].
invert_apply
(
points
)
inv_distance_scalar
=
torch
.
rsqrt
(
eps
+
torch
.
sum
(
affine_vec
**
2
,
dim
=-
1
)
)
t_aa_masks
=
batch
[
"template_all_atom_masks"
]
template_mask
=
(
t_aa_masks
[...,
n
]
*
t_aa_masks
[...,
ca
]
*
t_aa_masks
[...,
c
]
)
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
unit_vector
=
template_mask_2d
.
new_zeros
(
*
template_mask_2d
.
shape
,
3
)
to_concat
.
append
(
unit_vector
)
inv_distance_scalar
=
inv_distance_scalar
*
template_mask_2d
unit_vector
=
(
affine_vec
*
inv_distance_scalar
[...,
None
])
to_concat
.
extend
(
torch
.
unbind
(
unit_vector
[...,
None
,
:],
dim
=-
1
))
to_concat
.
append
(
template_mask_2d
[...,
None
])
act
=
torch
.
cat
(
to_concat
,
dim
=-
1
)
act
*=
template_mask_2d
[...,
None
]
act
=
act
*
template_mask_2d
[...,
None
]
return
act
...
...
@@ -594,7 +601,7 @@ def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None:
"""
ambiguous_atoms
=
(
batch
[
"atom14_gt_positions"
].
new_tensor
(
rc
.
restype_atom14_ambiguous_atoms
,
requires_grad
=
False
,
rc
.
restype_atom14_ambiguous_atoms
)
)
...
...
@@ -603,9 +610,7 @@ def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None:
# Swap pairs of ambiguous positions
swap_idx
=
rc
.
restype_atom14_ambiguous_atoms_swap_idx
swap_mat
=
np
.
eye
(
swap_idx
.
shape
[
-
1
])[
swap_idx
]
# one-hot swap_idx
swap_mat
=
batch
[
"atom14_gt_positions"
].
new_tensor
(
swap_mat
,
requires_grad
=
False
)
swap_mat
=
batch
[
"atom14_gt_positions"
].
new_tensor
(
swap_mat
)
swap_mat
=
swap_mat
[
batch
[
"aatype"
],
...]
atom14_alt_gt_positions
=
(
torch
.
sum
(
...
...
openfold/utils/loss.py
View file @
eb49136d
...
...
@@ -97,8 +97,8 @@ def compute_fape(
error_dist
=
torch
.
clamp
(
error_dist
,
min
=
0
,
max
=
l1_clamp_distance
)
normed_error
=
error_dist
/
length_scale
normed_error
*
=
frames_mask
[...,
None
]
normed_error
*
=
positions_mask
[...,
None
,
:]
normed_error
=
normed_error
*
frames_mask
[...,
None
]
normed_error
=
normed_error
*
positions_mask
[...,
None
,
:]
# FP16-friendly averaging. Roughly equivalent to:
#
...
...
@@ -291,7 +291,7 @@ def supervised_chi_loss(
)
loss
=
0
loss
+
=
chi_weight
*
sq_chi_loss
loss
=
loss
+
chi_weight
*
sq_chi_loss
angle_norm
=
torch
.
sqrt
(
torch
.
sum
(
unnormalized_angles_sin_cos
**
2
,
dim
=-
1
)
+
eps
...
...
@@ -304,7 +304,7 @@ def supervised_chi_loss(
seq_mask
[...,
None
,
:,
None
],
norm_error
,
dim
=
(
-
1
,
-
2
,
-
3
)
)
loss
+
=
angle_norm_weight
*
angle_norm_loss
loss
=
loss
+
angle_norm_weight
*
angle_norm_loss
return
loss
...
...
@@ -380,7 +380,7 @@ def lddt_loss(
(
dist_l1
<
2.0
).
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
4.0
).
type
(
dist_l1
.
dtype
)
)
score
*
=
0.25
score
=
score
*
0.25
norm
=
1.
/
(
eps
+
torch
.
sum
(
dists_to_score
,
dim
=-
1
))
score
=
norm
*
(
eps
+
torch
.
sum
(
dists_to_score
*
score
,
dim
=-
1
))
...
...
@@ -400,7 +400,7 @@ def lddt_loss(
(
eps
+
torch
.
sum
(
all_atom_mask
,
dim
=-
1
))
)
loss
*
=
(
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
...
...
@@ -452,50 +452,60 @@ def distogram_loss(
return
mean
def
tm_
score
(
def
tm_
loss
(
logits
,
t_pred
,
t_gt
,
mask
,
final_affine_tensor
,
backbone_affine_tensor
,
backbone_affine_
mask
,
resolution
,
max_bin
=
31
,
no_bins
=
64
,
min_resolution
:
float
=
0.1
,
max_resolution
:
float
=
3.0
,
eps
=
1e-8
eps
=
1e-8
,
**
kwargs
,
):
boundaries
=
torch
.
linspace
(
min
=
0
,
max
=
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
)
boundaries
=
boundaries
**
2
pred_affine
=
T
.
from_4x4
(
final_affine_tensor
)
backbone_affine
=
T
.
from_4x4
(
backbone_affine_tensor
)
def
_points
(
affine
):
pts
=
affine
.
trans
.
unsqueeze
(
-
3
)
return
affine
.
invert
()
.
apply
(
pts
,
addl_dims
=
1
)
pts
=
affine
.
get_
trans
()[...,
None
,
:,
:]
return
affine
.
invert
()
[...,
None
].
apply
(
pts
)
sq_diff
=
torch
.
sum
((
_points
(
t_pred
)
-
_points
(
t_gt
))
**
2
,
dim
=-
1
)
sq_diff
=
torch
.
sum
(
(
_points
(
pred_affine
)
-
_points
(
backbone_affine
))
**
2
,
dim
=-
1
)
sq_diff
=
sq_diff
.
detach
()
boundaries
=
torch
.
linspace
(
0
,
max_bin
,
steps
=
(
no_bins
-
1
),
device
=
logits
.
device
)
boundaries
=
boundaries
**
2
true_bins
=
torch
.
sum
(
sq_diff
[...,
None
]
>
boundaries
)
.
float
()
sq_diff
[...,
None
]
>
boundaries
,
dim
=-
1
)
errors
=
softmax_cross_entropy
(
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_bins
,
no_bins
)
)
square_mask
=
mask
[...,
None
]
*
mask
[...,
None
,
:]
loss
=
(
torch
.
sum
(
loss
,
dim
=
(
-
1
,
-
2
))
/
(
eps
+
torch
.
sum
(
square_mask
,
dim
=
(
-
1
,
-
2
)))
square_mask
=
(
backbone_affine_mask
[...,
None
]
*
backbone_affine_mask
[...,
None
,
:]
)
loss
*=
(
loss
=
torch
.
sum
(
errors
*
square_mask
,
dim
=-
1
)
scale
=
0.1
# hack to help FP16 training along
denom
=
eps
+
torch
.
sum
(
scale
*
square_mask
,
dim
=
(
-
1
,
-
2
))
loss
=
loss
/
denom
[...,
None
]
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
loss
/
scale
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
...
...
@@ -729,7 +739,7 @@ def between_residue_clash_loss(
# Mask out all the duplicate entries in the lower triangular matrix.
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
# are handled separately.
dists_mask
*
=
(
dists_mask
=
dists_mask
*
(
residue_index
[...,
:,
None
,
None
,
None
]
<
residue_index
[...,
None
,
:,
None
,
None
]
)
...
...
@@ -758,7 +768,7 @@ def between_residue_clash_loss(
c_one_hot
[...,
None
,
None
,
:,
None
]
*
n_one_hot
[...,
None
,
None
,
None
,
:]
)
dists_mask
*
=
(
1.
-
c_n_bonds
)
dists_mask
=
dists_mask
*
(
1.
-
c_n_bonds
)
# Disulfide bridge between two cysteines is no clash.
cys
=
residue_constants
.
restype_name_to_atom14_names
[
"CYS"
]
...
...
@@ -773,7 +783,7 @@ def between_residue_clash_loss(
disulfide_bonds
=
(
cys_sg_one_hot
[...,
None
,
None
,
:,
None
]
*
cys_sg_one_hot
[...,
None
,
None
,
None
,
:])
dists_mask
*
=
(
1.
-
disulfide_bonds
)
dists_mask
=
dists_mask
*
(
1.
-
disulfide_bonds
)
# Compute the lower bound for the allowed distances.
# shape (N, N, 14, 14)
...
...
@@ -1038,7 +1048,7 @@ def find_structural_violations_np(
atom14_pred_positions
:
np
.
ndarray
,
config
:
ml_collections
.
ConfigDict
)
->
Dict
[
str
,
np
.
ndarray
]:
to_tensor
=
lambda
x
:
torch
.
tensor
(
x
,
requires_grad
=
False
)
to_tensor
=
lambda
x
:
torch
.
tensor
(
x
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
atom14_pred_positions
=
to_tensor
(
atom14_pred_positions
)
...
...
@@ -1135,7 +1145,7 @@ def compute_violation_metrics_np(
atom14_pred_positions
:
np
.
ndarray
,
violations
:
Dict
[
str
,
np
.
ndarray
],
)
->
Dict
[
str
,
np
.
ndarray
]:
to_tensor
=
lambda
x
:
torch
.
tensor
(
x
,
requires_grad
=
False
)
to_tensor
=
lambda
x
:
torch
.
tensor
(
x
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
atom14_pred_positions
=
to_tensor
(
atom14_pred_positions
)
violations
=
tree_map
(
to_tensor
,
violations
,
np
.
ndarray
)
...
...
@@ -1285,10 +1295,11 @@ def experimentally_resolved_loss(
**
kwargs
,
)
->
torch
.
Tensor
:
errors
=
sigmoid_cross_entropy
(
logits
,
all_atom_mask
)
loss_num
=
torch
.
sum
(
errors
*
atom37_atom_exists
,
dim
=
(
-
1
,
-
2
))
loss
=
loss_num
/
(
eps
+
torch
.
sum
(
atom37_atom_exists
,
dim
=
(
-
1
,
-
2
)))
loss
=
torch
.
sum
(
errors
*
atom37_atom_exists
,
dim
=-
1
)
loss
=
loss
/
(
eps
+
torch
.
sum
(
atom37_atom_exists
,
dim
=
(
-
1
,
-
2
)))
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
*
=
(
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
...
...
@@ -1307,11 +1318,13 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
# torch.sum(errors * bert_mask, dim=(-1, -2)) /
# (eps + torch.sum(bert_mask, dim=(-1, -2)))
# )
denom
=
eps
+
torch
.
sum
(
bert_mask
,
dim
=
(
-
1
,
-
2
))
loss
=
errors
*
bert_mask
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
scale
=
0.1
denom
=
eps
+
torch
.
sum
(
scale
*
bert_mask
,
dim
=
(
-
1
,
-
2
))
loss
=
loss
/
denom
[...,
None
]
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
loss
/
scale
return
loss
...
...
@@ -1403,6 +1416,11 @@ class AlphaFoldLoss(nn.Module):
out
[
"violation"
],
**
batch
,
),
"tm"
:
lambda
:
tm_loss
(
logits
=
out
[
"tm_logits"
],
**
{
**
batch
,
**
out
,
**
self
.
config
.
tm
},
),
}
cum_loss
=
0
...
...
openfold/utils/tensor_utils.py
View file @
eb49136d
...
...
@@ -57,19 +57,6 @@ def dict_multimap(fn, dicts):
return
new_dict
def
stack_tensor_dicts
(
dicts
):
first
=
dicts
[
0
]
new_dict
=
{}
for
k
,
v
in
first
.
items
():
all_v
=
[
d
[
k
]
for
d
in
dicts
]
if
(
type
(
v
)
is
dict
):
new_dict
[
k
]
=
stack_tensor_dicts
(
all_v
)
else
:
new_dict
[
k
]
=
torch
.
stack
(
all_v
)
return
new_dict
def
one_hot
(
x
,
v_bins
):
reshaped_bins
=
v_bins
.
view
(((
1
,)
*
len
(
x
.
shape
))
+
(
len
(
v_bins
),))
diffs
=
x
[...,
None
]
-
reshaped_bins
...
...
@@ -119,6 +106,7 @@ def tree_map(fn, tree, leaf_type):
tensor_tree_map
=
partial
(
tree_map
,
leaf_type
=
torch
.
Tensor
)
def
chunk_layer
(
layer
:
Callable
,
inputs
:
Dict
[
str
,
Any
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment