test_template.py 10.2 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Christina Floristean's avatar
Christina Floristean committed
15
import re
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
17
18
import torch
import numpy as np
import unittest
19
20
21
22
23
24
25
26
from openfold.model.template import (
    TemplatePointwiseAttention,
    TemplatePairStack,
)
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_template_feats

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
27
if compare_utils.alphafold_is_installed():
28
29
30
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
31
32
33


class TestTemplatePointwiseAttention(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
34
    def test_shape(self):
35
36
37
38
        batch_size = consts.batch_size
        n_seq = consts.n_seq
        c_t = consts.c_t
        c_z = consts.c_z
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
39
40
        c = 26
        no_heads = 13
41
42
        n_res = consts.n_res
        inf = 1e7
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
43

44
        tpa = TemplatePointwiseAttention(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
45
            c_t, c_z, c, no_heads, inf=inf
46
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
47

48
49
        t = torch.rand((batch_size, n_seq, n_res, n_res, c_t))
        z = torch.rand((batch_size, n_res, n_res, c_z))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
50

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
51
        z_update = tpa(t, z, chunk_size=None)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
52
53
54
55
56

        self.assertTrue(z_update.shape == z.shape)


class TestTemplatePairStack(unittest.TestCase):
57
58
59
60
61
62
63
64
65
66
67
68
69
    @classmethod
    def setUpClass(cls):
        if consts.is_multimer:
            cls.am_atom = alphafold.model.all_atom_multimer
            cls.am_fold = alphafold.model.folding_multimer
            cls.am_modules = alphafold.model.modules_multimer
            cls.am_rigid = alphafold.model.geometry
        else:
            cls.am_atom = alphafold.model.all_atom
            cls.am_fold = alphafold.model.folding
            cls.am_modules = alphafold.model.modules
            cls.am_rigid = alphafold.model.r3

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
70
    def test_shape(self):
71
72
        batch_size = consts.batch_size
        c_t = consts.c_t
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
73
74
75
76
77
78
        c_hidden_tri_att = 7
        c_hidden_tri_mul = 7
        no_blocks = 2
        no_heads = 4
        pt_inner_dim = 15
        dropout = 0.25
79
80
        n_templ = consts.n_templ
        n_res = consts.n_res
81
        tri_mul_first = consts.is_multimer
Christina Floristean's avatar
Christina Floristean committed
82
        fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
83
        blocks_per_ckpt = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
84
        chunk_size = 4
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
85
86
        inf = 1e7
        eps = 1e-7
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
87
88
89
90
91
92
93
94
95

        tpe = TemplatePairStack(
            c_t,
            c_hidden_tri_att=c_hidden_tri_att,
            c_hidden_tri_mul=c_hidden_tri_mul,
            no_blocks=no_blocks,
            no_heads=no_heads,
            pair_transition_n=pt_inner_dim,
            dropout_rate=dropout,
96
            tri_mul_first=tri_mul_first,
Christina Floristean's avatar
Christina Floristean committed
97
            fuse_projection_weights=fuse_projection_weights,
98
99
100
            blocks_per_ckpt=None,
            inf=inf,
            eps=eps,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
101
102
103
104
105
        )

        t = torch.rand((batch_size, n_templ, n_res, n_res, c_t))
        mask = torch.randint(0, 2, (batch_size, n_templ, n_res, n_res))
        shape_before = t.shape
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
106
        t = tpe(t, mask, chunk_size=chunk_size)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
107
108
109
110
        shape_after = t.shape

        self.assertTrue(shape_before == shape_after)

111
112
113
114
115
    @compare_utils.skip_unless_alphafold_installed()
    def test_compare(self):
        def run_template_pair_stack(pair_act, pair_mask):
            config = compare_utils.get_alphafold_config()
            c_ee = config.model.embeddings_and_evoformer
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

            if consts.is_multimer:
                safe_key = alphafold.model.prng.SafeKey(hk.next_rng_key())
                template_iteration = self.am_modules.TemplateEmbeddingIteration(
                    c_ee.template.template_pair_stack,
                    config.model.global_config,
                    name='template_embedding_iteration')

                def template_iteration_fn(x):
                    act, safe_key = x

                    safe_key, safe_subkey = safe_key.split()
                    act = template_iteration(
                        act=act,
                        pair_mask=pair_mask,
                        is_training=False,
                        safe_key=safe_subkey)
                    return (act, safe_key)

                if config.model.global_config.use_remat:
                    template_iteration_fn = hk.remat(template_iteration_fn)

                safe_key, safe_subkey = safe_key.split()
                template_stack = alphafold.model.layer_stack.layer_stack(
                    c_ee.template.template_pair_stack.num_block)(
                    template_iteration_fn)
                act, _ = template_stack((pair_act, safe_subkey))
            else:
                tps = self.am_modules.TemplatePairStack(
                    c_ee.template.template_pair_stack,
                    config.model.global_config,
                    name="template_pair_stack",
                )
                act = tps(pair_act, pair_mask, is_training=False)
150
151
152
            ln = hk.LayerNorm([-1], True, True, name="output_layer_norm")
            act = ln(act)
            return act
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
153

154
        f = hk.transform(run_template_pair_stack)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
155

156
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
157

158
159
160
161
        pair_act = np.random.rand(n_res, n_res, consts.c_t).astype(np.float32)
        pair_mask = np.random.randint(
            low=0, high=2, size=(n_res, n_res)
        ).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
162

163
164
165
166
167
168
169
170
171
172
        if consts.is_multimer:
            params = compare_utils.fetch_alphafold_module_weights(
                "alphafold/alphafold_iteration/evoformer/template_embedding/"
                + "single_template_embedding/template_embedding_iteration"
            )
        else:
            params = compare_utils.fetch_alphafold_module_weights(
                "alphafold/alphafold_iteration/evoformer/template_embedding/"
                + "single_template_embedding/template_pair_stack"
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
173
174
175
176
177
        params.update(
            compare_utils.fetch_alphafold_module_weights(
                "alphafold/alphafold_iteration/evoformer/template_embedding/"
                + "single_template_embedding/output_layer_norm"
            )
178
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
179

180
181
182
183
        out_gt = f.apply(
            params, jax.random.PRNGKey(42), pair_act, pair_mask
        ).block_until_ready()
        out_gt = torch.as_tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
184

185
        model = compare_utils.get_global_pretrained_openfold()
186
        out_repro = model.template_embedder.template_pair_stack(
187
188
            torch.as_tensor(pair_act).unsqueeze(-4).cuda(),
            torch.as_tensor(pair_mask).unsqueeze(-3).cuda(),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
189
            chunk_size=None,
190
191
            _mask_trans=False,
        ).cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
192

193
194
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
195

196
class Template(unittest.TestCase):
197
198
199
200
201
202
203
204
205
206
207
208
209
    @classmethod
    def setUpClass(cls):
        if consts.is_multimer:
            cls.am_atom = alphafold.model.all_atom_multimer
            cls.am_fold = alphafold.model.folding_multimer
            cls.am_modules = alphafold.model.modules_multimer
            cls.am_rigid = alphafold.model.geometry
        else:
            cls.am_atom = alphafold.model.all_atom
            cls.am_fold = alphafold.model.folding
            cls.am_modules = alphafold.model.modules
            cls.am_rigid = alphafold.model.r3

210
211
    @compare_utils.skip_unless_alphafold_installed()
    def test_compare(self):
Christina Floristean's avatar
Christina Floristean committed
212
        def test_template_embedding(pair, batch, mask_2d, mc_mask_2d):
213
            config = compare_utils.get_alphafold_config()
214
            te = self.am_modules.TemplateEmbedding(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
215
216
                config.model.embeddings_and_evoformer.template,
                config.model.global_config,
217
            )
218
219

            if consts.is_multimer:
Christina Floristean's avatar
Christina Floristean committed
220
                act = te(pair, batch, mask_2d, multichain_mask_2d=mc_mask_2d, is_training=False)
221
222
            else:
                act = te(pair, batch, mask_2d, is_training=False)
223
            return act
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
224

225
        f = hk.transform(test_template_embedding)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
226

227
228
        n_res = consts.n_res
        n_templ = consts.n_templ
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
229

230
231
        pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
        batch = random_template_feats(n_templ, n_res)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
232
        batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
233

Christina Floristean's avatar
Christina Floristean committed
234
        multichain_mask_2d = None
235
236
237
238
239
240
        if consts.is_multimer:
            asym_id = batch['asym_id'][0]
            multichain_mask_2d = (
                    asym_id[..., None] == asym_id[..., None, :]
            ).astype(np.float32)

241
242
243
244
245
        pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
        # Fetch pretrained parameters (but only from one block)]
        params = compare_utils.fetch_alphafold_module_weights(
            "alphafold/alphafold_iteration/evoformer/template_embedding"
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
246

247
        out_gt = f.apply(
Christina Floristean's avatar
Christina Floristean committed
248
            params, jax.random.PRNGKey(42), pair_act, batch, pair_mask, multichain_mask_2d
249
250
        ).block_until_ready()
        out_gt = torch.as_tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
251

252
253
        inds = np.random.randint(0, 21, (n_res,))
        batch["target_feat"] = np.eye(22)[inds]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
254

255
        model = compare_utils.get_global_pretrained_openfold()
256
257
258
259
260
261
262
263
264

        template_feats = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
        if consts.is_multimer:
            out_repro = model.template_embedder(
                template_feats,
                torch.as_tensor(pair_act).cuda(),
                torch.as_tensor(pair_mask).cuda(),
                templ_dim=0,
                chunk_size=consts.chunk_size,
Christina Floristean's avatar
Christina Floristean committed
265
266
267
                multichain_mask_2d=torch.as_tensor(multichain_mask_2d).cuda(),
                use_lma=False,
                inplace_safe=False
268
269
270
271
272
273
274
            )
        else:
            out_repro = model.template_embedder(
                template_feats,
                torch.as_tensor(pair_act).cuda(),
                torch.as_tensor(pair_mask).cuda(),
                templ_dim=0,
Christina Floristean's avatar
Christina Floristean committed
275
276
277
                chunk_size=consts.chunk_size,
                use_lma=False,
                inplace_safe=False
278
279
            )

280
281
        out_repro = out_repro["template_pair_embedding"]
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
282

283
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
284
285
286


if __name__ == "__main__":
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
287
    unittest.main()