Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
68ba77e5
Commit
68ba77e5
authored
Sep 24, 2021
by
Gustaf Ahdritz
Browse files
Continue fixing loss bugs, clean up structure module docs
parent
33941e46
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
237 additions
and
105 deletions
+237
-105
openfold/model/structure_module.py
openfold/model/structure_module.py
+18
-10
openfold/np/residue_constants.py
openfold/np/residue_constants.py
+23
-0
openfold/utils/affine_utils.py
openfold/utils/affine_utils.py
+12
-0
openfold/utils/feats.py
openfold/utils/feats.py
+113
-58
openfold/utils/loss.py
openfold/utils/loss.py
+38
-37
tests/utils.py
tests/utils.py
+33
-0
No files found.
openfold/model/structure_module.py
View file @
68ba77e5
...
@@ -486,12 +486,12 @@ def _frames_and_literature_positions_to_atom14_pos(
...
@@ -486,12 +486,12 @@ def _frames_and_literature_positions_to_atom14_pos(
):
):
# [*, N, 14, 4, 4]
# [*, N, 14, 4, 4]
default_4x4
=
default_frames
[
f
,...]
default_4x4
=
default_frames
[
f
,
...]
# [*, N, 14]
# [*, N, 14]
group_mask
=
group_idx
[
f
,...]
group_mask
=
group_idx
[
f
,
...]
# [N, 14, 8]
# [
*,
N, 14, 8]
group_mask
=
nn
.
functional
.
one_hot
(
group_mask
=
nn
.
functional
.
one_hot
(
group_mask
,
num_classes
=
default_frames
.
shape
[
-
3
],
group_mask
,
num_classes
=
default_frames
.
shape
[
-
3
],
)
)
...
@@ -504,11 +504,11 @@ def _frames_and_literature_positions_to_atom14_pos(
...
@@ -504,11 +504,11 @@ def _frames_and_literature_positions_to_atom14_pos(
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
lambda
x
:
torch
.
sum
(
x
,
dim
=-
1
)
)
)
# [N, 14, 1]
# [
*,
N, 14, 1]
atom_mask
=
atom_mask
[
f
,...].
unsqueeze
(
-
1
)
atom_mask
=
atom_mask
[
f
,...].
unsqueeze
(
-
1
)
# [N, 14, 3]
# [
*,
N, 14, 3]
lit_positions
=
lit_positions
[
f
,...]
lit_positions
=
lit_positions
[
f
,
...]
pred_positions
=
t_atoms_to_global
.
apply
(
lit_positions
)
pred_positions
=
t_atoms_to_global
.
apply
(
lit_positions
)
pred_positions
*=
atom_mask
pred_positions
*=
atom_mask
...
@@ -758,19 +758,27 @@ class StructureModule(nn.Module):
...
@@ -758,19 +758,27 @@ class StructureModule(nn.Module):
def
_init_residue_constants
(
self
,
device
):
def
_init_residue_constants
(
self
,
device
):
if
(
self
.
default_frames
is
None
):
if
(
self
.
default_frames
is
None
):
self
.
default_frames
=
torch
.
tensor
(
self
.
default_frames
=
torch
.
tensor
(
restype_rigid_group_default_frame
,
device
=
device
,
restype_rigid_group_default_frame
,
device
=
device
,
requires_grad
=
False
,
)
)
if
(
self
.
group_idx
is
None
):
if
(
self
.
group_idx
is
None
):
self
.
group_idx
=
torch
.
tensor
(
self
.
group_idx
=
torch
.
tensor
(
restype_atom14_to_rigid_group
,
device
=
device
,
restype_atom14_to_rigid_group
,
device
=
device
,
requires_grad
=
False
,
)
)
if
(
self
.
atom_mask
is
None
):
if
(
self
.
atom_mask
is
None
):
self
.
atom_mask
=
torch
.
tensor
(
self
.
atom_mask
=
torch
.
tensor
(
restype_atom14_mask
,
device
=
device
,
restype_atom14_mask
,
device
=
device
,
requires_grad
=
False
,
)
)
if
(
self
.
lit_positions
is
None
):
if
(
self
.
lit_positions
is
None
):
self
.
lit_positions
=
torch
.
tensor
(
self
.
lit_positions
=
torch
.
tensor
(
restype_atom14_rigid_group_positions
,
device
=
device
,
restype_atom14_rigid_group_positions
,
device
=
device
,
requires_grad
=
False
,
)
)
def
torsion_angles_to_frames
(
self
,
t
,
alpha
,
f
):
def
torsion_angles_to_frames
(
self
,
t
,
alpha
,
f
):
...
...
openfold/np/residue_constants.py
View file @
68ba77e5
...
@@ -366,6 +366,7 @@ residue_atoms = {
...
@@ -366,6 +366,7 @@ residue_atoms = {
# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
# in LEU, VAL and ARG can be resolved by using the 3d constellations of
# in LEU, VAL and ARG can be resolved by using the 3d constellations of
# the 'ambiguous' atoms and their neighbours)
# the 'ambiguous' atoms and their neighbours)
# TODO: ^ interpret this
residue_atom_renaming_swaps
=
{
residue_atom_renaming_swaps
=
{
'ASP'
:
{
'OD1'
:
'OD2'
},
'ASP'
:
{
'OD1'
:
'OD2'
},
'GLU'
:
{
'OE1'
:
'OE2'
},
'GLU'
:
{
'OE1'
:
'OE2'
},
...
@@ -895,3 +896,25 @@ def make_atom14_dists_bounds(overlap_tolerance=1.5,
...
@@ -895,3 +896,25 @@ def make_atom14_dists_bounds(overlap_tolerance=1.5,
'upper_bound'
:
restype_atom14_bond_upper_bound
,
# shape (21,14,14)
'upper_bound'
:
restype_atom14_bond_upper_bound
,
# shape (21,14,14)
'stddev'
:
restype_atom14_bond_stddev
,
# shape (21,14,14)
'stddev'
:
restype_atom14_bond_stddev
,
# shape (21,14,14)
}
}
restype_atom14_ambiguous_atoms
=
np
.
zeros
((
21
,
14
),
dtype
=
np
.
float32
)
restype_atom14_ambiguous_atoms_swap_idx
=
(
np
.
tile
(
np
.
arange
(
14
,
dtype
=
np
.
int
),
(
21
,
1
))
)
def
_make_atom14_ambiguity_feats
():
for
res
,
pairs
in
residue_atom_renaming_swaps
.
items
():
res_idx
=
restype_order
[
restype_3to1
[
res
]]
for
atom1
,
atom2
in
pairs
.
items
():
atom1_idx
=
restype_name_to_atom14_names
[
res
].
index
(
atom1
)
atom2_idx
=
restype_name_to_atom14_names
[
res
].
index
(
atom2
)
restype_atom14_ambiguous_atoms
[
res_idx
,
atom1_idx
]
=
1
restype_atom14_ambiguous_atoms
[
res_idx
,
atom2_idx
]
=
1
restype_atom14_ambiguous_atoms_swap_idx
[
res_idx
,
atom1_idx
]
=
(
atom2_idx
)
restype_atom14_ambiguous_atoms_swap_idx
[
res_idx
,
atom2_idx
]
=
(
atom1_idx
)
_make_atom14_ambiguity_feats
()
openfold/utils/affine_utils.py
View file @
68ba77e5
...
@@ -335,3 +335,15 @@ def quat_to_rot(
...
@@ -335,3 +335,15 @@ def quat_to_rot(
# [*, 3, 3]
# [*, 3, 3]
return
torch
.
sum
(
quat
,
dim
=
(
-
3
,
-
4
))
return
torch
.
sum
(
quat
,
dim
=
(
-
3
,
-
4
))
def
affine_vector_to_4x4
(
vector
):
quats
=
vector
[...,
:
4
]
trans
=
vector
[...,
4
:]
four_by_four
=
torch
.
zeros
(
(
*
vector
.
shape
[:
-
1
],
4
,
4
),
device
=
vector
.
device
)
four_by_four
[...,
:
3
,
:
3
]
=
quat_to_rot
(
quats
)
four_by_four
[...,
:
3
,
3
]
=
trans
four_by_four
[...,
3
,
3
]
=
1
return
four_by_four
openfold/utils/feats.py
View file @
68ba77e5
...
@@ -18,7 +18,7 @@ import torch
...
@@ -18,7 +18,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Dict
from
typing
import
Dict
import
openfold.np.residue_constants
as
r
esidue_constants
import
openfold.np.residue_constants
as
r
c
from
openfold.utils.affine_utils
import
T
from
openfold.utils.affine_utils
import
T
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
batched_gather
,
batched_gather
,
...
@@ -27,9 +27,9 @@ from openfold.utils.tensor_utils import (
...
@@ -27,9 +27,9 @@ from openfold.utils.tensor_utils import (
def
pseudo_beta_fn
(
aatype
,
all_atom_positions
,
all_atom_masks
):
def
pseudo_beta_fn
(
aatype
,
all_atom_positions
,
all_atom_masks
):
is_gly
=
(
aatype
==
r
esidue_constants
.
restype_order
[
'G'
])
is_gly
=
(
aatype
==
r
c
.
restype_order
[
'G'
])
ca_idx
=
r
esidue_constants
.
atom_order
[
'CA'
]
ca_idx
=
r
c
.
atom_order
[
'CA'
]
cb_idx
=
r
esidue_constants
.
atom_order
[
'CB'
]
cb_idx
=
r
c
.
atom_order
[
'CB'
]
pseudo_beta
=
torch
.
where
(
pseudo_beta
=
torch
.
where
(
is_gly
[...,
None
].
expand
(
*
((
-
1
,)
*
len
(
is_gly
.
shape
)),
3
),
is_gly
[...,
None
].
expand
(
*
((
-
1
,)
*
len
(
is_gly
.
shape
)),
3
),
all_atom_positions
[...,
ca_idx
,
:],
all_atom_positions
[...,
ca_idx
,
:],
...
@@ -52,18 +52,18 @@ def get_chi_atom_indices():
...
@@ -52,18 +52,18 @@ def get_chi_atom_indices():
Returns:
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in r
esidue_constants
.restypes + unknown residue type
in the order specified in r
c
.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
positions indices are by default set to 0.
"""
"""
chi_atom_indices
=
[]
chi_atom_indices
=
[]
for
residue_name
in
r
esidue_constants
.
restypes
:
for
residue_name
in
r
c
.
restypes
:
residue_name
=
r
esidue_constants
.
restype_1to3
[
residue_name
]
residue_name
=
r
c
.
restype_1to3
[
residue_name
]
residue_chi_angles
=
r
esidue_constants
.
chi_angles_atoms
[
residue_name
]
residue_chi_angles
=
r
c
.
chi_angles_atoms
[
residue_name
]
atom_indices
=
[]
atom_indices
=
[]
for
chi_angle
in
residue_chi_angles
:
for
chi_angle
in
residue_chi_angles
:
atom_indices
.
append
(
atom_indices
.
append
(
[
r
esidue_constants
.
atom_order
[
atom
]
for
atom
in
chi_angle
])
[
r
c
.
atom_order
[
atom
]
for
atom
in
chi_angle
])
for
_
in
range
(
4
-
len
(
atom_indices
)):
for
_
in
range
(
4
-
len
(
atom_indices
)):
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
chi_atom_indices
.
append
(
atom_indices
)
chi_atom_indices
.
append
(
atom_indices
)
...
@@ -74,6 +74,7 @@ def get_chi_atom_indices():
...
@@ -74,6 +74,7 @@ def get_chi_atom_indices():
def
compute_residx
(
batch
):
def
compute_residx
(
batch
):
out
=
{}
float_type
=
batch
[
"seq_mask"
].
dtype
float_type
=
batch
[
"seq_mask"
].
dtype
aatype
=
batch
[
"aatype"
]
aatype
=
batch
[
"aatype"
]
...
@@ -81,19 +82,20 @@ def compute_residx(batch):
...
@@ -81,19 +82,20 @@ def compute_residx(batch):
restype_atom37_to_atom14
=
[]
# mapping (restype, atom37) --> atom14
restype_atom37_to_atom14
=
[]
# mapping (restype, atom37) --> atom14
restype_atom14_mask
=
[]
restype_atom14_mask
=
[]
for
rt
in
residue_constants
.
restypes
:
for
rt
in
rc
.
restypes
:
atom_names
=
residue_constants
.
restype_name_to_atom14_names
[
atom_names
=
rc
.
restype_name_to_atom14_names
[
residue_constants
.
restype_1to3
[
rt
]]
rc
.
restype_1to3
[
rt
]
]
restype_atom14_to_atom37
.
append
([
restype_atom14_to_atom37
.
append
([
(
r
esidue_constants
.
atom_order
[
name
]
if
name
else
0
)
(
r
c
.
atom_order
[
name
]
if
name
else
0
)
for
name
in
atom_names
for
name
in
atom_names
])
])
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
restype_atom37_to_atom14
.
append
([
restype_atom37_to_atom14
.
append
([
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
for
name
in
r
esidue_constants
.
atom_types
for
name
in
r
c
.
atom_types
])
])
restype_atom14_mask
.
append
(
restype_atom14_mask
.
append
(
...
@@ -118,24 +120,27 @@ def compute_residx(batch):
...
@@ -118,24 +120,27 @@ def compute_residx(batch):
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
aatype
]
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
aatype
]
residx_atom14_mask
=
restype_atom14_mask
[
aatype
]
residx_atom14_mask
=
restype_atom14_mask
[
aatype
]
batch
[
'
atom14_
a
to
m_exists'
]
=
residx_atom14_
mask
out
[
"residx_
atom14_to
_atom37"
]
=
residx_atom14_
to_atom37
batch
[
'residx_
atom14_
to_
atom
37'
]
=
residx_atom14_
to_atom37
out
[
"
atom14_atom
_exists"
]
=
residx_atom14_
mask
# create the gather indices for mapping back
# create the gather indices for mapping back
residx_atom37_to_atom14
=
restype_atom37_to_atom14
[
aatype
]
residx_atom37_to_atom14
=
restype_atom37_to_atom14
[
aatype
]
batch
[
'residx_atom37_to_atom14'
]
=
residx_atom37_to_atom14
out
[
"residx_atom37_to_atom14"
]
=
residx_atom37_to_atom14
# create the corresponding mask
# create the corresponding mask
restype_atom37_mask
=
torch
.
zeros
([
21
,
37
],
dtype
=
float_type
)
restype_atom37_mask
=
torch
.
zeros
([
21
,
37
],
dtype
=
float_type
)
for
restype
,
restype_letter
in
enumerate
(
r
esidue_constants
.
restypes
):
for
restype
,
restype_letter
in
enumerate
(
r
c
.
restypes
):
restype_name
=
r
esidue_constants
.
restype_1to3
[
restype_letter
]
restype_name
=
r
c
.
restype_1to3
[
restype_letter
]
atom_names
=
r
esidue_constants
.
residue_atoms
[
restype_name
]
atom_names
=
r
c
.
residue_atoms
[
restype_name
]
for
atom_name
in
atom_names
:
for
atom_name
in
atom_names
:
atom_type
=
r
esidue_constants
.
atom_order
[
atom_name
]
atom_type
=
r
c
.
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
restype_atom37_mask
[
restype
,
atom_type
]
=
1
residx_atom37_mask
=
restype_atom37_mask
[
aatype
]
residx_atom37_mask
=
restype_atom37_mask
[
aatype
]
batch
[
'atom37_atom_exists'
]
=
residx_atom37_mask
out
[
"atom37_atom_exists"
]
=
residx_atom37_mask
return
out
def
atom14_to_atom37
(
atom14
,
batch
):
def
atom14_to_atom37
(
atom14
,
batch
):
...
@@ -225,9 +230,9 @@ def atom37_to_torsion_angles(
...
@@ -225,9 +230,9 @@ def atom37_to_torsion_angles(
all_atom_pos
,
atom_indices
,
-
2
,
len
(
atom_indices
.
shape
[:
-
2
])
all_atom_pos
,
atom_indices
,
-
2
,
len
(
atom_indices
.
shape
[:
-
2
])
)
)
chi_angles_mask
=
list
(
r
esidue_constants
.
chi_angles_mask
)
chi_angles_mask
=
list
(
r
c
.
chi_angles_mask
)
chi_angles_mask
.
append
([
0.
,
0.
,
0.
,
0.
])
chi_angles_mask
.
append
([
0.
,
0.
,
0.
,
0.
])
chi_angles_mask
=
all_atom_
pos
.
new_tensor
(
chi_angles_mask
)
chi_angles_mask
=
all_atom_
mask
.
new_tensor
(
chi_angles_mask
)
chis_mask
=
chi_angles_mask
[
aatype
,
:]
chis_mask
=
chi_angles_mask
[
aatype
,
:]
...
@@ -282,7 +287,7 @@ def atom37_to_torsion_angles(
...
@@ -282,7 +287,7 @@ def atom37_to_torsion_angles(
)[((
None
,)
*
len
(
torsion_angles_sin_cos
.
shape
[:
-
2
]))
+
(
slice
(
None
),
None
)]
)[((
None
,)
*
len
(
torsion_angles_sin_cos
.
shape
[:
-
2
]))
+
(
slice
(
None
),
None
)]
chi_is_ambiguous
=
torsion_angles_sin_cos
.
new_tensor
(
chi_is_ambiguous
=
torsion_angles_sin_cos
.
new_tensor
(
r
esidue_constants
.
chi_pi_periodic
,
r
c
.
chi_pi_periodic
,
)[
aatype
,
...]
)[
aatype
,
...]
mirror_torsion_angles
=
torch
.
cat
(
mirror_torsion_angles
=
torch
.
cat
(
...
@@ -307,6 +312,7 @@ def atom37_to_frames(
...
@@ -307,6 +312,7 @@ def atom37_to_frames(
aatype
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
...
@@ -314,13 +320,14 @@ def atom37_to_frames(
...
@@ -314,13 +320,14 @@ def atom37_to_frames(
restype_rigidgroup_base_atom_names
[:,
0
,
:]
=
[
'C'
,
'CA'
,
'N'
]
restype_rigidgroup_base_atom_names
[:,
0
,
:]
=
[
'C'
,
'CA'
,
'N'
]
restype_rigidgroup_base_atom_names
[:,
3
,
:]
=
[
'CA'
,
'C'
,
'O'
]
restype_rigidgroup_base_atom_names
[:,
3
,
:]
=
[
'CA'
,
'C'
,
'O'
]
for
restype
,
restype_letter
in
enumerate
(
r
esidue_constants
.
restypes
):
for
restype
,
restype_letter
in
enumerate
(
r
c
.
restypes
):
resname
=
r
esidue_constants
.
restype_1to3
[
restype_letter
]
resname
=
r
c
.
restype_1to3
[
restype_letter
]
for
chi_idx
in
range
(
4
):
for
chi_idx
in
range
(
4
):
if
(
r
esidue_constants
.
chi_angles_mask
[
restype
][
chi_idx
]):
if
(
r
c
.
chi_angles_mask
[
restype
][
chi_idx
]):
names
=
r
esidue_constants
.
chi_angles_atoms
[
resname
][
chi_idx
]
names
=
r
c
.
chi_angles_atoms
[
resname
][
chi_idx
]
restype_rigidgroup_base_atom_names
[
restype_rigidgroup_base_atom_names
[
restype
,
chi_idx
+
4
,
:]
=
atom_names
[
1
:]
restype
,
chi_idx
+
4
,
:
]
=
names
[
1
:]
restype_rigidgroup_mask
=
torch
.
zeros
(
restype_rigidgroup_mask
=
torch
.
zeros
(
(
*
aatype
.
shape
[:
-
1
],
21
,
8
),
(
*
aatype
.
shape
[:
-
1
],
21
,
8
),
...
@@ -330,9 +337,11 @@ def atom37_to_frames(
...
@@ -330,9 +337,11 @@ def atom37_to_frames(
)
)
restype_rigidgroup_mask
[:,
0
]
=
1
restype_rigidgroup_mask
[:,
0
]
=
1
restype_rigidgroup_mask
[:,
3
]
=
1
restype_rigidgroup_mask
[:,
3
]
=
1
restype_rigidgroup_mask
[:
20
,
4
:]
=
residue_constants
.
chi_angles_mask
restype_rigidgroup_mask
[:
20
,
4
:]
=
(
all_atom_mask
.
new_tensor
(
rc
.
chi_angles_mask
)
)
lookuptable
=
r
esidue_constants
.
atom_order
.
copy
()
lookuptable
=
r
c
.
atom_order
.
copy
()
lookuptable
[
''
]
=
0
lookuptable
[
''
]
=
0
lookup
=
np
.
vectorize
(
lambda
x
:
lookuptable
[
x
])
lookup
=
np
.
vectorize
(
lambda
x
:
lookuptable
[
x
])
restype_rigidgroup_base_atom37_idx
=
lookup
(
restype_rigidgroup_base_atom37_idx
=
lookup
(
...
@@ -349,7 +358,7 @@ def atom37_to_frames(
...
@@ -349,7 +358,7 @@ def atom37_to_frames(
)
)
residx_rigidgroup_base_atom37_idx
=
batched_gather
(
residx_rigidgroup_base_atom37_idx
=
batched_gather
(
res
idx
_rigidgroup_base_atom37_idx
,
res
type
_rigidgroup_base_atom37_idx
,
aatype
,
aatype
,
dim
=-
3
,
dim
=-
3
,
no_batch_dims
=
batch_dims
,
no_batch_dims
=
batch_dims
,
...
@@ -363,9 +372,9 @@ def atom37_to_frames(
...
@@ -363,9 +372,9 @@ def atom37_to_frames(
)
)
gt_frames
=
T
.
from_3_points
(
gt_frames
=
T
.
from_3_points
(
p
oint_on
_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
p
oint_on
_xy_plane
=
base_atom_pos
[...,
2
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
)
)
group_exists
=
batched_gather
(
group_exists
=
batched_gather
(
...
@@ -381,33 +390,31 @@ def atom37_to_frames(
...
@@ -381,33 +390,31 @@ def atom37_to_frames(
dim
=-
1
,
dim
=-
1
,
no_batch_dims
=
len
(
all_atom_mask
.
shape
[:
-
1
])
no_batch_dims
=
len
(
all_atom_mask
.
shape
[:
-
1
])
)
)
gt_exists
=
torch
.
min
(
gt_atoms_exist
,
dim
=-
1
)
*
group_exists
gt_exists
=
torch
.
min
(
gt_atoms_exist
,
dim
=-
1
)
[
0
]
*
group_exists
rots
=
torch
.
eye
(
3
,
device
=
aatype
.
device
,
requires_grad
=
False
)
rots
=
torch
.
eye
(
3
,
device
=
aatype
.
device
,
requires_grad
=
False
)
rots
=
rots
.
view
(
*
((
1
,)
*
batch_dims
),
1
,
3
,
3
)
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
=
rots
.
expand
(
*
((
-
1
,)
*
batch_dims
),
8
,
-
1
,
-
1
)
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
gt_frames
=
gt_frames
.
compose
(
T
(
rots
,
None
))
gt_frames
=
gt_frames
.
compose
(
T
(
rots
,
None
))
restype_rigidgroup_is_ambiguous
=
all_atom_mask
.
new_zeros
(
restype_rigidgroup_is_ambiguous
=
all_atom_mask
.
new_zeros
(
*
((
1
,)
*
batch_dims
),
21
,
8
*
((
1
,)
*
batch_dims
),
21
,
8
)
)
restype_rigidgroup_rots
=
torch
.
eye
(
restype_rigidgroup_rots
=
torch
.
eye
(
3
,
device
=
aatype
.
device
,
requires_grad
=
False
3
,
device
=
aatype
.
device
,
requires_grad
=
False
)
)
restype_rigidgroup_rots
=
restype_rigidgroup_rots
.
view
(
restype_rigidgroup_rots
=
torch
.
tile
(
*
((
1
,)
*
batch_dims
),
1
,
1
,
3
,
3
restype_rigidgroup_rots
,
)
(
*
((
1
,)
*
batch_dims
),
21
,
8
,
1
,
1
),
restype_rigidgroup_rots
=
restype_rigidgroup_rots
.
expand
(
*
((
-
1
,)
*
batch_dims
),
21
,
8
,
3
,
3
)
)
for
resname
,
_
in
r
esidue_constants
.
residue_atom_renaming_swaps
.
items
():
for
resname
,
_
in
r
c
.
residue_atom_renaming_swaps
.
items
():
restype
=
r
esidue_constants
.
restype_order
[
restype
=
r
c
.
restype_order
[
r
esidue_constants
.
restype3to1
[
resname
]
r
c
.
restype
_
3to1
[
resname
]
]
]
chi_idx
=
int
(
sum
(
r
esidue_constants
.
chi_angles_mask
[
restype
])
-
1
)
chi_idx
=
int
(
sum
(
r
c
.
chi_angles_mask
[
restype
])
-
1
)
restype_rigidgroup_is_ambiguous
[...,
restype
,
chi_idx
+
4
]
=
1
restype_rigidgroup_is_ambiguous
[...,
restype
,
chi_idx
+
4
]
=
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
1
,
1
]
=
-
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
1
,
1
]
=
-
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
2
,
2
]
=
-
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
2
,
2
]
=
-
1
...
@@ -419,18 +426,17 @@ def atom37_to_frames(
...
@@ -419,18 +426,17 @@ def atom37_to_frames(
no_batch_dims
=
batch_dims
,
no_batch_dims
=
batch_dims
,
)
)
residx_rigidgroup_ambiguity_rot
=
utils
.
batched_gather
(
residx_rigidgroup_ambiguity_rot
=
batched_gather
(
restype_rigidgroup_rots
,
restype_rigidgroup_rots
,
aatype
,
aatype
,
dim
=-
4
,
dim
=-
4
,
no_batch_dims
=
batch_dims
,
no_batch_dims
=
batch_dims
,
)
)
alt_gt_frames
=
gt_frames
.
apply
(
T
(
residx_rigidgroup_ambiguity_rot
,
None
))
alt_gt_frames
=
gt_frames
.
compose
(
T
(
residx_rigidgroup_ambiguity_rot
,
None
))
# TODO: Verify that I can get away with skipping the flat12 format
gt_frames_tensor
=
gt_frames
.
to_4x4
()
gt_frames_tensor
=
gt_frames
.
to_tensor
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_tensor
()
return
{
return
{
'rigidgroups_gt_frames'
:
gt_frames_tensor
,
'rigidgroups_gt_frames'
:
gt_frames_tensor
,
...
@@ -477,7 +483,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
...
@@ -477,7 +483,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
to_concat
=
[
dgram
,
template_mask_2d
[...,
None
]]
to_concat
=
[
dgram
,
template_mask_2d
[...,
None
]]
aatype_one_hot
=
nn
.
functional
.
one_hot
(
aatype_one_hot
=
nn
.
functional
.
one_hot
(
batch
[
"template_aatype"
],
r
esidue_constants
.
restype_num
+
2
,
batch
[
"template_aatype"
],
r
c
.
restype_num
+
2
,
)
)
n_res
=
batch
[
"template_aatype"
].
shape
[
-
1
]
n_res
=
batch
[
"template_aatype"
].
shape
[
-
1
]
...
@@ -492,7 +498,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
...
@@ -492,7 +498,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-6, inf=1e8
)
)
)
)
n
,
ca
,
c
=
[
r
esidue_constants
.
atom_order
[
a
]
for
a
in
[
'N'
,
'CA'
,
'C'
]]
n
,
ca
,
c
=
[
r
c
.
atom_order
[
a
]
for
a
in
[
'N'
,
'CA'
,
'C'
]]
t_aa_masks
=
batch
[
"template_all_atom_masks"
]
t_aa_masks
=
batch
[
"template_all_atom_masks"
]
template_mask
=
(
template_mask
=
(
...
@@ -522,7 +528,7 @@ def build_extra_msa_feat(batch):
...
@@ -522,7 +528,7 @@ def build_extra_msa_feat(batch):
# adapted from model/tf/data_transforms.py
# adapted from model/tf/data_transforms.py
def
build_msa_feat
(
protein
):
def
build_msa_feat
(
batch
):
"""Create and concatenate MSA features."""
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping
# Whether there is a domain break. Always zero for chains, but keeping
# for compatibility with domain datasets.
# for compatibility with domain datasets.
...
@@ -544,7 +550,7 @@ def build_msa_feat(protein):
...
@@ -544,7 +550,7 @@ def build_msa_feat(protein):
deletion_value
.
unsqueeze
(
-
1
),
deletion_value
.
unsqueeze
(
-
1
),
]
]
if
'cluster_profile'
in
protein
:
if
'cluster_profile'
in
batch
:
deletion_mean_value
=
(
deletion_mean_value
=
(
tf
.
atan
(
batch
[
'cluster_deletion_mean'
]
/
3.
)
*
(
2.
/
np
.
pi
))
tf
.
atan
(
batch
[
'cluster_deletion_mean'
]
/
3.
)
*
(
2.
/
np
.
pi
))
msa_feat
.
extend
([
msa_feat
.
extend
([
...
@@ -560,4 +566,53 @@ def build_msa_feat(protein):
...
@@ -560,4 +566,53 @@ def build_msa_feat(protein):
batch
[
'msa_feat'
]
=
torch
.
cat
(
msa_feat
,
dim
=-
1
)
batch
[
'msa_feat'
]
=
torch
.
cat
(
msa_feat
,
dim
=-
1
)
batch
[
'target_feat'
]
=
torch
.
cat
(
target_feat
,
dim
=-
1
)
batch
[
'target_feat'
]
=
torch
.
cat
(
target_feat
,
dim
=-
1
)
return
protein
return
batch
def
build_ambiguity_feats
(
batch
:
Dict
[
str
,
torch
.
Tensor
])
->
None
:
"""
Compute features required by compute_renamed_ground_truth (Alg. 26)
Args:
batch:
str/tensor dictionary containing:
* atom14_gt_positions: [*, N, 14, 3] ground truth pos.
* atom14_gt_exists: [*, N, 14] atom mask
* aatype: [*, N] residue indices
Returns:
str/tensor dictionary containing:
* atom14_atom_is_ambiguous: [*, N, 14] mask of ambiguous atoms
* atom14_alt_gt_positions: [*, N, 14, 3] renamed positions
"""
ambiguous_atoms
=
(
batch
[
"atom14_gt_positions"
].
new_tensor
(
rc
.
restype_atom14_ambiguous_atoms
,
requires_grad
=
False
,
)
)
atom14_atom_is_ambiguous
=
ambiguous_atoms
[
batch
[
"aatype"
],
...]
# Swap pairs of ambiguous positions
swap_idx
=
rc
.
restype_atom14_ambiguous_atoms_swap_idx
swap_mat
=
np
.
eye
(
swap_idx
.
shape
[
-
1
])[
swap_idx
]
# one-hot swap_idx
swap_mat
=
batch
[
"atom14_gt_positions"
].
new_tensor
(
swap_mat
,
requires_grad
=
False
)
swap_mat
=
swap_mat
[
batch
[
"aatype"
],
...]
atom14_alt_gt_positions
=
(
torch
.
sum
(
batch
[
"atom14_gt_positions"
][...,
None
,
:]
*
swap_mat
[...,
None
],
dim
=-
3
)
)
atom14_alt_gt_exists
=
(
torch
.
sum
(
batch
[
"atom14_gt_exists"
][...,
None
]
*
swap_mat
,
dim
=-
2
)
)
return
{
"atom14_atom_is_ambiguous"
:
atom14_atom_is_ambiguous
,
"atom14_alt_gt_positions"
:
atom14_alt_gt_positions
,
"atom14_alt_gt_exists"
:
atom14_alt_gt_exists
,
}
openfold/utils/loss.py
View file @
68ba77e5
...
@@ -89,15 +89,15 @@ def compute_fape(
...
@@ -89,15 +89,15 @@ def compute_fape(
target_positions
[...,
None
,
:,
:],
target_positions
[...,
None
,
:,
:],
)
)
error_dist
=
torch
.
sqrt
(
error_dist
=
torch
.
sqrt
(
(
pred_positions
-
target_positions
)
**
2
+
eps
torch
.
sum
((
local_pred_pos
-
local_target_pos
)
**
2
,
dim
=-
1
)
+
eps
)
)
if
(
l1_clamp_distance
is
not
None
):
if
(
l1_clamp_distance
is
not
None
):
error_dist
=
torch
.
clamp
(
error_dist
,
min
=
0
,
max
=
l1_clamp_distance
)
error_dist
=
torch
.
clamp
(
error_dist
,
min
=
0
,
max
=
l1_clamp_distance
)
normed_error
=
error_dist
/
length_scale
normed_error
=
error_dist
/
length_scale
normed_error
*=
frames_mask
.
unsqueeze
(
-
1
)
normed_error
*=
frames_mask
[...,
None
]
normed_error
*=
positions_mask
.
unsqueeze
(
-
2
)
normed_error
*=
positions_mask
[...,
None
,
:]
norm_factor
=
(
norm_factor
=
(
torch
.
sum
(
frames_mask
,
dim
=-
1
)
*
torch
.
sum
(
frames_mask
,
dim
=-
1
)
*
...
@@ -109,67 +109,71 @@ def compute_fape(
...
@@ -109,67 +109,71 @@ def compute_fape(
return
normed_error
return
normed_error
# DISCREPANCY: figure out if loss clamping happens in 90% of each bach or in 90% of batches
def
backbone_loss
(
def
backbone_loss
(
batch
:
Dict
[
str
,
torch
.
Tensor
],
backbone_affine_tensor
:
torch
.
Tensor
,
pred_aff_tensor
:
torch
.
Tensor
,
backbone_affine_mask
:
torch
.
Tensor
,
traj
:
torch
.
Tensor
,
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
clamp_distance
:
float
=
10.
,
clamp_distance
:
float
=
10.
,
loss_unit_distance
:
float
=
10.
,
loss_unit_distance
:
float
=
10.
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
pred_aff
=
T
.
from_tensor
(
pred_aff_tensor
)
pred_aff
=
T
.
from_tensor
(
traj
)
gt_aff
=
T
.
from_tensor
(
batch
[
"backbone_affine_tensor"
])
gt_aff
=
T
.
from_tensor
(
backbone_affine_tensor
)
backbone_mask
=
batch
[
"backbone_affine_mask"
]
fape_loss
=
compute_fape
(
fape_loss
=
compute_fape
(
pred_aff
,
pred_aff
,
gt_aff
,
gt_aff
[...,
None
,
:]
,
backbone_
mask
,
backbone_
affine_mask
[...,
None
,
:]
,
pred_aff
.
get_trans
(),
pred_aff
.
get_trans
(),
gt_aff
.
get_trans
(),
gt_aff
[...,
None
,
:]
.
get_trans
(),
backbone_
mask
,
backbone_
affine_mask
[...,
None
,
:]
,
l1_clamp_distance
=
clamp_distance
,
l1_clamp_distance
=
clamp_distance
,
length_scale
=
loss_unit_distance
,
length_scale
=
loss_unit_distance
,
)
)
if
(
'use_clamped_fape'
in
batch
):
if
(
use_clamped_fape
is
not
None
):
use_clamped_fape
=
batch
[
"use_clamped_fape"
]
unclamped_fape_loss
=
compute_fape
(
unclamped_fape_loss
=
compute_fape
(
pred_aff
,
pred_aff
,
gt_aff
,
gt_aff
[...,
None
,
:]
,
backbone_
mask
,
backbone_
affine_mask
[...,
None
,
:]
,
pred_aff
.
get_trans
(),
pred_aff
.
get_trans
(),
gt_aff
.
get_trans
(),
gt_aff
[...,
None
,
:]
.
get_trans
(),
backbone_
mask
,
backbone_
affine_mask
[...,
None
,
:]
,
l1_clamp_distance
=
None
,
l1_clamp_distance
=
None
,
length_scale
=
loss_unit_distance
,
length_scale
=
loss_unit_distance
,
)
)
fape_loss
=
(
fape_loss
=
(
fape_loss
*
use_clamped_fape
+
fape_loss
*
use_clamped_fape
+
fape_loss_unclamped
*
(
1
-
use_clamped_fape
)
unclamped_fape_loss
*
(
1
-
use_clamped_fape
)
)
)
return
torch
.
mean
(
fape_loss
,
dim
=-
1
)
return
torch
.
mean
(
fape_loss
,
dim
=-
1
)
def
sidechain_loss
(
def
sidechain_loss
(
sidechain_frames
,
sidechain_frames
:
torch
.
Tensor
,
sidechain_atom_pos
,
sidechain_atom_pos
:
torch
.
Tensor
,
rigidgroups_gt_frames
,
rigidgroups_gt_frames
:
torch
.
Tensor
,
rigidgroups_alt_gt_frames
,
rigidgroups_alt_gt_frames
:
torch
.
Tensor
,
rigidgroups_gt_exists
,
rigidgroups_gt_exists
:
torch
.
Tensor
,
renamed_atom14_gt_positions
,
renamed_atom14_gt_positions
:
torch
.
Tensor
,
renamed_atom14_gt_exists
,
renamed_atom14_gt_exists
:
torch
.
Tensor
,
alt_naming_is_better
,
alt_naming_is_better
:
torch
.
Tensor
,
clamp_distance
=
10.
,
clamp_distance
:
float
=
10.
,
length_scale
=
10.
,
length_scale
:
float
=
10.
,
):
**
kwargs
,
)
->
torch
.
Tensor
:
renamed_gt_frames
=
(
renamed_gt_frames
=
(
(
1.
-
alt_naming_is_better
[...,
None
,
None
,
None
,
None
])
*
(
1.
-
alt_naming_is_better
[...,
None
,
None
,
None
,
None
])
*
gt_frames
+
rigidgroups_
gt_frames
+
alt_naming_is_better
[...,
None
,
None
,
None
,
None
]
*
alt_naming_is_better
[...,
None
,
None
,
None
,
None
]
*
alt_gt_frames
rigidgroups_
alt_gt_frames
)
)
sidechain_frames
=
T
.
from_4x4
(
sidechain_frames
)
renamed_gt_frames
=
T
.
from_4x4
(
renamed_gt_frames
)
renamed_gt_frames
=
T
.
from_4x4
(
renamed_gt_frames
)
fape
=
compute_fape
(
fape
=
compute_fape
(
...
@@ -192,16 +196,13 @@ def fape_loss(
...
@@ -192,16 +196,13 @@ def fape_loss(
config
:
ml_collections
.
ConfigDict
,
config
:
ml_collections
.
ConfigDict
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
bb_loss
=
backbone_loss
(
bb_loss
=
backbone_loss
(
batch
,
out
[
"sm"
][
"frames"
]
[
-
1
]
,
**
config
.
backbone
traj
=
out
[
"sm"
][
"frames"
]
,
**
{
**
batch
,
**
config
.
backbone
},
)
)
sc_loss
=
sidechain_loss
(
sc_loss
=
sidechain_loss
(
out
[
"sm"
][
"sidechain_frames"
],
out
[
"sm"
][
"sidechain_frames"
],
out
[
"sm"
][
"positions"
],
out
[
"sm"
][
"positions"
],
{
**
{
**
batch
,
**
config
.
sidechain
}
**
batch
,
**
config
.
sidechain
,
},
)
)
return
(
return
(
...
...
tests/utils.py
View file @
68ba77e5
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
numpy
as
np
import
numpy
as
np
from
scipy.spatial.transform
import
Rotation
def
random_template_feats
(
n_templ
,
n
,
batch_size
=
None
):
def
random_template_feats
(
n_templ
,
n
,
batch_size
=
None
):
...
@@ -35,6 +36,7 @@ def random_template_feats(n_templ, n, batch_size=None):
...
@@ -35,6 +36,7 @@ def random_template_feats(n_templ, n, batch_size=None):
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
astype
(
np
.
int64
)
batch
[
"template_aatype"
]
=
batch
[
"template_aatype"
].
astype
(
np
.
int64
)
return
batch
return
batch
def
random_extra_msa_feats
(
n_extra
,
n
,
batch_size
=
None
):
def
random_extra_msa_feats
(
n_extra
,
n
,
batch_size
=
None
):
b
=
[]
b
=
[]
if
(
batch_size
is
not
None
):
if
(
batch_size
is
not
None
):
...
@@ -50,3 +52,34 @@ def random_extra_msa_feats(n_extra, n, batch_size=None):
...
@@ -50,3 +52,34 @@ def random_extra_msa_feats(n_extra, n, batch_size=None):
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
np
.
random
.
randint
(
0
,
2
,
(
*
b
,
n_extra
,
n
)).
astype
(
np
.
float32
),
}
}
return
batch
return
batch
def
random_affine_vectors
(
dim
):
prod_dim
=
1
for
d
in
dim
:
prod_dim
*=
d
affines
=
np
.
zeros
((
prod_dim
,
7
))
for
i
in
range
(
prod_dim
):
affines
[
i
,
:
4
]
=
Rotation
.
random
(
random_state
=
42
).
as_quat
()
affines
[
i
,
4
:]
=
np
.
random
.
rand
(
3
,)
return
affines
.
reshape
(
*
dim
,
7
)
def
random_affine_4x4s
(
dim
):
prod_dim
=
1
for
d
in
dim
:
prod_dim
*=
d
affines
=
np
.
zeros
((
prod_dim
,
4
,
4
))
for
i
in
range
(
prod_dim
):
affines
[
i
,
:
3
,
:
3
]
=
Rotation
.
random
(
random_state
=
42
).
as_matrix
()
affines
[
i
,
:
3
,
3
]
=
np
.
random
.
rand
(
3
,)
affines
[:,
3
,
3
]
=
1
return
affines
.
reshape
(
*
dim
,
4
,
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment