Commit e683dce3 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added test for make_hhbits_profile

parent a3b5c162
......@@ -11,7 +11,7 @@ import unittest
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \
correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa, \
crop_extra_msa, delete_extra_msa, nearest_neighbor_clusters, make_msa_mask
crop_extra_msa, delete_extra_msa, nearest_neighbor_clusters, make_msa_mask, make_hhblits_profile
from openfold.config import model_config
......@@ -180,6 +180,15 @@ class TestDataTransforms(unittest.TestCase):
assert 'msa_row_mask' in protein
assert protein['msa_row_mask'].shape[0] == msa_mat.shape[0]
def test_make_hhblits_profile(self):
with open('../test_data/features.pkl', 'rb') as file:
features = pickle.load(file)
protein = {'msa': torch.tensor(features['msa'], dtype=torch.int64)}
protein = make_hhblits_profile(protein)
assert 'hhblits_profile' in protein
assert protein['hhblits_profile'].shape == torch.Size((protein['msa'].shape[1], 22))
if __name__ == '__main__':
unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment