model.py 10.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
from typing import Callable, Dict, List, Union

import dgl
import dgl.nn.pytorch as dglnn
import torch
import torch.nn as nn


class RelGraphEmbedding(nn.Module):
    def __init__(
        self,
        hg: dgl.DGLHeteroGraph,
        embedding_size: int,
        num_nodes: Dict[str, int],
        node_feats: Dict[str, torch.Tensor],
        node_feats_projection: bool = False,
    ):
        super().__init__()
        self._hg = hg
        self._node_feats = node_feats
        self._node_feats_projection = node_feats_projection
        self.node_embeddings = nn.ModuleDict()

        if node_feats_projection:
            self.embeddings = nn.ParameterDict()

        for ntype in hg.ntypes:
            if node_feats[ntype] is None:
                node_embedding = nn.Embedding(
                    num_nodes[ntype], embedding_size, sparse=True)
                nn.init.uniform_(node_embedding.weight, -1, 1)

                self.node_embeddings[ntype] = node_embedding
            elif node_feats[ntype] is not None and node_feats_projection:
                input_embedding_size = node_feats[ntype].shape[-1]
                embedding = nn.Parameter(torch.Tensor(
                    input_embedding_size, embedding_size))
                nn.init.xavier_uniform_(embedding)

                self.embeddings[ntype] = embedding

    def forward(
        self,
        in_nodes: Dict[str, torch.Tensor] = None,
        device: torch.device = None,
    ) -> Dict[str, torch.Tensor]:
        if in_nodes is not None:
            ntypes = [ntype for ntype in in_nodes.keys()]
            nids = [nid for nid in in_nodes.values()]
        else:
            ntypes = self._hg.ntypes
            nids = [self._hg.nodes(ntype) for ntype in ntypes]

        x = {}

        for ntype, nid in zip(ntypes, nids):
            if self._node_feats[ntype] is None:
                x[ntype] = self.node_embeddings[ntype](nid)
            else:
                if device is not None:
                    self._node_feats[ntype] = self._node_feats[ntype].to(
                        device)

                if self._node_feats_projection:
                    x[ntype] = self._node_feats[ntype][nid] @ self.embeddings[ntype]
                else:
                    x[ntype] = self._node_feats[ntype][nid]

        return x


class RelGraphConvLayer(nn.Module):
    def __init__(
        self,
        in_feats: int,
        out_feats: int,
        rel_names: List[str],
        num_bases: int,
        norm: str = 'right',
        weight: bool = True,
        bias: bool = True,
        activation: Callable[[torch.Tensor], torch.Tensor] = None,
        dropout: float = None,
        self_loop: bool = False,
    ):
        super().__init__()
        self._rel_names = rel_names
        self._num_rels = len(rel_names)
        self._conv = dglnn.HeteroGraphConv({rel: dglnn.GraphConv(
            in_feats, out_feats, norm=norm, weight=False, bias=False) for rel in rel_names})
        self._use_weight = weight
        self._use_basis = num_bases < self._num_rels and weight
        self._use_bias = bias
        self._activation = activation
        self._dropout = nn.Dropout(dropout) if dropout is not None else None
        self._use_self_loop = self_loop

        if weight:
            if self._use_basis:
                self.basis = dglnn.WeightBasis(
                    (in_feats, out_feats), num_bases, self._num_rels)
            else:
                self.weight = nn.Parameter(torch.Tensor(
                    self._num_rels, in_feats, out_feats))
                nn.init.xavier_uniform_(
                    self.weight, gain=nn.init.calculate_gain('relu'))

        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_feats))
            nn.init.zeros_(self.bias)

        if self_loop:
            self.self_loop_weight = nn.Parameter(
                torch.Tensor(in_feats, out_feats))
            nn.init.xavier_uniform_(
                self.self_loop_weight, gain=nn.init.calculate_gain('relu'))

    def _apply_layers(
        self,
        ntype: str,
        inputs: torch.Tensor,
        inputs_dst: torch.Tensor = None,
    ) -> torch.Tensor:
        x = inputs

        if inputs_dst is not None:
            x += torch.matmul(inputs_dst[ntype], self.self_loop_weight)

        if self._use_bias:
            x += self.bias

        if self._activation is not None:
            x = self._activation(x)

        if self._dropout is not None:
            x = self._dropout(x)

        return x

    def forward(
        self,
        hg: dgl.DGLHeteroGraph,
        inputs: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        hg = hg.local_var()

        if self._use_weight:
            weight = self.basis() if self._use_basis else self.weight
            weight_dict = {self._rel_names[i]: {'weight': w.squeeze(
                dim=0)} for i, w in enumerate(torch.split(weight, 1, dim=0))}
        else:
            weight_dict = {}

        if self._use_self_loop:
            if hg.is_block:
                inputs_dst = {ntype: h[:hg.num_dst_nodes(
                    ntype)] for ntype, h in inputs.items()}
            else:
                inputs_dst = inputs
        else:
            inputs_dst = None

        x = self._conv(hg, inputs, mod_kwargs=weight_dict)
        x = {ntype: self._apply_layers(ntype, h, inputs_dst)
             for ntype, h in x.items()}

        return x


class EntityClassify(nn.Module):
    def __init__(
        self,
        hg: dgl.DGLHeteroGraph,
        in_feats: int,
        hidden_feats: int,
        out_feats: int,
        num_bases: int,
        num_layers: int,
        norm: str = 'right',
        layer_norm: bool = False,
        input_dropout: float = 0,
        dropout: float = 0,
        activation: Callable[[torch.Tensor], torch.Tensor] = None,
        self_loop: bool = False,
    ):
        super().__init__()
        self._hidden_feats = hidden_feats
        self._out_feats = out_feats
        self._num_layers = num_layers
        self._input_dropout = nn.Dropout(input_dropout)
        self._dropout = nn.Dropout(dropout)
        self._activation = activation
        self._rel_names = sorted(list(set(hg.etypes)))
        self._num_rels = len(self._rel_names)

        if num_bases < 0 or num_bases > self._num_rels:
            self._num_bases = self._num_rels
        else:
            self._num_bases = num_bases

        self._layers = nn.ModuleList()

        self._layers.append(RelGraphConvLayer(
            in_feats,
            hidden_feats,
            self._rel_names,
            self._num_bases,
            norm=norm,
            self_loop=self_loop,
        ))

        for _ in range(1, num_layers - 1):
            self._layers.append(RelGraphConvLayer(
                hidden_feats,
                hidden_feats,
                self._rel_names,
                self._num_bases,
                norm=norm,
                self_loop=self_loop,
            ))

        self._layers.append(RelGraphConvLayer(
            hidden_feats,
            out_feats,
            self._rel_names,
            self._num_bases,
            norm=norm,
            self_loop=self_loop,
        ))

        if layer_norm:
            self._layer_norms = nn.ModuleList()

            for _ in range(num_layers - 1):
                self._layer_norms.append(nn.LayerNorm(hidden_feats))
        else:
            self._layer_norms = None

    def _apply_layers(
        self,
        layer_idx: int,
        inputs: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        x = inputs

        for ntype, h in x.items():
            if self._layer_norms is not None:
                h = self._layer_norms[layer_idx](h)

            if self._activation is not None:
                h = self._activation(h)

            x[ntype] = self._dropout(h)

        return x

    def forward(
        self,
        hg: Union[dgl.DGLHeteroGraph, List[dgl.DGLHeteroGraph]],
        inputs: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        x = {ntype: self._input_dropout(h) for ntype, h in inputs.items()}

        if isinstance(hg, list):
            for i, (layer, block) in enumerate(zip(self._layers, hg)):
                x = layer(block, x)

                if i < self._num_layers - 1:
                    x = self._apply_layers(i, x)
        else:
            for i, layer in enumerate(self._layers):
                x = layer(hg, x)

                if i < self._num_layers - 1:
                    x = self._apply_layers(i, x)

        return x

    def inference(
        self,
        hg: dgl.DGLHeteroGraph,
        batch_size: int,
        num_workers: int,
        embedding_layer: nn.Module,
        device: torch.device,
    ) -> Dict[str, torch.Tensor]:
        for i, layer in enumerate(self._layers):
            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
289
            dataloader = dgl.dataloading.DataLoader(
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
                hg,
                {ntype: hg.nodes(ntype) for ntype in hg.ntypes},
                sampler,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
                num_workers=num_workers,
            )

            if i < self._num_layers - 1:
                y = {ntype: torch.zeros(hg.num_nodes(
                    ntype), self._hidden_feats, device=device) for ntype in hg.ntypes}
            else:
                y = {ntype: torch.zeros(hg.num_nodes(
                    ntype), self._out_feats, device=device) for ntype in hg.ntypes}

            for in_nodes, out_nodes, blocks in dataloader:
                in_nodes = {rel: nid.to(device)
                            for rel, nid in in_nodes.items()}
                out_nodes = {rel: nid.to(device)
                             for rel, nid in out_nodes.items()}
                block = blocks[0].to(device)

                if i == 0:
                    h = embedding_layer(in_nodes=in_nodes, device=device)
                else:
                    h = {ntype: x[ntype][in_nodes[ntype]]
                         for ntype in hg.ntypes}

                h = layer(block, h)

                if i < self._num_layers - 1:
                    h = self._apply_layers(i, h)

                for ntype in h:
                    y[ntype][out_nodes[ntype]] = h[ntype]

            x = y

        return x