test_template.py 9.91 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import numpy as np
import unittest
18
19
20
21
22
23
24
25
from openfold.model.template import (
    TemplatePointwiseAttention,
    TemplatePairStack,
)
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_template_feats

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
26
if compare_utils.alphafold_is_installed():
27
28
29
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
30
31
32


class TestTemplatePointwiseAttention(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
33
    def test_shape(self):
34
35
36
37
        batch_size = consts.batch_size
        n_seq = consts.n_seq
        c_t = consts.c_t
        c_z = consts.c_z
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
38
39
        c = 26
        no_heads = 13
40
41
        n_res = consts.n_res
        inf = 1e7
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42

43
        tpa = TemplatePointwiseAttention(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
44
            c_t, c_z, c, no_heads, inf=inf
45
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
46

47
48
        t = torch.rand((batch_size, n_seq, n_res, n_res, c_t))
        z = torch.rand((batch_size, n_res, n_res, c_z))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
49

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
50
        z_update = tpa(t, z, chunk_size=None)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
51
52
53
54
55

        self.assertTrue(z_update.shape == z.shape)


class TestTemplatePairStack(unittest.TestCase):
56
57
58
59
60
61
62
63
64
65
66
67
68
    @classmethod
    def setUpClass(cls):
        if consts.is_multimer:
            cls.am_atom = alphafold.model.all_atom_multimer
            cls.am_fold = alphafold.model.folding_multimer
            cls.am_modules = alphafold.model.modules_multimer
            cls.am_rigid = alphafold.model.geometry
        else:
            cls.am_atom = alphafold.model.all_atom
            cls.am_fold = alphafold.model.folding
            cls.am_modules = alphafold.model.modules
            cls.am_rigid = alphafold.model.r3

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
69
    def test_shape(self):
70
71
        batch_size = consts.batch_size
        c_t = consts.c_t
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
72
73
74
75
76
77
        c_hidden_tri_att = 7
        c_hidden_tri_mul = 7
        no_blocks = 2
        no_heads = 4
        pt_inner_dim = 15
        dropout = 0.25
78
79
        n_templ = consts.n_templ
        n_res = consts.n_res
80
        tri_mul_first = consts.is_multimer
81
        blocks_per_ckpt = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
82
        chunk_size = 4
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
83
84
        inf = 1e7
        eps = 1e-7
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
85
86
87
88
89
90
91
92
93

        tpe = TemplatePairStack(
            c_t,
            c_hidden_tri_att=c_hidden_tri_att,
            c_hidden_tri_mul=c_hidden_tri_mul,
            no_blocks=no_blocks,
            no_heads=no_heads,
            pair_transition_n=pt_inner_dim,
            dropout_rate=dropout,
94
            tri_mul_first=tri_mul_first,
95
96
97
            blocks_per_ckpt=None,
            inf=inf,
            eps=eps,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
98
99
100
101
102
        )

        t = torch.rand((batch_size, n_templ, n_res, n_res, c_t))
        mask = torch.randint(0, 2, (batch_size, n_templ, n_res, n_res))
        shape_before = t.shape
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
103
        t = tpe(t, mask, chunk_size=chunk_size)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
104
105
106
107
        shape_after = t.shape

        self.assertTrue(shape_before == shape_after)

108
109
110
111
112
    @compare_utils.skip_unless_alphafold_installed()
    def test_compare(self):
        def run_template_pair_stack(pair_act, pair_mask):
            config = compare_utils.get_alphafold_config()
            c_ee = config.model.embeddings_and_evoformer
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

            if consts.is_multimer:
                safe_key = alphafold.model.prng.SafeKey(hk.next_rng_key())
                template_iteration = self.am_modules.TemplateEmbeddingIteration(
                    c_ee.template.template_pair_stack,
                    config.model.global_config,
                    name='template_embedding_iteration')

                def template_iteration_fn(x):
                    act, safe_key = x

                    safe_key, safe_subkey = safe_key.split()
                    act = template_iteration(
                        act=act,
                        pair_mask=pair_mask,
                        is_training=False,
                        safe_key=safe_subkey)
                    return (act, safe_key)

                if config.model.global_config.use_remat:
                    template_iteration_fn = hk.remat(template_iteration_fn)

                safe_key, safe_subkey = safe_key.split()
                template_stack = alphafold.model.layer_stack.layer_stack(
                    c_ee.template.template_pair_stack.num_block)(
                    template_iteration_fn)
                act, _ = template_stack((pair_act, safe_subkey))
            else:
                tps = self.am_modules.TemplatePairStack(
                    c_ee.template.template_pair_stack,
                    config.model.global_config,
                    name="template_pair_stack",
                )
                act = tps(pair_act, pair_mask, is_training=False)
147
148
149
            ln = hk.LayerNorm([-1], True, True, name="output_layer_norm")
            act = ln(act)
            return act
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
150

151
        f = hk.transform(run_template_pair_stack)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
152

153
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
154

155
156
157
158
        pair_act = np.random.rand(n_res, n_res, consts.c_t).astype(np.float32)
        pair_mask = np.random.randint(
            low=0, high=2, size=(n_res, n_res)
        ).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
159

160
161
162
163
164
165
166
167
168
169
        if consts.is_multimer:
            params = compare_utils.fetch_alphafold_module_weights(
                "alphafold/alphafold_iteration/evoformer/template_embedding/"
                + "single_template_embedding/template_embedding_iteration"
            )
        else:
            params = compare_utils.fetch_alphafold_module_weights(
                "alphafold/alphafold_iteration/evoformer/template_embedding/"
                + "single_template_embedding/template_pair_stack"
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
170
171
172
173
174
        params.update(
            compare_utils.fetch_alphafold_module_weights(
                "alphafold/alphafold_iteration/evoformer/template_embedding/"
                + "single_template_embedding/output_layer_norm"
            )
175
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
176

177
178
179
180
        out_gt = f.apply(
            params, jax.random.PRNGKey(42), pair_act, pair_mask
        ).block_until_ready()
        out_gt = torch.as_tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
181

182
        model = compare_utils.get_global_pretrained_openfold()
183
        out_repro = model.template_embedder.template_pair_stack(
184
185
            torch.as_tensor(pair_act).unsqueeze(-4).cuda(),
            torch.as_tensor(pair_mask).unsqueeze(-3).cuda(),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
186
            chunk_size=None,
187
188
            _mask_trans=False,
        ).cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
189

190
191
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
192

193
class Template(unittest.TestCase):
194
195
196
197
198
199
200
201
202
203
204
205
206
    @classmethod
    def setUpClass(cls):
        if consts.is_multimer:
            cls.am_atom = alphafold.model.all_atom_multimer
            cls.am_fold = alphafold.model.folding_multimer
            cls.am_modules = alphafold.model.modules_multimer
            cls.am_rigid = alphafold.model.geometry
        else:
            cls.am_atom = alphafold.model.all_atom
            cls.am_fold = alphafold.model.folding
            cls.am_modules = alphafold.model.modules
            cls.am_rigid = alphafold.model.r3

207
208
209
210
    @compare_utils.skip_unless_alphafold_installed()
    def test_compare(self):
        def test_template_embedding(pair, batch, mask_2d):
            config = compare_utils.get_alphafold_config()
211
            te = self.am_modules.TemplateEmbedding(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
212
213
                config.model.embeddings_and_evoformer.template,
                config.model.global_config,
214
            )
215
216
217
218
219

            if consts.is_multimer:
                act = te(pair, batch, mask_2d, multichain_mask_2d=multichain_mask_2d, is_training=False)
            else:
                act = te(pair, batch, mask_2d, is_training=False)
220
            return act
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
221

222
        f = hk.transform(test_template_embedding)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
223

224
225
        n_res = consts.n_res
        n_templ = consts.n_templ
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
226

227
228
        pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
        batch = random_template_feats(n_templ, n_res)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
229
        batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
230
231
232
233
234
235
236
237

        if consts.is_multimer:
            asym_id = batch['asym_id'][0]
            multichain_mask_2d = (
                    asym_id[..., None] == asym_id[..., None, :]
            ).astype(np.float32)
            batch["multichain_mask_2d"] = multichain_mask_2d

238
239
240
241
242
        pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
        # Fetch pretrained parameters (but only from one block)]
        params = compare_utils.fetch_alphafold_module_weights(
            "alphafold/alphafold_iteration/evoformer/template_embedding"
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
243

244
245
246
247
        out_gt = f.apply(
            params, jax.random.PRNGKey(42), pair_act, batch, pair_mask
        ).block_until_ready()
        out_gt = torch.as_tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
248

249
250
        inds = np.random.randint(0, 21, (n_res,))
        batch["target_feat"] = np.eye(22)[inds]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
251

252
        model = compare_utils.get_global_pretrained_openfold()
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272

        template_feats = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
        if consts.is_multimer:
            out_repro = model.template_embedder(
                template_feats,
                torch.as_tensor(pair_act).cuda(),
                torch.as_tensor(pair_mask).cuda(),
                templ_dim=0,
                chunk_size=consts.chunk_size,
                multichain_mask_2d=multichain_mask_2d,
            )
        else:
            out_repro = model.template_embedder(
                template_feats,
                torch.as_tensor(pair_act).cuda(),
                torch.as_tensor(pair_mask).cuda(),
                templ_dim=0,
                chunk_size=consts.chunk_size
            )

273
274
        out_repro = out_repro["template_pair_embedding"]
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
275

276
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
277
278
279


if __name__ == "__main__":
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
280
    unittest.main()