Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
24e470be
Commit
24e470be
authored
Jun 27, 2023
by
Geoffrey Yu
Browse files
remove shuffling
parent
2a70e080
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
64 deletions
+40
-64
openfold/utils/loss.py
openfold/utils/loss.py
+40
-64
No files found.
openfold/utils/loss.py
View file @
24e470be
...
@@ -1677,7 +1677,7 @@ def chain_center_of_mass_loss(
...
@@ -1677,7 +1677,7 @@ def chain_center_of_mass_loss(
# #
# #
def
kabsch_rotation
(
P
,
Q
):
def
kabsch_rotation
(
P
,
Q
):
"""
"""
Use
scipy.spatial
package to calculate best rotation that minimises
Use
procrustes
package to calculate best rotation that minimises
the RMSD betwee P and Q
the RMSD betwee P and Q
The optimal rotation matrix was calculated using
The optimal rotation matrix was calculated using
...
@@ -1755,19 +1755,6 @@ def compute_rmsd(
...
@@ -1755,19 +1755,6 @@ def compute_rmsd(
msd
=
torch
.
nan_to_num
(
msd
,
nan
=
1e8
)
msd
=
torch
.
nan_to_num
(
msd
,
nan
=
1e8
)
return
torch
.
sqrt
(
msd
+
eps
)
return
torch
.
sqrt
(
msd
+
eps
)
def
kabsch_rmsd
(
true_atom_pos
:
torch
.
Tensor
,
pred_atom_pos
:
torch
.
Tensor
,
atom_mask
:
torch
.
Tensor
,
):
r
,
x
=
get_optimal_transform
(
true_atom_pos
,
pred_atom_pos
,
atom_mask
,
)
aligned_true_atom_pos
=
true_atom_pos
@
r
+
x
return
compute_rmsd
(
aligned_true_atom_pos
,
pred_atom_pos
,
atom_mask
)
def
get_least_asym_entity_or_longest_length
(
batch
):
def
get_least_asym_entity_or_longest_length
(
batch
):
"""
"""
...
@@ -1802,6 +1789,11 @@ def get_least_asym_entity_or_longest_length(batch):
...
@@ -1802,6 +1789,11 @@ def get_least_asym_entity_or_longest_length(batch):
least_asym_entities
=
random
.
choice
(
least_asym_entities
)
least_asym_entities
=
random
.
choice
(
least_asym_entities
)
assert
len
(
least_asym_entities
)
==
1
assert
len
(
least_asym_entities
)
==
1
best_pred_asym
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
least_asym_entities
[
0
]])
best_pred_asym
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
least_asym_entities
[
0
]])
# If there is more than one chain in the predicted output that has the same sequence
# as the chosen ground truth anchor, then randomly picke one
if
len
(
best_pred_asym
)
>
1
:
best_pred_asym
=
random
.
choice
(
best_pred_asym
)
return
least_asym_entities
[
0
],
best_pred_asym
return
least_asym_entities
[
0
],
best_pred_asym
...
@@ -2032,65 +2024,49 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2032,65 +2024,49 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
batch
[
"residue_index"
][
asym_mask
]
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
batch
[
"residue_index"
][
asym_mask
]
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
print
(
f
"anchor_gt_asym is
chosen to be:
{
anchor_
gt
_asym
}
"
)
print
(
f
"anchor_gt_asym is
:
{
anchor_gt_asym
}
and anchor_pred_asym is
{
anchor_
pred
_asym
}
"
)
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
best_rmsd
=
1e20
best_labels
=
None
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
entity_2_asym_list
=
{}
entity_2_asym_list
=
{}
for
cur_ent_id
in
unique_entity_ids
:
for
cur_ent_id
in
unique_entity_ids
:
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
for
cur_asym_id
in
anchor_pred_asym
:
asym_mask
=
(
batch
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
anchor_residue_idx
=
per_asym_residue_index
[
int
(
anchor_pred_asym
)]
anchor_residue_idx
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
anchor_true_pos
=
true_ca_poses
[
anchor_gt_idx
][
anchor_residue_idx
]
anchor_true_pos
=
true_ca_poses
[
anchor_gt_idx
][
anchor_residue_idx
]
anchor_pred_pos
=
pred_ca_pos
[
asym_mask
]
anchor_pred_pos
=
pred_ca_pos
[
asym_mask
]
anchor_true_mask
=
true_ca_masks
[
anchor_gt_idx
][
anchor_residue_idx
]
anchor_true_mask
=
true_ca_masks
[
anchor_gt_idx
][
anchor_residue_idx
]
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
r
,
x
=
get_optimal_transform
(
r
,
x
=
get_optimal_transform
(
anchor_true_pos
,
anchor_true_pos
,
anchor_pred_pos
,
anchor_pred_pos
,
(
anchor_true_mask
.
to
(
'cpu'
)
*
anchor_pred_mask
.
to
(
'cpu'
)).
bool
(),
(
anchor_true_mask
.
to
(
'cpu'
)
*
anchor_pred_mask
.
to
(
'cpu'
)).
bool
(),
)
aligned_true_ca_poses
=
[
ca
.
to
(
'cpu'
)
@
r
.
to
(
'cpu'
)
+
x
.
to
(
'cpu'
)
for
ca
in
true_ca_poses
]
# apply transforms
align
=
greedy_align
(
batch
,
per_asym_residue_index
,
unique_asym_ids
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_mask
,
aligned_true_ca_poses
,
true_ca_masks
,
)
)
merged_labels
=
merge_labels
(
aligned_true_ca_poses
=
[
ca
.
to
(
'cpu'
)
@
r
.
to
(
'cpu'
)
+
x
.
to
(
'cpu'
)
for
ca
in
true_ca_poses
]
# apply transforms
batch
,
for
_
in
range
(
shuffle_times
):
per_asym_residue_index
,
shuffle_idx
=
torch
.
randperm
(
labels
,
unique_asym_ids
.
shape
[
0
],
device
=
unique_asym_ids
.
device
align
,
)
)
shuffled_asym_ids
=
unique_asym_ids
[
shuffle_idx
]
align
=
greedy_align
(
print
(
f
"finished multi-chain permutation and final align is
{
align
}
"
)
batch
,
per_asym_residue_index
,
return
merged_labels
shuffled_asym_ids
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_mask
,
aligned_true_ca_poses
,
true_ca_masks
,
)
merged_labels
=
merge_labels
(
batch
,
per_asym_residue_index
,
labels
,
align
,
)
rmsd
=
kabsch_rmsd
(
merged_labels
[
"all_atom_positions"
][...,
ca_idx
,
:].
to
(
'cpu'
)
@
r
.
to
(
'cpu'
)
+
x
.
to
(
'cpu'
),
pred_ca_pos
,
(
pred_ca_mask
.
to
(
'cpu'
)
*
merged_labels
[
"all_atom_mask"
][...,
ca_idx
].
to
(
'cpu'
)).
bool
(),
)
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_labels
=
merged_labels
print
(
f
"finished shuffling and final align is
{
align
}
"
)
return
best_labels
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment