test_evoformer.py 9.35 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Christina Floristean's avatar
Christina Floristean committed
15
import re
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
17
18
import torch
import numpy as np
import unittest
19
20
21
22
23
24
25
26
27
from openfold.model.evoformer import (
    MSATransition,
    EvoformerStack,
    ExtraMSAStack,
)
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
28
if compare_utils.alphafold_is_installed():
29
30
31
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
32
33
34


class TestEvoformerStack(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
35
    def test_shape(self):
36
37
38
39
40
        batch_size = consts.batch_size
        n_seq = consts.n_seq
        n_res = consts.n_res
        c_m = consts.c_m
        c_z = consts.c_z
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
41
42
43
44
        c_hidden_msa_att = 12
        c_hidden_opm = 17
        c_hidden_mul = 19
        c_hidden_pair_att = 14
45
        c_s = consts.c_s
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
46
47
48
49
50
51
        no_heads_msa = 3
        no_heads_pair = 7
        no_blocks = 2
        transition_n = 2
        msa_dropout = 0.15
        pair_stack_dropout = 0.25
52
        opm_first = consts.is_multimer
Christina Floristean's avatar
Christina Floristean committed
53
        fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        inf = 1e9
        eps = 1e-10

        es = EvoformerStack(
            c_m,
            c_z,
            c_hidden_msa_att,
            c_hidden_opm,
            c_hidden_mul,
            c_hidden_pair_att,
            c_s,
            no_heads_msa,
            no_heads_pair,
            no_blocks,
            transition_n,
            msa_dropout,
            pair_stack_dropout,
71
            opm_first,
Christina Floristean's avatar
Christina Floristean committed
72
            fuse_projection_weights,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
73
74
75
76
77
            blocks_per_ckpt=None,
            inf=inf,
            eps=eps,
        ).eval()

78
        m = torch.rand((batch_size, n_seq, n_res, c_m))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
79
        z = torch.rand((batch_size, n_res, n_res, c_z))
80
        msa_mask = torch.randint(0, 2, size=(batch_size, n_seq, n_res))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
81
82
83
84
85
        pair_mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))

        shape_m_before = m.shape
        shape_z_before = z.shape

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
86
87
88
        m, z, s = es(
            m, z, chunk_size=4, msa_mask=msa_mask, pair_mask=pair_mask
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
89
90
91
92
93

        self.assertTrue(m.shape == shape_m_before)
        self.assertTrue(z.shape == shape_z_before)
        self.assertTrue(s.shape == (batch_size, n_res, c_s))

94
95
96
97
98
99
    @compare_utils.skip_unless_alphafold_installed()
    def test_compare(self):
        def run_ei(activations, masks):
            config = compare_utils.get_alphafold_config()
            c_e = config.model.embeddings_and_evoformer.evoformer
            ei = alphafold.model.modules.EvoformerIteration(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
100
101
                c_e, config.model.global_config, is_extra_msa=False
            )
102
            return ei(activations, masks, is_training=False)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
103

104
        f = hk.transform(run_ei)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
105

106
107
        n_res = consts.n_res
        n_seq = consts.n_seq
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
108

109
        activations = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
110
111
            "msa": np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32),
            "pair": np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
112
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
113

114
        masks = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
115
116
            "msa": np.random.randint(0, 2, (n_seq, n_res)).astype(np.float32),
            "pair": np.random.randint(0, 2, (n_res, n_res)).astype(np.float32),
117
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
118

119
120
121
122
        params = compare_utils.fetch_alphafold_module_weights(
            "alphafold/alphafold_iteration/evoformer/evoformer_iteration"
        )
        params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
123

124
        key = jax.random.PRNGKey(42)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
125
        out_gt = f.apply(params, key, activations, masks)
126
127
128
        jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt_msa = torch.as_tensor(np.array(out_gt["msa"]))
        out_gt_pair = torch.as_tensor(np.array(out_gt["pair"]))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
129

130
131
        model = compare_utils.get_global_pretrained_openfold()
        out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
132
133
134
            torch.as_tensor(activations["msa"]).cuda(),
            torch.as_tensor(activations["pair"]).cuda(),
            torch.as_tensor(masks["msa"]).cuda(),
135
            torch.as_tensor(masks["pair"]).cuda(),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
136
            chunk_size=4,
137
            _mask_trans=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
138
            inplace_safe=False,
139
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
140

