Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
33941e46
"tools/cfgs/vscode:/vscode.git/clone" did not exist on "a6bb35802ed20dc4091e03fe57399b55355b096d"
Commit
33941e46
authored
Sep 23, 2021
by
Gustaf Ahdritz
Browse files
Fix loss bugs
parent
1bc68426
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
56 additions
and
66 deletions
+56
-66
config.py
config.py
+1
-1
openfold/model/heads.py
openfold/model/heads.py
+1
-1
openfold/model/structure_module.py
openfold/model/structure_module.py
+1
-1
openfold/utils/feats.py
openfold/utils/feats.py
+16
-25
openfold/utils/loss.py
openfold/utils/loss.py
+33
-34
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+4
-4
No files found.
config.py
View file @
33941e46
...
...
@@ -211,7 +211,7 @@ config = mlc.ConfigDict({
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"cutoff"
:
15.
,
"n
um
_bins"
:
50
,
"n
o
_bins"
:
50
,
"eps"
:
1e-10
,
"weight"
:
0.01
,
},
...
...
openfold/model/heads.py
View file @
33941e46
...
...
@@ -49,7 +49,7 @@ class AuxiliaryHeads(nn.Module):
def
forward
(
self
,
outputs
):
aux_out
=
{}
lddt_logits
=
self
.
plddt
(
outputs
[
"single"
])
lddt_logits
=
self
.
plddt
(
outputs
[
"
sm"
][
"
single"
])
aux_out
[
"lddt_logits"
]
=
lddt_logits
# Required for relaxation later on
...
...
openfold/model/structure_module.py
View file @
33941e46
...
...
@@ -751,7 +751,7 @@ class StructureModule(nn.Module):
t
=
t
.
stop_rot_gradient
()
outputs
=
stack_tensor_dicts
(
outputs
)
outputs
[
"single
_act
"
]
=
s
outputs
[
"single"
]
=
s
return
outputs
...
...
openfold/utils/feats.py
View file @
33941e46
...
...
@@ -74,6 +74,7 @@ def get_chi_atom_indices():
def
compute_residx
(
batch
):
float_type
=
batch
[
"seq_mask"
].
dtype
aatype
=
batch
[
"aatype"
]
restype_atom14_to_atom37
=
[]
# mapping (restype, atom37) --> atom14
...
...
@@ -104,34 +105,28 @@ def compute_residx(batch):
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom14_mask
.
append
([
0.
]
*
14
)
restype_atom14_to_atom37
=
np
.
array
(
restype_atom14_to_atom37
,
dtype
=
np
.
int32
)
restype_atom37_to_atom14
=
np
.
array
(
restype_atom37_to_atom14
,
dtype
=
np
.
int32
)
restype_atom14_mask
=
np
.
array
(
restype_atom14_mask
,
dtype
=
np
.
float32
)
residx_atom14_to_atom37
=
np
.
take_along_axis
(
restype_atom14_to_atom37
,
aatype
[...,
None
],
axis
=
0
restype_atom14_to_atom37
=
aatype
.
new_tensor
(
restype_atom14_to_atom37
)
restype_atom37_to_atom14
=
aatype
.
new_tensor
(
restype_atom37_to_atom14
)
residx_atom14_mask
=
np
.
take_along_axis
(
restype_atom14_mask
,
aatype
[...,
None
],
axis
=
0
,
restype_atom14_mask
=
aatype
.
new_tensor
(
restype_atom14_mask
,
dtype
=
float_type
)
residx_atom14_to_atom37
=
restype_atom14_to_atom37
[
aatype
]
residx_atom14_mask
=
restype_atom14_mask
[
aatype
]
batch
[
'atom14_atom_exists'
]
=
residx_atom14_mask
batch
[
'residx_atom14_to_atom37'
]
=
residx_atom14_to_atom37
.
long
()
batch
[
'residx_atom14_to_atom37'
]
=
residx_atom14_to_atom37
# create the gather indices for mapping back
residx_atom37_to_atom14
=
np
.
take_along_axis
(
restype_atom37_to_atom14
,
aatype
[...,
None
],
axis
=
0
,
)
batch
[
'residx_atom37_to_atom14'
]
=
residx_atom37_to_atom14
.
long
()
residx_atom37_to_atom14
=
restype_atom37_to_atom14
[
aatype
]
batch
[
'residx_atom37_to_atom14'
]
=
residx_atom37_to_atom14
# create the corresponding mask
restype_atom37_mask
=
np
.
zeros
([
21
,
37
],
dtype
=
np
.
float
32
)
restype_atom37_mask
=
torch
.
zeros
([
21
,
37
],
dtype
=
float
_type
)
for
restype
,
restype_letter
in
enumerate
(
residue_constants
.
restypes
):
restype_name
=
residue_constants
.
restype_1to3
[
restype_letter
]
atom_names
=
residue_constants
.
residue_atoms
[
restype_name
]
...
...
@@ -139,11 +134,7 @@ def compute_residx(batch):
atom_type
=
residue_constants
.
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
residx_atom37_mask
=
np
.
take_along_axis
(
restype_atom37_mask
,
aatype
[...,
None
],
axis
=
0
,
)
residx_atom37_mask
=
restype_atom37_mask
[
aatype
]
batch
[
'atom37_atom_exists'
]
=
residx_atom37_mask
...
...
openfold/utils/loss.py
View file @
33941e46
...
...
@@ -27,6 +27,7 @@ from openfold.utils.tensor_utils import (
tree_map
,
tensor_tree_map
,
masked_mean
,
permute_final_dims
,
)
...
...
@@ -289,28 +290,25 @@ def lddt_loss(
all_atom_mask
:
torch
.
Tensor
,
resolution
:
torch
.
Tensor
,
cutoff
:
float
=
15.
,
n
um
_bins
:
int
=
50
,
n
o
_bins
:
int
=
50
,
min_resolution
:
float
=
0.1
,
max_resolution
:
float
=
3.0
,
eps
:
float
=
1e-10
,
**
kwargs
,
)
->
torch
.
Tensor
:
all_atom_positions
=
batch
[
"all_atom_positions"
]
all_atom_mask
=
batch
[
"all_atom_mask"
]
n
=
all_atom_mask
.
shape
[
-
1
]
n
=
all_atom_mask
.
shape
[
-
2
]
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
:,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
:,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
:,
ca_pos
:(
ca_pos
+
1
)]
# keep dim
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:(
ca_pos
+
1
)]
# keep dim
dmat_true
=
torch
.
sqrt
(
eps
+
torch
.
sum
(
(
all_atom_positions
[...,
None
]
-
all_atom_positions
[...,
None
,
:]
all_atom_positions
[...,
None
,
:
]
-
all_atom_positions
[...,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
)
...
...
@@ -320,14 +318,13 @@ def lddt_loss(
eps
+
torch
.
sum
(
(
all_atom_pred_pos
[...,
None
]
-
all_atom_pred_pos
[...,
None
,
:]
all_atom_pred_pos
[...,
None
,
:
]
-
all_atom_pred_pos
[...,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
)
)
dists_to_score
=
(
(
dmat_true
<
cutoff
)
*
all_atom_mask
*
permute_final_dims
(
all_atom_mask
,
1
,
0
)
*
...
...
@@ -337,28 +334,30 @@ def lddt_loss(
dist_l1
=
torch
.
abs
(
dmat_true
-
dmat_pred
)
score
=
(
(
dist_l1
<
0.5
)
+
(
dist_l1
<
1.0
)
+
(
dist_l1
<
2.0
)
+
(
dist_l1
<
4.0
)
(
dist_l1
<
0.5
)
.
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
1.0
)
.
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
2.0
)
.
type
(
dist_l1
.
dtype
)
+
(
dist_l1
<
4.0
)
.
type
(
dist_l1
.
dtype
)
)
score
*=
0.25
norm
=
1.
/
(
eps
+
torch
.
sum
(
dists_to_score
,
dim
=-
1
))
score
=
norm
*
(
eps
+
torch
.
sum
(
dists_to_score
*
score
,
dim
=-
1
))
# TODO: this feels a bit weird, but it's in the source
score
=
score
.
detach
()
bin_index
=
torch
.
floor
(
lddt_ca
*
n
um
_bins
)
bin_index
=
torch
.
floor
(
score
*
n
o
_bins
)
.
long
()
bin_index
=
torch
.
clamp
(
bin_index
,
max
=
(
no_bins
-
1
))
lddt_ca_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
bin_index
,
num_classes
=
n
um
_bins
bin_index
,
num_classes
=
n
o
_bins
)
errors
=
softmax_cross_entropy
(
logits
,
lddt_ca_one_hot
)
loss
=
torch
.
sum
(
errors
*
all_atom_mask
)
/
(
torch
.
sum
(
mask_ca
)
+
eps
)
all_atom_mask
=
all_atom_mask
.
squeeze
(
-
1
)
loss
=
(
torch
.
sum
(
errors
*
all_atom_mask
)
/
(
torch
.
sum
(
all_atom_mask
)
+
1e-8
)
)
loss
*=
(
(
resolution
>=
min_resolution
)
&
...
...
@@ -917,17 +916,16 @@ def find_structural_violations(
overlap_tolerance
=
clash_overlap_tolerance
,
bond_length_tolerance_factor
=
violation_tolerance_factor
)
atom14_dists_lower_bound
=
restype_atom14_bounds
[
"lower_bound"
][
atom14_atom_exists
=
batch
[
"atom14_atom_exists"
]
atom14_dists_lower_bound
=
(
atom14_pred_positions
.
new_tensor
(
restype_atom14_bounds
[
"lower_bound"
])[
batch
[
"aatype"
]
]
atom14_dists_upper_bound
=
restype_atom14_bounds
[
"upper_bound"
][
)
atom14_dists_upper_bound
=
(
atom14_pred_positions
.
new_tensor
(
restype_atom14_bounds
[
"upper_bound"
])[
batch
[
"aatype"
]
]
atom14_dists_lower_bound
=
atom14_pred_positions
.
new_tensor
(
atom14_dists_lower_bound
)
atom14_dists_upper_bound
=
atom14_pred_positions
.
new_tensor
(
atom14_dists_upper_bound
)
residue_violations
=
within_residue_violations
(
atom14_pred_positions
=
atom14_pred_positions
,
...
...
@@ -1102,6 +1100,7 @@ def violation_loss(
violations
:
Dict
[
str
,
torch
.
Tensor
],
atom14_atom_exists
:
torch
.
Tensor
,
eps
=
1e-6
,
**
kwargs
,
)
->
torch
.
Tensor
:
num_atoms
=
torch
.
sum
(
atom14_atom_exists
)
l_clash
=
torch
.
sum
(
...
...
openfold/utils/tensor_utils.py
View file @
33941e46
...
...
@@ -49,7 +49,7 @@ def dict_multimap(fn, dicts):
for
k
,
v
in
first
.
items
():
all_v
=
[
d
[
k
]
for
d
in
dicts
]
if
(
type
(
v
)
is
dict
):
new_dict
[
k
]
=
dict_multimap
(
all_v
)
new_dict
[
k
]
=
dict_multimap
(
fn
,
all_v
)
else
:
new_dict
[
k
]
=
fn
(
all_v
)
...
...
@@ -122,9 +122,9 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are
interpreted as
simpl
ified
"pytrees,"
consisting only of (nested) lists, tuples, and dicts with
tensor
leaves.
Layer outputs and inputs are
assumed to be
simpl
e
"pytrees,"
consisting only of (
arbitrarily
nested) lists, tuples, and dicts with
torch.Tensor
leaves.
Args:
layer:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment