Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
68389359
"pcdet/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "fa78c4f121f5801b223a64290978e0c5094b18fe"
Commit
68389359
authored
Sep 28, 2023
by
Geoffrey Yu
Browse files
update permutation logic so that it check all valid anchor pairs
parent
2da285aa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
31 deletions
+36
-31
openfold/utils/loss.py
openfold/utils/loss.py
+36
-31
No files found.
openfold/utils/loss.py
View file @
68389359
...
...
@@ -1828,7 +1828,6 @@ def get_least_asym_entity_or_longest_length(batch,input_asym_id):
anchor_pred_asym_ids
=
[
id
for
id
in
entity_2_asym_list
[
least_asym_entities
]
if
id
in
input_asym_id
]
return
anchor_gt_asym_id
,
anchor_pred_asym_ids
def
greedy_align
(
batch
,
per_asym_residue_index
,
...
...
@@ -1897,15 +1896,12 @@ def merge_labels(per_asym_residue_index,labels, align,original_nres):
cur_out
=
{}
for
i
,
j
in
align
:
label
=
labels
[
j
][
k
]
cur_num_res
=
labels
[
j
][
'aatype'
].
shape
[
-
1
]
# to 1-based
cur_residue_index
=
per_asym_residue_index
[
i
+
1
]
if
len
(
v
.
shape
)
<=
1
or
"template"
in
k
or
"row_mask"
in
k
:
continue
else
:
dimension_to_merge
=
label
.
shape
.
index
(
cur_num_res
)
if
cur_num_res
in
label
.
shape
else
0
if
k
==
'all_atom_positions'
:
dimension_to_merge
=
1
dimension_to_merge
=
1
cur_out
[
i
]
=
label
.
index_select
(
dimension_to_merge
,
cur_residue_index
)
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
if
len
(
cur_out
)
>
0
:
...
...
@@ -2144,16 +2140,17 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
feature
,
ground_truth
=
batch
del
batch
if
permutate_chains
:
best_rmsd
=
float
(
'inf'
)
best_align
=
None
# First select anchors from predicted structures and ground truths
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
ground_truth
,
feature
[
'asym_id'
])
anchor_gt_asym
,
anchor_pred_asym_ids
=
get_least_asym_entity_or_longest_length
(
ground_truth
,
feature
[
'asym_id'
])
print
(
f
"########## line 2147 anchor_pred_asym_ids is
{
anchor_pred_asym_ids
}
and gt_asym is
{
anchor_gt_asym
}
"
)
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
ground_truth
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
assert
isinstance
(
labels
,
list
)
del
ground_truth
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
asym_mask
=
(
feature
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
# Then calculate optimal transform by aligning anchors
ca_idx
=
rc
.
atom_order
[
"CA"
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
...
...
@@ -2165,28 +2162,36 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
]
# list([nres,])
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
anchor_gt_residue
=
per_asym_residue_index
[
int
(
anchor_gt_asym
)]
r
,
x
=
AlphaFoldMultimerLoss
.
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
anchor_gt_residue
,
true_ca_masks
,
pred_ca_mask
,
asym_mask
,
pred_ca_pos
)
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
del
true_ca_poses
,
r
,
x
gc
.
collect
()
align
=
greedy_align
(
feature
,
per_asym_residue_index
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_mask
,
aligned_true_ca_poses
,
true_ca_masks
,
)
for
candidate_pred_anchor
in
anchor_pred_asym_ids
:
asym_mask
=
(
feature
[
"asym_id"
]
==
candidate_pred_anchor
).
bool
()
anchor_gt_residue
=
per_asym_residue_index
[
int
(
candidate_pred_anchor
)]
r
,
x
=
AlphaFoldMultimerLoss
.
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
anchor_gt_residue
,
true_ca_masks
,
pred_ca_mask
,
asym_mask
,
pred_ca_pos
)
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
align
=
greedy_align
(
feature
,
per_asym_residue_index
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_mask
,
aligned_true_ca_poses
,
true_ca_masks
,
)
merged_labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
=
feature
[
'aatype'
].
shape
[
-
1
])
rmsd
=
compute_rmsd
(
true_atom_pos
=
merged_labels
[
'all_atom_positions'
][...,
ca_idx
,
:].
to
(
r
.
dtype
)
@
r
+
x
,
pred_atom_pos
=
pred_ca_pos
,
atom_mask
=
(
pred_ca_mask
*
merged_labels
[
'all_atom_mask'
][...,
ca_idx
].
long
()).
bool
())
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_align
=
align
print
(
f
"##### 2193 rmsd is
{
rmsd
}
and anchor_gt_asym is
{
anchor_gt_asym
}
and candidate_pred_anchor is
{
candidate_pred_anchor
}
"
)
del
r
,
x
del
true_ca_masks
,
aligned_true_ca_poses
del
pred_ca_pos
,
pred_ca_mask
gc
.
collect
()
...
...
@@ -2195,9 +2200,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
align
=
list
(
enumerate
(
range
(
len
(
labels
))))
best_
align
=
list
(
enumerate
(
range
(
len
(
labels
))))
return
align
,
per_asym_residue_index
return
best_
align
,
per_asym_residue_index
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment