model.py 5.82 KB
Newer Older
1
import torch as th
Lingfan Yu's avatar
Lingfan Yu committed
2
3
import torch.nn as nn

4
5
import dgl

Lingfan Yu's avatar
Lingfan Yu committed
6
class BaseRGCN(nn.Module):
Minjie Wang's avatar
Minjie Wang committed
7
8
9
    def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases,
                 num_hidden_layers=1, dropout=0,
                 use_self_loop=False, use_cuda=False):
Lingfan Yu's avatar
Lingfan Yu committed
10
11
12
13
14
        super(BaseRGCN, self).__init__()
        self.num_nodes = num_nodes
        self.h_dim = h_dim
        self.out_dim = out_dim
        self.num_rels = num_rels
Minjie Wang's avatar
Minjie Wang committed
15
        self.num_bases = None if num_bases < 0 else num_bases
Lingfan Yu's avatar
Lingfan Yu committed
16
17
        self.num_hidden_layers = num_hidden_layers
        self.dropout = dropout
Minjie Wang's avatar
Minjie Wang committed
18
        self.use_self_loop = use_self_loop
Lingfan Yu's avatar
Lingfan Yu committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        self.use_cuda = use_cuda

        # create rgcn layers
        self.build_model()

    def build_model(self):
        self.layers = nn.ModuleList()
        # i2h
        i2h = self.build_input_layer()
        if i2h is not None:
            self.layers.append(i2h)
        # h2h
        for idx in range(self.num_hidden_layers):
            h2h = self.build_hidden_layer(idx)
            self.layers.append(h2h)
        # h2o
        h2o = self.build_output_layer()
        if h2o is not None:
            self.layers.append(h2o)

    def build_input_layer(self):
        return None

Zihao Ye's avatar
Zihao Ye committed
42
    def build_hidden_layer(self, idx):
Lingfan Yu's avatar
Lingfan Yu committed
43
44
45
46
47
        raise NotImplementedError

    def build_output_layer(self):
        return None

Minjie Wang's avatar
Minjie Wang committed
48
    def forward(self, g, h, r, norm):
Lingfan Yu's avatar
Lingfan Yu committed
49
        for layer in self.layers:
Minjie Wang's avatar
Minjie Wang committed
50
51
            h = layer(g, h, r, norm)
        return h
52

53
54
55
56
def initializer(emb):
    emb.uniform_(-1.0, 1.0)
    return emb

57
58
59
60
class RelGraphEmbedLayer(nn.Module):
    r"""Embedding layer for featureless heterograph.
    Parameters
    ----------
61
62
63
64
    storage_dev_id : int
        The device to store the weights of the layer.
    out_dev_id : int
        Device to return the output embeddings on.
65
66
67
68
69
70
71
    num_nodes : int
        Number of nodes.
    node_tides : tensor
        Storing the node type id for each node starting from 0
    num_of_ntype : int
        Number of node types
    input_size : list of int
72
        A list of input feature size for each node type. If None, we then
73
74
75
        treat certain input feature as an one-hot encoding feature.
    embed_size : int
        Output embed size
76
77
    dgl_sparse : bool, optional
        If true, use dgl.nn.NodeEmbedding otherwise use torch.nn.Embedding
78
79
    """
    def __init__(self,
80
81
                 storage_dev_id,
                 out_dev_id,
82
83
84
85
86
                 num_nodes,
                 node_tids,
                 num_of_ntype,
                 input_size,
                 embed_size,
87
                 dgl_sparse=False):
88
        super(RelGraphEmbedLayer, self).__init__()
89
90
91
        self.storage_dev_id = th.device( \
            storage_dev_id if storage_dev_id >= 0 else 'cpu')
        self.out_dev_id = th.device(out_dev_id if out_dev_id >= 0 else 'cpu')
92
93
        self.embed_size = embed_size
        self.num_nodes = num_nodes
94
        self.dgl_sparse = dgl_sparse
95
96
97

        # create weight embeddings for each node for each relation
        self.embeds = nn.ParameterDict()
98
        self.node_embeds = {} if dgl_sparse else nn.ModuleDict()
99
100
101
        self.num_of_ntype = num_of_ntype

        for ntype in range(num_of_ntype):
102
103
104
            if isinstance(input_size[ntype], int):
                if dgl_sparse:
                    self.node_embeds[str(ntype)] = dgl.nn.NodeEmbedding(input_size[ntype], embed_size, name=str(ntype),
105
                        init_func=initializer, device=self.storage_dev_id)
106
107
                else:
                    sparse_emb = th.nn.Embedding(input_size[ntype], embed_size, sparse=True)
108
                    sparse_emb.cuda(self.storage_dev_id)
109
110
111
                    nn.init.uniform_(sparse_emb.weight, -1.0, 1.0)
                    self.node_embeds[str(ntype)] = sparse_emb
            else:
112
                input_emb_size = input_size[ntype].shape[1]
113
114
                embed = nn.Parameter(th.empty([input_emb_size, self.embed_size],
                                              device=self.storage_dev_id))
115
                nn.init.xavier_uniform_(embed)
116
117
                self.embeds[str(ntype)] = embed

118
119
120
121
122
123
124
125
126
    @property
    def dgl_emb(self):
        """
        """
        if self.dgl_sparse:
            embs = [emb for emb in self.node_embeds.values()]
            return embs
        else:
            return []
127

128
    def forward(self, node_ids, node_tids, type_ids, features):
129
130
131
132
133
134
135
136
137
138
        """Forward computation
        Parameters
        ----------
        node_ids : tensor
            node ids to generate embedding for.
        node_ids : tensor
            node type ids
        features : list of features
            list of initial features for nodes belong to different node type.
            If None, the corresponding features is an one-hot encoding feature,
139
            else use the features directly as input feature and matmul a
140
141
142
143
144
145
            projection matrix.
        Returns
        -------
        tensor
            embeddings as the input of the next layer
        """
146
147
148
149
150
151
152
153
154
155
        embeds = th.empty(node_ids.shape[0], self.embed_size, device=self.out_dev_id)

        # transfer input to the correct device
        type_ids = type_ids.to(self.storage_dev_id)
        node_tids = node_tids.to(self.storage_dev_id)

        # build locs first
        locs = [None for i in range(self.num_of_ntype)]
        for ntype in range(self.num_of_ntype):
            locs[ntype] = (node_tids == ntype).nonzero().squeeze(-1)
156
        for ntype in range(self.num_of_ntype):
157
            loc = locs[ntype]
158
159
            if isinstance(features[ntype], int):
                if self.dgl_sparse:
160
                    embeds[loc] = self.node_embeds[str(ntype)](type_ids[loc], self.out_dev_id)
161
                else:
162
                    embeds[loc] = self.node_embeds[str(ntype)](type_ids[loc]).to(self.out_dev_id)
163
            else:
164
                embeds[loc] = features[ntype][type_ids[loc]].to(self.out_dev_id) @ self.embeds[str(ntype)].to(self.out_dev_id)
165

166
        return embeds