test_evoformer.py 11.1 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Christina Floristean's avatar
Christina Floristean committed
15
import re
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
17
18
import torch
import numpy as np
import unittest
19
20
21
22
23
24
25
26
27
from openfold.model.evoformer import (
    MSATransition,
    EvoformerStack,
    ExtraMSAStack,
)
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
28
if compare_utils.alphafold_is_installed():
29
30
31
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
32
33
34


class TestEvoformerStack(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
35
    def test_shape(self):
36
37
38
39
40
        batch_size = consts.batch_size
        n_seq = consts.n_seq
        n_res = consts.n_res
        c_m = consts.c_m
        c_z = consts.c_z
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
41
42
43
44
        c_hidden_msa_att = 12
        c_hidden_opm = 17
        c_hidden_mul = 19
        c_hidden_pair_att = 14
45
        c_s = consts.c_s
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
46
47
48
49
50
51
        no_heads_msa = 3
        no_heads_pair = 7
        no_blocks = 2
        transition_n = 2
        msa_dropout = 0.15
        pair_stack_dropout = 0.25
52
        opm_first = consts.is_multimer
Christina Floristean's avatar
Christina Floristean committed
53
        fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        inf = 1e9
        eps = 1e-10

        es = EvoformerStack(
            c_m,
            c_z,
            c_hidden_msa_att,
            c_hidden_opm,
            c_hidden_mul,
            c_hidden_pair_att,
            c_s,
            no_heads_msa,
            no_heads_pair,
            no_blocks,
            transition_n,
            msa_dropout,
            pair_stack_dropout,
71
            no_column_attention=False,
72
73
            opm_first=opm_first,
            fuse_projection_weights=fuse_projection_weights,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
74
75
76
77
78
            blocks_per_ckpt=None,
            inf=inf,
            eps=eps,
        ).eval()

79
        m = torch.rand((batch_size, n_seq, n_res, c_m))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
80
        z = torch.rand((batch_size, n_res, n_res, c_z))
81
        msa_mask = torch.randint(0, 2, size=(batch_size, n_seq, n_res))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
82
83
84
85
86
        pair_mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))

        shape_m_before = m.shape
        shape_z_before = z.shape

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
87
88
89
        m, z, s = es(
            m, z, chunk_size=4, msa_mask=msa_mask, pair_mask=pair_mask
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
90
91
92
93
94

        self.assertTrue(m.shape == shape_m_before)
        self.assertTrue(z.shape == shape_z_before)
        self.assertTrue(s.shape == (batch_size, n_res, c_s))

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    def test_shape_without_column_attention(self):
        batch_size = consts.batch_size
        n_seq = consts.n_seq
        n_res = consts.n_res
        c_m = consts.c_m
        c_z = consts.c_z
        c_hidden_msa_att = 12
        c_hidden_opm = 17
        c_hidden_mul = 19
        c_hidden_pair_att = 14
        c_s = consts.c_s
        no_heads_msa = 3
        no_heads_pair = 7
        no_blocks = 2
        transition_n = 2
        msa_dropout = 0.15
        pair_stack_dropout = 0.25
        inf = 1e9
        eps = 1e-10

        es = EvoformerStack(
            c_m,
            c_z,
            c_hidden_msa_att,
            c_hidden_opm,
            c_hidden_mul,
            c_hidden_pair_att,
            c_s,
            no_heads_msa,
            no_heads_pair,
            no_blocks,
            transition_n,
            msa_dropout,
            pair_stack_dropout,
            no_column_attention=True,
130
131
132
            opm_first=False,
            fuse_projection_weights=False,
            blocks_per_ckpt=None,
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
            inf=inf,
            eps=eps,
        ).eval()

        m_init = torch.rand((batch_size, n_seq, n_res, c_m))
        z_init = torch.rand((batch_size, n_res, n_res, c_z))
        msa_mask = torch.randint(0, 2, size=(batch_size, n_seq, n_res))
        pair_mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))

        shape_m_before = m_init.shape
        shape_z_before = z_init.shape

        m, z, s = es(
            m_init, z_init, chunk_size=4, msa_mask=msa_mask, pair_mask=pair_mask
        )

        self.assertTrue(m.shape == shape_m_before)
        self.assertTrue(z.shape == shape_z_before)
        self.assertTrue(s.shape == (batch_size, n_res, c_s))

153
154
155
156
157
158
    @compare_utils.skip_unless_alphafold_installed()
    def test_compare(self):
        def run_ei(activations, masks):
            config = compare_utils.get_alphafold_config()
            c_e = config.model.embeddings_and_evoformer.evoformer
            ei = alphafold.model.modules.EvoformerIteration(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
159
160
                c_e, config.model.global_config, is_extra_msa=False
            )
161
            return ei(activations, masks, is_training=False)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
162

163
        f = hk.transform(run_ei)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
164

165
166
        n_res = consts.n_res
        n_seq = consts.n_seq
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
167

168
        activations = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
169
170
            "msa": np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32),
            "pair": np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
171
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
172

173
        masks = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
174
175
            "msa": np.random.randint(0, 2, (n_seq, n_res)).astype(np.float32),
            "pair": np.random.randint(0, 2, (n_res, n_res)).astype(np.float32),
176
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
177

178
179
180
        params = compare_utils.fetch_alphafold_module_weights(
            "alphafold/alphafold_iteration/evoformer/evoformer_iteration"
        )
181
        params = tree_map(lambda n: n[0], params, jax.Array)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
182

183
        key = jax.random.PRNGKey(42)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
184
        out_gt = f.apply(params, key, activations, masks)
185
186
187
        jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt_msa = torch.as_tensor(np.array(out_gt["msa"]))
        out_gt_pair = torch.as_tensor(np.array(out_gt["pair"]))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
188

189
190
        model = compare_utils.get_global_pretrained_openfold()
        out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
191
192
193
            torch.as_tensor(activations["msa"]).cuda(),
            torch.as_tensor(activations["pair"]).cuda(),
            torch.as_tensor(masks["msa"]).cuda(),
194
            torch.as_tensor(masks["pair"]).cuda(),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
195
            chunk_size=4,
196
            _mask_trans=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
197
            inplace_safe=False,
198
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
199

200
201
202
        out_repro_msa = out_repro_msa.cpu()
        out_repro_pair = out_repro_pair.cpu()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
        self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)

        # Inplace version
        out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
            torch.as_tensor(activations["msa"]).cuda(),
            torch.as_tensor(activations["pair"]).cuda(),
            torch.as_tensor(masks["msa"]).cuda(),
            torch.as_tensor(masks["pair"]).cuda(),
            chunk_size=4,
            _mask_trans=False,
            inplace_safe=True,
        )

        out_repro_msa = out_repro_msa.cpu()
        out_repro_pair = out_repro_pair.cpu()

        self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
        self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
222

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
223
224

class TestExtraMSAStack(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
225
    def test_shape(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        batch_size = 2
        s_t = 23
        n_res = 5
        c_m = 7
        c_z = 11
        c_hidden_msa_att = 12
        c_hidden_opm = 17
        c_hidden_mul = 19
        c_hidden_tri_att = 16
        no_heads_msa = 3
        no_heads_pair = 8
        no_blocks = 2
        transition_n = 5
        msa_dropout = 0.15
        pair_stack_dropout = 0.25
241
        opm_first = consts.is_multimer
Christina Floristean's avatar
Christina Floristean committed
242
        fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        inf = 1e9
        eps = 1e-10

        es = ExtraMSAStack(
            c_m,
            c_z,
            c_hidden_msa_att,
            c_hidden_opm,
            c_hidden_mul,
            c_hidden_tri_att,
            no_heads_msa,
            no_heads_pair,
            no_blocks,
            transition_n,
            msa_dropout,
            pair_stack_dropout,
259
            opm_first,
Christina Floristean's avatar
Christina Floristean committed
260
            fuse_projection_weights,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
261
            ckpt=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
262
263
            inf=inf,
            eps=eps,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
264
        ).eval().cuda()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
265

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
266
267
        m = torch.rand((batch_size, s_t, n_res, c_m), device="cuda")
        z = torch.rand((batch_size, n_res, n_res, c_z), device="cuda")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
268
269
270
271
272
273
274
275
        msa_mask = torch.randint(
            0,
            2,
            size=(
                batch_size,
                s_t,
                n_res,
            ),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
276
            device="cuda",
277
        ).float()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
278
279
280
281
282
283
284
285
        pair_mask = torch.randint(
            0,
            2,
            size=(
                batch_size,
                n_res,
                n_res,
            ),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
286
            device="cuda",
287
        ).float()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
288
289
290

        shape_z_before = z.shape

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
291
        z = es(m, z, chunk_size=4, msa_mask=msa_mask, pair_mask=pair_mask)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
292
293
294
295
296

        self.assertTrue(z.shape == shape_z_before)


class TestMSATransition(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
297
    def test_shape(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
298
299
300
301
302
303
        batch_size = 2
        s_t = 3
        n_r = 5
        c_m = 7
        n = 11

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
304
        mt = MSATransition(c_m, n)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
305
306
307
308

        m = torch.rand((batch_size, s_t, n_r, c_m))

        shape_before = m.shape
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
309
        m = mt(m, chunk_size=4)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
310
311
312
313
        shape_after = m.shape

        self.assertTrue(shape_before == shape_after)

314
315
316
317
318
319
    @compare_utils.skip_unless_alphafold_installed()
    def test_compare(self):
        def run_msa_transition(msa_act, msa_mask):
            config = compare_utils.get_alphafold_config()
            c_e = config.model.embeddings_and_evoformer.evoformer
            msa_trans = alphafold.model.modules.Transition(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
320
                c_e.msa_transition,
321
                config.model.global_config,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
322
                name="msa_transition",
323
324
325
            )
            act = msa_trans(act=msa_act, mask=msa_mask)
            return act
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
326

327
        f = hk.transform(run_msa_transition)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
328

329
330
        n_res = consts.n_res
        n_seq = consts.n_seq
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
331

332
        msa_act = np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
333
334
335
336
        msa_mask = np.ones((n_seq, n_res)).astype(
            np.float32
        )  # no mask here either

337
338
        # Fetch pretrained parameters (but only from one block)]
        params = compare_utils.fetch_alphafold_module_weights(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
339
340
            "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
            + "msa_transition"
341
        )
342
        params = tree_map(lambda n: n[0], params, jax.Array)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
343
344

        out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
345
        out_gt = torch.as_tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
346

347
        model = compare_utils.get_global_pretrained_openfold()
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
348
        
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
349
        out_repro = (
350
            model.evoformer.blocks[0].msa_transition(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
351
352
353
354
355
356
                torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
                mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
            )
            .cpu()
        )

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
357
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
358

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
359
360
361

if __name__ == "__main__":
    unittest.main()