Commit a9c1715b authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added test for fix_templates_aatype

parent 02bf2e83
......@@ -9,7 +9,7 @@ import numpy
import torch
import unittest
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype
from openfold.config import model_config
......@@ -43,6 +43,18 @@ class TestDataTransforms(unittest.TestCase):
assert 'all_atom_aatype' in protein
assert protein['all_atom_aatype'].shape == protein['aatype'].shape
def test_fix_templates_aatype(self):
template_seq = torch.tensor(list(range(20))*2, dtype=torch.int64)
template_seq = template_seq.unsqueeze(0).transpose(0, 1)
template_seq_one_hot = torch.FloatTensor(template_seq.shape[0], 20).zero_()
template_seq_one_hot.scatter_(1, template_seq, 1)
template_aatype = torch.tensor(template_seq_one_hot).unsqueeze(0)
protein = {'template_aatype': template_aatype}
protein = fix_templates_aatype(protein)
print(protein)
template_seq_ours = torch.tensor([[0, 4, 3, 6, 13, 7, 8, 9, 11, 10, 12, 2, 14, 5, 1, 15, 16, 19, 17, 18]*2])
assert torch.all(torch.eq(protein['template_aatype'], template_seq_ours))
if __name__ == '__main__':
unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment