Unverified Commit 490c5a8d authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Model] RGCN with new heterograph API (#3025)



* rgcn with new heterograph API

* added new apply_edge()

* optimized forward pass

* renaming from *hetero to *heteroAPI
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
parent 66a54555
"""Modeling Relational Data with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1703.06103
Reference Code: https://github.com/tkipf/relational-gcn
"""
import argparse
import numpy as np
import time
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from model import EntityClassify_HeteroAPI
def main(args):
# load graph data
if args.dataset == 'aifb':
dataset = AIFBDataset()
elif args.dataset == 'mutag':
dataset = MUTAGDataset()
elif args.dataset == 'bgs':
dataset = BGSDataset()
elif args.dataset == 'am':
dataset = AMDataset()
else:
raise ValueError()
g = dataset[0]
category = dataset.predict_category
num_classes = dataset.num_classes
train_mask = g.nodes[category].data.pop('train_mask')
test_mask = g.nodes[category].data.pop('test_mask')
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
labels = g.nodes[category].data.pop('labels')
category_id = len(g.ntypes)
for i, ntype in enumerate(g.ntypes):
if ntype == category:
category_id = i
# split dataset into train, validate, test
if args.validation:
val_idx = train_idx[:len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5:]
else:
val_idx = train_idx
# check cuda
use_cuda = args.gpu >= 0 and th.cuda.is_available()
if use_cuda:
th.cuda.set_device(args.gpu)
g = g.to('cuda:%d' % args.gpu)
labels = labels.cuda()
train_idx = train_idx.cuda()
test_idx = test_idx.cuda()
# create model
model = EntityClassify_HeteroAPI(g,
args.n_hidden,
num_classes,
num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2,
dropout=args.dropout,
use_self_loop=args.use_self_loop)
if use_cuda:
model.cuda()
# optimizer
optimizer = th.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2norm)
# training loop
print("start training...")
dur = []
model.train()
for epoch in range(args.n_epochs):
optimizer.zero_grad()
t0 = time.time()
logits = model()[category]
loss = F.cross_entropy(logits[train_idx], labels[train_idx])
loss.backward()
optimizer.step()
t1 = time.time()
dur.append(t1 - t0)
train_acc = th.sum(logits[train_idx].argmax(dim=1) == labels[train_idx]).item() / len(train_idx)
val_loss = F.cross_entropy(logits[val_idx], labels[val_idx])
val_acc = th.sum(logits[val_idx].argmax(dim=1) == labels[val_idx]).item() / len(val_idx)
print("Epoch {:05d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".
format(epoch, train_acc, loss.item(), val_acc, val_loss.item(), np.average(dur)))
print()
if args.model_path is not None:
th.save(model.state_dict(), args.model_path)
model.eval()
logits = model.forward()[category]
test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
test_acc = th.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
print("Test Acc: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item()))
print()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN')
parser.add_argument("--dropout", type=float, default=0,
help="dropout probability")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden units")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-bases", type=int, default=-1,
help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("--n-layers", type=int, default=2,
help="number of propagation rounds")
parser.add_argument("-e", "--n-epochs", type=int, default=50,
help="number of training epochs")
parser.add_argument("-d", "--dataset", type=str, required=True,
help="dataset to use")
parser.add_argument("--model_path", type=str, default=None,
help='path for save the model')
parser.add_argument("--l2norm", type=float, default=0,
help="l2 norm coef")
parser.add_argument("--use-self-loop", default=False, action='store_true',
help="include self feature as a special relation")
fp = parser.add_mutually_exclusive_group(required=False)
fp.add_argument('--validation', dest='validation', action='store_true')
fp.add_argument('--testing', dest='validation', action='store_false')
parser.set_defaults(validation=True)
args = parser.parse_args()
print(args)
main(args)
......@@ -5,6 +5,7 @@ import torch as th
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
import dgl.nn as dglnn
import tqdm
......@@ -120,6 +121,122 @@ class RelGraphConvLayer(nn.Module):
return self.dropout(h)
return {ntype : _apply(ntype, h) for ntype, h in hs.items()}
class RelGraphConvLayerHeteroAPI(nn.Module):
r"""Relational graph convolution layer.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
rel_names : list[str]
Relation names.
num_bases : int, optional
Number of bases. If is none, use number of relations. Default: None.
weight : bool, optional
True if a linear layer is applied after message passing. Default: True
bias : bool, optional
True if bias is added. Default: True
activation : callable, optional
Activation function. Default: None
self_loop : bool, optional
True to include self loop message. Default: False
dropout : float, optional
Dropout rate. Default: 0.0
"""
def __init__(self,
in_feat,
out_feat,
rel_names,
num_bases,
*,
weight=True,
bias=True,
activation=None,
self_loop=False,
dropout=0.0):
super(RelGraphConvLayerHeteroAPI, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.rel_names = rel_names
self.num_bases = num_bases
self.bias = bias
self.activation = activation
self.self_loop = self_loop
self.use_weight = weight
self.use_basis = num_bases < len(self.rel_names) and weight
if self.use_weight:
if self.use_basis:
self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.rel_names))
else:
self.weight = nn.Parameter(th.Tensor(len(self.rel_names), in_feat, out_feat))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
# bias
if bias:
self.h_bias = nn.Parameter(th.Tensor(out_feat))
nn.init.zeros_(self.h_bias)
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight,
gain=nn.init.calculate_gain('relu'))
self.dropout = nn.Dropout(dropout)
def forward(self, g, inputs):
"""Forward computation
Parameters
----------
g : DGLHeteroGraph
Input graph.
inputs : dict[str, torch.Tensor]
Node feature for each node type.
Returns
-------
dict[str, torch.Tensor]
New node features for each node type.
"""
g = g.local_var()
if self.use_weight:
weight = self.basis() if self.use_basis else self.weight
wdict = {self.rel_names[i] : {'weight' : w.squeeze(0)}
for i, w in enumerate(th.split(weight, 1, dim=0))}
else:
wdict = {}
inputs_src = inputs_dst = inputs
for srctype,_,_ in g.canonical_etypes:
g.nodes[srctype].data['h'] = inputs[srctype]
if self.use_weight:
g.apply_edges(fn.copy_u('h', 'm'))
m = g.edata['m']
for rel in g.canonical_etypes:
_, etype, _ = rel
g.edges[rel].data['h*w_r'] = th.matmul(m[rel], wdict[etype]['weight'])
else:
g.apply_edges(fn.copy_u('h', 'h*w_r'))
g.update_all(fn.copy_e('h*w_r', 'm'), fn.sum('m', 'h'))
def _apply(ntype):
h = g.nodes[ntype].data['h']
if self.self_loop:
h = h + th.matmul(inputs_dst[ntype], self.loop_weight)
if self.bias:
h = h + self.h_bias
if self.activation:
h = self.activation(h)
return self.dropout(h)
return {ntype : _apply(ntype) for ntype in g.dsttypes}
class RelGraphEmbed(nn.Module):
r"""Embedding layer for featureless heterograph."""
def __init__(self,
......@@ -253,3 +370,98 @@ class EntityClassify(nn.Module):
x = y
return y
class EntityClassify_HeteroAPI(nn.Module):
def __init__(self,
g,
h_dim, out_dim,
num_bases,
num_hidden_layers=1,
dropout=0,
use_self_loop=False):
super(EntityClassify_HeteroAPI, self).__init__()
self.g = g
self.h_dim = h_dim
self.out_dim = out_dim
self.rel_names = list(set(g.etypes))
self.rel_names.sort()
if num_bases < 0 or num_bases > len(self.rel_names):
self.num_bases = len(self.rel_names)
else:
self.num_bases = num_bases
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
self.use_self_loop = use_self_loop
self.embed_layer = RelGraphEmbed(g, self.h_dim)
self.layers = nn.ModuleList()
# i2h
self.layers.append(RelGraphConvLayerHeteroAPI(
self.h_dim, self.h_dim, self.rel_names,
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout, weight=False))
# h2h
for i in range(self.num_hidden_layers):
self.layers.append(RelGraphConvLayerHeteroAPI(
self.h_dim, self.h_dim, self.rel_names,
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout))
# h2o
self.layers.append(RelGraphConvLayerHeteroAPI(
self.h_dim, self.out_dim, self.rel_names,
self.num_bases, activation=None,
self_loop=self.use_self_loop))
def forward(self, h=None, blocks=None):
if h is None:
# full graph training
h = self.embed_layer()
if blocks is None:
# full graph training
for layer in self.layers:
h = layer(self.g, h)
else:
# minibatch training
for layer, block in zip(self.layers, blocks):
h = layer(block, h)
return h
def inference(self, g, batch_size, device, num_workers, x=None):
"""Minibatch inference of final representation over all node types.
***NOTE***
For node classification, the model is trained to predict on only one node type's
label. Therefore, only that type's final representation is meaningful.
"""
if x is None:
x = self.embed_layer()
for l, layer in enumerate(self.layers):
y = {
k: th.zeros(
g.number_of_nodes(k),
self.h_dim if l != len(self.layers) - 1 else self.out_dim)
for k in g.ntypes}
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
g,
{k: th.arange(g.number_of_nodes(k)) for k in g.ntypes},
sampler,
batch_size=batch_size,
shuffle=True,
drop_last=False,
num_workers=num_workers)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0].to(device)
h = {k: x[k][input_nodes[k]].to(device) for k in input_nodes.keys()}
h = layer(block, h)
for k in h.keys():
y[k][output_nodes[k]] = h[k].cpu()
x = y
return y
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment