Commit 9a6eb649 authored by Dingquan Yu's avatar Dingquan Yu Committed by Jennifer Wei
Browse files

update comments;fixed typos

parent bc240326
......@@ -32,7 +32,7 @@ def compute_rmsd(
return torch.sqrt(msd + eps) # prevent sqrt 0
def kabsch_rotation(P:torch.Tensor, Q:torch.Tensor) -> torch.Tensor:
def kabsch_rotation(P: torch.Tensor, Q: torch.Tensor) -> torch.Tensor:
"""
Calculate the best rotation that minimises the RMSD between P and Q.
......@@ -44,7 +44,7 @@ def kabsch_rotation(P:torch.Tensor, Q:torch.Tensor) -> torch.Tensor:
Q: [N * 3] the same dimension as P
return:
one 3*3 rotation matrix
one 3*3 rotation matrix that best aligns the sorce and target atoms
"""
assert P.shape == torch.Size([Q.shape[0], Q.shape[1]])
......@@ -187,9 +187,16 @@ def greedy_align(
true_ca_masks: a list of tensors, corresponding to the masks of c-alpha positions of the ground truth structure. If there are 5 chains, this list will have a length of 5
Return:
A list of tuple(int,int) that provides instructions of how the ground truth chains should be permuated
A list of tuple(int,int) that provides instructions of how the ground truth chains should be permuated
e.g. if 3 chains in the imput model have the same sequences, an example return would be:
[(0,2),(1,1),(2,0)], meaning the 1st chain in the predicted structure should be aligned to the 3rd chain in the ground truth,
and the 2nd chain in the predicted structure is ok to stay with the 2nd chain in the ground truth.
Note: the tuples in the returned list begin with 0 indexing but aym_id begins with 1. The reason why tuples in the return are 0-indexing
is that at the stage of loss calculation, the ground truth atom positions: true_ca_poses, are already split up into a list of matrices.
Hence, now this function needs to return tuples that provide the index to select from the list: true_ca_poses, and list index starts from 0.
"""
used = [False for _ in range(len(true_ca_poses))]
used = [False for _ in range(len(true_ca_poses))] # a list the keeps recording whether a ground truth chain has been used or not
align = []
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i != 0]
for cur_asym_id in unique_asym_ids:
......@@ -326,22 +333,22 @@ def get_per_asym_residue_index(features: dict) -> Dict[int, list]:
return per_asym_residue_index
def get_entity_2_asym_list(batch: dict) -> Dict[int, list]:
def get_entity_2_asym_list(features: dict) -> Dict[int, list]:
"""
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.
Args:
batch (dict): A dictionary containing data batches, including "entity_id" and "asym_id" tensors.
features (dict): A dictionary containing data features, including "entity_id" and "asym_id" tensors.
Returns:
entity_2_asym_list (dict): A dictionary where keys are unique entity IDs, and values are lists of unique asymmetry IDs
associated with each entity.
"""
entity_2_asym_list = {}
unique_entity_ids = torch.unique(batch["entity_id"])
unique_entity_ids = torch.unique(features["entity_id"])
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
ent_mask = features["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(features["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
return entity_2_asym_list
......@@ -428,7 +435,10 @@ def compute_permutation_alignment(out: Dict[str,torch.Tensor],
because the mapping between the predicted and ground-truth will become arbitrary.
The model cannot be assumed to predict chains in the same order as the ground truth.
Thus, this function pick the optimal permutaion of predicted chains that best matches the ground truth,
by minimising the RMSD.
by minimising the RMSD i.e. the best permutation of ground truth chains is selected based on which permutation has the lowest RMSD calculation
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
Args:
out: a dictionary of output tensors from model.forward()
......@@ -438,10 +448,6 @@ def compute_permutation_alignment(out: Dict[str,torch.Tensor],
Returns:
a list of tuple(int,int) that instructs how ground truth chains should be permutated
a dictionary recording which residues belong to which aysm_id
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
unique_asym_ids = set(torch.unique(features['asym_id']).tolist())
unique_asym_ids.discard(0) # Remove padding asym_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment