Commit 10fc9ec0 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added test for correct_msa_restypes

parent a9c1715b
...@@ -9,7 +9,8 @@ import numpy ...@@ -9,7 +9,8 @@ import numpy
import torch import torch
import unittest import unittest
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \
correct_msa_restypes
from openfold.config import model_config from openfold.config import model_config
...@@ -55,6 +56,15 @@ class TestDataTransforms(unittest.TestCase): ...@@ -55,6 +56,15 @@ class TestDataTransforms(unittest.TestCase):
template_seq_ours = torch.tensor([[0, 4, 3, 6, 13, 7, 8, 9, 11, 10, 12, 2, 14, 5, 1, 15, 16, 19, 17, 18]*2]) template_seq_ours = torch.tensor([[0, 4, 3, 6, 13, 7, 8, 9, 11, 10, 12, 2, 14, 5, 1, 15, 16, 19, 17, 18]*2])
assert torch.all(torch.eq(protein['template_aatype'], template_seq_ours)) assert torch.all(torch.eq(protein['template_aatype'], template_seq_ours))
def test_correct_msa_restypes(self):
with open('../features.pkl', 'rb') as file:
features = pickle.load(file)
protein = {'msa': torch.tensor(features['msa'], dtype=torch.int64)}
protein = correct_msa_restypes(protein)
print(protein)
assert torch.all(torch.eq(torch.tensor(features['msa'].shape), torch.tensor(protein['msa'].shape)))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment