Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
dff973ab
Commit
dff973ab
authored
Aug 30, 2023
by
Geoffrey Yu
Browse files
update get_least_asym_entity_or_longest_length and added split_ground_truth_labels
parent
ab09ded4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
4 deletions
+43
-4
openfold/utils/loss.py
openfold/utils/loss.py
+43
-4
No files found.
openfold/utils/loss.py
View file @
dff973ab
...
...
@@ -1770,8 +1770,8 @@ def get_optimal_transform(
else
:
src_atoms
=
src_atoms
[
mask
,
:]
tgt_atoms
=
tgt_atoms
[
mask
,
:]
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
)
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
,
dtype
=
src_atoms
.
dtype
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
,
dtype
=
src_atoms
.
dtype
)
r
=
kabsch_rotation
(
src_atoms
,
tgt_atoms
)
del
src_atoms
,
tgt_atoms
,
gc
.
collect
()
...
...
@@ -1792,6 +1792,12 @@ def get_least_asym_entity_or_longest_length(batch):
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
"""
REQUIRED_FEATURES
=
[
'entity_id'
,
'asym_id'
]
seq_length
=
batch
[
'seq_length'
].
item
()
# remove padding part before selecting candidate
remove_padding
=
lambda
t
:
torch
.
index_select
(
t
,
dim
=
1
,
index
=
torch
.
arange
(
seq_length
,
device
=
t
.
device
))
batch
=
{
k
:
tensor_tree_map
(
remove_padding
,
batch
[
k
])
for
k
in
REQUIRED_FEATURES
}
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
entity_asym_count
=
{}
entity_length
=
{}
...
...
@@ -1842,6 +1848,8 @@ def greedy_align(
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
align
=
[]
for
cur_asym_id
in
unique_asym_ids
:
if
cur_asym_id
==
0
:
continue
i
=
int
(
cur_asym_id
-
1
)
asym_mask
=
batch
[
"asym_id"
]
==
cur_asym_id
cur_entity_ids
=
batch
[
"entity_id"
][
asym_mask
][
0
]
...
...
@@ -2021,7 +2029,26 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
self
.
config
=
config
@
staticmethod
def
multi_chain_perm_align
(
out
,
batch
,
labels
,
permutate_chains
=
True
):
def
split_ground_truth_labels
(
batch
,
REQUIRED_FEATURES
):
"""
Splits ground truth features according to chains
Returns a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation
"""
unique_asym_ids
,
asym_id_counts
=
torch
.
unique
(
batch
[
"asym_id"
],
sorted
=
False
,
return_counts
=
True
)
unique_asym_ids
,
asym_id_counts
=
unique_asym_ids
.
tolist
(),
asym_id_counts
.
tolist
()
if
0
in
unique_asym_ids
:
pop_idx
=
unique_asym_ids
.
index
(
0
)
padding_asym_id
=
unique_asym_ids
.
pop
(
pop_idx
)
padding_asym_counts
=
asym_id_counts
.
pop
(
pop_idx
)
unique_asym_ids
.
append
(
padding_asym_id
)
asym_id_counts
.
append
(
padding_asym_counts
)
labels
=
list
(
map
(
dict
,
zip
(
*
[[(
k
,
v
)
for
v
in
torch
.
split
(
value
,
asym_id_counts
,
dim
=
1
)]
for
k
,
value
in
batch
.
items
()
if
k
in
REQUIRED_FEATURES
])))
return
labels
@
staticmethod
def
multi_chain_perm_align
(
out
,
batch
,
permutate_chains
=
True
):
"""
A class method that first permutate chains in ground truth first
before calculating the loss.
...
...
@@ -2029,6 +2056,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
assert
isinstance
(
labels
,
list
)
ca_idx
=
rc
.
atom_order
[
"CA"
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:]
# [bsz, nres, 3]
...
...
@@ -2049,6 +2078,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
batch
[
"residue_index"
],
asym_mask
)
if
permutate_chains
:
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
print
(
f
"anchor_gt_asym:
{
anchor_gt_asym
}
anchor_pred_asym:
{
anchor_pred_asym
}
"
)
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
...
...
@@ -2074,7 +2104,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del
anchor_pred_mask
del
anchor_true_mask
gc
.
collect
()
aligned_true_ca_poses
=
[
ca
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
aligned_true_ca_poses
=
[
ca
.
to
(
r
.
dtype
)
@
r
+
x
for
ca
in
true_ca_poses
]
# apply transforms
del
true_ca_poses
gc
.
collect
()
align
=
greedy_align
(
...
...
@@ -2114,7 +2144,16 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
# permutated_labels = merge_labels(per_asym_residue_index, labels, align)
# permutated_labels.pop('aatype')
# features.update(permutated_labels)
print
(
f
"########## line 2154 loss.py features is
{
type
(
features
)
}
"
)
for
k
,
v
in
features
.
items
():
print
(
f
"
{
k
}
:
{
v
.
shape
}
"
)
# permutate ground truth chains before calculating the loss
align
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
features
,
permutate_chains
=
permutate_chains
)
import
sys
sys
.
exit
()
# permutated_labels = merge_labels(per_asym_residue_index, labels, align)
if
(
not
_return_breakdown
):
cum_loss
=
self
.
loss
(
out
,
features
,
_return_breakdown
)
print
(
f
"cum_loss:
{
cum_loss
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment