msa.py 3.09 KB
Newer Older
oahzxl's avatar
oahzxl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn import LayerNorm

from .kernel import bias_dropout_add
from .ops import SelfAttention, Transition


class MSARowAttentionWithPairBias(nn.Module):

    def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15):
        super(MSARowAttentionWithPairBias, self).__init__()
        self.d_node = d_node
        self.d_pair = d_pair
        self.c = c
        self.n_head = n_head
        self.p_drop = p_drop

        self.layernormM = LayerNorm(d_node)
        self.layernormZ = LayerNorm(d_pair)

        _init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
                                              std=1.0 / math.sqrt(d_pair))
        self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True)

        self.attention = SelfAttention(qkv_dim=d_node,
                                       c=c,
                                       n_head=n_head,
                                       out_dim=d_node,
                                       gating=True,
                                       last_bias_fuse=True)

        self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True)

    def forward(self, M_raw, Z):
        ## Input projections
        M = self.layernormM(M_raw)
        Z = self.layernormZ(Z)
        b = F.linear(Z, self.linear_b_weights)
        b = b.permute(0, 3, 1, 2)
        # b = rearrange(b, 'b q k h -> b h q k')

        M = self.attention(M, b)
        dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype)

        return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop)


class MSAColumnAttention(nn.Module):

    def __init__(self, d_node, c=32, n_head=8):
        super(MSAColumnAttention, self).__init__()
        self.d_node = d_node
        self.c = c
        self.n_head = n_head

        self.layernormM = LayerNorm(d_node)
        self.attention = SelfAttention(qkv_dim=d_node,
                                       c=c,
                                       n_head=n_head,
                                       out_dim=d_node,
                                       gating=True)

    def forward(self, M_raw):
        M = M_raw.transpose(-2, -3)
        M = self.layernormM(M)

        M = self.attention(M)

        M = M.transpose(-2, -3)
        return M_raw + M


class MSAStack(nn.Module):

    def __init__(self, d_node, d_pair, p_drop=0.15):
        super(MSAStack, self).__init__()

        self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node,
                                                                       d_pair=d_pair,
                                                                       p_drop=p_drop)

        self.MSAColumnAttention = MSAColumnAttention(d_node=d_node)
        self.MSATransition = Transition(d=d_node)

    def forward(self, node, pair):
        node = self.MSARowAttentionWithPairBias(node, pair)
        node = self.MSAColumnAttention(node)
        node = self.MSATransition(node)

        return node