test_data_transforms.py 2.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import copy
import gzip

import os

import pickle

import numpy
import torch
import unittest

12
from data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from openfold.config import model_config


class TestDataTransforms(unittest.TestCase):
    def test_make_seq_mask(self):
        seq = torch.tensor([range(20)], dtype=torch.int64).transpose(0,1)
        seq_one_hot = torch.FloatTensor(seq.shape[0], 20).zero_()
        seq_one_hot.scatter_(1, seq, 1)
        protein_aatype = torch.tensor(seq_one_hot)
        protein = {'aatype': protein_aatype}
        protein = make_seq_mask(protein)
        print(protein)
        assert 'seq_mask' in protein
        assert protein['seq_mask'].shape == torch.Size((seq.shape[0], 20))

28
29
30
31
32
33
34
    def test_add_distillation_flag(self):
        protein = {}
        protein = add_distillation_flag.__wrapped__(protein, True)
        print(protein)
        assert 'is_distillation' in protein
        assert protein['is_distillation'] is True

35
36
37
38
39
40
41
42
43
44
45
    def test_make_all_atom_aatype(self):
        seq = torch.tensor([range(20)], dtype=torch.int64).transpose(0, 1)
        seq_one_hot = torch.FloatTensor(seq.shape[0], 20).zero_()
        seq_one_hot.scatter_(1, seq, 1)
        protein_aatype = torch.tensor(seq_one_hot)
        protein = {'aatype': protein_aatype}
        protein = make_all_atom_aatype(protein)
        print(protein)
        assert 'all_atom_aatype' in protein
        assert protein['all_atom_aatype'].shape == protein['aatype'].shape

46
47
48
49
50
51
52
53
54
55
56
57
    def test_fix_templates_aatype(self):
        template_seq = torch.tensor(list(range(20))*2, dtype=torch.int64)
        template_seq = template_seq.unsqueeze(0).transpose(0, 1)
        template_seq_one_hot = torch.FloatTensor(template_seq.shape[0], 20).zero_()
        template_seq_one_hot.scatter_(1, template_seq, 1)
        template_aatype = torch.tensor(template_seq_one_hot).unsqueeze(0)
        protein = {'template_aatype': template_aatype}
        protein = fix_templates_aatype(protein)
        print(protein)
        template_seq_ours = torch.tensor([[0, 4, 3, 6, 13, 7, 8, 9, 11, 10, 12, 2, 14, 5, 1, 15, 16, 19, 17, 18]*2])
        assert torch.all(torch.eq(protein['template_aatype'], template_seq_ours))

58
59
60
61

if __name__ == '__main__':
    unittest.main()