"...include/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "3727a9007512507e247fdab69c62e039f29ff80b"
test_template.py 10.1 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import numpy as np
import unittest
18
19
20
21
22
23
24
25
from openfold.model.template import (
    TemplatePointwiseAttention,
    TemplatePairStack,
)
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_template_feats

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
26
if compare_utils.alphafold_is_installed():
27
28
29
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
30
31
32


class TestTemplatePointwiseAttention(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
33
    def test_shape(self):
34
35
36
37
        batch_size = consts.batch_size
        n_seq = consts.n_seq
        c_t = consts.c_t
        c_z = consts.c_z
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
38
39
        c = 26
        no_heads = 13
40
41
        n_res = consts.n_res
        inf = 1e7
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42

43
        tpa = TemplatePointwiseAttention(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
44
            c_t, c_z, c, no_heads, inf=inf
45
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
46

47
48
        t = torch.rand((batch_size, n_seq, n_res, n_res, c_t))
        z = torch.rand((batch_size, n_res, n_res, c_z))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
49

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
50
        z_update = tpa(t, z, chunk_size=None)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
51
52
53
54
55

        self.assertTrue(z_update.shape == z.shape)


class TestTemplatePairStack(unittest.TestCase):
56
57
58
59
60
61
62
63
64
65
66
67
68
    @classmethod
    def setUpClass(cls):
        if consts.is_multimer:
            cls.am_atom = alphafold.model.all_atom_multimer
            cls.am_fold = alphafold.model.folding_multimer
            cls.am_modules = alphafold.model.modules_multimer
            cls.am_rigid = alphafold.model.geometry
        else:
            cls.am_atom = alphafold.model.all_atom
            cls.am_fold = alphafold.model.folding
            cls.am_modules = alphafold.model.modules
            cls.am_rigid = alphafold.model.r3

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
69
    def test_shape(self):
70
71
        batch_size = consts.batch_size
        c_t = consts.c_t
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
72
73
74
75
76
77
        c_hidden_tri_att = 7
        c_hidden_tri_mul = 7
        no_blocks = 2
        no_heads = 4
        pt_inner_dim = 15
        dropout = 0.25
78
79
        n_templ = consts.n_templ
        n_res = consts.n_res
80
        tri_mul_first = consts.is_multimer
81
        blocks_per_ckpt = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
82
        chunk_size = 4
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
83
84
        inf = 1e7
        eps = 1e-7
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
85
86
87
88
89
90
91
92
93

        tpe = TemplatePairStack(
            c_t,
            c_hidden_tri_att=c_hidden_tri_att,
            c_hidden_tri_mul=c_hidden_tri_mul,
            no_blocks=no_blocks,
            no_heads=no_heads,
            pair_transition_n=pt_inner_dim,
            dropout_rate=dropout,
94
            tri_mul_first=tri_mul_first,
95
96
97
            blocks_per_ckpt=None,
            inf=inf,
            eps=eps,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
98
99
100
101
102
        )

        t = torch.rand((batch_size, n_templ, n_res, n_res, c_t))
        mask = torch.randint(0, 2, (batch_size, n_templ, n_res, n_res))
        shape_before = t.shape
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
103
        t = tpe(t, mask, chunk_size=chunk_size)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
104
105
106
107
        shape_after = t.shape

        self.assertTrue(shape_before == shape_after)

108
109
110
111
112
    @compare_utils.skip_unless_alphafold_installed()
    def test_compare(self):
        def run_template_pair_stack(pair_act, pair_mask):
            config = compare_utils.get_alphafold_config()
            c_ee = config.model.embeddings_and_evoformer
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

            if consts.is_multimer:
                safe_key = alphafold.model.prng.SafeKey(hk.next_rng_key())
                template_iteration = self.am_modules.TemplateEmbeddingIteration(
                    c_ee.template.template_pair_stack,
                    config.model.global_config,
                    name='template_embedding_iteration')

                def template_iteration_fn(x):
                    act, safe_key = x

                    safe_key, safe_subkey = safe_key.split()
                    act = template_iteration(
                        act=act,
                        pair_mask=pair_mask,
                        is_training=False,
                        safe_key=safe_subkey)
                    return (act, safe_key)

                if config.model.global_config.use_remat:
                    template_iteration_fn = hk.remat(template_iteration_fn)

                safe_key, safe_subkey = safe_key.split()
                template_stack = alphafold.model.layer_stack.layer_stack(
                    c_ee.template.template_pair_stack.num_block)(
                    template_iteration_fn)
                act, _ = template_stack((pair_act, safe_subkey))
            else:
                tps = self.am_modules.TemplatePairStack(
                    c_ee.template.template_pair_stack,
                    config.model.global_config,
                    name="template_pair_stack",
                )
                act = tps(pair_act, pair_mask, is_training=False)
147
148
149
            ln = hk.LayerNorm([-1], True, True, name="output_layer_norm")
            act = ln(act)
            return act
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
150

151
        f = hk.transform(run_template_pair_stack)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
152

153
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
154

155
156
157
158
        pair_act = np.random.rand(n_res, n_res, consts.c_t).astype(np.float32)
        pair_mask = np.random.randint(
            low=0, high=2, size=(n_res, n_res)
        ).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
159

160
161
162
163
164
165
166
167
168
169
        if consts.is_multimer:
            params = compare_utils.fetch_alphafold_module_weights(
                "alphafold/alphafold_iteration/evoformer/template_embedding/"
                + "single_template_embedding/template_embedding_iteration"
            )
        else:
            params = compare_utils.fetch_alphafold_module_weights(
                "alphafold/alphafold_iteration/evoformer/template_embedding/"
                + "single_template_embedding/template_pair_stack"
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
170
171
172
173
174
        params.update(
            compare_utils.fetch_alphafold_module_weights(
                "alphafold/alphafold_iteration/evoformer/template_embedding/"
                + "single_template_embedding/output_layer_norm"
            )
175
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
176

177
178
179
180
        out_gt = f.apply(
            params, jax.random.PRNGKey(42), pair_act, pair_mask
        ).block_until_ready()
        out_gt = torch.as_tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
181

182
        model = compare_utils.get_global_pretrained_openfold()
183
        out_repro = model.template_embedder.template_pair_stack(
184
185
            torch.as_tensor(pair_act).unsqueeze(-4).cuda(),
            torch.as_tensor(pair_mask).unsqueeze(-3).cuda(),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
186
            chunk_size=None,
187
188
            _mask_trans=False,
        ).cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
189

190
191
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
192

193
class Template(unittest.TestCase):
194
195
196
197
198
199
200
201
202
203
204
205
206
    @classmethod
    def setUpClass(cls):
        if consts.is_multimer:
            cls.am_atom = alphafold.model.all_atom_multimer
            cls.am_fold = alphafold.model.folding_multimer
            cls.am_modules = alphafold.model.modules_multimer
            cls.am_rigid = alphafold.model.geometry
        else:
            cls.am_atom = alphafold.model.all_atom
            cls.am_fold = alphafold.model.folding
            cls.am_modules = alphafold.model.modules
            cls.am_rigid = alphafold.model.r3

207
208
    @compare_utils.skip_unless_alphafold_installed()
    def test_compare(self):
Christina Floristean's avatar
Christina Floristean committed
209
        def test_template_embedding(pair, batch, mask_2d, mc_mask_2d):
210
            config = compare_utils.get_alphafold_config()
211
            te = self.am_modules.TemplateEmbedding(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
212
213
                config.model.embeddings_and_evoformer.template,
                config.model.global_config,
214
            )
215
216

            if consts.is_multimer:
Christina Floristean's avatar
Christina Floristean committed
217
                act = te(pair, batch, mask_2d, multichain_mask_2d=mc_mask_2d, is_training=False)
218
219
            else:
                act = te(pair, batch, mask_2d, is_training=False)
220
            return act
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
221

222
        f = hk.transform(test_template_embedding)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
223

224
225
        n_res = consts.n_res
        n_templ = consts.n_templ
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
226

227
228
        pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
        batch = random_template_feats(n_templ, n_res)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
229
        batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
230

Christina Floristean's avatar
Christina Floristean committed
231
        multichain_mask_2d = None
232
233
234
235
236
237
        if consts.is_multimer:
            asym_id = batch['asym_id'][0]
            multichain_mask_2d = (
                    asym_id[..., None] == asym_id[..., None, :]
            ).astype(np.float32)

238
239
240
241
242
        pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
        # Fetch pretrained parameters (but only from one block)]
        params = compare_utils.fetch_alphafold_module_weights(
            "alphafold/alphafold_iteration/evoformer/template_embedding"
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
243

244
        out_gt = f.apply(
Christina Floristean's avatar
Christina Floristean committed
245
            params, jax.random.PRNGKey(42), pair_act, batch, pair_mask, multichain_mask_2d
246
247
        ).block_until_ready()
        out_gt = torch.as_tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
248

249
250
        inds = np.random.randint(0, 21, (n_res,))
        batch["target_feat"] = np.eye(22)[inds]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
251

252
        model = compare_utils.get_global_pretrained_openfold()
253
254
255
256
257
258
259
260
261

        template_feats = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
        if consts.is_multimer:
            out_repro = model.template_embedder(
                template_feats,
                torch.as_tensor(pair_act).cuda(),
                torch.as_tensor(pair_mask).cuda(),
                templ_dim=0,
                chunk_size=consts.chunk_size,
Christina Floristean's avatar
Christina Floristean committed
262
263
264
                multichain_mask_2d=torch.as_tensor(multichain_mask_2d).cuda(),
                use_lma=False,
                inplace_safe=False
265
266
267
268
269
270
271
            )
        else:
            out_repro = model.template_embedder(
                template_feats,
                torch.as_tensor(pair_act).cuda(),
                torch.as_tensor(pair_mask).cuda(),
                templ_dim=0,
Christina Floristean's avatar
Christina Floristean committed
272
273
274
                chunk_size=consts.chunk_size,
                use_lma=False,
                inplace_safe=False
275
276
            )

277
278
        out_repro = out_repro["template_pair_embedding"]
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
279

280
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
281
282
283


if __name__ == "__main__":
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
284
    unittest.main()