Commit 02bf2e83 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added test for make_all_atom_aatype

parent bc7864b1
...@@ -9,7 +9,7 @@ import numpy ...@@ -9,7 +9,7 @@ import numpy
import torch import torch
import unittest import unittest
from data.data_transforms import make_seq_mask, add_distillation_flag from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype
from openfold.config import model_config from openfold.config import model_config
...@@ -32,6 +32,17 @@ class TestDataTransforms(unittest.TestCase): ...@@ -32,6 +32,17 @@ class TestDataTransforms(unittest.TestCase):
assert 'is_distillation' in protein assert 'is_distillation' in protein
assert protein['is_distillation'] is True assert protein['is_distillation'] is True
def test_make_all_atom_aatype(self):
seq = torch.tensor([range(20)], dtype=torch.int64).transpose(0, 1)
seq_one_hot = torch.FloatTensor(seq.shape[0], 20).zero_()
seq_one_hot.scatter_(1, seq, 1)
protein_aatype = torch.tensor(seq_one_hot)
protein = {'aatype': protein_aatype}
protein = make_all_atom_aatype(protein)
print(protein)
assert 'all_atom_aatype' in protein
assert protein['all_atom_aatype'].shape == protein['aatype'].shape
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment