test_triangular_multiplicative_update.py 5.18 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
Christina Floristean's avatar
Christina Floristean committed
16
import re
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
17
18
import numpy as np
import unittest
19
20
21
22
23
from openfold.model.triangular_multiplicative_update import *
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
24
if compare_utils.alphafold_is_installed():
25
26
27
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
28
29
30


class TestTriangularMultiplicativeUpdate(unittest.TestCase):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
31
    def test_shape(self):
32
        c_z = consts.c_z
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
33
34
        c = 11

Christina Floristean's avatar
Christina Floristean committed
35
36
37
38
39
40
41
42
43
44
        if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model):
            tm = FusedTriangleMultiplicationOutgoing(
                c_z,
                c,
            )
        else:
            tm = TriangleMultiplicationOutgoing(
                c_z,
                c,
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
45

46
47
        n_res = consts.c_z
        batch_size = consts.batch_size
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
48
49
50
51
52
53
54
55
56

        x = torch.rand((batch_size, n_res, n_res, c_z))
        mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
        shape_before = x.shape
        x = tm(x, mask)
        shape_after = x.shape

        self.assertTrue(shape_before == shape_after)

57
    def _tri_mul_compare(self, incoming=False):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
58
59
        name = "triangle_multiplication_" + (
            "incoming" if incoming else "outgoing"
60
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
61

62
63
64
65
        def run_tri_mul(pair_act, pair_mask):
            config = compare_utils.get_alphafold_config()
            c_e = config.model.embeddings_and_evoformer.evoformer
            tri_mul = alphafold.model.modules.TriangleMultiplication(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
66
67
68
                c_e.triangle_multiplication_incoming
                if incoming
                else c_e.triangle_multiplication_outgoing,
69
70
71
                config.model.global_config,
                name=name,
            )
Christina Floristean's avatar
Christina Floristean committed
72
            act = tri_mul(pair_act, pair_mask)
73
            return act
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
74

75
76
77
78
79
80
81
82
83
84
        f = hk.transform(run_tri_mul)

        n_res = consts.n_res

        pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
        pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res))
        pair_mask = pair_mask.astype(np.float32)

        # Fetch pretrained parameters (but only from one block)]
        params = compare_utils.fetch_alphafold_module_weights(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
85
86
            "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
            + name
87
88
89
        )
        params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
90
        out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
91
92
93
94
        out_gt = torch.as_tensor(np.array(out_gt))

        model = compare_utils.get_global_pretrained_openfold()
        module = (
95
            model.evoformer.blocks[0].pair_stack.tri_mul_in
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
96
            if incoming
97
            else model.evoformer.blocks[0].pair_stack.tri_mul_out
98
        )
Christina Floristean's avatar
Christina Floristean committed
99

100
        out_repro = module(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
101
102
            torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
            mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
103
            inplace_safe=True, _inplace_chunk_size=4,
104
105
        ).cpu()

106
        self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
107
108
109
110
111
112
113
114
115

    @compare_utils.skip_unless_alphafold_installed()
    def test_tri_mul_out_compare(self):
        self._tri_mul_compare()

    @compare_utils.skip_unless_alphafold_installed()
    def test_tri_mul_in_compare(self):
        self._tri_mul_compare(incoming=True)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
116
    def _tri_mul_inplace(self, incoming=False):
117
118
119
120
121
122
123
124
        n_res = consts.n_res
        
        pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
        pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res))
        pair_mask = pair_mask.astype(np.float32)

        model = compare_utils.get_global_pretrained_openfold()
        module = (
Christina Floristean's avatar
Christina Floristean committed
125
            model.evoformer.blocks[0].pair_stack.tri_mul_in
126
            if incoming
Christina Floristean's avatar
Christina Floristean committed
127
            else model.evoformer.blocks[0].pair_stack.tri_mul_out
128
129
130
131
        )
        out_stock = module(
            torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
            mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
132
            inplace_safe=False,
133
134
135
        ).cpu()
        
        # This has to come second because inference mode is in-place
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
136
        out_inplace = module(
137
138
            torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
            mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
139
            inplace_safe=True, _inplace_chunk_size=2,
140
141
        ).cpu()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
142
        self.assertTrue(torch.mean(torch.abs(out_stock - out_inplace)) < consts.eps)
143
144

    def test_tri_mul_out_inference(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
145
        self._tri_mul_inplace()
146
147

    def test_tri_mul_in_inference(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
148
        self._tri_mul_inplace(incoming=True)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
149
150
151

if __name__ == "__main__":
    unittest.main()