"...AutoBuildImmortalWrt.git" did not exist on "958f8d4079dcaad8302d66e2a3de638eec076178"
test_embedders.py 3.78 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import random
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
17
import torch
import unittest
18
19
from tests.config import consts
from tests.data_utils import random_asym_ids
20
21
from openfold.model.embedders import (
    InputEmbedder,
22
    InputEmbedderMultimer,
23
24
    RecyclingEmbedder,
    TemplateAngleEmbedder,
25
    TemplatePairEmbedder
26
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
27
28
29


class TestInputEmbedder(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
30
    def test_shape(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
31
32
33
34
35
        tf_dim = 2
        msa_dim = 3
        c_z = 5
        c_m = 7
        relpos_k = 11
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
36

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
37
38
39
        b = 13
        n_res = 17
        n_clust = 19
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
40

41
42
43
44
        max_relative_chain = 2
        max_relative_idx = 32
        use_chain_relative = True

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
45
46
47
        tf = torch.rand((b, n_res, tf_dim))
        ri = torch.rand((b, n_res))
        msa = torch.rand((b, n_clust, n_res, msa_dim))
48
49
50
51
52
53
54
55
56
57
        asym_ids_flat = torch.Tensor(random_asym_ids(n_res))
        asym_id = torch.tile(asym_ids_flat.unsqueeze(0), (b, 1))
        entity_id = asym_id
        sym_id = torch.zeros_like(entity_id)

        if consts.is_multimer:
            ie = InputEmbedderMultimer(tf_dim, msa_dim, c_z, c_m,
                                       max_relative_idx=max_relative_idx,
                                       use_chain_relative=use_chain_relative,
                                       max_relative_chain=max_relative_chain)
Christina Floristean's avatar
Christina Floristean committed
58
59
60
            batch = {"target_feat": tf, "residue_index": ri, "msa_feat": msa,
                     "asym_id": asym_id, "entity_id": entity_id, "sym_id": sym_id}
            msa_emb, pair_emb = ie(batch)
61
62
        else:
            ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k)
Christina Floristean's avatar
Christina Floristean committed
63
            msa_emb, pair_emb = ie(tf=tf, ri=ri, msa=msa, inplace_safe=False)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
64

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
65
66
67
68
69
        self.assertTrue(msa_emb.shape == (b, n_clust, n_res, c_m))
        self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z))


class TestRecyclingEmbedder(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
70
    def test_shape(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        batch_size = 2
        n = 3
        c_z = 5
        c_m = 7
        min_bin = 0
        max_bin = 10
        no_bins = 9

        re = RecyclingEmbedder(c_m, c_z, min_bin, max_bin, no_bins)

        m_1 = torch.rand((batch_size, n, c_m))
        z = torch.rand((batch_size, n, n, c_z))
        x = torch.rand((batch_size, n, 3))

        m_1, z = re(m_1, z, x)

        self.assertTrue(z.shape == (batch_size, n, n, c_z))
        self.assertTrue(m_1.shape == (batch_size, n, c_m))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
89

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
90
91
92
93
94
95
96
97
98
99
100
101
102

class TestTemplateAngleEmbedder(unittest.TestCase):
    def test_shape(self):
        template_angle_dim = 51
        c_m = 256
        batch_size = 4
        n_templ = 4
        n_res = 256

        tae = TemplateAngleEmbedder(
            template_angle_dim,
            c_m,
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
103

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
104
105
106
        x = torch.rand((batch_size, n_templ, n_res, template_angle_dim))
        x = tae(x)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
107
        self.assertTrue(x.shape == (batch_size, n_templ, n_res, c_m))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
108
109
110
111
112
113
114
115
116


class TestTemplatePairEmbedder(unittest.TestCase):
    def test_shape(self):
        batch_size = 2
        n_templ = 3
        n_res = 5
        template_pair_dim = 7
        c_t = 11
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
117

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
118
119
120
121
        tpe = TemplatePairEmbedder(
            template_pair_dim,
            c_t,
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
122

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
123
124
125
        x = torch.rand((batch_size, n_templ, n_res, n_res, template_pair_dim))
        x = tpe(x)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
126
        self.assertTrue(x.shape == (batch_size, n_templ, n_res, n_res, c_t))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
127
128
129
130


if __name__ == "__main__":
    unittest.main()