Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
2a1028f0
Commit
2a1028f0
authored
Aug 31, 2023
by
Geoffrey Yu
Browse files
update AlphaFoldMultimerLoss to accomodate new way of data_module loading procudure
parent
dff973ab
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
24 deletions
+39
-24
openfold/utils/loss.py
openfold/utils/loss.py
+39
-24
No files found.
openfold/utils/loss.py
View file @
2a1028f0
...
...
@@ -1848,8 +1848,6 @@ def greedy_align(
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
align
=
[]
for
cur_asym_id
in
unique_asym_ids
:
if
cur_asym_id
==
0
:
continue
i
=
int
(
cur_asym_id
-
1
)
asym_mask
=
batch
[
"asym_id"
]
==
cur_asym_id
cur_entity_ids
=
batch
[
"entity_id"
][
asym_mask
][
0
]
...
...
@@ -1878,7 +1876,15 @@ def greedy_align(
return
align
def
merge_labels
(
per_asym_residue_index
,
labels
,
align
):
def
pad_features
(
feature_tensor
,
nres_pad
,
pad_dim
):
"""Pad input feature tensor"""
pad_shape
=
list
(
feature_tensor
.
shape
)
pad_shape
[
pad_dim
]
=
nres_pad
padding_tensor
=
feature_tensor
.
new_zeros
(
pad_shape
,
device
=
feature_tensor
.
device
)
return
torch
.
concat
((
feature_tensor
,
padding_tensor
),
dim
=
pad_dim
)
def
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
):
"""
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
labels: list of original ground truth feats
...
...
@@ -1905,6 +1911,10 @@ def merge_labels(per_asym_residue_index, labels, align):
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
if
len
(
cur_out
)
>
0
:
new_v
=
torch
.
concat
(
cur_out
,
dim
=
dimension_to_merge
)
# below check whether padding is needed
if
new_v
.
shape
[
dimension_to_merge
]
!=
original_nres
:
nres_pad
=
original_nres
-
new_v
.
shape
[
dimension_to_merge
]
new_v
=
pad_features
(
new_v
,
nres_pad
,
pad_dim
=
dimension_to_merge
)
outs
[
k
]
=
new_v
return
outs
...
...
@@ -2027,9 +2037,15 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
def
__init__
(
self
,
config
):
super
(
AlphaFoldMultimerLoss
,
self
).
__init__
(
config
)
self
.
config
=
config
@
staticmethod
def
determine_split_dim
(
batch
)
->
dict
:
"""A method to determine which dimension to split in split_ground_truth_labels"""
padded_dim
=
batch
[
'aatype'
].
shape
[
-
1
]
dim_dict
=
{
k
:
list
(
v
.
shape
).
index
(
padded_dim
)
for
k
,
v
in
batch
.
items
()
if
padded_dim
in
v
.
shape
}
return
dim_dict
@
staticmethod
def
split_ground_truth_labels
(
batch
,
REQUIRED_FEATURES
):
def
split_ground_truth_labels
(
batch
,
REQUIRED_FEATURES
,
dim_dict
):
"""
Splits ground truth features according to chains
...
...
@@ -2044,11 +2060,12 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
padding_asym_counts
=
asym_id_counts
.
pop
(
pop_idx
)
unique_asym_ids
.
append
(
padding_asym_id
)
asym_id_counts
.
append
(
padding_asym_counts
)
labels
=
list
(
map
(
dict
,
zip
(
*
[[(
k
,
v
)
for
v
in
torch
.
split
(
value
,
asym_id_counts
,
dim
=
1
)]
for
k
,
value
in
batch
.
items
()
if
k
in
REQUIRED_FEATURES
])))
labels
=
list
(
map
(
dict
,
zip
(
*
[[(
k
,
v
)
for
v
in
torch
.
split
(
value
,
asym_id_counts
,
dim
=
dim_dict
[
k
])]
for
k
,
value
in
batch
.
items
()
if
k
in
REQUIRED_FEATURES
])))
return
labels
@
staticmethod
def
multi_chain_perm_align
(
out
,
batch
,
permutate_chains
=
True
):
def
multi_chain_perm_align
(
out
,
batch
,
dim_dict
,
permutate_chains
=
True
):
"""
A class method that first permutate chains in ground truth first
before calculating the loss.
...
...
@@ -2056,7 +2073,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
batch
,
dim_dict
=
dim_dict
,
REQUIRED_FEATURES
=
[
"all_atom_mask"
,
"all_atom_positions"
])
assert
isinstance
(
labels
,
list
)
ca_idx
=
rc
.
atom_order
[
"CA"
]
...
...
@@ -2070,7 +2087,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
long
()
for
l
in
labels
]
# list([nres,])
unique_asym_ids
=
torch
.
unique
(
batch
[
"asym_id"
])
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
per_asym_residue_index
=
{}
for
cur_asym_id
in
unique_asym_ids
:
...
...
@@ -2129,7 +2146,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
return
align
,
per_asym_residue_index
def
forward
(
self
,
out
,
features
,
_return_breakdown
=
False
,
permutate_chains
=
True
):
def
forward
(
self
,
out
,
features
,
_return_breakdown
=
False
):
"""
Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation
...
...
@@ -2138,22 +2155,20 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure
"""
# permutate ground truth chains before calculating the loss
# align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features, labels,
# permutate_chains=permutate_chains)
# permutated_labels = merge_labels(per_asym_residue_index, labels, align)
# permutated_labels.pop('aatype')
# features.update(permutated_labels)
print
(
f
"########## line 2154 loss.py features is
{
type
(
features
)
}
"
)
for
k
,
v
in
features
.
items
():
print
(
f
"
{
k
}
:
{
v
.
shape
}
"
)
# permutate ground truth chains before calculating the loss
align
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
features
,
# first determin which dimension in the tensor to split into individual ground truth labels
dim_dict
=
AlphaFoldMultimerLoss
.
determine_split_dim
(
features
)
# Then permutate ground truth chains before calculating the loss
align
,
per_asym_residue_index
=
AlphaFoldMultimerLoss
.
multi_chain_perm_align
(
out
,
features
,
dim_dict
=
dim_dict
,
permutate_chains
=
permutate_chains
)
import
sys
sys
.
exit
()
# permutated_labels = merge_labels(per_asym_residue_index, labels, align)
labels
=
AlphaFoldMultimerLoss
.
split_ground_truth_labels
(
features
,
dim_dict
=
dim_dict
,
REQUIRED_FEATURES
=
[
i
for
i
in
features
.
keys
()
if
i
in
dim_dict
])
# reorder ground truth labels according to permutation results
labels
=
merge_labels
(
per_asym_residue_index
,
labels
,
align
,
original_nres
=
features
[
'aatype'
].
shape
[
-
1
])
features
.
update
(
labels
)
if
(
not
_return_breakdown
):
cum_loss
=
self
.
loss
(
out
,
features
,
_return_breakdown
)
print
(
f
"cum_loss:
{
cum_loss
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment