modules.py 7.7 KB
Newer Older
1
import math
2
3
4
5
6
7
8

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from utlis import *

9
import dgl.function as fn
10
from dgl.nn.functional import edge_softmax
11
12
13
14
15
16
17


class MSA(nn.Module):
    # multi-head self-attention, three modes
    # the first is the copy, determining which entity should be copied.
    # the second is the normal attention with two sequence inputs
    # the third is the attention but with one token and a sequence. (gather, attentive pooling)
18
19

    def __init__(self, args, mode="normal"):
20
        super(MSA, self).__init__()
21
        if mode == "copy":
22
23
            nhead, head_dim = 1, args.nhid
            qninp, kninp = args.dec_ninp, args.nhid
24
        if mode == "normal":
25
26
27
            nhead, head_dim = args.nhead, args.head_dim
            qninp, kninp = args.nhid, args.nhid
        self.attn_drop = nn.Dropout(0.1)
28
29
30
31
32
33
34
35
36
37
38
39
        self.WQ = nn.Linear(
            qninp, nhead * head_dim, bias=True if mode == "copy" else False
        )
        if mode != "copy":
            self.WK = nn.Linear(kninp, nhead * head_dim, bias=False)
            self.WV = nn.Linear(kninp, nhead * head_dim, bias=False)
        self.args, self.nhead, self.head_dim, self.mode = (
            args,
            nhead,
            head_dim,
            mode,
        )
40
41
42
43

    def forward(self, inp1, inp2, mask=None):
        B, L2, H = inp2.shape
        NH, HD = self.nhead, self.head_dim
44
        if self.mode == "copy":
45
46
47
            q, k, v = self.WQ(inp1), inp2, inp2
        else:
            q, k, v = self.WQ(inp1), self.WK(inp2), self.WV(inp2)
48
49
        L1 = 1 if inp1.ndim == 2 else inp1.shape[1]
        if self.mode != "copy":
50
            q = q / math.sqrt(H)
51
        q = q.view(B, L1, NH, HD).permute(0, 2, 1, 3)
52
53
        k = k.view(B, L2, NH, HD).permute(0, 2, 3, 1)
        v = v.view(B, L2, NH, HD).permute(0, 2, 1, 3)
54
        pre_attn = torch.matmul(q, k)
55
        if mask is not None:
56
57
            pre_attn = pre_attn.masked_fill(mask[:, None, None, :], -1e8)
        if self.mode == "copy":
58
59
60
            return pre_attn.squeeze(1)
        else:
            alpha = self.attn_drop(torch.softmax(pre_attn, -1))
61
62
63
64
65
66
            attn = (
                torch.matmul(alpha, v)
                .permute(0, 2, 1, 3)
                .contiguous()
                .view(B, L1, NH * HD)
            )
67
            ret = attn
68
            if inp1.ndim == 2:
69
70
71
72
73
74
75
                return ret.squeeze(1)
            else:
                return ret


class BiLSTM(nn.Module):
    # for entity encoding or the title encoding
76
    def __init__(self, args, enc_type="title"):
77
78
79
        super(BiLSTM, self).__init__()
        self.enc_type = enc_type
        self.drop = nn.Dropout(args.emb_drop)
80
81
82
83
84
85
86
87
        self.bilstm = nn.LSTM(
            args.nhid,
            args.nhid // 2,
            bidirectional=True,
            num_layers=args.enc_lstm_layers,
            batch_first=True,
        )

88
89
    def forward(self, inp, mask, ent_len=None):
        inp = self.drop(inp)
90
91
92
93
        lens = (mask == 0).sum(-1).long().tolist()
        pad_seq = pack_padded_sequence(
            inp, lens, batch_first=True, enforce_sorted=False
        )
94
        y, (_h, _c) = self.bilstm(pad_seq)
95
        if self.enc_type == "title":
96
97
            y = pad_packed_sequence(y, batch_first=True)[0]
            return y
98
99
100
101
102
103
        if self.enc_type == "entity":
            _h = _h.transpose(0, 1).contiguous()
            _h = _h[:, -2:].view(
                _h.size(0), -1
            )  # two directions of the top-layer
            ret = pad(_h.split(ent_len), out_type="tensor")
104
105
106
107
108
            return ret


class GAT(nn.Module):
    # a graph attention network with dot-product attention
109
110
111
112
113
114
115
116
117
    def __init__(
        self,
        in_feats,
        out_feats,
        num_heads,
        ffn_drop=0.0,
        attn_drop=0.0,
        trans=True,
    ):
118
119
120
121
        super(GAT, self).__init__()
        self._num_heads = num_heads
        self._in_feats = in_feats
        self._out_feats = out_feats
122
123
124
        self.q_proj = nn.Linear(in_feats, num_heads * out_feats, bias=False)
        self.k_proj = nn.Linear(in_feats, num_heads * out_feats, bias=False)
        self.v_proj = nn.Linear(in_feats, num_heads * out_feats, bias=False)
125
126
127
128
129
        self.attn_drop = nn.Dropout(0.1)
        self.ln1 = nn.LayerNorm(in_feats)
        self.ln2 = nn.LayerNorm(in_feats)
        if trans:
            self.FFN = nn.Sequential(
130
131
132
                nn.Linear(in_feats, 4 * in_feats),
                nn.PReLU(4 * in_feats),
                nn.Linear(4 * in_feats, in_feats),
133
134
135
136
137
138
139
140
141
142
143
144
                nn.Dropout(0.1),
            )
            # a strange FFN, see the author's code
        self._trans = trans

    def forward(self, graph, feat):
        graph = graph.local_var()
        feat_c = feat.clone().detach().requires_grad_(False)
        q, k, v = self.q_proj(feat), self.k_proj(feat_c), self.v_proj(feat_c)
        q = q.view(-1, self._num_heads, self._out_feats)
        k = k.view(-1, self._num_heads, self._out_feats)
        v = v.view(-1, self._num_heads, self._out_feats)
145
146
147
        graph.ndata.update(
            {"ft": v, "el": k, "er": q}
        )  # k,q instead of q,k, the edge_softmax is applied on incoming edges
148
        # compute edge attention
149
150
151
        graph.apply_edges(fn.u_dot_v("el", "er", "e"))
        e = graph.edata.pop("e") / math.sqrt(self._out_feats * self._num_heads)
        graph.edata["a"] = edge_softmax(graph, e)
152
        # message passing
153
154
        graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft2"))
        rst = graph.ndata["ft2"]
155
156
157
158
        # residual
        rst = rst.view(feat.shape) + feat
        if self._trans:
            rst = self.ln1(rst)
159
            rst = self.ln1(rst + self.FFN(rst))
160
161
162
163
164
            # use the same layer norm, see the author's code
        return rst


class GraphTrans(nn.Module):
165
    def __init__(self, args):
166
167
168
169
        super().__init__()
        self.args = args
        if args.graph_enc == "gat":
            # we only support gtrans, don't use this one
170
171
172
173
174
175
176
177
178
179
180
181
            self.gat = nn.ModuleList(
                [
                    GAT(
                        args.nhid,
                        args.nhid // 4,
                        4,
                        attn_drop=args.attn_drop,
                        trans=False,
                    )
                    for _ in range(args.prop)
                ]
            )  # untested
182
        else:
183
184
185
186
187
188
189
190
191
192
193
194
195
            self.gat = nn.ModuleList(
                [
                    GAT(
                        args.nhid,
                        args.nhid // 4,
                        4,
                        attn_drop=args.attn_drop,
                        ffn_drop=args.drop,
                        trans=True,
                    )
                    for _ in range(args.prop)
                ]
            )
196
197
198
199
        self.prop = args.prop

    def forward(self, ent, ent_mask, ent_len, rel, rel_mask, graphs):
        device = ent.device
200
        graphs = graphs.to(device)
201
202
        ent_mask = ent_mask == 0  # reverse mask
        rel_mask = rel_mask == 0
203
204
205
206
207
208
209
210
        init_h = []
        for i in range(graphs.batch_size):
            init_h.append(ent[i][ent_mask[i]])
            init_h.append(rel[i][rel_mask[i]])
        init_h = torch.cat(init_h, 0)
        feats = init_h
        for i in range(self.prop):
            feats = self.gat[i](graphs, feats)
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        g_root = feats.index_select(
            0,
            graphs.filter_nodes(
                lambda x: x.data["type"] == NODE_TYPE["root"]
            ).to(device),
        )
        g_ent = pad(
            feats.index_select(
                0,
                graphs.filter_nodes(
                    lambda x: x.data["type"] == NODE_TYPE["entity"]
                ).to(device),
            ).split(ent_len),
            out_type="tensor",
        )
226
        return g_ent, g_root