Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
51556d52
"...git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "e7eadc440d3745b7f7cf1ca8565249472d016efd"
Commit
51556d52
authored
Jun 14, 2023
by
Christina Floristean
Browse files
Added multimer changes for loss functions
parent
fbfbd808
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
424 additions
and
63 deletions
+424
-63
openfold/config.py
openfold/config.py
+86
-1
openfold/model/heads.py
openfold/model/heads.py
+9
-1
openfold/model/model.py
openfold/model/model.py
+3
-0
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+6
-3
openfold/utils/geometry/vector.py
openfold/utils/geometry/vector.py
+1
-1
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+1
-1
openfold/utils/loss.py
openfold/utils/loss.py
+177
-38
tests/config.py
tests/config.py
+1
-1
tests/data_utils.py
tests/data_utils.py
+3
-1
tests/test_loss.py
tests/test_loss.py
+137
-16
No files found.
openfold/config.py
View file @
51556d52
...
@@ -648,6 +648,7 @@ config = mlc.ConfigDict(
...
@@ -648,6 +648,7 @@ config = mlc.ConfigDict(
"violation"
:
{
"violation"
:
{
"violation_tolerance_factor"
:
12.0
,
"violation_tolerance_factor"
:
12.0
,
"clash_overlap_tolerance"
:
1.5
,
"clash_overlap_tolerance"
:
1.5
,
"average_clashes"
:
False
,
"eps"
:
eps
,
# 1e-6,
"eps"
:
eps
,
# 1e-6,
"weight"
:
0.0
,
"weight"
:
0.0
,
},
},
...
@@ -660,6 +661,12 @@ config = mlc.ConfigDict(
...
@@ -660,6 +661,12 @@ config = mlc.ConfigDict(
"weight"
:
0.
,
"weight"
:
0.
,
"enabled"
:
tm_enabled
,
"enabled"
:
tm_enabled
,
},
},
"chain_center_of_mass"
:
{
"clamp_distance"
:
-
4.0
,
"weight"
:
0.
,
"eps"
:
eps
,
"enabled"
:
False
,
},
"eps"
:
eps
,
"eps"
:
eps
,
},
},
"ema"
:
{
"decay"
:
0.999
},
"ema"
:
{
"decay"
:
0.999
},
...
@@ -802,7 +809,9 @@ multimer_model_config_update = {
...
@@ -802,7 +809,9 @@ multimer_model_config_update = {
"tm"
:
{
"tm"
:
{
"c_z"
:
c_z
,
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
"no_bins"
:
aux_distogram_bins
,
"enabled"
:
tm_enabled
,
"ptm_weight"
:
0.2
,
"iptm_weight"
:
0.8
,
"enabled"
:
True
,
},
},
"masked_msa"
:
{
"masked_msa"
:
{
"c_m"
:
c_m
,
"c_m"
:
c_m
,
...
@@ -813,5 +822,81 @@ multimer_model_config_update = {
...
@@ -813,5 +822,81 @@ multimer_model_config_update = {
"c_out"
:
37
,
"c_out"
:
37
,
},
},
},
},
"loss"
:
{
"distogram"
:
{
"min_bin"
:
2.3125
,
"max_bin"
:
21.6875
,
"no_bins"
:
64
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
0.3
,
},
"experimentally_resolved"
:
{
"eps"
:
eps
,
# 1e-8,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"weight"
:
0.0
,
},
"fape"
:
{
"intra_chain_backbone"
:
{
"clamp_distance"
:
10.0
,
"loss_unit_distance"
:
10.0
,
"weight"
:
0.5
,
},
"interface"
:
{
"clamp_distance"
:
30.0
,
"loss_unit_distance"
:
20.0
,
"weight"
:
0.5
,
},
"sidechain"
:
{
"clamp_distance"
:
10.0
,
"length_scale"
:
10.0
,
"weight"
:
0.5
,
},
"eps"
:
1e-4
,
"weight"
:
1.0
,
},
"plddt_loss"
:
{
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"cutoff"
:
15.0
,
"no_bins"
:
50
,
"eps"
:
eps
,
# 1e-10,
"weight"
:
0.01
,
},
"masked_msa"
:
{
"num_classes"
:
23
,
"eps"
:
eps
,
# 1e-8,
"weight"
:
2.0
,
},
"supervised_chi"
:
{
"chi_weight"
:
0.5
,
"angle_norm_weight"
:
0.01
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
1.0
,
},
"violation"
:
{
"violation_tolerance_factor"
:
12.0
,
"clash_overlap_tolerance"
:
1.5
,
"average_clashes"
:
True
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
0.03
,
# Not finetuning
},
"tm"
:
{
"max_bin"
:
31
,
"no_bins"
:
64
,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"eps"
:
eps
,
# 1e-8,
"weight"
:
0.1
,
"enabled"
:
True
,
},
"chain_center_of_mass"
:
{
"clamp_distance"
:
-
4.0
,
"weight"
:
0.05
,
"eps"
:
eps
,
"enabled"
:
True
,
},
"eps"
:
eps
,
},
"recycle_early_stop_tolerance"
:
0.5
"recycle_early_stop_tolerance"
:
0.5
}
}
openfold/model/heads.py
View file @
51556d52
...
@@ -76,9 +76,17 @@ class AuxiliaryHeads(nn.Module):
...
@@ -76,9 +76,17 @@ class AuxiliaryHeads(nn.Module):
if
self
.
config
.
tm
.
enabled
:
if
self
.
config
.
tm
.
enabled
:
tm_logits
=
self
.
tm
(
outputs
[
"pair"
])
tm_logits
=
self
.
tm
(
outputs
[
"pair"
])
aux_out
[
"tm_logits"
]
=
tm_logits
aux_out
[
"tm_logits"
]
=
tm_logits
aux_out
[
"p
redicted_
tm_score"
]
=
compute_tm
(
aux_out
[
"ptm_score"
]
=
compute_tm
(
tm_logits
,
**
self
.
config
.
tm
tm_logits
,
**
self
.
config
.
tm
)
)
asym_id
=
outputs
.
get
(
"asym_id"
)
if
asym_id
is
not
None
:
aux_out
[
"iptm_score"
]
=
compute_tm
(
tm_logits
,
asym_id
=
asym_id
,
interface
=
True
,
**
self
.
config
.
tm
)
aux_out
[
"weighted_ptm_score"
]
=
(
self
.
config
.
tm
[
"iptm_weight"
]
*
aux_out
[
"iptm_score"
]
+
self
.
config
.
tm
[
"ptm_weight"
]
*
aux_out
[
"ptm_score"
])
aux_out
.
update
(
aux_out
.
update
(
compute_predicted_aligned_error
(
compute_predicted_aligned_error
(
tm_logits
,
tm_logits
,
...
...
openfold/model/model.py
View file @
51556d52
...
@@ -555,6 +555,9 @@ class AlphaFold(nn.Module):
...
@@ -555,6 +555,9 @@ class AlphaFold(nn.Module):
else
:
else
:
break
break
if
"asym_id"
in
batch
:
outputs
[
"asym_id"
]
=
feats
[
"asym_id"
]
# Run auxiliary heads
# Run auxiliary heads
outputs
.
update
(
self
.
aux_heads
(
outputs
))
outputs
.
update
(
self
.
aux_heads
(
outputs
))
...
...
openfold/model/triangular_multiplicative_update.py
View file @
51556d52
...
@@ -435,7 +435,7 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
...
@@ -435,7 +435,7 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
# reduced-precision modes
# reduced-precision modes
a_std
=
a
.
std
()
a_std
=
a
.
std
()
b_std
=
b
.
std
()
b_std
=
b
.
std
()
if
(
a_std
!=
0.
and
b_std
!=
0.
):
if
(
is_fp16_enabled
()
and
a_std
!=
0.
and
b_std
!=
0.
):
a
=
a
/
a
.
std
()
a
=
a
/
a
.
std
()
b
=
b
/
b
.
std
()
b
=
b
/
b
.
std
()
...
@@ -589,8 +589,11 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
...
@@ -589,8 +589,11 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
# Prevents overflow of torch.matmul in combine projections in
# Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
# reduced-precision modes
a
=
a
/
a
.
std
()
a_std
=
a
.
std
()
b
=
b
/
b
.
std
()
b_std
=
b
.
std
()
if
(
is_fp16_enabled
()
and
a_std
!=
0.
and
b_std
!=
0.
):
a
=
a
/
a
.
std
()
b
=
b
/
b
.
std
()
if
(
is_fp16_enabled
()):
if
(
is_fp16_enabled
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
...
...
openfold/utils/geometry/vector.py
View file @
51556d52
...
@@ -193,7 +193,7 @@ def square_euclidean_distance(
...
@@ -193,7 +193,7 @@ def square_euclidean_distance(
difference
=
vec1
-
vec2
difference
=
vec1
-
vec2
distance
=
difference
.
dot
(
difference
)
distance
=
difference
.
dot
(
difference
)
if
epsilon
:
if
epsilon
:
distance
=
torch
.
maximum
(
distance
,
epsilon
)
distance
=
torch
.
clamp
(
distance
,
min
=
epsilon
)
return
distance
return
distance
...
...
openfold/utils/import_weights.py
View file @
51556d52
...
@@ -617,7 +617,7 @@ def generate_translation_dict(model, version, is_multimer=False):
...
@@ -617,7 +617,7 @@ def generate_translation_dict(model, version, is_multimer=False):
translations
[
"evoformer"
].
update
(
template_param_dict
)
translations
[
"evoformer"
].
update
(
template_param_dict
)
if
"_ptm"
in
version
:
if
is_multimer
or
"_ptm"
in
version
:
translations
[
"predicted_aligned_error_head"
]
=
{
translations
[
"predicted_aligned_error_head"
]
=
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
}
}
...
...
openfold/utils/loss.py
View file @
51556d52
...
@@ -25,6 +25,8 @@ from typing import Dict, Optional, Tuple
...
@@ -25,6 +25,8 @@ from typing import Dict, Optional, Tuple
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
from
openfold.utils
import
feats
from
openfold.utils
import
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.geometry.vector
import
Vec3Array
,
euclidean_distance
from
openfold.utils.all_atom_multimer
import
get_rc_tensor
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tree_map
,
tree_map
,
tensor_tree_map
,
tensor_tree_map
,
...
@@ -87,6 +89,7 @@ def compute_fape(
...
@@ -87,6 +89,7 @@ def compute_fape(
target_positions
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
positions_mask
:
torch
.
Tensor
,
positions_mask
:
torch
.
Tensor
,
length_scale
:
float
,
length_scale
:
float
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
l1_clamp_distance
:
Optional
[
float
]
=
None
,
l1_clamp_distance
:
Optional
[
float
]
=
None
,
eps
=
1e-8
,
eps
=
1e-8
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -108,6 +111,9 @@ def compute_fape(
...
@@ -108,6 +111,9 @@ def compute_fape(
[*, N_pts] positions mask
[*, N_pts] positions mask
length_scale:
length_scale:
Length scale by which the loss is divided
Length scale by which the loss is divided
pair_mask:
[*, N_frames, N_pts] mask to use for
separating intra- from inter-chain losses.
l1_clamp_distance:
l1_clamp_distance:
Cutoff above which distance errors are disregarded
Cutoff above which distance errors are disregarded
eps:
eps:
...
@@ -134,21 +140,30 @@ def compute_fape(
...
@@ -134,21 +140,30 @@ def compute_fape(
normed_error
=
normed_error
*
frames_mask
[...,
None
]
normed_error
=
normed_error
*
frames_mask
[...,
None
]
normed_error
=
normed_error
*
positions_mask
[...,
None
,
:]
normed_error
=
normed_error
*
positions_mask
[...,
None
,
:]
# FP16-friendly averaging. Roughly equivalent to:
if
pair_mask
is
not
None
:
#
normed_error
=
normed_error
*
pair_mask
# norm_factor = (
normed_error
=
torch
.
sum
(
normed_error
,
dim
=
(
-
1
,
-
2
))
# torch.sum(frames_mask, dim=-1) *
# torch.sum(positions_mask, dim=-1)
mask
=
frames_mask
[...,
None
]
*
positions_mask
[...,
None
,
:]
*
pair_mask
# )
norm_factor
=
torch
.
sum
(
mask
,
dim
=
(
-
2
,
-
1
))
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
normed_error
=
normed_error
/
(
eps
+
norm_factor
)
# ("roughly" because eps is necessarily duplicated in the latter)
else
:
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
# FP16-friendly averaging. Roughly equivalent to:
normed_error
=
(
#
normed_error
/
(
eps
+
torch
.
sum
(
frames_mask
,
dim
=-
1
))[...,
None
]
# norm_factor = (
)
# torch.sum(frames_mask, dim=-1) *
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
# torch.sum(positions_mask, dim=-1)
normed_error
=
normed_error
/
(
eps
+
torch
.
sum
(
positions_mask
,
dim
=-
1
))
# )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
# ("roughly" because eps is necessarily duplicated in the latter)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
(
normed_error
/
(
eps
+
torch
.
sum
(
frames_mask
,
dim
=-
1
))[...,
None
]
)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
normed_error
/
(
eps
+
torch
.
sum
(
positions_mask
,
dim
=-
1
))
return
normed_error
return
normed_error
...
@@ -157,6 +172,7 @@ def backbone_loss(
...
@@ -157,6 +172,7 @@ def backbone_loss(
backbone_rigid_tensor
:
torch
.
Tensor
,
backbone_rigid_tensor
:
torch
.
Tensor
,
backbone_rigid_mask
:
torch
.
Tensor
,
backbone_rigid_mask
:
torch
.
Tensor
,
traj
:
torch
.
Tensor
,
traj
:
torch
.
Tensor
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
clamp_distance
:
float
=
10.0
,
clamp_distance
:
float
=
10.0
,
loss_unit_distance
:
float
=
10.0
,
loss_unit_distance
:
float
=
10.0
,
...
@@ -184,6 +200,7 @@ def backbone_loss(
...
@@ -184,6 +200,7 @@ def backbone_loss(
pred_aff
.
get_trans
(),
pred_aff
.
get_trans
(),
gt_aff
[
None
].
get_trans
(),
gt_aff
[
None
].
get_trans
(),
backbone_rigid_mask
[
None
],
backbone_rigid_mask
[
None
],
pair_mask
=
pair_mask
,
l1_clamp_distance
=
clamp_distance
,
l1_clamp_distance
=
clamp_distance
,
length_scale
=
loss_unit_distance
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
eps
=
eps
,
...
@@ -196,6 +213,7 @@ def backbone_loss(
...
@@ -196,6 +213,7 @@ def backbone_loss(
pred_aff
.
get_trans
(),
pred_aff
.
get_trans
(),
gt_aff
[
None
].
get_trans
(),
gt_aff
[
None
].
get_trans
(),
backbone_rigid_mask
[
None
],
backbone_rigid_mask
[
None
],
pair_mask
=
pair_mask
,
l1_clamp_distance
=
None
,
l1_clamp_distance
=
None
,
length_scale
=
loss_unit_distance
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
eps
=
eps
,
...
@@ -253,6 +271,7 @@ def sidechain_loss(
...
@@ -253,6 +271,7 @@ def sidechain_loss(
sidechain_atom_pos
,
sidechain_atom_pos
,
renamed_atom14_gt_positions
,
renamed_atom14_gt_positions
,
renamed_atom14_gt_exists
,
renamed_atom14_gt_exists
,
pair_mask
=
None
,
l1_clamp_distance
=
clamp_distance
,
l1_clamp_distance
=
clamp_distance
,
length_scale
=
length_scale
,
length_scale
=
length_scale
,
eps
=
eps
,
eps
=
eps
,
...
@@ -266,10 +285,29 @@ def fape_loss(
...
@@ -266,10 +285,29 @@ def fape_loss(
batch
:
Dict
[
str
,
torch
.
Tensor
],
batch
:
Dict
[
str
,
torch
.
Tensor
],
config
:
ml_collections
.
ConfigDict
,
config
:
ml_collections
.
ConfigDict
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
bb_loss
=
backbone_loss
(
traj
=
out
[
"sm"
][
"frames"
],
traj
=
out
[
"sm"
][
"frames"
]
**
{
**
batch
,
**
config
.
backbone
},
asym_id
=
batch
.
get
(
"asym_id"
)
)
if
asym_id
is
not
None
:
intra_chain_mask
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]).
to
(
dtype
=
traj
.
dtype
)
intra_chain_bb_loss
=
backbone_loss
(
traj
=
traj
,
pair_mask
=
intra_chain_mask
,
**
{
**
batch
,
**
config
.
intra_chain_backbone
},
)
interface_bb_loss
=
backbone_loss
(
traj
=
traj
,
pair_mask
=
1.
-
intra_chain_mask
,
**
{
**
batch
,
**
config
.
interface_backbone
},
)
weighted_bb_loss
=
(
intra_chain_bb_loss
*
config
.
intra_chain_backbone
.
weight
+
interface_bb_loss
*
config
.
interface_backbone
.
weight
)
else
:
bb_loss
=
backbone_loss
(
traj
=
traj
,
**
{
**
batch
,
**
config
.
backbone
},
)
weighted_bb_loss
=
bb_loss
*
config
.
backbone
.
weight
sc_loss
=
sidechain_loss
(
sc_loss
=
sidechain_loss
(
out
[
"sm"
][
"sidechain_frames"
],
out
[
"sm"
][
"sidechain_frames"
],
...
@@ -277,7 +315,7 @@ def fape_loss(
...
@@ -277,7 +315,7 @@ def fape_loss(
**
{
**
batch
,
**
config
.
sidechain
},
**
{
**
batch
,
**
config
.
sidechain
},
)
)
loss
=
config
.
backbone
.
weight
*
bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
loss
=
weight
ed_
bb_loss
+
config
.
sidechain
.
weight
*
sc_loss
# Average over the batch dimension
# Average over the batch dimension
loss
=
torch
.
mean
(
loss
)
loss
=
torch
.
mean
(
loss
)
...
@@ -654,7 +692,7 @@ def compute_tm(
...
@@ -654,7 +692,7 @@ def compute_tm(
n
=
residue_weights
.
shape
[
-
1
]
n
=
residue_weights
.
shape
[
-
1
]
pair_mask
=
residue_weights
.
new_ones
((
n
,
n
),
dtype
=
torch
.
int32
)
pair_mask
=
residue_weights
.
new_ones
((
n
,
n
),
dtype
=
torch
.
int32
)
if
interface
:
if
interface
:
pair_mask
*=
(
asym_id
[...,
None
]
!=
asym_id
[...,
None
,
:])
pair_mask
*=
(
asym_id
[...,
None
]
!=
asym_id
[...,
None
,
:])
.
to
(
dtype
=
pair_mask
.
dtype
)
predicted_tm_term
*=
pair_mask
predicted_tm_term
*=
pair_mask
...
@@ -891,6 +929,7 @@ def between_residue_clash_loss(
...
@@ -891,6 +929,7 @@ def between_residue_clash_loss(
atom14_atom_exists
:
torch
.
Tensor
,
atom14_atom_exists
:
torch
.
Tensor
,
atom14_atom_radius
:
torch
.
Tensor
,
atom14_atom_radius
:
torch
.
Tensor
,
residue_index
:
torch
.
Tensor
,
residue_index
:
torch
.
Tensor
,
asym_id
:
Optional
[
torch
.
Tensor
]
=
None
,
overlap_tolerance_soft
=
1.5
,
overlap_tolerance_soft
=
1.5
,
overlap_tolerance_hard
=
1.5
,
overlap_tolerance_hard
=
1.5
,
eps
=
1e-10
,
eps
=
1e-10
,
...
@@ -966,9 +1005,13 @@ def between_residue_clash_loss(
...
@@ -966,9 +1005,13 @@ def between_residue_clash_loss(
)
)
n_one_hot
=
n_one_hot
.
type
(
fp_type
)
n_one_hot
=
n_one_hot
.
type
(
fp_type
)
neighbour_mask
=
(
neighbour_mask
=
(
residue_index
[...,
:,
None
]
+
1
)
==
residue_index
[...,
None
,
:]
residue_index
[...,
:,
None
,
None
,
None
]
+
1
)
==
residue_index
[...,
None
,
:,
None
,
None
]
if
asym_id
is
not
None
:
neighbour_mask
=
neighbour_mask
&
(
asym_id
[...,
:,
None
]
==
asym_id
[...,
None
,
:])
neighbour_mask
=
neighbour_mask
[...,
None
,
None
]
c_n_bonds
=
(
c_n_bonds
=
(
neighbour_mask
neighbour_mask
*
c_one_hot
[...,
None
,
None
,
:,
None
]
*
c_one_hot
[...,
None
,
None
,
:,
None
]
...
@@ -1010,26 +1053,29 @@ def between_residue_clash_loss(
...
@@ -1010,26 +1053,29 @@ def between_residue_clash_loss(
# Compute the per atom loss sum.
# Compute the per atom loss sum.
# shape (N, 14)
# shape (N, 14)
per_atom_loss_sum
=
torch
.
sum
(
dists_to_low_error
,
dim
=
(
-
4
,
-
2
))
+
torch
.
sum
(
per_atom_loss_sum
=
torch
.
sum
(
dists_to_low_error
,
dim
=
(
-
4
,
-
2
))
+
torch
.
sum
(
dists_to_low_error
,
axis
=
(
-
3
,
-
1
)
dists_to_low_error
,
dim
=
(
-
3
,
-
1
)
)
)
# Compute the hard clash mask.
# Compute the hard clash mask.
# shape (N, N, 14, 14)
# shape (N, N, 14, 14)
clash_mask
=
dists_mask
*
(
clash_mask
=
dists_mask
*
(
dists
<
(
dists_lower_bound
-
overlap_tolerance_hard
)
dists
<
(
dists_lower_bound
-
overlap_tolerance_hard
)
)
)
per_atom_num_clash
=
torch
.
sum
(
clash_mask
,
dim
=
(
-
4
,
-
2
))
+
torch
.
sum
(
clash_mask
,
dim
=
(
-
3
,
-
1
))
# Compute the per atom clash.
# Compute the per atom clash.
# shape (N, 14)
# shape (N, 14)
per_atom_clash_mask
=
torch
.
maximum
(
per_atom_clash_mask
=
torch
.
maximum
(
torch
.
amax
(
clash_mask
,
axis
=
(
-
4
,
-
2
)),
torch
.
amax
(
clash_mask
,
dim
=
(
-
4
,
-
2
)),
torch
.
amax
(
clash_mask
,
axis
=
(
-
3
,
-
1
)),
torch
.
amax
(
clash_mask
,
dim
=
(
-
3
,
-
1
)),
)
)
return
{
return
{
"mean_loss"
:
mean_loss
,
# shape ()
"mean_loss"
:
mean_loss
,
# shape ()
"per_atom_loss_sum"
:
per_atom_loss_sum
,
# shape (N, 14)
"per_atom_loss_sum"
:
per_atom_loss_sum
,
# shape (N, 14)
"per_atom_clash_mask"
:
per_atom_clash_mask
,
# shape (N, 14)
"per_atom_clash_mask"
:
per_atom_clash_mask
,
# shape (N, 14)
"per_atom_num_clash"
:
per_atom_num_clash
# shape (N, 14)
}
}
...
@@ -1109,6 +1155,8 @@ def within_residue_violations(
...
@@ -1109,6 +1155,8 @@ def within_residue_violations(
(
dists
<
atom14_dists_lower_bound
)
|
(
dists
>
atom14_dists_upper_bound
)
(
dists
<
atom14_dists_lower_bound
)
|
(
dists
>
atom14_dists_upper_bound
)
)
)
per_atom_num_clash
=
torch
.
sum
(
violations
,
dim
=-
2
)
+
torch
.
sum
(
violations
,
dim
=-
1
)
# Compute the per atom violations.
# Compute the per atom violations.
per_atom_violations
=
torch
.
maximum
(
per_atom_violations
=
torch
.
maximum
(
torch
.
max
(
violations
,
dim
=-
2
)[
0
],
torch
.
max
(
violations
,
axis
=-
1
)[
0
]
torch
.
max
(
violations
,
dim
=-
2
)[
0
],
torch
.
max
(
violations
,
axis
=-
1
)[
0
]
...
@@ -1117,6 +1165,7 @@ def within_residue_violations(
...
@@ -1117,6 +1165,7 @@ def within_residue_violations(
return
{
return
{
"per_atom_loss_sum"
:
per_atom_loss_sum
,
"per_atom_loss_sum"
:
per_atom_loss_sum
,
"per_atom_violations"
:
per_atom_violations
,
"per_atom_violations"
:
per_atom_violations
,
"per_atom_num_clash"
:
per_atom_num_clash
}
}
...
@@ -1146,11 +1195,24 @@ def find_structural_violations(
...
@@ -1146,11 +1195,24 @@ def find_structural_violations(
residue_constants
.
van_der_waals_radius
[
name
[
0
]]
residue_constants
.
van_der_waals_radius
[
name
[
0
]]
for
name
in
residue_constants
.
atom_types
for
name
in
residue_constants
.
atom_types
]
]
atomtype_radius
=
atom14_pred_positions
.
new_tensor
(
atomtype_radius
)
atomtype_radius
=
atom14_pred_positions
.
new_tensor
(
atomtype_radius
)
atom14_atom_radius
=
(
batch
[
"atom14_atom_exists"
]
#TODO: Consolidate monomer/multimer modes
*
atomtype_radius
[
batch
[
"residx_atom14_to_atom37"
]]
asym_id
=
batch
.
get
(
"asym_id"
)
)
if
asym_id
is
not
None
:
residx_atom14_to_atom37
=
get_rc_tensor
(
residue_constants
.
RESTYPE_ATOM14_TO_ATOM37
,
batch
[
"aatype"
]
)
atom14_atom_radius
=
(
batch
[
"atom14_atom_exists"
]
*
atomtype_radius
[
residx_atom14_to_atom37
.
long
()]
)
else
:
atom14_atom_radius
=
(
batch
[
"atom14_atom_exists"
]
*
atomtype_radius
[
batch
[
"residx_atom14_to_atom37"
]]
)
# Compute the between residue clash loss.
# Compute the between residue clash loss.
between_residue_clashes
=
between_residue_clash_loss
(
between_residue_clashes
=
between_residue_clash_loss
(
...
@@ -1158,6 +1220,7 @@ def find_structural_violations(
...
@@ -1158,6 +1220,7 @@ def find_structural_violations(
atom14_atom_exists
=
batch
[
"atom14_atom_exists"
],
atom14_atom_exists
=
batch
[
"atom14_atom_exists"
],
atom14_atom_radius
=
atom14_atom_radius
,
atom14_atom_radius
=
atom14_atom_radius
,
residue_index
=
batch
[
"residue_index"
],
residue_index
=
batch
[
"residue_index"
],
asym_id
=
asym_id
,
overlap_tolerance_soft
=
clash_overlap_tolerance
,
overlap_tolerance_soft
=
clash_overlap_tolerance
,
overlap_tolerance_hard
=
clash_overlap_tolerance
,
overlap_tolerance_hard
=
clash_overlap_tolerance
,
)
)
...
@@ -1220,6 +1283,9 @@ def find_structural_violations(
...
@@ -1220,6 +1283,9 @@ def find_structural_violations(
"clashes_per_atom_clash_mask"
:
between_residue_clashes
[
"clashes_per_atom_clash_mask"
:
between_residue_clashes
[
"per_atom_clash_mask"
"per_atom_clash_mask"
],
# (N, 14)
],
# (N, 14)
"clashes_per_atom_num_clash"
:
between_residue_clashes
[
"per_atom_num_clash"
],
# (N, 14)
},
},
"within_residues"
:
{
"within_residues"
:
{
"per_atom_loss_sum"
:
residue_violations
[
"per_atom_loss_sum"
:
residue_violations
[
...
@@ -1228,6 +1294,9 @@ def find_structural_violations(
...
@@ -1228,6 +1294,9 @@ def find_structural_violations(
"per_atom_violations"
:
residue_violations
[
"per_atom_violations"
:
residue_violations
[
"per_atom_violations"
"per_atom_violations"
],
# (N, 14),
],
# (N, 14),
"per_atom_num_clash"
:
residue_violations
[
"per_atom_num_clash"
],
# (N, 14)
},
},
"total_per_residue_violations_mask"
:
per_residue_violations_mask
,
# (N)
"total_per_residue_violations_mask"
:
per_residue_violations_mask
,
# (N)
}
}
...
@@ -1349,15 +1418,21 @@ def compute_violation_metrics_np(
...
@@ -1349,15 +1418,21 @@ def compute_violation_metrics_np(
def
violation_loss
(
def
violation_loss
(
violations
:
Dict
[
str
,
torch
.
Tensor
],
violations
:
Dict
[
str
,
torch
.
Tensor
],
atom14_atom_exists
:
torch
.
Tensor
,
atom14_atom_exists
:
torch
.
Tensor
,
average_clashes
:
bool
=
False
,
eps
=
1e-6
,
eps
=
1e-6
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_atoms
=
torch
.
sum
(
atom14_atom_exists
)
num_atoms
=
torch
.
sum
(
atom14_atom_exists
)
l_clash
=
torch
.
sum
(
violations
[
"between_residues"
][
"clashes_per_atom_loss_sum"
]
per_atom_clash
=
(
violations
[
"between_residues"
][
"clashes_per_atom_loss_sum"
]
+
+
violations
[
"within_residues"
][
"per_atom_loss_sum"
]
violations
[
"within_residues"
][
"per_atom_loss_sum"
])
)
l_clash
=
l_clash
/
(
eps
+
num_atoms
)
if
average_clashes
:
num_clash
=
(
violations
[
"between_residues"
][
"clashes_per_atom_num_clash"
]
+
violations
[
"within_residues"
][
"per_atom_num_clash"
])
per_atom_clash
=
per_atom_clash
/
(
num_clash
+
eps
)
l_clash
=
torch
.
sum
(
per_atom_clash
)
/
(
eps
+
num_atoms
)
loss
=
(
loss
=
(
violations
[
"between_residues"
][
"bonds_c_n_loss_mean"
]
violations
[
"between_residues"
][
"bonds_c_n_loss_mean"
]
+
violations
[
"between_residues"
][
"angles_ca_c_n_loss_mean"
]
+
violations
[
"between_residues"
][
"angles_ca_c_n_loss_mean"
]
...
@@ -1533,6 +1608,64 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs
...
@@ -1533,6 +1608,64 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs
return
loss
return
loss
def
chain_center_of_mass_loss
(
all_atom_pred_pos
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
asym_id
:
torch
.
Tensor
,
clamp_distance
:
float
=
-
4.0
,
weight
:
float
=
0.05
,
eps
:
float
=
1e-10
)
->
torch
.
Tensor
:
"""
Computes chain centre-of-mass loss. Implements section 2.5, eqn 1 in the Multimer paper.
Args:
all_atom_pred_pos:
[*, N_pts, 37, 3] All-atom predicted atom positions
all_atom_positions:
[*, N_pts, 37, 3] Ground truth all-atom positions
all_atom_mask:
[*, N_pts, 37] All-atom positions mask
asym_id:
[*, N_pts] Chain asym IDs
clamp_distance:
Cutoff above which distance errors are disregarded
weight:
Weight for loss
eps:
Small value used to regularize denominators
Returns:
[*] loss tensor
"""
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
chains
,
_
=
asym_id
.
unique
(
return_counts
=
True
)
one_hot
=
torch
.
nn
.
functional
.
one_hot
(
asym_id
,
num_classes
=
chains
.
shape
[
0
]).
to
(
dtype
=
all_atom_mask
.
dtype
)
one_hot
=
one_hot
*
all_atom_mask
chain_pos_mask
=
one_hot
.
transpose
(
-
2
,
-
1
)
chain_exists
=
torch
.
any
(
chain_pos_mask
,
dim
=-
1
).
float
()
def
get_chain_center_of_mass
(
pos
):
center_sum
=
(
chain_pos_mask
[...,
None
]
*
pos
[...,
None
,
:,
:]).
sum
(
dim
=-
2
)
centers
=
center_sum
/
(
torch
.
sum
(
chain_pos_mask
,
dim
=-
1
,
keepdim
=
True
)
+
eps
)
return
Vec3Array
.
from_array
(
centers
)
pred_centers
=
get_chain_center_of_mass
(
all_atom_pred_pos
)
# [B, NC, 3]
true_centers
=
get_chain_center_of_mass
(
all_atom_positions
)
# [B, NC, 3]
pred_dists
=
euclidean_distance
(
pred_centers
[...,
None
,
:],
pred_centers
[...,
:,
None
],
epsilon
=
eps
)
true_dists
=
euclidean_distance
(
true_centers
[...,
None
,
:],
true_centers
[...,
:,
None
],
epsilon
=
eps
)
losses
=
torch
.
clamp
((
weight
*
(
pred_dists
-
true_dists
-
clamp_distance
)),
max
=
0
)
**
2
loss_mask
=
chain_exists
[...,
:,
None
]
*
chain_exists
[...,
None
,
:]
loss
=
masked_mean
(
loss_mask
,
losses
,
dim
=
(
-
1
,
-
2
))
return
loss
class
AlphaFoldLoss
(
nn
.
Module
):
class
AlphaFoldLoss
(
nn
.
Module
):
"""Aggregation of the various losses described in the supplement"""
"""Aggregation of the various losses described in the supplement"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -1585,7 +1718,7 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1585,7 +1718,7 @@ class AlphaFoldLoss(nn.Module):
),
),
"violation"
:
lambda
:
violation_loss
(
"violation"
:
lambda
:
violation_loss
(
out
[
"violation"
],
out
[
"violation"
],
**
batch
,
**
{
**
batch
,
**
self
.
config
.
violation
},
),
),
}
}
...
@@ -1595,6 +1728,12 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1595,6 +1728,12 @@ class AlphaFoldLoss(nn.Module):
**
{
**
batch
,
**
out
,
**
self
.
config
.
tm
},
**
{
**
batch
,
**
out
,
**
self
.
config
.
tm
},
)
)
if
(
self
.
config
.
chain_center_of_mass
.
enabled
):
loss_fns
[
"chain_center_of_mass"
]
=
lambda
:
chain_center_of_mass_loss
(
all_atom_pred_pos
=
out
[
"final_atom_positions"
],
**
{
**
batch
,
**
self
.
config
.
chain_center_of_mass
},
)
cum_loss
=
0.
cum_loss
=
0.
losses
=
{}
losses
=
{}
for
loss_name
,
loss_fn
in
loss_fns
.
items
():
for
loss_name
,
loss_fn
in
loss_fns
.
items
():
...
...
tests/config.py
View file @
51556d52
...
@@ -6,7 +6,7 @@ consts = mlc.ConfigDict(
...
@@ -6,7 +6,7 @@ consts = mlc.ConfigDict(
"is_multimer"
:
True
,
# monomer: False, multimer: True
"is_multimer"
:
True
,
# monomer: False, multimer: True
"chunk_size"
:
4
,
"chunk_size"
:
4
,
"batch_size"
:
2
,
"batch_size"
:
2
,
"n_res"
:
11
,
"n_res"
:
22
,
"n_seq"
:
13
,
"n_seq"
:
13
,
"n_templ"
:
3
,
"n_templ"
:
3
,
"n_extra"
:
17
,
"n_extra"
:
17
,
...
...
tests/data_utils.py
View file @
51556d52
...
@@ -29,14 +29,16 @@ def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
...
@@ -29,14 +29,16 @@ def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
pieces
=
[]
pieces
=
[]
asym_ids
=
[]
asym_ids
=
[]
final_idx
=
n_chain
-
1
for
idx
in
range
(
n_chain
-
1
):
for
idx
in
range
(
n_chain
-
1
):
n_stop
=
(
n_res
-
sum
(
pieces
)
-
n_chain
+
idx
-
min_chain_len
)
n_stop
=
(
n_res
-
sum
(
pieces
)
-
n_chain
+
idx
-
min_chain_len
)
if
n_stop
<=
min_chain_len
:
if
n_stop
<=
min_chain_len
:
final_idx
=
idx
break
break
piece
=
randint
(
min_chain_len
,
n_stop
)
piece
=
randint
(
min_chain_len
,
n_stop
)
pieces
.
append
(
piece
)
pieces
.
append
(
piece
)
asym_ids
.
extend
(
piece
*
[
idx
])
asym_ids
.
extend
(
piece
*
[
idx
])
asym_ids
.
extend
((
n_res
-
sum
(
pieces
))
*
[
n_chain
-
1
])
asym_ids
.
extend
((
n_res
-
sum
(
pieces
))
*
[
final_idx
])
return
np
.
array
(
asym_ids
).
astype
(
np
.
int64
)
return
np
.
array
(
asym_ids
).
astype
(
np
.
int64
)
...
...
tests/test_loss.py
View file @
51556d52
...
@@ -20,6 +20,7 @@ import unittest
...
@@ -20,6 +20,7 @@ import unittest
import
ml_collections
as
mlc
import
ml_collections
as
mlc
from
openfold.data
import
data_transforms
from
openfold.data
import
data_transforms
from
openfold.np
import
residue_constants
from
openfold.utils.rigid_utils
import
(
from
openfold.utils.rigid_utils
import
(
Rotation
,
Rotation
,
Rigid
,
Rigid
,
...
@@ -42,6 +43,8 @@ from openfold.utils.loss import (
...
@@ -42,6 +43,8 @@ from openfold.utils.loss import (
sidechain_loss
,
sidechain_loss
,
tm_loss
,
tm_loss
,
compute_plddt
,
compute_plddt
,
compute_tm
,
chain_center_of_mass_loss
)
)
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tree_map
,
tree_map
,
...
@@ -233,11 +236,24 @@ class TestLoss(unittest.TestCase):
...
@@ -233,11 +236,24 @@ class TestLoss(unittest.TestCase):
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
atom_exists
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
)
atom_exists
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
14
)).
astype
(
np
.
float32
)
atom_radius
=
np
.
random
.
rand
(
n_res
,
14
).
astype
(
np
.
float32
)
res_ind
=
np
.
arange
(
res_ind
=
np
.
arange
(
n_res
,
n_res
,
)
)
asym_id
=
random_asym_ids
(
n_res
)
residx_atom14_to_atom37
=
np
.
random
.
randint
(
0
,
37
,
(
n_res
,
14
)).
astype
(
np
.
int64
)
atomtype_radius
=
[
residue_constants
.
van_der_waals_radius
[
name
[
0
]]
for
name
in
residue_constants
.
atom_types
]
atomtype_radius
=
np
.
array
(
atomtype_radius
).
astype
(
np
.
float32
)
atom_radius
=
(
atom_exists
*
atomtype_radius
[
residx_atom14_to_atom37
]
)
asym_id
=
None
if
consts
.
is_multimer
:
asym_id
=
random_asym_ids
(
n_res
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
{},
{},
...
@@ -256,6 +272,7 @@ class TestLoss(unittest.TestCase):
...
@@ -256,6 +272,7 @@ class TestLoss(unittest.TestCase):
torch
.
tensor
(
atom_exists
).
cuda
(),
torch
.
tensor
(
atom_exists
).
cuda
(),
torch
.
tensor
(
atom_radius
).
cuda
(),
torch
.
tensor
(
atom_radius
).
cuda
(),
torch
.
tensor
(
res_ind
).
cuda
(),
torch
.
tensor
(
res_ind
).
cuda
(),
torch
.
tensor
(
asym_id
).
cuda
()
if
asym_id
is
not
None
else
None
,
)
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
out_repro
=
tensor_tree_map
(
lambda
x
:
x
.
cpu
(),
out_repro
)
...
@@ -279,6 +296,36 @@ class TestLoss(unittest.TestCase):
...
@@ -279,6 +296,36 @@ class TestLoss(unittest.TestCase):
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compute_ptm_compare
(
self
):
n_res
=
consts
.
n_res
max_bin
=
31
no_bins
=
64
logits
=
np
.
random
.
rand
(
n_res
,
n_res
,
no_bins
)
boundaries
=
np
.
linspace
(
0
,
max_bin
,
num
=
(
no_bins
-
1
))
ptm_gt
=
alphafold
.
common
.
confidence
.
predicted_tm_score
(
logits
,
boundaries
)
ptm_gt
=
torch
.
tensor
(
ptm_gt
)
logits_t
=
torch
.
tensor
(
logits
)
ptm_repro
=
compute_tm
(
logits_t
,
no_bins
=
no_bins
,
max_bin
=
max_bin
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
ptm_gt
-
ptm_repro
))
<
consts
.
eps
)
if
consts
.
is_multimer
:
asym_id
=
random_asym_ids
(
n_res
)
iptm_gt
=
alphafold
.
common
.
confidence
.
predicted_tm_score
(
logits
,
boundaries
,
asym_id
=
asym_id
,
interface
=
True
)
iptm_gt
=
torch
.
tensor
(
iptm_gt
)
iptm_repro
=
compute_tm
(
logits_t
,
no_bins
=
no_bins
,
max_bin
=
max_bin
,
asym_id
=
torch
.
tensor
(
asym_id
),
interface
=
True
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
iptm_gt
-
iptm_repro
))
<
consts
.
eps
)
def
test_find_structural_violations
(
self
):
def
test_find_structural_violations
(
self
):
n
=
consts
.
n_res
n
=
consts
.
n_res
...
@@ -335,9 +382,11 @@ class TestLoss(unittest.TestCase):
...
@@ -335,9 +382,11 @@ class TestLoss(unittest.TestCase):
"residx_atom14_to_atom37"
:
np
.
random
.
randint
(
"residx_atom14_to_atom37"
:
np
.
random
.
randint
(
0
,
37
,
(
n_res
,
14
)
0
,
37
,
(
n_res
,
14
)
).
astype
(
np
.
int64
),
).
astype
(
np
.
int64
),
"asym_id"
:
random_asym_ids
(
n_res
)
}
}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
)
config
=
mlc
.
ConfigDict
(
config
=
mlc
.
ConfigDict
(
...
@@ -632,6 +681,40 @@ class TestLoss(unittest.TestCase):
...
@@ -632,6 +681,40 @@ class TestLoss(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_violation_loss
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
c_viol
=
config
.
model
.
heads
.
structure_module
n_res
=
consts
.
n_res
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"residue_index"
:
np
.
arange
(
n_res
),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
).
cuda
(),
batch
,
np
.
ndarray
)
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
atom14_pred_pos
=
torch
.
tensor
(
atom14_pred_pos
).
cuda
()
batch
=
data_transforms
.
make_atom14_masks
(
batch
)
loss_sum_clash
=
violation_loss
(
find_structural_violations
(
batch
,
atom14_pred_pos
,
**
c_viol
),
average_clashes
=
False
,
**
batch
)
loss_sum_clash
=
loss_sum_clash
.
cpu
()
loss_avg_clash
=
violation_loss
(
find_structural_violations
(
batch
,
atom14_pred_pos
,
**
c_viol
),
average_clashes
=
True
,
**
batch
)
loss_avg_clash
=
loss_avg_clash
.
cpu
()
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_violation_loss_compare
(
self
):
def
test_violation_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
...
@@ -680,10 +763,12 @@ class TestLoss(unittest.TestCase):
...
@@ -680,10 +763,12 @@ class TestLoss(unittest.TestCase):
batch
=
{
batch
=
{
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"seq_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
),
"residue_index"
:
np
.
arange
(
n_res
),
"residue_index"
:
np
.
arange
(
n_res
),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,)),
"aatype"
:
np
.
random
.
randint
(
0
,
21
,
(
n_res
,))
"asym_id"
:
random_asym_ids
(
n_res
)
}
}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
atom14_pred_pos
=
np
.
random
.
rand
(
n_res
,
14
,
3
).
astype
(
np
.
float32
)
alphafold
.
model
.
tf
.
data_transforms
.
make_atom14_masks
(
batch
)
alphafold
.
model
.
tf
.
data_transforms
.
make_atom14_masks
(
batch
)
...
@@ -801,8 +886,7 @@ class TestLoss(unittest.TestCase):
...
@@ -801,8 +886,7 @@ class TestLoss(unittest.TestCase):
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
"backbone_affine_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
n_res
,)).
astype
(
np
.
float32
np
.
float32
),
),
"use_clamped_fape"
:
np
.
array
(
0.0
),
"use_clamped_fape"
:
np
.
array
(
0.0
)
"asym_id"
:
random_asym_ids
(
n_res
)
}
}
value
=
{
value
=
{
...
@@ -814,6 +898,9 @@ class TestLoss(unittest.TestCase):
...
@@ -814,6 +898,9 @@ class TestLoss(unittest.TestCase):
),
),
}
}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
out_gt
=
f
.
apply
({},
None
,
batch
,
value
)
out_gt
=
f
.
apply
({},
None
,
batch
,
value
)
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
out_gt
=
torch
.
tensor
(
np
.
array
(
out_gt
.
block_until_ready
()))
...
@@ -826,8 +913,18 @@ class TestLoss(unittest.TestCase):
...
@@ -826,8 +913,18 @@ class TestLoss(unittest.TestCase):
)
)
batch
[
"backbone_rigid_mask"
]
=
batch
[
"backbone_affine_mask"
]
batch
[
"backbone_rigid_mask"
]
=
batch
[
"backbone_affine_mask"
]
out_repro
=
backbone_loss
(
traj
=
value
[
"traj"
],
**
{
**
batch
,
**
c_sm
})
if
consts
.
is_multimer
:
out_repro
=
out_repro
.
cpu
()
intra_chain_mask
=
(
batch
[
"asym_id"
][...,
None
]
==
batch
[
"asym_id"
][...,
None
,
:]).
to
(
dtype
=
value
[
"traj"
].
dtype
)
intra_chain_out
=
backbone_loss
(
traj
=
value
[
"traj"
],
pair_mask
=
intra_chain_mask
,
**
{
**
batch
,
**
c_sm
.
intra_chain_fape
})
interface_out
=
backbone_loss
(
traj
=
value
[
"traj"
],
pair_mask
=
1.
-
intra_chain_mask
,
**
{
**
batch
,
**
c_sm
.
interface_fape
})
out_repro
=
intra_chain_out
+
interface_out
out_repro
=
out_repro
.
cpu
()
else
:
out_repro
=
backbone_loss
(
traj
=
value
[
"traj"
],
**
{
**
batch
,
**
c_sm
})
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
...
@@ -869,14 +966,14 @@ class TestLoss(unittest.TestCase):
...
@@ -869,14 +966,14 @@ class TestLoss(unittest.TestCase):
v
[
"sidechains"
]
=
{}
v
[
"sidechains"
]
=
{}
v
[
"sidechains"
][
v
[
"sidechains"
][
"frames"
"frames"
]
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
]
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
value
[
"sidechains"
][
"frames"
]
value
[
"sidechains"
][
"frames"
]
)
)
v
[
"sidechains"
][
"atom_pos"
]
=
alphafold
.
model
.
r3
.
vecs_from_tensor
(
v
[
"sidechains"
][
"atom_pos"
]
=
self
.
am_rigid
.
vecs_from_tensor
(
value
[
"sidechains"
][
"atom_pos"
]
value
[
"sidechains"
][
"atom_pos"
]
)
)
v
.
update
(
v
.
update
(
alphafold
.
model
.
fold
ing
.
compute_renamed_ground_truth
(
self
.
am_
fold
.
compute_renamed_ground_truth
(
batch
,
batch
,
atom14_pred_positions
,
atom14_pred_positions
,
)
)
...
@@ -907,9 +1004,6 @@ class TestLoss(unittest.TestCase):
...
@@ -907,9 +1004,6 @@ class TestLoss(unittest.TestCase):
),
),
}
}
if
consts
.
is_multimer
:
batch
[
"asym_id"
]
=
random_asym_ids
(
n_res
)
def
_build_extra_feats_np
():
def
_build_extra_feats_np
():
b
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
b
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
b
=
data_transforms
.
make_atom14_masks
(
b
)
b
=
data_transforms
.
make_atom14_masks
(
b
)
...
@@ -950,7 +1044,7 @@ class TestLoss(unittest.TestCase):
...
@@ -950,7 +1044,7 @@ class TestLoss(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
@
unittest
.
skipIf
(
consts
.
is_multimer
or
"ptm"
not
in
consts
.
model
,
"Not enabled for non-ptm models."
)
@
unittest
.
skipIf
(
not
consts
.
is_multimer
and
"ptm"
not
in
consts
.
model
,
"Not enabled for non-ptm models."
)
def
test_tm_loss_compare
(
self
):
def
test_tm_loss_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_tm
=
config
.
model
.
heads
.
predicted_aligned_error
c_tm
=
config
.
model
.
heads
.
predicted_aligned_error
...
@@ -1017,6 +1111,33 @@ class TestLoss(unittest.TestCase):
...
@@ -1017,6 +1111,33 @@ class TestLoss(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_chain_center_of_mass_loss
(
self
):
batch_size
=
consts
.
batch_size
n_res
=
consts
.
n_res
batch
=
{
"all_atom_positions"
:
np
.
random
.
rand
(
batch_size
,
n_res
,
37
,
3
).
astype
(
np
.
float32
)
*
10.0
,
"all_atom_mask"
:
np
.
random
.
randint
(
0
,
2
,
(
batch_size
,
n_res
,
37
)).
astype
(
np
.
float32
),
"asym_id"
:
np
.
stack
([
random_asym_ids
(
n_res
)
for
_
in
range
(
batch_size
)])
}
config
=
{
"weight"
:
0.05
,
"clamp_distance"
:
-
4.0
,
}
final_atom_positions
=
torch
.
rand
(
batch_size
,
n_res
,
37
,
3
).
cuda
()
to_tensor
=
lambda
t
:
torch
.
tensor
(
t
).
cuda
()
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
out_repro
=
chain_center_of_mass_loss
(
all_atom_pred_pos
=
final_atom_positions
,
**
{
**
batch
,
**
config
},
)
out_repro
=
out_repro
.
cpu
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment