Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
697ac82a
Unverified
Commit
697ac82a
authored
Jun 24, 2023
by
Dingquan Yu
Committed by
GitHub
Jun 24, 2023
Browse files
Merge pull request #2 from dingquanyu/cleanup-permutation
added multi-chain permutation to AlphaFoldMultimerLoss
parents
85cd91f6
a49e4896
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
323 additions
and
2 deletions
+323
-2
openfold/utils/loss.py
openfold/utils/loss.py
+323
-2
No files found.
openfold/utils/loss.py
View file @
697ac82a
...
...
@@ -34,6 +34,8 @@ from openfold.utils.tensor_utils import (
permute_final_dims
,
batched_gather
,
)
import
random
from
openfold.np
import
residue_constants
as
rc
import
logging
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -1669,6 +1671,226 @@ def chain_center_of_mass_loss(
loss
=
masked_mean
(
loss_mask
,
losses
,
dim
=
(
-
1
,
-
2
))
return
loss
# #
# below are the functions required for permutations
# #
def
kabsch_rotation
(
P
,
Q
):
"""
Using the Kabsch algorithm with two sets of paired point P and Q, centered
around the centroid. Each vector set is represented as an NxD
matrix, where D is the the dimension of the space.
The algorithm works in three steps:
- a centroid translation of P and Q (assumed done before this function
call)
- the computation of a covariance matrix C
- computation of the optimal rotation matrix U
For more info see http://en.wikipedia.org/wiki/Kabsch_algorithm
Parameters
----------
P : array
(N,D) matrix, where N is points and D is dimension.
Q : array
(N,D) matrix, where N is points and D is dimension.
Returns
-------
U : matrix
Rotation matrix (D,D)
"""
# Computation of the covariance matrix
P
,
Q
=
P
.
to
(
'cpu'
),
Q
.
to
(
'cpu'
)
# move to cpu memory just in case it takes up too much gpu mem
C
=
P
.
transpose
(
-
1
,
-
2
)
@
Q
# Computation of the optimal rotation matrix
# This can be done using singular value decomposition (SVD)
# Getting the sign of the det(V)*(W) to decide
# whether we need to correct our rotation matrix to ensure a
# right-handed coordinate system.
# And finally calculating the optimal rotation matrix U
# see http://en.wikipedia.org/wiki/Kabsch_algorithm
V
,
_
,
W
=
torch
.
linalg
.
svd
(
C
)
d
=
(
torch
.
linalg
.
det
(
V
)
*
torch
.
linalg
.
det
(
W
))
<
0.0
if
d
:
V
[:,
-
1
]
=
-
V
[:,
-
1
]
# Create Rotation matrix U
U
=
V
@
W
return
U
def
get_optimal_transform
(
src_atoms
:
torch
.
Tensor
,
tgt_atoms
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
=
None
,
):
assert
src_atoms
.
shape
==
tgt_atoms
.
shape
,
(
src_atoms
.
shape
,
tgt_atoms
.
shape
)
assert
src_atoms
.
shape
[
-
1
]
==
3
if
mask
is
not
None
:
assert
mask
.
dtype
==
torch
.
bool
assert
mask
.
shape
[
-
1
]
==
src_atoms
.
shape
[
-
2
]
if
mask
.
sum
()
==
0
:
src_atoms
=
torch
.
zeros
((
1
,
3
),
device
=
src_atoms
.
device
).
float
()
tgt_atoms
=
src_atoms
else
:
src_atoms
=
src_atoms
[
mask
,
:]
tgt_atoms
=
tgt_atoms
[
mask
,
:]
src_center
=
src_atoms
.
mean
(
-
2
,
keepdim
=
True
)
tgt_center
=
tgt_atoms
.
mean
(
-
2
,
keepdim
=
True
)
r
=
kabsch_rotation
(
src_atoms
-
src_center
,
tgt_atoms
-
tgt_center
)
tgt_center
,
src_center
=
tgt_center
.
to
(
'cpu'
),
src_center
.
to
(
'cpu'
)
# load to cpu memory just in case
x
=
tgt_center
-
src_center
@
r
return
r
,
x
def
compute_rmsd
(
true_atom_pos
:
torch
.
Tensor
,
pred_atom_pos
:
torch
.
Tensor
,
atom_mask
:
torch
.
Tensor
=
None
,
eps
:
float
=
1e-6
,
)
->
torch
.
Tensor
:
# shape check
sq_diff
=
torch
.
square
(
true_atom_pos
-
pred_atom_pos
).
sum
(
dim
=-
1
,
keepdim
=
False
)
if
atom_mask
is
not
None
:
sq_diff
=
sq_diff
[
atom_mask
]
msd
=
torch
.
mean
(
sq_diff
)
msd
=
torch
.
nan_to_num
(
msd
,
nan
=
1e8
)
return
torch
.
sqrt
(
msd
+
eps
)
def
kabsch_rmsd
(
true_atom_pos
:
torch
.
Tensor
,
pred_atom_pos
:
torch
.
Tensor
,
atom_mask
:
torch
.
Tensor
,
):
r
,
x
=
get_optimal_transform
(
true_atom_pos
,
pred_atom_pos
,
atom_mask
,
)
aligned_true_atom_pos
=
true_atom_pos
@
r
+
x
return
compute_rmsd
(
aligned_true_atom_pos
,
pred_atom_pos
,
atom_mask
)
def
get_least_asym_entity_or_longest_length
(
batch
):
"""
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
"""
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
entity_asym_count
=
{}
entity_length
=
{}
for
entity_id
in
unique_entity_ids
:
asym_ids
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
entity_id
])
entity_asym_count
[
int
(
entity_id
)]
=
len
(
asym_ids
)
# Calculate entity length
entity_mask
=
(
batch
[
"entity_id"
]
==
entity_id
)
entity_length
[
int
(
entity_id
)]
=
entity_mask
.
sum
().
item
()
min_asym_count
=
min
(
entity_asym_count
.
values
())
least_asym_entities
=
[
entity
for
entity
,
count
in
entity_asym_count
.
items
()
if
count
==
min_asym_count
]
# If multiple entities have the least asym_id count, return those with the shortest length
if
len
(
least_asym_entities
)
>
1
:
max_length
=
max
([
entity_length
[
entity
]
for
entity
in
least_asym_entities
])
least_asym_entities
=
[
entity
for
entity
in
least_asym_entities
if
entity_length
[
entity
]
==
max_length
]
# If still multiple entities, return a random one
if
len
(
least_asym_entities
)
>
1
:
least_asym_entities
=
random
.
choice
(
least_asym_entities
)
assert
len
(
least_asym_entities
)
==
1
best_pred_asym
=
torch
.
unique
(
batch
[
"asym_id"
][
batch
[
"entity_id"
]
==
least_asym_entities
[
0
]])
return
least_asym_entities
[
0
],
best_pred_asym
def
greedy_align
(
batch
,
per_asym_residue_index
,
unique_asym_ids
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_mask
,
true_ca_poses
,
true_ca_masks
,
):
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
align
=
[]
for
cur_asym_id
in
unique_asym_ids
:
# skip padding
if
cur_asym_id
==
0
:
continue
i
=
int
(
cur_asym_id
-
1
)
asym_mask
=
batch
[
"asym_id"
]
==
cur_asym_id
entity_id
=
batch
[
"entity_id"
][
asym_mask
][
0
]
# don't need to align
if
(
entity_id
)
==
1
:
align
.
append
((
i
,
i
))
assert
used
[
i
]
==
False
used
[
i
]
=
True
continue
cur_entity_ids
=
batch
[
"entity_id"
][
asym_mask
][
0
]
best_rmsd
=
1e20
best_idx
=
None
cur_asym_list
=
entity_2_asym_list
[
int
(
cur_entity_ids
)]
cur_residue_index
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
cur_pred_pos
=
pred_ca_pos
[
asym_mask
]
# only need the first 1 column of asym_mask
cur_pred_mask
=
pred_ca_mask
[
asym_mask
]
for
next_asym_id
in
cur_asym_list
:
if
next_asym_id
==
0
:
continue
j
=
int
(
next_asym_id
-
1
)
if
not
used
[
j
]:
# possible candidate
cropped_pos
=
true_ca_poses
[
j
]
mask
=
true_ca_masks
[
j
][
cur_residue_index
]
rmsd
=
compute_rmsd
(
cropped_pos
,
cur_pred_pos
,
(
cur_pred_mask
.
to
(
'cpu'
)
*
mask
.
to
(
'cpu'
)).
bool
()
)
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_idx
=
j
assert
best_idx
is
not
None
used
[
best_idx
]
=
True
align
.
append
((
i
,
best_idx
))
print
(
f
"align is
{
align
}
"
)
return
align
def
merge_labels
(
batch
,
per_asym_residue_index
,
labels
,
align
):
"""
batch:
labels: list of label dicts, each with shape [nk, *]
align: list of int, such as [2, None, 0, 1], each entry specify the corresponding label of the asym.
"""
num_res
=
batch
[
"msa_mask"
].
shape
[
-
1
]
outs
=
{}
for
k
,
v
in
labels
[
0
].
items
():
if
k
in
[
"resolution"
,
]:
continue
cur_out
=
{}
for
i
,
j
in
align
:
label
=
labels
[
j
][
k
]
# to 1-based
cur_residue_index
=
per_asym_residue_index
[
i
+
1
]
cur_out
[
i
]
=
label
[
cur_residue_index
]
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
new_v
=
torch
.
concat
(
cur_out
,
dim
=
0
)
merged_nres
=
new_v
.
shape
[
0
]
assert
(
merged_nres
<=
num_res
),
f
"bad merged num res:
{
merged_nres
}
>
{
num_res
}
. something is wrong."
if
merged_nres
<
num_res
:
# must pad
pad_dim
=
new_v
.
shape
[
1
:]
pad_v
=
new_v
.
new_zeros
((
num_res
-
merged_nres
,
*
pad_dim
))
new_v
=
torch
.
concat
((
new_v
,
pad_v
),
dim
=
0
)
outs
[
k
]
=
new_v
return
outs
class
AlphaFoldLoss
(
nn
.
Module
):
"""Aggregation of the various losses described in the supplement"""
...
...
@@ -1782,12 +2004,111 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
AlphaFoldLoss
"""
def
__init__
(
self
,
config
):
super
(
AlphaFoldMultimerLoss
,
self
).
__init__
()
super
(
AlphaFoldMultimerLoss
,
self
).
__init__
(
config
)
self
.
config
=
config
def
multi_chain_perm_align
(
self
,
out
,
batch
,
labels
,
shuffle_times
=
2
):
"""
A class method that first permutate chains in ground truth first
before calculating the loss.
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
assert
isinstance
(
labels
,
list
)
ca_idx
=
rc
.
atom_order
[
"CA"
]
pred_ca_pos
=
out
[
"final_atom_positions"
][...,
ca_idx
,
:].
float
()
# [bsz, nres, 3]
pred_ca_mask
=
out
[
"final_atom_mask"
][...,
ca_idx
].
float
()
# [bsz, nres]
true_ca_poses
=
[
l
[
"all_atom_positions"
][...,
ca_idx
,
:].
float
()
for
l
in
labels
]
# list([nres, 3])
true_ca_masks
=
[
l
[
"all_atom_mask"
][...,
ca_idx
].
float
()
for
l
in
labels
]
# list([nres,])
unique_asym_ids
=
torch
.
unique
(
batch
[
"asym_id"
])
per_asym_residue_index
=
{}
for
cur_asym_id
in
unique_asym_ids
:
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
batch
[
"residue_index"
][
asym_mask
]
anchor_gt_asym
,
anchor_pred_asym
=
get_least_asym_entity_or_longest_length
(
batch
)
logger
.
info
(
f
"anchor_gt_asym is chosen to be:
{
anchor_gt_asym
}
"
)
anchor_gt_idx
=
int
(
anchor_gt_asym
)
-
1
best_rmsd
=
1e20
best_labels
=
None
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
entity_2_asym_list
=
{}
for
cur_ent_id
in
unique_entity_ids
:
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
for
cur_asym_id
in
anchor_pred_asym
:
asym_mask
=
(
batch
[
"asym_id"
]
==
cur_asym_id
).
bool
()
anchor_residue_idx
=
per_asym_residue_index
[
int
(
cur_asym_id
)]
anchor_true_pos
=
true_ca_poses
[
anchor_gt_idx
][
anchor_residue_idx
]
anchor_pred_pos
=
pred_ca_pos
[
asym_mask
]
anchor_true_mask
=
true_ca_masks
[
anchor_gt_idx
][
anchor_residue_idx
]
anchor_pred_mask
=
pred_ca_mask
[
asym_mask
]
r
,
x
=
get_optimal_transform
(
anchor_true_pos
,
anchor_pred_pos
,
(
anchor_true_mask
.
to
(
'cpu'
)
*
anchor_pred_mask
.
to
(
'cpu'
)).
bool
(),
)
aligned_true_ca_poses
=
[
ca
.
to
(
'cpu'
)
@
r
.
to
(
'cpu'
)
+
x
.
to
(
'cpu'
)
for
ca
in
true_ca_poses
]
# apply transforms
for
_
in
range
(
shuffle_times
):
shuffle_idx
=
torch
.
randperm
(
unique_asym_ids
.
shape
[
0
],
device
=
unique_asym_ids
.
device
)
shuffled_asym_ids
=
unique_asym_ids
[
shuffle_idx
]
align
=
greedy_align
(
batch
,
per_asym_residue_index
,
shuffled_asym_ids
,
entity_2_asym_list
,
pred_ca_pos
,
pred_ca_mask
,
aligned_true_ca_poses
,
true_ca_masks
,
)
merged_labels
=
merge_labels
(
batch
,
per_asym_residue_index
,
labels
,
align
,
)
rmsd
=
kabsch_rmsd
(
merged_labels
[
"all_atom_positions"
][...,
ca_idx
,
:].
to
(
'cpu'
)
@
r
.
to
(
'cpu'
)
+
x
.
to
(
'cpu'
),
pred_ca_pos
,
(
pred_ca_mask
.
to
(
'cpu'
)
*
merged_labels
[
"all_atom_mask"
][...,
ca_idx
].
to
(
'cpu'
)).
bool
(),
)
if
rmsd
<
best_rmsd
:
best_rmsd
=
rmsd
best_labels
=
merged_labels
return
best_labels
def
forward
(
self
,
out
,
batch
,
_return_breakdown
=
False
):
"""
Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation
args:
out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure
"""
logger
.
info
(
f
"out is
{
type
(
out
)
}
and batch is
{
type
(
batch
)
}
"
)
\ No newline at end of file
features
,
labels
=
batch
# first remove the recycling dimention of input features
features
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
features
)
# then permutate ground truth chains before calculating the loss
permutated_labels
=
self
.
multi_chain_perm_align
(
out
,
features
,
labels
)
logger
.
info
(
"finished multi-chain permutation"
)
return
permutated_labels
## TODO next need to check how the ground truth label is used
# in loss calculation.
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment