Commit 68389359 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update permutation logic so that it check all valid anchor pairs

parent 2da285aa
...@@ -1828,7 +1828,6 @@ def get_least_asym_entity_or_longest_length(batch,input_asym_id): ...@@ -1828,7 +1828,6 @@ def get_least_asym_entity_or_longest_length(batch,input_asym_id):
anchor_pred_asym_ids = [id for id in entity_2_asym_list[least_asym_entities] if id in input_asym_id] anchor_pred_asym_ids = [id for id in entity_2_asym_list[least_asym_entities] if id in input_asym_id]
return anchor_gt_asym_id, anchor_pred_asym_ids return anchor_gt_asym_id, anchor_pred_asym_ids
def greedy_align( def greedy_align(
batch, batch,
per_asym_residue_index, per_asym_residue_index,
...@@ -1897,15 +1896,12 @@ def merge_labels(per_asym_residue_index,labels, align,original_nres): ...@@ -1897,15 +1896,12 @@ def merge_labels(per_asym_residue_index,labels, align,original_nres):
cur_out = {} cur_out = {}
for i, j in align: for i, j in align:
label = labels[j][k] label = labels[j][k]
cur_num_res = labels[j]['aatype'].shape[-1]
# to 1-based # to 1-based
cur_residue_index = per_asym_residue_index[i + 1] cur_residue_index = per_asym_residue_index[i + 1]
if len(v.shape)<=1 or "template" in k or "row_mask" in k : if len(v.shape)<=1 or "template" in k or "row_mask" in k :
continue continue
else: else:
dimension_to_merge = label.shape.index(cur_num_res) if cur_num_res in label.shape else 0 dimension_to_merge = 1
if k =='all_atom_positions':
dimension_to_merge=1
cur_out[i] = label.index_select(dimension_to_merge,cur_residue_index) cur_out[i] = label.index_select(dimension_to_merge,cur_residue_index)
cur_out = [x[1] for x in sorted(cur_out.items())] cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0: if len(cur_out)>0:
...@@ -2144,16 +2140,17 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2144,16 +2140,17 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
feature, ground_truth = batch feature, ground_truth = batch
del batch del batch
if permutate_chains: if permutate_chains:
best_rmsd = float('inf')
best_align = None
# First select anchors from predicted structures and ground truths # First select anchors from predicted structures and ground truths
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(ground_truth,feature['asym_id']) anchor_gt_asym, anchor_pred_asym_ids = get_least_asym_entity_or_longest_length(ground_truth,feature['asym_id'])
print(f"########## line 2147 anchor_pred_asym_ids is {anchor_pred_asym_ids} and gt_asym is {anchor_gt_asym}")
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(ground_truth) entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(ground_truth)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth, labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"]) REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list) assert isinstance(labels, list)
del ground_truth del ground_truth
anchor_gt_idx = int(anchor_gt_asym) - 1 anchor_gt_idx = int(anchor_gt_asym) - 1
asym_mask = (feature["asym_id"] == anchor_pred_asym).bool()
# Then calculate optimal transform by aligning anchors # Then calculate optimal transform by aligning anchors
ca_idx = rc.atom_order["CA"] ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3] pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
...@@ -2165,28 +2162,36 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2165,28 +2162,36 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_masks = [ true_ca_masks = [
l["all_atom_mask"][..., ca_idx].long() for l in labels l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,]) ] # list([nres,])
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature) per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
anchor_gt_residue = per_asym_residue_index[int(anchor_gt_asym)] for candidate_pred_anchor in anchor_pred_asym_ids:
r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses, asym_mask = (feature["asym_id"] == candidate_pred_anchor).bool()
anchor_gt_idx,anchor_gt_residue, anchor_gt_residue = per_asym_residue_index[int(candidate_pred_anchor)]
true_ca_masks,pred_ca_mask, r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses,
asym_mask, anchor_gt_idx,anchor_gt_residue,
pred_ca_pos true_ca_masks,pred_ca_mask,
) asym_mask,
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms pred_ca_pos
del true_ca_poses,r,x )
gc.collect() aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
align = greedy_align( align = greedy_align(
feature, feature,
per_asym_residue_index, per_asym_residue_index,
entity_2_asym_list, entity_2_asym_list,
pred_ca_pos, pred_ca_pos,
pred_ca_mask, pred_ca_mask,
aligned_true_ca_poses, aligned_true_ca_poses,
true_ca_masks, true_ca_masks,
) )
merged_labels = merge_labels(per_asym_residue_index,labels,align,
original_nres=feature['aatype'].shape[-1])
rmsd = compute_rmsd(true_atom_pos = merged_labels['all_atom_positions'][..., ca_idx, :].to(r.dtype) @ r + x,
pred_atom_pos = pred_ca_pos,
atom_mask = (pred_ca_mask * merged_labels['all_atom_mask'][..., ca_idx].long()).bool())
if rmsd < best_rmsd:
best_rmsd = rmsd
best_align = align
print(f"##### 2193 rmsd is {rmsd} and anchor_gt_asym is {anchor_gt_asym} and candidate_pred_anchor is {candidate_pred_anchor}")
del r,x
del true_ca_masks,aligned_true_ca_poses del true_ca_masks,aligned_true_ca_poses
del pred_ca_pos, pred_ca_mask del pred_ca_pos, pred_ca_mask
gc.collect() gc.collect()
...@@ -2195,9 +2200,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss): ...@@ -2195,9 +2200,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature) per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth, labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"]) REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
align = list(enumerate(range(len(labels)))) best_align = list(enumerate(range(len(labels))))
return align, per_asym_residue_index return best_align, per_asym_residue_index
def forward(self, out, batch, _return_breakdown=False,permutate_chains=True): def forward(self, out, batch, _return_breakdown=False,permutate_chains=True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment