Commit b8e597b0 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added test for make_msa_feat

parent cdc28a07
...@@ -11,7 +11,8 @@ import unittest ...@@ -11,7 +11,8 @@ import unittest
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \ from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \
correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa, \ correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa, \
crop_extra_msa, delete_extra_msa, nearest_neighbor_clusters, make_msa_mask, make_hhblits_profile, make_masked_msa crop_extra_msa, delete_extra_msa, nearest_neighbor_clusters, make_msa_mask, make_hhblits_profile, make_masked_msa, \
make_msa_feat
from tests.config import config from tests.config import config
...@@ -205,6 +206,20 @@ class TestDataTransforms(unittest.TestCase): ...@@ -205,6 +206,20 @@ class TestDataTransforms(unittest.TestCase):
assert torch.all(torch.eq( assert torch.all(torch.eq(
protein['true_msa'] * (1-protein['bert_mask']), protein['msa'] * (1-protein['bert_mask']))) protein['true_msa'] * (1-protein['bert_mask']), protein['msa'] * (1-protein['bert_mask'])))
def test_make_msa_feat(self):
with open('../test_data/features.pkl', 'rb') as file:
features = pickle.load(file)
protein = {'between_segment_residues': torch.tensor(features['between_segment_residues']),
'msa': torch.tensor(features['msa'], dtype=torch.int64),
'deletion_matrix': torch.tensor(features['deletion_matrix_int']),
'aatype': torch.argmax(torch.tensor(features['aatype']), dim=1)}
protein = make_msa_feat.__wrapped__(protein)
assert 'msa_feat' in protein
assert 'target_feat' in protein
assert protein['target_feat'].shape == torch.Size((protein['msa'].shape[1], 22))
assert protein['msa_feat'].shape == torch.Size((*protein['msa'].shape, 25))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment