Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
a49e4896
Commit
a49e4896
authored
Jun 24, 2023
by
Geoffrey Yu
Browse files
added multi-chain permutation to AlphaFoldMultimerLoss
parent
85cd91f6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
323 additions
and
2 deletions
+323
-2
openfold/utils/loss.py
openfold/utils/loss.py
+323
-2
No files found.
openfold/utils/loss.py
View file @
a49e4896
...
@@ -34,6 +34,8 @@ from openfold.utils.tensor_utils import (
...
@@ -34,6 +34,8 @@ from openfold.utils.tensor_utils import (
permute_final_dims
,
permute_final_dims
,
batched_gather
,
batched_gather
,
)
)
import
random
from
openfold.np
import
residue_constants
as
rc
import
logging
import
logging
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -1669,6 +1671,226 @@ def chain_center_of_mass_loss(
...
@@ -1669,6 +1671,226 @@ def chain_center_of_mass_loss(
loss
=
masked_mean
(
loss_mask
,
losses
,
dim
=
(
-
1
,
-
2
))
loss
=
masked_mean
(
loss_mask
,
losses
,
dim
=
(
-
1
,
-
2
))
return
loss
return
loss
# #
# below are the functions required for permutations
# #
def
kabsch_rotation
(
P
,
Q
):
"""
Using the Kabsch algorithm with two sets of paired point P and Q, centered
around the centroid. Each vector set is represented as an NxD
matrix, where D is the the dimension of the space.
The algorithm works in three steps:
- a centroid translation of P and Q (assumed done before this function
call)
- the computation of a covariance matrix C
- computation of the optimal rotation matrix U
For more info see http://en.wikipedia.org/wiki/Kabsch_algorithm
Parameters
----------
P : array
(N,D) matrix, where N is points and D is dimension.
Q : array
(N,D) matrix, where N is points and D is dimension.
Returns
-------
U : matrix
Rotation matrix (D,D)
"""
# Computation of the covariance matrix
P
,
Q
=
P
.
to
(
'cpu'
),
Q
.
to
(
'cpu'
)
# move to cpu memory just in case it takes up too much gpu mem
C
=
P
.
transpose
(
-
1
,
-
2
)
@
Q
# Computation of the optimal rotation matrix
# This can be done using singular value decomposition (SVD)
# Getting the sign of the det(V)*(W) to decide
# whether we need to correct our rotation matrix to ensure a
# right-handed coordinate system.
# And finally calculating the optimal rotation matrix U
# see http://en.wikipedia.org/wiki/Kabsch_algorithm
V
,
_
,
W
=
torch
.
linalg
.
svd
(
C
)
d
=
(
torch
.
linalg
.
det
(
V
)
*
torch
.
linalg
.
det
(
W
))
<
0.0
if
d
:
V
[:,
-
1
]
=
-
V
[:,
-
1
]
# Create Rotation matrix U
U
=
V
@
W
return
U
def
get_optimal_transform
(
src_atoms
:
torch
.
Tensor
,
tgt_atoms
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
=
None
,
):
assert
src_atoms
.
shape
==
tgt_atoms
.
shape
,
(
src_atoms
.
shape
,
tgt_atoms
.
shape
)
assert
src_atoms
.
shape
[
-
1
]
==
3
if
mask
is
not
None
:
assert
mask
.
dtype
==
torch
.
bool
assert
mask
.
shape
[
-
1
]
==
src_atoms
.
shape
[
-
2
]
if
mask
.
sum
()
==
0
:
src_atoms
=
torch
.
zeros
((
1
,
3
),
device
=
src_atoms
.
device
).
float
()
tgt_atoms
=
src_atoms
else
:
src_atoms
=
src_atoms
[
mask
,
:]
tgt_atoms
=
tgt_atoms
[
mask
,
:]
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
)
r
=
kabsch_rotation
(
src_atoms
-
src_center
,
tgt_atoms
-
tgt_center
)
tgt_center
,
src_center
=
tgt_center
.
to
(
'cpu'
),
src_center
.
to
(
'cpu'
)
# load to cpu memory just in case
x
=
tgt_center
-
src_center
@
r
return
r
,
x
def
compute_rmsd
(
true_atom_pos
:
torch
.
Tensor
,
pred_atom_pos
:
torch
.
Tensor
,
atom_mask
:
torch
.
Tensor
=
None
,
eps
:
float
=
1e-6
,
)
->
torch
.
Tensor
:
# shape check
sq_diff
=
torch
.
square
(
true_atom_pos
-
pred_atom_pos
).
sum
(
dim
=-
1
,
keepdim
=
False
)
if
atom_mask
is
not
None
:
sq_diff
=
sq_diff
[
atom_mask
]
msd
=
torch
.
mean
(
sq_diff
)
msd
=
torch
.
nan_to_num
(
msd
,
nan
=
1e8
)
return
torch
.
sqrt
(
msd
+
eps
)
def
kabsch_rmsd
(
true_atom_pos
:
torch
.
Tensor
,
pred_atom_pos
:
torch
.
Tensor
,
atom_mask
:
torch
.
Tensor
,
):
r
,
x
=
get_optimal_transform
(
true_atom_pos
,
pred_atom_pos
,
atom_mask
,
)
aligned_true_atom_pos
=
true_atom_pos
@
r
+
x
return
compute_rmsd
(
aligned_true_atom_pos
,
pred_atom_pos
,
atom_mask
)
def
get_least_asym_entity_or_longest_length
(
batch
):
"""
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
"""
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
entity_asym_count
=
{}
entity_length
=
{}
for
entity_id
in
unique_entity_ids
:
asym_ids
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
entity_id
])
entity_asym_count
[
int
(
entity_id
)]
=
len
(
asym_ids
)
# Calculate entity length
entity_mask
=
(
batch
[
"entity_id"
]
==
entity_id
)
entity_length
[
int
(
entity_id
)]
=
entity_mask
.
sum
().
item
()
min_asym_count
=
min
(
entity_asym_count
.
values
())
least_asym_entities
=
[
entity
for
entity
,
count
in
entity_asym_count
.
items
()
if
count
==
min_asym_count
]
# If multiple entities have the least asym_id count, return those with the shortest length
if
len
(
least_asym_entities
)
>
1
:
max_length
=
max
([
entity_length
[
entity
]
for
entity
in
least_asym_entities
])
least_asym_entities
=
[
entity
for
entity
in
least_asym_entities
if
entity_length
[
entity
]
==
max_length
]
# If still multiple entities, return a random one
if
len
(
least_asym_entities
)
>
1
:
least_asym_entities
=
random
.
choice
(
least_asym_entities
)
assert
len
(
least_asym_entities
)
==
1
best_pred_asym
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
least_asym_entities
[
0
]])
return
least_asym_entities
[
0
],
best_pred_asym
def
greedy_align
(
batch
,
per_asym_residue_index
,
unique_asym_ids
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_mask
,
true_ca_poses
,
true_ca_masks
,
):
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
align
=
[]
for
cur_asym_id
in
unique_asym_ids
:
# skip padding
if
cur_asym_id
==
0
:
continue
i
=
int
(
cur_asym_id
-
1
)
asym_mask
=
batch
[
"asym_id"
]
==
cur_asym_id
entity_id
=
batch
[
"entity_id"
][
asym_mask
][
0
]
# don't need to align
if
(
entity_id
)
==
1
:
align
.
append
((
i
,
i
))
assert
used
[
i
]
==
False
used
[
i
]
=
True
continue
cur_entity_ids
=
batch
[
"entity_id"
][
asym_mask
][
0
]
best_rmsd
=
1e20
best_idx
=
None
cur_asym_list
=
entity_2_asym_list
[
int
(
cur_entity_ids
)]
cur_residue_index
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
cur_pred_pos
=
pred_ca_pos
[
asym_mask
]
# only need the first 1 column of asym_mask
cur_pred_mask
=
pred_ca_mask
[
asym_mask
]
for
next_asym_id
in
cur_asym_list
:
if
next_asym_id
==
0
:
continue
j
=
int
(
next_asym_id
-
1
)
if
not
used
[
j
]:
# possible candidate
cropped_pos
=
true_ca_poses
[
j
]
mask
=
true_ca_masks
[
j
][
cur_residue_index
]
rmsd
=
compute_rmsd
(
cropped_pos
,
cur_pred_pos
,
(
cur_pred_mask
.
to
(
'cpu'
)
*
mask
.
to
(
'cpu'
)).
bool
()
)
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_idx
=
j
assert
best_idx
is
not
None
used
[
best_idx
]
=
True
align
.
append
((
i
,
best_idx
))
print
(
f
"align is
{
align
}
"
)
return
align
def
merge_labels
(
batch
,
per_asym_residue_index
,
labels
,
align
):
"""
batch:
labels: list of label dicts, each with shape [nk, *]
align: list of int, such as [2, None, 0, 1], each entry specify the corresponding label of the asym.
"""
num_res
=
batch
[
"msa_mask"
].
shape
[
-
1
]
outs
=
{}
for
k
,
v
in
labels
[
0
].
items
():
if
k
in
[
"resolution"
,
]:
continue
cur_out
=
{}
for
i
,
j
in
align
:
label
=
labels
[
j
][
k
]
# to 1-based
cur_residue_index
=
per_asym_residue_index
[
i
+
1
]
cur_out
[
i
]
=
label
[
cur_residue_index
]
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
new_v
=
torch
.
concat
(
cur_out
,
dim
=
0
)
merged_nres
=
new_v
.
shape
[
0
]
assert
(
merged_nres
<=
num_res
),
f
"bad merged num res:
{
merged_nres
}
>
{
num_res
}
. something is wrong."
if
merged_nres
<
num_res
:
# must pad
pad_dim
=
new_v
.
shape
[
1
:]
pad_v
=
new_v
.
new_zeros
((
num_res
-
merged_nres
,
*
pad_dim
))
new_v
=
torch
.
concat
((
new_v
,
pad_v
),
dim
=
0
)
outs
[
k
]
=
new_v
return
outs
class
AlphaFoldLoss
(
nn
.
Module
):
class
AlphaFoldLoss
(
nn
.
Module
):
"""Aggregation of the various losses described in the supplement"""
"""Aggregation of the various losses described in the supplement"""
...
@@ -1782,12 +2004,111 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
...
@@ -1782,12 +2004,111 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
AlphaFoldLoss
AlphaFoldLoss
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
AlphaFoldMultimerLoss
,
self
).
__init__
()
super
(
AlphaFoldMultimerLoss
,
self
).
__init__
(
config
)
self
.
config
=
config
self
.
config
=
config
def
multi_chain_perm_align
(
self
,
out
,
batch
,
labels
,
shuffle_times
=
2
):
"""
A class method that first permutate chains in ground truth first
before calculating the loss.
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
assert
isinstance
(
labels
,
list
)
ca_idx
=
rc
.
atom_order
[
"CA"
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:].
float
()
# [bsz, nres, 3]
pred_ca_mask
=
out
[
"final_atom_mask"
][...,
ca_idx
].
float
()
# [bsz, nres]
true_ca_poses
=
[
l
[
"all_atom_positions"
][...,
ca_idx
,
:].
float
()
for
l
in
labels
]
# list([nres, 3])
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
float
()
for
l
in
labels
]
# list([nres,])
unique_asym_ids
=
torch
.
unique
(
batch
[
"asym_id"
])
per_asym_residue_index
=
{}
for
cur_asym_id
in
unique_asym_ids
:
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
batch
[
"residue_index"
][
asym_mask
]
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
logger
.
info
(
f
"anchor_gt_asym is chosen to be:
{
anchor_gt_asym
}
"
)
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
best_rmsd
=
1e20
best_labels
=
None
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
entity_2_asym_list
=
{}
for
cur_ent_id
in
unique_entity_ids
:
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
for
cur_asym_id
in
anchor_pred_asym
:
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
anchor_residue_idx
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
anchor_true_pos
=
true_ca_poses
[
anchor_gt_idx
][
anchor_residue_idx
]
anchor_pred_pos
=
pred_ca_pos
[
asym_mask
]
anchor_true_mask
=
true_ca_masks
[
anchor_gt_idx
][
anchor_residue_idx
]
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
r
,
x
=
get_optimal_transform
(
anchor_true_pos
,
anchor_pred_pos
,
(
anchor_true_mask
.
to
(
'cpu'
)
*
anchor_pred_mask
.
to
(
'cpu'
)).
bool
(),
)
aligned_true_ca_poses
=
[
ca
.
to
(
'cpu'
)
@
r
.
to
(
'cpu'
)
+
x
.
to
(
'cpu'
)
for
ca
in
true_ca_poses
]
# apply transforms
for
_
in
range
(
shuffle_times
):
shuffle_idx
=
torch
.
randperm
(
unique_asym_ids
.
shape
[
0
],
device
=
unique_asym_ids
.
device
)
shuffled_asym_ids
=
unique_asym_ids
[
shuffle_idx
]
align
=
greedy_align
(
batch
,
per_asym_residue_index
,
shuffled_asym_ids
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_mask
,
aligned_true_ca_poses
,
true_ca_masks
,
)
merged_labels
=
merge_labels
(
batch
,
per_asym_residue_index
,
labels
,
align
,
)
rmsd
=
kabsch_rmsd
(
merged_labels
[
"all_atom_positions"
][...,
ca_idx
,
:].
to
(
'cpu'
)
@
r
.
to
(
'cpu'
)
+
x
.
to
(
'cpu'
),
pred_ca_pos
,
(
pred_ca_mask
.
to
(
'cpu'
)
*
merged_labels
[
"all_atom_mask"
][...,
ca_idx
].
to
(
'cpu'
)).
bool
(),
)
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_labels
=
merged_labels
return
best_labels
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
"""
"""
Overwrite AlphaFoldLoss forward function so that
Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation
it first compute multi-chain permutation
args:
out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure
"""
"""
logger
.
info
(
f
"out is
{
type
(
out
)
}
and batch is
{
type
(
batch
)
}
"
)
features
,
labels
=
batch
\ No newline at end of file
# first remove the recycling dimention of input features
features
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
features
)
# then permutate ground truth chains before calculating the loss
permutated_labels
=
self
.
multi_chain_perm_align
(
out
,
features
,
labels
)
logger
.
info
(
"finished multi-chain permutation"
)
return
permutated_labels
## TODO next need to check how the ground truth label is used
# in loss calculation.
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment