Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
df6b97f2
Commit
df6b97f2
authored
Sep 20, 2021
by
Gustaf Ahdritz
Browse files
First draft of loss class
parent
15895ea9
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
360 additions
and
86 deletions
+360
-86
alphafold/model/structure_module.py
alphafold/model/structure_module.py
+7
-3
alphafold/utils/loss.py
alphafold/utils/loss.py
+302
-82
config.py
config.py
+50
-0
tests/test_structure_module.py
tests/test_structure_module.py
+1
-1
No files found.
alphafold/model/structure_module.py
View file @
df6b97f2
...
@@ -129,6 +129,8 @@ class AngleResnet(nn.Module):
...
@@ -129,6 +129,8 @@ class AngleResnet(nn.Module):
# [*, no_angles * 2]
# [*, no_angles * 2]
s
=
self
.
linear_out
(
s
)
s
=
self
.
linear_out
(
s
)
unnormalized_s
=
s
# [*, no_angles, 2]
# [*, no_angles, 2]
s
=
s
.
view
(
*
s
.
shape
[:
-
1
],
-
1
,
2
)
s
=
s
.
view
(
*
s
.
shape
[:
-
1
],
-
1
,
2
)
norm_denom
=
torch
.
sqrt
(
norm_denom
=
torch
.
sqrt
(
...
@@ -139,7 +141,7 @@ class AngleResnet(nn.Module):
...
@@ -139,7 +141,7 @@ class AngleResnet(nn.Module):
)
)
s
=
s
/
norm_denom
s
=
s
/
norm_denom
return
s
return
unnormalized_s
,
s
class
InvariantPointAttention
(
nn
.
Module
):
class
InvariantPointAttention
(
nn
.
Module
):
...
@@ -723,7 +725,7 @@ class StructureModule(nn.Module):
...
@@ -723,7 +725,7 @@ class StructureModule(nn.Module):
t
=
t
.
compose
(
self
.
bb_update
(
s
))
t
=
t
.
compose
(
self
.
bb_update
(
s
))
# [*, N, 7, 2]
# [*, N, 7, 2]
a
=
self
.
angle_resnet
(
s
,
s_initial
)
unnormalized_a
,
a
=
self
.
angle_resnet
(
s
,
s_initial
)
all_frames_to_global
=
self
.
torsion_angles_to_frames
(
all_frames_to_global
=
self
.
torsion_angles_to_frames
(
t
.
scale_translation
(
self
.
trans_scale_factor
),
a
,
f
,
t
.
scale_translation
(
self
.
trans_scale_factor
),
a
,
f
,
...
@@ -735,8 +737,10 @@ class StructureModule(nn.Module):
...
@@ -735,8 +737,10 @@ class StructureModule(nn.Module):
)
)
preds
=
{
preds
=
{
"
t
ra
nsformation
s"
:
"
f
ra
me
s"
:
t
.
scale_translation
(
self
.
trans_scale_factor
).
to_4x4
(),
t
.
scale_translation
(
self
.
trans_scale_factor
).
to_4x4
(),
"sidechain_frames"
:
all_frames_to_global
,
"unnormalized_angles"
:
unnormalized_a
,
"angles"
:
a
,
"angles"
:
a
,
"positions"
:
pred_xyz
,
"positions"
:
pred_xyz
,
}
}
...
...
alphafold/utils/loss.py
View file @
df6b97f2
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
partial
import
ml_collections
import
ml_collections
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -37,6 +38,13 @@ def softmax_cross_entropy(logits, labels):
...
@@ -37,6 +38,13 @@ def softmax_cross_entropy(logits, labels):
return
loss
return
loss
def
sigmoid_cross_entropy
(
logits
,
labels
):
log_p
=
torch
.
nn
.
functional
.
logsigmoid
(
logits
)
log_not_p
=
torch
.
nn
.
functional
.
logsigmoid
(
-
logits
)
loss
=
-
labels
*
log_p
-
(
1
-
labels
)
*
log_not_p
return
loss
def
torsion_angle_loss
(
def
torsion_angle_loss
(
a
,
# [*, N, 7, 2]
a
,
# [*, N, 7, 2]
a_gt
,
# [*, N, 7, 2]
a_gt
,
# [*, N, 7, 2]
...
@@ -102,12 +110,13 @@ def compute_fape(
...
@@ -102,12 +110,13 @@ def compute_fape(
def
backbone_loss
(
def
backbone_loss
(
batch
:
Dict
[
str
,
torch
.
Tensor
],
batch
:
Dict
[
str
,
torch
.
Tensor
],
pred_aff
:
T
,
pred_aff
_tensor
:
torch
.
Tensor
,
clamp_distance
:
float
=
10.
,
clamp_distance
:
float
=
10.
,
loss_unit_distance
:
float
=
10.
,
loss_unit_distance
:
float
=
10.
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
gt_aff
=
T
.
from_tensor
(
batch
[
'backbone_affine_tensor'
])
pred_aff
=
T
.
from_tensor
(
pred_aff_tensor
)
backbone_mask
=
batch
[
'backbone_affine_mask'
]
gt_aff
=
T
.
from_tensor
(
batch
[
"backbone_affine_tensor"
])
backbone_mask
=
batch
[
"backbone_affine_mask"
]
fape_loss
=
compute_fape
(
fape_loss
=
compute_fape
(
pred_aff
,
pred_aff
,
...
@@ -138,15 +147,15 @@ def backbone_loss(
...
@@ -138,15 +147,15 @@ def backbone_loss(
fape_loss_unclamped
*
(
1
-
use_clamped_fape
)
fape_loss_unclamped
*
(
1
-
use_clamped_fape
)
)
)
return
torch
.
mean
(
fape_loss
,
dim
=
backbone_mask
.
shape
[:
-
1
]
)
return
torch
.
mean
(
fape_loss
,
dim
=-
1
)
def
sidechain_loss
(
def
sidechain_loss
(
sidechain_frames
,
sidechain_frames
,
sidechain_atom_pos
,
sidechain_atom_pos
,
gt_frames
,
rigidgroups_
gt_frames
,
alt_gt_frames
,
rigidgroups_
alt_gt_frames
,
gt_exists
,
rigidgroups_
gt_exists
,
renamed_atom14_gt_positions
,
renamed_atom14_gt_positions
,
renamed_atom14_gt_exists
,
renamed_atom14_gt_exists
,
alt_naming_is_better
,
alt_naming_is_better
,
...
@@ -176,6 +185,87 @@ def sidechain_loss(
...
@@ -176,6 +185,87 @@ def sidechain_loss(
return
fape
return
fape
def
fape_loss
(
out
:
Dict
[
str
,
torch
.
Tensor
],
batch
:
Dict
[
str
,
torch
.
Tensor
],
config
:
ml_collections
.
ConfigDict
,
)
->
torch
.
Tensor
:
bb_loss
=
backbone_loss
(
batch
,
out
[
"sm"
][
"frames"
][
-
1
],
**
config
.
backbone
)
sc_loss
=
sidechain_loss
(
out
[
"sm"
][
"sidechain_frames"
],
out
[
"sm"
][
"positions"
],
{
**
batch
,
**
config
.
sidechain
,
},
)
return
(
config
.
backbone
.
weight
*
bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
)
def
supervised_chi_loss
(
angles_sin_cos
:
torch
.
Tensor
,
unnormalized_angles_sin_cos
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
seq_mask
:
torch
.
Tensor
,
chi_mask
:
torch
.
Tensor
,
chi_angles
:
torch
.
Tensor
,
chi_weight
:
float
,
angle_norm_weight
:
float
,
eps
=
1e-6
,
)
->
torch
.
Tensor
:
pred_angles
=
angles_sin_cos
[...,
3
:,
:]
residue_type_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
aatype
,
residue_constants
.
restype_num
+
1
,
).
unsqueeze
(
-
3
)
chi_pi_periodic
=
torch
.
einsum
(
"...ij,jk->ik"
,
residue_type_one_hot
,
aatype
.
new_tensor
(
residue_constants
.
chi_pi_periodic
)
)
true_chi
=
chi_angles
.
unsqueeze
(
-
3
)
sin_true_chi
=
torch
.
sin
(
true_chi
)
cos_true_chi
=
torch
.
cos
(
true_chi
)
sin_cos_true_chi
=
torch
.
stack
([
sin_true_chi
,
cos_true_chi
],
dim
=-
1
)
shifted_mask
=
(
1
-
2
*
chi_pi_periodic
).
unsqueeze
(
-
1
)
sin_cos_true_chi_shifted
=
shifted_mask
*
sin_cos_true_chi
sq_chi_error
=
torch
.
sum
(
(
sin_cos_true_chi
-
pred_angles
)
**
2
,
dim
=-
1
)
sq_chi_error_shifted
=
torch
.
sum
(
(
sin_cos_true_chi_shifted
-
pred_angles
)
**
2
,
dim
=-
1
)
sq_chi_error
=
torch
.
minimum
(
sq_chi_error
,
sq_chi_error_shifted
)
sq_chi_loss
=
masked_mean
(
sq_chi_error
,
chi_mask
.
unsqueeze
(
-
3
),
dim
=
(
-
1
,
-
2
,
-
3
)
)
loss
=
0
loss
+=
chi_weight
*
sq_chi_loss
angle_norm
=
torch
.
sqrt
(
torch
.
sum
(
unnormalized_angles_sin_cos
**
2
,
dim
=-
1
)
+
eps
)
norm_error
=
torch
.
abs
(
angle_norm
-
1.
)
angle_norm_loss
=
masked_mean
(
norm_error
,
sequence_mask
[...,
None
,
:,
None
],
dim
=
(
-
1
,
-
2
,
-
3
)
)
loss
+=
angle_norm_weight
*
angle_norm_loss
return
loss
def
compute_plddt
(
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
compute_plddt
(
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_bins
=
logits
.
shape
[
-
1
]
num_bins
=
logits
.
shape
[
-
1
]
bin_width
=
1.
/
num_bins
bin_width
=
1.
/
num_bins
...
@@ -192,31 +282,34 @@ def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
...
@@ -192,31 +282,34 @@ def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
def
lddt_loss
(
def
lddt_loss
(
batch
:
Dict
[
str
,
torch
.
Tensor
],
logits
:
torch
.
Tensor
,
all_atom_pred_pos
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
resolution
:
torch
.
Tensor
,
cutoff
:
float
=
15.
,
cutoff
:
float
=
15.
,
num_bins
:
int
=
50
,
num_bins
:
int
=
50
,
min_resolution
:
float
=
0.1
,
min_resolution
:
float
=
0.1
,
max_resolution
:
float
=
3.0
,
max_resolution
:
float
=
3.0
,
eps
:
float
=
1e-10
,
eps
:
float
=
1e-10
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
all_atom_pred_pos
=
batch
[
"sm"
][
"pred_pos"
][
-
1
]
all_atom_positions
=
batch
[
"all_atom_positions"
]
all_atom_true_pos
=
batch
[
"all_atom_positions"
]
all_atom_mask
=
batch
[
"all_atom_mask"
]
all_atom_mask
=
batch
[
"all_atom_mask"
]
logits
=
batch
[
"predicted_lddt_logits"
]
n
=
all_atom_mask
.
shape
[
-
1
]
n
=
all_atom_mask
.
shape
[
-
1
]
ca_pos
=
residue_constants
.
atom_order
[
'
CA
'
]
ca_pos
=
residue_constants
.
atom_order
[
"
CA
"
]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
:,
ca_pos
,
:]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
:,
ca_pos
,
:]
all_atom_
true_po
s
=
all_atom_
true_po
s
[...,
:,
ca_pos
,
:]
all_atom_
position
s
=
all_atom_
position
s
[...,
:,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
:,
ca_pos
:(
ca_pos
+
1
)]
# keep dim
all_atom_mask
=
all_atom_mask
[...,
:,
ca_pos
:(
ca_pos
+
1
)]
# keep dim
dmat_true
=
torch
.
sqrt
(
dmat_true
=
torch
.
sqrt
(
eps
+
eps
+
torch
.
sum
(
torch
.
sum
(
(
(
all_atom_
true_po
s
[...,
None
]
-
all_atom_
position
s
[...,
None
]
-
all_atom_
true_po
s
[...,
None
,
:]
all_atom_
position
s
[...,
None
,
:]
)
**
2
,
)
**
2
,
dim
=-
1
,
dim
=-
1
,
)
)
...
@@ -267,36 +360,44 @@ def lddt_loss(
...
@@ -267,36 +360,44 @@ def lddt_loss(
loss
=
torch
.
sum
(
errors
*
all_atom_mask
)
/
(
torch
.
sum
(
mask_ca
)
+
eps
)
loss
=
torch
.
sum
(
errors
*
all_atom_mask
)
/
(
torch
.
sum
(
mask_ca
)
+
eps
)
loss
*=
(
loss
*=
(
(
batch
[
"
resolution
"
]
>=
min_resolution
)
&
(
resolution
>=
min_resolution
)
&
(
batch
[
"
resolution
"
]
<=
max_resolution
)
(
resolution
<=
max_resolution
)
)
)
return
loss
return
loss
def
distogram_loss
(
def
distogram_loss
(
pred_distr
,
logits
,
gt
,
pseudo_beta
,
mask
,
pseudo_beta_mask
,
min_bin
=
2.3125
,
max_bin
=
21.6875
,
no_bins
=
64
,
eps
=
1e-6
min_bin
=
2.3125
,
max_bin
=
21.6875
,
no_bins
=
64
,
eps
=
1e-6
,
**
kwargs
,
):
):
boundaries
=
torch
.
linspace
(
boundaries
=
torch
.
linspace
(
min_bin
,
max_bin
,
no_bins
-
1
,
device
=
pred_distr
.
device
,
min_bin
,
max_bin
,
no_bins
-
1
,
device
=
logits
.
device
,
)
)
boundaries
=
boundaries
**
2
boundaries
=
boundaries
**
2
dists
=
torch
.
sum
(
dists
=
torch
.
sum
(
(
gt
[...,
None
,
:]
-
gt
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
keepdims
=
True
(
pseudo_beta
[...,
None
,
:]
-
pseudo_beta
[...,
None
,
:,
:]
)
**
2
,
dim
=-
1
,
keepdims
=
True
)
)
true_bins
=
torch
.
sum
(
dists
>
sq_breaks
,
dim
=-
1
)
true_bins
=
torch
.
sum
(
dists
>
sq_breaks
,
dim
=-
1
)
errors
=
softmax_cross_entropy
(
errors
=
softmax_cross_entropy
(
pred_distr
,
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_bins
,
num_bins
),
torch
.
nn
.
functional
.
one_hot
(
true_bins
,
num_bins
),
)
)
square_mask
=
mask
[...,
None
]
*
mask
[...,
None
,
:]
square_mask
=
pseudo_beta_
mask
[...,
None
]
*
pseudo_beta_
mask
[...,
None
,
:]
mean
=
(
mean
=
(
torch
.
sum
(
errors
*
square_mask
,
dim
=
(
-
1
,
-
2
))
/
torch
.
sum
(
errors
*
square_mask
,
dim
=
(
-
1
,
-
2
))
/
...
@@ -417,7 +518,7 @@ def between_residue_bond_loss(
...
@@ -417,7 +518,7 @@ def between_residue_bond_loss(
# The C-N bond to proline has slightly different length because of the ring.
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline
=
(
next_is_proline
=
(
aatype
[...,
1
:]
==
residue_constants
.
resname_to_idx
[
'
PRO
'
]
aatype
[...,
1
:]
==
residue_constants
.
resname_to_idx
[
"
PRO
"
]
)
)
gt_length
=
(
gt_length
=
(
(
~
next_is_proline
)
*
residue_constants
.
between_res_bond_length_c_n
[
0
]
(
~
next_is_proline
)
*
residue_constants
.
between_res_bond_length_c_n
[
0
]
...
@@ -609,7 +710,7 @@ def between_residue_clash_loss(
...
@@ -609,7 +710,7 @@ def between_residue_clash_loss(
dists_mask
*=
(
1.
-
c_n_bonds
)
dists_mask
*=
(
1.
-
c_n_bonds
)
# Disulfide bridge between two cysteines is no clash.
# Disulfide bridge between two cysteines is no clash.
cys
=
residue_constants
.
restype_name_to_atom14_names
[
'
CYS
'
]
cys
=
residue_constants
.
restype_name_to_atom14_names
[
"
CYS
"
]
cys_sg_idx
=
cys
.
index
(
'SG'
)
cys_sg_idx
=
cys
.
index
(
'SG'
)
cys_sg_idx
=
residue_index
.
new_tensor
(
cys_sg_idx
)
cys_sg_idx
=
residue_index
.
new_tensor
(
cys_sg_idx
)
cys_sg_idx
=
cys_sg_idx
.
reshape
(
cys_sg_idx
=
cys_sg_idx
.
reshape
(
...
@@ -768,18 +869,20 @@ def within_residue_violations(
...
@@ -768,18 +869,20 @@ def within_residue_violations(
def
find_structural_violations
(
def
find_structural_violations
(
batch
:
Dict
[
str
,
torch
.
Tensor
],
batch
:
Dict
[
str
,
torch
.
Tensor
],
atom14_pred_positions
:
torch
.
Tensor
,
atom14_pred_positions
:
torch
.
Tensor
,
config
:
ml_collections
.
ConfigDict
violation_tolerance_factor
:
float
,
clash_overlap_tolerance
:
float
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Computes several checks for structural violations."""
"""Computes several checks for structural violations."""
# Compute between residue backbone violations of bonds and angles.
# Compute between residue backbone violations of bonds and angles.
connection_violations
=
between_residue_bond_loss
(
connection_violations
=
between_residue_bond_loss
(
pred_atom_positions
=
atom14_pred_positions
,
pred_atom_positions
=
atom14_pred_positions
,
pred_atom_mask
=
batch
[
'
atom14_atom_exists
'
],
pred_atom_mask
=
batch
[
"
atom14_atom_exists
"
],
residue_index
=
batch
[
'
residue_index
'
],
residue_index
=
batch
[
"
residue_index
"
],
aatype
=
batch
[
'
aatype
'
],
aatype
=
batch
[
"
aatype
"
],
tolerance_factor_soft
=
config
.
violation_tolerance_factor
,
tolerance_factor_soft
=
violation_tolerance_factor
,
tolerance_factor_hard
=
config
.
violation_tolerance_factor
tolerance_factor_hard
=
violation_tolerance_factor
)
)
# Compute the Van der Waals radius for every atom
# Compute the Van der Waals radius for every atom
...
@@ -793,31 +896,31 @@ def find_structural_violations(
...
@@ -793,31 +896,31 @@ def find_structural_violations(
atomtype_radius
atomtype_radius
)
)
atom14_atom_radius
=
(
atom14_atom_radius
=
(
batch
[
'
atom14_atom_exists
'
]
*
batch
[
"
atom14_atom_exists
"
]
*
atomtype_radius
[
batch
[
'
residx_atom14_to_atom37
'
]]
atomtype_radius
[
batch
[
"
residx_atom14_to_atom37
"
]]
)
)
# Compute the between residue clash loss.
# Compute the between residue clash loss.
between_residue_clashes
=
between_residue_clash_loss
(
between_residue_clashes
=
between_residue_clash_loss
(
atom14_pred_positions
=
atom14_pred_positions
,
atom14_pred_positions
=
atom14_pred_positions
,
atom14_atom_exists
=
batch
[
'
atom14_atom_exists
'
],
atom14_atom_exists
=
batch
[
"
atom14_atom_exists
"
],
atom14_atom_radius
=
atom14_atom_radius
,
atom14_atom_radius
=
atom14_atom_radius
,
residue_index
=
batch
[
'
residue_index
'
],
residue_index
=
batch
[
"
residue_index
"
],
overlap_tolerance_soft
=
config
.
clash_overlap_tolerance
,
overlap_tolerance_soft
=
clash_overlap_tolerance
,
overlap_tolerance_hard
=
config
.
clash_overlap_tolerance
overlap_tolerance_hard
=
clash_overlap_tolerance
)
)
# Compute all within-residue violations (clashes,
# Compute all within-residue violations (clashes,
# bond length and angle violations).
# bond length and angle violations).
restype_atom14_bounds
=
residue_constants
.
make_atom14_dists_bounds
(
restype_atom14_bounds
=
residue_constants
.
make_atom14_dists_bounds
(
overlap_tolerance
=
config
.
clash_overlap_tolerance
,
overlap_tolerance
=
clash_overlap_tolerance
,
bond_length_tolerance_factor
=
config
.
violation_tolerance_factor
bond_length_tolerance_factor
=
violation_tolerance_factor
)
)
atom14_dists_lower_bound
=
restype_atom14_bounds
[
'
lower_bound
'
][
atom14_dists_lower_bound
=
restype_atom14_bounds
[
"
lower_bound
"
][
batch
[
'
aatype
'
]
batch
[
"
aatype
"
]
]
]
atom14_dists_upper_bound
=
restype_atom14_bounds
[
'
upper_bound
'
][
atom14_dists_upper_bound
=
restype_atom14_bounds
[
"
upper_bound
"
][
batch
[
'
aatype
'
]
batch
[
"
aatype
"
]
]
]
atom14_dists_lower_bound
=
atom14_pred_positions
.
new_tensor
(
atom14_dists_lower_bound
=
atom14_pred_positions
.
new_tensor
(
atom14_dists_lower_bound
atom14_dists_lower_bound
...
@@ -827,7 +930,7 @@ def find_structural_violations(
...
@@ -827,7 +930,7 @@ def find_structural_violations(
)
)
residue_violations
=
within_residue_violations
(
residue_violations
=
within_residue_violations
(
atom14_pred_positions
=
atom14_pred_positions
,
atom14_pred_positions
=
atom14_pred_positions
,
atom14_atom_exists
=
batch
[
'
atom14_atom_exists
'
],
atom14_atom_exists
=
batch
[
"
atom14_atom_exists
"
],
atom14_dists_lower_bound
=
atom14_dists_lower_bound
,
atom14_dists_lower_bound
=
atom14_dists_lower_bound
,
atom14_dists_upper_bound
=
atom14_dists_upper_bound
,
atom14_dists_upper_bound
=
atom14_dists_upper_bound
,
tighten_bounds_for_loss
=
0.0
tighten_bounds_for_loss
=
0.0
...
@@ -837,12 +940,12 @@ def find_structural_violations(
...
@@ -837,12 +940,12 @@ def find_structural_violations(
per_residue_violations_mask
=
torch
.
max
(
per_residue_violations_mask
=
torch
.
max
(
torch
.
stack
(
torch
.
stack
(
[
[
connection_violations
[
'
per_residue_violation_mask
'
],
connection_violations
[
"
per_residue_violation_mask
"
],
torch
.
max
(
torch
.
max
(
between_residue_clashes
[
'
per_atom_clash_mask
'
],
dim
=-
1
between_residue_clashes
[
"
per_atom_clash_mask
"
],
dim
=-
1
)[
0
],
)[
0
],
torch
.
max
(
torch
.
max
(
residue_violations
[
'
per_atom_violations
'
],
dim
=-
1
residue_violations
[
"
per_atom_violations
"
],
dim
=-
1
)[
0
],
)[
0
],
],
],
dim
=-
1
,
dim
=-
1
,
...
@@ -853,27 +956,27 @@ def find_structural_violations(
...
@@ -853,27 +956,27 @@ def find_structural_violations(
return
{
return
{
'between_residues'
:
{
'between_residues'
:
{
'bonds_c_n_loss_mean'
:
'bonds_c_n_loss_mean'
:
connection_violations
[
'
c_n_loss_mean
'
],
# ()
connection_violations
[
"
c_n_loss_mean
"
],
# ()
'angles_ca_c_n_loss_mean'
:
'angles_ca_c_n_loss_mean'
:
connection_violations
[
'
ca_c_n_loss_mean
'
],
# ()
connection_violations
[
"
ca_c_n_loss_mean
"
],
# ()
'angles_c_n_ca_loss_mean'
:
'angles_c_n_ca_loss_mean'
:
connection_violations
[
'
c_n_ca_loss_mean
'
],
# ()
connection_violations
[
"
c_n_ca_loss_mean
"
],
# ()
'connections_per_residue_loss_sum'
:
'connections_per_residue_loss_sum'
:
connection_violations
[
'
per_residue_loss_sum
'
],
# (N)
connection_violations
[
"
per_residue_loss_sum
"
],
# (N)
'connections_per_residue_violation_mask'
:
'connections_per_residue_violation_mask'
:
connection_violations
[
'
per_residue_violation_mask
'
],
# (N)
connection_violations
[
"
per_residue_violation_mask
"
],
# (N)
'clashes_mean_loss'
:
'clashes_mean_loss'
:
between_residue_clashes
[
'
mean_loss
'
],
# ()
between_residue_clashes
[
"
mean_loss
"
],
# ()
'clashes_per_atom_loss_sum'
:
'clashes_per_atom_loss_sum'
:
between_residue_clashes
[
'
per_atom_loss_sum
'
],
# (N, 14)
between_residue_clashes
[
"
per_atom_loss_sum
"
],
# (N, 14)
'clashes_per_atom_clash_mask'
:
'clashes_per_atom_clash_mask'
:
between_residue_clashes
[
'
per_atom_clash_mask
'
],
# (N, 14)
between_residue_clashes
[
"
per_atom_clash_mask
"
],
# (N, 14)
},
},
'within_residues'
:
{
'within_residues'
:
{
'per_atom_loss_sum'
:
'per_atom_loss_sum'
:
residue_violations
[
'
per_atom_loss_sum
'
],
# (N, 14)
residue_violations
[
"
per_atom_loss_sum
"
],
# (N, 14)
'per_atom_violations'
:
'per_atom_violations'
:
residue_violations
[
'
per_atom_violations
'
],
# (N, 14),
residue_violations
[
"
per_atom_violations
"
],
# (N, 14),
},
},
'total_per_residue_violations_mask'
:
'total_per_residue_violations_mask'
:
per_residue_violations_mask
,
# (N)
per_residue_violations_mask
,
# (N)
...
@@ -943,35 +1046,35 @@ def compute_violation_metrics(
...
@@ -943,35 +1046,35 @@ def compute_violation_metrics(
ret
=
{}
ret
=
{}
extreme_ca_ca_violations
=
extreme_ca_ca_distance_violations
(
extreme_ca_ca_violations
=
extreme_ca_ca_distance_violations
(
pred_atom_positions
=
atom14_pred_positions
,
pred_atom_positions
=
atom14_pred_positions
,
pred_atom_mask
=
batch
[
'
atom14_atom_exists
'
],
pred_atom_mask
=
batch
[
"
atom14_atom_exists
"
],
residue_index
=
batch
[
'
residue_index
'
]
residue_index
=
batch
[
"
residue_index
"
]
)
)
ret
[
'
violations_extreme_ca_ca_distance
'
]
=
extreme_ca_ca_violations
ret
[
"
violations_extreme_ca_ca_distance
"
]
=
extreme_ca_ca_violations
ret
[
'
violations_between_residue_bond
'
]
=
masked_mean
(
ret
[
"
violations_between_residue_bond
"
]
=
masked_mean
(
batch
[
'
seq_mask
'
],
batch
[
"
seq_mask
"
],
violations
[
'
between_residues
'
][
violations
[
"
between_residues
"
][
'connections_per_residue_violation_mask'
'connections_per_residue_violation_mask'
],
],
dim
=-
1
,
dim
=-
1
,
)
)
ret
[
'
violations_between_residue_clash
'
]
=
masked_mean
(
ret
[
"
violations_between_residue_clash
"
]
=
masked_mean
(
mask
=
batch
[
'
seq_mask
'
],
mask
=
batch
[
"
seq_mask
"
],
value
=
torch
.
max
(
value
=
torch
.
max
(
violations
[
'
between_residues
'
][
'
clashes_per_atom_clash_mask
'
],
violations
[
"
between_residues
"
][
"
clashes_per_atom_clash_mask
"
],
dim
=-
1
dim
=-
1
)[
0
],
)[
0
],
dim
=-
1
,
dim
=-
1
,
)
)
ret
[
'
violations_within_residue
'
]
=
masked_mean
(
ret
[
"
violations_within_residue
"
]
=
masked_mean
(
mask
=
batch
[
'
seq_mask
'
],
mask
=
batch
[
"
seq_mask
"
],
value
=
torch
.
max
(
value
=
torch
.
max
(
violations
[
'
within_residues
'
][
'
per_atom_violations
'
],
dim
=-
1
violations
[
"
within_residues
"
][
"
per_atom_violations
"
],
dim
=-
1
)[
0
],
)[
0
],
dim
=-
1
,
dim
=-
1
,
)
)
ret
[
'
violations_per_residue
'
]
=
masked_mean
(
ret
[
"
violations_per_residue
"
]
=
masked_mean
(
mask
=
batch
[
'
seq_mask
'
],
mask
=
batch
[
"
seq_mask
"
],
value
=
violations
[
'
total_per_residue_violations_mask
'
],
value
=
violations
[
"
total_per_residue_violations_mask
"
],
dim
=-
1
,
dim
=-
1
,
)
)
return
ret
return
ret
...
@@ -994,6 +1097,27 @@ def compute_violation_metrics_np(
...
@@ -994,6 +1097,27 @@ def compute_violation_metrics_np(
return
tree_map
(
to_np
,
out
,
torch
.
Tensor
)
return
tree_map
(
to_np
,
out
,
torch
.
Tensor
)
def
violation_loss
(
violations
:
Dict
[
str
,
torch
.
Tensor
],
atom14_atom_exists
:
torch
.
Tensor
,
eps
=
1e-6
,
)
->
torch
.
Tensor
:
num_atoms
=
torch
.
sum
(
atom14_atom_exists
)
l_clash
=
torch
.
sum
(
violations
[
"between_residues"
][
"clashes_per_atom_loss_sum"
]
+
violations
[
"within_residues"
][
"per_atom_loss_sum"
]
)
l_clash
=
l_clash
/
(
eps
+
num_atoms
)
loss
=
(
violations
[
"between_residues"
][
"bonds_c_n_loss_mean"
]
+
violations
[
"between_residues"
][
"angles_ca_c_n_loss_mean"
]
+
violations
[
"between_residues"
][
"angles_c_n_ca_loss_mean"
]
+
l_clash
)
return
loss
def
compute_renamed_ground_truth
(
def
compute_renamed_ground_truth
(
batch
:
Dict
[
str
,
torch
.
Tensor
],
batch
:
Dict
[
str
,
torch
.
Tensor
],
atom14_pred_positions
:
torch
.
Tensor
,
atom14_pred_positions
:
torch
.
Tensor
,
...
@@ -1038,7 +1162,7 @@ def compute_renamed_ground_truth(
...
@@ -1038,7 +1162,7 @@ def compute_renamed_ground_truth(
)
)
)
)
atom14_gt_positions
=
batch
[
'
atom14_gt_positions
'
]
atom14_gt_positions
=
batch
[
"
atom14_gt_positions
"
]
gt_dists
=
torch
.
sqrt
(
gt_dists
=
torch
.
sqrt
(
eps
+
eps
+
torch
.
sum
(
torch
.
sum
(
...
@@ -1050,7 +1174,7 @@ def compute_renamed_ground_truth(
...
@@ -1050,7 +1174,7 @@ def compute_renamed_ground_truth(
)
)
)
)
atom14_alt_gt_positions
=
batch
[
'
atom14_alt_gt_positions
'
]
atom14_alt_gt_positions
=
batch
[
"
atom14_alt_gt_positions
"
]
alt_gt_dists
=
torch
.
sqrt
(
alt_gt_dists
=
torch
.
sqrt
(
eps
+
eps
+
torch
.
sum
(
torch
.
sum
(
...
@@ -1065,8 +1189,8 @@ def compute_renamed_ground_truth(
...
@@ -1065,8 +1189,8 @@ def compute_renamed_ground_truth(
lddt
=
torch
.
sqrt
(
eps
+
(
pred_dists
-
gt_dists
)
**
2
)
lddt
=
torch
.
sqrt
(
eps
+
(
pred_dists
-
gt_dists
)
**
2
)
alt_lddt
=
torch
.
sqrt
(
eps
+
(
pred_dists
-
alt_gt_dists
)
**
2
)
alt_lddt
=
torch
.
sqrt
(
eps
+
(
pred_dists
-
alt_gt_dists
)
**
2
)
atom14_gt_exists
=
batch
[
'
atom14_gt_exists
'
]
atom14_gt_exists
=
batch
[
"
atom14_gt_exists
"
]
atom14_atom_is_ambiguous
=
batch
[
'
atom14_atom_is_ambiguous
'
]
atom14_atom_is_ambiguous
=
batch
[
"
atom14_atom_is_ambiguous
"
]
mask
=
(
mask
=
(
atom14_gt_exists
[...,
None
,
:,
None
]
*
atom14_gt_exists
[...,
None
,
:,
None
]
*
atom14_atom_is_ambiguous
[...,
None
,
:,
None
]
*
atom14_atom_is_ambiguous
[...,
None
,
:,
None
]
*
...
@@ -1089,13 +1213,13 @@ def compute_renamed_ground_truth(
...
@@ -1089,13 +1213,13 @@ def compute_renamed_ground_truth(
renamed_atom14_gt_mask
=
(
renamed_atom14_gt_mask
=
(
(
1.
-
alt_naming_is_better
[...,
None
])
*
atom14_gt_exists
+
(
1.
-
alt_naming_is_better
[...,
None
])
*
atom14_gt_exists
+
alt_naming_is_better
[...,
None
]
*
batch
[
'
atom14_alt_gt_exists
'
]
alt_naming_is_better
[...,
None
]
*
batch
[
"
atom14_alt_gt_exists
"
]
)
)
return
{
return
{
'
alt_naming_is_better
'
:
alt_naming_is_better
,
"
alt_naming_is_better
"
:
alt_naming_is_better
,
'
renamed_atom14_gt_positions
'
:
renamed_atom14_gt_positions
,
"
renamed_atom14_gt_positions
"
:
renamed_atom14_gt_positions
,
'
renamed_atom14_gt_exists
'
:
renamed_atom14_gt_mask
,
"
renamed_atom14_gt_exists
"
:
renamed_atom14_gt_mask
,
}
}
...
@@ -1103,9 +1227,105 @@ def experimentally_resolved_loss(
...
@@ -1103,9 +1227,105 @@ def experimentally_resolved_loss(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
atom37_atom_exists
:
torch
.
Tensor
,
atom37_atom_exists
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
resolution
:
torch
.
Tensor
,
min_resolution
:
float
,
max_resolution
:
float
,
eps
:
float
=
1e-8
,
eps
:
float
=
1e-8
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
errors
=
sigmoid_cross_entropy
(
logits
,
all_atom_mask
)
errors
=
sigmoid_cross_entropy
(
logits
,
all_atom_mask
)
loss_num
=
torch
.
sum
(
errors
*
atom37_atom_exists
,
dim
=
(
-
1
,
-
2
))
loss_num
=
torch
.
sum
(
errors
*
atom37_atom_exists
,
dim
=
(
-
1
,
-
2
))
loss
=
loss_num
/
(
eps
+
torch
.
sum
(
atom37_atom_exists
,
dim
=
(
-
1
,
-
2
)))
loss
=
loss_num
/
(
eps
+
torch
.
sum
(
atom37_atom_exists
,
dim
=
(
-
1
,
-
2
)))
loss
*=
(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
return
loss
def
masked_msa_loss
(
logits
,
true_msa
,
bert_mask
,
eps
=
1e-8
):
errors
=
softmax_cross_entropy
(
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
23
,
)
loss
=
(
torch
.
sum
(
errors
*
bert_mask
,
dim
=
(
-
1
,
-
2
))
/
(
eps
+
torch
.
sum
(
bert_mask
,
dim
=
(
-
1
,
-
2
)))
)
return
loss
return
loss
class
AlphaFoldLoss
(
nn
.
Module
):
""" Aggregation of the various losses described in the supplement """
def
__init__
(
self
,
config
):
super
(
AlphaFoldLoss
,
self
).
__init__
()
self
.
config
=
config
def
forward
(
self
,
out
,
batch
):
cum_loss
=
0
if
(
"violation"
not
in
out
.
keys
()
and
self
.
config
.
violation
.
weight
):
out
[
"violation"
]
=
find_structural_violations
(
batch
,
out
[
"sm"
][
"positions"
][
-
1
],
**
self
.
config
.
violation
,
)
if
(
"renamed_atom14_gt_positions"
not
in
out
.
keys
()):
batch
.
update
(
compute_renamed_ground_truth
(
batch
,
out
[
"sm"
][
"positions"
][
-
1
],
))
loss_fns
=
{
"distogram"
:
lambda
:
distogram_loss
(
logits
=
out
[
"distogram_logits"
],
{
**
batch
,
**
self
.
config
.
distogram
},
),
"experimentally_resolved"
:
lambda
:
experimentally_resolved_loss
(
logits
=
out
[
"experimentally_resolved"
],
{
**
batch
,
**
self
.
config
.
experimentally_resolved
},
),
"fape"
:
lambda
:
fape_loss
(
out
,
batch
,
self
.
config
.
fape
,
),
"lddt"
:
lambda
:
lddt_loss
(
logits
=
out
[
"lddt_logits"
],
all_atom_pred_pos
=
out
[
"final_atom_positions"
]
{
**
batch
,
**
self
.
config
.
lddt
},
),
"masked_msa"
:
lambda
:
masked_msa_loss
(
logits
=
out
[
"masked_msa_logits"
],
{
**
batch
,
**
self
.
config
.
masked_msa
},
),
"supervised_chi"
:
lambda
:
supervised_chi_loss
(
out
[
"sm"
][
"angles"
],
out
[
"sm"
][
"unnormalized_angles"
],
{
**
batch
,
**
self
.
config
.
supervised_chi
},
),
"violation"
:
lambda
:
violation_loss
(
out
[
"violation"
],
**
batch
,
),
}
for
k
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
k
].
weight
if
(
weight
):
cum_loss
+=
weight
*
loss_fn
()
return
cum_loss
config.py
View file @
df6b97f2
...
@@ -180,4 +180,54 @@ config = mlc.ConfigDict({
...
@@ -180,4 +180,54 @@ config = mlc.ConfigDict({
"max_outer_iterations"
:
20
,
"max_outer_iterations"
:
20
,
"exclude_residues"
:
[],
"exclude_residues"
:
[],
},
},
"loss"
:
{
"distogram"
:
{
"min_bin"
:
2.3125
,
"max_bin"
:
21.6875
,
"no_bins"
:
64
,
"eps"
:
1e-6
,
"weight"
:
0.3
,
},
"experimentally_resolved"
:
{
"eps"
:
1e-8
,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"weight"
:
0.
,
},
"fape"
:
{
"backbone"
:
{
"clamp_distance"
:
10.
,
"loss_unit_distance"
:
10.
,
"weight"
:
0.5
,
}
"sidechain"
:
{
"clamp_distance"
:
10.
,
"length_scale"
:
10.
,
"weight"
:
0.5
,
}
"weight"
:
1.0
,
},
"lddt"
:
{
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"cutoff"
:
15.
,
"num_bins"
:
50
,
"eps"
:
1e-10
,
"weight"
:
0.01
,
},
"masked_msa"
:
{
"eps"
:
1e-8
,
"weight"
:
2.0
,
},
"supervised_chi"
:
{
"chi_weight"
:
0.5
,
"angle_norm_weight"
:
0.01
,
"eps"
:
1e-6
,
"weight"
:
1.0
,
},
"violation"
:
{
"eps"
:
1e-6
,
"weight"
:
0.
,
},
},
})
})
tests/test_structure_module.py
View file @
df6b97f2
...
@@ -205,7 +205,7 @@ class TestAngleResnet(unittest.TestCase):
...
@@ -205,7 +205,7 @@ class TestAngleResnet(unittest.TestCase):
a
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
a
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
a_initial
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
a_initial
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
a
=
ar
(
a
,
a_initial
)
_
,
a
=
ar
(
a
,
a_initial
)
self
.
assertTrue
(
a
.
shape
==
(
batch_size
,
n
,
no_angles
,
2
))
self
.
assertTrue
(
a
.
shape
==
(
batch_size
,
n
,
no_angles
,
2
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment