Unverified Commit 2c5b48ab authored by Yizhi Liu's avatar Yizhi Liu Committed by GitHub
Browse files

[Model][MXNet] RGCN Entity Classification (#246)

* entity classify work for examples

* add loop_msg

* remove wrong assert

* remove one reshape

* add readme

* add MRR

* remove mrr from entity task
parent 7c7cc7e0
# Relational-GCN
### Prerequisites
Two extra python packages are needed for this example:
- rdflib
- pandas
Example code was tested with rdflib 4.2.2 and pandas 0.23.4
### Entity Classification
AIFB:
```
DGLBACKEND=mxnet python entity_classify.py -d aifb --testing --gpu 0
```
MUTAG:
```
DGLBACKEND=mxnet python entity_classify.py -d mutag --l2norm 5e-4 --n-bases 40 --testing --gpu 0
```
BGS:
```
DGLBACKEND=mxnet python entity_classify.py -d bgs --l2norm 5e-4 --n-bases 20 --testing --gpu 0 --relabel
```
"""
Modeling Relational Data with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1703.06103
Code: https://github.com/tkipf/relational-gcn
Difference compared to tkipf/relation-gcn
* l2norm applied to all weights
* remove nodes that won't be touched
"""
import argparse
import numpy as np
import time
import mxnet as mx
from mxnet import gluon
import mxnet.ndarray as F
from dgl import DGLGraph
from dgl.contrib.data import load_data
from functools import partial
from model import BaseRGCN
from layers import RGCNBasisLayer as RGCNLayer
class EntityClassify(BaseRGCN):
def create_features(self):
features = mx.nd.arange(self.num_nodes)
if self.gpu_id >= 0:
features = features.as_in_context(mx.gpu(self.gpu_id))
return features
def build_input_layer(self):
return RGCNLayer(self.num_nodes, self.h_dim, self.num_rels, self.num_bases,
activation=F.relu, is_input_layer=True)
def build_hidden_layer(self, idx):
return RGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases,
activation=F.relu)
def build_output_layer(self):
return RGCNLayer(self.h_dim, self.out_dim, self.num_rels,self.num_bases,
activation=partial(F.softmax, axis=1))
def main(args):
# load graph data
data = load_data(args.dataset, bfs_level=args.bfs_level, relabel=args.relabel)
num_nodes = data.num_nodes
num_rels = data.num_rels
num_classes = data.num_classes
labels = data.labels
train_idx = data.train_idx
test_idx = data.test_idx
# split dataset into train, validate, test
if args.validation:
val_idx = train_idx[:len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5:]
else:
val_idx = train_idx
train_idx = mx.nd.array(train_idx)
# edge type and normalization factor
edge_type = mx.nd.array(data.edge_type)
edge_norm = mx.nd.array(data.edge_norm).expand_dims(1)
labels = mx.nd.array(labels).reshape((-1))
# check cuda
use_cuda = args.gpu >= 0
if use_cuda:
ctx = mx.gpu(args.gpu)
edge_type = edge_type.as_in_context(ctx)
edge_norm = edge_norm.as_in_context(ctx)
labels = labels.as_in_context(ctx)
train_idx = train_idx.as_in_context(ctx)
else:
ctx = mx.cpu(0)
# create graph
g = DGLGraph()
g.add_nodes(num_nodes)
g.add_edges(data.edge_src, data.edge_dst)
g.edata.update({'type': edge_type, 'norm': edge_norm})
# create model
model = EntityClassify(len(g),
args.n_hidden,
num_classes,
num_rels,
num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2,
dropout=args.dropout,
gpu_id=args.gpu)
model.initialize(ctx=ctx)
# optimizer
trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr, 'wd': args.l2norm})
loss_fcn = gluon.loss.SoftmaxCELoss(from_logits=False)
# training loop
print("start training...")
forward_time = []
backward_time = []
for epoch in range(args.n_epochs):
t0 = time.time()
with mx.autograd.record():
pred = model(g)
loss = loss_fcn(pred[train_idx], labels[train_idx])
t1 = time.time()
loss.backward()
trainer.step(len(train_idx))
t2 = time.time()
forward_time.append(t1 - t0)
backward_time.append(t2 - t1)
print("Epoch {:05d} | Train Forward Time(s) {:.4f} | Backward Time(s) {:.4f}".
format(epoch, forward_time[-1], backward_time[-1]))
train_acc = F.sum(pred[train_idx].argmax(axis=1) == labels[train_idx]).asscalar() / train_idx.shape[0]
val_acc = F.sum(pred[val_idx].argmax(axis=1) == labels[val_idx]).asscalar() / len(val_idx)
print("Train Accuracy: {:.4f} | Validation Accuracy: {:.4f}".format(train_acc, val_acc))
print()
logits = model(g)
test_acc = F.sum(logits[test_idx].argmax(axis=1) == labels[test_idx]).asscalar() / len(test_idx)
print("Test Accuracy: {:.4f}".format(test_acc))
print()
print("Mean forward time: {:4f}".format(np.mean(forward_time[len(forward_time) // 4:])))
print("Mean backward time: {:4f}".format(np.mean(backward_time[len(backward_time) // 4:])))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN')
parser.add_argument("--dropout", type=float, default=0,
help="dropout probability")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden units")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-bases", type=int, default=-1,
help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("--n-layers", type=int, default=2,
help="number of propagation rounds")
parser.add_argument("-e", "--n-epochs", type=int, default=50,
help="number of training epochs")
parser.add_argument("-d", "--dataset", type=str, required=True,
help="dataset to use")
parser.add_argument("--l2norm", type=float, default=0,
help="l2 norm coef")
parser.add_argument("--relabel", default=False, action='store_true',
help="remove untouched nodes and relabel")
fp = parser.add_mutually_exclusive_group(required=False)
fp.add_argument('--validation', dest='validation', action='store_true')
fp.add_argument('--testing', dest='validation', action='store_false')
parser.set_defaults(validation=True)
args = parser.parse_args()
print(args)
args.bfs_level = args.n_layers + 1 # pruning used nodes for memory
main(args)
\ No newline at end of file
import math
import mxnet as mx
from mxnet import gluon
import mxnet.ndarray as F
import dgl.function as fn
class RGCNLayer(gluon.Block):
def __init__(self, in_feat, out_feat, bias=None, activation=None,
self_loop=False, dropout=0.0):
super(RGCNLayer, self).__init__()
self.bias = bias
self.activation = activation
self.self_loop = self_loop
if self.bias == True:
self.bias = self.params.get('bias', shape=(out_feat,),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
# weight for self loop
if self.self_loop:
self.loop_weight = self.params.get('loop_weight', shape=(in_feat, out_feat),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
if dropout:
self.dropout = gluon.nn.Dropout(dropout)
else:
self.dropout = None
# define how propagation is done in subclass
def propagate(self, g):
raise NotImplementedError
def forward(self, g):
if self.self_loop:
loop_message = F.dot(g.ndata['h'], self.loop_weight)
if self.dropout is not None:
loop_message = self.dropout(loop_message)
self.propagate(g)
# apply bias and activation
node_repr = g.ndata['h']
if self.bias:
node_repr = node_repr + self.bias
if self.self_loop:
node_repr = node_repr + loop_message
if self.activation:
node_repr = self.activation(node_repr)
g.ndata['h'] = node_repr
class RGCNBasisLayer(RGCNLayer):
def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None,
activation=None, is_input_layer=False):
super(RGCNBasisLayer, self).__init__(in_feat, out_feat, bias, activation)
self.in_feat = in_feat
self.out_feat = out_feat
self.num_rels = num_rels
self.num_bases = num_bases
self.is_input_layer = is_input_layer
if self.num_bases <= 0 or self.num_bases > self.num_rels:
self.num_bases = self.num_rels
# add basis weights
if self.num_bases < self.num_rels:
# linear combination coefficients
self.weight = self.params.get('weight', shape=(self.num_bases, self.in_feat * self.out_feat))
self.w_comp = self.params.get('w_comp', shape=(self.num_rels, self.num_bases),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
else:
self.weight = self.params.get('weight', shape=(self.num_bases, self.in_feat, self.out_feat),
init=mx.init.Xavier(magnitude=math.sqrt(2.0)))
def propagate(self, g):
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = F.dot(self.w_comp.data(), self.weight.data()).reshape((self.num_rels, self.in_feat, self.out_feat))
else:
weight = self.weight.data()
if self.is_input_layer:
def msg_func(edges):
# for input layer, matrix multiply can be converted to be
# an embedding lookup using source node id
embed = F.reshape(weight, (-1, self.out_feat))
index = edges.data['type'] * self.in_feat + edges.src['id']
return {'msg': embed[index] * edges.data['norm']}
else:
def msg_func(edges):
w = weight[edges.data['type']]
msg = F.batch_dot(edges.src['h'].expand_dims(1), w).reshape(-1, self.out_feat)
msg = msg * edges.data['norm']
return {'msg': msg}
g.update_all(msg_func, fn.sum(msg='msg', out='h'), None)
\ No newline at end of file
import mxnet as mx
from mxnet import gluon
class BaseRGCN(gluon.Block):
def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases=-1,
num_hidden_layers=1, dropout=0, gpu_id=-1):
super(BaseRGCN, self).__init__()
self.num_nodes = num_nodes
self.h_dim = h_dim
self.out_dim = out_dim
self.num_rels = num_rels
self.num_bases = num_bases
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
self.gpu_id = gpu_id
# create rgcn layers
self.build_model()
# create initial features
self.features = self.create_features()
def build_model(self):
self.layers = gluon.nn.Sequential()
# i2h
i2h = self.build_input_layer()
if i2h is not None:
self.layers.add(i2h)
# h2h
for idx in range(self.num_hidden_layers):
h2h = self.build_hidden_layer(idx)
self.layers.add(h2h)
# h2o
h2o = self.build_output_layer()
if h2o is not None:
self.layers.add(h2o)
# initialize feature for each node
def create_features(self):
return None
def build_input_layer(self):
return None
def build_hidden_layer(self):
raise NotImplementedError
def build_output_layer(self):
return None
def forward(self, g):
if self.features is not None:
g.ndata['id'] = self.features
for layer in self.layers:
layer(g)
return g.ndata.pop('h')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment