Commit cec5a426 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

added test3 to test if permutated tensors end up as expected

parent c5f16efc
......@@ -2095,6 +2095,7 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask)
if permutate_chains:
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym:{anchor_gt_asym} anchor_pred_asym:{anchor_pred_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1
unique_entity_ids = torch.unique(batch["entity_id"])
......@@ -2154,22 +2155,19 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
out: the output of model.forward()
batch: a pair of input features and its corresponding ground truth structure
"""
_is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1]
# first determin which dimension in the tensor to split into individual ground truth labels
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
# Then permutate ground truth chains before calculating the loss
align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict,
permutate_chains=permutate_chains)
if not _is_monomer:
# first determin which dimension in the tensor to split into individual ground truth labels
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
# Then permutate ground truth chains before calculating the loss
align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict,
permutate_chains=True)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict,
REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict])
# reorder ground truth labels according to permutation results
labels = merge_labels(per_asym_residue_index,labels,align,
original_nres=features['aatype'].shape[-1])
features.update(labels)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict,
REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict])
# reorder ground truth labels according to permutation results
labels = merge_labels(per_asym_residue_index,labels,align,
original_nres=features['aatype'].shape[-1])
features.update(labels)
if (not _return_breakdown):
cum_loss = self.loss(out, features, _return_breakdown)
......
......@@ -151,7 +151,19 @@ class TestPermutation(unittest.TestCase):
tensor_to_cuda = lambda t: t.to('cuda')
batch = tensor_tree_map(tensor_to_cuda,batch)
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(batch)
aligns,_ = AlphaFoldMultimerLoss.multi_chain_perm_align(out,
aligns,per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out,
batch,
dim_dict,
permutate_chains=True)
\ No newline at end of file
permutate_chains=True)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,dim_dict=dim_dict,
REQUIRED_FEATURES=[i for i in batch.keys() if i in dim_dict])
labels = merge_labels(per_asym_residue_index,labels,aligns,
original_nres=batch['aatype'].shape[-1])
self.assertTrue(torch.equal(labels['residue_index'],batch['residue_index']))
expected_permutated_gt_pos = torch.cat((chain_a2_pos,chain_a1_pos,chain_b2_pos,chain_b3_pos,chain_b1_pos),dim=1)
expected_permutated_gt_pos = pad_features(expected_permutated_gt_pos,nres_pad,pad_dim=1)
self.assertTrue(torch.equal(labels['all_atom_positions'],expected_permutated_gt_pos))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment