Commit e1c7c9e7 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added test for make_atom14_masks

parent f33faf84
...@@ -612,17 +612,18 @@ def make_atom14_masks(protein): ...@@ -612,17 +612,18 @@ def make_atom14_masks(protein):
dtype=torch.float32, dtype=torch.float32,
device=protein["aatype"].device, device=protein["aatype"].device,
) )
protein_aatype = protein['aatype'].to(torch.long)
# create the mapping for (residx, atom14) --> atom37, i.e. an array # create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein # with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein["aatype"]] residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
residx_atom14_mask = restype_atom14_mask[protein["aatype"]] residx_atom14_mask = restype_atom14_mask[protein_aatype]
protein["atom14_atom_exists"] = residx_atom14_mask protein["atom14_atom_exists"] = residx_atom14_mask
protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long() protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
# create the gather indices for mapping back # create the gather indices for mapping back
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein["aatype"]] residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long() protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
# create the corresponding mask # create the corresponding mask
...@@ -636,7 +637,7 @@ def make_atom14_masks(protein): ...@@ -636,7 +637,7 @@ def make_atom14_masks(protein):
atom_type = rc.atom_order[atom_name] atom_type = rc.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1 restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[protein["aatype"]] residx_atom37_mask = restype_atom37_mask[protein_aatype]
protein["atom37_atom_exists"] = residx_atom37_mask protein["atom37_atom_exists"] = residx_atom37_mask
return protein return protein
......
...@@ -12,7 +12,7 @@ import unittest ...@@ -12,7 +12,7 @@ import unittest
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \ from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \
correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa, \ correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa, \
crop_extra_msa, delete_extra_msa, nearest_neighbor_clusters, make_msa_mask, make_hhblits_profile, make_masked_msa, \ crop_extra_msa, delete_extra_msa, nearest_neighbor_clusters, make_msa_mask, make_hhblits_profile, make_masked_msa, \
make_msa_feat, crop_templates make_msa_feat, crop_templates, make_atom14_masks
from tests.config import config from tests.config import config
...@@ -231,6 +231,18 @@ class TestDataTransforms(unittest.TestCase): ...@@ -231,6 +231,18 @@ class TestDataTransforms(unittest.TestCase):
assert protein['template_aatype'].shape[0] == max_templates assert protein['template_aatype'].shape[0] == max_templates
assert protein['template_all_atom_masks'].shape[0] == max_templates assert protein['template_all_atom_masks'].shape[0] == max_templates
def test_make_atom14_masks(self):
with gzip.open('../test_data/sample_feats.pickle.gz', 'rb') as file:
features = pickle.load(file)
protein = {'aatype': torch.tensor(features['aatype'][0])}
protein = make_atom14_masks(protein)
print(protein)
assert 'atom14_atom_exists' in protein
assert 'residx_atom14_to_atom37' in protein
assert 'residx_atom37_to_atom14' in protein
assert 'atom37_atom_exists' in protein
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment