Commit 15f1fa63 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

cleaned up and split into smaller functions

parent ea7fcced
......@@ -1830,7 +1830,6 @@ def get_least_asym_entity_or_longest_length(batch):
def greedy_align(
batch,
unique_asym_ids,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
......@@ -1841,6 +1840,7 @@ def greedy_align(
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
"""
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
used = [False for _ in range(len(true_ca_poses))]
align = []
for cur_asym_id in unique_asym_ids:
......@@ -1860,17 +1860,13 @@ def greedy_align(
if cropped_pos.shape==cur_pred_pos.shape:
mask = true_ca_masks[j]
mask = torch.squeeze(mask,0)
print(f"cropped_pos shape: {cropped_pos.shape} cur_pred_pos shape: {cur_pred_pos.shape}")
print(f"mask shape: {mask.shape} and cur_pred_mask shape: {cur_pred_mask.shape} ")
rmsd = compute_rmsd(
torch.squeeze(cropped_pos,0), torch.squeeze(cur_pred_pos,0),
(cur_pred_mask * mask).bool()
)
print(f"rmsd is {rmsd}")
if (rmsd is not None) and (rmsd < best_rmsd):
best_rmsd = rmsd
best_idx = j
print(f"best_idx is {best_idx}")
assert best_idx is not None
used[best_idx] = True
......@@ -1887,9 +1883,10 @@ def pad_features(feature_tensor,nres_pad,pad_dim):
return torch.concat((feature_tensor,padding_tensor),dim=pad_dim)
def merge_labels(per_asym_residue_index, labels, align,original_nres):
def merge_labels(labels, align,original_nres):
"""
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
Merge ground truth labels according to the permutation results
labels: list of original ground truth feats
align: list of tuples, each entry specify the corresponding label of the asym.
......@@ -2051,7 +2048,8 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
"""
Splits ground truth features according to chains
Returns a list of feature dictionaries with only necessary ground truth features
Returns:
a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation
"""
unique_asym_ids, asym_id_counts = torch.unique(batch["asym_id"], sorted=False,return_counts=True)
......@@ -2066,6 +2064,70 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=dim_dict[k])] for k, value in batch.items() if k in REQUIRED_FEATURES])))
return labels
@staticmethod
def get_entity_2_asym_list(batch):
"""
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.
Args:
batch (dict): A dictionary containing data batches, including "entity_id" and "asym_id" tensors.
Returns:
entity_2_asym_list (dict): A dictionary where keys are unique entity IDs, and values are lists of unique asymmetry IDs
associated with each entity.
"""
entity_2_asym_list = {}
unique_entity_ids = torch.unique(batch["entity_id"])
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
return entity_2_asym_list
@staticmethod
def calculate_input_mask(true_ca_masks,anchor_gt_idx,
asym_mask,pred_ca_mask):
"""
Calculate an input mask for downstream optimal transformation computation
Args:
true_ca_masks (Tensor): ca mask from ground truth.
anchor_gt_idx (Tensor): The index of selected ground truth anchor.
asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor.
pred_ca_mask (Tensor): ca mask from predicted structure.
Returns:
input_mask (Tensor): A boolean mask
"""
pred_ca_mask = torch.squeeze(pred_ca_mask,0)
asym_mask = torch.squeeze(asym_mask,0)
anchor_pred_mask = pred_ca_mask[asym_mask]
anchor_true_mask = true_ca_masks[anchor_gt_idx]
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
return input_mask
@staticmethod
def calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,
true_ca_masks,pred_ca_mask,
asym_mask,
pred_ca_pos):
input_mask = AlphaFoldMultimerLoss.calculate_input_mask(true_ca_masks,
anchor_gt_idx,asym_mask,
pred_ca_mask)
input_mask = torch.squeeze(input_mask,0)
pred_ca_pos = torch.squeeze(pred_ca_pos,0)
asym_mask = torch.squeeze(asym_mask,0)
anchor_true_pos = true_ca_poses[anchor_gt_idx]
anchor_pred_pos = pred_ca_pos[asym_mask]
r, x = get_optimal_transform(
anchor_pred_pos, torch.squeeze(anchor_true_pos,0),
mask=input_mask
)
return r, x
@staticmethod
def multi_chain_perm_align(out, batch, dim_dict,permutate_chains=False):
"""
......@@ -2078,6 +2140,15 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,dim_dict=dim_dict,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list)
if permutate_chains:
# First select anchors from predicted structures and ground truths
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
anchor_gt_idx = int(anchor_gt_asym) - 1
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
# Then calculate optimal transform by aligning anchors
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres]
......@@ -2085,56 +2156,22 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_poses = [
l["all_atom_positions"][..., ca_idx, :] for l in labels
] # list([nres, 3])
true_ca_masks = [
l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,])
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask)
if permutate_chains:
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym:{anchor_gt_asym} anchor_pred_asym:{anchor_pred_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1
unique_entity_ids = torch.unique(batch["entity_id"])
entity_2_asym_list = {}
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
anchor_true_pos = true_ca_poses[anchor_gt_idx]
anchor_pred_pos = pred_ca_pos[0][asym_mask[0]]
anchor_true_mask = true_ca_masks[anchor_gt_idx]
anchor_pred_mask = pred_ca_mask[0][asym_mask[0]]
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
r, x = get_optimal_transform(
anchor_pred_pos, anchor_true_pos[0],
mask=input_mask[0]
r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,
true_ca_masks,pred_ca_mask,
asym_mask,
pred_ca_pos
)
del input_mask # just to save memory
del anchor_pred_mask
del anchor_true_mask
gc.collect()
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
del true_ca_poses
del true_ca_poses,r,x
gc.collect()
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(batch)
align = greedy_align(
batch,
unique_asym_ids,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
......@@ -2142,16 +2179,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_masks,
)
del aligned_true_ca_poses, true_ca_masks
del r, x
del true_ca_masks,aligned_true_ca_poses
del pred_ca_pos, pred_ca_mask
del anchor_pred_pos, anchor_true_pos
gc.collect()
print(f"finished multi-chain permutation and final align is {align}")
else:
align = list(enumerate(range(len(labels))))
return align, per_asym_residue_index
return align
def forward(self, out, features, _return_breakdown=False,permutate_chains=True):
"""
......@@ -2170,13 +2204,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
# Then permutate ground truth chains before calculating the loss
align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict,
align = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict,
permutate_chains=permutate_chains)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict,
REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict])
# reorder ground truth labels according to permutation results
labels = merge_labels(per_asym_residue_index,labels,align,
labels = merge_labels(labels,align,
original_nres=features['aatype'].shape[-1])
features.update(labels)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment