triangular_attention.py 3.34 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
3
#
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partialmethod
import math
import torch
import torch.nn as nn

21
from openfold.model.primitives import Linear, Attention
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
22
from openfold.utils.tensor_utils import (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
23
24
    chunk_layer,
    permute_final_dims,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
25
26
27
28
29
    flatten_final_dims,
)


class TriangleAttention(nn.Module):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
30
    def __init__(
31
        self, c_in, c_hidden, no_heads, starting, inf=1e9
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
32
33
    ):
        """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
34
35
36
37
38
39
40
        Args:
            c_in:
                Input channel dimension
            c_hidden:
                Overall hidden channel dimension (not per-head)
            no_heads:
                Number of attention heads
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
41
42
43
44
45
46
47
48
49
50
        """
        super(TriangleAttention, self).__init__()

        self.c_in = c_in
        self.c_hidden = c_hidden
        self.no_heads = no_heads
        self.starting = starting
        self.inf = inf

        self.layer_norm = nn.LayerNorm(self.c_in)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
51

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
52
53
        self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")

54
        self.mha = Attention(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
55
            self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
56
57
        )

58
    def forward(self, x, chunk_size, mask=None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
59
        """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
60
61
62
63
64
        Args:
            x:
                [*, I, J, C_in] input tensor (e.g. the pair representation)
        Returns:
            [*, I, J, C_in] output tensor
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
65
        """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
66
        if mask is None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
67
            # [*, I, J]
68
            mask = x.new_ones(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
69
                x.shape[:-1],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
70
71
72
            )

        # Shape annotations assume self.starting. Else, I and J are flipped
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
73
        if not self.starting:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
74
75
76
77
78
            x = x.transpose(-2, -3)
            mask = mask.transpose(-1, -2)

        # [*, I, J, C_in]
        x = self.layer_norm(x)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
79

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
80
81
        # [*, I, 1, 1, J]
        mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
82

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
83
        # [*, H, I, J]
84
        triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
85
86
87
88
89
90
91
92
93
94

        # [*, 1, H, I, J]
        triangle_bias = triangle_bias.unsqueeze(-4)

        mha_inputs = {
            "q_x": x,
            "k_x": x,
            "v_x": x,
            "biases": [mask_bias, triangle_bias],
        }
95
        if chunk_size is not None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
96
97
            x = chunk_layer(
                self.mha,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
98
                mha_inputs,
99
                chunk_size=chunk_size,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
100
                no_batch_dims=len(x.shape[:-2]),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
101
102
103
104
            )
        else:
            x = self.mha(**mha_inputs)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
105
        if not self.starting:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
106
107
108
109
110
111
112
            x = x.transpose(-2, -3)

        return x


class TriangleAttentionStartingNode(TriangleAttention):
    """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
113
    Implements Algorithm 13.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
114
    """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
115

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
116
117
118
119
120
    __init__ = partialmethod(TriangleAttention.__init__, starting=True)


class TriangleAttentionEndingNode(TriangleAttention):
    """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
121
    Implements Algorithm 14.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
122
    """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
123

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
124
    __init__ = partialmethod(TriangleAttention.__init__, starting=False)