Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
33941e46
Commit
33941e46
authored
Sep 23, 2021
by
Gustaf Ahdritz
Browse files
Fix loss bugs
parent
1bc68426
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
56 additions
and
66 deletions
+56
-66
config.py
config.py
+1
-1
openfold/model/heads.py
openfold/model/heads.py
+1
-1
openfold/model/structure_module.py
openfold/model/structure_module.py
+1
-1
openfold/utils/feats.py
openfold/utils/feats.py
+16
-25
openfold/utils/loss.py
openfold/utils/loss.py
+33
-34
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+4
-4
No files found.
config.py
View file @
33941e46
...
@@ -211,7 +211,7 @@ config = mlc.ConfigDict({
...
@@ -211,7 +211,7 @@ config = mlc.ConfigDict({
"min_resolution"
:
0.1
,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"max_resolution"
:
3.0
,
"cutoff"
:
15.
,
"cutoff"
:
15.
,
"n
um
_bins"
:
50
,
"n
o
_bins"
:
50
,
"eps"
:
1e-10
,
"eps"
:
1e-10
,
"weight"
:
0.01
,
"weight"
:
0.01
,
},
},
...
...
openfold/model/heads.py
View file @
33941e46
...
@@ -49,7 +49,7 @@ class AuxiliaryHeads(nn.Module):
...
@@ -49,7 +49,7 @@ class AuxiliaryHeads(nn.Module):
def
forward
(
self
,
outputs
):
def
forward
(
self
,
outputs
):
aux_out
=
{}
aux_out
=
{}
lddt_logits
=
self
.
plddt
(
outputs
[
"single"
])
lddt_logits
=
self
.
plddt
(
outputs
[
"
sm"
][
"
single"
])
aux_out
[
"lddt_logits"
]
=
lddt_logits
aux_out
[
"lddt_logits"
]
=
lddt_logits
# Required for relaxation later on
# Required for relaxation later on
...
...
openfold/model/structure_module.py
View file @
33941e46
...
@@ -751,7 +751,7 @@ class StructureModule(nn.Module):
...
@@ -751,7 +751,7 @@ class StructureModule(nn.Module):
t
=
t
.
stop_rot_gradient
()
t
=
t
.
stop_rot_gradient
()
outputs
=
stack_tensor_dicts
(
outputs
)
outputs
=
stack_tensor_dicts
(
outputs
)
outputs
[
"single
_act
"
]
=
s
outputs
[
"single"
]
=
s
return
outputs
return
outputs
...
...
openfold/utils/feats.py
View file @
33941e46
...
@@ -74,6 +74,7 @@ def get_chi_atom_indices():
...
@@ -74,6 +74,7 @@ def get_chi_atom_indices():
def
compute_residx
(
batch
):
def
compute_residx
(
batch
):
float_type
=
batch
[
"seq_mask"
].
dtype
aatype
=
batch
[
"aatype"
]
aatype
=
batch
[
"aatype"
]
restype_atom14_to_atom37
=
[]
# mapping (restype, atom37) --> atom14
restype_atom14_to_atom37
=
[]
# mapping (restype, atom37) --> atom14
...
@@ -104,34 +105,28 @@ def compute_residx(batch):
...
@@ -104,34 +105,28 @@ def compute_residx(batch):
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom14_mask
.
append
([
0.
]
*
14
)
restype_atom14_mask
.
append
([
0.
]
*
14
)
restype_atom14_to_atom37
=
np
.
array
(
restype_atom14_to_atom37
,
dtype
=
np
.
int32
)
restype_atom14_to_atom37
=
aatype
.
new_tensor
(
restype_atom37_to_atom14
=
np
.
array
(
restype_atom37_to_atom14
,
dtype
=
np
.
int32
)
restype_atom14_to_atom37
restype_atom14_mask
=
np
.
array
(
restype_atom14_mask
,
dtype
=
np
.
float32
)
)
restype_atom37_to_atom14
=
aatype
.
new_tensor
(
residx_atom14_to_atom37
=
np
.
take_along_axis
(
restype_atom37_to_atom14
restype_atom14_to_atom37
,
aatype
[...,
None
],
axis
=
0
)
)
residx_atom14_mask
=
np
.
take_along_axis
(
restype_atom14_mask
=
aatype
.
new_tensor
(
restype_atom14_mask
,
restype_atom14_mask
,
dtype
=
float_type
aatype
[...,
None
],
axis
=
0
,
)
)
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
aatype
]
residx_atom14_mask
=
restype_atom14_mask
[
aatype
]
batch
[
'atom14_atom_exists'
]
=
residx_atom14_mask
batch
[
'atom14_atom_exists'
]
=
residx_atom14_mask
batch
[
'residx_atom14_to_atom37'
]
=
residx_atom14_to_atom37
.
long
()
batch
[
'residx_atom14_to_atom37'
]
=
residx_atom14_to_atom37
# create the gather indices for mapping back
# create the gather indices for mapping back
residx_atom37_to_atom14
=
np
.
take_along_axis
(
residx_atom37_to_atom14
=
restype_atom37_to_atom14
[
aatype
]
restype_atom37_to_atom14
,
batch
[
'residx_atom37_to_atom14'
]
=
residx_atom37_to_atom14
aatype
[...,
None
],
axis
=
0
,
)
batch
[
'residx_atom37_to_atom14'
]
=
residx_atom37_to_atom14
.
long
()
# create the corresponding mask
# create the corresponding mask
restype_atom37_mask
=
np
.
zeros
([
21
,
37
],
dtype
=
np
.
float
32
)
restype_atom37_mask
=
torch
.
zeros
([
21
,
37
],
dtype
=
float
_type
)
for
restype
,
restype_letter
in
enumerate
(
residue_constants
.
restypes
):
for
restype
,
restype_letter
in
enumerate
(
residue_constants
.
restypes
):
restype_name
=
residue_constants
.
restype_1to3
[
restype_letter
]
restype_name
=
residue_constants
.
restype_1to3
[
restype_letter
]
atom_names
=
residue_constants
.
residue_atoms
[
restype_name
]
atom_names
=
residue_constants
.
residue_atoms
[
restype_name
]
...
@@ -139,11 +134,7 @@ def compute_residx(batch):
...
@@ -139,11 +134,7 @@ def compute_residx(batch):
atom_type
=
residue_constants
.
atom_order
[
atom_name
]
atom_type
=
residue_constants
.
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
restype_atom37_mask
[
restype
,
atom_type
]
=
1
residx_atom37_mask
=
np
.
take_along_axis
(
residx_atom37_mask
=
restype_atom37_mask
[
aatype
]
restype_atom37_mask
,
aatype
[...,
None
],
axis
=
0
,
)
batch
[
'atom37_atom_exists'
]
=
residx_atom37_mask
batch
[
'atom37_atom_exists'
]
=
residx_atom37_mask
...
...
openfold/utils/loss.py
View file @
33941e46
...
@@ -27,6 +27,7 @@ from openfold.utils.tensor_utils import (
...
@@ -27,6 +27,7 @@ from openfold.utils.tensor_utils import (
tree_map
,
tree_map
,
tensor_tree_map
,
tensor_tree_map
,
masked_mean
,
masked_mean
,
permute_final_dims
,
)
)
...
@@ -289,28 +290,25 @@ def lddt_loss(
...
@@ -289,28 +290,25 @@ def lddt_loss(
all_atom_mask
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
resolution
:
torch
.
Tensor
,
resolution
:
torch
.
Tensor
,
cutoff
:
float
=
15.
,
cutoff
:
float
=
15.
,
n
um
_bins
:
int
=
50
,
n
o
_bins
:
int
=
50
,
min_resolution
:
float
=
0.1
,
min_resolution
:
float
=
0.1
,
max_resolution
:
float
=
3.0
,
max_resolution
:
float
=
3.0
,
eps
:
float
=
1e-10
,
eps
:
float
=
1e-10
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
all_atom_positions
=
batch
[
"all_atom_positions"
]
n
=
all_atom_mask
.
shape
[
-
2
]
all_atom_mask
=
batch
[
"all_atom_mask"
]
n
=
all_atom_mask
.
shape
[
-
1
]
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
:,
ca_pos
,
:]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
:,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
:,
ca_pos
:(
ca_pos
+
1
)]
# keep dim
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:(
ca_pos
+
1
)]
# keep dim
dmat_true
=
torch
.
sqrt
(
dmat_true
=
torch
.
sqrt
(
eps
+
eps
+
torch
.
sum
(
torch
.
sum
(
(
(
all_atom_positions
[...,
None
]
-
all_atom_positions
[...,
None
,
:
]
-
all_atom_positions
[...,
None
,
:]
all_atom_positions
[...,
None
,
:,
:]
)
**
2
,
)
**
2
,
dim
=-
1
,
dim
=-
1
,
)
)
...
@@ -320,14 +318,13 @@ def lddt_loss(
...
@@ -320,14 +318,13 @@ def lddt_loss(
eps
+
eps
+
torch
.
sum
(
torch
.
sum
(
(
(
all_atom_pred_pos
[...,
None
]
-
all_atom_pred_pos
[...,
None
,
:
]
-
all_atom_pred_pos
[...,
None
,
:]
all_atom_pred_pos
[...,
None
,
:,
:]
)
**
2
,
)
**
2
,
dim
=-
1
,
dim
=-
1
,
)
)
)
)
dists_to_score
=
(
dists_to_score
=
(
(
dmat_true
<
cutoff
)
*
all_atom_mask
*
(
dmat_true
<
cutoff
)
*
all_atom_mask
*
permute_final_dims
(
all_atom_mask
,
1
,
0
)
*
permute_final_dims
(
all_atom_mask
,
1
,
0
)
*
...
@@ -337,29 +334,31 @@ def lddt_loss(
...
@@ -337,29 +334,31 @@ def lddt_loss(
dist_l1
=
torch
.
abs
(
dmat_true
-
dmat_pred
)
dist_l1
=
torch
.
abs
(
dmat_true
-
dmat_pred
)
score
=
(
score
=
(
(
dist_l1
<
0.5
)
+
(
dist_l1
<
0.5
)
.
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
1.0
)
+
(
dist_l1
<
1.0
)
.
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
2.0
)
+
(
dist_l1
<
2.0
)
.
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
4.0
)
(
dist_l1
<
4.0
)
.
type
(
dist_l1
.
dtype
)
)
)
score
*=
0.25
score
*=
0.25
norm
=
1.
/
(
eps
+
torch
.
sum
(
dists_to_score
,
dim
=-
1
))
norm
=
1.
/
(
eps
+
torch
.
sum
(
dists_to_score
,
dim
=-
1
))
score
=
norm
*
(
eps
+
torch
.
sum
(
dists_to_score
*
score
,
dim
=-
1
))
score
=
norm
*
(
eps
+
torch
.
sum
(
dists_to_score
*
score
,
dim
=-
1
))
# TODO: this feels a bit weird, but it's in the source
score
=
score
.
detach
()
score
=
score
.
detach
()
bin_index
=
torch
.
floor
(
lddt_ca
*
n
um
_bins
)
bin_index
=
torch
.
floor
(
score
*
n
o
_bins
)
.
long
()
bin_index
=
torch
.
clamp
(
bin_index
,
max
=
(
no_bins
-
1
))
lddt_ca_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
lddt_ca_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
bin_index
,
num_classes
=
n
um
_bins
bin_index
,
num_classes
=
n
o
_bins
)
)
errors
=
softmax_cross_entropy
(
logits
,
lddt_ca_one_hot
)
errors
=
softmax_cross_entropy
(
logits
,
lddt_ca_one_hot
)
loss
=
torch
.
sum
(
errors
*
all_atom_mask
)
/
(
torch
.
sum
(
mask_ca
)
+
eps
)
all_atom_mask
=
all_atom_mask
.
squeeze
(
-
1
)
loss
=
(
torch
.
sum
(
errors
*
all_atom_mask
)
/
(
torch
.
sum
(
all_atom_mask
)
+
1e-8
)
)
loss
*=
(
loss
*=
(
(
resolution
>=
min_resolution
)
&
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
(
resolution
<=
max_resolution
)
...
@@ -917,17 +916,16 @@ def find_structural_violations(
...
@@ -917,17 +916,16 @@ def find_structural_violations(
overlap_tolerance
=
clash_overlap_tolerance
,
overlap_tolerance
=
clash_overlap_tolerance
,
bond_length_tolerance_factor
=
violation_tolerance_factor
bond_length_tolerance_factor
=
violation_tolerance_factor
)
)
atom14_dists_lower_bound
=
restype_atom14_bounds
[
"lower_bound"
][
atom14_atom_exists
=
batch
[
"atom14_atom_exists"
]
batch
[
"aatype"
]
atom14_dists_lower_bound
=
(
]
atom14_pred_positions
.
new_tensor
(
restype_atom14_bounds
[
"lower_bound"
])[
atom14_dists_upper_bound
=
restype_atom14_bounds
[
"upper_bound"
][
batch
[
"aatype"
]
batch
[
"aatype"
]
]
]
atom14_dists_lower_bound
=
atom14_pred_positions
.
new_tensor
(
atom14_dists_lower_bound
)
)
atom14_dists_upper_bound
=
atom14_pred_positions
.
new_tensor
(
atom14_dists_upper_bound
=
(
atom14_dists_upper_bound
atom14_pred_positions
.
new_tensor
(
restype_atom14_bounds
[
"upper_bound"
])[
batch
[
"aatype"
]
]
)
)
residue_violations
=
within_residue_violations
(
residue_violations
=
within_residue_violations
(
atom14_pred_positions
=
atom14_pred_positions
,
atom14_pred_positions
=
atom14_pred_positions
,
...
@@ -1102,6 +1100,7 @@ def violation_loss(
...
@@ -1102,6 +1100,7 @@ def violation_loss(
violations
:
Dict
[
str
,
torch
.
Tensor
],
violations
:
Dict
[
str
,
torch
.
Tensor
],
atom14_atom_exists
:
torch
.
Tensor
,
atom14_atom_exists
:
torch
.
Tensor
,
eps
=
1e-6
,
eps
=
1e-6
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_atoms
=
torch
.
sum
(
atom14_atom_exists
)
num_atoms
=
torch
.
sum
(
atom14_atom_exists
)
l_clash
=
torch
.
sum
(
l_clash
=
torch
.
sum
(
...
...
openfold/utils/tensor_utils.py
View file @
33941e46
...
@@ -49,7 +49,7 @@ def dict_multimap(fn, dicts):
...
@@ -49,7 +49,7 @@ def dict_multimap(fn, dicts):
for
k
,
v
in
first
.
items
():
for
k
,
v
in
first
.
items
():
all_v
=
[
d
[
k
]
for
d
in
dicts
]
all_v
=
[
d
[
k
]
for
d
in
dicts
]
if
(
type
(
v
)
is
dict
):
if
(
type
(
v
)
is
dict
):
new_dict
[
k
]
=
dict_multimap
(
all_v
)
new_dict
[
k
]
=
dict_multimap
(
fn
,
all_v
)
else
:
else
:
new_dict
[
k
]
=
fn
(
all_v
)
new_dict
[
k
]
=
fn
(
all_v
)
...
@@ -122,9 +122,9 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
...
@@ -122,9 +122,9 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
"""
"""
Implements the "chunking" procedure described in section 1.11.8.
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are
interpreted as
simpl
ified
"pytrees,"
Layer outputs and inputs are
assumed to be
simpl
e
"pytrees,"
consisting only of (nested) lists, tuples, and dicts with
tensor
consisting only of (
arbitrarily
nested) lists, tuples, and dicts with
leaves.
torch.Tensor
leaves.
Args:
Args:
layer:
layer:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment