"vscode:/vscode.git/clone" did not exist on "3fdecc873991c7696a55eb518225b2bd85cbaac2"
Commit 02bf2e83 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added test for make_all_atom_aatype

parent bc7864b1
......@@ -9,7 +9,7 @@ import numpy
import torch
import unittest
from data.data_transforms import make_seq_mask, add_distillation_flag
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype
from openfold.config import model_config
......@@ -32,6 +32,17 @@ class TestDataTransforms(unittest.TestCase):
assert 'is_distillation' in protein
assert protein['is_distillation'] is True
def test_make_all_atom_aatype(self):
seq = torch.tensor([range(20)], dtype=torch.int64).transpose(0, 1)
seq_one_hot = torch.FloatTensor(seq.shape[0], 20).zero_()
seq_one_hot.scatter_(1, seq, 1)
protein_aatype = torch.tensor(seq_one_hot)
protein = {'aatype': protein_aatype}
protein = make_all_atom_aatype(protein)
print(protein)
assert 'all_atom_aatype' in protein
assert protein['all_atom_aatype'].shape == protein['aatype'].shape
if __name__ == '__main__':
unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment