Unverified Commit 708765f0 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[NN] RGCN modules (#744)

* rgcn module

* support id input

* WIP: model codes

* use faster index select

* dropout

* self loop

* WIP: link prediction

* fix lint

* WIP: docs

* docstring

* docstring

* merge two child classes

* mxnet rgcn module

* fix lint

* fix lint

* fix rename bug

* add uniform edge sampler

* fix fn name

* docstring

* fix mxnet rgcn module

* fix mx rgcn

* enable test on cuda
parent 52d4535b
...@@ -5,6 +5,7 @@ Requirements ...@@ -5,6 +5,7 @@ Requirements
------------ ------------
* sphinx * sphinx
* sphinx-gallery * sphinx-gallery
* sphinx_rtd_theme
* Both pytorch and mxnet installed. * Both pytorch and mxnet installed.
Build documents Build documents
......
...@@ -12,6 +12,10 @@ dgl.nn.mxnet.conv ...@@ -12,6 +12,10 @@ dgl.nn.mxnet.conv
:members: weight, bias, forward :members: weight, bias, forward
:show-inheritance: :show-inheritance:
.. autoclass:: dgl.nn.mxnet.conv.RelGraphConv
:members: forward
:show-inheritance:
dgl.nn.mxnet.glob dgl.nn.mxnet.glob
----------------- -----------------
......
...@@ -12,6 +12,10 @@ dgl.nn.pytorch.conv ...@@ -12,6 +12,10 @@ dgl.nn.pytorch.conv
:members: weight, bias, forward, reset_parameters :members: weight, bias, forward, reset_parameters
:show-inheritance: :show-inheritance:
.. autoclass:: dgl.nn.pytorch.conv.RelGraphConv
:members: forward
:show-inheritance:
dgl.nn.pytorch.glob dgl.nn.pytorch.glob
------------------- -------------------
.. automodule:: dgl.nn.pytorch.glob .. automodule:: dgl.nn.pytorch.glob
......
...@@ -25,12 +25,12 @@ AIFB: accuracy 97.22% (DGL), 95.83% (paper) ...@@ -25,12 +25,12 @@ AIFB: accuracy 97.22% (DGL), 95.83% (paper)
DGLBACKEND=mxnet python3 entity_classify.py -d aifb --testing --gpu 0 DGLBACKEND=mxnet python3 entity_classify.py -d aifb --testing --gpu 0
``` ```
MUTAG: accuracy 76.47% (DGL), 73.23% (paper) MUTAG: accuracy 73.53% (DGL), 73.23% (paper)
``` ```
DGLBACKEND=mxnet python3 entity_classify.py -d mutag --l2norm 5e-4 --n-bases 40 --testing --gpu 0 DGLBACKEND=mxnet python3 entity_classify.py -d mutag --l2norm 5e-4 --n-bases 40 --testing --gpu 0
``` ```
BGS: accuracy 79.31% (DGL, n-basese=20, OOM when >20), 83.10% (paper) BGS: accuracy 75.86% (DGL, n-basese=20, OOM when >20), 83.10% (paper)
``` ```
DGLBACKEND=mxnet python3 entity_classify.py -d bgs --l2norm 5e-4 --n-bases 20 --testing --gpu 0 --relabel DGLBACKEND=mxnet python3 entity_classify.py -d bgs --l2norm 5e-4 --n-bases 20 --testing --gpu 0 --relabel
``` ```
...@@ -15,32 +15,27 @@ import mxnet as mx ...@@ -15,32 +15,27 @@ import mxnet as mx
from mxnet import gluon from mxnet import gluon
import mxnet.ndarray as F import mxnet.ndarray as F
from dgl import DGLGraph from dgl import DGLGraph
from dgl.nn.mxnet import RelGraphConv
from dgl.contrib.data import load_data from dgl.contrib.data import load_data
from functools import partial from functools import partial
from model import BaseRGCN from model import BaseRGCN
from layers import RGCNBasisLayer as RGCNLayer
class EntityClassify(BaseRGCN): class EntityClassify(BaseRGCN):
def create_features(self):
features = mx.nd.arange(self.num_nodes)
if self.gpu_id >= 0:
features = features.as_in_context(mx.gpu(self.gpu_id))
return features
def build_input_layer(self): def build_input_layer(self):
return RGCNLayer(self.num_nodes, self.h_dim, self.num_rels, self.num_bases, return RelGraphConv(self.num_nodes, self.h_dim, self.num_rels, "basis",
activation=F.relu, is_input_layer=True) self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout)
def build_hidden_layer(self, idx): def build_hidden_layer(self, idx):
return RGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases, return RelGraphConv(self.h_dim, self.h_dim, self.num_rels, "basis",
activation=F.relu) self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout)
def build_output_layer(self): def build_output_layer(self):
return RGCNLayer(self.h_dim, self.out_dim, self.num_rels,self.num_bases, return RelGraphConv(self.h_dim, self.out_dim, self.num_rels, "basis",
activation=partial(F.softmax, axis=1)) self.num_bases, activation=partial(F.softmax, axis=1),
self_loop=self.use_self_loop)
def main(args): def main(args):
# load graph data # load graph data
...@@ -60,8 +55,10 @@ def main(args): ...@@ -60,8 +55,10 @@ def main(args):
val_idx = train_idx val_idx = train_idx
train_idx = mx.nd.array(train_idx) train_idx = mx.nd.array(train_idx)
# since the nodes are featureless, the input feature is then the node id.
feats = mx.nd.arange(num_nodes, dtype='int32')
# edge type and normalization factor # edge type and normalization factor
edge_type = mx.nd.array(data.edge_type) edge_type = mx.nd.array(data.edge_type, dtype='int32')
edge_norm = mx.nd.array(data.edge_norm).expand_dims(1) edge_norm = mx.nd.array(data.edge_norm).expand_dims(1)
labels = mx.nd.array(labels).reshape((-1)) labels = mx.nd.array(labels).reshape((-1))
...@@ -69,6 +66,7 @@ def main(args): ...@@ -69,6 +66,7 @@ def main(args):
use_cuda = args.gpu >= 0 use_cuda = args.gpu >= 0
if use_cuda: if use_cuda:
ctx = mx.gpu(args.gpu) ctx = mx.gpu(args.gpu)
feats = feats.as_in_context(ctx)
edge_type = edge_type.as_in_context(ctx) edge_type = edge_type.as_in_context(ctx)
edge_norm = edge_norm.as_in_context(ctx) edge_norm = edge_norm.as_in_context(ctx)
labels = labels.as_in_context(ctx) labels = labels.as_in_context(ctx)
...@@ -80,7 +78,6 @@ def main(args): ...@@ -80,7 +78,6 @@ def main(args):
g = DGLGraph() g = DGLGraph()
g.add_nodes(num_nodes) g.add_nodes(num_nodes)
g.add_edges(data.edge_src, data.edge_dst) g.add_edges(data.edge_src, data.edge_dst)
g.edata.update({'type': edge_type, 'norm': edge_norm})
# create model # create model
model = EntityClassify(len(g), model = EntityClassify(len(g),
...@@ -90,6 +87,7 @@ def main(args): ...@@ -90,6 +87,7 @@ def main(args):
num_bases=args.n_bases, num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2, num_hidden_layers=args.n_layers - 2,
dropout=args.dropout, dropout=args.dropout,
use_self_loop=args.use_self_loop,
gpu_id=args.gpu) gpu_id=args.gpu)
model.initialize(ctx=ctx) model.initialize(ctx=ctx)
...@@ -104,7 +102,7 @@ def main(args): ...@@ -104,7 +102,7 @@ def main(args):
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
t0 = time.time() t0 = time.time()
with mx.autograd.record(): with mx.autograd.record():
pred = model(g) pred = model(g, feats, edge_type, edge_norm)
loss = loss_fcn(pred[train_idx], labels[train_idx]) loss = loss_fcn(pred[train_idx], labels[train_idx])
t1 = time.time() t1 = time.time()
loss.backward() loss.backward()
...@@ -120,7 +118,7 @@ def main(args): ...@@ -120,7 +118,7 @@ def main(args):
print("Train Accuracy: {:.4f} | Validation Accuracy: {:.4f}".format(train_acc, val_acc)) print("Train Accuracy: {:.4f} | Validation Accuracy: {:.4f}".format(train_acc, val_acc))
print() print()
logits = model(g) logits = model.forward(g, feats, edge_type, edge_norm)
test_acc = F.sum(logits[test_idx].argmax(axis=1) == labels[test_idx]).asscalar() / len(test_idx) test_acc = F.sum(logits[test_idx].argmax(axis=1) == labels[test_idx]).asscalar() / len(test_idx)
print("Test Accuracy: {:.4f}".format(test_acc)) print("Test Accuracy: {:.4f}".format(test_acc))
print() print()
...@@ -151,6 +149,8 @@ if __name__ == '__main__': ...@@ -151,6 +149,8 @@ if __name__ == '__main__':
help="l2 norm coef") help="l2 norm coef")
parser.add_argument("--relabel", default=False, action='store_true', parser.add_argument("--relabel", default=False, action='store_true',
help="remove untouched nodes and relabel") help="remove untouched nodes and relabel")
parser.add_argument("--use-self-loop", default=False, action='store_true',
help="include self feature as a special relation")
fp = parser.add_mutually_exclusive_group(required=False) fp = parser.add_mutually_exclusive_group(required=False)
fp.add_argument('--validation', dest='validation', action='store_true') fp.add_argument('--validation', dest='validation', action='store_true')
fp.add_argument('--testing', dest='validation', action='store_false') fp.add_argument('--testing', dest='validation', action='store_false')
......
import math
import mxnet as mx
from mxnet import gluon
import mxnet.ndarray as F
import dgl.function as fn
class RGCNLayer(gluon.Block):
def __init__(self, in_feat, out_feat, bias=None, activation=None,
self_loop=False, dropout=0.0):
super(RGCNLayer, self).__init__()
self.bias = bias
self.activation = activation
self.self_loop = self_loop
if self.bias == True:
self.bias = self.params.get('bias', shape=(out_feat,),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
# weight for self loop
if self.self_loop:
self.loop_weight = self.params.get('loop_weight', shape=(in_feat, out_feat),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
if dropout:
self.dropout = gluon.nn.Dropout(dropout)
else:
self.dropout = None
# define how propagation is done in subclass
def propagate(self, g):
raise NotImplementedError
def forward(self, g):
if self.self_loop:
loop_message = F.dot(g.ndata['h'], self.loop_weight)
if self.dropout is not None:
loop_message = self.dropout(loop_message)
self.propagate(g)
# apply bias and activation
node_repr = g.ndata['h']
if self.bias:
node_repr = node_repr + self.bias
if self.self_loop:
node_repr = node_repr + loop_message
if self.activation:
node_repr = self.activation(node_repr)
g.ndata['h'] = node_repr
class RGCNBasisLayer(RGCNLayer):
def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None,
activation=None, is_input_layer=False):
super(RGCNBasisLayer, self).__init__(in_feat, out_feat, bias, activation)
self.in_feat = in_feat
self.out_feat = out_feat
self.num_rels = num_rels
self.num_bases = num_bases
self.is_input_layer = is_input_layer
if self.num_bases <= 0 or self.num_bases > self.num_rels:
self.num_bases = self.num_rels
# add basis weights
if self.num_bases < self.num_rels:
# linear combination coefficients
self.weight = self.params.get('weight', shape=(self.num_bases, self.in_feat * self.out_feat))
self.w_comp = self.params.get('w_comp', shape=(self.num_rels, self.num_bases),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
else:
self.weight = self.params.get('weight', shape=(self.num_bases, self.in_feat, self.out_feat),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
def propagate(self, g):
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = F.dot(self.w_comp.data(), self.weight.data()).reshape((self.num_rels, self.in_feat, self.out_feat))
else:
weight = self.weight.data()
if self.is_input_layer:
def msg_func(edges):
# for input layer, matrix multiply can be converted to be
# an embedding lookup using source node id
embed = F.reshape(weight, (-1, self.out_feat))
index = edges.data['type'] * self.in_feat + edges.src['id']
return {'msg': embed[index] * edges.data['norm']}
else:
def msg_func(edges):
w = weight[edges.data['type']]
msg = F.batch_dot(edges.src['h'].expand_dims(1), w).reshape(-1, self.out_feat)
msg = msg * edges.data['norm']
return {'msg': msg}
g.update_all(msg_func, fn.sum(msg='msg', out='h'), None)
\ No newline at end of file
...@@ -3,7 +3,8 @@ from mxnet import gluon ...@@ -3,7 +3,8 @@ from mxnet import gluon
class BaseRGCN(gluon.Block): class BaseRGCN(gluon.Block):
def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases=-1, def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases=-1,
num_hidden_layers=1, dropout=0, gpu_id=-1): num_hidden_layers=1, dropout=0,
use_self_loop=False, gpu_id=-1):
super(BaseRGCN, self).__init__() super(BaseRGCN, self).__init__()
self.num_nodes = num_nodes self.num_nodes = num_nodes
self.h_dim = h_dim self.h_dim = h_dim
...@@ -12,14 +13,12 @@ class BaseRGCN(gluon.Block): ...@@ -12,14 +13,12 @@ class BaseRGCN(gluon.Block):
self.num_bases = num_bases self.num_bases = num_bases
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.dropout = dropout self.dropout = dropout
self.use_self_loop = use_self_loop
self.gpu_id = gpu_id self.gpu_id = gpu_id
# create rgcn layers # create rgcn layers
self.build_model() self.build_model()
# create initial features
self.features = self.create_features()
def build_model(self): def build_model(self):
self.layers = gluon.nn.Sequential() self.layers = gluon.nn.Sequential()
# i2h # i2h
...@@ -35,10 +34,6 @@ class BaseRGCN(gluon.Block): ...@@ -35,10 +34,6 @@ class BaseRGCN(gluon.Block):
if h2o is not None: if h2o is not None:
self.layers.add(h2o) self.layers.add(h2o)
# initialize feature for each node
def create_features(self):
return None
def build_input_layer(self): def build_input_layer(self):
return None return None
...@@ -48,10 +43,7 @@ class BaseRGCN(gluon.Block): ...@@ -48,10 +43,7 @@ class BaseRGCN(gluon.Block):
def build_output_layer(self): def build_output_layer(self):
return None return None
def forward(self, g): def forward(self, g, h, r, norm):
if self.features is not None:
g.ndata['id'] = self.features
for layer in self.layers: for layer in self.layers:
layer(g) h = layer(g, h, r, norm)
return g.ndata.pop('h') return h
...@@ -14,11 +14,10 @@ import time ...@@ -14,11 +14,10 @@ import time
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from dgl import DGLGraph from dgl import DGLGraph
from dgl.nn.pytorch import RelGraphConv
from dgl.contrib.data import load_data from dgl.contrib.data import load_data
import dgl.function as fn
from functools import partial from functools import partial
from layers import RGCNBasisLayer as RGCNLayer
from model import BaseRGCN from model import BaseRGCN
class EntityClassify(BaseRGCN): class EntityClassify(BaseRGCN):
...@@ -29,16 +28,19 @@ class EntityClassify(BaseRGCN): ...@@ -29,16 +28,19 @@ class EntityClassify(BaseRGCN):
return features return features
def build_input_layer(self): def build_input_layer(self):
return RGCNLayer(self.num_nodes, self.h_dim, self.num_rels, self.num_bases, return RelGraphConv(self.num_nodes, self.h_dim, self.num_rels, "basis",
activation=F.relu, is_input_layer=True) self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout)
def build_hidden_layer(self, idx): def build_hidden_layer(self, idx):
return RGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases, return RelGraphConv(self.h_dim, self.h_dim, self.num_rels, "basis",
activation=F.relu) self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout)
def build_output_layer(self): def build_output_layer(self):
return RGCNLayer(self.h_dim, self.out_dim, self.num_rels,self.num_bases, return RelGraphConv(self.h_dim, self.out_dim, self.num_rels, "basis",
activation=partial(F.softmax, dim=1)) self.num_bases, activation=partial(F.softmax, dim=1),
self_loop=self.use_self_loop)
def main(args): def main(args):
# load graph data # load graph data
...@@ -57,6 +59,9 @@ def main(args): ...@@ -57,6 +59,9 @@ def main(args):
else: else:
val_idx = train_idx val_idx = train_idx
# since the nodes are featureless, the input feature is then the node id.
feats = torch.arange(num_nodes)
# edge type and normalization factor # edge type and normalization factor
edge_type = torch.from_numpy(data.edge_type) edge_type = torch.from_numpy(data.edge_type)
edge_norm = torch.from_numpy(data.edge_norm).unsqueeze(1) edge_norm = torch.from_numpy(data.edge_norm).unsqueeze(1)
...@@ -66,6 +71,7 @@ def main(args): ...@@ -66,6 +71,7 @@ def main(args):
use_cuda = args.gpu >= 0 and torch.cuda.is_available() use_cuda = args.gpu >= 0 and torch.cuda.is_available()
if use_cuda: if use_cuda:
torch.cuda.set_device(args.gpu) torch.cuda.set_device(args.gpu)
feats = feats.cuda()
edge_type = edge_type.cuda() edge_type = edge_type.cuda()
edge_norm = edge_norm.cuda() edge_norm = edge_norm.cuda()
labels = labels.cuda() labels = labels.cuda()
...@@ -74,7 +80,6 @@ def main(args): ...@@ -74,7 +80,6 @@ def main(args):
g = DGLGraph() g = DGLGraph()
g.add_nodes(num_nodes) g.add_nodes(num_nodes)
g.add_edges(data.edge_src, data.edge_dst) g.add_edges(data.edge_src, data.edge_dst)
g.edata.update({'type': edge_type, 'norm': edge_norm})
# create model # create model
model = EntityClassify(len(g), model = EntityClassify(len(g),
...@@ -84,6 +89,7 @@ def main(args): ...@@ -84,6 +89,7 @@ def main(args):
num_bases=args.n_bases, num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2, num_hidden_layers=args.n_layers - 2,
dropout=args.dropout, dropout=args.dropout,
use_self_loop=args.use_self_loop,
use_cuda=use_cuda) use_cuda=use_cuda)
if use_cuda: if use_cuda:
...@@ -100,7 +106,7 @@ def main(args): ...@@ -100,7 +106,7 @@ def main(args):
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
optimizer.zero_grad() optimizer.zero_grad()
t0 = time.time() t0 = time.time()
logits = model.forward(g) logits = model(g, feats, edge_type, edge_norm)
loss = F.cross_entropy(logits[train_idx], labels[train_idx]) loss = F.cross_entropy(logits[train_idx], labels[train_idx])
t1 = time.time() t1 = time.time()
loss.backward() loss.backward()
...@@ -119,7 +125,7 @@ def main(args): ...@@ -119,7 +125,7 @@ def main(args):
print() print()
model.eval() model.eval()
logits = model.forward(g) logits = model.forward(g, feats, edge_type, edge_norm)
test_loss = F.cross_entropy(logits[test_idx], labels[test_idx]) test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
test_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx) test_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
print("Test Accuracy: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item())) print("Test Accuracy: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item()))
...@@ -151,6 +157,8 @@ if __name__ == '__main__': ...@@ -151,6 +157,8 @@ if __name__ == '__main__':
help="l2 norm coef") help="l2 norm coef")
parser.add_argument("--relabel", default=False, action='store_true', parser.add_argument("--relabel", default=False, action='store_true',
help="remove untouched nodes and relabel") help="remove untouched nodes and relabel")
parser.add_argument("--use-self-loop", default=False, action='store_true',
help="include self feature as a special relation")
fp = parser.add_mutually_exclusive_group(required=False) fp = parser.add_mutually_exclusive_group(required=False)
fp.add_argument('--validation', dest='validation', action='store_true') fp.add_argument('--validation', dest='validation', action='store_true')
fp.add_argument('--testing', dest='validation', action='store_false') fp.add_argument('--testing', dest='validation', action='store_false')
...@@ -160,4 +168,3 @@ if __name__ == '__main__': ...@@ -160,4 +168,3 @@ if __name__ == '__main__':
print(args) print(args)
args.bfs_level = args.n_layers + 1 # pruning used nodes for memory args.bfs_level = args.n_layers + 1 # pruning used nodes for memory
main(args) main(args)
import torch
import torch.nn as nn
import dgl.function as fn
class RGCNLayer(nn.Module):
def __init__(self, in_feat, out_feat, bias=None, activation=None,
self_loop=False, dropout=0.0):
super(RGCNLayer, self).__init__()
self.bias = bias
self.activation = activation
self.self_loop = self_loop
if self.bias == True:
self.bias = nn.Parameter(torch.Tensor(out_feat))
nn.init.xavier_uniform_(self.bias,
gain=nn.init.calculate_gain('relu'))
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(torch.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight,
gain=nn.init.calculate_gain('relu'))
if dropout:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
# define how propagation is done in subclass
def propagate(self, g):
raise NotImplementedError
def forward(self, g):
if self.self_loop:
loop_message = torch.mm(g.ndata['h'], self.loop_weight)
if self.dropout is not None:
loop_message = self.dropout(loop_message)
self.propagate(g)
# apply bias and activation
node_repr = g.ndata['h']
if self.bias:
node_repr = node_repr + self.bias
if self.self_loop:
node_repr = node_repr + loop_message
if self.activation:
node_repr = self.activation(node_repr)
g.ndata['h'] = node_repr
class RGCNBasisLayer(RGCNLayer):
def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None,
activation=None, is_input_layer=False):
super(RGCNBasisLayer, self).__init__(in_feat, out_feat, bias, activation)
self.in_feat = in_feat
self.out_feat = out_feat
self.num_rels = num_rels
self.num_bases = num_bases
self.is_input_layer = is_input_layer
if self.num_bases <= 0 or self.num_bases > self.num_rels:
self.num_bases = self.num_rels
# add basis weights
self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.in_feat,
self.out_feat))
if self.num_bases < self.num_rels:
# linear combination coefficients
self.w_comp = nn.Parameter(torch.Tensor(self.num_rels,
self.num_bases))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
if self.num_bases < self.num_rels:
nn.init.xavier_uniform_(self.w_comp,
gain=nn.init.calculate_gain('relu'))
def propagate(self, g):
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = self.weight.view(self.num_bases,
self.in_feat * self.out_feat)
weight = torch.matmul(self.w_comp, weight).view(
self.num_rels, self.in_feat, self.out_feat)
else:
weight = self.weight
if self.is_input_layer:
def msg_func(edges):
# for input layer, matrix multiply can be converted to be
# an embedding lookup using source node id
embed = weight.view(-1, self.out_feat)
index = edges.data['type'] * self.in_feat + edges.src['id']
return {'msg': embed.index_select(0, index) * edges.data['norm']}
else:
def msg_func(edges):
w = weight.index_select(0, edges.data['type'])
msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()
msg = msg * edges.data['norm']
return {'msg': msg}
g.update_all(msg_func, fn.sum(msg='msg', out='h'), None)
class RGCNBlockLayer(RGCNLayer):
def __init__(self, in_feat, out_feat, num_rels, num_bases, bias=None,
activation=None, self_loop=False, dropout=0.0):
super(RGCNBlockLayer, self).__init__(in_feat, out_feat, bias,
activation, self_loop=self_loop,
dropout=dropout)
self.num_rels = num_rels
self.num_bases = num_bases
assert self.num_bases > 0
self.out_feat = out_feat
self.submat_in = in_feat // self.num_bases
self.submat_out = out_feat // self.num_bases
# assuming in_feat and out_feat are both divisible by num_bases
self.weight = nn.Parameter(torch.Tensor(
self.num_rels, self.num_bases * self.submat_in * self.submat_out))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
def msg_func(self, edges):
weight = self.weight.index_select(0, edges.data['type']).view(
-1, self.submat_in, self.submat_out)
node = edges.src['h'].view(-1, 1, self.submat_in)
msg = torch.bmm(node, weight).view(-1, self.out_feat)
return {'msg': msg}
def propagate(self, g):
g.update_all(self.msg_func, fn.sum(msg='msg', out='h'), self.apply_func)
def apply_func(self, nodes):
return {'h': nodes.data['h'] * nodes.data['norm']}
...@@ -4,7 +4,12 @@ Paper: https://arxiv.org/abs/1703.06103 ...@@ -4,7 +4,12 @@ Paper: https://arxiv.org/abs/1703.06103
Code: https://github.com/MichSchli/RelationPrediction Code: https://github.com/MichSchli/RelationPrediction
Difference compared to MichSchli/RelationPrediction Difference compared to MichSchli/RelationPrediction
* report raw metrics instead of filtered metrics * Report raw metrics instead of filtered metrics.
* By default, we use uniform edge sampling instead of neighbor-based edge
sampling used in author's code. In practice, we find it achieves similar MRR
probably because the model only uses one GNN layer so messages are propagated
among immediate neighbors. User could specify "--edge-sampler=neighbor" to switch
to neighbor-based edge sampling.
""" """
import argparse import argparse
...@@ -15,8 +20,8 @@ import torch.nn as nn ...@@ -15,8 +20,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import random import random
from dgl.contrib.data import load_data from dgl.contrib.data import load_data
from dgl.nn.pytorch import RelGraphConv
from layers import RGCNBlockLayer as RGCNLayer
from model import BaseRGCN from model import BaseRGCN
import utils import utils
...@@ -26,9 +31,8 @@ class EmbeddingLayer(nn.Module): ...@@ -26,9 +31,8 @@ class EmbeddingLayer(nn.Module):
super(EmbeddingLayer, self).__init__() super(EmbeddingLayer, self).__init__()
self.embedding = torch.nn.Embedding(num_nodes, h_dim) self.embedding = torch.nn.Embedding(num_nodes, h_dim)
def forward(self, g): def forward(self, g, h, r, norm):
node_id = g.ndata['id'].squeeze() return self.embedding(h.squeeze())
g.ndata['h'] = self.embedding(node_id)
class RGCN(BaseRGCN): class RGCN(BaseRGCN):
def build_input_layer(self): def build_input_layer(self):
...@@ -36,8 +40,9 @@ class RGCN(BaseRGCN): ...@@ -36,8 +40,9 @@ class RGCN(BaseRGCN):
def build_hidden_layer(self, idx): def build_hidden_layer(self, idx):
act = F.relu if idx < self.num_hidden_layers - 1 else None act = F.relu if idx < self.num_hidden_layers - 1 else None
return RGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases, return RelGraphConv(self.h_dim, self.h_dim, self.num_rels, "bdd",
activation=act, self_loop=True, dropout=self.dropout) self.num_bases, activation=act, self_loop=True,
dropout=self.dropout)
class LinkPredict(nn.Module): class LinkPredict(nn.Module):
def __init__(self, in_dim, h_dim, num_rels, num_bases=-1, def __init__(self, in_dim, h_dim, num_rels, num_bases=-1,
...@@ -58,26 +63,26 @@ class LinkPredict(nn.Module): ...@@ -58,26 +63,26 @@ class LinkPredict(nn.Module):
score = torch.sum(s * r * o, dim=1) score = torch.sum(s * r * o, dim=1)
return score return score
def forward(self, g): def forward(self, g, h, r, norm):
return self.rgcn.forward(g) return self.rgcn.forward(g, h, r, norm)
def evaluate(self, g):
# get embedding and relation weight without grad
embedding = self.forward(g)
return embedding, self.w_relation
def regularization_loss(self, embedding): def regularization_loss(self, embedding):
return torch.mean(embedding.pow(2)) + torch.mean(self.w_relation.pow(2)) return torch.mean(embedding.pow(2)) + torch.mean(self.w_relation.pow(2))
def get_loss(self, g, triplets, labels): def get_loss(self, g, embed, triplets, labels):
# triplets is a list of data samples (positive and negative) # triplets is a list of data samples (positive and negative)
# each row in the triplets is a 3-tuple of (source, relation, destination) # each row in the triplets is a 3-tuple of (source, relation, destination)
embedding = self.forward(g) score = self.calc_score(embed, triplets)
score = self.calc_score(embedding, triplets)
predict_loss = F.binary_cross_entropy_with_logits(score, labels) predict_loss = F.binary_cross_entropy_with_logits(score, labels)
reg_loss = self.regularization_loss(embedding) reg_loss = self.regularization_loss(embed)
return predict_loss + self.reg_param * reg_loss return predict_loss + self.reg_param * reg_loss
def node_norm_to_edge_norm(g, node_norm):
g = g.local_var()
# convert to edge norm
g.ndata['norm'] = node_norm
g.apply_edges(lambda edges : {'norm' : edges.dst['norm']})
return g.edata['norm']
def main(args): def main(args):
# load graph data # load graph data
...@@ -114,9 +119,7 @@ def main(args): ...@@ -114,9 +119,7 @@ def main(args):
range(test_graph.number_of_nodes())).float().view(-1,1) range(test_graph.number_of_nodes())).float().view(-1,1)
test_node_id = torch.arange(0, num_nodes, dtype=torch.long).view(-1, 1) test_node_id = torch.arange(0, num_nodes, dtype=torch.long).view(-1, 1)
test_rel = torch.from_numpy(test_rel) test_rel = torch.from_numpy(test_rel)
test_norm = torch.from_numpy(test_norm).view(-1, 1) test_norm = node_norm_to_edge_norm(test_graph, torch.from_numpy(test_norm).view(-1, 1))
test_graph.ndata.update({'id': test_node_id, 'norm': test_norm})
test_graph.edata['type'] = test_rel
if use_cuda: if use_cuda:
model.cuda() model.cuda()
...@@ -144,24 +147,24 @@ def main(args): ...@@ -144,24 +147,24 @@ def main(args):
g, node_id, edge_type, node_norm, data, labels = \ g, node_id, edge_type, node_norm, data, labels = \
utils.generate_sampled_graph_and_labels( utils.generate_sampled_graph_and_labels(
train_data, args.graph_batch_size, args.graph_split_size, train_data, args.graph_batch_size, args.graph_split_size,
num_rels, adj_list, degrees, args.negative_sample) num_rels, adj_list, degrees, args.negative_sample,
args.edge_sampler)
print("Done edge sampling") print("Done edge sampling")
# set node/edge feature # set node/edge feature
node_id = torch.from_numpy(node_id).view(-1, 1).long() node_id = torch.from_numpy(node_id).view(-1, 1).long()
edge_type = torch.from_numpy(edge_type) edge_type = torch.from_numpy(edge_type)
node_norm = torch.from_numpy(node_norm).view(-1, 1) edge_norm = node_norm_to_edge_norm(g, torch.from_numpy(node_norm).view(-1, 1))
data, labels = torch.from_numpy(data), torch.from_numpy(labels) data, labels = torch.from_numpy(data), torch.from_numpy(labels)
deg = g.in_degrees(range(g.number_of_nodes())).float().view(-1, 1) deg = g.in_degrees(range(g.number_of_nodes())).float().view(-1, 1)
if use_cuda: if use_cuda:
node_id, deg = node_id.cuda(), deg.cuda() node_id, deg = node_id.cuda(), deg.cuda()
edge_type, node_norm = edge_type.cuda(), node_norm.cuda() edge_type, edge_norm = edge_type.cuda(), edge_norm.cuda()
data, labels = data.cuda(), labels.cuda() data, labels = data.cuda(), labels.cuda()
g.ndata.update({'id': node_id, 'norm': node_norm})
g.edata['type'] = edge_type
t0 = time.time() t0 = time.time()
loss = model.get_loss(g, data, labels) embed = model(g, node_id, edge_type, edge_norm)
loss = model.get_loss(g, embed, data, labels)
t1 = time.time() t1 = time.time()
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) # clip gradients
...@@ -182,7 +185,8 @@ def main(args): ...@@ -182,7 +185,8 @@ def main(args):
model.cpu() model.cpu()
model.eval() model.eval()
print("start eval") print("start eval")
mrr = utils.evaluate(test_graph, model, valid_data, embed = model(test_graph, test_node_id, test_rel, test_norm)
mrr = utils.calc_mrr(embed, model.w_relation, valid_data,
hits=[1, 3, 10], eval_bz=args.eval_batch_size) hits=[1, 3, 10], eval_bz=args.eval_batch_size)
# save best model # save best model
if mrr < best_mrr: if mrr < best_mrr:
...@@ -207,9 +211,9 @@ def main(args): ...@@ -207,9 +211,9 @@ def main(args):
model.eval() model.eval()
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'])
print("Using best epoch: {}".format(checkpoint['epoch'])) print("Using best epoch: {}".format(checkpoint['epoch']))
utils.evaluate(test_graph, model, test_data, hits=[1, 3, 10], embed = model(test_graph, test_node_id, test_rel, test_norm)
eval_bz=args.eval_batch_size) utils.calc_mrr(embed, model.w_relation, test_data,
hits=[1, 3, 10], eval_bz=args.eval_batch_size)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN') parser = argparse.ArgumentParser(description='RGCN')
...@@ -243,8 +247,9 @@ if __name__ == '__main__': ...@@ -243,8 +247,9 @@ if __name__ == '__main__':
help="number of negative samples per positive sample") help="number of negative samples per positive sample")
parser.add_argument("--evaluate-every", type=int, default=500, parser.add_argument("--evaluate-every", type=int, default=500,
help="perform evaluation every n epochs") help="perform evaluation every n epochs")
parser.add_argument("--edge-sampler", type=str, default="uniform",
help="type of edge sampler: 'uniform' or 'neighbor'")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(args) main(args)
import torch.nn as nn import torch.nn as nn
class BaseRGCN(nn.Module): class BaseRGCN(nn.Module):
def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases=-1, def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases,
num_hidden_layers=1, dropout=0, use_cuda=False): num_hidden_layers=1, dropout=0,
use_self_loop=False, use_cuda=False):
super(BaseRGCN, self).__init__() super(BaseRGCN, self).__init__()
self.num_nodes = num_nodes self.num_nodes = num_nodes
self.h_dim = h_dim self.h_dim = h_dim
self.out_dim = out_dim self.out_dim = out_dim
self.num_rels = num_rels self.num_rels = num_rels
self.num_bases = num_bases self.num_bases = None if num_bases < 0 else num_bases
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.dropout = dropout self.dropout = dropout
self.use_self_loop = use_self_loop
self.use_cuda = use_cuda self.use_cuda = use_cuda
# create rgcn layers # create rgcn layers
self.build_model() self.build_model()
# create initial features
self.features = self.create_features()
def build_model(self): def build_model(self):
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# i2h # i2h
...@@ -34,10 +33,6 @@ class BaseRGCN(nn.Module): ...@@ -34,10 +33,6 @@ class BaseRGCN(nn.Module):
if h2o is not None: if h2o is not None:
self.layers.append(h2o) self.layers.append(h2o)
# initialize feature for each node
def create_features(self):
return None
def build_input_layer(self): def build_input_layer(self):
return None return None
...@@ -47,10 +42,7 @@ class BaseRGCN(nn.Module): ...@@ -47,10 +42,7 @@ class BaseRGCN(nn.Module):
def build_output_layer(self): def build_output_layer(self):
return None return None
def forward(self, g): def forward(self, g, h, r, norm):
if self.features is not None:
g.ndata['id'] = self.features
for layer in self.layers: for layer in self.layers:
layer(g) h = layer(g, h, r, norm)
return g.ndata.pop('h') return h
...@@ -28,9 +28,11 @@ def get_adj_and_degrees(num_nodes, triplets): ...@@ -28,9 +28,11 @@ def get_adj_and_degrees(num_nodes, triplets):
return adj_list, degrees return adj_list, degrees
def sample_edge_neighborhood(adj_list, degrees, n_triplets, sample_size): def sample_edge_neighborhood(adj_list, degrees, n_triplets, sample_size):
""" Edge neighborhood sampling to reduce training graph size """Sample edges by neighborhool expansion.
"""
This guarantees that the sampled edges form a connected graph, which
may help deeper GNNs that require information from more than one hop.
"""
edges = np.zeros((sample_size), dtype=np.int32) edges = np.zeros((sample_size), dtype=np.int32)
#initialize #initialize
...@@ -69,16 +71,25 @@ def sample_edge_neighborhood(adj_list, degrees, n_triplets, sample_size): ...@@ -69,16 +71,25 @@ def sample_edge_neighborhood(adj_list, degrees, n_triplets, sample_size):
return edges return edges
def sample_edge_uniform(adj_list, degrees, n_triplets, sample_size):
"""Sample edges uniformly from all the edges."""
all_edges = np.arange(n_triplets)
return np.random.choice(all_edges, sample_size, replace=False)
def generate_sampled_graph_and_labels(triplets, sample_size, split_size, def generate_sampled_graph_and_labels(triplets, sample_size, split_size,
num_rels, adj_list, degrees, num_rels, adj_list, degrees,
negative_rate): negative_rate, sampler="uniform"):
"""Get training graph and signals """Get training graph and signals
First perform edge neighborhood sampling on graph, then perform negative First perform edge neighborhood sampling on graph, then perform negative
sampling to generate negative samples sampling to generate negative samples
""" """
# perform edge neighbor sampling # perform edge neighbor sampling
edges = sample_edge_neighborhood(adj_list, degrees, len(triplets), if sampler == "uniform":
sample_size) edges = sample_edge_uniform(adj_list, degrees, len(triplets), sample_size)
elif sampler == "neighbor":
edges = sample_edge_neighborhood(adj_list, degrees, len(triplets), sample_size)
else:
raise ValueError("Sampler type must be either 'uniform' or 'neighbor'.")
# relabel nodes to have consecutive node ids # relabel nodes to have consecutive node ids
edges = triplets[edges] edges = triplets[edges]
...@@ -108,6 +119,7 @@ def generate_sampled_graph_and_labels(triplets, sample_size, split_size, ...@@ -108,6 +119,7 @@ def generate_sampled_graph_and_labels(triplets, sample_size, split_size,
return g, uniq_v, rel, norm, samples, labels return g, uniq_v, rel, norm, samples, labels
def comp_deg_norm(g): def comp_deg_norm(g):
g = g.local_var()
in_deg = g.in_degrees(range(g.number_of_nodes())).float().numpy() in_deg = g.in_degrees(range(g.number_of_nodes())).float().numpy()
norm = 1.0 / in_deg norm = 1.0 / in_deg
norm[np.isinf(norm)] = 0 norm[np.isinf(norm)] = 0
...@@ -187,9 +199,8 @@ def perturb_and_get_rank(embedding, w, a, r, b, test_size, batch_size=100): ...@@ -187,9 +199,8 @@ def perturb_and_get_rank(embedding, w, a, r, b, test_size, batch_size=100):
# TODO (lingfan): implement filtered metrics # TODO (lingfan): implement filtered metrics
# return MRR (raw), and Hits @ (1, 3, 10) # return MRR (raw), and Hits @ (1, 3, 10)
def evaluate(test_graph, model, test_triplets, hits=[], eval_bz=100): def calc_mrr(embedding, w, test_triplets, hits=[], eval_bz=100):
with torch.no_grad(): with torch.no_grad():
embedding, w = model.evaluate(test_graph)
s = test_triplets[:, 0] s = test_triplets[:, 0]
r = test_triplets[:, 1] r = test_triplets[:, 1]
o = test_triplets[:, 2] o = test_triplets[:, 2]
...@@ -210,4 +221,3 @@ def evaluate(test_graph, model, test_triplets, hits=[], eval_bz=100): ...@@ -210,4 +221,3 @@ def evaluate(test_graph, model, test_triplets, hits=[], eval_bz=100):
avg_count = torch.mean((ranks <= hit).float()) avg_count = torch.mean((ranks <= hit).float())
print("Hits (raw) @ {}: {:.6f}".format(hit, avg_count.item())) print("Hits (raw) @ {}: {:.6f}".format(hit, avg_count.item()))
return mrr.item() return mrr.item()
"""MXNet modules for graph convolutions.""" """MXNet modules for graph convolutions."""
# pylint: disable= no-member, arguments-differ # pylint: disable= no-member, arguments-differ
import math
import mxnet as mx import mxnet as mx
from mxnet import gluon from mxnet import gluon, nd
from mxnet.gluon import nn
import numpy as np
from . import utils
from ... import function as fn from ... import function as fn
__all__ = ['GraphConv'] __all__ = ['GraphConv', 'RelGraphConv']
class GraphConv(gluon.Block): class GraphConv(gluon.Block):
r"""Apply graph convolution over an input signal. r"""Apply graph convolution over an input signal.
...@@ -142,3 +146,191 @@ class GraphConv(gluon.Block): ...@@ -142,3 +146,191 @@ class GraphConv(gluon.Block):
self._norm, self._activation) self._norm, self._activation)
summary += '\n)' summary += '\n)'
return summary return summary
class RelGraphConv(gluon.Block):
r"""Relational graph convolution layer.
Relational graph convolution is introduced in "`Modeling Relational Data with Graph
Convolutional Networks <https://arxiv.org/abs/1703.06103>`__"
and can be described as below:
.. math::
h_i^{(l+1)} = \sigma(\sum_{r\in\mathcal{R}}
\sum_{j\in\mathcal{N}^r(i)}\frac{1}{c_{i,r}}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)})
where :math:`\mathcal{N}^r(i)` is the neighbor set of node :math:`i` w.r.t. relation
:math:`r`. :math:`c_{i,r}` is the normalizer equal
to :math:`|\mathcal{N}^r(i)|`. :math:`\sigma` is an activation function. :math:`W_0`
is the self-loop weight.
The basis regularization decomposes :math:`W_r` by:
.. math::
W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)}
where :math:`B` is the number of bases.
The block-diagonal-decomposition regularization decomposes :math:`W_r` into :math:`B`
number of block diagonal matrices. We refer :math:`B` as the number of bases.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
num_rels : int
Number of relations.
regularizer : str
Which weight regularizer to use "basis" or "bdd"
num_bases : int, optional
Number of bases. If is none, use number of relations. Default: None.
bias : bool, optional
True if bias is added. Default: True
activation : callable, optional
Activation function. Default: None
self_loop : bool, optional
True to include self loop message. Default: False
dropout : float, optional
Dropout rate. Default: 0.0
"""
def __init__(self,
in_feat,
out_feat,
num_rels,
regularizer="basis",
num_bases=None,
bias=True,
activation=None,
self_loop=False,
dropout=0.0):
super(RelGraphConv, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.num_rels = num_rels
self.regularizer = regularizer
self.num_bases = num_bases
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases < 0:
self.num_bases = self.num_rels
self.bias = bias
self.activation = activation
self.self_loop = self_loop
if regularizer == "basis":
# add basis weights
self.weight = self.params.get(
'weight', shape=(self.num_bases, self.in_feat, self.out_feat),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
if self.num_bases < self.num_rels:
# linear combination coefficients
self.w_comp = self.params.get(
'w_comp', shape=(self.num_rels, self.num_bases),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
# message func
self.message_func = self.basis_message_func
elif regularizer == "bdd":
if in_feat % num_bases != 0 or out_feat % num_bases != 0:
raise ValueError('Feature size must be a multiplier of num_bases.')
# add block diagonal weights
self.submat_in = in_feat // self.num_bases
self.submat_out = out_feat // self.num_bases
# assuming in_feat and out_feat are both divisible by num_bases
self.weight = self.params.get(
'weight',
shape=(self.num_rels, self.num_bases * self.submat_in * self.submat_out),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
# message func
self.message_func = self.bdd_message_func
else:
raise ValueError("Regularizer must be either 'basis' or 'bdd'")
# bias
if self.bias:
self.h_bias = self.params.get('bias', shape=(out_feat,),
init=mx.init.Zero())
# weight for self loop
if self.self_loop:
self.loop_weight = self.params.get(
'W_0', shape=(in_feat, out_feat),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
self.dropout = nn.Dropout(dropout)
def basis_message_func(self, edges):
"""Message function for basis regularizer"""
ctx = edges.src['h'].context
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = self.weight.data(ctx).reshape(
self.num_bases, self.in_feat * self.out_feat)
weight = nd.dot(self.w_comp.data(ctx), weight).reshape(
self.num_rels, self.in_feat, self.out_feat)
else:
weight = self.weight.data(ctx)
msg = utils.bmm_maybe_select(edges.src['h'], weight, edges.data['type'])
if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}
def bdd_message_func(self, edges):
"""Message function for block-diagonal-decomposition regularizer"""
ctx = edges.src['h'].context
if edges.src['h'].dtype in (np.int32, np.int64) and len(edges.src['h'].shape) == 1:
raise TypeError('Block decomposition does not allow integer ID feature.')
weight = self.weight.data(ctx)[edges.data['type'], :].reshape(
-1, self.submat_in, self.submat_out)
node = edges.src['h'].reshape(-1, 1, self.submat_in)
msg = nd.batch_dot(node, weight).reshape(-1, self.out_feat)
if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}
def forward(self, g, x, etypes, norm=None):
"""Forward computation
Parameters
----------
g : DGLGraph
The graph.
x : mx.ndarray.NDArray
Input node features. Could be either
- (|V|, D) dense tensor
- (|V|,) int64 vector, representing the categorical values of each
node. We then treat the input feature as an one-hot encoding feature.
etypes : mx.ndarray.NDArray
Edge type tensor. Shape: (|E|,)
norm : mx.ndarray.NDArray
Optional edge normalizer tensor. Shape: (|E|, 1)
Returns
-------
mx.ndarray.NDArray
New node features.
"""
g = g.local_var()
g.ndata['h'] = x
g.edata['type'] = etypes
if norm is not None:
g.edata['norm'] = norm
if self.self_loop:
loop_message = utils.matmul_maybe_select(x, self.loop_weight.data(x.context))
# message passing
g.update_all(self.message_func, fn.sum(msg='msg', out='h'))
# apply bias and activation
node_repr = g.ndata['h']
if self.bias:
node_repr = node_repr + self.h_bias.data(x.context)
if self.self_loop:
node_repr = node_repr + loop_message
if self.activation:
node_repr = self.activation(node_repr)
node_repr = self.dropout(node_repr)
return node_repr
"""Utilities for pytorch NN package"""
#pylint: disable=no-member, invalid-name
from mxnet import nd
import numpy as np
def matmul_maybe_select(A, B):
"""Perform Matrix multiplication C = A * B but A could be an integer id vector.
If A is an integer vector, we treat it as multiplying a one-hot encoded tensor.
In this case, the expensive dense matrix multiply can be replaced by a much
cheaper index lookup.
For example,
::
A = [2, 0, 1],
B = [[0.1, 0.2],
[0.3, 0.4],
[0.5, 0.6]]
then matmul_maybe_select(A, B) is equivalent to
::
[[0, 0, 1], [[0.1, 0.2],
[1, 0, 0], * [0.3, 0.4],
[0, 1, 0]] [0.5, 0.6]]
In all other cases, perform a normal matmul.
Parameters
----------
A : torch.Tensor
lhs tensor
B : torch.Tensor
rhs tensor
Returns
-------
C : torch.Tensor
result tensor
"""
if A.dtype in (np.int32, np.int64) and len(A.shape) == 1:
return nd.take(B, A, axis=0)
else:
return nd.dot(A, B)
def bmm_maybe_select(A, B, index):
"""Slice submatrices of A by the given index and perform bmm.
B is a 3D tensor of shape (N, D1, D2), which can be viewed as a stack of
N matrices of shape (D1, D2). The input index is an integer vector of length M.
A could be either:
(1) a dense tensor of shape (M, D1),
(2) an integer vector of length M.
The result C is a 2D matrix of shape (M, D2)
For case (1), C is computed by bmm:
::
C[i, :] = matmul(A[i, :], B[index[i], :, :])
For case (2), C is computed by index select:
::
C[i, :] = B[index[i], A[i], :]
Parameters
----------
A : torch.Tensor
lhs tensor
B : torch.Tensor
rhs tensor
index : torch.Tensor
index tensor
Returns
-------
C : torch.Tensor
return tensor
"""
if A.dtype in (np.int32, np.int64) and len(A.shape) == 1:
return B[index, A, :]
else:
BB = nd.take(B, index, axis=0)
return nd.batch_dot(A.expand_dims(1), BB).squeeze()
...@@ -4,9 +4,10 @@ import torch as th ...@@ -4,9 +4,10 @@ import torch as th
from torch import nn from torch import nn
from torch.nn import init from torch.nn import init
from . import utils
from ... import function as fn from ... import function as fn
__all__ = ['GraphConv'] __all__ = ['GraphConv', 'RelGraphConv']
class GraphConv(nn.Module): class GraphConv(nn.Module):
r"""Apply graph convolution over an input signal. r"""Apply graph convolution over an input signal.
...@@ -148,3 +149,188 @@ class GraphConv(nn.Module): ...@@ -148,3 +149,188 @@ class GraphConv(nn.Module):
if '_activation' in self.__dict__: if '_activation' in self.__dict__:
summary += ', activation={_activation}' summary += ', activation={_activation}'
return summary.format(**self.__dict__) return summary.format(**self.__dict__)
class RelGraphConv(nn.Module):
r"""Relational graph convolution layer.
Relational graph convolution is introduced in "`Modeling Relational Data with Graph
Convolutional Networks <https://arxiv.org/abs/1703.06103>`__"
and can be described as below:
.. math::
h_i^{(l+1)} = \sigma(\sum_{r\in\mathcal{R}}
\sum_{j\in\mathcal{N}^r(i)}\frac{1}{c_{i,r}}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)})
where :math:`\mathcal{N}^r(i)` is the neighbor set of node :math:`i` w.r.t. relation
:math:`r`. :math:`c_{i,r}` is the normalizer equal
to :math:`|\mathcal{N}^r(i)|`. :math:`\sigma` is an activation function. :math:`W_0`
is the self-loop weight.
The basis regularization decomposes :math:`W_r` by:
.. math::
W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)}
where :math:`B` is the number of bases.
The block-diagonal-decomposition regularization decomposes :math:`W_r` into :math:`B`
number of block diagonal matrices. We refer :math:`B` as the number of bases.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
num_rels : int
Number of relations.
regularizer : str
Which weight regularizer to use "basis" or "bdd"
num_bases : int, optional
Number of bases. If is none, use number of relations. Default: None.
bias : bool, optional
True if bias is added. Default: True
activation : callable, optional
Activation function. Default: None
self_loop : bool, optional
True to include self loop message. Default: False
dropout : float, optional
Dropout rate. Default: 0.0
"""
def __init__(self,
in_feat,
out_feat,
num_rels,
regularizer="basis",
num_bases=None,
bias=True,
activation=None,
self_loop=False,
dropout=0.0):
super(RelGraphConv, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.num_rels = num_rels
self.regularizer = regularizer
self.num_bases = num_bases
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases < 0:
self.num_bases = self.num_rels
self.bias = bias
self.activation = activation
self.self_loop = self_loop
if regularizer == "basis":
# add basis weights
self.weight = nn.Parameter(th.Tensor(self.num_bases, self.in_feat, self.out_feat))
if self.num_bases < self.num_rels:
# linear combination coefficients
self.w_comp = nn.Parameter(th.Tensor(self.num_rels, self.num_bases))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
if self.num_bases < self.num_rels:
nn.init.xavier_uniform_(self.w_comp,
gain=nn.init.calculate_gain('relu'))
# message func
self.message_func = self.basis_message_func
elif regularizer == "bdd":
if in_feat % num_bases != 0 or out_feat % num_bases != 0:
raise ValueError('Feature size must be a multiplier of num_bases.')
# add block diagonal weights
self.submat_in = in_feat // self.num_bases
self.submat_out = out_feat // self.num_bases
# assuming in_feat and out_feat are both divisible by num_bases
self.weight = nn.Parameter(th.Tensor(
self.num_rels, self.num_bases * self.submat_in * self.submat_out))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
# message func
self.message_func = self.bdd_message_func
else:
raise ValueError("Regularizer must be either 'basis' or 'bdd'")
# bias
if self.bias:
self.h_bias = nn.Parameter(th.Tensor(out_feat))
nn.init.zeros_(self.h_bias)
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight,
gain=nn.init.calculate_gain('relu'))
self.dropout = nn.Dropout(dropout)
def basis_message_func(self, edges):
"""Message function for basis regularizer"""
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = self.weight.view(self.num_bases,
self.in_feat * self.out_feat)
weight = th.matmul(self.w_comp, weight).view(
self.num_rels, self.in_feat, self.out_feat)
else:
weight = self.weight
msg = utils.bmm_maybe_select(edges.src['h'], weight, edges.data['type'])
if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}
def bdd_message_func(self, edges):
"""Message function for block-diagonal-decomposition regularizer"""
if edges.src['h'].dtype == th.int64 and len(edges.src['h'].shape) == 1:
raise TypeError('Block decomposition does not allow integer ID feature.')
weight = self.weight.index_select(0, edges.data['type']).view(
-1, self.submat_in, self.submat_out)
node = edges.src['h'].view(-1, 1, self.submat_in)
msg = th.bmm(node, weight).view(-1, self.out_feat)
if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}
def forward(self, g, x, etypes, norm=None):
"""Forward computation
Parameters
----------
g : DGLGraph
The graph.
x : torch.Tensor
Input node features. Could be either
- (|V|, D) dense tensor
- (|V|,) int64 vector, representing the categorical values of each
node. We then treat the input feature as an one-hot encoding feature.
etypes : torch.Tensor
Edge type tensor. Shape: (|E|,)
norm : torch.Tensor
Optional edge normalizer tensor. Shape: (|E|, 1)
Returns
-------
torch.Tensor
New node features.
"""
g = g.local_var()
g.ndata['h'] = x
g.edata['type'] = etypes
if norm is not None:
g.edata['norm'] = norm
if self.self_loop:
loop_message = utils.matmul_maybe_select(x, self.loop_weight)
# message passing
g.update_all(self.message_func, fn.sum(msg='msg', out='h'))
# apply bias and activation
node_repr = g.ndata['h']
if self.bias:
node_repr = node_repr + self.h_bias
if self.self_loop:
node_repr = node_repr + loop_message
if self.activation:
node_repr = self.activation(node_repr)
node_repr = self.dropout(node_repr)
return node_repr
"""Utilities for pytorch NN package"""
#pylint: disable=no-member, invalid-name
import torch as th
def matmul_maybe_select(A, B):
"""Perform Matrix multiplication C = A * B but A could be an integer id vector.
If A is an integer vector, we treat it as multiplying a one-hot encoded tensor.
In this case, the expensive dense matrix multiply can be replaced by a much
cheaper index lookup.
For example,
::
A = [2, 0, 1],
B = [[0.1, 0.2],
[0.3, 0.4],
[0.5, 0.6]]
then matmul_maybe_select(A, B) is equivalent to
::
[[0, 0, 1], [[0.1, 0.2],
[1, 0, 0], * [0.3, 0.4],
[0, 1, 0]] [0.5, 0.6]]
In all other cases, perform a normal matmul.
Parameters
----------
A : torch.Tensor
lhs tensor
B : torch.Tensor
rhs tensor
Returns
-------
C : torch.Tensor
result tensor
"""
if A.dtype == th.int64 and len(A.shape) == 1:
return B.index_select(0, A)
else:
return th.matmul(A, B)
def bmm_maybe_select(A, B, index):
"""Slice submatrices of A by the given index and perform bmm.
B is a 3D tensor of shape (N, D1, D2), which can be viewed as a stack of
N matrices of shape (D1, D2). The input index is an integer vector of length M.
A could be either:
(1) a dense tensor of shape (M, D1),
(2) an integer vector of length M.
The result C is a 2D matrix of shape (M, D2)
For case (1), C is computed by bmm:
::
C[i, :] = matmul(A[i, :], B[index[i], :, :])
For case (2), C is computed by index select:
::
C[i, :] = B[index[i], A[i], :]
Parameters
----------
A : torch.Tensor
lhs tensor
B : torch.Tensor
rhs tensor
index : torch.Tensor
index tensor
Returns
-------
C : torch.Tensor
return tensor
"""
if A.dtype == th.int64 and len(A.shape) == 1:
# following is a faster version of B[index, A, :]
B = B.view(-1, B.shape[2])
flatidx = index * B.shape[1] + A
return B.index_select(0, flatidx)
else:
BB = B.index_select(0, index)
return th.bmm(A.unsqueeze(1), BB).squeeze()
import mxnet as mx import mxnet as mx
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import scipy as sp
import dgl import dgl
import dgl.nn.mxnet as nn import dgl.nn.mxnet as nn
import backend as F import backend as F
from mxnet import autograd, gluon from mxnet import autograd, gluon, nd
def check_close(a, b): def check_close(a, b):
assert np.allclose(a.asnumpy(), b.asnumpy(), rtol=1e-4, atol=1e-4) assert np.allclose(a.asnumpy(), b.asnumpy(), rtol=1e-4, atol=1e-4)
...@@ -182,9 +183,61 @@ def test_edge_softmax(): ...@@ -182,9 +183,61 @@ def test_edge_softmax():
assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(), assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
1e-4, 1e-4) 1e-4, 1e-4)
def test_rgcn():
ctx = F.ctx()
etype = []
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
# 5 etypes
R = 5
for i in range(g.number_of_edges()):
etype.append(i % 5)
B = 2
I = 10
O = 8
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
rgc_basis.initialize(ctx=ctx)
h = nd.random.randn(100, I, ctx=ctx)
r = nd.array(etype, ctx=ctx)
h_new = rgc_basis(g, h, r)
assert list(h_new.shape) == [100, O]
rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
rgc_bdd.initialize(ctx=ctx)
h = nd.random.randn(100, I, ctx=ctx)
r = nd.array(etype, ctx=ctx)
h_new = rgc_bdd(g, h, r)
assert list(h_new.shape) == [100, O]
# with norm
norm = nd.zeros((g.number_of_edges(), 1), ctx=ctx)
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
rgc_basis.initialize(ctx=ctx)
h = nd.random.randn(100, I, ctx=ctx)
r = nd.array(etype, ctx=ctx)
h_new = rgc_basis(g, h, r, norm)
assert list(h_new.shape) == [100, O]
rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B)
rgc_bdd.initialize(ctx=ctx)
h = nd.random.randn(100, I, ctx=ctx)
r = nd.array(etype, ctx=ctx)
h_new = rgc_bdd(g, h, r, norm)
assert list(h_new.shape) == [100, O]
# id input
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B)
rgc_basis.initialize(ctx=ctx)
h = nd.random.randint(0, I, (100,), ctx=ctx)
r = nd.array(etype, ctx=ctx)
h_new = rgc_basis(g, h, r)
assert list(h_new.shape) == [100, O]
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_edge_softmax() test_edge_softmax()
test_set2set() test_set2set()
test_glob_att_pool() test_glob_att_pool()
test_simple_pool() test_simple_pool()
test_rgcn()
...@@ -260,6 +260,51 @@ def test_edge_softmax(): ...@@ -260,6 +260,51 @@ def test_edge_softmax():
assert len(g.edata) == 2 assert len(g.edata) == 2
assert F.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend assert F.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend
def test_rgcn():
ctx = F.ctx()
etype = []
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
# 5 etypes
R = 5
for i in range(g.number_of_edges()):
etype.append(i % 5)
B = 2
I = 10
O = 8
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
h = th.randn((100, I)).to(ctx)
r = th.tensor(etype).to(ctx)
h_new = rgc_basis(g, h, r)
assert list(h_new.shape) == [100, O]
rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
h = th.randn((100, I)).to(ctx)
r = th.tensor(etype).to(ctx)
h_new = rgc_bdd(g, h, r)
assert list(h_new.shape) == [100, O]
# with norm
norm = th.zeros((g.number_of_edges(), 1)).to(ctx)
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
h = th.randn((100, I)).to(ctx)
r = th.tensor(etype).to(ctx)
h_new = rgc_basis(g, h, r, norm)
assert list(h_new.shape) == [100, O]
rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
h = th.randn((100, I)).to(ctx)
r = th.tensor(etype).to(ctx)
h_new = rgc_bdd(g, h, r, norm)
assert list(h_new.shape) == [100, O]
# id input
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
h = th.randint(0, I, (100,)).to(ctx)
r = th.tensor(etype).to(ctx)
h_new = rgc_basis(g, h, r)
assert list(h_new.shape) == [100, O]
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
...@@ -268,3 +313,4 @@ if __name__ == '__main__': ...@@ -268,3 +313,4 @@ if __name__ == '__main__':
test_glob_att_pool() test_glob_att_pool()
test_simple_pool() test_simple_pool()
test_set_trans() test_set_trans()
test_rgcn()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment