"container/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "f122aa4ec1ce10f10919e608572a7e12f24243aa"
Commit 15f1fa63 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

cleaned up and split into smaller functions

parent ea7fcced
...@@ -1830,7 +1830,6 @@ def get_least_asym_entity_or_longest_length(batch): ...@@ -1830,7 +1830,6 @@ def get_least_asym_entity_or_longest_length(batch):
def greedy_align( def greedy_align(
batch, batch,
unique_asym_ids,
entity_2_asym_list, entity_2_asym_list,
pred_ca_pos, pred_ca_pos,
pred_ca_mask, pred_ca_mask,
...@@ -1841,6 +1840,7 @@ def greedy_align( ...@@ -1841,6 +1840,7 @@ def greedy_align(
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper: Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034 Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
""" """
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
used = [False for _ in range(len(true_ca_poses))] used = [False for _ in range(len(true_ca_poses))]
align = [] align = []
for cur_asym_id in unique_asym_ids: for cur_asym_id in unique_asym_ids:
...@@ -1860,17 +1860,13 @@ def greedy_align( ...@@ -1860,17 +1860,13 @@ def greedy_align(
if cropped_pos.shape==cur_pred_pos.shape: if cropped_pos.shape==cur_pred_pos.shape:
mask = true_ca_masks[j] mask = true_ca_masks[j]
mask = torch.squeeze(mask,0) mask = torch.squeeze(mask,0)
print(f"cropped_pos shape: {cropped_pos.shape} cur_pred_pos shape: {cur_pred_pos.shape}")
print(f"mask shape: {mask.shape} and cur_pred_mask shape: {cur_pred_mask.shape} ")
rmsd = compute_rmsd( rmsd = compute_rmsd(
torch.squeeze(cropped_pos,0), torch.squeeze(cur_pred_pos,0), torch.squeeze(cropped_pos,0), torch.squeeze(cur_pred_pos,0),
(cur_pred_mask * mask).bool() (cur_pred_mask * mask).bool()
) )
print(f"rmsd is {rmsd}")
if (rmsd is not None) and (rmsd < best_rmsd): if (rmsd is not None) and (rmsd < best_rmsd):
best_rmsd = rmsd best_rmsd = rmsd
best_idx = j best_idx = j
print(f"best_idx is {best_idx}")
assert best_idx is not None assert best_idx is not None
used[best_idx] = True used[best_idx] = True
...@@ -1887,9 +1883,10 @@ def pad_features(feature_tensor,nres_pad,pad_dim): ...@@ -1887,9 +1883,10 @@ def pad_features(feature_tensor,nres_pad,pad_dim):
return torch.concat((feature_tensor,padding_tensor),dim=pad_dim) return torch.concat((feature_tensor,padding_tensor),dim=pad_dim)
def merge_labels(per_asym_residue_index, labels, align,original_nres): def merge_labels(labels, align,original_nres):
""" """
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex. Merge ground truth labels according to the permutation results
labels: list of original ground truth feats labels: list of original ground truth feats
align: list of tuples, each entry specify the corresponding label of the asym. align: list of tuples, each entry specify the corresponding label of the asym.
...@@ -2051,7 +2048,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2051,7 +2048,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
""" """
Splits ground truth features according to chains Splits ground truth features according to chains
Returns a list of feature dictionaries with only necessary ground truth features Returns:
a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation required to finish multi-chain permutation
""" """
unique_asym_ids, asym_id_counts = torch.unique(batch["asym_id"], sorted=False,return_counts=True) unique_asym_ids, asym_id_counts = torch.unique(batch["asym_id"], sorted=False,return_counts=True)
...@@ -2066,6 +2064,70 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2066,6 +2064,70 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=dim_dict[k])] for k, value in batch.items() if k in REQUIRED_FEATURES]))) labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=dim_dict[k])] for k, value in batch.items() if k in REQUIRED_FEATURES])))
return labels return labels
@staticmethod
def get_entity_2_asym_list(batch):
"""
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.
Args:
batch (dict): A dictionary containing data batches, including "entity_id" and "asym_id" tensors.
Returns:
entity_2_asym_list (dict): A dictionary where keys are unique entity IDs, and values are lists of unique asymmetry IDs
associated with each entity.
"""
entity_2_asym_list = {}
unique_entity_ids = torch.unique(batch["entity_id"])
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
return entity_2_asym_list
@staticmethod
def calculate_input_mask(true_ca_masks,anchor_gt_idx,
asym_mask,pred_ca_mask):
"""
Calculate an input mask for downstream optimal transformation computation
Args:
true_ca_masks (Tensor): ca mask from ground truth.
anchor_gt_idx (Tensor): The index of selected ground truth anchor.
asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor.
pred_ca_mask (Tensor): ca mask from predicted structure.
Returns:
input_mask (Tensor): A boolean mask
"""
pred_ca_mask = torch.squeeze(pred_ca_mask,0)
asym_mask = torch.squeeze(asym_mask,0)
anchor_pred_mask = pred_ca_mask[asym_mask]
anchor_true_mask = true_ca_masks[anchor_gt_idx]
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
return input_mask
@staticmethod
def calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,
true_ca_masks,pred_ca_mask,
asym_mask,
pred_ca_pos):
input_mask = AlphaFoldMultimerLoss.calculate_input_mask(true_ca_masks,
anchor_gt_idx,asym_mask,
pred_ca_mask)
input_mask = torch.squeeze(input_mask,0)
pred_ca_pos = torch.squeeze(pred_ca_pos,0)
asym_mask = torch.squeeze(asym_mask,0)
anchor_true_pos = true_ca_poses[anchor_gt_idx]
anchor_pred_pos = pred_ca_pos[asym_mask]
r, x = get_optimal_transform(
anchor_pred_pos, torch.squeeze(anchor_true_pos,0),
mask=input_mask
)
return r, x
@staticmethod @staticmethod
def multi_chain_perm_align(out, batch, dim_dict,permutate_chains=False): def multi_chain_perm_align(out, batch, dim_dict,permutate_chains=False):
""" """
...@@ -2078,63 +2140,38 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2078,63 +2140,38 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,dim_dict=dim_dict, labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,dim_dict=dim_dict,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"]) REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list) assert isinstance(labels, list)
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres]
true_ca_poses = [
l["all_atom_positions"][..., ca_idx, :] for l in labels
] # list([nres, 3])
true_ca_masks = [
l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,])
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask)
if permutate_chains: if permutate_chains:
# First select anchors from predicted structures and ground truths
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym:{anchor_gt_asym} anchor_pred_asym:{anchor_pred_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1 anchor_gt_idx = int(anchor_gt_asym) - 1
unique_entity_ids = torch.unique(batch["entity_id"])
entity_2_asym_list = {}
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool() asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
anchor_true_pos = true_ca_poses[anchor_gt_idx] # Then calculate optimal transform by aligning anchors
ca_idx = rc.atom_order["CA"]
anchor_pred_pos = pred_ca_pos[0][asym_mask[0]] pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres]
anchor_true_mask = true_ca_masks[anchor_gt_idx]
anchor_pred_mask = pred_ca_mask[0][asym_mask[0]] true_ca_poses = [
l["all_atom_positions"][..., ca_idx, :] for l in labels
input_mask = (anchor_true_mask * anchor_pred_mask).bool() ] # list([nres, 3])
r, x = get_optimal_transform( true_ca_masks = [
anchor_pred_pos, anchor_true_pos[0], l["all_atom_mask"][..., ca_idx].long() for l in labels
mask=input_mask[0] ] # list([nres,])
r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,
true_ca_masks,pred_ca_mask,
asym_mask,
pred_ca_pos
) )
del input_mask # just to save memory
del anchor_pred_mask
del anchor_true_mask
gc.collect()
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
del true_ca_poses del true_ca_poses,r,x
gc.collect() gc.collect()
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(batch)
align = greedy_align( align = greedy_align(
batch, batch,
unique_asym_ids,
entity_2_asym_list, entity_2_asym_list,
pred_ca_pos, pred_ca_pos,
pred_ca_mask, pred_ca_mask,
...@@ -2142,16 +2179,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2142,16 +2179,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_masks, true_ca_masks,
) )
del aligned_true_ca_poses, true_ca_masks del true_ca_masks,aligned_true_ca_poses
del r, x
del pred_ca_pos, pred_ca_mask del pred_ca_pos, pred_ca_mask
del anchor_pred_pos, anchor_true_pos
gc.collect() gc.collect()
print(f"finished multi-chain permutation and final align is {align}")
else: else:
align = list(enumerate(range(len(labels)))) align = list(enumerate(range(len(labels))))
return align, per_asym_residue_index return align
def forward(self, out, features, _return_breakdown=False,permutate_chains=True): def forward(self, out, features, _return_breakdown=False,permutate_chains=True):
""" """
...@@ -2170,13 +2204,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2170,13 +2204,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features) dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
# Then permutate ground truth chains before calculating the loss # Then permutate ground truth chains before calculating the loss
align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict, align = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict,
permutate_chains=permutate_chains) permutate_chains=permutate_chains)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict, labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict,
REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict]) REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict])
# reorder ground truth labels according to permutation results # reorder ground truth labels according to permutation results
labels = merge_labels(per_asym_residue_index,labels,align, labels = merge_labels(labels,align,
original_nres=features['aatype'].shape[-1]) original_nres=features['aatype'].shape[-1])
features.update(labels) features.update(labels)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment