"vscode:/vscode.git/clone" did not exist on "dba446122164aee843ac3d7d303f09ad1f29a0f6"
Commit bf8788c7 authored by Jennifer's avatar Jennifer
Browse files

debugging for permutation test

parent de12c0ea
......@@ -90,15 +90,15 @@ def get_optimal_transform(
def get_least_asym_entity_or_longest_length(batch, input_asym_id):
"""
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor
First check how many subunit(s) one sequence has. Select the subunit that is less
common, e.g. if the protein was AABBB then select one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
Args:
batch: in this funtion batch is the full ground truth features
input_asym_id: A list of aym_ids that are in the cropped input features
batch: in this function batch is the full ground truth features
input_asym_id: A list of asym_ids that are in the cropped input features
Return:
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
......@@ -126,7 +126,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
min_asym_count = min(entity_asym_count.values())
least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count]
# If multiple entities have the least asym_id count, return those with the shortest length
# If multiple entities have the least asym_id count, return those with the longest length
if len(least_asym_entities) > 1:
max_length = max([entity_length[entity] for entity in least_asym_entities])
least_asym_entities = [entity for entity in least_asym_entities if entity_length[entity] == max_length]
......
......@@ -21,7 +21,7 @@ from openfold.utils.multi_chain_permutation import (pad_features, get_least_asym
merge_labels)
@unittest.skip("Tests need to be fixed post-refactor")
# @unittest.skip("Tests need to be fixed post-refactor")
class TestPermutation(unittest.TestCase):
def setUp(self):
"""
......@@ -65,10 +65,39 @@ class TestPermutation(unittest.TestCase):
'seq_length': torch.tensor([57])
}
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id'])
self.assertIn(int(anchor_gt_asym), [1, 2])
self.assertNotIn(int(anchor_gt_asym), [3, 4, 5])
self.assertIn(int(anchor_pred_asym), [1, 2])
self.assertNotIn(int(anchor_pred_asym), [3, 4, 5])
anchor_gt_asym = int(anchor_gt_asym)
anchor_pred_asym = {int(i) for i in anchor_pred_asym}
expected_anchors = {1, 2}
expected_non_anchors = {3, 4, 5}
self.assertIn(anchor_gt_asym, expected_anchors)
self.assertNotIn(anchor_gt_asym, expected_non_anchors)
# Check that predicted anchors are within expected anchor set
self.assertEqual(anchor_pred_asym, expected_anchors & anchor_pred_asym)
self.assertEqual(set(), anchor_pred_asym & expected_non_anchors)
def test_1_selecting_anchors_with_padding(self):
# This test fails because it's looking for 0 as the
nres_pad = 325 - 57 # suppose the cropping size is 325
batch = {
'asym_id': pad_features(self.asym_id, nres_pad, pad_dim=1),
'sym_id': pad_features(self.sym_id, nres_pad, pad_dim=1),
'entity_id': pad_features(self.entity_id, nres_pad, pad_dim=1),
'aatype': torch.randint(21, size=(1, 325)),
'seq_length': torch.tensor([57])
}
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id'])
anchor_gt_asym = int(anchor_gt_asym)
anchor_pred_asym = {int(i) for i in anchor_pred_asym}
expected_anchors = {1, 2}
expected_non_anchors = {3, 4, 5}
self.assertIn(anchor_gt_asym, expected_anchors)
self.assertNotIn(anchor_gt_asym, expected_non_anchors)
# Check that predicted anchors are within expected anchor set
self.assertEqual(anchor_pred_asym, expected_anchors & anchor_pred_asym)
self.assertEqual(set(), anchor_pred_asym & expected_non_anchors)
def test_2_permutation_pentamer(self):
batch = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment