Commit 974fe5a9 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added test for nearest_neighbor_clusters

parent 734ce915
...@@ -11,7 +11,7 @@ import unittest ...@@ -11,7 +11,7 @@ import unittest
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \ from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \
correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa, \ correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa, \
crop_extra_msa, delete_extra_msa crop_extra_msa, delete_extra_msa, nearest_neighbor_clusters
from openfold.config import model_config from openfold.config import model_config
...@@ -157,6 +157,18 @@ class TestDataTransforms(unittest.TestCase): ...@@ -157,6 +157,18 @@ class TestDataTransforms(unittest.TestCase):
assert 'extra_' + k not in protein assert 'extra_' + k not in protein
assert 'extra_msa' not in protein assert 'extra_msa' not in protein
def test_nearest_neighbor_clusters(self):
with gzip.open('../test_data/sample_feats.pickle.gz', 'rb') as f:
features = pickle.load(f)
protein = {'msa': torch.tensor(features['true_msa'][0], dtype=torch.int64),
'msa_mask': torch.tensor(features['msa_mask'][0], dtype=torch.int64),
'extra_msa': torch.tensor(features['extra_msa'][0], dtype=torch.int64),
'extra_msa_mask': torch.tensor(features['extra_msa_mask'][0], dtype=torch.int64)}
protein = nearest_neighbor_clusters.__wrapped__(protein, 0)
print(protein)
assert 'extra_cluster_assignment' in protein
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment