Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
67f873e7
Commit
67f873e7
authored
Sep 24, 2023
by
Geoffrey Yu
Browse files
updated optimal transform function now
parent
bd82338e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
39 deletions
+46
-39
openfold/utils/loss.py
openfold/utils/loss.py
+46
-39
No files found.
openfold/utils/loss.py
View file @
67f873e7
...
...
@@ -1789,12 +1789,6 @@ def get_least_asym_entity_or_longest_length(batch):
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
"""
REQUIRED_FEATURES
=
[
'entity_id'
,
'asym_id'
]
seq_length
=
batch
[
'seq_length'
].
item
()
# remove padding part before selecting candidate
remove_padding
=
lambda
t
:
torch
.
index_select
(
t
,
dim
=
1
,
index
=
torch
.
arange
(
seq_length
,
device
=
t
.
device
))
batch
=
{
k
:
tensor_tree_map
(
remove_padding
,
batch
[
k
])
for
k
in
REQUIRED_FEATURES
}
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
entity_asym_count
=
{}
entity_length
=
{}
...
...
@@ -1819,13 +1813,13 @@ def get_least_asym_entity_or_longest_length(batch):
if
len
(
least_asym_entities
)
>
1
:
least_asym_entities
=
random
.
choice
(
least_asym_entities
)
assert
len
(
least_asym_entities
)
==
1
#
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])
best_pred_asym
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
least_asym_entities
[
0
]])
# # If there is more than one chain in the predicted output that has the same sequence
# # as the chosen ground truth anchor, then randomly picke one
#
if len(best_pred_asym) > 1:
#
best_pred_asym = random.choice(best_pred_asym)
best_pred_asym
=
least_asym_entities
[
0
]
if
len
(
best_pred_asym
)
>
1
:
best_pred_asym
=
random
.
choice
(
best_pred_asym
)
return
least_asym_entities
[
0
],
best_pred_asym
...
...
@@ -2037,15 +2031,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
def
__init__
(
self
,
config
):
super
(
AlphaFoldMultimerLoss
,
self
).
__init__
(
config
)
self
.
config
=
config
@
staticmethod
def
determine_split_dim
(
batch
)
->
dict
:
"""A method to determine which dimension to split in split_ground_truth_labels"""
padded_dim
=
batch
[
'aatype'
].
shape
[
-
1
]
dim_dict
=
{
k
:
list
(
v
.
shape
).
index
(
padded_dim
)
for
k
,
v
in
batch
.
items
()
if
padded_dim
in
v
.
shape
}
return
dim_dict
@
staticmethod
def
split_ground_truth_labels
(
batch
,
REQUIRED_FEATURES
,
dim_dict
):
def
split_ground_truth_labels
(
batch
,
REQUIRED_FEATURES
,
split_dim
=
1
):
"""
Splits ground truth features according to chains
...
...
@@ -2062,9 +2050,19 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
unique_asym_ids
.
append
(
padding_asym_id
)
asym_id_counts
.
append
(
padding_asym_counts
)
labels
=
list
(
map
(
dict
,
zip
(
*
[[(
k
,
v
)
for
v
in
torch
.
split
(
value
,
asym_id_counts
,
dim
=
dim_dict
[
k
]
)]
for
k
,
value
in
batch
.
items
()
if
k
in
REQUIRED_FEATURES
])))
labels
=
list
(
map
(
dict
,
zip
(
*
[[(
k
,
v
)
for
v
in
torch
.
split
(
value
,
asym_id_counts
,
dim
=
split_dim
)]
for
k
,
value
in
batch
.
items
()
if
k
in
REQUIRED_FEATURES
])))
return
labels
@
staticmethod
def
get_per_asym_residue_index
(
features
):
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
features
[
"asym_id"
])
if
i
!=
0
]
per_asym_residue_index
=
{}
for
cur_asym_id
in
unique_asym_ids
:
asym_mask
=
(
features
[
"asym_id"
]
==
cur_asym_id
).
bool
()
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
features
[
"residue_index"
],
asym_mask
)
return
per_asym_residue_index
@
staticmethod
def
get_entity_2_asym_list
(
batch
):
"""
...
...
@@ -2086,7 +2084,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
return
entity_2_asym_list
@
staticmethod
def
calculate_input_mask
(
true_ca_masks
,
anchor_gt_idx
,
def
calculate_input_mask
(
true_ca_masks
,
anchor_gt_idx
,
anchor_gt_residue
,
asym_mask
,
pred_ca_mask
):
"""
Calculate an input mask for downstream optimal transformation computation
...
...
@@ -2103,24 +2101,24 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
pred_ca_mask
=
torch
.
squeeze
(
pred_ca_mask
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
anchor_true_mask
=
true_ca_masks
[
anchor_gt_idx
]
anchor_true_mask
=
torch
.
index_select
(
true_ca_masks
[
anchor_gt_idx
]
,
1
,
anchor_gt_residue
)
input_mask
=
(
anchor_true_mask
*
anchor_pred_mask
).
bool
()
return
input_mask
@
staticmethod
def
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
anchor_gt_idx
,
anchor_gt_residue
,
true_ca_masks
,
pred_ca_mask
,
asym_mask
,
pred_ca_pos
):
input_mask
=
AlphaFoldMultimerLoss
.
calculate_input_mask
(
true_ca_masks
,
anchor_gt_idx
,
asym_mask
,
anchor_gt_idx
,
anchor_gt_residue
,
asym_mask
,
pred_ca_mask
)
input_mask
=
torch
.
squeeze
(
input_mask
,
0
)
pred_ca_pos
=
torch
.
squeeze
(
pred_ca_pos
,
0
)
asym_mask
=
torch
.
squeeze
(
asym_mask
,
0
)
anchor_true_pos
=
true_ca_poses
[
anchor_gt_idx
]
anchor_true_pos
=
torch
.
index_select
(
true_ca_poses
[
anchor_gt_idx
]
,
1
,
anchor_gt_residue
)
anchor_pred_pos
=
pred_ca_pos
[
asym_mask
]
r
,
x
=
get_optimal_transform
(
anchor_pred_pos
,
torch
.
squeeze
(
anchor_true_pos
,
0
),
...
...
@@ -2130,7 +2128,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
return
r
,
x
@
staticmethod
def
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
permutate_chains
=
False
):
def
multi_chain_perm_align
(
out
,
batch
,
permutate_chains
=
False
):
"""
A class method that first permutate chains in ground truth first
before calculating the loss.
...
...
@@ -2138,17 +2136,21 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
dim_dict
=
dim_dict
,
feature
,
ground_truth
=
batch
del
batch
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
ground_truth
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
assert
isinstance
(
labels
,
list
)
print
(
f
"successfully split ground truth labels"
)
if
permutate_chains
:
# First select anchors from predicted structures and ground truths
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
ground_truth
)
print
(
f
"###### anchor gt asym is:
{
anchor_gt_asym
}
and anchor pred asym is
{
anchor_pred_asym
}
"
)
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
asym_mask
=
(
batch
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
asym_mask
=
(
feature
[
"asym_id"
]
==
anchor_pred_asym
).
bool
()
print
(
f
"###### asym_mask is
{
asym_mask
}
"
)
# Then calculate optimal transform by aligning anchors
ca_idx
=
rc
.
atom_order
[
"CA"
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
...
...
@@ -2161,8 +2163,11 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
]
# list([nres,])
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
get_per_asym_residue_index
(
feature
)
anchor_gt_residue
=
per_asym_residue_index
[
int
(
anchor_gt_asym
)]
print
(
f
"######## per_asym_residue_index is
{
per_asym_residue_index
}
"
)
r
,
x
=
AlphaFoldMultimerLoss
.
calculate_optimal_transform
(
true_ca_poses
,
anchor_gt_idx
,
anchor_gt_idx
,
anchor_gt_residue
,
true_ca_masks
,
pred_ca_mask
,
asym_mask
,
pred_ca_pos
...
...
@@ -2170,7 +2175,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
del
true_ca_poses
,
r
,
x
gc
.
collect
()
print
(
f
"$$$$$$$ successfully calculated r and x"
)
import
sys
sys
.
exit
()
entity_2_asym_list
=
AlphaFoldMultimerLoss
.
get_entity_2_asym_list
(
batch
)
align
=
greedy_align
(
batch
,
...
...
@@ -2189,7 +2196,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
return
align
def
forward
(
self
,
out
,
features
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
"""
Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation
...
...
@@ -2199,18 +2206,18 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
batch: a pair of input features and its corresponding ground truth structure
"""
# first check if it is a monomer
features
,
ground_truth
=
batch
del
batch
is_monomer
=
len
(
torch
.
unique
(
features
[
'asym_id'
]))
==
1
or
torch
.
unique
(
features
[
'asym_id'
]).
tolist
()
==
[
0
,
1
]
print
(
f
"asym_id is
{
features
[
'asym_id'
]
}
"
)
if
not
is_monomer
:
permutate_chains
=
True
# first determin which dimension in the tensor to split into individual ground truth labels
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
features
)
# Then permutate ground truth chains before calculating the loss
align
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
features
,
dim_dict
=
dim_dict
,
permutate_chains
=
permutate_chains
)
align
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
(
features
,
ground_truth
),
permutate_chains
=
permutate_chains
)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
features
,
dim_dict
=
dim_dict
,
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
features
,
REQUIRED_FEATURES
=
[
i
for
i
in
features
.
keys
()
if
i
in
dim_dict
])
# reorder ground truth labels according to permutation results
labels
=
merge_labels
(
labels
,
align
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment