"deploy/vscode:/vscode.git/clone" did not exist on "1f6ccc7f0e356e0589a381fa48b2f25931257eac"
Commit f33faf84 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added test for crop_templates

parent b8e597b0
...@@ -12,7 +12,7 @@ import unittest ...@@ -12,7 +12,7 @@ import unittest
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \ from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \
correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa, \ correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa, \
crop_extra_msa, delete_extra_msa, nearest_neighbor_clusters, make_msa_mask, make_hhblits_profile, make_masked_msa, \ crop_extra_msa, delete_extra_msa, nearest_neighbor_clusters, make_msa_mask, make_hhblits_profile, make_masked_msa, \
make_msa_feat make_msa_feat, crop_templates
from tests.config import config from tests.config import config
...@@ -220,6 +220,17 @@ class TestDataTransforms(unittest.TestCase): ...@@ -220,6 +220,17 @@ class TestDataTransforms(unittest.TestCase):
assert protein['target_feat'].shape == torch.Size((protein['msa'].shape[1], 22)) assert protein['target_feat'].shape == torch.Size((protein['msa'].shape[1], 22))
assert protein['msa_feat'].shape == torch.Size((*protein['msa'].shape, 25)) assert protein['msa_feat'].shape == torch.Size((*protein['msa'].shape, 25))
def test_crop_templates(self):
with gzip.open('../test_data/sample_feats.pickle.gz', 'rb') as f:
features = pickle.load(f)
protein = {'template_aatype': torch.tensor(features['true_msa'][0]),
'template_all_atom_masks': torch.tensor(features['msa_mask'][0])}
max_templates = 2
protein = crop_templates.__wrapped__(protein, max_templates)
assert protein['template_aatype'].shape[0] == max_templates
assert protein['template_all_atom_masks'].shape[0] == max_templates
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment