test_evoformer.py 8.97 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import numpy as np
import unittest
18
19
20
21
22
23
24
25
26
from openfold.model.evoformer import (
    MSATransition,
    EvoformerStack,
    ExtraMSAStack,
)
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
27
if compare_utils.alphafold_is_installed():
28
29
30
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
31
32
33


class TestEvoformerStack(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
34
    def test_shape(self):
35
36
37
38
39
        batch_size = consts.batch_size
        n_seq = consts.n_seq
        n_res = consts.n_res
        c_m = consts.c_m
        c_z = consts.c_z
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
40
41
42
43
        c_hidden_msa_att = 12
        c_hidden_opm = 17
        c_hidden_mul = 19
        c_hidden_pair_att = 14
44
        c_s = consts.c_s
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
45
46
47
48
49
50
        no_heads_msa = 3
        no_heads_pair = 7
        no_blocks = 2
        transition_n = 2
        msa_dropout = 0.15
        pair_stack_dropout = 0.25
51
        opm_first = consts.is_multimer
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        inf = 1e9
        eps = 1e-10

        es = EvoformerStack(
            c_m,
            c_z,
            c_hidden_msa_att,
            c_hidden_opm,
            c_hidden_mul,
            c_hidden_pair_att,
            c_s,
            no_heads_msa,
            no_heads_pair,
            no_blocks,
            transition_n,
            msa_dropout,
            pair_stack_dropout,
69
            opm_first,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
70
71
72
73
74
            blocks_per_ckpt=None,
            inf=inf,
            eps=eps,
        ).eval()

75
        m = torch.rand((batch_size, n_seq, n_res, c_m))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
76
        z = torch.rand((batch_size, n_res, n_res, c_z))
77
        msa_mask = torch.randint(0, 2, size=(batch_size, n_seq, n_res))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
78
79
80
81
82
        pair_mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))

        shape_m_before = m.shape
        shape_z_before = z.shape

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
83
84
85
        m, z, s = es(
            m, z, chunk_size=4, msa_mask=msa_mask, pair_mask=pair_mask
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
86
87
88
89
90

        self.assertTrue(m.shape == shape_m_before)
        self.assertTrue(z.shape == shape_z_before)
        self.assertTrue(s.shape == (batch_size, n_res, c_s))

91
92
93
94
95
96
    @compare_utils.skip_unless_alphafold_installed()
    def test_compare(self):
        def run_ei(activations, masks):
            config = compare_utils.get_alphafold_config()
            c_e = config.model.embeddings_and_evoformer.evoformer
            ei = alphafold.model.modules.EvoformerIteration(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
97
98
                c_e, config.model.global_config, is_extra_msa=False
            )
99
            return ei(activations, masks, is_training=False)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
100

101
        f = hk.transform(run_ei)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
102

103
104
        n_res = consts.n_res
        n_seq = consts.n_seq
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
105

106
        activations = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
107
108
            "msa": np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32),
            "pair": np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
109
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
110

111
        masks = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
112
113
            "msa": np.random.randint(0, 2, (n_seq, n_res)).astype(np.float32),
            "pair": np.random.randint(0, 2, (n_res, n_res)).astype(np.float32),
114
        }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
115

116
117
118
119
        params = compare_utils.fetch_alphafold_module_weights(
            "alphafold/alphafold_iteration/evoformer/evoformer_iteration"
        )
        params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
120

121
        key = jax.random.PRNGKey(42)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
122
        out_gt = f.apply(params, key, activations, masks)
123
124
125
        jax.tree_map(lambda x: x.block_until_ready(), out_gt)
        out_gt_msa = torch.as_tensor(np.array(out_gt["msa"]))
        out_gt_pair = torch.as_tensor(np.array(out_gt["pair"]))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
126

127
128
        model = compare_utils.get_global_pretrained_openfold()
        out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
129
130
131
            torch.as_tensor(activations["msa"]).cuda(),
            torch.as_tensor(activations["pair"]).cuda(),
            torch.as_tensor(masks["msa"]).cuda(),
132
            torch.as_tensor(masks["pair"]).cuda(),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
133
            chunk_size=4,
134
            _mask_trans=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
135
            inplace_safe=False,
136
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
137

138
139
140
        out_repro_msa = out_repro_msa.cpu()
        out_repro_pair = out_repro_pair.cpu()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
        self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)

        # Inplace version
        out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
            torch.as_tensor(activations["msa"]).cuda(),
            torch.as_tensor(activations["pair"]).cuda(),
            torch.as_tensor(masks["msa"]).cuda(),
            torch.as_tensor(masks["pair"]).cuda(),
            chunk_size=4,
            _mask_trans=False,
            inplace_safe=True,
        )

        out_repro_msa = out_repro_msa.cpu()
        out_repro_pair = out_repro_pair.cpu()

        self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
        self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
160

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
161
162

class TestExtraMSAStack(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
163
    def test_shape(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        batch_size = 2
        s_t = 23
        n_res = 5
        c_m = 7
        c_z = 11
        c_hidden_msa_att = 12
        c_hidden_opm = 17
        c_hidden_mul = 19
        c_hidden_tri_att = 16
        no_heads_msa = 3
        no_heads_pair = 8
        no_blocks = 2
        transition_n = 5
        msa_dropout = 0.15
        pair_stack_dropout = 0.25
179
        opm_first = consts.is_multimer
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        inf = 1e9
        eps = 1e-10

        es = ExtraMSAStack(
            c_m,
            c_z,
            c_hidden_msa_att,
            c_hidden_opm,
            c_hidden_mul,
            c_hidden_tri_att,
            no_heads_msa,
            no_heads_pair,
            no_blocks,
            transition_n,
            msa_dropout,
            pair_stack_dropout,
196
            opm_first,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
197
            ckpt=False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
198
199
200
201
202
203
            inf=inf,
            eps=eps,
        ).eval()

        m = torch.rand((batch_size, s_t, n_res, c_m))
        z = torch.rand((batch_size, n_res, n_res, c_z))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        msa_mask = torch.randint(
            0,
            2,
            size=(
                batch_size,
                s_t,
                n_res,
            ),
        )
        pair_mask = torch.randint(
            0,
            2,
            size=(
                batch_size,
                n_res,
                n_res,
            ),
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
222
223
224

        shape_z_before = z.shape

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
225
        z = es(m, z, chunk_size=4, msa_mask=msa_mask, pair_mask=pair_mask)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
226
227
228
229
230

        self.assertTrue(z.shape == shape_z_before)


class TestMSATransition(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
231
    def test_shape(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
232
233
234
235
236
237
        batch_size = 2
        s_t = 3
        n_r = 5
        c_m = 7
        n = 11

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
238
        mt = MSATransition(c_m, n)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
239
240
241
242

        m = torch.rand((batch_size, s_t, n_r, c_m))

        shape_before = m.shape
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
243
        m = mt(m, chunk_size=4)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
244
245
246
247
        shape_after = m.shape

        self.assertTrue(shape_before == shape_after)

248
249
250
251
252
253
    @compare_utils.skip_unless_alphafold_installed()
    def test_compare(self):
        def run_msa_transition(msa_act, msa_mask):
            config = compare_utils.get_alphafold_config()
            c_e = config.model.embeddings_and_evoformer.evoformer
            msa_trans = alphafold.model.modules.Transition(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
254
                c_e.msa_transition,
255
                config.model.global_config,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
256
                name="msa_transition",
257
258
259
            )
            act = msa_trans(act=msa_act, mask=msa_mask)
            return act
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
260

261
        f = hk.transform(run_msa_transition)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
262

263
264
        n_res = consts.n_res
        n_seq = consts.n_seq
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
265

266
        msa_act = np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
267
268
269
270
        msa_mask = np.ones((n_seq, n_res)).astype(
            np.float32
        )  # no mask here either

271
272
        # Fetch pretrained parameters (but only from one block)]
        params = compare_utils.fetch_alphafold_module_weights(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
273
274
            "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
            + "msa_transition"
275
276
        )
        params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
277
278

        out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
279
        out_gt = torch.as_tensor(np.array(out_gt))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
280

281
        model = compare_utils.get_global_pretrained_openfold()
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
282
        
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
283
        out_repro = (
284
            model.evoformer.blocks[0].msa_transition(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
285
286
287
288
289
290
                torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
                mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
            )
            .cpu()
        )

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
291
        self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
292

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
293
294
295

if __name__ == "__main__":
    unittest.main()