"lib/vscode:/vscode.git/clone" did not exist on "1fb31d6a7024202f45e71aff8dc0c0fbc79b20e2"
test_embedders.py 2.83 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import numpy as np
import unittest
18
19
20
21
22
23
from openfold.model.embedders import (
    InputEmbedder,
    RecyclingEmbedder,
    TemplateAngleEmbedder,
    TemplatePairEmbedder,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
24
25
26


class TestInputEmbedder(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
27
    def test_shape(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
28
29
30
31
32
        tf_dim = 2
        msa_dim = 3
        c_z = 5
        c_m = 7
        relpos_k = 11
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
33

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
34
35
36
        b = 13
        n_res = 17
        n_clust = 19
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
37

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
38
39
40
        tf = torch.rand((b, n_res, tf_dim))
        ri = torch.rand((b, n_res))
        msa = torch.rand((b, n_clust, n_res, msa_dim))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
41

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42
        ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
43

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
44
45
46
47
48
49
        msa_emb, pair_emb = ie(tf, ri, msa)
        self.assertTrue(msa_emb.shape == (b, n_clust, n_res, c_m))
        self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z))


class TestRecyclingEmbedder(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
50
    def test_shape(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        batch_size = 2
        n = 3
        c_z = 5
        c_m = 7
        min_bin = 0
        max_bin = 10
        no_bins = 9

        re = RecyclingEmbedder(c_m, c_z, min_bin, max_bin, no_bins)

        m_1 = torch.rand((batch_size, n, c_m))
        z = torch.rand((batch_size, n, n, c_z))
        x = torch.rand((batch_size, n, 3))

        m_1, z = re(m_1, z, x)

        self.assertTrue(z.shape == (batch_size, n, n, c_z))
        self.assertTrue(m_1.shape == (batch_size, n, c_m))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
69

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
70
71
72
73
74
75
76
77
78
79
80
81
82

class TestTemplateAngleEmbedder(unittest.TestCase):
    def test_shape(self):
        template_angle_dim = 51
        c_m = 256
        batch_size = 4
        n_templ = 4
        n_res = 256

        tae = TemplateAngleEmbedder(
            template_angle_dim,
            c_m,
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
83

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
84
85
86
        x = torch.rand((batch_size, n_templ, n_res, template_angle_dim))
        x = tae(x)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
87
        self.assertTrue(x.shape == (batch_size, n_templ, n_res, c_m))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
88
89
90
91
92
93
94
95
96


class TestTemplatePairEmbedder(unittest.TestCase):
    def test_shape(self):
        batch_size = 2
        n_templ = 3
        n_res = 5
        template_pair_dim = 7
        c_t = 11
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
97

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
98
99
100
101
        tpe = TemplatePairEmbedder(
            template_pair_dim,
            c_t,
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
102

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
103
104
105
        x = torch.rand((batch_size, n_templ, n_res, n_res, template_pair_dim))
        x = tpe(x)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
106
        self.assertTrue(x.shape == (batch_size, n_templ, n_res, n_res, c_t))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
107
108
109
110


if __name__ == "__main__":
    unittest.main()