Commit f3c1af45 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

fixed error with anchor selections; updated permutation unittest

parent 15f1fa63
......@@ -1819,12 +1819,13 @@ def get_least_asym_entity_or_longest_length(batch):
if len(least_asym_entities) > 1:
least_asym_entities = random.choice(least_asym_entities)
assert len(least_asym_entities)==1
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])
# best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])
# If there is more than one chain in the predicted output that has the same sequence
# as the chosen ground truth anchor, then randomly picke one
if len(best_pred_asym) > 1:
best_pred_asym = random.choice(best_pred_asym)
# # If there is more than one chain in the predicted output that has the same sequence
# # as the chosen ground truth anchor, then randomly picke one
# if len(best_pred_asym) > 1:
# best_pred_asym = random.choice(best_pred_asym)
best_pred_asym = least_asym_entities[0]
return least_asym_entities[0], best_pred_asym
......@@ -2159,6 +2160,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_masks = [
l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,])
r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,
true_ca_masks,pred_ca_mask,
......@@ -2200,6 +2202,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1]
if not is_monomer:
permutate_chains = True
# first determin which dimension in the tensor to split into individual ground truth labels
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
......
......@@ -102,9 +102,10 @@ class TestPermutation(unittest.TestCase):
batch['all_atom_mask'] = true_atom_mask
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(batch)
aligns,_ = AlphaFoldMultimerLoss.multi_chain_perm_align(out,batch,
aligns = AlphaFoldMultimerLoss.multi_chain_perm_align(out,batch,
dim_dict,
permutate_chains=True)
print(f"##### aligns is {aligns}")
possible_outcome = [[(0,1),(1,0),(2,3),(3,4),(4,2)],[(0,0),(1,1),(2,3),(3,4),(4,2)]]
wrong_outcome = [[(0,1),(1,0),(2,4),(3,2),(4,3)],[(0,0),(1,1),(2,2),(3,3),(4,4)]]
self.assertIn(aligns,possible_outcome)
......@@ -151,15 +152,15 @@ class TestPermutation(unittest.TestCase):
tensor_to_cuda = lambda t: t.to('cuda')
batch = tensor_tree_map(tensor_to_cuda,batch)
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(batch)
aligns,per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out,
aligns = AlphaFoldMultimerLoss.multi_chain_perm_align(out,
batch,
dim_dict,
permutate_chains=True)
print(f"##### aligns is {aligns}")
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,dim_dict=dim_dict,
REQUIRED_FEATURES=[i for i in batch.keys() if i in dim_dict])
labels = merge_labels(per_asym_residue_index,labels,aligns,
labels = merge_labels(labels,aligns,
original_nres=batch['aatype'].shape[-1])
self.assertTrue(torch.equal(labels['residue_index'],batch['residue_index']))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment