Commit f7571e25 authored by Geoffrey Yu's avatar Geoffrey Yu Committed by Jennifer Wei
Browse files

added typing hints and fixed some comments

parent 77cb4135
import logging import logging
import random import random
import torch import torch
from typing import Tuple, List,Dict
from openfold.np import residue_constants as rc from openfold.np import residue_constants as rc
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -13,6 +13,17 @@ def compute_rmsd( ...@@ -13,6 +13,17 @@ def compute_rmsd(
atom_mask: torch.Tensor = None, atom_mask: torch.Tensor = None,
eps: float = 1e-6, eps: float = 1e-6,
) -> torch.Tensor: ) -> torch.Tensor:
"""
Function to calculate RMSD between predicted and ground truth atom position
Args:
true_atom_pos: a [nres*3] tensor
pred_atom_pos: a [nres*3] tensor
atom_mask: a [1*nres] tensor
Return:
RMSD value between true and predicted atom positions
"""
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False) sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
if atom_mask is not None: if atom_mask is not None:
sq_diff = torch.masked_select(sq_diff, atom_mask.to(sq_diff.device)) sq_diff = torch.masked_select(sq_diff, atom_mask.to(sq_diff.device))
...@@ -21,7 +32,7 @@ def compute_rmsd( ...@@ -21,7 +32,7 @@ def compute_rmsd(
return torch.sqrt(msd + eps) # prevent sqrt 0 return torch.sqrt(msd + eps) # prevent sqrt 0
def kabsch_rotation(P, Q): def kabsch_rotation(P:torch.Tensor, Q:torch.Tensor) -> torch.Tensor:
""" """
Calculate the best rotation that minimises the RMSD between P and Q. Calculate the best rotation that minimises the RMSD between P and Q.
...@@ -29,11 +40,11 @@ def kabsch_rotation(P, Q): ...@@ -29,11 +40,11 @@ def kabsch_rotation(P, Q):
https://en.wikipedia.org/wiki/Kabsch_algorithm https://en.wikipedia.org/wiki/Kabsch_algorithm
Args: Args:
P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates
Q: [N * 3] the same dimension as P Q: [N * 3] the same dimension as P
return: return:
A 3*3 rotation matrix one 3*3 rotation matrix
""" """
assert P.shape == torch.Size([Q.shape[0], Q.shape[1]]) assert P.shape == torch.Size([Q.shape[0], Q.shape[1]])
...@@ -54,11 +65,15 @@ def get_optimal_transform( ...@@ -54,11 +65,15 @@ def get_optimal_transform(
src_atoms: torch.Tensor, src_atoms: torch.Tensor,
tgt_atoms: torch.Tensor, tgt_atoms: torch.Tensor,
mask: torch.Tensor = None, mask: torch.Tensor = None,
): ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
src_atoms: predicted CA positions, shape:[num_res,3] A function that obtain the transformation that optimally align
tgt_atoms: ground-truth CA positions, shape:[num_res,3] src_atoms with tgt_atoms
mask: a vector of boolean values, shape:[num_res]
Args:
src_atoms: predicted CA positions, shape:[num_res,3]
tgt_atoms: ground-truth CA positions, shape:[num_res,3]
mask: a vector of boolean values, shape:[num_res]
""" """
assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape) assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape)
assert src_atoms.shape[-1] == 3 assert src_atoms.shape[-1] == 3
...@@ -88,7 +103,7 @@ def get_optimal_transform( ...@@ -88,7 +103,7 @@ def get_optimal_transform(
return r, x return r, x
def get_least_asym_entity_or_longest_length(batch, input_asym_id): def get_least_asym_entity_or_longest_length(batch:dict, input_asym_id:list)->Tuple[torch.Tensor, List[torch.Tensor]]:
""" """
First check how many subunit(s) one sequence has. Select the subunit that is less First check how many subunit(s) one sequence has. Select the subunit that is less
common, e.g. if the protein was AABBB then select one of the A as anchor common, e.g. if the protein was AABBB then select one of the A as anchor
...@@ -97,12 +112,12 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): ...@@ -97,12 +112,12 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
then choose one of the corresponding subunits as anchor then choose one of the corresponding subunits as anchor
Args: Args:
batch: in this function batch is the full ground truth features batch: in this function batch is the full ground truth features
input_asym_id: A list of asym_ids that are in the cropped input features input_asym_id: A list of asym_ids that are in the cropped input features
Return: Return:
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates
""" """
entity_2_asym_list = get_entity_2_asym_list(batch) entity_2_asym_list = get_entity_2_asym_list(batch)
unique_entity_ids = torch.unique(batch["entity_id"]) unique_entity_ids = torch.unique(batch["entity_id"])
...@@ -145,17 +160,29 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): ...@@ -145,17 +160,29 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
def greedy_align( def greedy_align(
batch, batch:dict,
per_asym_residue_index, per_asym_residue_index:dict,
entity_2_asym_list, entity_2_asym_list:dict,
pred_ca_pos, pred_ca_pos:torch.Tensor,
pred_ca_mask, pred_ca_mask:torch.Tensor,
true_ca_poses, true_ca_poses:list,
true_ca_masks, true_ca_masks:list
): ) -> List[Tuple[int,int]]:
""" """
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper: Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034 Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
Args:
batch: a dictionary of ground truth features
per_asym_residue_index: a dictionary recording which residues belong to which aysm_id
entity_2_asym_list: a dictionary recording which asym_id(s) belong to which entity_id
pred_ca_pos: predicted positions of c-alpha atoms from the results of model.forward()
pred_ca_mask: a boolean tensor that masks pred_ca_pos
true_ca_poses: a list of tensors, corresponding to the c-alpha positions of the ground truth structure. e.g. If there are 5 chains, this list will have a length of 5
true_ca_masks: a list of tensors, corresponding to the masks of c-alpha positions of the ground truth structure. If there are 5 chains, this list will have a length of 5
Return:
A list of tuple(int,int) that provides instructions of how the ground truth chains should be permuated
""" """
used = [False for _ in range(len(true_ca_poses))] used = [False for _ in range(len(true_ca_poses))]
align = [] align = []
...@@ -189,21 +216,38 @@ def greedy_align( ...@@ -189,21 +216,38 @@ def greedy_align(
return align return align
def pad_features(feature_tensor, nres_pad, pad_dim): def pad_features(feature_tensor:torch.Tensor, nres_pad:int, pad_dim:int) -> torch.Tensor:
"""Pad input feature tensor""" """
Pad input feature tensor. Padding values will be 0 and put behind the true feature values
Args:
feature_tensor: A feature tensor
nres_pad: number of residues to add
pad_dim: along which dimension of the feature_tensor to pad
Returns:
a padded feature tensor
"""
pad_shape = list(feature_tensor.shape) pad_shape = list(feature_tensor.shape)
pad_shape[pad_dim] = nres_pad pad_shape[pad_dim] = nres_pad
padding_tensor = feature_tensor.new_zeros(pad_shape, device=feature_tensor.device) padding_tensor = feature_tensor.new_zeros(pad_shape, device=feature_tensor.device)
return torch.concat((feature_tensor, padding_tensor), dim=pad_dim) return torch.concat((feature_tensor, padding_tensor), dim=pad_dim)
def merge_labels(per_asym_residue_index, labels, align, original_nres): def merge_labels(per_asym_residue_index:Dict[int,List[int]],
labels:dict, align:List[Tuple[int, int]],
original_nres:int) -> Dict[str,torch.Tensor]:
""" """
Merge ground truth labels according to the permutation results Merge ground truth labels according to the permutation results
labels: list of original ground truth feats Args:
align: list of tuples, each entry specify the corresponding label of the asym. per_asym_residue_index: a dictionary recording which residues belong to which aysm_id
labels: list of original ground truth feats e.g. if there're 5 chains, labels will have a length of 5
align: list of tuples, each entry specify the corresponding label of the asym.
original_nres: int, corresponding to the number of residues specified by crop_size in config.py
Returns:
A new dictionary of permuated ground truth features
modified based on UniFold: modified based on UniFold:
https://github.com/dptech-corp/Uni-Fold/blob/b1c89a2cebd4e4ee4c47b4e443f92beeb9138fbb/unifold/losses/chain_align.py#L176C1-L176C1 https://github.com/dptech-corp/Uni-Fold/blob/b1c89a2cebd4e4ee4c47b4e443f92beeb9138fbb/unifold/losses/chain_align.py#L176C1-L176C1
""" """
...@@ -230,13 +274,13 @@ def merge_labels(per_asym_residue_index, labels, align, original_nres): ...@@ -230,13 +274,13 @@ def merge_labels(per_asym_residue_index, labels, align, original_nres):
return outs return outs
def split_ground_truth_labels(gt_features): def split_ground_truth_labels(gt_features:dict) -> List[Dict]:
""" """
Splits ground truth features according to chains Splits ground truth features according to chains
Returns: Returns:
a list of feature dictionaries with only necessary ground truth features a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation required to finish multi-chain permutation
""" """
unique_asym_ids, asym_id_counts = torch.unique(gt_features["asym_id"], sorted=True, return_counts=True) unique_asym_ids, asym_id_counts = torch.unique(gt_features["asym_id"], sorted=True, return_counts=True)
n_res = gt_features["asym_id"].shape[-1] n_res = gt_features["asym_id"].shape[-1]
...@@ -251,7 +295,16 @@ def split_ground_truth_labels(gt_features): ...@@ -251,7 +295,16 @@ def split_ground_truth_labels(gt_features):
return labels return labels
def get_per_asym_residue_index(features): def get_per_asym_residue_index(features: dict) -> Dict[int,list]:
"""
A function that retrieve which residues belong to which asym_id
Args:
features: a dictionary that contains input features after cropping
Returns:
A dictionary that records which region of the sequence belongs to which asym_id
"""
unique_asym_ids = [i for i in torch.unique(features["asym_id"]) if i != 0] unique_asym_ids = [i for i in torch.unique(features["asym_id"]) if i != 0]
per_asym_residue_index = {} per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
...@@ -261,7 +314,7 @@ def get_per_asym_residue_index(features): ...@@ -261,7 +314,7 @@ def get_per_asym_residue_index(features):
return per_asym_residue_index return per_asym_residue_index
def get_entity_2_asym_list(batch): def get_entity_2_asym_list(batch: dict) -> Dict[int,list]:
""" """
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity. Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.
...@@ -281,14 +334,16 @@ def get_entity_2_asym_list(batch): ...@@ -281,14 +334,16 @@ def get_entity_2_asym_list(batch):
return entity_2_asym_list return entity_2_asym_list
def calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue, def calculate_input_mask(true_ca_masks:List[torch.Tensor], anchor_gt_idx:torch.Tensor,
asym_mask, pred_ca_mask): anchor_gt_residue:list,
asym_mask:torch.Tensor, pred_ca_mask:torch.Tensor) -> torch.Tensor:
""" """
Calculate an input mask for downstream optimal transformation computation Calculate an input mask for downstream optimal transformation computation
Args: Args:
true_ca_masks (Tensor): ca mask from ground truth. true_ca_masks: list of masks from ground truth chains.
anchor_gt_idx (Tensor): The index of selected ground truth anchor. anchor_gt_idx (Tensor): a tensor with one integer in it. The index of selected ground truth anchor.
anchor_gt_residue:a list of residue indexes that belongs to the selected ground truth anchor
asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor. asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor.
pred_ca_mask (Tensor): ca mask from predicted structure. pred_ca_mask (Tensor): ca mask from predicted structure.
...@@ -303,11 +358,26 @@ def calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue, ...@@ -303,11 +358,26 @@ def calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue,
return input_mask return input_mask
def calculate_optimal_transform(true_ca_poses, def calculate_optimal_transform(true_ca_poses:List[torch.Tensor],
anchor_gt_idx, anchor_gt_residue, anchor_gt_idx:int, anchor_gt_residue:list,
true_ca_masks, pred_ca_mask, true_ca_masks:List[torch.Tensor], pred_ca_mask:torch.Tensor,
asym_mask, asym_mask:torch.Tensor,
pred_ca_pos): pred_ca_pos:torch.Tensor):
"""
Takes selected anchor ground truth c-alpha positions and
selected predicted anchor c-alpha position then calculate the optimal rotation matrix
to align ground-truth anchor and predicted anchor
Args:
true_ca_poses: a list of tensors, corresponding to the c-alpha positions of the ground truth structure. e.g. If there are 5 chains, this list will have a length of 5
anchor_gt_idx (Tensor): a tensor with one integer in it. The index of selected ground truth anchor.
anchor_gt_residue:a list of residue indexes that belongs to the selected ground truth anchor
true_ca_masks: list of masks from ground truth chains e.g. it will be length=5 if there are 5 chains in ground truth structure
pred_ca_mask: A boolean tensor corresponds to the mask to mask the predicted features
asym_mask: A boolean tensor that mask out other elements in a tensor if they do not belong to a this asym_id
pred_ca_pos: a [nres*3] tensor of predicted c-alpha atom positions
"""
input_mask = calculate_input_mask(true_ca_masks, input_mask = calculate_input_mask(true_ca_masks,
anchor_gt_idx, anchor_gt_idx,
anchor_gt_residue, anchor_gt_residue,
...@@ -326,13 +396,25 @@ def calculate_optimal_transform(true_ca_poses, ...@@ -326,13 +396,25 @@ def calculate_optimal_transform(true_ca_poses,
return r, x return r, x
def compute_permutation_alignment(out, features, ground_truth): def compute_permutation_alignment(out:Dict[str,torch.Tensor],
features:Dict[str,torch.Tensor],
ground_truth:List[Dict[str, torch.Tensor]]) -> Tuple[List[Tuple[int,int]], Dict[int,List[int]]]:
""" """
A class method that first permutate chains in ground truth first Permutates chains in ground truth first
before calculating the loss. before calculating the loss.
Args:
out: a dictionary of output tensors from model.forward()
features: a dictionary of feature tensors that are used as input for model.forward()
ground_truth: a list of dictionaries of features corresponding to chains in ground truth structure e.g. it will be a length of 5 if there are 5 chains in ground truth structure
Returns:
best_align: a list of tuple(int,int) that instructs how ground truth chains should be permutated
per_asym_residue_index: per_asym_residue_index: a dictionary recording which residues belong to which aysm_id
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper: Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2 https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
""" """
unique_asym_ids = set(torch.unique(features['asym_id']).tolist()) unique_asym_ids = set(torch.unique(features['asym_id']).tolist())
unique_asym_ids.discard(0) # Remove padding asym_id unique_asym_ids.discard(0) # Remove padding asym_id
...@@ -397,13 +479,19 @@ def compute_permutation_alignment(out, features, ground_truth): ...@@ -397,13 +479,19 @@ def compute_permutation_alignment(out, features, ground_truth):
return best_align, per_asym_residue_index return best_align, per_asym_residue_index
def multi_chain_permutation_align(out, features, ground_truth): def multi_chain_permutation_align(out:Dict[str,torch.Tensor],
"""Compute multi-chain permutation alignment. features:Dict[str,torch.Tensor],
ground_truth:List[Dict[str, torch.Tensor]])->Dict[str,torch.Tensor]:
"""
Compute multi-chain permutation alignment.
Args: Args:
out: The output of model.forward() out: a dictionary of output tensors from model.forward()
features: Input features features: a dictionary of feature tensors that are used as input for model.forward()
ground_truth: Ground truth features ground_truth: a list of dictionaries of features corresponding to chains in ground truth structure e.g. it will be a length of 5 if there are 5 chains in ground truth structure
Returns:
features: a dictionary with updated ground truth feature tensors, ready for downstream loss calculations.
""" """
labels = split_ground_truth_labels(ground_truth) labels = split_ground_truth_labels(ground_truth)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment