"""RGCN layer implementation""" from collections import defaultdict import torch as th import torch.nn as nn import torch.nn.functional as F import dgl import dgl.nn as dglnn import tqdm class RelGraphConvLayer(nn.Module): r"""Relational graph convolution layer. Parameters ---------- in_feat : int Input feature size. out_feat : int Output feature size. rel_names : list[str] Relation names. num_bases : int, optional Number of bases. If is none, use number of relations. Default: None. weight : bool, optional True if a linear layer is applied after message passing. Default: True bias : bool, optional True if bias is added. Default: True activation : callable, optional Activation function. Default: None self_loop : bool, optional True to include self loop message. Default: False dropout : float, optional Dropout rate. Default: 0.0 """ def __init__(self, in_feat, out_feat, rel_names, num_bases, *, weight=True, bias=True, activation=None, self_loop=False, dropout=0.0): super(RelGraphConvLayer, self).__init__() self.in_feat = in_feat self.out_feat = out_feat self.rel_names = rel_names self.num_bases = num_bases self.bias = bias self.activation = activation self.self_loop = self_loop self.conv = dglnn.HeteroGraphConv({ rel : dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False) for rel in rel_names }) self.use_weight = weight self.use_basis = num_bases < len(self.rel_names) and weight if self.use_weight: if self.use_basis: self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.rel_names)) else: self.weight = nn.Parameter(th.Tensor(len(self.rel_names), in_feat, out_feat)) nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) # bias if bias: self.h_bias = nn.Parameter(th.Tensor(out_feat)) nn.init.zeros_(self.h_bias) # weight for self loop if self.self_loop: self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat)) nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('relu')) self.dropout = nn.Dropout(dropout) def forward(self, g, inputs): """Forward computation Parameters ---------- g : DGLHeteroGraph Input graph. inputs : dict[str, torch.Tensor] Node feature for each node type. Returns ------- dict[str, torch.Tensor] New node features for each node type. """ g = g.local_var() if self.use_weight: weight = self.basis() if self.use_basis else self.weight wdict = {self.rel_names[i] : {'weight' : w.squeeze(0)} for i, w in enumerate(th.split(weight, 1, dim=0))} else: wdict = {} inputs_src, inputs_dst = dglnn.expand_as_pair(inputs, g) hs = self.conv(g, inputs, mod_kwargs=wdict) def _apply(ntype, h): if self.self_loop: h = h + th.matmul(inputs_dst[ntype], self.loop_weight) if self.bias: h = h + self.h_bias if self.activation: h = self.activation(h) return self.dropout(h) return {ntype : _apply(ntype, h) for ntype, h in hs.items()} class RelGraphEmbed(nn.Module): r"""Embedding layer for featureless heterograph.""" def __init__(self, g, embed_size, embed_name='embed', activation=None, dropout=0.0): super(RelGraphEmbed, self).__init__() self.g = g self.embed_size = embed_size self.embed_name = embed_name self.activation = activation self.dropout = nn.Dropout(dropout) # create weight embeddings for each node for each relation self.embeds = nn.ParameterDict() for ntype in g.ntypes: embed = nn.Parameter(th.Tensor(g.number_of_nodes(ntype), self.embed_size)) nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu')) self.embeds[ntype] = embed def forward(self, block=None): """Forward computation Parameters ---------- block : DGLHeteroGraph, optional If not specified, directly return the full graph with embeddings stored in :attr:`embed_name`. Otherwise, extract and store the embeddings to the block graph and return. Returns ------- DGLHeteroGraph The block graph fed with embeddings. """ return self.embeds class EntityClassify(nn.Module): def __init__(self, g, h_dim, out_dim, num_bases, num_hidden_layers=1, dropout=0, use_self_loop=False): super(EntityClassify, self).__init__() self.g = g self.h_dim = h_dim self.out_dim = out_dim self.rel_names = list(set(g.etypes)) self.rel_names.sort() if num_bases < 0 or num_bases > len(self.rel_names): self.num_bases = len(self.rel_names) else: self.num_bases = num_bases self.num_hidden_layers = num_hidden_layers self.dropout = dropout self.use_self_loop = use_self_loop self.embed_layer = RelGraphEmbed(g, self.h_dim) self.layers = nn.ModuleList() # i2h self.layers.append(RelGraphConvLayer( self.h_dim, self.h_dim, self.rel_names, self.num_bases, activation=F.relu, self_loop=self.use_self_loop, dropout=self.dropout, weight=False)) # h2h for i in range(self.num_hidden_layers): self.layers.append(RelGraphConvLayer( self.h_dim, self.h_dim, self.rel_names, self.num_bases, activation=F.relu, self_loop=self.use_self_loop, dropout=self.dropout)) # h2o self.layers.append(RelGraphConvLayer( self.h_dim, self.out_dim, self.rel_names, self.num_bases, activation=None, self_loop=self.use_self_loop)) def forward(self, h=None, blocks=None): if h is None: # full graph training h = self.embed_layer() if blocks is None: # full graph training for layer in self.layers: h = layer(self.g, h) else: # minibatch training for layer, block in zip(self.layers, blocks): h = layer(block, h) return h def inference(self, g, batch_size, device, num_workers, x=None): """Minibatch inference of final representation over all node types. ***NOTE*** For node classification, the model is trained to predict on only one node type's label. Therefore, only that type's final representation is meaningful. """ if x is None: x = self.embed_layer() for l, layer in enumerate(self.layers): y = { k: th.zeros( g.number_of_nodes(k), self.h_dim if l != len(self.layers) - 1 else self.out_dim) for k in g.ntypes} sampler = dgl.sampling.MultiLayerNeighborSampler([None]) dataloader = dgl.sampling.NodeDataLoader( g, {k: th.arange(g.number_of_nodes(k)) for k in g.ntypes}, sampler, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=num_workers) for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): block = blocks[0] h = {k: x[k][input_nodes[k]].to(device) for k in input_nodes.keys()} h = layer(block, h) for k in h.keys(): y[k][output_nodes[k]] = h[k].cpu() x = y return y