Commit 734ce915 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added test for delete_extra_msa

parent 09564595
...@@ -11,7 +11,7 @@ import unittest ...@@ -11,7 +11,7 @@ import unittest
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \ from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \
correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa, \ correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa, \
crop_extra_msa crop_extra_msa, delete_extra_msa
from openfold.config import model_config from openfold.config import model_config
...@@ -146,6 +146,17 @@ class TestDataTransforms(unittest.TestCase): ...@@ -146,6 +146,17 @@ class TestDataTransforms(unittest.TestCase):
if "extra_" + k in protein: if "extra_" + k in protein:
assert protein["extra_" + k].shape[0] == min(max_extra_msa, num_seq) assert protein["extra_" + k].shape[0] == min(max_extra_msa, num_seq)
def test_delete_extra_msa(self):
protein = {'extra_msa': torch.rand((512, 100, 23))}
extra_msa_has_deletion_shape = list(protein['extra_msa'].shape)
extra_msa_has_deletion_shape[2] = 1
protein['extra_deletion_matrix'] = torch.rand(extra_msa_has_deletion_shape)
protein = delete_extra_msa(protein)
print(protein)
for k in MSA_FEATURE_NAMES:
assert 'extra_' + k not in protein
assert 'extra_msa' not in protein
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment