Commit 67f873e7 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

updated optimal transform function now

parent bd82338e
...@@ -1789,12 +1789,6 @@ def get_least_asym_entity_or_longest_length(batch): ...@@ -1789,12 +1789,6 @@ def get_least_asym_entity_or_longest_length(batch):
If there is a tie, e.g. AABB, first check which sequence is the longer/longest, If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor then choose one of the corresponding subunits as anchor
""" """
REQUIRED_FEATURES = ['entity_id','asym_id']
seq_length = batch['seq_length'].item()
# remove padding part before selecting candidate
remove_padding = lambda t: torch.index_select(t,dim=1,index=torch.arange(seq_length,device=t.device))
batch = {k:tensor_tree_map(remove_padding,batch[k]) for k in REQUIRED_FEATURES}
unique_entity_ids = torch.unique(batch["entity_id"]) unique_entity_ids = torch.unique(batch["entity_id"])
entity_asym_count = {} entity_asym_count = {}
entity_length = {} entity_length = {}
...@@ -1819,13 +1813,13 @@ def get_least_asym_entity_or_longest_length(batch): ...@@ -1819,13 +1813,13 @@ def get_least_asym_entity_or_longest_length(batch):
if len(least_asym_entities) > 1: if len(least_asym_entities) > 1:
least_asym_entities = random.choice(least_asym_entities) least_asym_entities = random.choice(least_asym_entities)
assert len(least_asym_entities)==1 assert len(least_asym_entities)==1
# best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]]) best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])
# # If there is more than one chain in the predicted output that has the same sequence # # If there is more than one chain in the predicted output that has the same sequence
# # as the chosen ground truth anchor, then randomly picke one # # as the chosen ground truth anchor, then randomly picke one
# if len(best_pred_asym) > 1: if len(best_pred_asym) > 1:
# best_pred_asym = random.choice(best_pred_asym) best_pred_asym = random.choice(best_pred_asym)
best_pred_asym = least_asym_entities[0]
return least_asym_entities[0], best_pred_asym return least_asym_entities[0], best_pred_asym
...@@ -2037,15 +2031,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2037,15 +2031,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
def __init__(self, config): def __init__(self, config):
super(AlphaFoldMultimerLoss, self).__init__(config) super(AlphaFoldMultimerLoss, self).__init__(config)
self.config = config self.config = config
@staticmethod
def determine_split_dim(batch)->dict:
"""A method to determine which dimension to split in split_ground_truth_labels"""
padded_dim = batch['aatype'].shape[-1]
dim_dict = {k:list(v.shape).index(padded_dim) for k,v in batch.items() if padded_dim in v.shape}
return dim_dict
@staticmethod @staticmethod
def split_ground_truth_labels(batch,REQUIRED_FEATURES,dim_dict): def split_ground_truth_labels(batch,REQUIRED_FEATURES,split_dim=1):
""" """
Splits ground truth features according to chains Splits ground truth features according to chains
...@@ -2062,9 +2050,19 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2062,9 +2050,19 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
unique_asym_ids.append(padding_asym_id) unique_asym_ids.append(padding_asym_id)
asym_id_counts.append(padding_asym_counts) asym_id_counts.append(padding_asym_counts)
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=dim_dict[k])] for k, value in batch.items() if k in REQUIRED_FEATURES]))) labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=split_dim)] for k, value in batch.items() if k in REQUIRED_FEATURES])))
return labels return labels
@staticmethod
def get_per_asym_residue_index(features):
unique_asym_ids = [i for i in torch.unique(features["asym_id"]) if i!=0]
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (features["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(features["residue_index"], asym_mask)
return per_asym_residue_index
@staticmethod @staticmethod
def get_entity_2_asym_list(batch): def get_entity_2_asym_list(batch):
""" """
...@@ -2086,7 +2084,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2086,7 +2084,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
return entity_2_asym_list return entity_2_asym_list
@staticmethod @staticmethod
def calculate_input_mask(true_ca_masks,anchor_gt_idx, def calculate_input_mask(true_ca_masks,anchor_gt_idx,anchor_gt_residue,
asym_mask,pred_ca_mask): asym_mask,pred_ca_mask):
""" """
Calculate an input mask for downstream optimal transformation computation Calculate an input mask for downstream optimal transformation computation
...@@ -2103,24 +2101,24 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2103,24 +2101,24 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
pred_ca_mask = torch.squeeze(pred_ca_mask,0) pred_ca_mask = torch.squeeze(pred_ca_mask,0)
asym_mask = torch.squeeze(asym_mask,0) asym_mask = torch.squeeze(asym_mask,0)
anchor_pred_mask = pred_ca_mask[asym_mask] anchor_pred_mask = pred_ca_mask[asym_mask]
anchor_true_mask = true_ca_masks[anchor_gt_idx] anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx],1,anchor_gt_residue)
input_mask = (anchor_true_mask * anchor_pred_mask).bool() input_mask = (anchor_true_mask * anchor_pred_mask).bool()
return input_mask return input_mask
@staticmethod @staticmethod
def calculate_optimal_transform(true_ca_poses, def calculate_optimal_transform(true_ca_poses,
anchor_gt_idx, anchor_gt_idx,anchor_gt_residue,
true_ca_masks,pred_ca_mask, true_ca_masks,pred_ca_mask,
asym_mask, asym_mask,
pred_ca_pos): pred_ca_pos):
input_mask = AlphaFoldMultimerLoss.calculate_input_mask(true_ca_masks, input_mask = AlphaFoldMultimerLoss.calculate_input_mask(true_ca_masks,
anchor_gt_idx,asym_mask, anchor_gt_idx,anchor_gt_residue,
asym_mask,
pred_ca_mask) pred_ca_mask)
input_mask = torch.squeeze(input_mask,0) input_mask = torch.squeeze(input_mask,0)
pred_ca_pos = torch.squeeze(pred_ca_pos,0) pred_ca_pos = torch.squeeze(pred_ca_pos,0)
asym_mask = torch.squeeze(asym_mask,0) asym_mask = torch.squeeze(asym_mask,0)
anchor_true_pos = true_ca_poses[anchor_gt_idx] anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx],1,anchor_gt_residue)
anchor_pred_pos = pred_ca_pos[asym_mask] anchor_pred_pos = pred_ca_pos[asym_mask]
r, x = get_optimal_transform( r, x = get_optimal_transform(
anchor_pred_pos, torch.squeeze(anchor_true_pos,0), anchor_pred_pos, torch.squeeze(anchor_true_pos,0),
...@@ -2130,7 +2128,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2130,7 +2128,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
return r, x return r, x
@staticmethod @staticmethod
def multi_chain_perm_align(out, batch, dim_dict,permutate_chains=False): def multi_chain_perm_align(out, batch,permutate_chains=False):
""" """
A class method that first permutate chains in ground truth first A class method that first permutate chains in ground truth first
before calculating the loss. before calculating the loss.
...@@ -2138,17 +2136,21 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2138,17 +2136,21 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper: Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2 https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
""" """
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,dim_dict=dim_dict, feature, ground_truth = batch
del batch
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"]) REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list) assert isinstance(labels, list)
print(f"successfully split ground truth labels")
if permutate_chains: if permutate_chains:
# First select anchors from predicted structures and ground truths # First select anchors from predicted structures and ground truths
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(ground_truth)
print(f"###### anchor gt asym is: {anchor_gt_asym} and anchor pred asym is {anchor_pred_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1 anchor_gt_idx = int(anchor_gt_asym) - 1
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool() asym_mask = (feature["asym_id"] == anchor_pred_asym).bool()
print(f"###### asym_mask is {asym_mask}")
# Then calculate optimal transform by aligning anchors # Then calculate optimal transform by aligning anchors
ca_idx = rc.atom_order["CA"] ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3] pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
...@@ -2161,8 +2163,11 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2161,8 +2163,11 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
l["all_atom_mask"][..., ca_idx].long() for l in labels l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,]) ] # list([nres,])
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
anchor_gt_residue = per_asym_residue_index[int(anchor_gt_asym)]
print(f"######## per_asym_residue_index is {per_asym_residue_index}")
r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses, r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses,
anchor_gt_idx, anchor_gt_idx,anchor_gt_residue,
true_ca_masks,pred_ca_mask, true_ca_masks,pred_ca_mask,
asym_mask, asym_mask,
pred_ca_pos pred_ca_pos
...@@ -2170,7 +2175,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2170,7 +2175,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
del true_ca_poses,r,x del true_ca_poses,r,x
gc.collect() gc.collect()
print(f"$$$$$$$ successfully calculated r and x")
import sys
sys.exit()
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(batch) entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(batch)
align = greedy_align( align = greedy_align(
batch, batch,
...@@ -2189,7 +2196,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2189,7 +2196,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
return align return align
def forward(self, out, features, _return_breakdown=False,permutate_chains=True): def forward(self, out, batch, _return_breakdown=False,permutate_chains=True):
""" """
Overwrite AlphaFoldLoss forward function so that Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation it first compute multi-chain permutation
...@@ -2199,18 +2206,18 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2199,18 +2206,18 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
batch: a pair of input features and its corresponding ground truth structure batch: a pair of input features and its corresponding ground truth structure
""" """
# first check if it is a monomer # first check if it is a monomer
features, ground_truth = batch
del batch
is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1] is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1]
print(f"asym_id is {features['asym_id']}")
if not is_monomer: if not is_monomer:
permutate_chains = True permutate_chains = True
# first determin which dimension in the tensor to split into individual ground truth labels
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
# Then permutate ground truth chains before calculating the loss # Then permutate ground truth chains before calculating the loss
align = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict, align = AlphaFoldMultimerLoss.multi_chain_perm_align(out, (features,ground_truth),permutate_chains=permutate_chains)
permutate_chains=permutate_chains)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict, labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,
REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict]) REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict])
# reorder ground truth labels according to permutation results # reorder ground truth labels according to permutation results
labels = merge_labels(labels,align, labels = merge_labels(labels,align,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment