Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
3d87ef2f
Unverified
Commit
3d87ef2f
authored
Jun 27, 2023
by
Dingquan Yu
Committed by
GitHub
Jun 27, 2023
Browse files
Merge pull request #3 from dingquanyu/modify-assignment-stage
Modify assignment stage
parents
2a70e080
eeb035c2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
64 deletions
+42
-64
openfold/utils/loss.py
openfold/utils/loss.py
+42
-64
No files found.
openfold/utils/loss.py
View file @
3d87ef2f
...
...
@@ -1677,7 +1677,7 @@ def chain_center_of_mass_loss(
# #
def
kabsch_rotation
(
P
,
Q
):
"""
Use
scipy.spatial
package to calculate best rotation that minimises
Use
procrustes
package to calculate best rotation that minimises
the RMSD betwee P and Q
The optimal rotation matrix was calculated using
...
...
@@ -1755,19 +1755,6 @@ def compute_rmsd(
msd
=
torch
.
nan_to_num
(
msd
,
nan
=
1e8
)
return
torch
.
sqrt
(
msd
+
eps
)
def
kabsch_rmsd
(
true_atom_pos
:
torch
.
Tensor
,
pred_atom_pos
:
torch
.
Tensor
,
atom_mask
:
torch
.
Tensor
,
):
r
,
x
=
get_optimal_transform
(
true_atom_pos
,
pred_atom_pos
,
atom_mask
,
)
aligned_true_atom_pos
=
true_atom_pos
@
r
+
x
return
compute_rmsd
(
aligned_true_atom_pos
,
pred_atom_pos
,
atom_mask
)
def
get_least_asym_entity_or_longest_length
(
batch
):
"""
...
...
@@ -1802,6 +1789,11 @@ def get_least_asym_entity_or_longest_length(batch):
least_asym_entities
=
random
.
choice
(
least_asym_entities
)
assert
len
(
least_asym_entities
)
==
1
best_pred_asym
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
least_asym_entities
[
0
]])
# If there is more than one chain in the predicted output that has the same sequence
# as the chosen ground truth anchor, then randomly picke one
if
len
(
best_pred_asym
)
>
1
:
best_pred_asym
=
random
.
choice
(
best_pred_asym
)
return
least_asym_entities
[
0
],
best_pred_asym
...
...
@@ -2032,21 +2024,17 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
batch
[
"residue_index"
][
asym_mask
]
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
print
(
f
"anchor_gt_asym is
chosen to be:
{
anchor_
gt
_asym
}
"
)
print
(
f
"anchor_gt_asym is
:
{
anchor_gt_asym
}
and anchor_pred_asym is
{
anchor_
pred
_asym
}
"
)
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
best_rmsd
=
1e20
best_labels
=
None
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
entity_2_asym_list
=
{}
for
cur_ent_id
in
unique_entity_ids
:
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
for
cur_asym_id
in
anchor_pred_asym
:
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
anchor_residue_idx
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
asym_mask
=
(
batch
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
anchor_residue_idx
=
per_asym_residue_index
[
int
(
anchor_pred_asym
)]
anchor_true_pos
=
true_ca_poses
[
anchor_gt_idx
][
anchor_residue_idx
]
anchor_pred_pos
=
pred_ca_pos
[
asym_mask
]
...
...
@@ -2059,15 +2047,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
)
aligned_true_ca_poses
=
[
ca
.
to
(
'cpu'
)
@
r
.
to
(
'cpu'
)
+
x
.
to
(
'cpu'
)
for
ca
in
true_ca_poses
]
# apply transforms
for
_
in
range
(
shuffle_times
):
shuffle_idx
=
torch
.
randperm
(
unique_asym_ids
.
shape
[
0
],
device
=
unique_asym_ids
.
device
)
shuffled_asym_ids
=
unique_asym_ids
[
shuffle_idx
]
align
=
greedy_align
(
batch
,
per_asym_residue_index
,
shuffled
_asym_ids
,
unique
_asym_ids
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_mask
,
...
...
@@ -2080,17 +2063,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
labels
,
align
,
)
rmsd
=
kabsch_rmsd
(
merged_labels
[
"all_atom_positions"
][...,
ca_idx
,
:].
to
(
'cpu'
)
@
r
.
to
(
'cpu'
)
+
x
.
to
(
'cpu'
),
pred_ca_pos
,
(
pred_ca_mask
.
to
(
'cpu'
)
*
merged_labels
[
"all_atom_mask"
][...,
ca_idx
].
to
(
'cpu'
)).
bool
(),
)
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_labels
=
merged_labels
print
(
f
"finished shuffling and final align is
{
align
}
"
)
return
best_labels
print
(
f
"finished multi-chain permutation and final align is
{
align
}
"
)
return
merged_labels
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
"""
...
...
@@ -2107,6 +2083,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
# then permutate ground truth chains before calculating the loss
permutated_labels
=
self
.
multi_chain_perm_align
(
out
,
features
,
labels
)
logger
.
info
(
"finished multi-chain permutation"
)
# features.update(permutated_labels)
# self.loss(out,features)
return
permutated_labels
## TODO next need to check how the ground truth label is used
# in loss calculation.
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment