import torch.nn as nn class BaseRGCN(nn.Module): def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases, num_hidden_layers=1, dropout=0, use_self_loop=False, use_cuda=False): super(BaseRGCN, self).__init__() self.num_nodes = num_nodes self.h_dim = h_dim self.out_dim = out_dim self.num_rels = num_rels self.num_bases = None if num_bases < 0 else num_bases self.num_hidden_layers = num_hidden_layers self.dropout = dropout self.use_self_loop = use_self_loop self.use_cuda = use_cuda # create rgcn layers self.build_model() def build_model(self): self.layers = nn.ModuleList() # i2h i2h = self.build_input_layer() if i2h is not None: self.layers.append(i2h) # h2h for idx in range(self.num_hidden_layers): h2h = self.build_hidden_layer(idx) self.layers.append(h2h) # h2o h2o = self.build_output_layer() if h2o is not None: self.layers.append(h2o) def build_input_layer(self): return None def build_hidden_layer(self, idx): raise NotImplementedError def build_output_layer(self): return None def forward(self, g, h, r, norm): for layer in self.layers: h = layer(g, h, r, norm) return h