141
142
143
        out_repro_msa = out_repro_msa.cpu()
        out_repro_pair = out_repro_pair.cpu()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
        self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)

        # Inplace version
        out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
            torch.as_tensor(activations["msa"]).cuda(),
            torch.as_tensor(activations["pair"]).cuda(),
            torch.as_tensor(masks["msa"]).cuda(),
            torch.as_tensor(masks["pair"]).cuda(),
            chunk_size=4,
            _mask_trans=False,
            inplace_safe=True,
        )

        out_repro_msa = out_repro_msa.cpu()
        out_repro_pair = out_repro_pair.cpu()

        self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
        self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
163

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
164
165

class TestExtraMSAStack(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
166
    def test_shape(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        batch_size = 2
        s_t = 23
        n_res = 5
        c_m = 7
        c_z = 11
        c_hidden_msa_att = 12
        c_hidden_opm = 17
        c_hidden_mul = 19
        c_hidden_tri_att = 16
        no_heads_msa = 3
        no_heads_pair = 8
        no_blocks = 2
        transition_n = 5
        msa_dropout = 0.15
        pair_stack_dropout = 0.25
182
        opm_first = consts.is_multimer
Christina Floristean's avatar
Christina Floristean committed
183
        fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        inf = 1e9
        eps = 1e-10

        es = ExtraMSAStack(
            c_m,
            c_z,
            c_hidden_msa_att,
            c_hidden_opm,
            c_hidden_mul,
            c_hidden_tri_att,
            no_heads_msa,
            no_heads_pair,
            no_blocks,
            transition_n,
            msa_dropout,
            pair_stack_dropout,
200
            opm_first,
Christina Floristean's avatar
Christina Floristean committed
201
            fuse_projection_weights,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
202
            ckpt=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
203
204
            inf=inf,
            eps=eps,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
205
        ).eval().cuda()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
206

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
207
208
        m = torch.rand((batch_size, s_t, n_res, c_m), device="cuda")
        z = torch.rand((batch_size, n_res, n_res, c_z), device="cuda")
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
209
210
211
212
213
214
215
216
        msa_mask = torch.randint(
            0,
            2,
            size=(
                batch_size,
                s_t,
                n_res,
            ),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
217
            device="cuda",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
218
219
220
221
222
223
224
225
226
        )
        pair_mask = torch.randint(
            0,
            2,
            size=(
                batch_size,
                n_res,
                n_res,
            ),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
227
            device="cuda",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
228
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
229
230
231

        shape_z_before = z.shape

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
232
        z = es(m, z, chunk_size=4, msa_mask=msa_mask, pair_mask=pair_mask)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
233
234
235
236
237

        self.assertTrue(z.shape == shape_z_before)


class TestMSATransition(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
238
    def test_shape(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
239
240
241
242
243
244
        batch_size = 2
        s_t = 3
        n_r = 5
        c_m = 7
        n = 11

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
245
        mt = MSATransition(c_m, n)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
246
247
248
249

        m = torch.rand((batch_size, s_t, n_r, c_m))

        shape_before = m.shape
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
250
        m = mt(m, chunk_size=4)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
251
252
253
254
        shape_after = m.shape

        self.assertTrue(shape_before == shape_after)

255
256
257
258
259
260
    @compare_utils.skip_unless_alphafold_installed()
    def test_compare(self):
        def run_msa_transition(msa_act, msa_mask):
            config = compare_utils.get_alphafold_config()
            c_e = config.model.embeddings_and_evoformer.evoformer
            msa_trans = alphafold.model.modules.Transition(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
261
                c_e.msa_transition,
262
                config.model.global_config,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
263
                name="msa_transition",
264
265
266
            )
            act = msa_trans(act=msa_act, mask=msa_mask)
            return act
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
267

268
        f = hk.transform(run_msa_transition)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
269

270
271
        n_res = consts.n_res
        n_seq = consts.n_seq
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
272

273
        msa_act = np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
274
275
276
277
        msa_mask = np.ones((n_seq, n_res)).astype(
            np.float32
        )  # no mask here either

278
279
        # Fetch pretrained parameters (but only from one block)]
        params = compare_utils.fetch_alphafold_module_weights(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
280
281
            "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
            + "msa_transition"
282
283
        )
        params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
284
285

        out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
286
        out_gt = torch.as_tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
287

288
        model = compare_utils.get_global_pretrained_openfold()
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
289
        
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
290
        out_repro = (
291
            model.evoformer.blocks[0].msa_transition(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
292
293
294
295
296
297
                torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
                mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
            )
            .cpu()
        )

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
298
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
299

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
300
301
302

if __name__ == "__main__":
    unittest.main()