Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
68ba77e5
"...git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "43baf787bcc3ceefc007c0bafb3b122a56911cc9"
Commit
68ba77e5
authored
Sep 24, 2021
by
Gustaf Ahdritz
Browse files
Continue fixing loss bugs, clean up structure module docs
parent
33941e46
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
237 additions
and
105 deletions
+237
-105
openfold/model/structure_module.py
openfold/model/structure_module.py
+18
-10
openfold/np/residue_constants.py
openfold/np/residue_constants.py
+23
-0
openfold/utils/affine_utils.py
openfold/utils/affine_utils.py
+12
-0
openfold/utils/feats.py
openfold/utils/feats.py
+113
-58
openfold/utils/loss.py
openfold/utils/loss.py
+38
-37
tests/utils.py
tests/utils.py
+33
-0
No files found.
openfold/model/structure_module.py
View file @
68ba77e5
...
...
@@ -486,12 +486,12 @@ def _frames_and_literature_positions_to_atom14_pos(
):
# [*, N, 14, 4, 4]
default_4x4
=
default_frames
[
f
,...]
default_4x4
=
default_frames
[
f
,
...]
# [*, N, 14]
group_mask
=
group_idx
[
f
,...]
group_mask
=
group_idx
[
f
,
...]
# [N, 14, 8]
# [
*,
N, 14, 8]
group_mask
=
nn
.
functional
.
one_hot
(
group_mask
,
num_classes
=
default_frames
.
shape
[
-
3
],
)
...
...
@@ -504,11 +504,11 @@ def _frames_and_literature_positions_to_atom14_pos(
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
)
# [N, 14, 1]
# [
*,
N, 14, 1]
atom_mask
=
atom_mask
[
f
,...].
unsqueeze
(
-
1
)
# [N, 14, 3]
lit_positions
=
lit_positions
[
f
,...]
# [
*,
N, 14, 3]
lit_positions
=
lit_positions
[
f
,
...]
pred_positions
=
t_atoms_to_global
.
apply
(
lit_positions
)
pred_positions
*=
atom_mask
...
...
@@ -758,19 +758,27 @@ class StructureModule(nn.Module):
def
_init_residue_constants
(
self
,
device
):
if
(
self
.
default_frames
is
None
):
self
.
default_frames
=
torch
.
tensor
(
restype_rigid_group_default_frame
,
device
=
device
,
restype_rigid_group_default_frame
,
device
=
device
,
requires_grad
=
False
,
)
if
(
self
.
group_idx
is
None
):
self
.
group_idx
=
torch
.
tensor
(
restype_atom14_to_rigid_group
,
device
=
device
,
restype_atom14_to_rigid_group
,
device
=
device
,
requires_grad
=
False
,
)
if
(
self
.
atom_mask
is
None
):
self
.
atom_mask
=
torch
.
tensor
(
restype_atom14_mask
,
device
=
device
,
restype_atom14_mask
,
device
=
device
,
requires_grad
=
False
,
)
if
(
self
.
lit_positions
is
None
):
self
.
lit_positions
=
torch
.
tensor
(
restype_atom14_rigid_group_positions
,
device
=
device
,
restype_atom14_rigid_group_positions
,
device
=
device
,
requires_grad
=
False
,
)
def
torsion_angles_to_frames
(
self
,
t
,
alpha
,
f
):
...
...
openfold/np/residue_constants.py
View file @
68ba77e5
...
...
@@ -366,6 +366,7 @@ residue_atoms = {
# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
# in LEU, VAL and ARG can be resolved by using the 3d constellations of
# the 'ambiguous' atoms and their neighbours)
# TODO: ^ interpret this
residue_atom_renaming_swaps
=
{
'ASP'
:
{
'OD1'
:
'OD2'
},
'GLU'
:
{
'OE1'
:
'OE2'
},
...
...
@@ -895,3 +896,25 @@ def make_atom14_dists_bounds(overlap_tolerance=1.5,
'upper_bound'
:
restype_atom14_bond_upper_bound
,
# shape (21,14,14)
'stddev'
:
restype_atom14_bond_stddev
,
# shape (21,14,14)
}
restype_atom14_ambiguous_atoms
=
np
.
zeros
((
21
,
14
),
dtype
=
np
.
float32
)
restype_atom14_ambiguous_atoms_swap_idx
=
(
np
.
tile
(
np
.
arange
(
14
,
dtype
=
np
.
int
),
(
21
,
1
))
)
def
_make_atom14_ambiguity_feats
():
for
res
,
pairs
in
residue_atom_renaming_swaps
.
items
():
res_idx
=
restype_order
[
restype_3to1
[
res
]]
for
atom1
,
atom2
in
pairs
.
items
():
atom1_idx
=
restype_name_to_atom14_names
[
res
].
index
(
atom1
)
atom2_idx
=
restype_name_to_atom14_names
[
res
].
index
(
atom2
)
restype_atom14_ambiguous_atoms
[
res_idx
,
atom1_idx
]
=
1
restype_atom14_ambiguous_atoms
[
res_idx
,
atom2_idx
]
=
1
restype_atom14_ambiguous_atoms_swap_idx
[
res_idx
,
atom1_idx
]
=
(
atom2_idx
)
restype_atom14_ambiguous_atoms_swap_idx
[
res_idx
,
atom2_idx
]
=
(
atom1_idx
)
_make_atom14_ambiguity_feats
()
openfold/utils/affine_utils.py
View file @
68ba77e5
...
...
@@ -335,3 +335,15 @@ def quat_to_rot(
# [*, 3, 3]
return
torch
.
sum
(
quat
,
dim
=
(
-
3
,
-
4
))
def
affine_vector_to_4x4
(
vector
):
quats
=
vector
[...,
:
4
]
trans
=
vector
[...,
4
:]
four_by_four
=
torch
.
zeros
(
(
*
vector
.
shape
[:
-
1
],
4
,
4
),
device
=
vector
.
device
)
four_by_four
[...,
:
3
,
:
3
]
=
quat_to_rot
(
quats
)
four_by_four
[...,
:
3
,
3
]
=
trans
four_by_four
[...,
3
,
3
]
=
1
return
four_by_four
openfold/utils/feats.py
View file @
68ba77e5
...
...
@@ -18,7 +18,7 @@ import torch
import
torch.nn
as
nn
from
typing
import
Dict
import
openfold.np.residue_constants
as
r
esidue_constants
import
openfold.np.residue_constants
as
r
c
from
openfold.utils.affine_utils
import
T
from
openfold.utils.tensor_utils
import
(
batched_gather
,
...
...
@@ -27,9 +27,9 @@ from openfold.utils.tensor_utils import (
def
pseudo_beta_fn
(
aatype
,
all_atom_positions
,
all_atom_masks
):
is_gly
=
(
aatype
==
r
esidue_constants
.
restype_order
[
'G'
])
ca_idx
=
r
esidue_constants
.
atom_order
[
'CA'
]
cb_idx
=
r
esidue_constants
.
atom_order
[
'CB'
]
is_gly
=
(
aatype
==
r
c
.
restype_order
[
'G'
])
ca_idx
=
r
c
.
atom_order
[
'CA'
]
cb_idx
=
r
c
.
atom_order
[
'CB'
]
pseudo_beta
=
torch
.
where
(
is_gly
[...,
None
].
expand
(
*
((
-
1
,)
*
len
(
is_gly
.
shape
)),
3
),
all_atom_positions
[...,
ca_idx
,
:],
...
...
@@ -52,18 +52,18 @@ def get_chi_atom_indices():
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in r
esidue_constants
.restypes + unknown residue type
in the order specified in r
c
.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices
=
[]
for
residue_name
in
r
esidue_constants
.
restypes
:
residue_name
=
r
esidue_constants
.
restype_1to3
[
residue_name
]
residue_chi_angles
=
r
esidue_constants
.
chi_angles_atoms
[
residue_name
]
for
residue_name
in
r
c
.
restypes
:
residue_name
=
r
c
.
restype_1to3
[
residue_name
]
residue_chi_angles
=
r
c
.
chi_angles_atoms
[
residue_name
]
atom_indices
=
[]
for
chi_angle
in
residue_chi_angles
:
atom_indices
.
append
(
[
r
esidue_constants
.
atom_order
[
atom
]
for
atom
in
chi_angle
])
[
r
c
.
atom_order
[
atom
]
for
atom
in
chi_angle
])
for
_
in
range
(
4
-
len
(
atom_indices
)):
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
chi_atom_indices
.
append
(
atom_indices
)
...
...
@@ -74,6 +74,7 @@ def get_chi_atom_indices():
def
compute_residx
(
batch
):
out
=
{}
float_type
=
batch
[
"seq_mask"
].
dtype
aatype
=
batch
[
"aatype"
]
...
...
@@ -81,19 +82,20 @@ def compute_residx(batch):
restype_atom37_to_atom14
=
[]
# mapping (restype, atom37) --> atom14
restype_atom14_mask
=
[]
for
rt
in
residue_constants
.
restypes
:
atom_names
=
residue_constants
.
restype_name_to_atom14_names
[
residue_constants
.
restype_1to3
[
rt
]]
for
rt
in
rc
.
restypes
:
atom_names
=
rc
.
restype_name_to_atom14_names
[
rc
.
restype_1to3
[
rt
]
]
restype_atom14_to_atom37
.
append
([
(
r
esidue_constants
.
atom_order
[
name
]
if
name
else
0
)
(
r
c
.
atom_order
[
name
]
if
name
else
0
)
for
name
in
atom_names
])
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
restype_atom37_to_atom14
.
append
([
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
for
name
in
r
esidue_constants
.
atom_types
for
name
in
r
c
.
atom_types
])
restype_atom14_mask
.
append
(
...
...
@@ -118,24 +120,27 @@ def compute_residx(batch):
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
aatype
]
residx_atom14_mask
=
restype_atom14_mask
[
aatype
]
batch
[
'
atom14_
a
to
m_exists'
]
=
residx_atom14_
mask
batch
[
'residx_
atom14_
to_
atom
37'
]
=
residx_atom14_
to_atom37
out
[
"residx_
atom14_to
_atom37"
]
=
residx_atom14_
to_atom37
out
[
"
atom14_atom
_exists"
]
=
residx_atom14_
mask
# create the gather indices for mapping back
residx_atom37_to_atom14
=
restype_atom37_to_atom14
[
aatype
]
batch
[
'residx_atom37_to_atom14'
]
=
residx_atom37_to_atom14
out
[
"residx_atom37_to_atom14"
]
=
residx_atom37_to_atom14
# create the corresponding mask
restype_atom37_mask
=
torch
.
zeros
([
21
,
37
],
dtype
=
float_type
)
for
restype
,
restype_letter
in
enumerate
(
r
esidue_constants
.
restypes
):
restype_name
=
r
esidue_constants
.
restype_1to3
[
restype_letter
]
atom_names
=
r
esidue_constants
.
residue_atoms
[
restype_name
]
for
restype
,
restype_letter
in
enumerate
(
r
c
.
restypes
):
restype_name
=
r
c
.
restype_1to3
[
restype_letter
]
atom_names
=
r
c
.
residue_atoms
[
restype_name
]
for
atom_name
in
atom_names
:
atom_type
=
r
esidue_constants
.
atom_order
[
atom_name
]
atom_type
=
r
c
.
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
residx_atom37_mask
=
restype_atom37_mask
[
aatype
]
batch
[
'atom37_atom_exists'
]
=
residx_atom37_mask
out
[
"atom37_atom_exists"
]
=
residx_atom37_mask
return
out
def
atom14_to_atom37
(
atom14
,
batch
):
...
...
@@ -225,9 +230,9 @@ def atom37_to_torsion_angles(
all_atom_pos
,
atom_indices
,
-
2
,
len
(
atom_indices
.
shape
[:
-
2
])
)
chi_angles_mask
=
list
(
r
esidue_constants
.
chi_angles_mask
)
chi_angles_mask
=
list
(
r
c
.
chi_angles_mask
)
chi_angles_mask
.
append
([
0.
,
0.
,
0.
,
0.
])
chi_angles_mask
=
all_atom_
pos
.
new_tensor
(
chi_angles_mask
)
chi_angles_mask
=
all_atom_
mask
.
new_tensor
(
chi_angles_mask
)
chis_mask
=
chi_angles_mask
[
aatype
,
:]
...
...
@@ -282,7 +287,7 @@ def atom37_to_torsion_angles(
)[((
None
,)
*
len
(
torsion_angles_sin_cos
.
shape
[:
-
2
]))
+
(
slice
(
None
),
None
)]
chi_is_ambiguous
=
torsion_angles_sin_cos
.
new_tensor
(
r
esidue_constants
.
chi_pi_periodic
,
r
c
.
chi_pi_periodic
,
)[
aatype
,
...]
mirror_torsion_angles
=
torch
.
cat
(
...
...
@@ -307,6 +312,7 @@ def atom37_to_frames(
aatype
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
...
...
@@ -314,13 +320,14 @@ def atom37_to_frames(
restype_rigidgroup_base_atom_names
[:,
0
,
:]
=
[
'C'
,
'CA'
,
'N'
]
restype_rigidgroup_base_atom_names
[:,
3
,
:]
=
[
'CA'
,
'C'
,
'O'
]
for
restype
,
restype_letter
in
enumerate
(
r
esidue_constants
.
restypes
):
resname
=
r
esidue_constants
.
restype_1to3
[
restype_letter
]
for
restype
,
restype_letter
in
enumerate
(
r
c
.
restypes
):
resname
=
r
c
.
restype_1to3
[
restype_letter
]
for
chi_idx
in
range
(
4
):
if
(
r
esidue_constants
.
chi_angles_mask
[
restype
][
chi_idx
]):
names
=
r
esidue_constants
.
chi_angles_atoms
[
resname
][
chi_idx
]
if
(
r
c
.
chi_angles_mask
[
restype
][
chi_idx
]):
names
=
r
c
.
chi_angles_atoms
[
resname
][
chi_idx
]
restype_rigidgroup_base_atom_names
[
restype
,
chi_idx
+
4
,
:]
=
atom_names
[
1
:]
restype
,
chi_idx
+
4
,
:
]
=
names
[
1
:]
restype_rigidgroup_mask
=
torch
.
zeros
(
(
*
aatype
.
shape
[:
-
1
],
21
,
8
),
...
...
@@ -330,9 +337,11 @@ def atom37_to_frames(
)
restype_rigidgroup_mask
[:,
0
]
=
1
restype_rigidgroup_mask
[:,
3
]
=
1
restype_rigidgroup_mask
[:
20
,
4
:]
=
residue_constants
.
chi_angles_mask
restype_rigidgroup_mask
[:
20
,
4
:]
=
(
all_atom_mask
.
new_tensor
(
rc
.
chi_angles_mask
)
)
lookuptable
=
r
esidue_constants
.
atom_order
.
copy
()
lookuptable
=
r
c
.
atom_order
.
copy
()
lookuptable
[
''
]
=
0
lookup
=
np
.
vectorize
(
lambda
x
:
lookuptable
[
x
])
restype_rigidgroup_base_atom37_idx
=
lookup
(
...
...
@@ -349,7 +358,7 @@ def atom37_to_frames(
)
residx_rigidgroup_base_atom37_idx
=
batched_gather
(
res
idx
_rigidgroup_base_atom37_idx
,
res
type
_rigidgroup_base_atom37_idx
,
aatype
,
dim
=-
3
,
no_batch_dims
=
batch_dims
,
...
...
@@ -363,9 +372,9 @@ def atom37_to_frames(
)
gt_frames
=
T
.
from_3_points
(
p
oint_on
_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
p
oint_on
_xy_plane
=
base_atom_pos
[...,
2
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
)
group_exists
=
batched_gather
(
...
...
@@ -381,33 +390,31 @@ def atom37_to_frames(
dim
=-
1
,
no_batch_dims
=
len
(
all_atom_mask
.
shape
[:
-
1
])
)
gt_exists
=
torch
.
min
(
gt_atoms_exist
,
dim
=-
1
)
*
group_exists
gt_exists
=
torch
.
min
(
gt_atoms_exist
,
dim
=-
1
)
[
0
]
*
group_exists
rots
=
torch
.
eye
(
3
,
device
=
aatype
.
device
,
requires_grad
=
False
)
rots
=
rots
.
view
(
*
((
1
,)
*
batch_dims
),
1
,
3
,
3
)
rots
=
rots
.
expand
(
*
((
-
1
,)
*
batch_dims
),
8
,
-
1
,
-
1
)
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
gt_frames
=
gt_frames
.
compose
(
T
(
rots
,
None
))
gt_frames
=
gt_frames
.
compose
(
T
(
rots
,
None
))
restype_rigidgroup_is_ambiguous
=
all_atom_mask
.
new_zeros
(
*
((
1
,)
*
batch_dims
),
21
,
8
)
restype_rigidgroup_rots
=
torch
.
eye
(
3
,
device
=
aatype
.
device
,
requires_grad
=
False
)
restype_rigidgroup_rots
=
restype_rigidgroup_rots
.
view
(
*
((
1
,)
*
batch_dims
),
1
,
1
,
3
,
3
)
restype_rigidgroup_rots
=
restype_rigidgroup_rots
.
expand
(
*
((
-
1
,)
*
batch_dims
),
21
,
8
,
3
,
3
restype_rigidgroup_rots
=
torch
.
tile
(
restype_rigidgroup_rots
,
(
*
((
1
,)
*
batch_dims
),
21
,
8
,
1
,
1
),
)
for
resname
,
_
in
r
esidue_constants
.
residue_atom_renaming_swaps
.
items
():
restype
=
r
esidue_constants
.
restype_order
[
r
esidue_constants
.
restype3to1
[
resname
]
for
resname
,
_
in
r
c
.
residue_atom_renaming_swaps
.
items
():
restype
=
r
c
.
restype_order
[
r
c
.
restype
_
3to1
[
resname
]
]
chi_idx
=
int
(
sum
(
r
esidue_constants
.
chi_angles_mask
[
restype
])
-
1
)
chi_idx
=
int
(
sum
(
r
c
.
chi_angles_mask
[
restype
])
-
1
)
restype_rigidgroup_is_ambiguous
[...,
restype
,
chi_idx
+
4
]
=
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
1
,
1
]
=
-
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
2
,
2
]
=
-
1
...
...
@@ -419,18 +426,17 @@ def atom37_to_frames(
no_batch_dims
=
batch_dims
,
)
residx_rigidgroup_ambiguity_rot
=
utils
.
batched_gather
(
residx_rigidgroup_ambiguity_rot
=
batched_gather
(
restype_rigidgroup_rots
,
aatype
,
dim
=-
4
,
no_batch_dims
=
batch_dims
,
)
alt_gt_frames
=
gt_frames
.
apply
(
T
(
residx_rigidgroup_ambiguity_rot
,
None
))
alt_gt_frames
=
gt_frames
.
compose
(
T
(
residx_rigidgroup_ambiguity_rot
,
None
))
# TODO: Verify that I can get away with skipping the flat12 format
gt_frames_tensor
=
gt_frames
.
to_tensor
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_tensor
()
gt_frames_tensor
=
gt_frames
.
to_4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_4x4
()
return
{
'rigidgroups_gt_frames'
:
gt_frames_tensor
,
...
...
@@ -477,7 +483,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
to_concat
=
[
dgram
,
template_mask_2d
[...,
None
]]
aatype_one_hot
=
nn
.
functional
.
one_hot
(
batch
[
"template_aatype"
],
r
esidue_constants
.
restype_num
+
2
,
batch
[
"template_aatype"
],
r
c
.
restype_num
+
2
,
)
n_res
=
batch
[
"template_aatype"
].
shape
[
-
1
]
...
...
@@ -492,7 +498,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
)
)
n
,
ca
,
c
=
[
r
esidue_constants
.
atom_order
[
a
]
for
a
in
[
'N'
,
'CA'
,
'C'
]]
n
,
ca
,
c
=
[
r
c
.
atom_order
[
a
]
for
a
in
[
'N'
,
'CA'
,
'C'
]]
t_aa_masks
=
batch
[
"template_all_atom_masks"
]
template_mask
=
(
...
...
@@ -522,7 +528,7 @@ def build_extra_msa_feat(batch):
# adapted from model/tf/data_transforms.py
def
build_msa_feat
(
protein
):
def
build_msa_feat
(
batch
):
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping
# for compatibility with domain datasets.
...
...
@@ -544,7 +550,7 @@ def build_msa_feat(protein):
deletion_value
.
unsqueeze
(
-
1
),
]
if
'cluster_profile'
in
protein
:
if
'cluster_profile'
in
batch
:
deletion_mean_value
=
(
tf
.
atan
(
batch
[
'cluster_deletion_mean'
]
/
3.
)
*
(
2.
/
np
.
pi
))
msa_feat
.
extend
([
...
...
@@ -560,4 +566,53 @@ def build_msa_feat(protein):
batch
[
'msa_feat'
]
=
torch
.
cat
(
msa_feat
,
dim
=-
1
)
batch
[
'target_feat'
]
=
torch
.
cat
(
target_feat
,
dim
=-
1
)
return
protein
return
batch
def
build_ambiguity_feats
(
batch
:
Dict
[
str
,
torch
.
Tensor
])
->
None
:
"""
Compute features required by compute_renamed_ground_truth (Alg. 26)
Args:
batch:
str/tensor dictionary containing:
* atom14_gt_positions: [*, N, 14, 3] ground truth pos.
* atom14_gt_exists: [*, N, 14] atom mask
* aatype: [*, N] residue indices
Returns:
str/tensor dictionary containing:
* atom14_atom_is_ambiguous: [*, N, 14] mask of ambiguous atoms
* atom14_alt_gt_positions: [*, N, 14, 3] renamed positions
"""
ambiguous_atoms
=
(
batch
[
"atom14_gt_positions"
].
new_tensor
(
rc
.
restype_atom14_ambiguous_atoms
,
requires_grad
=
False
,
)
)
atom14_atom_is_ambiguous
=
ambiguous_atoms
[
batch
[
"aatype"
],
...]
# Swap pairs of ambiguous positions
swap_idx
=
rc
.
restype_atom14_ambiguous_atoms_swap_idx
swap_mat
=
np
.
eye
(
swap_idx
.
shape
[
-
1
])[
swap_idx
]
# one-hot swap_idx
swap_mat
=
batch
[
"atom14_gt_positions"
].
new_tensor
(
swap_mat
,
requires_grad
=
False
)
swap_mat
=
swap_mat
[
batch
[
"aatype"
],
...]
atom14_alt_gt_positions
=
(
torch
.
sum
(
batch
[
"atom14_gt_positions"
][...,
None
,
:]
*
swap_mat
[...,
None
],
dim
=-
3
)
)
atom14_alt_gt_exists
=
(
torch
.
sum
(
batch
[
"atom14_gt_exists"
][...,
None
]
*
swap_mat
,
dim
=-
2
)
)
return
{
"atom14_atom_is_ambiguous"
:
atom14_atom_is_ambiguous
,
"atom14_alt_gt_positions"
:
atom14_alt_gt_positions
,
"atom14_alt_gt_exists"
:
atom14_alt_gt_exists
,
}
openfold/utils/loss.py
View file @
68ba77e5
...
...
@@ -89,15 +89,15 @@ def compute_fape(
target_positions
[...,
None
,
:,
:],
)
error_dist
=
torch
.
sqrt
(
(
pred_positions
-
target_positions
)
**
2
+
eps
torch
.
sum
((
local_pred_pos
-
local_target_pos
)
**
2
,
dim
=-
1
)
+
eps
)
if
(
l1_clamp_distance
is
not
None
):
error_dist
=
torch
.
clamp
(
error_dist
,
min
=
0
,
max
=
l1_clamp_distance
)
normed_error
=
error_dist
/
length_scale
normed_error
*=
frames_mask
.
unsqueeze
(
-
1
)
normed_error
*=
positions_mask
.
unsqueeze
(
-
2
)
normed_error
*=
frames_mask
[...,
None
]
normed_error
*=
positions_mask
[...,
None
,
:]
norm_factor
=
(
torch
.
sum
(
frames_mask
,
dim
=-
1
)
*
...
...
@@ -109,67 +109,71 @@ def compute_fape(
return
normed_error
# DISCREPANCY: figure out if loss clamping happens in 90% of each bach or in 90% of batches
def
backbone_loss
(
batch
:
Dict
[
str
,
torch
.
Tensor
],
pred_aff_tensor
:
torch
.
Tensor
,
backbone_affine_tensor
:
torch
.
Tensor
,
backbone_affine_mask
:
torch
.
Tensor
,
traj
:
torch
.
Tensor
,
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
clamp_distance
:
float
=
10.
,
loss_unit_distance
:
float
=
10.
,
**
kwargs
,
)
->
torch
.
Tensor
:
pred_aff
=
T
.
from_tensor
(
pred_aff_tensor
)
gt_aff
=
T
.
from_tensor
(
batch
[
"backbone_affine_tensor"
])
backbone_mask
=
batch
[
"backbone_affine_mask"
]
pred_aff
=
T
.
from_tensor
(
traj
)
gt_aff
=
T
.
from_tensor
(
backbone_affine_tensor
)
fape_loss
=
compute_fape
(
pred_aff
,
gt_aff
,
backbone_
mask
,
gt_aff
[...,
None
,
:]
,
backbone_
affine_mask
[...,
None
,
:]
,
pred_aff
.
get_trans
(),
gt_aff
.
get_trans
(),
backbone_
mask
,
gt_aff
[...,
None
,
:]
.
get_trans
(),
backbone_
affine_mask
[...,
None
,
:]
,
l1_clamp_distance
=
clamp_distance
,
length_scale
=
loss_unit_distance
,
)
if
(
'use_clamped_fape'
in
batch
):
use_clamped_fape
=
batch
[
"use_clamped_fape"
]
if
(
use_clamped_fape
is
not
None
):
unclamped_fape_loss
=
compute_fape
(
pred_aff
,
gt_aff
,
backbone_
mask
,
gt_aff
[...,
None
,
:]
,
backbone_
affine_mask
[...,
None
,
:]
,
pred_aff
.
get_trans
(),
gt_aff
.
get_trans
(),
backbone_
mask
,
gt_aff
[...,
None
,
:]
.
get_trans
(),
backbone_
affine_mask
[...,
None
,
:]
,
l1_clamp_distance
=
None
,
length_scale
=
loss_unit_distance
,
)
fape_loss
=
(
fape_loss
*
use_clamped_fape
+
fape_loss_unclamped
*
(
1
-
use_clamped_fape
)
unclamped_fape_loss
*
(
1
-
use_clamped_fape
)
)
return
torch
.
mean
(
fape_loss
,
dim
=-
1
)
def
sidechain_loss
(
sidechain_frames
,
sidechain_atom_pos
,
rigidgroups_gt_frames
,
rigidgroups_alt_gt_frames
,
rigidgroups_gt_exists
,
renamed_atom14_gt_positions
,
renamed_atom14_gt_exists
,
alt_naming_is_better
,
clamp_distance
=
10.
,
length_scale
=
10.
,
):
sidechain_frames
:
torch
.
Tensor
,
sidechain_atom_pos
:
torch
.
Tensor
,
rigidgroups_gt_frames
:
torch
.
Tensor
,
rigidgroups_alt_gt_frames
:
torch
.
Tensor
,
rigidgroups_gt_exists
:
torch
.
Tensor
,
renamed_atom14_gt_positions
:
torch
.
Tensor
,
renamed_atom14_gt_exists
:
torch
.
Tensor
,
alt_naming_is_better
:
torch
.
Tensor
,
clamp_distance
:
float
=
10.
,
length_scale
:
float
=
10.
,
**
kwargs
,
)
->
torch
.
Tensor
:
renamed_gt_frames
=
(
(
1.
-
alt_naming_is_better
[...,
None
,
None
,
None
,
None
])
*
gt_frames
+
rigidgroups_
gt_frames
+
alt_naming_is_better
[...,
None
,
None
,
None
,
None
]
*
alt_gt_frames
rigidgroups_
alt_gt_frames
)
sidechain_frames
=
T
.
from_4x4
(
sidechain_frames
)
renamed_gt_frames
=
T
.
from_4x4
(
renamed_gt_frames
)
fape
=
compute_fape
(
...
...
@@ -192,16 +196,13 @@ def fape_loss(
config
:
ml_collections
.
ConfigDict
,
)
->
torch
.
Tensor
:
bb_loss
=
backbone_loss
(
batch
,
out
[
"sm"
][
"frames"
]
[
-
1
]
,
**
config
.
backbone
traj
=
out
[
"sm"
][
"frames"
]
,
**
{
**
batch
,
**
config
.
backbone
},
)
sc_loss
=
sidechain_loss
(
out
[
"sm"
][
"sidechain_frames"
],
out
[
"sm"
][
"positions"
],
{
**
batch
,
**
config
.
sidechain
,
},
**
{
**
batch
,
**
config
.
sidechain
}
)
return
(
...
...
tests/utils.py
View file @
68ba77e5
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
numpy
as
np
from
scipy.spatial.transform
import
Rotation
def
random_template_feats
(
n_templ
,
n
,
batch_size
=
None
):
...
...
@@ -35,6 +36,7 @@ def random_template_feats(n_templ, n, batch_size=None):
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
astype
(
np
.
int64
)
return
batch
def
random_extra_msa_feats
(
n_extra
,
n
,
batch_size
=
None
):
b
=
[]
if
(
batch_size
is
not
None
):
...
...
@@ -50,3 +52,34 @@ def random_extra_msa_feats(n_extra, n, batch_size=None):
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
}
return
batch
def
random_affine_vectors
(
dim
):
prod_dim
=
1
for
d
in
dim
:
prod_dim
*=
d
affines
=
np
.
zeros
((
prod_dim
,
7
))
for
i
in
range
(
prod_dim
):
affines
[
i
,
:
4
]
=
Rotation
.
random
(
random_state
=
42
).
as_quat
()
affines
[
i
,
4
:]
=
np
.
random
.
rand
(
3
,)
return
affines
.
reshape
(
*
dim
,
7
)
def
random_affine_4x4s
(
dim
):
prod_dim
=
1
for
d
in
dim
:
prod_dim
*=
d
affines
=
np
.
zeros
((
prod_dim
,
4
,
4
))
for
i
in
range
(
prod_dim
):
affines
[
i
,
:
3
,
:
3
]
=
Rotation
.
random
(
random_state
=
42
).
as_matrix
()
affines
[
i
,
:
3
,
3
]
=
np
.
random
.
rand
(
3
,)
affines
[:,
3
,
3
]
=
1
return
affines
.
reshape
(
*
dim
,
4
,
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment