Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
9a6eb649
Commit
9a6eb649
authored
May 10, 2024
by
Dingquan Yu
Committed by
Jennifer Wei
May 11, 2024
Browse files
update comments;fixed typos
parent
bc240326
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
14 deletions
+20
-14
openfold/utils/multi_chain_permutation.py
openfold/utils/multi_chain_permutation.py
+20
-14
No files found.
openfold/utils/multi_chain_permutation.py
View file @
9a6eb649
...
...
@@ -32,7 +32,7 @@ def compute_rmsd(
return
torch
.
sqrt
(
msd
+
eps
)
# prevent sqrt 0
def
kabsch_rotation
(
P
:
torch
.
Tensor
,
Q
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
kabsch_rotation
(
P
:
torch
.
Tensor
,
Q
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Calculate the best rotation that minimises the RMSD between P and Q.
...
...
@@ -44,7 +44,7 @@ def kabsch_rotation(P:torch.Tensor, Q:torch.Tensor) -> torch.Tensor:
Q: [N * 3] the same dimension as P
return:
one 3*3 rotation matrix
one 3*3 rotation matrix
that best aligns the sorce and target atoms
"""
assert
P
.
shape
==
torch
.
Size
([
Q
.
shape
[
0
],
Q
.
shape
[
1
]])
...
...
@@ -188,8 +188,15 @@ def greedy_align(
Return:
A list of tuple(int,int) that provides instructions of how the ground truth chains should be permuated
e.g. if 3 chains in the imput model have the same sequences, an example return would be:
[(0,2),(1,1),(2,0)], meaning the 1st chain in the predicted structure should be aligned to the 3rd chain in the ground truth,
and the 2nd chain in the predicted structure is ok to stay with the 2nd chain in the ground truth.
Note: the tuples in the returned list begin with 0 indexing but aym_id begins with 1. The reason why tuples in the return are 0-indexing
is that at the stage of loss calculation, the ground truth atom positions: true_ca_poses, are already split up into a list of matrices.
Hence, now this function needs to return tuples that provide the index to select from the list: true_ca_poses, and list index starts from 0.
"""
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
used
=
[
False
for
_
in
range
(
len
(
true_ca_poses
))]
# a list the keeps recording whether a ground truth chain has been used or not
align
=
[]
unique_asym_ids
=
[
i
for
i
in
torch
.
unique
(
batch
[
"asym_id"
])
if
i
!=
0
]
for
cur_asym_id
in
unique_asym_ids
:
...
...
@@ -326,22 +333,22 @@ def get_per_asym_residue_index(features: dict) -> Dict[int, list]:
return
per_asym_residue_index
def
get_entity_2_asym_list
(
batch
:
dict
)
->
Dict
[
int
,
list
]:
def
get_entity_2_asym_list
(
features
:
dict
)
->
Dict
[
int
,
list
]:
"""
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.
Args:
batch
(dict): A dictionary containing data
batch
es, including "entity_id" and "asym_id" tensors.
features
(dict): A dictionary containing data
featur
es, including "entity_id" and "asym_id" tensors.
Returns:
entity_2_asym_list (dict): A dictionary where keys are unique entity IDs, and values are lists of unique asymmetry IDs
associated with each entity.
"""
entity_2_asym_list
=
{}
unique_entity_ids
=
torch
.
unique
(
batch
[
"entity_id"
])
unique_entity_ids
=
torch
.
unique
(
features
[
"entity_id"
])
for
cur_ent_id
in
unique_entity_ids
:
ent_mask
=
batch
[
"entity_id"
]
==
cur_ent_id
cur_asym_id
=
torch
.
unique
(
batch
[
"asym_id"
][
ent_mask
])
ent_mask
=
features
[
"entity_id"
]
==
cur_ent_id
cur_asym_id
=
torch
.
unique
(
features
[
"asym_id"
][
ent_mask
])
entity_2_asym_list
[
int
(
cur_ent_id
)]
=
cur_asym_id
return
entity_2_asym_list
...
...
@@ -428,7 +435,10 @@ def compute_permutation_alignment(out: Dict[str,torch.Tensor],
because the mapping between the predicted and ground-truth will become arbitrary.
The model cannot be assumed to predict chains in the same order as the ground truth.
Thus, this function pick the optimal permutaion of predicted chains that best matches the ground truth,
by minimising the RMSD.
by minimising the RMSD i.e. the best permutation of ground truth chains is selected based on which permutation has the lowest RMSD calculation
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
Args:
out: a dictionary of output tensors from model.forward()
...
...
@@ -438,10 +448,6 @@ def compute_permutation_alignment(out: Dict[str,torch.Tensor],
Returns:
a list of tuple(int,int) that instructs how ground truth chains should be permutated
a dictionary recording which residues belong to which aysm_id
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
unique_asym_ids
=
set
(
torch
.
unique
(
features
[
'asym_id'
]).
tolist
())
unique_asym_ids
.
discard
(
0
)
# Remove padding asym_id
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment