Commit 0fe5ea31 authored by Jennifer's avatar Jennifer
Browse files

Add fixes to permutation test after refactor

parents 0aa69474 bf8788c7
...@@ -90,15 +90,15 @@ def get_optimal_transform( ...@@ -90,15 +90,15 @@ def get_optimal_transform(
def get_least_asym_entity_or_longest_length(batch, input_asym_id): def get_least_asym_entity_or_longest_length(batch, input_asym_id):
""" """
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select First check how many subunit(s) one sequence has. Select the subunit that is less
one of the A as anchor common, e.g. if the protein was AABBB then select one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest, If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor then choose one of the corresponding subunits as anchor
Args: Args:
batch: in this funtion batch is the full ground truth features batch: in this function batch is the full ground truth features
input_asym_id: A list of aym_ids that are in the cropped input features input_asym_id: A list of asym_ids that are in the cropped input features
Return: Return:
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
...@@ -126,7 +126,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): ...@@ -126,7 +126,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
min_asym_count = min(entity_asym_count.values()) min_asym_count = min(entity_asym_count.values())
least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count] least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count]
# If multiple entities have the least asym_id count, return those with the shortest length # If multiple entities have the least asym_id count, return those with the longest length
if len(least_asym_entities) > 1: if len(least_asym_entities) > 1:
max_length = max([entity_length[entity] for entity in least_asym_entities]) max_length = max([entity_length[entity] for entity in least_asym_entities])
least_asym_entities = [entity for entity in least_asym_entities if entity_length[entity] == max_length] least_asym_entities = [entity for entity in least_asym_entities if entity_length[entity] == max_length]
......
...@@ -21,7 +21,6 @@ from openfold.utils.multi_chain_permutation import (pad_features, get_least_asym ...@@ -21,7 +21,6 @@ from openfold.utils.multi_chain_permutation import (pad_features, get_least_asym
merge_labels) merge_labels)
@unittest.skip("Tests need to be fixed post-refactor")
class TestPermutation(unittest.TestCase): class TestPermutation(unittest.TestCase):
def setUp(self): def setUp(self):
""" """
...@@ -65,10 +64,16 @@ class TestPermutation(unittest.TestCase): ...@@ -65,10 +64,16 @@ class TestPermutation(unittest.TestCase):
'seq_length': torch.tensor([57]) 'seq_length': torch.tensor([57])
} }
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id']) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id'])
self.assertIn(int(anchor_gt_asym), [1, 2]) anchor_gt_asym = int(anchor_gt_asym)
self.assertNotIn(int(anchor_gt_asym), [3, 4, 5]) anchor_pred_asym = {int(i) for i in anchor_pred_asym}
self.assertIn(int(anchor_pred_asym), [1, 2]) expected_anchors = {1, 2}
self.assertNotIn(int(anchor_pred_asym), [3, 4, 5]) expected_non_anchors = {3, 4, 5}
self.assertIn(anchor_gt_asym, expected_anchors)
self.assertNotIn(anchor_gt_asym, expected_non_anchors)
# Check that predicted anchors are within expected anchor set
self.assertEqual(anchor_pred_asym, expected_anchors & anchor_pred_asym)
self.assertEqual(set(), anchor_pred_asym & expected_non_anchors)
def test_2_permutation_pentamer(self): def test_2_permutation_pentamer(self):
batch = { batch = {
...@@ -114,6 +119,7 @@ class TestPermutation(unittest.TestCase): ...@@ -114,6 +119,7 @@ class TestPermutation(unittest.TestCase):
self.assertIn(aligns, possible_outcome) self.assertIn(aligns, possible_outcome)
self.assertNotIn(aligns, wrong_outcome) self.assertNotIn(aligns, wrong_outcome)
@unittest.skip("Test needs to be fixed post-refactor")
def test_3_merge_labels(self): def test_3_merge_labels(self):
nres_pad = 325 - 57 # suppose the cropping size is 325 nres_pad = 325 - 57 # suppose the cropping size is 325
batch = { batch = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment