Unverified Commit 6111ea46 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Model] RGCN mini-batch training; Better bipartite graph support (#1337)

* change the model to use node embedding only

* minibatch training

* add readme

* small fix

* backward-compatible

* backward compatible

* modify to_block; rgcn changed

* fix

* fix transform

* fix bug in unittest script

* docstring

* fix lint

* add tests

* address comments; fix offline eval

* gitignore
parent 9120ee60
......@@ -34,7 +34,7 @@ Example code was tested with rdflib 4.2.2 and pandas 0.23.4
### Entity Classification
(all experiments use one-hot encoding as featureless input)
All experiments use one-hot encoding as featureless input. Best accuracy reported.
AIFB: accuracy 97.22% (DGL), 95.83% (paper)
```
......@@ -56,6 +56,30 @@ AM: accuracy 91.41% (DGL), 89.29% (paper)
python3 entity_classify.py -d am --l2norm 5e-4 --n-bases 40 --testing --gpu 0
```
### Entity Classification w/ minibatch training
Accuracy numbers are reported by 10 runs.
AIFB: accuracy best=97.22% avg=93.33%
```
python3 entity_classify_mb.py -d aifb --testing --gpu 0 --fanout=8
```
MUTAG: accuracy best=76.47% avg=68.38%
```
python3 entity_classify_mb.py -d mutag --l2norm 5e-4 --n-bases 30 --testing --gpu 0 --batch-size=50 --fanout=8
```
BGS: accuracy best=96.55% avg=92.41%
```
python3 entity_classify_mb.py -d bgs --l2norm 5e-4 --n-bases 40 --testing --gpu 0
```
AM: accuracy best=90.91% avg=88.43%
```
python3 entity_classify_mb.py -d am --l2norm 5e-4 --n-bases 40 --testing --gpu 0
```
### Offline Inferencing
Trained Model can be exported by providing '--model\_path <PATH>' parameter to entity\_classify.py. And then test\_classify.py can load the saved model and do the testing offline.
......@@ -81,4 +105,4 @@ AM:
```
python3 entity_classify.py -d am --l2norm 5e-4 --n-bases 40 --testing --gpu 0 --model_path "am.pt"
python3 test_classify.py -d am --n-bases 40 --gpu 0 --model_path "am.pt"
```
\ No newline at end of file
```
......@@ -34,6 +34,9 @@ class RelGraphConvHetero(nn.Module):
Activation function. Default: None
self_loop : bool, optional
True to include self loop message. Default: False
use_weight : bool, optional
If True, multiply the input node feature with a learnable weight matrix
before message passing.
dropout : float, optional
Dropout rate. Default: 0.0
"""
......@@ -46,6 +49,7 @@ class RelGraphConvHetero(nn.Module):
bias=True,
activation=None,
self_loop=False,
use_weight=True,
dropout=0.0):
super(RelGraphConvHetero, self).__init__()
self.in_feat = in_feat
......@@ -60,18 +64,20 @@ class RelGraphConvHetero(nn.Module):
self.activation = activation
self.self_loop = self_loop
if regularizer == "basis":
# add basis weights
self.weight = nn.Parameter(th.Tensor(self.num_bases, self.in_feat, self.out_feat))
if self.num_bases < self.num_rels:
# linear combination coefficients
self.w_comp = nn.Parameter(th.Tensor(self.num_rels, self.num_bases))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
if self.num_bases < self.num_rels:
nn.init.xavier_uniform_(self.w_comp,
gain=nn.init.calculate_gain('relu'))
else:
raise ValueError("Only basis regularizer is supported.")
self.use_weight = use_weight
if use_weight:
if regularizer == "basis":
# add basis weights
self.weight = nn.Parameter(th.Tensor(self.num_bases, self.in_feat, self.out_feat))
if self.num_bases < self.num_rels:
# linear combination coefficients
self.w_comp = nn.Parameter(th.Tensor(self.num_rels, self.num_bases))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
if self.num_bases < self.num_rels:
nn.init.xavier_uniform_(self.w_comp,
gain=nn.init.calculate_gain('relu'))
else:
raise ValueError("Only basis regularizer is supported.")
# bias
if self.bias:
......@@ -99,13 +105,13 @@ class RelGraphConvHetero(nn.Module):
return {self.rel_names[i] : w.squeeze(0) for i, w in enumerate(th.split(weight, 1, dim=0))}
def forward(self, g, xs):
""" Forward computation
"""Forward computation
Parameters
----------
g : DGLHeteroGraph
Input graph.
xs : list of torch.Tensor
xs : dict[str, torch.Tensor]
Node feature for each node type.
Returns
......@@ -114,98 +120,76 @@ class RelGraphConvHetero(nn.Module):
New node features for each node type.
"""
g = g.local_var()
for i, ntype in enumerate(g.ntypes):
g.nodes[ntype].data['x'] = xs[i]
ws = self.basis_weight()
funcs = {}
for i, (srctype, etype, dsttype) in enumerate(g.canonical_etypes):
g.nodes[srctype].data['h%d' % i] = th.matmul(
g.nodes[srctype].data['x'], ws[etype])
funcs[(srctype, etype, dsttype)] = (fn.copy_u('h%d' % i, 'm'), fn.mean('m', 'h'))
for ntype in g.ntypes:
g.nodes[ntype].data['x'] = xs[ntype]
if self.use_weight:
ws = self.basis_weight()
funcs = {}
for i, (srctype, etype, dsttype) in enumerate(g.canonical_etypes):
g.nodes[srctype].data['h%d' % i] = th.matmul(
g.nodes[srctype].data['x'], ws[etype])
funcs[(srctype, etype, dsttype)] = (fn.copy_u('h%d' % i, 'm'), fn.mean('m', 'h'))
else:
funcs = {}
for i, (srctype, etype, dsttype) in enumerate(g.canonical_etypes):
g.nodes[srctype].data['h%d' % i] = g.nodes[srctype].data['x']
funcs[(srctype, etype, dsttype)] = (fn.copy_u('h%d' % i, 'm'), fn.mean('m', 'h'))
# message passing
g.multi_update_all(funcs, 'sum')
hs = [g.nodes[ntype].data['h'] for ntype in g.ntypes]
for i in range(len(hs)):
h = hs[i]
hs = {ntype : g.nodes[ntype].data['h'] for ntype in g.ntypes}
new_hs = {}
for ntype, h in hs.items():
# apply bias and activation
if self.self_loop:
h = h + th.matmul(xs[i], self.loop_weight)
h = h + th.matmul(xs[ntype], self.loop_weight)
if self.bias:
h = h + self.h_bias
if self.activation:
h = self.activation(h)
h = self.dropout(h)
hs[i] = h
return hs
new_hs[ntype] = h
return new_hs
class RelGraphConvHeteroEmbed(nn.Module):
class RelGraphEmbed(nn.Module):
r"""Embedding layer for featureless heterograph."""
def __init__(self,
embed_size,
g,
bias=True,
embed_size,
embed_name='embed',
activation=None,
self_loop=False,
dropout=0.0):
super(RelGraphConvHeteroEmbed, self).__init__()
self.embed_size = embed_size
super(RelGraphEmbed, self).__init__()
self.g = g
self.bias = bias
self.embed_size = embed_size
self.embed_name = embed_name
self.activation = activation
self.self_loop = self_loop
self.dropout = nn.Dropout(dropout)
# create weight embeddings for each node for each relation
self.embeds = nn.ParameterDict()
for srctype, etype, dsttype in g.canonical_etypes:
embed = nn.Parameter(th.Tensor(g.number_of_nodes(srctype), self.embed_size))
for ntype in g.ntypes:
embed = nn.Parameter(th.Tensor(g.number_of_nodes(ntype), self.embed_size))
nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu'))
self.embeds["{}-{}-{}".format(srctype, etype, dsttype)] = embed
# bias
if self.bias:
self.h_bias = nn.Parameter(th.Tensor(embed_size))
nn.init.zeros_(self.h_bias)
self.embeds[ntype] = embed
# weight for self loop
if self.self_loop:
self.self_embeds = nn.ParameterList()
for ntype in g.ntypes:
embed = nn.Parameter(th.Tensor(g.number_of_nodes(ntype), embed_size))
nn.init.xavier_uniform_(embed,
gain=nn.init.calculate_gain('relu'))
self.self_embeds.append(embed)
self.dropout = nn.Dropout(dropout)
def forward(self, block=None):
"""Forward computation
def forward(self):
""" Forward computation
Parameters
----------
block : DGLHeteroGraph, optional
If not specified, directly return the full graph with embeddings stored in
:attr:`embed_name`. Otherwise, extract and store the embeddings to the block
graph and return.
Returns
-------
torch.Tensor
New node features.
DGLHeteroGraph
The block graph fed with embeddings.
"""
g = self.g.local_var()
funcs = {}
for i, (srctype, etype, dsttype) in enumerate(g.canonical_etypes):
g.nodes[srctype].data['embed-%d' % i] = self.embeds["{}-{}-{}".format(srctype, etype, dsttype)]
funcs[(srctype, etype, dsttype)] = (fn.copy_u('embed-%d' % i, 'm'), fn.mean('m', 'h'))
g.multi_update_all(funcs, 'sum')
hs = [g.nodes[ntype].data['h'] for ntype in g.ntypes]
for i in range(len(hs)):
h = hs[i]
# apply bias and activation
if self.self_loop:
h = h + self.self_embeds[i]
if self.bias:
h = h + self.h_bias
if self.activation:
h = self.activation(h)
h = self.dropout(h)
hs[i] = h
return hs
return self.embeds
class EntityClassify(nn.Module):
def __init__(self,
......@@ -226,10 +210,13 @@ class EntityClassify(nn.Module):
self.dropout = dropout
self.use_self_loop = use_self_loop
self.embed_layer = RelGraphConvHeteroEmbed(
self.h_dim, g, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout)
self.embed_layer = RelGraphEmbed(g, self.h_dim)
self.layers = nn.ModuleList()
# i2h
self.layers.append(RelGraphConvHetero(
self.h_dim, self.h_dim, self.rel_names, "basis",
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout, use_weight=False))
# h2h
for i in range(self.num_hidden_layers):
self.layers.append(RelGraphConvHetero(
......@@ -310,7 +297,7 @@ def main(args):
optimizer.zero_grad()
if epoch > 5:
t0 = time.time()
logits = model()[category_id]
logits = model()[category]
loss = F.cross_entropy(logits[train_idx], labels[train_idx])
loss.backward()
optimizer.step()
......@@ -328,7 +315,7 @@ def main(args):
th.save(model.state_dict(), args.model_path)
model.eval()
logits = model.forward()[category_id]
logits = model.forward()[category]
test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
test_acc = th.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
print("Test Acc: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item()))
......
"""Modeling Relational Data with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1703.06103
Reference Code: https://github.com/tkipf/relational-gcn
"""
import argparse
import itertools
import numpy as np
import time
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from functools import partial
import dgl
import dgl.function as fn
from dgl.data.rdf import AIFB, MUTAG, BGS, AM
class RelGraphConvHetero(nn.Module):
r"""Relational graph convolution layer.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
rel_names : int
Relation names.
regularizer : str
Which weight regularizer to use "basis" or "bdd"
num_bases : int, optional
Number of bases. If is none, use number of relations. Default: None.
bias : bool, optional
True if bias is added. Default: True
activation : callable, optional
Activation function. Default: None
self_loop : bool, optional
True to include self loop message. Default: False
use_weight : bool, optional
If True, multiply the input node feature with a learnable weight matrix
before message passing.
dropout : float, optional
Dropout rate. Default: 0.0
"""
def __init__(self,
in_feat,
out_feat,
rel_names,
regularizer="basis",
num_bases=None,
bias=True,
activation=None,
self_loop=False,
use_weight=True,
dropout=0.0):
super(RelGraphConvHetero, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.rel_names = rel_names
self.num_rels = len(rel_names)
self.regularizer = regularizer
self.num_bases = num_bases
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases < 0:
self.num_bases = self.num_rels
self.bias = bias
self.activation = activation
self.self_loop = self_loop
self.use_weight = use_weight
if use_weight:
if regularizer == "basis":
# add basis weights
self.weight = nn.Parameter(th.Tensor(self.num_bases, self.in_feat, self.out_feat))
if self.num_bases < self.num_rels:
# linear combination coefficients
self.w_comp = nn.Parameter(th.Tensor(self.num_rels, self.num_bases))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
if self.num_bases < self.num_rels:
nn.init.xavier_uniform_(self.w_comp,
gain=nn.init.calculate_gain('relu'))
else:
raise ValueError("Only basis regularizer is supported.")
# bias
if self.bias:
self.h_bias = nn.Parameter(th.Tensor(out_feat))
nn.init.zeros_(self.h_bias)
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight,
gain=nn.init.calculate_gain('relu'))
self.dropout = nn.Dropout(dropout)
def basis_weight(self):
"""Message function for basis regularizer"""
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = self.weight.view(self.num_bases,
self.in_feat * self.out_feat)
weight = th.matmul(self.w_comp, weight).view(
self.num_rels, self.in_feat, self.out_feat)
else:
weight = self.weight
return {self.rel_names[i] : w.squeeze(0) for i, w in enumerate(th.split(weight, 1, dim=0))}
def forward(self, g, xs):
"""Forward computation
Parameters
----------
g : DGLHeteroGraph
Input block graph.
xs : dict[str, torch.Tensor]
Node feature for each node type.
Returns
-------
list of torch.Tensor
New node features for each node type.
"""
g = g.local_var()
for ntype, x in xs.items():
g.srcnodes[ntype].data['x'] = x
if self.use_weight:
ws = self.basis_weight()
funcs = {}
for i, (srctype, etype, dsttype) in enumerate(g.canonical_etypes):
if srctype not in xs:
continue
g.srcnodes[srctype].data['h%d' % i] = th.matmul(
g.srcnodes[srctype].data['x'], ws[etype])
funcs[(srctype, etype, dsttype)] = (fn.copy_u('h%d' % i, 'm'), fn.mean('m', 'h'))
else:
funcs = {}
for i, (srctype, etype, dsttype) in enumerate(g.canonical_etypes):
if srctype not in xs:
continue
g.srcnodes[srctype].data['h%d' % i] = g.srcnodes[srctype].data['x']
funcs[(srctype, etype, dsttype)] = (fn.copy_u('h%d' % i, 'm'), fn.mean('m', 'h'))
# message passing
g.multi_update_all(funcs, 'sum')
hs = {}
for ntype in g.dsttypes:
if 'h' in g.dstnodes[ntype].data:
hs[ntype] = g.dstnodes[ntype].data['h']
def _apply(ntype, h):
# apply bias and activation
if self.self_loop:
h = h + th.matmul(xs[ntype][:h.shape[0]], self.loop_weight)
if self.activation:
h = self.activation(h)
h = self.dropout(h)
return h
hs = {ntype : _apply(ntype, h) for ntype, h in hs.items()}
return hs
class RelGraphEmbed(nn.Module):
r"""Embedding layer for featureless heterograph."""
def __init__(self,
g,
embed_size,
activation=None,
dropout=0.0):
super(RelGraphEmbed, self).__init__()
self.g = g
self.embed_size = embed_size
self.activation = activation
self.dropout = nn.Dropout(dropout)
# create weight embeddings for each node for each relation
self.embeds = nn.ParameterDict()
for ntype in g.ntypes:
embed = nn.Parameter(th.Tensor(g.number_of_nodes(ntype), self.embed_size))
nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu'))
self.embeds[ntype] = embed
def forward(self, block=None):
"""Forward computation
Parameters
----------
block : DGLHeteroGraph, optional
If not specified, directly return the full graph with embeddings stored in
:attr:`embed_name`. Otherwise, extract and store the embeddings to the block
graph and return.
Returns
-------
DGLHeteroGraph
The block graph fed with embeddings.
"""
return self.embeds
class EntityClassify(nn.Module):
def __init__(self,
g,
h_dim, out_dim,
num_bases,
num_hidden_layers=1,
dropout=0,
use_self_loop=False):
super(EntityClassify, self).__init__()
self.g = g
self.h_dim = h_dim
self.out_dim = out_dim
self.rel_names = list(set(g.etypes))
self.rel_names.sort()
self.num_bases = None if num_bases < 0 else num_bases
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
self.use_self_loop = use_self_loop
self.layers = nn.ModuleList()
# i2h
self.layers.append(RelGraphConvHetero(
self.h_dim, self.h_dim, self.rel_names, "basis",
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout, use_weight=False))
# h2h
for i in range(self.num_hidden_layers):
self.layers.append(RelGraphConvHetero(
self.h_dim, self.h_dim, self.rel_names, "basis",
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout))
# h2o
self.layers.append(RelGraphConvHetero(
self.h_dim, self.out_dim, self.rel_names, "basis",
self.num_bases, activation=None,
self_loop=self.use_self_loop))
def forward(self, h, blocks):
for layer, block in zip(self.layers, blocks):
h = layer(block, h)
return h
class HeteroNeighborSampler:
"""Neighbor sampler on heterogeneous graphs
Parameters
----------
g : DGLHeteroGraph
Full graph
category : str
Category name of the seed nodes.
fanouts : list of int
Fanout of each hop starting from the seed nodes. If a fanout is None,
sample full neighbors.
"""
def __init__(self, g, category, fanouts):
self.g = g
self.category = category
self.fanouts = fanouts
def sample_blocks(self, seeds):
blocks = []
seeds = {self.category : th.tensor(seeds).long()}
cur = seeds
for fanout in self.fanouts:
if fanout is None:
frontier = dgl.in_subgraph(self.g, cur)
else:
frontier = dgl.sampling.sample_neighbors(self.g, cur, fanout)
block = dgl.to_block(frontier, cur)
cur = {}
for ntype in block.srctypes:
cur[ntype] = block.srcnodes[ntype].data[dgl.NID]
blocks.insert(0, block)
return seeds, blocks
def extract_embed(node_embed, block):
emb = {}
for ntype in block.srctypes:
nid = block.srcnodes[ntype].data[dgl.NID]
emb[ntype] = node_embed[ntype][nid]
return emb
def evaluate(model, seeds, blocks, node_embed, labels, category, use_cuda):
model.eval()
emb = extract_embed(node_embed, blocks[0])
lbl = labels[seeds]
if use_cuda:
emb = {k : e.cuda() for k, e in emb.items()}
lbl = lbl.cuda()
logits = model(emb, blocks)[category]
loss = F.cross_entropy(logits, lbl)
acc = th.sum(logits.argmax(dim=1) == lbl).item() / len(seeds)
return loss, acc
def main(args):
# load graph data
if args.dataset == 'aifb':
dataset = AIFB()
elif args.dataset == 'mutag':
dataset = MUTAG()
elif args.dataset == 'bgs':
dataset = BGS()
elif args.dataset == 'am':
dataset = AM()
else:
raise ValueError()
g = dataset.graph
category = dataset.predict_category
num_classes = dataset.num_classes
train_idx = dataset.train_idx
test_idx = dataset.test_idx
labels = dataset.labels
# split dataset into train, validate, test
if args.validation:
val_idx = train_idx[:len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5:]
else:
val_idx = train_idx
# check cuda
use_cuda = args.gpu >= 0 and th.cuda.is_available()
if use_cuda:
th.cuda.set_device(args.gpu)
train_label = labels[train_idx]
val_label = labels[val_idx]
test_label = labels[test_idx]
# create embeddings
embed_layer = RelGraphEmbed(g, args.n_hidden)
node_embed = embed_layer()
# create model
model = EntityClassify(g,
args.n_hidden,
num_classes,
num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2,
dropout=args.dropout,
use_self_loop=args.use_self_loop)
if use_cuda:
model.cuda()
# train sampler
sampler = HeteroNeighborSampler(g, category, [args.fanout] * args.n_layers)
loader = DataLoader(dataset=train_idx.numpy(),
batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True,
num_workers=0)
# validation sampler
val_sampler = HeteroNeighborSampler(g, category, [None] * args.n_layers)
_, val_blocks = val_sampler.sample_blocks(val_idx)
# test sampler
test_sampler = HeteroNeighborSampler(g, category, [None] * args.n_layers)
_, test_blocks = test_sampler.sample_blocks(test_idx)
# optimizer
all_params = itertools.chain(model.parameters(), embed_layer.parameters())
optimizer = th.optim.Adam(all_params, lr=args.lr, weight_decay=args.l2norm)
# training loop
print("start training...")
dur = []
for epoch in range(args.n_epochs):
model.train()
optimizer.zero_grad()
if epoch > 3:
t0 = time.time()
for i, (seeds, blocks) in enumerate(loader):
batch_tic = time.time()
emb = extract_embed(node_embed, blocks[0])
lbl = labels[seeds[category]]
if use_cuda:
emb = {k : e.cuda() for k, e in emb.items()}
lbl = lbl.cuda()
logits = model(emb, blocks)[category]
loss = F.cross_entropy(logits, lbl)
loss.backward()
optimizer.step()
train_acc = th.sum(logits.argmax(dim=1) == lbl).item() / len(seeds[category])
print("Epoch {:05d} | Batch {:03d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Time: {:.4f}".
format(epoch, i, train_acc, loss.item(), time.time() - batch_tic))
if epoch > 3:
dur.append(time.time() - t0)
val_loss, val_acc = evaluate(model, val_idx, val_blocks, node_embed, labels, category, use_cuda)
print("Epoch {:05d} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".
format(epoch, val_acc, val_loss.item(), np.average(dur)))
print()
if args.model_path is not None:
th.save(model.state_dict(), args.model_path)
test_loss, test_acc = evaluate(model, test_idx, test_blocks, node_embed, labels, category, use_cuda)
print("Test Acc: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item()))
print()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN')
parser.add_argument("--dropout", type=float, default=0,
help="dropout probability")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden units")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-bases", type=int, default=-1,
help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("--n-layers", type=int, default=2,
help="number of propagation rounds")
parser.add_argument("-e", "--n-epochs", type=int, default=20,
help="number of training epochs")
parser.add_argument("-d", "--dataset", type=str, required=True,
help="dataset to use")
parser.add_argument("--model_path", type=str, default=None,
help='path for save the model')
parser.add_argument("--l2norm", type=float, default=0,
help="l2 norm coef")
parser.add_argument("--use-self-loop", default=False, action='store_true',
help="include self feature as a special relation")
parser.add_argument("--batch-size", type=int, default=100,
help="Mini-batch size. If -1, use full graph training.")
parser.add_argument("--fanout", type=int, default=4,
help="Fan-out of neighbor sampling.")
fp = parser.add_mutually_exclusive_group(required=False)
fp.add_argument('--validation', dest='validation', action='store_true')
fp.add_argument('--testing', dest='validation', action='store_false')
parser.set_defaults(validation=True)
args = parser.parse_args()
print(args)
main(args)
......@@ -26,10 +26,6 @@ def main(args):
num_classes = dataset.num_classes
test_idx = dataset.test_idx
labels = dataset.labels
category_id = len(g.ntypes)
for i, ntype in enumerate(g.ntypes):
if ntype == category:
category_id = i
# check cuda
use_cuda = args.gpu >= 0 and th.cuda.is_available()
......@@ -52,7 +48,7 @@ def main(args):
print("start testing...")
model.eval()
logits = model.forward()[category_id]
logits = model.forward()[category]
test_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
test_acc = th.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
print("Test Acc: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item()))
......@@ -79,4 +75,4 @@ if __name__ == '__main__':
args = parser.parse_args()
print(args)
main(args)
\ No newline at end of file
main(args)
......@@ -149,6 +149,8 @@ class DGLHeteroGraph(object):
>>> g['plays'].number_of_nodes() # ERROR!! There are two types 'user' and 'game'.
>>> g['plays'].number_of_edges() # OK!! because there is only one edge type 'plays'
TODO(minjie): docstring about uni-directional bipartite graph
Metagraph
---------
For each heterogeneous graph, one can often infer the *metagraph*, the template of
......@@ -171,8 +173,10 @@ class DGLHeteroGraph(object):
----------
gidx : HeteroGraphIndex
Graph index object.
ntypes : list of str
ntypes : list of str, pair of list of str
Node type list. ``ntypes[i]`` stores the name of node type i.
If a pair is given, the graph created is a uni-directional bipartite graph,
and its SRC node types and DST node types are given as in the pair.
etypes : list of str
Edge type list. ``etypes[i]`` stores the name of edge type i.
node_frames : list of FrameRef, optional
......@@ -196,22 +200,49 @@ class DGLHeteroGraph(object):
def _init(self, gidx, ntypes, etypes, node_frames, edge_frames):
"""Init internal states."""
self._graph = gidx
self._ntypes = ntypes
# Handle node types
if isinstance(ntypes, tuple):
if len(ntypes) != 2:
errmsg = 'Invalid input. Expect a pair (srctypes, dsttypes) but got {}'.format(
ntypes)
raise TypeError(errmsg)
if not is_unibipartite(self._graph.metagraph):
raise ValueError('Invalid input. The metagraph must be a uni-directional'
' bipartite graph.')
self._ntypes = ntypes[0] + ntypes[1]
self._srctypes_invmap = {t : i for i, t in enumerate(ntypes[0])}
self._dsttypes_invmap = {t : i + len(ntypes[0]) for i, t in enumerate(ntypes[1])}
self._is_unibipartite = True
else:
self._ntypes = ntypes
src_dst_map = find_src_dst_ntypes(self._ntypes, self._graph.metagraph)
self._is_unibipartite = (src_dst_map is not None)
if self._is_unibipartite:
self._srctypes_invmap, self._dsttypes_invmap = src_dst_map
else:
self._srctypes_invmap = {t : i for i, t in enumerate(self._ntypes)}
self._dsttypes_invmap = self._srctypes_invmap
# Handle edge types
self._etypes = etypes
self._nx_metagraph = None
self._canonical_etypes = make_canonical_etypes(
self._etypes, self._ntypes, self._graph.metagraph)
# An internal map from etype to canonical etype tuple.
# If two etypes have the same name, an empty tuple is stored instead to indicte ambiguity.
# If two etypes have the same name, an empty tuple is stored instead to indicate
# ambiguity.
self._etype2canonical = {}
for i, ety in enumerate(self._etypes):
if ety in self._etype2canonical:
self._etype2canonical[ety] = tuple()
else:
self._etype2canonical[ety] = self._canonical_etypes[i]
self._ntypes_invmap = {t : i for i, t in enumerate(self._ntypes)}
self._etypes_invmap = {t : i for i, t in enumerate(self._canonical_etypes)}
# Cached metagraph in networkx
self._nx_metagraph = None
# node and edge frame
if node_frames is None:
node_frames = [None] * len(self._ntypes)
......@@ -304,6 +335,24 @@ class DGLHeteroGraph(object):
# Metagraph query
#################################################################
@property
def is_unibipartite(self):
"""Return whether the graph is a uni-bipartite graph.
A uni-bipartite heterograph can further divide its node types into two sets:
SRC and DST. All edges are from nodes in SRC to nodes in DST. The following APIs
can be used to get the nodes and types that belong to SRC and DST sets:
* :func:`srctype` and :func:`dsttype`
* :func:`srcdata` and :func:`dstdata`
* :func:`srcnodes` and :func:`dstnodes`
Note that we allow two node types to have the same name as long as one
belongs to SRC while the other belongs to DST. To distinguish them, prepend
the name with ``"SRC/"`` or ``"DST/"`` when specifying a node type.
"""
return self._is_unibipartite
@property
def ntypes(self):
"""Return the list of node types of this graph.
......@@ -364,22 +413,24 @@ class DGLHeteroGraph(object):
return self._canonical_etypes
@property
def ntype(self):
"""Return the node type if the graph has only one node type."""
assert len(self.ntypes) == 1, "The graph has more than one node type."
return self.ntypes[0]
@property
def srctype(self):
"""Return the source node type if the graph has only one edge type."""
assert len(self.etypes) == 1, "The graph has more than one edge type."
return self.canonical_etypes[0][0]
def srctypes(self):
"""Return the node types in the SRC category. Return :attr:``ntypes`` if
the graph is not a uni-bipartite graph.
"""
if self.is_unibipartite:
return sorted(list(self._srctypes_invmap.keys()))
else:
return self.ntypes
@property
def dsttype(self):
"""Return the destination node type if the graph has only one edge type."""
assert len(self.etypes) == 1, "The graph has more than one edge type."
return self.canonical_etypes[0][2]
def dsttypes(self):
"""Return the node types in the DST category. Return :attr:``ntypes`` if
the graph is not a uni-bipartite graph.
"""
if self.is_unibipartite:
return sorted(list(self._dsttypes_invmap.keys()))
else:
return self.ntypes
@property
def metagraph(self):
......@@ -479,16 +530,75 @@ class DGLHeteroGraph(object):
-------
int
"""
if self.is_unibipartite:
# Only check 'SRC/' and 'DST/' prefix when is_unibipartite graph is True.
if ntype.startswith('SRC/'):
return self.get_ntype_id_from_src(ntype[4:])
elif ntype.startswith('DST/'):
return self.get_ntype_id_from_dst(ntype[4:])
# If there is no prefix, fallback to normal lookup.
# Lookup both SRC and DST
if ntype is None:
if self._graph.number_of_ntypes() != 1:
if self.is_unibipartite or len(self._srctypes_invmap) != 1:
raise DGLError('Node type name must be specified if there are more than one '
'node types.')
return 0
ntid = self._ntypes_invmap.get(ntype, None)
ntid = self._srctypes_invmap.get(ntype, self._dsttypes_invmap.get(ntype, None))
if ntid is None:
raise DGLError('Node type "{}" does not exist.'.format(ntype))
return ntid
def get_ntype_id_from_src(self, ntype):
"""Return the id of the given SRC node type.
ntype can also be None. If so, there should be only one node type in the
SRC category. Callable even when the self graph is not uni-bipartite.
Parameters
----------
ntype : str
Node type
Returns
-------
int
"""
if ntype is None:
if len(self._srctypes_invmap) != 1:
raise DGLError('SRC node type name must be specified if there are more than one '
'SRC node types.')
return 0
ntid = self._srctypes_invmap.get(ntype, None)
if ntid is None:
raise DGLError('SRC node type "{}" does not exist.'.format(ntype))
return ntid
def get_ntype_id_from_dst(self, ntype):
"""Return the id of the given DST node type.
ntype can also be None. If so, there should be only one node type in the
DST category. Callable even when the self graph is not uni-bipartite.
Parameters
----------
ntype : str
Node type
Returns
-------
int
"""
if ntype is None:
if len(self._dsttypes_invmap) != 1:
raise DGLError('DST node type name must be specified if there are more than one '
'DST node types.')
return 0
ntid = self._dsttypes_invmap.get(ntype, None)
if ntid is None:
raise DGLError('DST node type "{}" does not exist.'.format(ntype))
return ntid
def get_etype_id(self, etype):
"""Return the id of the given edge type.
......@@ -536,7 +646,47 @@ class DGLHeteroGraph(object):
--------
ndata
"""
return HeteroNodeView(self)
return HeteroNodeView(self, self.get_ntype_id)
@property
def srcnodes(self):
"""Return a SRC node view that can be used to set/get feature
data of a single node type.
Examples
--------
The following example uses PyTorch backend.
To set features of all users
>>> g = dgl.biparite([(0, 1), (1, 2)], 'user', 'plays', 'game')
>>> g.srcnodes['user'].data['h'] = torch.zeros(2, 5)
See Also
--------
srcdata
"""
return HeteroNodeView(self, self.get_ntype_id_from_src)
@property
def dstnodes(self):
"""Return a DST node view that can be used to set/get feature
data of a single node type.
Examples
--------
The following example uses PyTorch backend.
To set features of all games
>>> g = dgl.biparite([(0, 1), (1, 2)], 'user', 'plays', 'game')
>>> g.dstnodes['game'].data['h'] = torch.zeros(3, 5)
See Also
--------
dstdata
"""
return HeteroNodeView(self, self.get_ntype_id_from_dst)
@property
def ndata(self):
......@@ -558,13 +708,16 @@ class DGLHeteroGraph(object):
--------
nodes
"""
return HeteroNodeDataView(self, None, ALL)
ntid = self.get_ntype_id(None)
ntype = self.ntypes[0]
return HeteroNodeDataView(self, ntype, ntid, ALL)
@property
def srcdata(self):
"""Return the data view of all source nodes.
"""Return the data view of all nodes in the SRC category.
**Only works if the graph has only one edge type.**
**Only works if the graph is uni-bipartite and has one node type in the
SRC category.**
Examples
--------
......@@ -579,6 +732,16 @@ class DGLHeteroGraph(object):
>>> g.nodes['user'].data['h'] = torch.zeros(2, 5)
Also work on more complex uni-bipartite graph
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'), [(0, 1), (1, 2)],
... ('user', 'reads', 'book'), [(0, 1), (1, 0)],
... })
>>> print(g.is_unibipartite)
True
>>> g.srcdata['h'] = torch.zeros(2, 5)
Notes
-----
This is identical to :any:`DGLHeteroGraph.ndata` if the graph is homogeneous.
......@@ -587,15 +750,18 @@ class DGLHeteroGraph(object):
--------
nodes
"""
assert len(self.etypes) == 1, "Graph has more than one edge type."
srctype = self.canonical_etypes[0][0]
return HeteroNodeDataView(self, srctype, ALL)
assert self.is_unibipartite, 'srcdata is only allowed for uni-bipartite graph.'
assert len(self.srctypes) == 1, 'srcdata is only allowed when there is only one SRC type.'
ntype = self.srctypes[0]
ntid = self.get_ntype_id_from_src(ntype)
return HeteroNodeDataView(self, ntype, ntid, ALL)
@property
def dstdata(self):
"""Return the data view of all destination nodes.
**Only works if the graph has only one edge type.**
**Only works if the graph is uni-bipartite and has one node type in the
DST category.**
Examples
--------
......@@ -610,6 +776,16 @@ class DGLHeteroGraph(object):
>>> g.nodes['game'].data['h'] = torch.zeros(3, 5)
Also work on more complex uni-bipartite graph
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'), [(0, 1), (1, 2)],
... ('store', 'sells', 'game'), [(0, 1), (1, 0)],
... })
>>> print(g.is_unibipartite)
True
>>> g.dstdata['h'] = torch.zeros(3, 5)
Notes
-----
This is identical to :any:`DGLHeteroGraph.ndata` if the graph is homogeneous.
......@@ -618,9 +794,11 @@ class DGLHeteroGraph(object):
--------
nodes
"""
assert len(self.etypes) == 1, "Graph has more than one edge type."
dsttype = self.canonical_etypes[0][2]
return HeteroNodeDataView(self, dsttype, ALL)
assert self.is_unibipartite, 'dstdata is only allowed for uni-bipartite graph.'
assert len(self.dsttypes) == 1, 'dstdata is only allowed when there is only one DST type.'
ntype = self.dsttypes[0]
ntid = self.get_ntype_id_from_dst(ntype)
return HeteroNodeDataView(self, ntype, ntid, ALL)
@property
def edges(self):
......@@ -715,9 +893,9 @@ class DGLHeteroGraph(object):
if len(etypes) == 1:
# no ambiguity: return the unitgraph itself
srctype, etype, dsttype = self._canonical_etypes[etypes[0]]
stid = self.get_ntype_id(srctype)
stid = self.get_ntype_id_from_src(srctype)
etid = self.get_etype_id((srctype, etype, dsttype))
dtid = self.get_ntype_id(dsttype)
dtid = self.get_ntype_id_from_dst(dsttype)
new_g = self._graph.get_relation_graph(etid)
if stid == dtid:
......@@ -2728,7 +2906,7 @@ class DGLHeteroGraph(object):
"""
# infer receive node type
ntype = infer_ntype_from_dict(self, reducer_dict)
ntid = self.get_ntype_id(ntype)
ntid = self.get_ntype_id_from_dst(ntype)
if is_all(v):
v = F.arange(0, self.number_of_nodes(ntid))
elif isinstance(v, int):
......@@ -2937,7 +3115,7 @@ class DGLHeteroGraph(object):
"""
# infer receive node type
ntype = infer_ntype_from_dict(self, etype_dict)
dtid = self.get_ntype_id(ntype)
dtid = self.get_ntype_id_from_dst(ntype)
# TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel.
......@@ -3129,7 +3307,7 @@ class DGLHeteroGraph(object):
return
# infer receive node type
ntype = infer_ntype_from_dict(self, etype_dict)
dtid = self.get_ntype_id(ntype)
dtid = self.get_ntype_id_from_dst(ntype)
# TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel.
all_out = []
......@@ -3864,6 +4042,58 @@ def make_canonical_etypes(etypes, ntypes, metagraph):
rst = [(ntypes[sid], etypes[eid], ntypes[did]) for sid, did, eid in zip(src, dst, eid)]
return rst
def is_unibipartite(graph):
"""Internal function that returns whether the given graph is a uni-directional
bipartite graph.
Parameters
----------
graph : GraphIndex
Input graph
Returns
-------
bool
True if the graph is a uni-bipartite.
"""
src, dst, _ = graph.edges()
return set(src.tonumpy()).isdisjoint(set(dst.tonumpy()))
def find_src_dst_ntypes(ntypes, metagraph):
"""Internal function to split ntypes into SRC and DST categories.
If the metagraph is not a uni-bipartite graph (so that the SRC and DST categories
are not well-defined), return None.
For node types that are isolated (i.e, no relation is associated with it), they
are assigned to the SRC category.
Parameters
----------
ntypes : list of str
Node type list
metagraph : GraphIndex
Meta graph.
Returns
-------
(dict[int, str], dict[int, str]) or None
Node types belonging to SRC and DST categories. Types are stored in
a dictionary from type name to type id. Return None if the graph is
not uni-bipartite.
"""
src, dst, _ = metagraph.edges()
if set(src.tonumpy()).isdisjoint(set(dst.tonumpy())):
srctypes = {ntypes[tid] : tid for tid in src}
dsttypes = {ntypes[tid] : tid for tid in dst}
# handle isolated node types
for ntid, ntype in enumerate(ntypes):
if ntype not in srctypes and ntype not in dsttypes:
srctypes[ntype] = ntid
return srctypes, dsttypes
else:
return None
def infer_ntype_from_dict(graph, etype_dict):
"""Infer node type from dictionary of edge type to values.
......
......@@ -743,38 +743,46 @@ def compact_graphs(graphs, always_preserve=None):
return new_graphs
def to_block(g, rhs_nodes=None, lhs_suffix="_l", rhs_suffix="_r"):
def to_block(g, dst_nodes=None):
"""Convert a graph into a bipartite-structured "block" for message passing.
Specifically, we create one node type ``ntype_l`` on the "left hand" side and another
node type ``ntype_r`` on the "right hand" side for each node type ``ntype``. The
nodes of type ``ntype_r`` would contain the nodes that have an inbound edge of any type,
while ``ntype_l`` would contain all the nodes on the right hand side, as well as any
nodes that have an outbound edge of any type pointing to any node on the right hand side.
A block graph is uni-directional bipartite graph consisting of two sets of nodes
SRC and DST. Each set can have many node types while all the edges are from SRC
nodes to DST nodes.
For each relation graph of canonical edge type ``(utype, etype, vtype)``, edges
from node type ``utype`` to node type ``vtype`` are preserved, except that the
source node type and destination node type become ``utype_l`` and ``vtype_r`` in
the new graph. The resulting relation graph would have a canonical edge type
``(utype_l, etype, vtype_r)``.
Specifically, for each relation graph of canonical edge type ``(utype, etype, vtype)``,
node type ``utype`` belongs to SRC while ``vtype`` belongs to DST.
Edges from node type ``utype`` to node type ``vtype`` are preserved. If
``utype == vtype``, the result graph will have two node types of the same name ``utype``,
but one belongs to SRC while the other belongs to DST. This is because although
they have the same name, their node ids are relabeled differently (see below). In
both cases, the canonical edge type in the new graph is still
``(utype, etype, vtype)``, so there is no difference when referring to it.
We refer to such bipartite-structured graphs a **block**.
Moreover, the function also relabels node ids in each type to make the graph more compact.
Specifically, the nodes of type ``vtype`` would contain the nodes that have at least one
inbound edge of any type, while ``utype`` would contain all the DST nodes of type ``utype``,
as well as the nodes that have at least one outbound edge to any DST node.
If ``rhs_nodes`` is given, the right hand side would contain the given nodes.
Otherwise, the right hand side would be determined by DGL via the rules above.
Since DST nodes are included in SRC nodes, a common requirement is to fetch
the DST node features from the SRC nodes features. To avoid expensive sparse lookup,
the function assures that the DST nodes in both SRC and DST sets have the same ids.
As a result, given the node feature tensor ``X`` of type ``utype``,
the following code finds the corresponding DST node features of type ``vtype``:
.. code::
X[:block.number_of_nodes('DST/vtype')]
If the ``dst_nodes`` argument is given, the DST nodes would contain the given nodes.
Otherwise, the DST nodes would be determined by DGL via the rules above.
Parameters
----------
graph : DGLHeteroGraph
The graph.
rhs_nodes : Tensor or dict[str, Tensor], optional
Optional nodes that would appear on the right hand side.
If a tensor is given, the graph must have only one node type.
lhs_suffix : str, default "_l"
The suffix attached to all node types on the left hand side.
rhs_suffix : str, default "_r"
The suffix attached to all node types on the right hand side.
dst_nodes : Tensor or dict[str, Tensor], optional
Optional DST nodes. If a tensor is given, the graph must have only one node type.
Returns
-------
......@@ -786,12 +794,9 @@ def to_block(g, rhs_nodes=None, lhs_suffix="_l", rhs_suffix="_r"):
The edge IDs induced for each type would be stored in feature ``dgl.EID``.
For each node type ``ntype``, the first few nodes with type ``ntype_l`` are
guaranteed to be identical to the nodes with type ``ntype_r``.
Notes
-----
This function is primarily for creating graph structures for efficient
This function is primarily for creating the structures for efficient
computation of message passing. See [TODO] for a detailed example.
Examples
......@@ -830,63 +835,64 @@ def to_block(g, rhs_nodes=None, lhs_suffix="_l", rhs_suffix="_r"):
the right hand side nodes, you have to give a dict:
>>> g = dgl.bipartite([(0, 1), (1, 2), (2, 3)], utype='A', vtype='B')
If you don't specify any node of type A on the right hand side, the node type ``A_r``
If you don't specify any node of type A on the right hand side, the node type ``A``
in the block would have zero nodes.
>>> block = dgl.to_block(g, {'B': torch.LongTensor([3, 2])})
>>> block.number_of_nodes('A_r')
>>> block.number_of_nodes('A')
0
>>> block.number_of_nodes('B_r')
>>> block.number_of_nodes('B')
2
>>> block.nodes['B_r'].data[dgl.NID]
>>> block.nodes['B'].data[dgl.NID]
tensor([3, 2])
The left hand side would contain all the nodes on the right hand side:
>>> block.nodes['B_l'].data[dgl.NID]
>>> block.nodes['B'].data[dgl.NID]
tensor([3, 2])
As well as all the nodes that have connections to the nodes on the right hand side:
>>> block.nodes['A_l'].data[dgl.NID]
>>> block.nodes['A'].data[dgl.NID]
tensor([2, 1])
"""
if rhs_nodes is None:
if dst_nodes is None:
# Find all nodes that appeared as destinations
rhs_nodes = defaultdict(list)
dst_nodes = defaultdict(list)
for etype in g.canonical_etypes:
_, dst = g.edges(etype=etype)
rhs_nodes[etype[2]].append(dst)
rhs_nodes = {ntype: F.unique(F.cat(values, 0)) for ntype, values in rhs_nodes.items()}
elif not isinstance(rhs_nodes, Mapping):
# rhs_nodes is a Tensor, check if the g has only one type.
dst_nodes[etype[2]].append(dst)
dst_nodes = {ntype: F.unique(F.cat(values, 0)) for ntype, values in dst_nodes.items()}
elif not isinstance(dst_nodes, Mapping):
# dst_nodes is a Tensor, check if the g has only one type.
if len(g.ntypes) > 1:
raise ValueError(
'Graph has more than one node type; please specify a dict for rhs_nodes.')
rhs_nodes = {g.ntypes[0]: rhs_nodes}
'Graph has more than one node type; please specify a dict for dst_nodes.')
dst_nodes = {g.ntypes[0]: dst_nodes}
# rhs_nodes is now a dict
rhs_nodes_nd = []
# dst_nodes is now a dict
dst_nodes_nd = []
for ntype in g.ntypes:
nodes = rhs_nodes.get(ntype, None)
nodes = dst_nodes.get(ntype, None)
if nodes is not None:
rhs_nodes_nd.append(F.zerocopy_to_dgl_ndarray(nodes))
dst_nodes_nd.append(F.zerocopy_to_dgl_ndarray(nodes))
else:
rhs_nodes_nd.append(nd.null())
dst_nodes_nd.append(nd.null())
new_graph_index, lhs_nodes_nd, induced_edges_nd = _CAPI_DGLToBlock(g._graph, rhs_nodes_nd)
lhs_nodes = [F.zerocopy_from_dgl_ndarray(nodes_nd.data) for nodes_nd in lhs_nodes_nd]
rhs_nodes = [F.zerocopy_from_dgl_ndarray(nodes_nd) for nodes_nd in rhs_nodes_nd]
new_graph_index, src_nodes_nd, induced_edges_nd = _CAPI_DGLToBlock(g._graph, dst_nodes_nd)
src_nodes = [F.zerocopy_from_dgl_ndarray(nodes_nd.data) for nodes_nd in src_nodes_nd]
dst_nodes = [F.zerocopy_from_dgl_ndarray(nodes_nd) for nodes_nd in dst_nodes_nd]
new_ntypes = [ntype + lhs_suffix for ntype in g.ntypes] + \
[ntype + rhs_suffix for ntype in g.ntypes]
# The new graph duplicates the original node types to SRC and DST sets.
new_ntypes = ([ntype for ntype in g.ntypes], [ntype for ntype in g.ntypes])
new_graph = DGLHeteroGraph(new_graph_index, new_ntypes, g.etypes)
assert new_graph.is_unibipartite # sanity check
for i, ntype in enumerate(g.ntypes):
new_graph.nodes[ntype + lhs_suffix].data[NID] = lhs_nodes[i]
new_graph.nodes[ntype + rhs_suffix].data[NID] = rhs_nodes[i]
new_graph.srcnodes[ntype].data[NID] = src_nodes[i]
new_graph.dstnodes[ntype].data[NID] = dst_nodes[i]
for i, canonical_etype in enumerate(g.canonical_etypes):
induced_edges = F.zerocopy_from_dgl_ndarray(induced_edges_nd[i].data)
utype, etype, vtype = canonical_etype
new_canonical_etype = (utype + lhs_suffix, etype, vtype + rhs_suffix)
new_canonical_etype = (utype, etype, vtype)
new_graph.edges[new_canonical_etype].data[EID] = induced_edges
return new_graph
......
......@@ -250,10 +250,11 @@ class BlockDataView(MutableMapping):
class HeteroNodeView(object):
"""A NodeView class to act as G.nodes for a DGLHeteroGraph."""
__slots__ = ['_graph']
__slots__ = ['_graph', '_typeid_getter']
def __init__(self, graph):
def __init__(self, graph, typeid_getter):
self._graph = graph
self._typeid_getter = typeid_getter
def __getitem__(self, key):
if isinstance(key, slice):
......@@ -271,7 +272,8 @@ class HeteroNodeView(object):
else:
nodes = key
ntype = None
return NodeSpace(data=HeteroNodeDataView(self._graph, ntype, nodes))
ntid = self._typeid_getter(ntype)
return NodeSpace(data=HeteroNodeDataView(self._graph, ntype, ntid, nodes))
def __call__(self, ntype=None):
"""Return the nodes."""
......@@ -281,10 +283,10 @@ class HeteroNodeDataView(MutableMapping):
"""The data view class when G.ndata[ntype] is called."""
__slots__ = ['_graph', '_ntype', '_ntid', '_nodes']
def __init__(self, graph, ntype, nodes):
def __init__(self, graph, ntype, ntid, nodes):
self._graph = graph
self._ntype = ntype
self._ntid = self._graph.get_ntype_id(ntype)
self._ntid = ntid
self._nodes = nodes
def __getitem__(self, key):
......
......@@ -1473,6 +1473,43 @@ def test_isolated_ntype():
assert g.number_of_nodes('B') == 4
assert g.number_of_nodes('C') == 4
def test_bipartite():
g1 = dgl.bipartite([(0, 1), (0, 2), (1, 5)], 'A', 'AB', 'B')
assert g1.is_unibipartite
assert len(g1.ntypes) == 2
assert g1.etypes == ['AB']
assert g1.srctypes == ['A']
assert g1.dsttypes == ['B']
assert g1.number_of_nodes('A') == 2
assert g1.number_of_nodes('B') == 6
assert g1.number_of_edges() == 3
g1.srcdata['h'] = F.randn((2, 5))
assert F.array_equal(g1.srcnodes['A'].data['h'], g1.srcdata['h'])
assert F.array_equal(g1.nodes['A'].data['h'], g1.srcdata['h'])
assert F.array_equal(g1.nodes['SRC/A'].data['h'], g1.srcdata['h'])
g1.dstdata['h'] = F.randn((6, 3))
assert F.array_equal(g1.dstnodes['B'].data['h'], g1.dstdata['h'])
assert F.array_equal(g1.nodes['B'].data['h'], g1.dstdata['h'])
assert F.array_equal(g1.nodes['DST/B'].data['h'], g1.dstdata['h'])
# more complicated bipartite
g2 = dgl.bipartite([(1, 0), (0, 0)], 'A', 'AC', 'C')
g3 = dgl.hetero_from_relations([g1, g2])
assert g3.is_unibipartite
assert g3.srctypes == ['A']
assert set(g3.dsttypes) == {'B', 'C'}
assert g3.number_of_nodes('A') == 2
assert g3.number_of_nodes('B') == 6
assert g3.number_of_nodes('C') == 1
g3.srcdata['h'] = F.randn((2, 5))
assert F.array_equal(g3.srcnodes['A'].data['h'], g3.srcdata['h'])
assert F.array_equal(g3.nodes['A'].data['h'], g3.srcdata['h'])
assert F.array_equal(g3.nodes['SRC/A'].data['h'], g3.srcdata['h'])
g4 = dgl.graph([(0, 0), (1, 1)], 'A', 'AA')
g5 = dgl.hetero_from_relations([g1, g2, g4])
assert not g5.is_unibipartite
if __name__ == '__main__':
test_create()
test_query()
......@@ -1496,3 +1533,4 @@ if __name__ == '__main__':
test_types_in_function()
test_stack_reduce()
test_isolated_ntype()
test_bipartite()
......@@ -427,13 +427,13 @@ def test_to_simple():
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU compaction not implemented")
def test_to_block():
def check(g, bg, ntype, etype, rhs_nodes):
if rhs_nodes is not None:
assert F.array_equal(bg.nodes[ntype + '_r'].data[dgl.NID], rhs_nodes)
n_rhs_nodes = bg.number_of_nodes(ntype + '_r')
def check(g, bg, ntype, etype, dst_nodes):
if dst_nodes is not None:
assert F.array_equal(bg.dstnodes[ntype].data[dgl.NID], dst_nodes)
n_dst_nodes = bg.number_of_nodes('DST/' + ntype)
assert F.array_equal(
bg.nodes[ntype + '_l'].data[dgl.NID][:n_rhs_nodes],
bg.nodes[ntype + '_r'].data[dgl.NID])
bg.srcnodes[ntype].data[dgl.NID][:n_dst_nodes],
bg.dstnodes[ntype].data[dgl.NID])
g = g[etype]
bg = bg[etype]
......@@ -451,11 +451,11 @@ def test_to_block():
assert F.array_equal(induced_src_bg, induced_src_ans)
assert F.array_equal(induced_dst_bg, induced_dst_ans)
def checkall(g, bg, rhs_nodes):
def checkall(g, bg, dst_nodes):
for etype in g.etypes:
ntype = g.to_canonical_etype(etype)[2]
if rhs_nodes is not None and ntype in rhs_nodes:
check(g, bg, ntype, etype, rhs_nodes[ntype])
if dst_nodes is not None and ntype in dst_nodes:
check(g, bg, ntype, etype, dst_nodes[ntype])
else:
check(g, bg, ntype, etype, None)
......@@ -468,36 +468,36 @@ def test_to_block():
bg = dgl.to_block(g_a)
check(g_a, bg, 'A', 'AA', None)
rhs_nodes = F.tensor([3, 4], dtype=F.int64)
bg = dgl.to_block(g_a, rhs_nodes)
check(g_a, bg, 'A', 'AA', rhs_nodes)
dst_nodes = F.tensor([3, 4], dtype=F.int64)
bg = dgl.to_block(g_a, dst_nodes)
check(g_a, bg, 'A', 'AA', dst_nodes)
rhs_nodes = F.tensor([4, 3, 2, 1], dtype=F.int64)
bg = dgl.to_block(g_a, rhs_nodes)
check(g_a, bg, 'A', 'AA', rhs_nodes)
dst_nodes = F.tensor([4, 3, 2, 1], dtype=F.int64)
bg = dgl.to_block(g_a, dst_nodes)
check(g_a, bg, 'A', 'AA', dst_nodes)
g_ab = g['AB']
bg = dgl.to_block(g_ab)
assert bg.number_of_nodes('B_l') == 4
assert F.array_equal(bg.nodes['B_l'].data[dgl.NID], bg.nodes['B_r'].data[dgl.NID])
assert bg.number_of_nodes('A_r') == 0
assert bg.number_of_nodes('SRC/B') == 4
assert F.array_equal(bg.srcnodes['B'].data[dgl.NID], bg.dstnodes['B'].data[dgl.NID])
assert bg.number_of_nodes('DST/A') == 0
checkall(g_ab, bg, None)
rhs_nodes = {'B': F.tensor([5, 6], dtype=F.int64)}
bg = dgl.to_block(g, rhs_nodes)
assert bg.number_of_nodes('B_l') == 2
assert F.array_equal(bg.nodes['B_l'].data[dgl.NID], bg.nodes['B_r'].data[dgl.NID])
assert bg.number_of_nodes('A_r') == 0
checkall(g, bg, rhs_nodes)
dst_nodes = {'B': F.tensor([5, 6], dtype=F.int64)}
bg = dgl.to_block(g, dst_nodes)
assert bg.number_of_nodes('SRC/B') == 2
assert F.array_equal(bg.srcnodes['B'].data[dgl.NID], bg.dstnodes['B'].data[dgl.NID])
assert bg.number_of_nodes('DST/A') == 0
checkall(g, bg, dst_nodes)
rhs_nodes = {'A': F.tensor([3, 4], dtype=F.int64), 'B': F.tensor([5, 6], dtype=F.int64)}
bg = dgl.to_block(g, rhs_nodes)
checkall(g, bg, rhs_nodes)
dst_nodes = {'A': F.tensor([3, 4], dtype=F.int64), 'B': F.tensor([5, 6], dtype=F.int64)}
bg = dgl.to_block(g, dst_nodes)
checkall(g, bg, dst_nodes)
rhs_nodes = {'A': F.tensor([4, 3, 2, 1], dtype=F.int64), 'B': F.tensor([3, 5, 6, 1], dtype=F.int64)}
bg = dgl.to_block(g, rhs_nodes=rhs_nodes)
checkall(g, bg, rhs_nodes)
dst_nodes = {'A': F.tensor([4, 3, 2, 1], dtype=F.int64), 'B': F.tensor([3, 5, 6, 1], dtype=F.int64)}
bg = dgl.to_block(g, dst_nodes=dst_nodes)
checkall(g, bg, dst_nodes)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented")
def test_remove_edges():
......
......@@ -37,6 +37,6 @@ python3 -m pytest -v --junitxml=pytest_gindex.xml tests/graph_index || fail "gra
python3 -m pytest -v --junitxml=pytest_backend.xml tests/$DGLBACKEND || fail "backend-specific"
export OMP_NUM_THREADS=1
if [ $2 != "gpu" ] && [ $1 != "tensorflow"]; then
if [ $2 != "gpu" ] && [ $1 != "tensorflow" ]; then
python3 -m pytest -v --junitxml=pytest_distributed.xml tests/distributed || fail "distributed"
fi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment