Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
15f1fa63
"container/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "f122aa4ec1ce10f10919e608572a7e12f24243aa"
Commit
15f1fa63
authored
Sep 21, 2023
by
Geoffrey Yu
Browse files
cleaned up and split into smaller functions
parent
ea7fcced
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
95 additions
and
61 deletions
+95
-61
openfold/utils/loss.py
openfold/utils/loss.py
+95
-61
No files found.
openfold/utils/loss.py
View file @
15f1fa63
...
@@ -1830,7 +1830,6 @@ def get_least_asym_entity_or_longest_length(batch):
...
@@ -1830,7 +1830,6 @@ def get_least_asym_entity_or_longest_length(batch):
def
greedy_align
(
def
greedy_align
(
batch
,
batch
,
unique_asym_ids
,
entity_2_asym_list
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_pos
,
pred_ca_mask
,
pred_ca_mask
,
...
@@ -1841,6 +1840,7 @@ def greedy_align(
...
@@ -1841,6 +1840,7 @@ def greedy_align(
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
"""
"""
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
align
=
[]
align
=
[]
for
cur_asym_id
in
unique_asym_ids
:
for
cur_asym_id
in
unique_asym_ids
:
...
@@ -1860,17 +1860,13 @@ def greedy_align(
...
@@ -1860,17 +1860,13 @@ def greedy_align(
if
cropped_pos
.
shape
==
cur_pred_pos
.
shape
:
if
cropped_pos
.
shape
==
cur_pred_pos
.
shape
:
mask
=
true_ca_masks
[
j
]
mask
=
true_ca_masks
[
j
]
mask
=
torch
.
squeeze
(
mask
,
0
)
mask
=
torch
.
squeeze
(
mask
,
0
)
print
(
f
"cropped_pos shape:
{
cropped_pos
.
shape
}
cur_pred_pos shape:
{
cur_pred_pos
.
shape
}
"
)
print
(
f
"mask shape:
{
mask
.
shape
}
and cur_pred_mask shape:
{
cur_pred_mask
.
shape
}
"
)
rmsd
=
compute_rmsd
(
rmsd
=
compute_rmsd
(
torch
.
squeeze
(
cropped_pos
,
0
),
torch
.
squeeze
(
cur_pred_pos
,
0
),
torch
.
squeeze
(
cropped_pos
,
0
),
torch
.
squeeze
(
cur_pred_pos
,
0
),
(
cur_pred_mask
*
mask
).
bool
()
(
cur_pred_mask
*
mask
).
bool
()
)
)
print
(
f
"rmsd is
{
rmsd
}
"
)
if
(
rmsd
is
not
None
)
and
(
rmsd
<
best_rmsd
):
if
(
rmsd
is
not
None
)
and
(
rmsd
<
best_rmsd
):
best_rmsd
=
rmsd
best_rmsd
=
rmsd
best_idx
=
j
best_idx
=
j
print
(
f
"best_idx is
{
best_idx
}
"
)
assert
best_idx
is
not
None
assert
best_idx
is
not
None
used
[
best_idx
]
=
True
used
[
best_idx
]
=
True
...
@@ -1887,9 +1883,10 @@ def pad_features(feature_tensor,nres_pad,pad_dim):
...
@@ -1887,9 +1883,10 @@ def pad_features(feature_tensor,nres_pad,pad_dim):
return
torch
.
concat
((
feature_tensor
,
padding_tensor
),
dim
=
pad_dim
)
return
torch
.
concat
((
feature_tensor
,
padding_tensor
),
dim
=
pad_dim
)
def
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
):
def
merge_labels
(
labels
,
align
,
original_nres
):
"""
"""
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
Merge ground truth labels according to the permutation results
labels: list of original ground truth feats
labels: list of original ground truth feats
align: list of tuples, each entry specify the corresponding label of the asym.
align: list of tuples, each entry specify the corresponding label of the asym.
...
@@ -2051,7 +2048,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2051,7 +2048,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
"""
"""
Splits ground truth features according to chains
Splits ground truth features according to chains
Returns a list of feature dictionaries with only necessary ground truth features
Returns:
a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation
required to finish multi-chain permutation
"""
"""
unique_asym_ids
,
asym_id_counts
=
torch
.
unique
(
batch
[
"asym_id"
],
sorted
=
False
,
return_counts
=
True
)
unique_asym_ids
,
asym_id_counts
=
torch
.
unique
(
batch
[
"asym_id"
],
sorted
=
False
,
return_counts
=
True
)
...
@@ -2066,6 +2064,70 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2066,6 +2064,70 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
labels
=
list
(
map
(
dict
,
zip
(
*
[[(
k
,
v
)
for
v
in
torch
.
split
(
value
,
asym_id_counts
,
dim
=
dim_dict
[
k
])]
for
k
,
value
in
batch
.
items
()
if
k
in
REQUIRED_FEATURES
])))
labels
=
list
(
map
(
dict
,
zip
(
*
[[(
k
,
v
)
for
v
in
torch
.
split
(
value
,
asym_id_counts
,
dim
=
dim_dict
[
k
])]
for
k
,
value
in
batch
.
items
()
if
k
in
REQUIRED_FEATURES
])))
return
labels
return
labels
@
staticmethod
def
get_entity_2_asym_list
(
batch
):
"""
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.
Args:
batch (dict): A dictionary containing data batches, including "entity_id" and "asym_id" tensors.
Returns:
entity_2_asym_list (dict): A dictionary where keys are unique entity IDs, and values are lists of unique asymmetry IDs
associated with each entity.
"""
entity_2_asym_list
=
{}
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
for
cur_ent_id
in
unique_entity_ids
:
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
return
entity_2_asym_list
@
staticmethod
def
calculate_input_mask
(
true_ca_masks
,
anchor_gt_idx
,
asym_mask
,
pred_ca_mask
):
"""
Calculate an input mask for downstream optimal transformation computation
Args:
true_ca_masks (Tensor): ca mask from ground truth.
anchor_gt_idx (Tensor): The index of selected ground truth anchor.
asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor.
pred_ca_mask (Tensor): ca mask from predicted structure.
Returns:
input_mask (Tensor): A boolean mask
"""
pred_ca_mask
=
torch
.
squeeze
(
pred_ca_mask
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
anchor_true_mask
=
true_ca_masks
[
anchor_gt_idx
]
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
return
input_mask
@
staticmethod
def
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
true_ca_masks
,
pred_ca_mask
,
asym_mask
,
pred_ca_pos
):
input_mask
=
AlphaFoldMultimerLoss
.
calculate_input_mask
(
true_ca_masks
,
anchor_gt_idx
,
asym_mask
,
pred_ca_mask
)
input_mask
=
torch
.
squeeze
(
input_mask
,
0
)
pred_ca_pos
=
torch
.
squeeze
(
pred_ca_pos
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
anchor_true_pos
=
true_ca_poses
[
anchor_gt_idx
]
anchor_pred_pos
=
pred_ca_pos
[
asym_mask
]
r
,
x
=
get_optimal_transform
(
anchor_pred_pos
,
torch
.
squeeze
(
anchor_true_pos
,
0
),
mask
=
input_mask
)
return
r
,
x
@
staticmethod
@
staticmethod
def
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
permutate_chains
=
False
):
def
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
permutate_chains
=
False
):
"""
"""
...
@@ -2078,63 +2140,38 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2078,63 +2140,38 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
dim_dict
=
dim_dict
,
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
dim_dict
=
dim_dict
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
assert
isinstance
(
labels
,
list
)
assert
isinstance
(
labels
,
list
)
ca_idx
=
rc
.
atom_order
[
"CA"
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
pred_ca_mask
=
out
[
"final_atom_mask"
][...,
ca_idx
].
to
(
dtype
=
pred_ca_pos
.
dtype
)
# [bsz, nres]
true_ca_poses
=
[
l
[
"all_atom_positions"
][...,
ca_idx
,
:]
for
l
in
labels
]
# list([nres, 3])
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
]
# list([nres,])
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
per_asym_residue_index
=
{}
for
cur_asym_id
in
unique_asym_ids
:
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
batch
[
"residue_index"
],
asym_mask
)
if
permutate_chains
:
if
permutate_chains
:
# First select anchors from predicted structures and ground truths
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
print
(
f
"anchor_gt_asym:
{
anchor_gt_asym
}
anchor_pred_asym:
{
anchor_pred_asym
}
"
)
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
entity_2_asym_list
=
{}
for
cur_ent_id
in
unique_entity_ids
:
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
asym_mask
=
(
batch
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
asym_mask
=
(
batch
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
anchor_true_pos
=
true_ca_poses
[
anchor_gt_idx
]
# Then calculate optimal transform by aligning anchors
ca_idx
=
rc
.
atom_order
[
"CA"
]
anchor_pred_pos
=
pred_ca_pos
[
0
][
asym_mask
[
0
]]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
pred_ca_mask
=
out
[
"final_atom_mask"
][...,
ca_idx
].
to
(
dtype
=
pred_ca_pos
.
dtype
)
# [bsz, nres]
anchor_true_mask
=
true_ca_masks
[
anchor_gt_idx
]
anchor_pred_mask
=
pred_ca_mask
[
0
][
asym_mask
[
0
]]
true_ca_poses
=
[
l
[
"all_atom_positions"
][...,
ca_idx
,
:]
for
l
in
labels
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
]
# list([nres, 3])
r
,
x
=
get_optimal_transform
(
true_ca_masks
=
[
anchor_pred_pos
,
anchor_true_pos
[
0
],
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
mask
=
input_mask
[
0
]
]
# list([nres,])
r
,
x
=
AlphaFoldMultimerLoss
.
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
true_ca_masks
,
pred_ca_mask
,
asym_mask
,
pred_ca_pos
)
)
del
input_mask
# just to save memory
del
anchor_pred_mask
del
anchor_true_mask
gc
.
collect
()
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
del
true_ca_poses
del
true_ca_poses
,
r
,
x
gc
.
collect
()
gc
.
collect
()
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
batch
)
align
=
greedy_align
(
align
=
greedy_align
(
batch
,
batch
,
unique_asym_ids
,
entity_2_asym_list
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_pos
,
pred_ca_mask
,
pred_ca_mask
,
...
@@ -2142,16 +2179,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2142,16 +2179,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_masks
,
true_ca_masks
,
)
)
del
aligned_true_ca_poses
,
true_ca_masks
del
true_ca_masks
,
aligned_true_ca_poses
del
r
,
x
del
pred_ca_pos
,
pred_ca_mask
del
pred_ca_pos
,
pred_ca_mask
del
anchor_pred_pos
,
anchor_true_pos
gc
.
collect
()
gc
.
collect
()
print
(
f
"finished multi-chain permutation and final align is
{
align
}
"
)
else
:
else
:
align
=
list
(
enumerate
(
range
(
len
(
labels
))))
align
=
list
(
enumerate
(
range
(
len
(
labels
))))
return
align
,
per_asym_residue_index
return
align
def
forward
(
self
,
out
,
features
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
def
forward
(
self
,
out
,
features
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
"""
"""
...
@@ -2170,13 +2204,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2170,13 +2204,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
features
)
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
features
)
# Then permutate ground truth chains before calculating the loss
# Then permutate ground truth chains before calculating the loss
align
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
features
,
dim_dict
=
dim_dict
,
align
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
features
,
dim_dict
=
dim_dict
,
permutate_chains
=
permutate_chains
)
permutate_chains
=
permutate_chains
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
features
,
dim_dict
=
dim_dict
,
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
features
,
dim_dict
=
dim_dict
,
REQUIRED_FEATURES
=
[
i
for
i
in
features
.
keys
()
if
i
in
dim_dict
])
REQUIRED_FEATURES
=
[
i
for
i
in
features
.
keys
()
if
i
in
dim_dict
])
# reorder ground truth labels according to permutation results
# reorder ground truth labels according to permutation results
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
labels
=
merge_labels
(
labels
,
align
,
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
features
.
update
(
labels
)
features
.
update
(
labels
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment