model.py 1.39 KB
Newer Older
Lingfan Yu's avatar
Lingfan Yu committed
1
2
3
import torch.nn as nn

class BaseRGCN(nn.Module):
Minjie Wang's avatar
Minjie Wang committed
4
5
6
    def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases,
                 num_hidden_layers=1, dropout=0,
                 use_self_loop=False, use_cuda=False):
Lingfan Yu's avatar
Lingfan Yu committed
7
8
9
10
11
        super(BaseRGCN, self).__init__()
        self.num_nodes = num_nodes
        self.h_dim = h_dim
        self.out_dim = out_dim
        self.num_rels = num_rels
Minjie Wang's avatar
Minjie Wang committed
12
        self.num_bases = None if num_bases < 0 else num_bases
Lingfan Yu's avatar
Lingfan Yu committed
13
14
        self.num_hidden_layers = num_hidden_layers
        self.dropout = dropout
Minjie Wang's avatar
Minjie Wang committed
15
        self.use_self_loop = use_self_loop
Lingfan Yu's avatar
Lingfan Yu committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
        self.use_cuda = use_cuda

        # create rgcn layers
        self.build_model()

    def build_model(self):
        self.layers = nn.ModuleList()
        # i2h
        i2h = self.build_input_layer()
        if i2h is not None:
            self.layers.append(i2h)
        # h2h
        for idx in range(self.num_hidden_layers):
            h2h = self.build_hidden_layer(idx)
            self.layers.append(h2h)
        # h2o
        h2o = self.build_output_layer()
        if h2o is not None:
            self.layers.append(h2o)

    def build_input_layer(self):
        return None

Zihao Ye's avatar
Zihao Ye committed
39
    def build_hidden_layer(self, idx):
Lingfan Yu's avatar
Lingfan Yu committed
40
41
42
43
44
        raise NotImplementedError

    def build_output_layer(self):
        return None

Minjie Wang's avatar
Minjie Wang committed
45
    def forward(self, g, h, r, norm):
Lingfan Yu's avatar
Lingfan Yu committed
46
        for layer in self.layers:
Minjie Wang's avatar
Minjie Wang committed
47
48
            h = layer(g, h, r, norm)
        return h