"vscode:/vscode.git/clone" did not exist on "f46fad33c6dafc754c6527dffd3ca31f284d603a"
model.py 5.09 KB
Newer Older
1
import torch as th
Lingfan Yu's avatar
Lingfan Yu committed
2
3
import torch.nn as nn

4
5
import dgl

Lingfan Yu's avatar
Lingfan Yu committed
6
class BaseRGCN(nn.Module):
Minjie Wang's avatar
Minjie Wang committed
7
8
9
    def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases,
                 num_hidden_layers=1, dropout=0,
                 use_self_loop=False, use_cuda=False):
Lingfan Yu's avatar
Lingfan Yu committed
10
11
12
13
14
        super(BaseRGCN, self).__init__()
        self.num_nodes = num_nodes
        self.h_dim = h_dim
        self.out_dim = out_dim
        self.num_rels = num_rels
Minjie Wang's avatar
Minjie Wang committed
15
        self.num_bases = None if num_bases < 0 else num_bases
Lingfan Yu's avatar
Lingfan Yu committed
16
17
        self.num_hidden_layers = num_hidden_layers
        self.dropout = dropout
Minjie Wang's avatar
Minjie Wang committed
18
        self.use_self_loop = use_self_loop
Lingfan Yu's avatar
Lingfan Yu committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        self.use_cuda = use_cuda

        # create rgcn layers
        self.build_model()

    def build_model(self):
        self.layers = nn.ModuleList()
        # i2h
        i2h = self.build_input_layer()
        if i2h is not None:
            self.layers.append(i2h)
        # h2h
        for idx in range(self.num_hidden_layers):
            h2h = self.build_hidden_layer(idx)
            self.layers.append(h2h)
        # h2o
        h2o = self.build_output_layer()
        if h2o is not None:
            self.layers.append(h2o)

    def build_input_layer(self):
        return None

Zihao Ye's avatar
Zihao Ye committed
42
    def build_hidden_layer(self, idx):
Lingfan Yu's avatar
Lingfan Yu committed
43
44
45
46
47
        raise NotImplementedError

    def build_output_layer(self):
        return None

Minjie Wang's avatar
Minjie Wang committed
48
    def forward(self, g, h, r, norm):
Lingfan Yu's avatar
Lingfan Yu committed
49
        for layer in self.layers:
Minjie Wang's avatar
Minjie Wang committed
50
51
            h = layer(g, h, r, norm)
        return h
52

53
54
55
56
def initializer(emb):
    emb.uniform_(-1.0, 1.0)
    return emb

57
58
59
60
61
62
63
64
65
66
67
68
69
class RelGraphEmbedLayer(nn.Module):
    r"""Embedding layer for featureless heterograph.
    Parameters
    ----------
    dev_id : int
        Device to run the layer.
    num_nodes : int
        Number of nodes.
    node_tides : tensor
        Storing the node type id for each node starting from 0
    num_of_ntype : int
        Number of node types
    input_size : list of int
70
        A list of input feature size for each node type. If None, we then
71
72
73
        treat certain input feature as an one-hot encoding feature.
    embed_size : int
        Output embed size
74
75
    dgl_sparse : bool, optional
        If true, use dgl.nn.NodeEmbedding otherwise use torch.nn.Embedding
76
77
78
79
80
81
82
83
    """
    def __init__(self,
                 dev_id,
                 num_nodes,
                 node_tids,
                 num_of_ntype,
                 input_size,
                 embed_size,
84
                 dgl_sparse=False):
85
        super(RelGraphEmbedLayer, self).__init__()
86
        self.dev_id = th.device(dev_id if dev_id >= 0 else 'cpu')
87
88
        self.embed_size = embed_size
        self.num_nodes = num_nodes
89
        self.dgl_sparse = dgl_sparse
90
91
92

        # create weight embeddings for each node for each relation
        self.embeds = nn.ParameterDict()
93
        self.node_embeds = {} if dgl_sparse else nn.ModuleDict()
94
95
96
        self.num_of_ntype = num_of_ntype

        for ntype in range(num_of_ntype):
97
98
99
100
101
102
103
104
105
            if isinstance(input_size[ntype], int):
                if dgl_sparse:
                    self.node_embeds[str(ntype)] = dgl.nn.NodeEmbedding(input_size[ntype], embed_size, name=str(ntype),
                        init_func=initializer)
                else:
                    sparse_emb = th.nn.Embedding(input_size[ntype], embed_size, sparse=True)
                    nn.init.uniform_(sparse_emb.weight, -1.0, 1.0)
                    self.node_embeds[str(ntype)] = sparse_emb
            else:
106
                input_emb_size = input_size[ntype].shape[1]
107
                embed = nn.Parameter(th.Tensor(input_emb_size, self.embed_size))
108
                nn.init.xavier_uniform_(embed)
109
110
                self.embeds[str(ntype)] = embed

111
112
113
114
115
116
117
118
119
    @property
    def dgl_emb(self):
        """
        """
        if self.dgl_sparse:
            embs = [emb for emb in self.node_embeds.values()]
            return embs
        else:
            return []
120

121
    def forward(self, node_ids, node_tids, type_ids, features):
122
123
124
125
126
127
128
129
130
131
        """Forward computation
        Parameters
        ----------
        node_ids : tensor
            node ids to generate embedding for.
        node_ids : tensor
            node type ids
        features : list of features
            list of initial features for nodes belong to different node type.
            If None, the corresponding features is an one-hot encoding feature,
132
            else use the features directly as input feature and matmul a
133
134
135
136
137
138
            projection matrix.
        Returns
        -------
        tensor
            embeddings as the input of the next layer
        """
139
        tsd_ids = node_ids.to(self.dev_id)
140
        embeds = th.empty(node_ids.shape[0], self.embed_size, device=self.dev_id)
141
        for ntype in range(self.num_of_ntype):
142
143
144
145
146
147
            loc = node_tids == ntype
            if isinstance(features[ntype], int):
                if self.dgl_sparse:
                    embeds[loc] = self.node_embeds[str(ntype)](type_ids[loc], self.dev_id)
                else:
                    embeds[loc] = self.node_embeds[str(ntype)](type_ids[loc]).to(self.dev_id)
148
            else:
149
                embeds[loc] = features[ntype][type_ids[loc]].to(self.dev_id) @ self.embeds[str(ntype)].to(self.dev_id)
150

151
        return embeds