Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
15f1fa63
Commit
15f1fa63
authored
Sep 21, 2023
by
Geoffrey Yu
Browse files
cleaned up and split into smaller functions
parent
ea7fcced
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
95 additions
and
61 deletions
+95
-61
openfold/utils/loss.py
openfold/utils/loss.py
+95
-61
No files found.
openfold/utils/loss.py
View file @
15f1fa63
...
@@ -1830,7 +1830,6 @@ def get_least_asym_entity_or_longest_length(batch):
...
@@ -1830,7 +1830,6 @@ def get_least_asym_entity_or_longest_length(batch):
def
greedy_align
(
def
greedy_align
(
batch
,
batch
,
unique_asym_ids
,
entity_2_asym_list
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_pos
,
pred_ca_mask
,
pred_ca_mask
,
...
@@ -1841,6 +1840,7 @@ def greedy_align(
...
@@ -1841,6 +1840,7 @@ def greedy_align(
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
"""
"""
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
align
=
[]
align
=
[]
for
cur_asym_id
in
unique_asym_ids
:
for
cur_asym_id
in
unique_asym_ids
:
...
@@ -1860,17 +1860,13 @@ def greedy_align(
...
@@ -1860,17 +1860,13 @@ def greedy_align(
if
cropped_pos
.
shape
==
cur_pred_pos
.
shape
:
if
cropped_pos
.
shape
==
cur_pred_pos
.
shape
:
mask
=
true_ca_masks
[
j
]
mask
=
true_ca_masks
[
j
]
mask
=
torch
.
squeeze
(
mask
,
0
)
mask
=
torch
.
squeeze
(
mask
,
0
)
print
(
f
"cropped_pos shape:
{
cropped_pos
.
shape
}
cur_pred_pos shape:
{
cur_pred_pos
.
shape
}
"
)
print
(
f
"mask shape:
{
mask
.
shape
}
and cur_pred_mask shape:
{
cur_pred_mask
.
shape
}
"
)
rmsd
=
compute_rmsd
(
rmsd
=
compute_rmsd
(
torch
.
squeeze
(
cropped_pos
,
0
),
torch
.
squeeze
(
cur_pred_pos
,
0
),
torch
.
squeeze
(
cropped_pos
,
0
),
torch
.
squeeze
(
cur_pred_pos
,
0
),
(
cur_pred_mask
*
mask
).
bool
()
(
cur_pred_mask
*
mask
).
bool
()
)
)
print
(
f
"rmsd is
{
rmsd
}
"
)
if
(
rmsd
is
not
None
)
and
(
rmsd
<
best_rmsd
):
if
(
rmsd
is
not
None
)
and
(
rmsd
<
best_rmsd
):
best_rmsd
=
rmsd
best_rmsd
=
rmsd
best_idx
=
j
best_idx
=
j
print
(
f
"best_idx is
{
best_idx
}
"
)
assert
best_idx
is
not
None
assert
best_idx
is
not
None
used
[
best_idx
]
=
True
used
[
best_idx
]
=
True
...
@@ -1887,9 +1883,10 @@ def pad_features(feature_tensor,nres_pad,pad_dim):
...
@@ -1887,9 +1883,10 @@ def pad_features(feature_tensor,nres_pad,pad_dim):
return
torch
.
concat
((
feature_tensor
,
padding_tensor
),
dim
=
pad_dim
)
return
torch
.
concat
((
feature_tensor
,
padding_tensor
),
dim
=
pad_dim
)
def
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
):
def
merge_labels
(
labels
,
align
,
original_nres
):
"""
"""
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
Merge ground truth labels according to the permutation results
labels: list of original ground truth feats
labels: list of original ground truth feats
align: list of tuples, each entry specify the corresponding label of the asym.
align: list of tuples, each entry specify the corresponding label of the asym.
...
@@ -2051,7 +2048,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2051,7 +2048,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
"""
"""
Splits ground truth features according to chains
Splits ground truth features according to chains
Returns a list of feature dictionaries with only necessary ground truth features
Returns:
a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation
required to finish multi-chain permutation
"""
"""
unique_asym_ids
,
asym_id_counts
=
torch
.
unique
(
batch
[
"asym_id"
],
sorted
=
False
,
return_counts
=
True
)
unique_asym_ids
,
asym_id_counts
=
torch
.
unique
(
batch
[
"asym_id"
],
sorted
=
False
,
return_counts
=
True
)
...
@@ -2066,6 +2064,70 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2066,6 +2064,70 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
labels
=
list
(
map
(
dict
,
zip
(
*
[[(
k
,
v
)
for
v
in
torch
.
split
(
value
,
asym_id_counts
,
dim
=
dim_dict
[
k
])]
for
k
,
value
in
batch
.
items
()
if
k
in
REQUIRED_FEATURES
])))
labels
=
list
(
map
(
dict
,
zip
(
*
[[(
k
,
v
)
for
v
in
torch
.
split
(
value
,
asym_id_counts
,
dim
=
dim_dict
[
k
])]
for
k
,
value
in
batch
.
items
()
if
k
in
REQUIRED_FEATURES
])))
return
labels
return
labels
@
staticmethod
def
get_entity_2_asym_list
(
batch
):
"""
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.
Args:
batch (dict): A dictionary containing data batches, including "entity_id" and "asym_id" tensors.
Returns:
entity_2_asym_list (dict): A dictionary where keys are unique entity IDs, and values are lists of unique asymmetry IDs
associated with each entity.
"""
entity_2_asym_list
=
{}
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
for
cur_ent_id
in
unique_entity_ids
:
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
return
entity_2_asym_list
@
staticmethod
def
calculate_input_mask
(
true_ca_masks
,
anchor_gt_idx
,
asym_mask
,
pred_ca_mask
):
"""
Calculate an input mask for downstream optimal transformation computation
Args:
true_ca_masks (Tensor): ca mask from ground truth.
anchor_gt_idx (Tensor): The index of selected ground truth anchor.
asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor.
pred_ca_mask (Tensor): ca mask from predicted structure.
Returns:
input_mask (Tensor): A boolean mask
"""
pred_ca_mask
=
torch
.
squeeze
(
pred_ca_mask
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
anchor_true_mask
=
true_ca_masks
[
anchor_gt_idx
]
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
return
input_mask
@
staticmethod
def
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
true_ca_masks
,
pred_ca_mask
,
asym_mask
,
pred_ca_pos
):
input_mask
=
AlphaFoldMultimerLoss
.
calculate_input_mask
(
true_ca_masks
,
anchor_gt_idx
,
asym_mask
,
pred_ca_mask
)
input_mask
=
torch
.
squeeze
(
input_mask
,
0
)
pred_ca_pos
=
torch
.
squeeze
(
pred_ca_pos
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
anchor_true_pos
=
true_ca_poses
[
anchor_gt_idx
]
anchor_pred_pos
=
pred_ca_pos
[
asym_mask
]
r
,
x
=
get_optimal_transform
(
anchor_pred_pos
,
torch
.
squeeze
(
anchor_true_pos
,
0
),
mask
=
input_mask
)
return
r
,
x
@
staticmethod
@
staticmethod
def
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
permutate_chains
=
False
):
def
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
permutate_chains
=
False
):
"""
"""
...
@@ -2078,6 +2140,15 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2078,6 +2140,15 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
dim_dict
=
dim_dict
,
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
dim_dict
=
dim_dict
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
assert
isinstance
(
labels
,
list
)
assert
isinstance
(
labels
,
list
)
if
permutate_chains
:
# First select anchors from predicted structures and ground truths
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
asym_mask
=
(
batch
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
# Then calculate optimal transform by aligning anchors
ca_idx
=
rc
.
atom_order
[
"CA"
]
ca_idx
=
rc
.
atom_order
[
"CA"
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
pred_ca_mask
=
out
[
"final_atom_mask"
][...,
ca_idx
].
to
(
dtype
=
pred_ca_pos
.
dtype
)
# [bsz, nres]
pred_ca_mask
=
out
[
"final_atom_mask"
][...,
ca_idx
].
to
(
dtype
=
pred_ca_pos
.
dtype
)
# [bsz, nres]
...
@@ -2085,56 +2156,22 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2085,56 +2156,22 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_poses
=
[
true_ca_poses
=
[
l
[
"all_atom_positions"
][...,
ca_idx
,
:]
for
l
in
labels
l
[
"all_atom_positions"
][...,
ca_idx
,
:]
for
l
in
labels
]
# list([nres, 3])
]
# list([nres, 3])
true_ca_masks
=
[
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
]
# list([nres,])
]
# list([nres,])
r
,
x
=
AlphaFoldMultimerLoss
.
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
true_ca_masks
,
pred_ca_mask
,
asym_mask
,
per_asym_residue_index
=
{}
pred_ca_pos
for
cur_asym_id
in
unique_asym_ids
:
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
batch
[
"residue_index"
],
asym_mask
)
if
permutate_chains
:
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
print
(
f
"anchor_gt_asym:
{
anchor_gt_asym
}
anchor_pred_asym:
{
anchor_pred_asym
}
"
)
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
entity_2_asym_list
=
{}
for
cur_ent_id
in
unique_entity_ids
:
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
asym_mask
=
(
batch
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
anchor_true_pos
=
true_ca_poses
[
anchor_gt_idx
]
anchor_pred_pos
=
pred_ca_pos
[
0
][
asym_mask
[
0
]]
anchor_true_mask
=
true_ca_masks
[
anchor_gt_idx
]
anchor_pred_mask
=
pred_ca_mask
[
0
][
asym_mask
[
0
]]
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
r
,
x
=
get_optimal_transform
(
anchor_pred_pos
,
anchor_true_pos
[
0
],
mask
=
input_mask
[
0
]
)
)
del
input_mask
# just to save memory
del
anchor_pred_mask
del
anchor_true_mask
gc
.
collect
()
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
del
true_ca_poses
del
true_ca_poses
,
r
,
x
gc
.
collect
()
gc
.
collect
()
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
batch
)
align
=
greedy_align
(
align
=
greedy_align
(
batch
,
batch
,
unique_asym_ids
,
entity_2_asym_list
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_pos
,
pred_ca_mask
,
pred_ca_mask
,
...
@@ -2142,16 +2179,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2142,16 +2179,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_masks
,
true_ca_masks
,
)
)
del
aligned_true_ca_poses
,
true_ca_masks
del
true_ca_masks
,
aligned_true_ca_poses
del
r
,
x
del
pred_ca_pos
,
pred_ca_mask
del
pred_ca_pos
,
pred_ca_mask
del
anchor_pred_pos
,
anchor_true_pos
gc
.
collect
()
gc
.
collect
()
print
(
f
"finished multi-chain permutation and final align is
{
align
}
"
)
else
:
else
:
align
=
list
(
enumerate
(
range
(
len
(
labels
))))
align
=
list
(
enumerate
(
range
(
len
(
labels
))))
return
align
,
per_asym_residue_index
return
align
def
forward
(
self
,
out
,
features
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
def
forward
(
self
,
out
,
features
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
"""
"""
...
@@ -2170,13 +2204,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -2170,13 +2204,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
features
)
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
features
)
# Then permutate ground truth chains before calculating the loss
# Then permutate ground truth chains before calculating the loss
align
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
features
,
dim_dict
=
dim_dict
,
align
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
features
,
dim_dict
=
dim_dict
,
permutate_chains
=
permutate_chains
)
permutate_chains
=
permutate_chains
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
features
,
dim_dict
=
dim_dict
,
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
features
,
dim_dict
=
dim_dict
,
REQUIRED_FEATURES
=
[
i
for
i
in
features
.
keys
()
if
i
in
dim_dict
])
REQUIRED_FEATURES
=
[
i
for
i
in
features
.
keys
()
if
i
in
dim_dict
])
# reorder ground truth labels according to permutation results
# reorder ground truth labels according to permutation results
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
labels
=
merge_labels
(
labels
,
align
,
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
features
.
update
(
labels
)
features
.
update
(
labels
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment