Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
Args:
batch: a dictionary of ground truth features
per_asym_residue_index: a dictionary recording which residues belong to which aysm_id
entity_2_asym_list: a dictionary recording which asym_id(s) belong to which entity_id
pred_ca_pos: predicted positions of c-alpha atoms from the results of model.forward()
pred_ca_mask: a boolean tensor that masks pred_ca_pos
true_ca_poses: a list of tensors, corresponding to the c-alpha positions of the ground truth structure. e.g. If there are 5 chains, this list will have a length of 5
true_ca_masks: a list of tensors, corresponding to the masks of c-alpha positions of the ground truth structure. If there are 5 chains, this list will have a length of 5
Return:
A list of tuple(int,int) that provides instructions of how the ground truth chains should be permuated
e.g. if 3 chains in the imput model have the same sequences, an example return would be:
[(0,2),(1,1),(2,0)], meaning the 1st chain in the predicted structure should be aligned to the 3rd chain in the ground truth,
and the 2nd chain in the predicted structure is ok to stay with the 2nd chain in the ground truth.
Note: the tuples in the returned list begin with 0 indexing but aym_id begins with 1. The reason why tuples in the return are 0-indexing
is that at the stage of loss calculation, the ground truth atom positions: true_ca_poses, are already split up into a list of matrices.
Hence, now this function needs to return tuples that provide the index to select from the list: true_ca_poses, and list index starts from 0.
"""
"""
used=[Falsefor_inrange(len(true_ca_poses))]
used=[Falsefor_inrange(len(true_ca_poses))]# a list the keeps recording whether a ground truth chain has been used or not
gt_features: A dictionary within a the PyTorch DataSet iteration, which returns by the upstream DataLoader.iter() method
In the DataLoader pipeline, all tensors belonging to all the ground truth changes are concatenated so it stays the same as monomer data input format/pipeline,
thus, this function is needed to 1) detect the number of chains i.e. unique(asym_id)
2) split the concatenated tensors back to individual ones that correspond to individual asym_ids
Returns:
Returns:
a list of feature dictionaries with only necessary ground truth features
a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation
required to finish multi-chain permutation, e.g. it will be a list of 5 elements if there
Takes selected anchor ground truth c-alpha positions and
selected predicted anchor c-alpha position then calculate the optimal rotation matrix
to align ground-truth anchor and predicted anchor
Args:
true_ca_poses: a list of tensors, corresponding to the c-alpha positions of the ground truth structure. e.g. If there are 5 chains, this list will have a length of 5
anchor_gt_idx (Tensor): a tensor with one integer in it. The index of selected ground truth anchor.
anchor_gt_residue:a 1D vector tensor of residue indexes that belongs to the selected ground truth anchor
true_ca_masks: list of masks from ground truth chains e.g. it will be length=5 if there are 5 chains in ground truth structure
pred_ca_mask: A boolean tensor corresponds to the mask to mask the predicted features
asym_mask: A boolean tensor that mask out other elements in a tensor if they do not belong to a this asym_id
pred_ca_pos: a [nres*3] tensor of predicted c-alpha atom positions
Process:
1) select an achor chain from ground truth, denoted by anchor_gt_idx, and
an chor chain from the predicted structure. Both anchor_gt and anchor_pred have exactly the same sequence
2) obtain the C-alpha positions corresponding to the selected anchor_gt, done be slicing the true_ca_pose according to anchor_gt_residue
3) calculate the optimal transformation that can best align the C-alpha atoms of anchor_pred to those of anchor_gt,
done by Kabsch algorithm: source https://en.wikipedia.org/wiki/Kabsch_algorithm
Returns:
a rotation matrix that record the optimal rotation
that will best align selected anchor prediction to selected anchor truth
a matrix records how the atoms should be shifted after applying r i.e. optimal alignment requires 1) rotate 2) shift the positions
out: a dictionary of output tensors from model.forward()
features: a dictionary of feature tensors that are used as input for model.forward()
ground_truth: a list of dictionaries of features corresponding to chains in ground truth structure e.g. it will be a length of 5 if there are 5 chains in ground truth structure
Returns:
a list of tuple(int,int) that instructs how ground truth chains should be permutated
a dictionary recording which residues belong to which aysm_id
out: a dictionary of output tensors from model.forward()
features: Input features
features: a dictionary of feature tensors that are used as input for model.forward()
ground_truth: Ground truth features
ground_truth: a list of dictionaries of features corresponding to chains in ground truth structure e.g. it will be a length of 5 if there are 5 chains in ground truth structure
Returns:
features: a dictionary with updated ground truth feature tensors, ready for downstream loss calculations.