test_triangular_attention.py 3.82 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
14
import copy
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
15
16
17
18

import torch
import numpy as np
import unittest
19
20
21
22
23
24
from openfold.model.triangular_attention import TriangleAttention
from openfold.utils.tensor_utils import tree_map

import tests.compare_utils as compare_utils
from tests.config import consts

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
25
if compare_utils.alphafold_is_installed():
26
27
28
    alphafold = compare_utils.import_alphafold()
    import jax
    import haiku as hk
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
29
30
31
32


class TestTriangularAttention(unittest.TestCase):
    def test_shape(self):
33
        c_z = consts.c_z
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
34
35
36
37
        c = 12
        no_heads = 4
        starting = True

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
38
        tan = TriangleAttention(c_z, c, no_heads, starting)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
39

40
41
        batch_size = consts.batch_size
        n_res = consts.n_res
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42
43
44

        x = torch.rand((batch_size, n_res, n_res, c_z))
        shape_before = x.shape
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
45
        x = tan(x, chunk_size=None)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
46
47
48
49
        shape_after = x.shape

        self.assertTrue(shape_before == shape_after)

50
51
    def _tri_att_compare(self, starting=False):
        name = (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
52
53
54
            "triangle_attention_"
            + ("starting" if starting else "ending")
            + "_node"
55
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
56

57
58
59
60
        def run_tri_att(pair_act, pair_mask):
            config = compare_utils.get_alphafold_config()
            c_e = config.model.embeddings_and_evoformer.evoformer
            tri_att = alphafold.model.modules.TriangleAttention(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
61
62
63
                c_e.triangle_attention_starting_node
                if starting
                else c_e.triangle_attention_ending_node,
64
65
66
67
68
                config.model.global_config,
                name=name,
            )
            act = tri_att(pair_act=pair_act, pair_mask=pair_mask)
            return act
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
69

70
71
72
73
        f = hk.transform(run_tri_att)

        n_res = consts.n_res

74
        pair_act = np.random.rand(n_res, n_res, consts.c_z) * 100
75
76
77
78
        pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res))

        # Fetch pretrained parameters (but only from one block)]
        params = compare_utils.fetch_alphafold_module_weights(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
79
80
            "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
            + name
81
        )
82
        params = tree_map(lambda n: n[0], params, jax.Array)
83

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
84
        out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
85
86
87
88
        out_gt = torch.as_tensor(np.array(out_gt))

        model = compare_utils.get_global_pretrained_openfold()
        module = (
89
            model.evoformer.blocks[0].pair_stack.tri_att_start
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
90
            if starting
91
            else model.evoformer.blocks[0].pair_stack.tri_att_end
92
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
93
94
95
96
97
98

        # To save memory, the full model transposes inputs outside of the
        # triangle attention module. We adjust the module here.
        module = copy.deepcopy(module)
        module.starting = starting

99
        out_repro = module(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
100
101
            torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
            mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
102
            chunk_size=None,
103
        ).cpu()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
104

105
        compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
106

107
    @compare_utils.skip_unless_alphafold_installed()
108
109
    def test_tri_att_end_compare(self):
        self._tri_att_compare()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
110

111
    @compare_utils.skip_unless_alphafold_installed()
112
113
114
115
116
    def test_tri_att_start_compare(self):
        self._tri_att_compare(starting=True)


if __name__ == "__main__":
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
117
    unittest.main()