Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
51556d52
Commit
51556d52
authored
Jun 14, 2023
by
Christina Floristean
Browse files
Added multimer changes for loss functions
parent
fbfbd808
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
424 additions
and
63 deletions
+424
-63
openfold/config.py
openfold/config.py
+86
-1
openfold/model/heads.py
openfold/model/heads.py
+9
-1
openfold/model/model.py
openfold/model/model.py
+3
-0
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+6
-3
openfold/utils/geometry/vector.py
openfold/utils/geometry/vector.py
+1
-1
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+1
-1
openfold/utils/loss.py
openfold/utils/loss.py
+177
-38
tests/config.py
tests/config.py
+1
-1
tests/data_utils.py
tests/data_utils.py
+3
-1
tests/test_loss.py
tests/test_loss.py
+137
-16
No files found.
openfold/config.py
View file @
51556d52
...
...
@@ -648,6 +648,7 @@ config = mlc.ConfigDict(
"violation"
:
{
"violation_tolerance_factor"
:
12.0
,
"clash_overlap_tolerance"
:
1.5
,
"average_clashes"
:
False
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
0.0
,
},
...
...
@@ -660,6 +661,12 @@ config = mlc.ConfigDict(
"weight"
:
0.
,
"enabled"
:
tm_enabled
,
},
"chain_center_of_mass"
:
{
"clamp_distance"
:
-
4.0
,
"weight"
:
0.
,
"eps"
:
eps
,
"enabled"
:
False
,
},
"eps"
:
eps
,
},
"ema"
:
{
"decay"
:
0.999
},
...
...
@@ -802,7 +809,9 @@ multimer_model_config_update = {
"tm"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
"enabled"
:
tm_enabled
,
"ptm_weight"
:
0.2
,
"iptm_weight"
:
0.8
,
"enabled"
:
True
,
},
"masked_msa"
:
{
"c_m"
:
c_m
,
...
...
@@ -813,5 +822,81 @@ multimer_model_config_update = {
"c_out"
:
37
,
},
},
"loss"
:
{
"distogram"
:
{
"min_bin"
:
2.3125
,
"max_bin"
:
21.6875
,
"no_bins"
:
64
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
0.3
,
},
"experimentally_resolved"
:
{
"eps"
:
eps
,
# 1e-8,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"weight"
:
0.0
,
},
"fape"
:
{
"intra_chain_backbone"
:
{
"clamp_distance"
:
10.0
,
"loss_unit_distance"
:
10.0
,
"weight"
:
0.5
,
},
"interface"
:
{
"clamp_distance"
:
30.0
,
"loss_unit_distance"
:
20.0
,
"weight"
:
0.5
,
},
"sidechain"
:
{
"clamp_distance"
:
10.0
,
"length_scale"
:
10.0
,
"weight"
:
0.5
,
},
"eps"
:
1e-4
,
"weight"
:
1.0
,
},
"plddt_loss"
:
{
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"cutoff"
:
15.0
,
"no_bins"
:
50
,
"eps"
:
eps
,
# 1e-10,
"weight"
:
0.01
,
},
"masked_msa"
:
{
"num_classes"
:
23
,
"eps"
:
eps
,
# 1e-8,
"weight"
:
2.0
,
},
"supervised_chi"
:
{
"chi_weight"
:
0.5
,
"angle_norm_weight"
:
0.01
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
1.0
,
},
"violation"
:
{
"violation_tolerance_factor"
:
12.0
,
"clash_overlap_tolerance"
:
1.5
,
"average_clashes"
:
True
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
0.03
,
# Not finetuning
},
"tm"
:
{
"max_bin"
:
31
,
"no_bins"
:
64
,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"eps"
:
eps
,
# 1e-8,
"weight"
:
0.1
,
"enabled"
:
True
,
},
"chain_center_of_mass"
:
{
"clamp_distance"
:
-
4.0
,
"weight"
:
0.05
,
"eps"
:
eps
,
"enabled"
:
True
,
},
"eps"
:
eps
,
},
"recycle_early_stop_tolerance"
:
0.5
}
openfold/model/heads.py
View file @
51556d52
...
...
@@ -76,9 +76,17 @@ class AuxiliaryHeads(nn.Module):
if
self
.
config
.
tm
.
enabled
:
tm_logits
=
self
.
tm
(
outputs
[
"pair"
])
aux_out
[
"tm_logits"
]
=
tm_logits
aux_out
[
"p
redicted_
tm_score"
]
=
compute_tm
(
aux_out
[
"ptm_score"
]
=
compute_tm
(
tm_logits
,
**
self
.
config
.
tm
)
asym_id
=
outputs
.
get
(
"asym_id"
)
if
asym_id
is
not
None
:
aux_out
[
"iptm_score"
]
=
compute_tm
(
tm_logits
,
asym_id
=
asym_id
,
interface
=
True
,
**
self
.
config
.
tm
)
aux_out
[
"weighted_ptm_score"
]
=
(
self
.
config
.
tm
[
"iptm_weight"
]
*
aux_out
[
"iptm_score"
]
+
self
.
config
.
tm
[
"ptm_weight"
]
*
aux_out
[
"ptm_score"
])
aux_out
.
update
(
compute_predicted_aligned_error
(
tm_logits
,
...
...
openfold/model/model.py
View file @
51556d52
...
...
@@ -555,6 +555,9 @@ class AlphaFold(nn.Module):
else
:
break
if
"asym_id"
in
batch
:
outputs
[
"asym_id"
]
=
feats
[
"asym_id"
]
# Run auxiliary heads
outputs
.
update
(
self
.
aux_heads
(
outputs
))
...
...
openfold/model/triangular_multiplicative_update.py
View file @
51556d52
...
...
@@ -435,7 +435,7 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
# reduced-precision modes
a_std
=
a
.
std
()
b_std
=
b
.
std
()
if
(
a_std
!=
0.
and
b_std
!=
0.
):
if
(
is_fp16_enabled
()
and
a_std
!=
0.
and
b_std
!=
0.
):
a
=
a
/
a
.
std
()
b
=
b
/
b
.
std
()
...
...
@@ -589,8 +589,11 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
# Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
a
=
a
/
a
.
std
()
b
=
b
/
b
.
std
()
a_std
=
a
.
std
()
b_std
=
b
.
std
()
if
(
is_fp16_enabled
()
and
a_std
!=
0.
and
b_std
!=
0.
):
a
=
a
/
a
.
std
()
b
=
b
/
b
.
std
()
if
(
is_fp16_enabled
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
...
...
openfold/utils/geometry/vector.py
View file @
51556d52
...
...
@@ -193,7 +193,7 @@ def square_euclidean_distance(
difference
=
vec1
-
vec2
distance
=
difference
.
dot
(
difference
)
if
epsilon
:
distance
=
torch
.
maximum
(
distance
,
epsilon
)
distance
=
torch
.
clamp
(
distance
,
min
=
epsilon
)
return
distance
...
...
openfold/utils/import_weights.py
View file @
51556d52
...
...
@@ -617,7 +617,7 @@ def generate_translation_dict(model, version, is_multimer=False):
translations
[
"evoformer"
].
update
(
template_param_dict
)
if
"_ptm"
in
version
:
if
is_multimer
or
"_ptm"
in
version
:
translations
[
"predicted_aligned_error_head"
]
=
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
}
...
...
openfold/utils/loss.py
View file @
51556d52
...
...
@@ -25,6 +25,8 @@ from typing import Dict, Optional, Tuple
from
openfold.np
import
residue_constants
from
openfold.utils
import
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.geometry.vector
import
Vec3Array
,
euclidean_distance
from
openfold.utils.all_atom_multimer
import
get_rc_tensor
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
...
...
@@ -87,6 +89,7 @@ def compute_fape(
target_positions
:
torch
.
Tensor
,
positions_mask
:
torch
.
Tensor
,
length_scale
:
float
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
l1_clamp_distance
:
Optional
[
float
]
=
None
,
eps
=
1e-8
,
)
->
torch
.
Tensor
:
...
...
@@ -108,6 +111,9 @@ def compute_fape(
[*, N_pts] positions mask
length_scale:
Length scale by which the loss is divided
pair_mask:
[*, N_frames, N_pts] mask to use for
separating intra- from inter-chain losses.
l1_clamp_distance:
Cutoff above which distance errors are disregarded
eps:
...
...
@@ -134,21 +140,30 @@ def compute_fape(
normed_error
=
normed_error
*
frames_mask
[...,
None
]
normed_error
=
normed_error
*
positions_mask
[...,
None
,
:]
# FP16-friendly averaging. Roughly equivalent to:
#
# norm_factor = (
# torch.sum(frames_mask, dim=-1) *
# torch.sum(positions_mask, dim=-1)
# )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
# ("roughly" because eps is necessarily duplicated in the latter)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
(
normed_error
/
(
eps
+
torch
.
sum
(
frames_mask
,
dim
=-
1
))[...,
None
]
)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
normed_error
/
(
eps
+
torch
.
sum
(
positions_mask
,
dim
=-
1
))
if
pair_mask
is
not
None
:
normed_error
=
normed_error
*
pair_mask
normed_error
=
torch
.
sum
(
normed_error
,
dim
=
(
-
1
,
-
2
))
mask
=
frames_mask
[...,
None
]
*
positions_mask
[...,
None
,
:]
*
pair_mask
norm_factor
=
torch
.
sum
(
mask
,
dim
=
(
-
2
,
-
1
))
normed_error
=
normed_error
/
(
eps
+
norm_factor
)
else
:
# FP16-friendly averaging. Roughly equivalent to:
#
# norm_factor = (
# torch.sum(frames_mask, dim=-1) *
# torch.sum(positions_mask, dim=-1)
# )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
# ("roughly" because eps is necessarily duplicated in the latter)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
(
normed_error
/
(
eps
+
torch
.
sum
(
frames_mask
,
dim
=-
1
))[...,
None
]
)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
normed_error
/
(
eps
+
torch
.
sum
(
positions_mask
,
dim
=-
1
))
return
normed_error
...
...
@@ -157,6 +172,7 @@ def backbone_loss(
backbone_rigid_tensor
:
torch
.
Tensor
,
backbone_rigid_mask
:
torch
.
Tensor
,
traj
:
torch
.
Tensor
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
clamp_distance
:
float
=
10.0
,
loss_unit_distance
:
float
=
10.0
,
...
...
@@ -184,6 +200,7 @@ def backbone_loss(
pred_aff
.
get_trans
(),
gt_aff
[
None
].
get_trans
(),
backbone_rigid_mask
[
None
],
pair_mask
=
pair_mask
,
l1_clamp_distance
=
clamp_distance
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
...
...
@@ -196,6 +213,7 @@ def backbone_loss(
pred_aff
.
get_trans
(),
gt_aff
[
None
].
get_trans
(),
backbone_rigid_mask
[
None
],
pair_mask
=
pair_mask
,
l1_clamp_distance
=
None
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
...
...
@@ -253,6 +271,7 @@ def sidechain_loss(
sidechain_atom_pos
,
renamed_atom14_gt_positions
,
renamed_atom14_gt_exists
,
pair_mask
=
None
,
l1_clamp_distance
=
clamp_distance
,
length_scale
=
length_scale
,
eps
=
eps
,
...
...
@@ -266,10 +285,29 @@ def fape_loss(
batch
:
Dict
[
str
,
torch
.
Tensor
],
config
:
ml_collections
.
ConfigDict
,
)
->
torch
.
Tensor
:
bb_loss
=
backbone_loss
(
traj
=
out
[
"sm"
][
"frames"
],
**
{
**
batch
,
**
config
.
backbone
},
)
traj
=
out
[
"sm"
][
"frames"
]
asym_id
=
batch
.
get
(
"asym_id"
)
if
asym_id
is
not
None
:
intra_chain_mask
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]).
to
(
dtype
=
traj
.
dtype
)
intra_chain_bb_loss
=
backbone_loss
(
traj
=
traj
,
pair_mask
=
intra_chain_mask
,
**
{
**
batch
,
**
config
.
intra_chain_backbone
},
)
interface_bb_loss
=
backbone_loss
(
traj
=
traj
,
pair_mask
=
1.
-
intra_chain_mask
,
**
{
**
batch
,
**
config
.
interface_backbone
},
)
weighted_bb_loss
=
(
intra_chain_bb_loss
*
config
.
intra_chain_backbone
.
weight
+
interface_bb_loss
*
config
.
interface_backbone
.
weight
)
else
:
bb_loss
=
backbone_loss
(
traj
=
traj
,
**
{
**
batch
,
**
config
.
backbone
},
)
weighted_bb_loss
=
bb_loss
*
config
.
backbone
.
weight
sc_loss
=
sidechain_loss
(
out
[
"sm"
][
"sidechain_frames"
],
...
...
@@ -277,7 +315,7 @@ def fape_loss(
**
{
**
batch
,
**
config
.
sidechain
},
)
loss
=
config
.
backbone
.
weight
*
bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
loss
=
weight
ed_
bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
# Average over the batch dimension
loss
=
torch
.
mean
(
loss
)
...
...
@@ -654,7 +692,7 @@ def compute_tm(
n
=
residue_weights
.
shape
[
-
1
]
pair_mask
=
residue_weights
.
new_ones
((
n
,
n
),
dtype
=
torch
.
int32
)
if
interface
:
pair_mask
*=
(
asym_id
[...,
None
]
!=
asym_id
[...,
None
,
:])
pair_mask
*=
(
asym_id
[...,
None
]
!=
asym_id
[...,
None
,
:])
.
to
(
dtype
=
pair_mask
.
dtype
)
predicted_tm_term
*=
pair_mask
...
...
@@ -891,6 +929,7 @@ def between_residue_clash_loss(
atom14_atom_exists
:
torch
.
Tensor
,
atom14_atom_radius
:
torch
.
Tensor
,
residue_index
:
torch
.
Tensor
,
asym_id
:
Optional
[
torch
.
Tensor
]
=
None
,
overlap_tolerance_soft
=
1.5
,
overlap_tolerance_hard
=
1.5
,
eps
=
1e-10
,
...
...
@@ -966,9 +1005,13 @@ def between_residue_clash_loss(
)
n_one_hot
=
n_one_hot
.
type
(
fp_type
)
neighbour_mask
=
(
residue_index
[...,
:,
None
,
None
,
None
]
+
1
)
==
residue_index
[...,
None
,
:,
None
,
None
]
neighbour_mask
=
(
residue_index
[...,
:,
None
]
+
1
)
==
residue_index
[...,
None
,
:]
if
asym_id
is
not
None
:
neighbour_mask
=
neighbour_mask
&
(
asym_id
[...,
:,
None
]
==
asym_id
[...,
None
,
:])
neighbour_mask
=
neighbour_mask
[...,
None
,
None
]
c_n_bonds
=
(
neighbour_mask
*
c_one_hot
[...,
None
,
None
,
:,
None
]
...
...
@@ -1010,26 +1053,29 @@ def between_residue_clash_loss(
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum
=
torch
.
sum
(
dists_to_low_error
,
dim
=
(
-
4
,
-
2
))
+
torch
.
sum
(
dists_to_low_error
,
axis
=
(
-
3
,
-
1
)
dists_to_low_error
,
dim
=
(
-
3
,
-
1
)
)
# Compute the hard clash mask.
# shape (N, N, 14, 14)
clash_mask
=
dists_mask
*
(
dists
<
(
dists_lower_bound
-
overlap_tolerance_hard
)
dists
<
(
dists_lower_bound
-
overlap_tolerance_hard
)
)
per_atom_num_clash
=
torch
.
sum
(
clash_mask
,
dim
=
(
-
4
,
-
2
))
+
torch
.
sum
(
clash_mask
,
dim
=
(
-
3
,
-
1
))
# Compute the per atom clash.
# shape (N, 14)
per_atom_clash_mask
=
torch
.
maximum
(
torch
.
amax
(
clash_mask
,
axis
=
(
-
4
,
-
2
)),
torch
.
amax
(
clash_mask
,
axis
=
(
-
3
,
-
1
)),
torch
.
amax
(
clash_mask
,
dim
=
(
-
4
,
-
2
)),
torch
.
amax
(
clash_mask
,
dim
=
(
-
3
,
-
1
)),
)
return
{
"mean_loss"
:
mean_loss
,
# shape ()
"per_atom_loss_sum"
:
per_atom_loss_sum
,
# shape (N, 14)
"per_atom_clash_mask"
:
per_atom_clash_mask
,
# shape (N, 14)
"per_atom_num_clash"
:
per_atom_num_clash
# shape (N, 14)
}
...
...
@@ -1109,6 +1155,8 @@ def within_residue_violations(
(
dists
<
atom14_dists_lower_bound
)
|
(
dists
>
atom14_dists_upper_bound
)
)
per_atom_num_clash
=
torch
.
sum
(
violations
,
dim
=-
2
)
+
torch
.
sum
(
violations
,
dim
=-
1
)
# Compute the per atom violations.
per_atom_violations
=
torch
.
maximum
(
torch
.
max
(
violations
,
dim
=-
2
)[
0
],
torch
.
max
(
violations
,
axis
=-
1
)[
0
]
...
...
@@ -1117,6 +1165,7 @@ def within_residue_violations(
return
{
"per_atom_loss_sum"
:
per_atom_loss_sum
,
"per_atom_violations"
:
per_atom_violations
,
"per_atom_num_clash"
:
per_atom_num_clash
}
...
...
@@ -1146,11 +1195,24 @@ def find_structural_violations(
residue_constants
.
van_der_waals_radius
[
name
[
0
]]
for
name
in
residue_constants
.
atom_types
]
atomtype_radius
=
atom14_pred_positions
.
new_tensor
(
atomtype_radius
)
atom14_atom_radius
=
(
batch
[
"atom14_atom_exists"
]
*
atomtype_radius
[
batch
[
"residx_atom14_to_atom37"
]]
)
#TODO: Consolidate monomer/multimer modes
asym_id
=
batch
.
get
(
"asym_id"
)
if
asym_id
is
not
None
:
residx_atom14_to_atom37
=
get_rc_tensor
(
residue_constants
.
RESTYPE_ATOM14_TO_ATOM37
,
batch
[
"aatype"
]
)
atom14_atom_radius
=
(
batch
[
"atom14_atom_exists"
]
*
atomtype_radius
[
residx_atom14_to_atom37
.
long
()]
)
else
:
atom14_atom_radius
=
(
batch
[
"atom14_atom_exists"
]
*
atomtype_radius
[
batch
[
"residx_atom14_to_atom37"
]]
)
# Compute the between residue clash loss.
between_residue_clashes
=
between_residue_clash_loss
(
...
...
@@ -1158,6 +1220,7 @@ def find_structural_violations(
atom14_atom_exists
=
batch
[
"atom14_atom_exists"
],
atom14_atom_radius
=
atom14_atom_radius
,
residue_index
=
batch
[
"residue_index"
],
asym_id
=
asym_id
,
overlap_tolerance_soft
=
clash_overlap_tolerance
,
overlap_tolerance_hard
=
clash_overlap_tolerance
,
)
...
...
@@ -1220,6 +1283,9 @@ def find_structural_violations(
"clashes_per_atom_clash_mask"
:
between_residue_clashes
[
"per_atom_clash_mask"
],
# (N, 14)
"clashes_per_atom_num_clash"
:
between_residue_clashes
[
"per_atom_num_clash"
],
# (N, 14)
},
"within_residues"
:
{
"per_atom_loss_sum"
:
residue_violations
[
...
...
@@ -1228,6 +1294,9 @@ def find_structural_violations(
"per_atom_violations"
:
residue_violations
[
"per_atom_violations"
],
# (N, 14),
"per_atom_num_clash"
:
residue_violations
[
"per_atom_num_clash"
],
# (N, 14)
},
"total_per_residue_violations_mask"
:
per_residue_violations_mask
,
# (N)
}
...
...
@@ -1349,15 +1418,21 @@ def compute_violation_metrics_np(
def
violation_loss
(
violations
:
Dict
[
str
,
torch
.
Tensor
],
atom14_atom_exists
:
torch
.
Tensor
,
average_clashes
:
bool
=
False
,
eps
=
1e-6
,
**
kwargs
,
)
->
torch
.
Tensor
:
num_atoms
=
torch
.
sum
(
atom14_atom_exists
)
l_clash
=
torch
.
sum
(
violations
[
"between_residues"
][
"clashes_per_atom_loss_sum"
]
+
violations
[
"within_residues"
][
"per_atom_loss_sum"
]
)
l_clash
=
l_clash
/
(
eps
+
num_atoms
)
per_atom_clash
=
(
violations
[
"between_residues"
][
"clashes_per_atom_loss_sum"
]
+
violations
[
"within_residues"
][
"per_atom_loss_sum"
])
if
average_clashes
:
num_clash
=
(
violations
[
"between_residues"
][
"clashes_per_atom_num_clash"
]
+
violations
[
"within_residues"
][
"per_atom_num_clash"
])
per_atom_clash
=
per_atom_clash
/
(
num_clash
+
eps
)
l_clash
=
torch
.
sum
(
per_atom_clash
)
/
(
eps
+
num_atoms
)
loss
=
(
violations
[
"between_residues"
][
"bonds_c_n_loss_mean"
]
+
violations
[
"between_residues"
][
"angles_ca_c_n_loss_mean"
]
...
...
@@ -1533,6 +1608,64 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs
return
loss
def
chain_center_of_mass_loss
(
all_atom_pred_pos
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
asym_id
:
torch
.
Tensor
,
clamp_distance
:
float
=
-
4.0
,
weight
:
float
=
0.05
,
eps
:
float
=
1e-10
)
->
torch
.
Tensor
:
"""
Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper.
Args:
all_atom_pred_pos:
[*, N_pts, 37, 3] All-atom predicted atom positions
all_atom_positions:
[*, N_pts, 37, 3] Ground truth all-atom positions
all_atom_mask:
[*, N_pts, 37] All-atom positions mask
asym_id:
[*, N_pts] Chain asym IDs
clamp_distance:
Cutoff above which distance errors are disregarded
weight:
Weight for loss
eps:
Small value used to regularize denominators
Returns:
[*] loss tensor
"""
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
chains
,
_
=
asym_id
.
unique
(
return_counts
=
True
)
one_hot
=
torch
.
nn
.
functional
.
one_hot
(
asym_id
,
num_classes
=
chains
.
shape
[
0
]).
to
(
dtype
=
all_atom_mask
.
dtype
)
one_hot
=
one_hot
*
all_atom_mask
chain_pos_mask
=
one_hot
.
transpose
(
-
2
,
-
1
)
chain_exists
=
torch
.
any
(
chain_pos_mask
,
dim
=-
1
).
float
()
def
get_chain_center_of_mass
(
pos
):
center_sum
=
(
chain_pos_mask
[...,
None
]
*
pos
[...,
None
,
:,
:]).
sum
(
dim
=-
2
)
centers
=
center_sum
/
(
torch
.
sum
(
chain_pos_mask
,
dim
=-
1
,
keepdim
=
True
)
+
eps
)
return
Vec3Array
.
from_array
(
centers
)
pred_centers
=
get_chain_center_of_mass
(
all_atom_pred_pos
)
# [B, NC, 3]
true_centers
=
get_chain_center_of_mass
(
all_atom_positions
)
# [B, NC, 3]
pred_dists
=
euclidean_distance
(
pred_centers
[...,
None
,
:],
pred_centers
[...,
:,
None
],
epsilon
=
eps
)
true_dists
=
euclidean_distance
(
true_centers
[...,
None
,
:],
true_centers
[...,
:,
None
],
epsilon
=
eps
)
losses
=
torch
.
clamp
((
weight
*
(
pred_dists
-
true_dists
-
clamp_distance
)),
max
=
0
)
**
2
loss_mask
=
chain_exists
[...,
:,
None
]
*
chain_exists
[...,
None
,
:]
loss
=
masked_mean
(
loss_mask
,
losses
,
dim
=
(
-
1
,
-
2
))
return
loss
class
AlphaFoldLoss
(
nn
.
Module
):
"""Aggregation of the various losses described in the supplement"""
def
__init__
(
self
,
config
):
...
...
@@ -1585,7 +1718,7 @@ class AlphaFoldLoss(nn.Module):
),
"violation"
:
lambda
:
violation_loss
(
out
[
"violation"
],
**
batch
,
**
{
**
batch
,
**
self
.
config
.
violation
},
),
}
...
...
@@ -1595,6 +1728,12 @@ class AlphaFoldLoss(nn.Module):
**
{
**
batch
,
**
out
,
**
self
.
config
.
tm
},
)
if
(
self
.
config
.
chain_center_of_mass
.
enabled
):
loss_fns
[
"chain_center_of_mass"
]
=
lambda
:
chain_center_of_mass_loss
(
all_atom_pred_pos
=
out
[
"final_atom_positions"
],
**
{
**
batch
,
**
self
.
config
.
chain_center_of_mass
},
)
cum_loss
=
0.
losses
=
{}
for
loss_name
,
loss_fn
in
loss_fns
.
items
():
...
...
tests/config.py
View file @
51556d52
...
...
@@ -6,7 +6,7 @@ consts = mlc.ConfigDict(
"is_multimer"
:
True
,
# monomer: False, multimer: True
"chunk_size"
:
4
,
"batch_size"
:
2
,
"n_res"
:
11
,
"n_res"
:
22
,
"n_seq"
:
13
,
"n_templ"
:
3
,
"n_extra"
:
17
,
...
...
tests/data_utils.py
View file @
51556d52
...
...
@@ -29,14 +29,16 @@ def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
pieces
=
[]
asym_ids
=
[]
final_idx
=
n_chain
-
1
for
idx
in
range
(
n_chain
-
1
):
n_stop
=
(
n_res
-
sum
(
pieces
)
-
n_chain
+
idx
-
min_chain_len
)
if
n_stop
<=
min_chain_len
:
final_idx
=
idx
break
piece
=
randint
(
min_chain_len
,
n_stop
)
pieces
.
append
(
piece
)
asym_ids
.
extend
(
piece
*
[
idx
])
asym_ids
.
extend
((
n_res
-
sum
(
pieces
))
*
[
n_chain
-
1
])
asym_ids
.
extend
((
n_res
-
sum
(
pieces
))
*
[
final_idx
])
return
np
.
array
(
asym_ids
).
astype
(
np
.
int64
)
...
...
tests/test_loss.py
View file @
51556d52
...
...
@@ -20,6 +20,7 @@ import unittest
import
ml_collections
as
mlc
from
openfold.data
import
data_transforms
from
openfold.np
import
residue_constants
from
openfold.utils.rigid_utils
import
(
Rotation
,
Rigid
,
...
...
@@ -42,6 +43,8 @@ from openfold.utils.loss import (
sidechain_loss
,
tm_loss
,
compute_plddt
,
compute_tm
,
chain_center_of_mass_loss
)
from
openfold.utils.tensor_utils
import
(
tree_map
,
...
...
@@ -233,11 +236,24 @@ class TestLoss(unittest.TestCase):
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
atom_exists
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
)
atom_radius
=
np
.
random
.
rand
(
n_res
,
14
).
astype
(
np
.
float32
)
res_ind
=
np
.
arange
(
n_res
,
)
asym_id
=
random_asym_ids
(
n_res
)
residx_atom14_to_atom37
=
np
.
random
.
randint
(
0
,
37
,
(
n_res
,
14
)).
astype
(
np
.
int64
)
atomtype_radius
=
[
residue_constants
.
van_der_waals_radius
[
name
[
0
]]
for
name
in
residue_constants
.
atom_types
]
atomtype_radius
=
np
.
array
(
atomtype_radius
).
astype
(
np
.
float32
)
atom_radius
=
(
atom_exists
*
atomtype_radius
[
residx_atom14_to_atom37
]
)
asym_id
=
None
if
consts
.
is_multimer
:
asym_id
=
random_asym_ids
(
n_res
)
out_gt
=
f
.
apply
(
{},
...
...
@@ -256,6 +272,7 @@ class TestLoss(unittest.TestCase):
torch
.
tensor
(
atom_exists
).
cuda
(),
torch
.
tensor
(
atom_radius
).
cuda
(),
torch
.
tensor
(
res_ind
).
cuda
(),
torch
.
tensor
(
asym_id
).
cuda
()
if
asym_id
is
not
None
else
None
,
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
...
...
@@ -279,6 +296,36 @@ class TestLoss(unittest.TestCase):
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compute_ptm_compare
(
self
):
n_res
=
consts
.
n_res
max_bin
=
31
no_bins
=
64
logits
=
np
.
random
.
rand
(
n_res
,
n_res
,
no_bins
)
boundaries
=
np
.
linspace
(
0
,
max_bin
,
num
=
(
no_bins
-
1
))
ptm_gt
=
alphafold
.
common
.
confidence
.
predicted_tm_score
(
logits
,
boundaries
)
ptm_gt
=
torch
.
tensor
(
ptm_gt
)
logits_t
=
torch
.
tensor
(
logits
)
ptm_repro
=
compute_tm
(
logits_t
,
no_bins
=
no_bins
,
max_bin
=
max_bin
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
ptm_gt
-
ptm_repro
))
<
consts
.
eps
)
if
consts
.
is_multimer
:
asym_id
=
random_asym_ids
(
n_res
)
iptm_gt
=
alphafold
.
common
.
confidence
.
predicted_tm_score
(
logits
,
boundaries
,
asym_id
=
asym_id
,
interface
=
True
)
iptm_gt
=
torch
.
tensor
(
iptm_gt
)
iptm_repro
=
compute_tm
(
logits_t
,
no_bins
=
no_bins
,
max_bin
=
max_bin
,
asym_id
=
torch
.
tensor
(
asym_id
),
interface
=
True
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
iptm_gt
-
iptm_repro
))
<
consts
.
eps
)
def
test_find_structural_violations
(
self
):
n
=
consts
.
n_res
...
...
@@ -335,9 +382,11 @@ class TestLoss(unittest.TestCase):
"residx_atom14_to_atom37"
:
np
.
random
.
randint
(
0
,
37
,
(
n_res
,
14
)
).
astype
(
np
.
int64
),
"asym_id"
:
random_asym_ids
(
n_res
)
}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
config
=
mlc
.
ConfigDict
(
...
...
@@ -632,6 +681,40 @@ class TestLoss(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_violation_loss
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_viol
=
config
.
model
.
heads
.
structure_module
n_res
=
consts
.
n_res
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"residue_index"
:
np
.
arange
(
n_res
),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
).
cuda
(),
batch
,
np
.
ndarray
)
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
atom14_pred_pos
=
torch
.
tensor
(
atom14_pred_pos
).
cuda
()
batch
=
data_transforms
.
make_atom14_masks
(
batch
)
loss_sum_clash
=
violation_loss
(
find_structural_violations
(
batch
,
atom14_pred_pos
,
**
c_viol
),
average_clashes
=
False
,
**
batch
)
loss_sum_clash
=
loss_sum_clash
.
cpu
()
loss_avg_clash
=
violation_loss
(
find_structural_violations
(
batch
,
atom14_pred_pos
,
**
c_viol
),
average_clashes
=
True
,
**
batch
)
loss_avg_clash
=
loss_avg_clash
.
cpu
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_violation_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
...
...
@@ -680,10 +763,12 @@ class TestLoss(unittest.TestCase):
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"residue_index"
:
np
.
arange
(
n_res
),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"asym_id"
:
random_asym_ids
(
n_res
)
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,))
}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
alphafold
.
model
.
tf
.
data_transforms
.
make_atom14_masks
(
batch
)
...
...
@@ -801,8 +886,7 @@ class TestLoss(unittest.TestCase):
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"use_clamped_fape"
:
np
.
array
(
0.0
),
"asym_id"
:
random_asym_ids
(
n_res
)
"use_clamped_fape"
:
np
.
array
(
0.0
)
}
value
=
{
...
...
@@ -814,6 +898,9 @@ class TestLoss(unittest.TestCase):
),
}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
out_gt
=
f
.
apply
({},
None
,
batch
,
value
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
...
...
@@ -826,8 +913,18 @@ class TestLoss(unittest.TestCase):
)
batch
[
"backbone_rigid_mask"
]
=
batch
[
"backbone_affine_mask"
]
out_repro
=
backbone_loss
(
traj
=
value
[
"traj"
],
**
{
**
batch
,
**
c_sm
})
out_repro
=
out_repro
.
cpu
()
if
consts
.
is_multimer
:
intra_chain_mask
=
(
batch
[
"asym_id"
][...,
None
]
==
batch
[
"asym_id"
][...,
None
,
:]).
to
(
dtype
=
value
[
"traj"
].
dtype
)
intra_chain_out
=
backbone_loss
(
traj
=
value
[
"traj"
],
pair_mask
=
intra_chain_mask
,
**
{
**
batch
,
**
c_sm
.
intra_chain_fape
})
interface_out
=
backbone_loss
(
traj
=
value
[
"traj"
],
pair_mask
=
1.
-
intra_chain_mask
,
**
{
**
batch
,
**
c_sm
.
interface_fape
})
out_repro
=
intra_chain_out
+
interface_out
out_repro
=
out_repro
.
cpu
()
else
:
out_repro
=
backbone_loss
(
traj
=
value
[
"traj"
],
**
{
**
batch
,
**
c_sm
})
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
...
...
@@ -869,14 +966,14 @@ class TestLoss(unittest.TestCase):
v
[
"sidechains"
]
=
{}
v
[
"sidechains"
][
"frames"
]
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
]
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
value
[
"sidechains"
][
"frames"
]
)
v
[
"sidechains"
][
"atom_pos"
]
=
alphafold
.
model
.
r3
.
vecs_from_tensor
(
v
[
"sidechains"
][
"atom_pos"
]
=
self
.
am_rigid
.
vecs_from_tensor
(
value
[
"sidechains"
][
"atom_pos"
]
)
v
.
update
(
alphafold
.
model
.
fold
ing
.
compute_renamed_ground_truth
(
self
.
am_
fold
.
compute_renamed_ground_truth
(
batch
,
atom14_pred_positions
,
)
...
...
@@ -907,9 +1004,6 @@ class TestLoss(unittest.TestCase):
),
}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
def
_build_extra_feats_np
():
b
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
b
=
data_transforms
.
make_atom14_masks
(
b
)
...
...
@@ -950,7 +1044,7 @@ class TestLoss(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
unittest
.
skipIf
(
consts
.
is_multimer
or
"ptm"
not
in
consts
.
model
,
"Not enabled for non-ptm models."
)
@
unittest
.
skipIf
(
not
consts
.
is_multimer
and
"ptm"
not
in
consts
.
model
,
"Not enabled for non-ptm models."
)
def
test_tm_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_tm
=
config
.
model
.
heads
.
predicted_aligned_error
...
...
@@ -1017,6 +1111,33 @@ class TestLoss(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_chain_center_of_mass_loss
(
self
):
batch_size
=
consts
.
batch_size
n_res
=
consts
.
n_res
batch
=
{
"all_atom_positions"
:
np
.
random
.
rand
(
batch_size
,
n_res
,
37
,
3
).
astype
(
np
.
float32
)
*
10.0
,
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
batch_size
,
n_res
,
37
)).
astype
(
np
.
float32
),
"asym_id"
:
np
.
stack
([
random_asym_ids
(
n_res
)
for
_
in
range
(
batch_size
)])
}
config
=
{
"weight"
:
0.05
,
"clamp_distance"
:
-
4.0
,
}
final_atom_positions
=
torch
.
rand
(
batch_size
,
n_res
,
37
,
3
).
cuda
()
to_tensor
=
lambda
t
:
torch
.
tensor
(
t
).
cuda
()
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
out_repro
=
chain_center_of_mass_loss
(
all_atom_pred_pos
=
final_atom_positions
,
**
{
**
batch
,
**
config
},
)
out_repro
=
out_repro
.
cpu
()
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment