Commit 04d63782 authored by Geoffrey Yu's avatar Geoffrey Yu Committed by Jennifer Wei
Browse files

fixed typing errors; added more comments

parent 6d418384
import logging import logging
import random import random
import torch import torch
from typing import Tuple, List,Dict from typing import Tuple, List, Dict
from openfold.np import residue_constants as rc from openfold.np import residue_constants as rc
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -74,6 +74,11 @@ def get_optimal_transform( ...@@ -74,6 +74,11 @@ def get_optimal_transform(
src_atoms: predicted CA positions, shape:[num_res,3] src_atoms: predicted CA positions, shape:[num_res,3]
tgt_atoms: ground-truth CA positions, shape:[num_res,3] tgt_atoms: ground-truth CA positions, shape:[num_res,3]
mask: a vector of boolean values, shape:[num_res] mask: a vector of boolean values, shape:[num_res]
Returns:
a rotation matrix that record the optimal rotation
that will best align selected anchor prediction to selected anchor truth
a matrix records how the atoms should be shifted after applying r i.e. optimal alignment requires 1) rotate 2) shift the positions
""" """
assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape) assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape)
assert src_atoms.shape[-1] == 3 assert src_atoms.shape[-1] == 3
...@@ -103,7 +108,7 @@ def get_optimal_transform( ...@@ -103,7 +108,7 @@ def get_optimal_transform(
return r, x return r, x
def get_least_asym_entity_or_longest_length(batch:dict, input_asym_id:list)->Tuple[torch.Tensor, List[torch.Tensor]]: def get_least_asym_entity_or_longest_length(batch: dict, input_asym_id: list) -> Tuple[torch.Tensor, List[torch.Tensor]]:
""" """
First check how many subunit(s) one sequence has. Select the subunit that is less First check how many subunit(s) one sequence has. Select the subunit that is less
common, e.g. if the protein was AABBB then select one of the A as anchor common, e.g. if the protein was AABBB then select one of the A as anchor
...@@ -160,14 +165,14 @@ def get_least_asym_entity_or_longest_length(batch:dict, input_asym_id:list)->Tup ...@@ -160,14 +165,14 @@ def get_least_asym_entity_or_longest_length(batch:dict, input_asym_id:list)->Tup
def greedy_align( def greedy_align(
batch:dict, batch: dict,
per_asym_residue_index:dict, per_asym_residue_index: dict,
entity_2_asym_list:dict, entity_2_asym_list: dict,
pred_ca_pos:torch.Tensor, pred_ca_pos: torch.Tensor,
pred_ca_mask:torch.Tensor, pred_ca_mask: torch.Tensor,
true_ca_poses:list, true_ca_poses: list,
true_ca_masks:list true_ca_masks: list
) -> List[Tuple[int,int]]: ) -> List[Tuple[int, int]]:
""" """
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper: Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034 Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
...@@ -216,7 +221,7 @@ def greedy_align( ...@@ -216,7 +221,7 @@ def greedy_align(
return align return align
def pad_features(feature_tensor:torch.Tensor, nres_pad:int, pad_dim:int) -> torch.Tensor: def pad_features(feature_tensor: torch.Tensor, nres_pad: int, pad_dim: int) -> torch.Tensor:
""" """
Pad input feature tensor. Padding values will be 0 and put behind the true feature values Pad input feature tensor. Padding values will be 0 and put behind the true feature values
...@@ -234,9 +239,9 @@ def pad_features(feature_tensor:torch.Tensor, nres_pad:int, pad_dim:int) -> torc ...@@ -234,9 +239,9 @@ def pad_features(feature_tensor:torch.Tensor, nres_pad:int, pad_dim:int) -> torc
return torch.concat((feature_tensor, padding_tensor), dim=pad_dim) return torch.concat((feature_tensor, padding_tensor), dim=pad_dim)
def merge_labels(per_asym_residue_index:Dict[int,List[int]], def merge_labels(per_asym_residue_index: Dict[int,List[int]],
labels:dict, align:List[Tuple[int, int]], labels: List[Dict], align: List[Tuple[int, int]],
original_nres:int) -> Dict[str,torch.Tensor]: original_nres: int) -> Dict[str, torch.Tensor]:
""" """
Merge ground truth labels according to the permutation results Merge ground truth labels according to the permutation results
...@@ -274,13 +279,20 @@ def merge_labels(per_asym_residue_index:Dict[int,List[int]], ...@@ -274,13 +279,20 @@ def merge_labels(per_asym_residue_index:Dict[int,List[int]],
return outs return outs
def split_ground_truth_labels(gt_features:dict) -> List[Dict]: def split_ground_truth_labels(gt_features: dict) -> List[Dict]:
""" """
Splits ground truth features according to chains Splits ground truth features according to chains
Args:
gt_features: A dictionary within a the PyTorch DataSet iteration, which returns by the upstream DataLoader.iter() method
In the DataLoader pipeline, all tensors belonging to all the ground truth changes are concatenated so it stays the same as monomer data input format/pipeline,
thus, this function is needed to 1) detect the number of chains i.e. unique(asym_id)
2) split the concatenated tensors back to individual ones that correspond to individual asym_ids
Returns: Returns:
a list of feature dictionaries with only necessary ground truth features a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation required to finish multi-chain permutation, e.g. it will be a list of 5 elements if there
are 5 chains in total.
""" """
unique_asym_ids, asym_id_counts = torch.unique(gt_features["asym_id"], sorted=True, return_counts=True) unique_asym_ids, asym_id_counts = torch.unique(gt_features["asym_id"], sorted=True, return_counts=True)
n_res = gt_features["asym_id"].shape[-1] n_res = gt_features["asym_id"].shape[-1]
...@@ -295,7 +307,7 @@ def split_ground_truth_labels(gt_features:dict) -> List[Dict]: ...@@ -295,7 +307,7 @@ def split_ground_truth_labels(gt_features:dict) -> List[Dict]:
return labels return labels
def get_per_asym_residue_index(features: dict) -> Dict[int,list]: def get_per_asym_residue_index(features: dict) -> Dict[int, list]:
""" """
A function that retrieve which residues belong to which asym_id A function that retrieve which residues belong to which asym_id
...@@ -314,7 +326,7 @@ def get_per_asym_residue_index(features: dict) -> Dict[int,list]: ...@@ -314,7 +326,7 @@ def get_per_asym_residue_index(features: dict) -> Dict[int,list]:
return per_asym_residue_index return per_asym_residue_index
def get_entity_2_asym_list(batch: dict) -> Dict[int,list]: def get_entity_2_asym_list(batch: dict) -> Dict[int, list]:
""" """
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity. Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.
...@@ -334,9 +346,9 @@ def get_entity_2_asym_list(batch: dict) -> Dict[int,list]: ...@@ -334,9 +346,9 @@ def get_entity_2_asym_list(batch: dict) -> Dict[int,list]:
return entity_2_asym_list return entity_2_asym_list
def calculate_input_mask(true_ca_masks:List[torch.Tensor], anchor_gt_idx:torch.Tensor, def calculate_input_mask(true_ca_masks: List[torch.Tensor], anchor_gt_idx: torch.Tensor,
anchor_gt_residue:list, anchor_gt_residue: list,
asym_mask:torch.Tensor, pred_ca_mask:torch.Tensor) -> torch.Tensor: asym_mask: torch.Tensor, pred_ca_mask: torch.Tensor) -> torch.Tensor:
""" """
Calculate an input mask for downstream optimal transformation computation Calculate an input mask for downstream optimal transformation computation
...@@ -358,11 +370,11 @@ def calculate_input_mask(true_ca_masks:List[torch.Tensor], anchor_gt_idx:torch.T ...@@ -358,11 +370,11 @@ def calculate_input_mask(true_ca_masks:List[torch.Tensor], anchor_gt_idx:torch.T
return input_mask return input_mask
def calculate_optimal_transform(true_ca_poses:List[torch.Tensor], def calculate_optimal_transform(true_ca_poses: List[torch.Tensor],
anchor_gt_idx:int, anchor_gt_residue:list, anchor_gt_idx: int, anchor_gt_residue: list,
true_ca_masks:List[torch.Tensor], pred_ca_mask:torch.Tensor, true_ca_masks: List[torch.Tensor], pred_ca_mask: torch.Tensor,
asym_mask:torch.Tensor, asym_mask: torch.Tensor,
pred_ca_pos:torch.Tensor): pred_ca_pos: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Takes selected anchor ground truth c-alpha positions and Takes selected anchor ground truth c-alpha positions and
...@@ -377,6 +389,18 @@ def calculate_optimal_transform(true_ca_poses:List[torch.Tensor], ...@@ -377,6 +389,18 @@ def calculate_optimal_transform(true_ca_poses:List[torch.Tensor],
pred_ca_mask: A boolean tensor corresponds to the mask to mask the predicted features pred_ca_mask: A boolean tensor corresponds to the mask to mask the predicted features
asym_mask: A boolean tensor that mask out other elements in a tensor if they do not belong to a this asym_id asym_mask: A boolean tensor that mask out other elements in a tensor if they do not belong to a this asym_id
pred_ca_pos: a [nres*3] tensor of predicted c-alpha atom positions pred_ca_pos: a [nres*3] tensor of predicted c-alpha atom positions
Process:
1) select an achor chain from ground truth, denoted by anchor_gt_idx, and
an chor chain from the predicted structure. Both anchor_gt and anchor_pred have exactly the same sequence
2) obtain the C-alpha positions corresponding to the selected anchor_gt, done be slicing the true_ca_pose according to anchor_gt_residue
3) calculate the optimal transformation that can best align the C-alpha atoms of anchor_pred to those of anchor_gt,
done by Kabsch algorithm: source https://en.wikipedia.org/wiki/Kabsch_algorithm
Returns:
a rotation matrix that record the optimal rotation
that will best align selected anchor prediction to selected anchor truth
a matrix records how the atoms should be shifted after applying r i.e. optimal alignment requires 1) rotate 2) shift the positions
""" """
input_mask = calculate_input_mask(true_ca_masks, input_mask = calculate_input_mask(true_ca_masks,
anchor_gt_idx, anchor_gt_idx,
...@@ -396,11 +420,11 @@ def calculate_optimal_transform(true_ca_poses:List[torch.Tensor], ...@@ -396,11 +420,11 @@ def calculate_optimal_transform(true_ca_poses:List[torch.Tensor],
return r, x return r, x
def compute_permutation_alignment(out:Dict[str,torch.Tensor], def compute_permutation_alignment(out: Dict[str,torch.Tensor],
features:Dict[str,torch.Tensor], features: Dict[str,torch.Tensor],
ground_truth:List[Dict[str, torch.Tensor]]) -> Tuple[List[Tuple[int,int]], Dict[int,List[int]]]: ground_truth: List[Dict[str, torch.Tensor]]) -> Tuple[List[Tuple[int, int]], Dict[int, List[int]]]:
""" """
Permutates chains in ground truth first A method that permutes chains in ground truth
before calculating the loss. before calculating the loss.
Args: Args:
...@@ -409,8 +433,8 @@ def compute_permutation_alignment(out:Dict[str,torch.Tensor], ...@@ -409,8 +433,8 @@ def compute_permutation_alignment(out:Dict[str,torch.Tensor],
ground_truth: a list of dictionaries of features corresponding to chains in ground truth structure e.g. it will be a length of 5 if there are 5 chains in ground truth structure ground_truth: a list of dictionaries of features corresponding to chains in ground truth structure e.g. it will be a length of 5 if there are 5 chains in ground truth structure
Returns: Returns:
best_align: a list of tuple(int,int) that instructs how ground truth chains should be permutated a list of tuple(int,int) that instructs how ground truth chains should be permutated
per_asym_residue_index: per_asym_residue_index: a dictionary recording which residues belong to which aysm_id a dictionary recording which residues belong to which aysm_id
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper: Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2 https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
...@@ -479,9 +503,9 @@ def compute_permutation_alignment(out:Dict[str,torch.Tensor], ...@@ -479,9 +503,9 @@ def compute_permutation_alignment(out:Dict[str,torch.Tensor],
return best_align, per_asym_residue_index return best_align, per_asym_residue_index
def multi_chain_permutation_align(out:Dict[str,torch.Tensor], def multi_chain_permutation_align(out: Dict[str, torch.Tensor],
features:Dict[str,torch.Tensor], features: Dict[str, torch.Tensor],
ground_truth:List[Dict[str, torch.Tensor]])->Dict[str,torch.Tensor]: ground_truth: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
""" """
Compute multi-chain permutation alignment. Compute multi-chain permutation alignment.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment