"vscode:/vscode.git/clone" did not exist on "8bee20f80b47f1b79eb87e3d53b117d84d4ff948"
model.py 2.52 KB
Newer Older
Mufei Li's avatar
Mufei Li committed
1
from dgl import DGLGraph
2
import torch as th
Lingfan Yu's avatar
Lingfan Yu committed
3
import torch.nn as nn
Mufei Li's avatar
Mufei Li committed
4
import torch.nn.functional as F
5
6
import dgl

Mufei Li's avatar
Mufei Li committed
7
from dgl.nn.pytorch import RelGraphConv
Lingfan Yu's avatar
Lingfan Yu committed
8

Mufei Li's avatar
Mufei Li committed
9
10
11
12
13
class RGCN(nn.Module):
    def __init__(self, in_dim, h_dim, out_dim, num_rels,
                 regularizer="basis", num_bases=-1, dropout=0.,
                 self_loop=False, link_pred=False):
        super(RGCN, self).__init__()
Lingfan Yu's avatar
Lingfan Yu committed
14
15

        self.layers = nn.ModuleList()
Mufei Li's avatar
Mufei Li committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
        if link_pred:
            self.emb = nn.Embedding(in_dim, h_dim)
            in_dim = h_dim
        else:
            self.emb = None
        self.layers.append(RelGraphConv(in_dim, h_dim, num_rels, regularizer,
                                        num_bases, activation=F.relu, self_loop=self_loop,
                                        dropout=dropout))

        # For entity classification, dropout should not be applied to the output layer
        if not link_pred:
            dropout = 0.
        self.layers.append(RelGraphConv(h_dim, out_dim, num_rels, regularizer,
                                        num_bases, self_loop=self_loop, dropout=dropout))

    def forward(self, g, h):
        if isinstance(g, DGLGraph):
            blocks = [g] * len(self.layers)
        else:
            blocks = g
Lingfan Yu's avatar
Lingfan Yu committed
36

Mufei Li's avatar
Mufei Li committed
37
38
        if self.emb is not None:
            h = self.emb(h.squeeze())
Lingfan Yu's avatar
Lingfan Yu committed
39

Mufei Li's avatar
Mufei Li committed
40
41
        for layer, block in zip(self.layers, blocks):
            h = layer(block, h, block.edata[dgl.ETYPE], block.edata['norm'])
Minjie Wang's avatar
Minjie Wang committed
42
        return h
43

44
45
46
47
def initializer(emb):
    emb.uniform_(-1.0, 1.0)
    return emb

48
class RelGraphEmbedLayer(nn.Module):
Mufei Li's avatar
Mufei Li committed
49
50
    """Embedding layer for featureless heterograph.

51
52
    Parameters
    ----------
Mufei Li's avatar
Mufei Li committed
53
54
    out_dev
        Device to store the output embeddings
55
    num_nodes : int
Mufei Li's avatar
Mufei Li committed
56
        Number of nodes in the graph.
57
58
59
60
    embed_size : int
        Output embed size
    """
    def __init__(self,
Mufei Li's avatar
Mufei Li committed
61
                 out_dev,
62
                 num_nodes,
Mufei Li's avatar
Mufei Li committed
63
                 embed_size):
64
        super(RelGraphEmbedLayer, self).__init__()
Mufei Li's avatar
Mufei Li committed
65
        self.out_dev = out_dev
66
67
        self.embed_size = embed_size

Mufei Li's avatar
Mufei Li committed
68
69
70
        # create embeddings for all nodes
        self.node_embed = nn.Embedding(num_nodes, embed_size, sparse=True)
        nn.init.uniform_(self.node_embed.weight, -1.0, 1.0)
71

Mufei Li's avatar
Mufei Li committed
72
    def forward(self, node_ids):
73
        """Forward computation
Mufei Li's avatar
Mufei Li committed
74

75
76
77
        Parameters
        ----------
        node_ids : tensor
Mufei Li's avatar
Mufei Li committed
78
79
            Raw node IDs.

80
81
82
83
84
        Returns
        -------
        tensor
            embeddings as the input of the next layer
        """
Mufei Li's avatar
Mufei Li committed
85
        embeds = self.node_embed(node_ids).to(self.out_dev)
86

87
        return embeds