Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
02ce77c5
Commit
02ce77c5
authored
Sep 24, 2023
by
Geoffrey Yu
Browse files
fixed error when selected anchor_aysm is not in the cropped input features
parent
67f873e7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
43 deletions
+49
-43
openfold/utils/loss.py
openfold/utils/loss.py
+49
-43
No files found.
openfold/utils/loss.py
View file @
02ce77c5
...
@@ -1781,7 +1781,7 @@ def get_optimal_transform(
...
@@ -1781,7 +1781,7 @@ def get_optimal_transform(
return
r
,
x
return
r
,
x
def
get_least_asym_entity_or_longest_length
(
batch
):
def
get_least_asym_entity_or_longest_length
(
batch
,
input_asym_id
):
"""
"""
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor
one of the A as anchor
...
@@ -1818,6 +1818,7 @@ def get_least_asym_entity_or_longest_length(batch):
...
@@ -1818,6 +1818,7 @@ def get_least_asym_entity_or_longest_length(batch):
# # If there is more than one chain in the predicted output that has the same sequence
# # If there is more than one chain in the predicted output that has the same sequence
# # as the chosen ground truth anchor, then randomly picke one
# # as the chosen ground truth anchor, then randomly picke one
if
len
(
best_pred_asym
)
>
1
:
if
len
(
best_pred_asym
)
>
1
:
while
best_pred_asym
not
in
input_asym_id
:
best_pred_asym
=
random
.
choice
(
best_pred_asym
)
best_pred_asym
=
random
.
choice
(
best_pred_asym
)
return
least_asym_entities
[
0
],
best_pred_asym
return
least_asym_entities
[
0
],
best_pred_asym
...
@@ -1825,6 +1826,7 @@ def get_least_asym_entity_or_longest_length(batch):
...
@@ -1825,6 +1826,7 @@ def get_least_asym_entity_or_longest_length(batch):
def
greedy_align
(
def
greedy_align
(
batch
,
batch
,
per_asym_residue_index
,
entity_2_asym_list
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_pos
,
pred_ca_mask
,
pred_ca_mask
,
...
@@ -1835,9 +1837,9 @@ def greedy_align(
...
@@ -1835,9 +1837,9 @@ def greedy_align(
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
"""
"""
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
align
=
[]
align
=
[]
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
for
cur_asym_id
in
unique_asym_ids
:
for
cur_asym_id
in
unique_asym_ids
:
i
=
int
(
cur_asym_id
-
1
)
i
=
int
(
cur_asym_id
-
1
)
asym_mask
=
batch
[
"asym_id"
]
==
cur_asym_id
asym_mask
=
batch
[
"asym_id"
]
==
cur_asym_id
...
@@ -1845,16 +1847,14 @@ def greedy_align(
...
@@ -1845,16 +1847,14 @@ def greedy_align(
best_rmsd
=
torch
.
inf
best_rmsd
=
torch
.
inf
best_idx
=
None
best_idx
=
None
cur_asym_list
=
entity_2_asym_list
[
int
(
cur_entity_ids
)]
cur_asym_list
=
entity_2_asym_list
[
int
(
cur_entity_ids
)]
cur_residue_index
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
cur_pred_pos
=
pred_ca_pos
[
asym_mask
]
cur_pred_pos
=
pred_ca_pos
[
asym_mask
]
cur_pred_mask
=
pred_ca_mask
[
asym_mask
]
cur_pred_mask
=
pred_ca_mask
[
asym_mask
]
for
next_asym_id
in
cur_asym_list
:
for
next_asym_id
in
cur_asym_list
:
j
=
int
(
next_asym_id
-
1
)
j
=
int
(
next_asym_id
-
1
)
if
not
used
[
j
]:
# possible candidate
if
not
used
[
j
]:
# possible candidate
cropped_pos
=
true_ca_poses
[
j
]
cropped_pos
=
torch
.
index_select
(
true_ca_poses
[
j
],
1
,
cur_residue_index
)
cropped_pos
=
torch
.
squeeze
(
cropped_pos
,
0
)
mask
=
torch
.
index_select
(
true_ca_masks
[
j
],
1
,
cur_residue_index
)
if
cropped_pos
.
shape
==
cur_pred_pos
.
shape
:
mask
=
true_ca_masks
[
j
]
mask
=
torch
.
squeeze
(
mask
,
0
)
rmsd
=
compute_rmsd
(
rmsd
=
compute_rmsd
(
torch
.
squeeze
(
cropped_pos
,
0
),
torch
.
squeeze
(
cur_pred_pos
,
0
),
torch
.
squeeze
(
cropped_pos
,
0
),
torch
.
squeeze
(
cur_pred_pos
,
0
),
(
cur_pred_mask
*
mask
).
bool
()
(
cur_pred_mask
*
mask
).
bool
()
...
@@ -1866,7 +1866,6 @@ def greedy_align(
...
@@ -1866,7 +1866,6 @@ def greedy_align(
assert
best_idx
is
not
None
assert
best_idx
is
not
None
used
[
best_idx
]
=
True
used
[
best_idx
]
=
True
align
.
append
((
i
,
best_idx
))
align
.
append
((
i
,
best_idx
))
return
align
return
align
...
@@ -1878,7 +1877,7 @@ def pad_features(feature_tensor,nres_pad,pad_dim):
...
@@ -1878,7 +1877,7 @@ def pad_features(feature_tensor,nres_pad,pad_dim):
return
torch
.
concat
((
feature_tensor
,
padding_tensor
),
dim
=
pad_dim
)
return
torch
.
concat
((
feature_tensor
,
padding_tensor
),
dim
=
pad_dim
)
def
merge_labels
(
labels
,
align
,
original_nres
):
def
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
):
"""
"""
Merge ground truth labels according to the permutation results
Merge ground truth labels according to the permutation results
...
@@ -1895,13 +1894,14 @@ def merge_labels(labels, align,original_nres):
...
@@ -1895,13 +1894,14 @@ def merge_labels(labels, align,original_nres):
label
=
labels
[
j
][
k
]
label
=
labels
[
j
][
k
]
cur_num_res
=
labels
[
j
][
'aatype'
].
shape
[
-
1
]
cur_num_res
=
labels
[
j
][
'aatype'
].
shape
[
-
1
]
# to 1-based
# to 1-based
cur_residue_index
=
per_asym_residue_index
[
i
+
1
]
if
len
(
v
.
shape
)
<=
1
or
"template"
in
k
or
"row_mask"
in
k
:
if
len
(
v
.
shape
)
<=
1
or
"template"
in
k
or
"row_mask"
in
k
:
continue
continue
else
:
else
:
dimension_to_merge
=
label
.
shape
.
index
(
cur_num_res
)
if
cur_num_res
in
label
.
shape
else
0
dimension_to_merge
=
label
.
shape
.
index
(
cur_num_res
)
if
cur_num_res
in
label
.
shape
else
0
if
k
==
'all_atom_positions'
:
if
k
==
'all_atom_positions'
:
dimension_to_merge
=
1
dimension_to_merge
=
1
cur_out
[
i
]
=
label
cur_out
[
i
]
=
label
.
index_select
(
dimension_to_merge
,
cur_residue_index
)
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
if
len
(
cur_out
)
>
0
:
if
len
(
cur_out
)
>
0
:
new_v
=
torch
.
concat
(
cur_out
,
dim
=
dimension_to_merge
)
new_v
=
torch
.
concat
(
cur_out
,
dim
=
dimension_to_merge
)
...
@@ -2100,8 +2100,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2100,8 +2100,10 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
"""
"""
pred_ca_mask
=
torch
.
squeeze
(
pred_ca_mask
,
0
)
pred_ca_mask
=
torch
.
squeeze
(
pred_ca_mask
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
print
(
f
"##### line 2102 asym_mask is
{
asym_mask
}
and shape:
{
asym_mask
.
shape
}
"
)
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
anchor_true_mask
=
torch
.
index_select
(
true_ca_masks
[
anchor_gt_idx
],
1
,
anchor_gt_residue
)
anchor_true_mask
=
torch
.
index_select
(
true_ca_masks
[
anchor_gt_idx
],
1
,
anchor_gt_residue
)
print
(
f
"##### line 2104 anchor_pred_mask:
{
anchor_pred_mask
.
shape
}
and anchor_true_mask :
{
anchor_true_mask
.
shape
}
"
)
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
return
input_mask
return
input_mask
...
@@ -2137,20 +2139,20 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2137,20 +2139,20 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
"""
feature
,
ground_truth
=
batch
feature
,
ground_truth
=
batch
print
(
f
"###### line 2140 feature asym_id is :
{
feature
[
'asym_id'
]
}
"
)
del
batch
del
batch
if
permutate_chains
:
# First select anchors from predicted structures and ground truths
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
ground_truth
,
feature
[
'asym_id'
])
print
(
f
"###### anchor_gt_asym:
{
anchor_gt_asym
}
and anchor_pred_asym:
{
anchor_pred_asym
}
"
)
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
ground_truth
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
assert
isinstance
(
labels
,
list
)
assert
isinstance
(
labels
,
list
)
print
(
f
"successfully split ground truth labels"
)
del
ground_truth
if
permutate_chains
:
# First select anchors from predicted structures and ground truths
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
ground_truth
)
print
(
f
"###### anchor gt asym is:
{
anchor_gt_asym
}
and anchor pred asym is
{
anchor_pred_asym
}
"
)
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
asym_mask
=
(
feature
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
asym_mask
=
(
feature
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
print
(
f
"###### asym_mask is
{
asym_mask
}
"
)
# Then calculate optimal transform by aligning anchors
# Then calculate optimal transform by aligning anchors
ca_idx
=
rc
.
atom_order
[
"CA"
]
ca_idx
=
rc
.
atom_order
[
"CA"
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
...
@@ -2165,7 +2167,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2165,7 +2167,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
anchor_gt_residue
=
per_asym_residue_index
[
int
(
anchor_gt_asym
)]
anchor_gt_residue
=
per_asym_residue_index
[
int
(
anchor_gt_asym
)]
print
(
f
"######## per_asym_residue_index is
{
per_asym_residue_index
}
"
)
r
,
x
=
AlphaFoldMultimerLoss
.
calculate_optimal_transform
(
true_ca_poses
,
r
,
x
=
AlphaFoldMultimerLoss
.
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
anchor_gt_residue
,
anchor_gt_idx
,
anchor_gt_residue
,
true_ca_masks
,
pred_ca_mask
,
true_ca_masks
,
pred_ca_mask
,
...
@@ -2175,12 +2176,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2175,12 +2176,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
del
true_ca_poses
,
r
,
x
del
true_ca_poses
,
r
,
x
gc
.
collect
()
gc
.
collect
()
print
(
f
"$$$$$$$ successfully calculated r and x"
)
import
sys
sys
.
exit
()
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
batch
)
align
=
greedy_align
(
align
=
greedy_align
(
batch
,
feature
,
per_asym_residue_index
,
entity_2_asym_list
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_pos
,
pred_ca_mask
,
pred_ca_mask
,
...
@@ -2191,10 +2189,16 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2191,10 +2189,16 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del
true_ca_masks
,
aligned_true_ca_poses
del
true_ca_masks
,
aligned_true_ca_poses
del
pred_ca_pos
,
pred_ca_mask
del
pred_ca_pos
,
pred_ca_mask
gc
.
collect
()
gc
.
collect
()
print
(
f
"finished permutation align. Align is
{
align
}
"
)
else
:
else
:
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
align
=
list
(
enumerate
(
range
(
len
(
labels
))))
align
=
list
(
enumerate
(
range
(
len
(
labels
))))
return
align
return
align
,
per_asym_residue_index
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
"""
"""
...
@@ -2209,18 +2213,20 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2209,18 +2213,20 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
features
,
ground_truth
=
batch
features
,
ground_truth
=
batch
del
batch
del
batch
is_monomer
=
len
(
torch
.
unique
(
features
[
'asym_id'
]))
==
1
or
torch
.
unique
(
features
[
'asym_id'
]).
tolist
()
==
[
0
,
1
]
is_monomer
=
len
(
torch
.
unique
(
features
[
'asym_id'
]))
==
1
or
torch
.
unique
(
features
[
'asym_id'
]).
tolist
()
==
[
0
,
1
]
print
(
f
"asym_id is
{
features
[
'asym_id'
]
}
"
)
if
not
is_monomer
:
if
not
is_monomer
:
permutate_chains
=
True
permutate_chains
=
True
# Then permutate ground truth chains before calculating the loss
# Then permutate ground truth chains before calculating the loss
align
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
(
features
,
ground_truth
),
permutate_chains
=
permutate_chains
)
align
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
(
features
,
ground_truth
),
permutate_chains
=
permutate_chains
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
REQUIRED_FEATURES
=
[
i
for
i
in
ground_truth
.
keys
()])
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
features
,
REQUIRED_FEATURES
=
[
i
for
i
in
features
.
keys
()
if
i
in
dim_dict
])
# reorder ground truth labels according to permutation results
# reorder ground truth labels according to permutation results
labels
=
merge_labels
(
labels
,
align
,
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
features
.
update
(
labels
)
features
.
update
(
labels
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment