Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
02ce77c5
Commit
02ce77c5
authored
Sep 24, 2023
by
Geoffrey Yu
Browse files
fixed error when selected anchor_aysm is not in the cropped input features
parent
67f873e7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
43 deletions
+49
-43
openfold/utils/loss.py
openfold/utils/loss.py
+49
-43
No files found.
openfold/utils/loss.py
View file @
02ce77c5
...
...
@@ -1781,7 +1781,7 @@ def get_optimal_transform(
return
r
,
x
def
get_least_asym_entity_or_longest_length
(
batch
):
def
get_least_asym_entity_or_longest_length
(
batch
,
input_asym_id
):
"""
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor
...
...
@@ -1818,6 +1818,7 @@ def get_least_asym_entity_or_longest_length(batch):
# # If there is more than one chain in the predicted output that has the same sequence
# # as the chosen ground truth anchor, then randomly picke one
if
len
(
best_pred_asym
)
>
1
:
while
best_pred_asym
not
in
input_asym_id
:
best_pred_asym
=
random
.
choice
(
best_pred_asym
)
return
least_asym_entities
[
0
],
best_pred_asym
...
...
@@ -1825,6 +1826,7 @@ def get_least_asym_entity_or_longest_length(batch):
def
greedy_align
(
batch
,
per_asym_residue_index
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_mask
,
...
...
@@ -1835,9 +1837,9 @@ def greedy_align(
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
"""
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
align
=
[]
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
for
cur_asym_id
in
unique_asym_ids
:
i
=
int
(
cur_asym_id
-
1
)
asym_mask
=
batch
[
"asym_id"
]
==
cur_asym_id
...
...
@@ -1845,16 +1847,14 @@ def greedy_align(
best_rmsd
=
torch
.
inf
best_idx
=
None
cur_asym_list
=
entity_2_asym_list
[
int
(
cur_entity_ids
)]
cur_residue_index
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
cur_pred_pos
=
pred_ca_pos
[
asym_mask
]
cur_pred_mask
=
pred_ca_mask
[
asym_mask
]
for
next_asym_id
in
cur_asym_list
:
j
=
int
(
next_asym_id
-
1
)
if
not
used
[
j
]:
# possible candidate
cropped_pos
=
true_ca_poses
[
j
]
cropped_pos
=
torch
.
squeeze
(
cropped_pos
,
0
)
if
cropped_pos
.
shape
==
cur_pred_pos
.
shape
:
mask
=
true_ca_masks
[
j
]
mask
=
torch
.
squeeze
(
mask
,
0
)
cropped_pos
=
torch
.
index_select
(
true_ca_poses
[
j
],
1
,
cur_residue_index
)
mask
=
torch
.
index_select
(
true_ca_masks
[
j
],
1
,
cur_residue_index
)
rmsd
=
compute_rmsd
(
torch
.
squeeze
(
cropped_pos
,
0
),
torch
.
squeeze
(
cur_pred_pos
,
0
),
(
cur_pred_mask
*
mask
).
bool
()
...
...
@@ -1866,7 +1866,6 @@ def greedy_align(
assert
best_idx
is
not
None
used
[
best_idx
]
=
True
align
.
append
((
i
,
best_idx
))
return
align
...
...
@@ -1878,7 +1877,7 @@ def pad_features(feature_tensor,nres_pad,pad_dim):
return
torch
.
concat
((
feature_tensor
,
padding_tensor
),
dim
=
pad_dim
)
def
merge_labels
(
labels
,
align
,
original_nres
):
def
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
):
"""
Merge ground truth labels according to the permutation results
...
...
@@ -1895,13 +1894,14 @@ def merge_labels(labels, align,original_nres):
label
=
labels
[
j
][
k
]
cur_num_res
=
labels
[
j
][
'aatype'
].
shape
[
-
1
]
# to 1-based
cur_residue_index
=
per_asym_residue_index
[
i
+
1
]
if
len
(
v
.
shape
)
<=
1
or
"template"
in
k
or
"row_mask"
in
k
:
continue
else
:
dimension_to_merge
=
label
.
shape
.
index
(
cur_num_res
)
if
cur_num_res
in
label
.
shape
else
0
if
k
==
'all_atom_positions'
:
dimension_to_merge
=
1
cur_out
[
i
]
=
label
cur_out
[
i
]
=
label
.
index_select
(
dimension_to_merge
,
cur_residue_index
)
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
if
len
(
cur_out
)
>
0
:
new_v
=
torch
.
concat
(
cur_out
,
dim
=
dimension_to_merge
)
...
...
@@ -2100,8 +2100,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
"""
pred_ca_mask
=
torch
.
squeeze
(
pred_ca_mask
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
print
(
f
"##### line 2102 asym_mask is
{
asym_mask
}
and shape:
{
asym_mask
.
shape
}
"
)
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
anchor_true_mask
=
torch
.
index_select
(
true_ca_masks
[
anchor_gt_idx
],
1
,
anchor_gt_residue
)
print
(
f
"##### line 2104 anchor_pred_mask:
{
anchor_pred_mask
.
shape
}
and anchor_true_mask :
{
anchor_true_mask
.
shape
}
"
)
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
return
input_mask
...
...
@@ -2137,20 +2139,20 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
feature
,
ground_truth
=
batch
print
(
f
"###### line 2140 feature asym_id is :
{
feature
[
'asym_id'
]
}
"
)
del
batch
if
permutate_chains
:
# First select anchors from predicted structures and ground truths
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
ground_truth
,
feature
[
'asym_id'
])
print
(
f
"###### anchor_gt_asym:
{
anchor_gt_asym
}
and anchor_pred_asym:
{
anchor_pred_asym
}
"
)
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
ground_truth
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
assert
isinstance
(
labels
,
list
)
print
(
f
"successfully split ground truth labels"
)
if
permutate_chains
:
# First select anchors from predicted structures and ground truths
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
ground_truth
)
print
(
f
"###### anchor gt asym is:
{
anchor_gt_asym
}
and anchor pred asym is
{
anchor_pred_asym
}
"
)
del
ground_truth
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
asym_mask
=
(
feature
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
print
(
f
"###### asym_mask is
{
asym_mask
}
"
)
# Then calculate optimal transform by aligning anchors
ca_idx
=
rc
.
atom_order
[
"CA"
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
...
...
@@ -2165,7 +2167,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
anchor_gt_residue
=
per_asym_residue_index
[
int
(
anchor_gt_asym
)]
print
(
f
"######## per_asym_residue_index is
{
per_asym_residue_index
}
"
)
r
,
x
=
AlphaFoldMultimerLoss
.
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
anchor_gt_residue
,
true_ca_masks
,
pred_ca_mask
,
...
...
@@ -2175,12 +2176,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
del
true_ca_poses
,
r
,
x
gc
.
collect
()
print
(
f
"$$$$$$$ successfully calculated r and x"
)
import
sys
sys
.
exit
()
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
batch
)
align
=
greedy_align
(
batch
,
feature
,
per_asym_residue_index
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_mask
,
...
...
@@ -2191,10 +2189,16 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del
true_ca_masks
,
aligned_true_ca_poses
del
pred_ca_pos
,
pred_ca_mask
gc
.
collect
()
print
(
f
"finished permutation align. Align is
{
align
}
"
)
else
:
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
align
=
list
(
enumerate
(
range
(
len
(
labels
))))
return
align
return
align
,
per_asym_residue_index
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
"""
...
...
@@ -2209,18 +2213,20 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
features
,
ground_truth
=
batch
del
batch
is_monomer
=
len
(
torch
.
unique
(
features
[
'asym_id'
]))
==
1
or
torch
.
unique
(
features
[
'asym_id'
]).
tolist
()
==
[
0
,
1
]
print
(
f
"asym_id is
{
features
[
'asym_id'
]
}
"
)
if
not
is_monomer
:
permutate_chains
=
True
# Then permutate ground truth chains before calculating the loss
align
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
(
features
,
ground_truth
),
permutate_chains
=
permutate_chains
)
align
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
(
features
,
ground_truth
),
permutate_chains
=
permutate_chains
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
REQUIRED_FEATURES
=
[
i
for
i
in
ground_truth
.
keys
()])
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
features
,
REQUIRED_FEATURES
=
[
i
for
i
in
features
.
keys
()
if
i
in
dim_dict
])
# reorder ground truth labels according to permutation results
labels
=
merge_labels
(
labels
,
align
,
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
features
.
update
(
labels
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment