models.py 9.66 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
2
import dgl
import dgl.function as fn
KounianhuaDu's avatar
KounianhuaDu committed
3
4
import torch as th
import torch.nn as nn
5
6
7
import torch.nn.functional as F
import torch.optim as optim
from utils import ccorr
KounianhuaDu's avatar
KounianhuaDu committed
8
9
10
11
12


class CompGraphConv(nn.Module):
    """One layer of CompGCN."""

13
14
15
    def __init__(
        self, in_dim, out_dim, comp_fn="sub", batchnorm=True, dropout=0.1
    ):
KounianhuaDu's avatar
KounianhuaDu committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
        super(CompGraphConv, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.comp_fn = comp_fn
        self.actvation = th.tanh
        self.batchnorm = batchnorm

        # define dropout layer
        self.dropout = nn.Dropout(dropout)

        # define batch norm layer
        if self.batchnorm:
            self.bn = nn.BatchNorm1d(out_dim)

        # define in/out/loop transform layer
        self.W_O = nn.Linear(self.in_dim, self.out_dim)
        self.W_I = nn.Linear(self.in_dim, self.out_dim)
        self.W_S = nn.Linear(self.in_dim, self.out_dim)

        # define relation transform layer
        self.W_R = nn.Linear(self.in_dim, self.out_dim)

38
        # self loop embedding
KounianhuaDu's avatar
KounianhuaDu committed
39
40
41
42
43
44
45
        self.loop_rel = nn.Parameter(th.Tensor(1, self.in_dim))
        nn.init.xavier_normal_(self.loop_rel)

    def forward(self, g, n_in_feats, r_feats):
        with g.local_scope():
            # Assign values to source nodes. In a homogeneous graph, this is equal to
            # assigning them to all nodes.
46
47
            g.srcdata["h"] = n_in_feats
            # append loop_rel embedding to r_feats
KounianhuaDu's avatar
KounianhuaDu committed
48
49
            r_feats = th.cat((r_feats, self.loop_rel), 0)
            # Assign features to all edges with the corresponding relation embeddings
50
            g.edata["h"] = r_feats[g.edata["etype"]] * g.edata["norm"]
KounianhuaDu's avatar
KounianhuaDu committed
51
52
53

            # Compute composition function in 4 steps
            # Step 1: compute composition by edge in the edge direction, and store results in edges.
54
55
56
57
58
59
60
61
62
63
            if self.comp_fn == "sub":
                g.apply_edges(fn.u_sub_e("h", "h", out="comp_h"))
            elif self.comp_fn == "mul":
                g.apply_edges(fn.u_mul_e("h", "h", out="comp_h"))
            elif self.comp_fn == "ccorr":
                g.apply_edges(
                    lambda edges: {
                        "comp_h": ccorr(edges.src["h"], edges.data["h"])
                    }
                )
KounianhuaDu's avatar
KounianhuaDu committed
64
            else:
65
                raise Exception("Only supports sub, mul, and ccorr")
KounianhuaDu's avatar
KounianhuaDu committed
66
67

            # Step 2: use extracted edge direction to compute in and out edges
68
            comp_h = g.edata["comp_h"]
KounianhuaDu's avatar
KounianhuaDu committed
69

70
71
72
73
74
75
            in_edges_idx = th.nonzero(
                g.edata["in_edges_mask"], as_tuple=False
            ).squeeze()
            out_edges_idx = th.nonzero(
                g.edata["out_edges_mask"], as_tuple=False
            ).squeeze()
KounianhuaDu's avatar
KounianhuaDu committed
76
77
78
79

            comp_h_O = self.W_O(comp_h[out_edges_idx])
            comp_h_I = self.W_I(comp_h[in_edges_idx])

80
81
82
            new_comp_h = th.zeros(comp_h.shape[0], self.out_dim).to(
                comp_h.device
            )
KounianhuaDu's avatar
KounianhuaDu committed
83
84
85
            new_comp_h[out_edges_idx] = comp_h_O
            new_comp_h[in_edges_idx] = comp_h_I

86
            g.edata["new_comp_h"] = new_comp_h
KounianhuaDu's avatar
KounianhuaDu committed
87
88

            # Step 3: sum comp results to both src and dst nodes
89
            g.update_all(fn.copy_e("new_comp_h", "m"), fn.sum("m", "comp_edge"))
KounianhuaDu's avatar
KounianhuaDu committed
90
91

            # Step 4: add results of self-loop
92
            if self.comp_fn == "sub":
KounianhuaDu's avatar
KounianhuaDu committed
93
                comp_h_s = n_in_feats - r_feats[-1]
94
            elif self.comp_fn == "mul":
KounianhuaDu's avatar
KounianhuaDu committed
95
                comp_h_s = n_in_feats * r_feats[-1]
96
            elif self.comp_fn == "ccorr":
KounianhuaDu's avatar
KounianhuaDu committed
97
98
                comp_h_s = ccorr(n_in_feats, r_feats[-1])
            else:
99
                raise Exception("Only supports sub, mul, and ccorr")
KounianhuaDu's avatar
KounianhuaDu committed
100
101

            # Sum all of the comp results as output of nodes and dropout
102
103
104
            n_out_feats = (
                self.W_S(comp_h_s) + self.dropout(g.ndata["comp_edge"])
            ) * (1 / 3)
KounianhuaDu's avatar
KounianhuaDu committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

            # Compute relation output
            r_out_feats = self.W_R(r_feats)

            # Batch norm
            if self.batchnorm:
                n_out_feats = self.bn(n_out_feats)

            # Activation function
            if self.actvation is not None:
                n_out_feats = self.actvation(n_out_feats)

        return n_out_feats, r_out_feats[:-1]


class CompGCN(nn.Module):
121
122
123
124
125
126
127
128
129
130
131
132
    def __init__(
        self,
        num_bases,
        num_rel,
        num_ent,
        in_dim=100,
        layer_size=[200],
        comp_fn="sub",
        batchnorm=True,
        dropout=0.1,
        layer_dropout=[0.3],
    ):
KounianhuaDu's avatar
KounianhuaDu committed
133
134
135
136
137
138
139
140
141
142
143
144
145
        super(CompGCN, self).__init__()

        self.num_bases = num_bases
        self.num_rel = num_rel
        self.num_ent = num_ent
        self.in_dim = in_dim
        self.layer_size = layer_size
        self.comp_fn = comp_fn
        self.batchnorm = batchnorm
        self.dropout = dropout
        self.layer_dropout = layer_dropout
        self.num_layer = len(layer_size)

146
        # CompGCN layers
KounianhuaDu's avatar
KounianhuaDu committed
147
148
        self.layers = nn.ModuleList()
        self.layers.append(
149
150
151
152
153
154
155
            CompGraphConv(
                self.in_dim,
                self.layer_size[0],
                comp_fn=self.comp_fn,
                batchnorm=self.batchnorm,
                dropout=self.dropout,
            )
KounianhuaDu's avatar
KounianhuaDu committed
156
        )
157
        for i in range(self.num_layer - 1):
KounianhuaDu's avatar
KounianhuaDu committed
158
            self.layers.append(
159
160
161
162
163
164
165
                CompGraphConv(
                    self.layer_size[i],
                    self.layer_size[i + 1],
                    comp_fn=self.comp_fn,
                    batchnorm=self.batchnorm,
                    dropout=self.dropout,
                )
KounianhuaDu's avatar
KounianhuaDu committed
166
167
            )

168
        # Initial relation embeddings
KounianhuaDu's avatar
KounianhuaDu committed
169
170
171
172
173
174
175
176
177
        if self.num_bases > 0:
            self.basis = nn.Parameter(th.Tensor(self.num_bases, self.in_dim))
            self.weights = nn.Parameter(th.Tensor(self.num_rel, self.num_bases))
            nn.init.xavier_normal_(self.basis)
            nn.init.xavier_normal_(self.weights)
        else:
            self.rel_embds = nn.Parameter(th.Tensor(self.num_rel, self.in_dim))
            nn.init.xavier_normal_(self.rel_embds)

178
        # Node embeddings
KounianhuaDu's avatar
KounianhuaDu committed
179
180
        self.n_embds = nn.Parameter(th.Tensor(self.num_ent, self.in_dim))
        nn.init.xavier_normal_(self.n_embds)
181
182

        # Dropout after compGCN layers
KounianhuaDu's avatar
KounianhuaDu committed
183
184
        self.dropouts = nn.ModuleList()
        for i in range(self.num_layer):
185
            self.dropouts.append(nn.Dropout(self.layer_dropout[i]))
KounianhuaDu's avatar
KounianhuaDu committed
186
187

    def forward(self, graph):
188
        # node and relation features
KounianhuaDu's avatar
KounianhuaDu committed
189
190
191
192
193
194
195
196
197
198
199
200
201
        n_feats = self.n_embds
        if self.num_bases > 0:
            r_embds = th.mm(self.weights, self.basis)
            r_feats = r_embds
        else:
            r_feats = self.rel_embds

        for layer, dropout in zip(self.layers, self.dropouts):
            n_feats, r_feats = layer(graph, n_feats, r_feats)
            n_feats = dropout(n_feats)

        return n_feats, r_feats

202
203

# Use convE as the score function
KounianhuaDu's avatar
KounianhuaDu committed
204
class CompGCN_ConvE(nn.Module):
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    def __init__(
        self,
        num_bases,
        num_rel,
        num_ent,
        in_dim,
        layer_size,
        comp_fn="sub",
        batchnorm=True,
        dropout=0.1,
        layer_dropout=[0.3],
        num_filt=200,
        hid_drop=0.3,
        feat_drop=0.3,
        ker_sz=5,
        k_w=5,
        k_h=5,
    ):
KounianhuaDu's avatar
KounianhuaDu committed
223
224
225
        super(CompGCN_ConvE, self).__init__()

        self.embed_dim = layer_size[-1]
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        self.hid_drop = hid_drop
        self.feat_drop = feat_drop
        self.ker_sz = ker_sz
        self.k_w = k_w
        self.k_h = k_h
        self.num_filt = num_filt

        # compGCN model to get sub/rel embs
        self.compGCN_Model = CompGCN(
            num_bases,
            num_rel,
            num_ent,
            in_dim,
            layer_size,
            comp_fn,
            batchnorm,
            dropout,
            layer_dropout,
        )

        # batchnorms to the combined (sub+rel) emb
KounianhuaDu's avatar
KounianhuaDu committed
247
248
249
        self.bn0 = th.nn.BatchNorm2d(1)
        self.bn1 = th.nn.BatchNorm2d(self.num_filt)
        self.bn2 = th.nn.BatchNorm1d(self.embed_dim)
250
251

        # dropouts and conv module to the combined (sub+rel) emb
KounianhuaDu's avatar
KounianhuaDu committed
252
253
        self.hidden_drop = th.nn.Dropout(self.hid_drop)
        self.feature_drop = th.nn.Dropout(self.feat_drop)
254
255
256
257
258
259
260
261
262
        self.m_conv1 = th.nn.Conv2d(
            1,
            out_channels=self.num_filt,
            kernel_size=(self.ker_sz, self.ker_sz),
            stride=1,
            padding=0,
            bias=False,
        )

KounianhuaDu's avatar
KounianhuaDu committed
263
264
265
        flat_sz_h = int(2 * self.k_w) - self.ker_sz + 1
        flat_sz_w = self.k_h - self.ker_sz + 1
        self.flat_sz = flat_sz_h * flat_sz_w * self.num_filt
266
        self.fc = th.nn.Linear(self.flat_sz, self.embed_dim)
KounianhuaDu's avatar
KounianhuaDu committed
267

268
        # bias to the score
KounianhuaDu's avatar
KounianhuaDu committed
269
        self.bias = nn.Parameter(th.zeros(num_ent))
270
271

    # combine entity embeddings and relation embeddings
KounianhuaDu's avatar
KounianhuaDu committed
272
273
274
275
    def concat(self, e1_embed, rel_embed):
        e1_embed = e1_embed.view(-1, 1, self.embed_dim)
        rel_embed = rel_embed.view(-1, 1, self.embed_dim)
        stack_inp = th.cat([e1_embed, rel_embed], 1)
276
277
278
        stack_inp = th.transpose(stack_inp, 2, 1).reshape(
            (-1, 1, 2 * self.k_w, self.k_h)
        )
KounianhuaDu's avatar
KounianhuaDu committed
279
        return stack_inp
280

KounianhuaDu's avatar
KounianhuaDu committed
281
    def forward(self, graph, sub, rel):
282
        # get sub_emb and rel_emb via compGCN
KounianhuaDu's avatar
KounianhuaDu committed
283
284
285
286
        n_feats, r_feats = self.compGCN_Model(graph)
        sub_emb = n_feats[sub, :]
        rel_emb = r_feats[rel, :]

287
        # combine the sub_emb and rel_emb
KounianhuaDu's avatar
KounianhuaDu committed
288
        stk_inp = self.concat(sub_emb, rel_emb)
289
        # use convE to score the combined emb
KounianhuaDu's avatar
KounianhuaDu committed
290
291
292
293
294
295
296
297
298
299
        x = self.bn0(stk_inp)
        x = self.m_conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.feature_drop(x)
        x = x.view(-1, self.flat_sz)
        x = self.fc(x)
        x = self.hidden_drop(x)
        x = self.bn2(x)
        x = F.relu(x)
300
301
302
        # compute score
        x = th.mm(x, n_feats.transpose(1, 0))
        # add in bias
KounianhuaDu's avatar
KounianhuaDu committed
303
304
305
        x += self.bias.expand_as(x)
        score = th.sigmoid(x)
        return score