Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
68389359
Commit
68389359
authored
Sep 28, 2023
by
Geoffrey Yu
Browse files
update permutation logic so that it check all valid anchor pairs
parent
2da285aa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
31 deletions
+36
-31
openfold/utils/loss.py
openfold/utils/loss.py
+36
-31
No files found.
openfold/utils/loss.py
View file @
68389359
...
@@ -1828,7 +1828,6 @@ def get_least_asym_entity_or_longest_length(batch,input_asym_id):
...
@@ -1828,7 +1828,6 @@ def get_least_asym_entity_or_longest_length(batch,input_asym_id):
anchor_pred_asym_ids
=
[
id
for
id
in
entity_2_asym_list
[
least_asym_entities
]
if
id
in
input_asym_id
]
anchor_pred_asym_ids
=
[
id
for
id
in
entity_2_asym_list
[
least_asym_entities
]
if
id
in
input_asym_id
]
return
anchor_gt_asym_id
,
anchor_pred_asym_ids
return
anchor_gt_asym_id
,
anchor_pred_asym_ids
def
greedy_align
(
def
greedy_align
(
batch
,
batch
,
per_asym_residue_index
,
per_asym_residue_index
,
...
@@ -1897,15 +1896,12 @@ def merge_labels(per_asym_residue_index,labels, align,original_nres):
...
@@ -1897,15 +1896,12 @@ def merge_labels(per_asym_residue_index,labels, align,original_nres):
cur_out
=
{}
cur_out
=
{}
for
i
,
j
in
align
:
for
i
,
j
in
align
:
label
=
labels
[
j
][
k
]
label
=
labels
[
j
][
k
]
cur_num_res
=
labels
[
j
][
'aatype'
].
shape
[
-
1
]
# to 1-based
# to 1-based
cur_residue_index
=
per_asym_residue_index
[
i
+
1
]
cur_residue_index
=
per_asym_residue_index
[
i
+
1
]
if
len
(
v
.
shape
)
<=
1
or
"template"
in
k
or
"row_mask"
in
k
:
if
len
(
v
.
shape
)
<=
1
or
"template"
in
k
or
"row_mask"
in
k
:
continue
continue
else
:
else
:
dimension_to_merge
=
label
.
shape
.
index
(
cur_num_res
)
if
cur_num_res
in
label
.
shape
else
0
dimension_to_merge
=
1
if
k
==
'all_atom_positions'
:
dimension_to_merge
=
1
cur_out
[
i
]
=
label
.
index_select
(
dimension_to_merge
,
cur_residue_index
)
cur_out
[
i
]
=
label
.
index_select
(
dimension_to_merge
,
cur_residue_index
)
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
if
len
(
cur_out
)
>
0
:
if
len
(
cur_out
)
>
0
:
...
@@ -2144,16 +2140,17 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2144,16 +2140,17 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
feature
,
ground_truth
=
batch
feature
,
ground_truth
=
batch
del
batch
del
batch
if
permutate_chains
:
if
permutate_chains
:
best_rmsd
=
float
(
'inf'
)
best_align
=
None
# First select anchors from predicted structures and ground truths
# First select anchors from predicted structures and ground truths
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
ground_truth
,
feature
[
'asym_id'
])
anchor_gt_asym
,
anchor_pred_asym_ids
=
get_least_asym_entity_or_longest_length
(
ground_truth
,
feature
[
'asym_id'
])
print
(
f
"########## line 2147 anchor_pred_asym_ids is
{
anchor_pred_asym_ids
}
and gt_asym is
{
anchor_gt_asym
}
"
)
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
ground_truth
)
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
ground_truth
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
assert
isinstance
(
labels
,
list
)
assert
isinstance
(
labels
,
list
)
del
ground_truth
del
ground_truth
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
asym_mask
=
(
feature
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
# Then calculate optimal transform by aligning anchors
# Then calculate optimal transform by aligning anchors
ca_idx
=
rc
.
atom_order
[
"CA"
]
ca_idx
=
rc
.
atom_order
[
"CA"
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
...
@@ -2165,28 +2162,36 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2165,28 +2162,36 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_masks
=
[
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
]
# list([nres,])
]
# list([nres,])
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
anchor_gt_residue
=
per_asym_residue_index
[
int
(
anchor_gt_asym
)]
for
candidate_pred_anchor
in
anchor_pred_asym_ids
:
r
,
x
=
AlphaFoldMultimerLoss
.
calculate_optimal_transform
(
true_ca_poses
,
asym_mask
=
(
feature
[
"asym_id"
]
==
candidate_pred_anchor
).
bool
()
anchor_gt_idx
,
anchor_gt_residue
,
anchor_gt_residue
=
per_asym_residue_index
[
int
(
candidate_pred_anchor
)]
true_ca_masks
,
pred_ca_mask
,
r
,
x
=
AlphaFoldMultimerLoss
.
calculate_optimal_transform
(
true_ca_poses
,
asym_mask
,
anchor_gt_idx
,
anchor_gt_residue
,
pred_ca_pos
true_ca_masks
,
pred_ca_mask
,
)
asym_mask
,
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
pred_ca_pos
del
true_ca_poses
,
r
,
x
)
gc
.
collect
()
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
align
=
greedy_align
(
align
=
greedy_align
(
feature
,
feature
,
per_asym_residue_index
,
per_asym_residue_index
,
entity_2_asym_list
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_pos
,
pred_ca_mask
,
pred_ca_mask
,
aligned_true_ca_poses
,
aligned_true_ca_poses
,
true_ca_masks
,
true_ca_masks
,
)
)
merged_labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
=
feature
[
'aatype'
].
shape
[
-
1
])
rmsd
=
compute_rmsd
(
true_atom_pos
=
merged_labels
[
'all_atom_positions'
][...,
ca_idx
,
:].
to
(
r
.
dtype
)
@
r
+
x
,
pred_atom_pos
=
pred_ca_pos
,
atom_mask
=
(
pred_ca_mask
*
merged_labels
[
'all_atom_mask'
][...,
ca_idx
].
long
()).
bool
())
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_align
=
align
print
(
f
"##### 2193 rmsd is
{
rmsd
}
and anchor_gt_asym is
{
anchor_gt_asym
}
and candidate_pred_anchor is
{
candidate_pred_anchor
}
"
)
del
r
,
x
del
true_ca_masks
,
aligned_true_ca_poses
del
true_ca_masks
,
aligned_true_ca_poses
del
pred_ca_pos
,
pred_ca_mask
del
pred_ca_pos
,
pred_ca_mask
gc
.
collect
()
gc
.
collect
()
...
@@ -2195,9 +2200,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2195,9 +2200,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
align
=
list
(
enumerate
(
range
(
len
(
labels
))))
best_
align
=
list
(
enumerate
(
range
(
len
(
labels
))))
return
align
,
per_asym_residue_index
return
best_
align
,
per_asym_residue_index
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment