test_template.py 10.5 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Christina Floristean's avatar
Christina Floristean committed
15
import re
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
17
18
import torch
import numpy as np
import unittest
19
20
21
22
23
24
25
26
from openfold.model.template import (
    TemplatePointwiseAttention,
    TemplatePairStack,
)
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_template_feats

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
27
if compare_utils.alphafold_is_installed():
28
29
30
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
31
32
33


class TestTemplatePointwiseAttention(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
34
    def test_shape(self):
35
36
37
38
        batch_size = consts.batch_size
        n_seq = consts.n_seq
        c_t = consts.c_t
        c_z = consts.c_z
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
39
40
        c = 26
        no_heads = 13
41
42
        n_res = consts.n_res
        inf = 1e7
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
43

44
        tpa = TemplatePointwiseAttention(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
45
            c_t, c_z, c, no_heads, inf=inf
46
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
47

48
49
        t = torch.rand((batch_size, n_seq, n_res, n_res, c_t))
        z = torch.rand((batch_size, n_res, n_res, c_z))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
50

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
51
        z_update = tpa(t, z, chunk_size=None)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
52
53
54
55
56

        self.assertTrue(z_update.shape == z.shape)


class TestTemplatePairStack(unittest.TestCase):
57
58
    @classmethod
    def setUpClass(cls):
59
60
61
62
63
64
65
66
67
68
69
        if compare_utils.alphafold_is_installed():
            if consts.is_multimer:
                cls.am_atom = alphafold.model.all_atom_multimer
                cls.am_fold = alphafold.model.folding_multimer
                cls.am_modules = alphafold.model.modules_multimer
                cls.am_rigid = alphafold.model.geometry
            else:
                cls.am_atom = alphafold.model.all_atom
                cls.am_fold = alphafold.model.folding
                cls.am_modules = alphafold.model.modules
                cls.am_rigid = alphafold.model.r3
70

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
71
    def test_shape(self):
72
73
        batch_size = consts.batch_size
        c_t = consts.c_t
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
74
75
76
77
78
79
        c_hidden_tri_att = 7
        c_hidden_tri_mul = 7
        no_blocks = 2
        no_heads = 4
        pt_inner_dim = 15
        dropout = 0.25
80
81
        n_templ = consts.n_templ
        n_res = consts.n_res
82
        tri_mul_first = consts.is_multimer
Christina Floristean's avatar
Christina Floristean committed
83
        fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
84
        blocks_per_ckpt = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
85
        chunk_size = 4
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
86
87
        inf = 1e7
        eps = 1e-7
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
88
89
90
91
92
93
94
95
96

        tpe = TemplatePairStack(
            c_t,
            c_hidden_tri_att=c_hidden_tri_att,
            c_hidden_tri_mul=c_hidden_tri_mul,
            no_blocks=no_blocks,
            no_heads=no_heads,
            pair_transition_n=pt_inner_dim,
            dropout_rate=dropout,
97
            tri_mul_first=tri_mul_first,
Christina Floristean's avatar
Christina Floristean committed
98
            fuse_projection_weights=fuse_projection_weights,
99
100
101
            blocks_per_ckpt=None,
            inf=inf,
            eps=eps,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
102
103
104
105
106
        )

        t = torch.rand((batch_size, n_templ, n_res, n_res, c_t))
        mask = torch.randint(0, 2, (batch_size, n_templ, n_res, n_res))
        shape_before = t.shape
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
107
        t = tpe(t, mask, chunk_size=chunk_size)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
108
109
110
111
        shape_after = t.shape

        self.assertTrue(shape_before == shape_after)

112
113
114
115
116
    @compare_utils.skip_unless_alphafold_installed()
    def test_compare(self):
        def run_template_pair_stack(pair_act, pair_mask):
            config = compare_utils.get_alphafold_config()
            c_ee = config.model.embeddings_and_evoformer
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

            if consts.is_multimer:
                safe_key = alphafold.model.prng.SafeKey(hk.next_rng_key())
                template_iteration = self.am_modules.TemplateEmbeddingIteration(
                    c_ee.template.template_pair_stack,
                    config.model.global_config,
                    name='template_embedding_iteration')

                def template_iteration_fn(x):
                    act, safe_key = x

                    safe_key, safe_subkey = safe_key.split()
                    act = template_iteration(
                        act=act,
                        pair_mask=pair_mask,
                        is_training=False,
                        safe_key=safe_subkey)
                    return (act, safe_key)

                if config.model.global_config.use_remat:
                    template_iteration_fn = hk.remat(template_iteration_fn)

                safe_key, safe_subkey = safe_key.split()
                template_stack = alphafold.model.layer_stack.layer_stack(
                    c_ee.template.template_pair_stack.num_block)(
                    template_iteration_fn)
                act, _ = template_stack((pair_act, safe_subkey))
            else:
                tps = self.am_modules.TemplatePairStack(
                    c_ee.template.template_pair_stack,
                    config.model.global_config,
                    name="template_pair_stack",
                )
                act = tps(pair_act, pair_mask, is_training=False)
151
152
153
            ln = hk.LayerNorm([-1], True, True, name="output_layer_norm")
            act = ln(act)
            return act
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
154

155
        f = hk.transform(run_template_pair_stack)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
156

157
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
158

159
160
161
162
        pair_act = np.random.rand(n_res, n_res, consts.c_t).astype(np.float32)
        pair_mask = np.random.randint(
            low=0, high=2, size=(n_res, n_res)
        ).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
163

164
165
166
167
168
169
170
171
172
173
        if consts.is_multimer:
            params = compare_utils.fetch_alphafold_module_weights(
                "alphafold/alphafold_iteration/evoformer/template_embedding/"
                + "single_template_embedding/template_embedding_iteration"
            )
        else:
            params = compare_utils.fetch_alphafold_module_weights(
                "alphafold/alphafold_iteration/evoformer/template_embedding/"
                + "single_template_embedding/template_pair_stack"
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
174
175
176
177
178
        params.update(
            compare_utils.fetch_alphafold_module_weights(
                "alphafold/alphafold_iteration/evoformer/template_embedding/"
                + "single_template_embedding/output_layer_norm"
            )
179
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
180

181
182
183
184
        out_gt = f.apply(
            params, jax.random.PRNGKey(42), pair_act, pair_mask
        ).block_until_ready()
        out_gt = torch.as_tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
185

186
        model = compare_utils.get_global_pretrained_openfold()
187
        out_repro = model.template_embedder.template_pair_stack(
188
189
            torch.as_tensor(pair_act).unsqueeze(-4).cuda(),
            torch.as_tensor(pair_mask).unsqueeze(-3).cuda(),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
190
            chunk_size=None,
191
192
            _mask_trans=False,
        ).cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
193

194
195
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
196

197
class Template(unittest.TestCase):
198
199
    @classmethod
    def setUpClass(cls):
200
201
202
203
204
205
206
207
208
209
210
        if compare_utils.alphafold_is_installed():
            if consts.is_multimer:
                cls.am_atom = alphafold.model.all_atom_multimer
                cls.am_fold = alphafold.model.folding_multimer
                cls.am_modules = alphafold.model.modules_multimer
                cls.am_rigid = alphafold.model.geometry
            else:
                cls.am_atom = alphafold.model.all_atom
                cls.am_fold = alphafold.model.folding
                cls.am_modules = alphafold.model.modules
                cls.am_rigid = alphafold.model.r3
211

212
213
    @compare_utils.skip_unless_alphafold_installed()
    def test_compare(self):
Christina Floristean's avatar
Christina Floristean committed
214
        def test_template_embedding(pair, batch, mask_2d, mc_mask_2d):
215
            config = compare_utils.get_alphafold_config()
216
            te = self.am_modules.TemplateEmbedding(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
217
218
                config.model.embeddings_and_evoformer.template,
                config.model.global_config,
219
            )
220
221

            if consts.is_multimer:
Christina Floristean's avatar
Christina Floristean committed
222
                act = te(pair, batch, mask_2d, multichain_mask_2d=mc_mask_2d, is_training=False)
223
224
            else:
                act = te(pair, batch, mask_2d, is_training=False)
225
            return act
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
226

227
        f = hk.transform(test_template_embedding)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
228

229
230
        n_res = consts.n_res
        n_templ = consts.n_templ
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
231

232
233
        pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
        batch = random_template_feats(n_templ, n_res)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
234
        batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
235

Christina Floristean's avatar
Christina Floristean committed
236
        multichain_mask_2d = None
237
238
239
240
241
242
        if consts.is_multimer:
            asym_id = batch['asym_id'][0]
            multichain_mask_2d = (
                    asym_id[..., None] == asym_id[..., None, :]
            ).astype(np.float32)

243
244
245
246
247
        pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
        # Fetch pretrained parameters (but only from one block)]
        params = compare_utils.fetch_alphafold_module_weights(
            "alphafold/alphafold_iteration/evoformer/template_embedding"
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
248

249
        out_gt = f.apply(
Christina Floristean's avatar
Christina Floristean committed
250
            params, jax.random.PRNGKey(42), pair_act, batch, pair_mask, multichain_mask_2d
251
252
        ).block_until_ready()
        out_gt = torch.as_tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
253

254
255
        inds = np.random.randint(0, 21, (n_res,))
        batch["target_feat"] = np.eye(22)[inds]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
256

257
        model = compare_utils.get_global_pretrained_openfold()
258
259
260

        template_feats = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
        if consts.is_multimer:
261
            out_repro_all = model.template_embedder(
262
263
264
265
266
                template_feats,
                torch.as_tensor(pair_act).cuda(),
                torch.as_tensor(pair_mask).cuda(),
                templ_dim=0,
                chunk_size=consts.chunk_size,
Christina Floristean's avatar
Christina Floristean committed
267
                multichain_mask_2d=torch.as_tensor(multichain_mask_2d).cuda(),
268
                _mask_trans=False,
Christina Floristean's avatar
Christina Floristean committed
269
270
                use_lma=False,
                inplace_safe=False
271
272
            )
        else:
273
            out_repro_all = model.template_embedder(
274
275
276
277
                template_feats,
                torch.as_tensor(pair_act).cuda(),
                torch.as_tensor(pair_mask).cuda(),
                templ_dim=0,
Christina Floristean's avatar
Christina Floristean committed
278
                chunk_size=consts.chunk_size,
279
                mask_trans=False,
Christina Floristean's avatar
Christina Floristean committed
280
281
                use_lma=False,
                inplace_safe=False
282
283
            )

284
        out_repro = out_repro_all["template_pair_embedding"]
285
        out_repro = out_repro.cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
286

287
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
288
289
290


if __name__ == "__main__":
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
291
    unittest.main()