Commit 7e30382e authored by Ziyue Huang's avatar Ziyue Huang Committed by Da Zheng
Browse files

[Model][MXNet] neighbor sampling & skip connection & control variate & graphsage (#322)

* neighbor sampling draft

* val/test acc

* control variate draft

* control variate

* update

* fix new_history

* maintain aggregated history while updating new history

* preprocess the first layer, change push to pull

* update

* fix subg_degree

* nodeflow

* clear

* readme

* doc and unittest for self loop

* address comments

* rename

* update

* fix

* Update node_flow.py

* Update node_flow.py
parent 3f891b64
# Stochastic Training for Graph Convolutional Networks
* Paper: [Control Variate](https://arxiv.org/abs/1710.10568)
* Paper: [Skip Connection](https://arxiv.org/abs/1809.05343)
* Author's code: [https://github.com/thu-ml/stochastic_gcn](https://github.com/thu-ml/stochastic_gcn)
### Dependencies
- MXNet nightly build
```bash
pip install mxnet --pre
```
### Neighbor Sampling & Skip Connection
cora: test accuracy ~83% with `--num-neighbors 2`, ~84% by training on the full graph
```
DGLBACKEND=mxnet python gcn_ns_sc.py --dataset cora --self-loop --num-neighbors 2 --batch-size 1000000 --test-batch-size 1000000 --gpu 0
```
citeseer: test accuracy ~69% with `--num-neighbors 2`, ~70% by training on the full graph
```
DGLBACKEND=mxnet python gcn_ns_sc.py --dataset citeseer --self-loop --num-neighbors 2 --batch-size 1000000 --test-batch-size 1000000 --gpu 0
```
pubmed: test accuracy ~76% with `--num-neighbors 3`, ~77% by training on the full graph
```
DGLBACKEND=mxnet python gcn_ns_sc.py --dataset pubmed --self-loop --num-neighbors 3 --batch-size 1000000 --test-batch-size 1000000 --gpu 0
```
reddit: test accuracy ~91% with `--num-neighbors 2` and `--batch-size 1000`, ~93% by training on the full graph
```
DGLBACKEND=mxnet python gcn_ns_sc.py --dataset reddit-self-loop --num-neighbors 2 --batch-size 1000 --test-batch-size 500 --n-hidden 64
```
### Control Variate & Skip Connection
cora: test accuracy ~84% with `--num-neighbors 1`, ~84% by training on the full graph
```
DGLBACKEND=mxnet python gcn_cv_sc.py --dataset cora --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000 --gpu 0
```
citeseer: test accuracy ~69% with `--num-neighbors 1`, ~70% by training on the full graph
```
DGLBACKEND=mxnet python gcn_cv_sc.py --dataset citeseer --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000 --gpu 0
```
pubmed: test accuracy ~77% with `--num-neighbors 1`, ~77% by training on the full graph
```
DGLBACKEND=mxnet python gcn_cv_sc.py --dataset pubmed --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000 --gpu 0
```
reddit: test accuracy ~93% with `--num-neighbors 1` and `--batch-size 1000`, ~93% by training on the full graph
```
DGLBACKEND=mxnet python gcn_cv_sc.py --dataset reddit-self-loop --num-neighbors 1 --batch-size 1000 --test-batch-size 500 --n-hidden 64
```
### Control Variate & GraphSAGE-mean
Following [Control Variate](https://arxiv.org/abs/1710.10568), we use the mean pooling architecture GraphSAGE-mean, two linear layers and layer normalization per graph convolution layer.
reddit: test accuracy 96.1% with `--num-neighbors 1` and `--batch-size 1000`, ~96.2% in [Control Variate](https://arxiv.org/abs/1710.10568) with `--num-neighbors 2` and `--batch-size 1000`
```
DGLBACKEND=mxnet python graphsage_cv.py --batch-size 1000 --test-batch-size 500 --n-epochs 50 --dataset reddit --num-neighbors 1 --n-hidden 128 --dropout 0.2 --weight-decay 0
```
import argparse, time, math
import numpy as np
import mxnet as mx
from mxnet import gluon
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
class NodeUpdate(gluon.Block):
def __init__(self, layer_id, in_feats, out_feats, dropout, activation=None, test=False, concat=False):
super(NodeUpdate, self).__init__()
self.layer_id = layer_id
self.dropout = dropout
self.test = test
self.concat = concat
with self.name_scope():
self.dense = gluon.nn.Dense(out_feats, in_units=in_feats)
self.activation = activation
def forward(self, node):
h = node.data['h']
if self.test:
norm = node.data['norm']
h = h * norm
else:
agg_history_str = 'agg_h_{}'.format(self.layer_id-1)
agg_history = node.data[agg_history_str]
# control variate
h = h + agg_history
if self.dropout:
h = mx.nd.Dropout(h, p=self.dropout)
h = self.dense(h)
if self.concat:
h = mx.nd.concat(h, self.activation(h))
elif self.activation:
h = self.activation(h)
return {'activation': h}
class GCNSampling(gluon.Block):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
**kwargs):
super(GCNSampling, self).__init__(**kwargs)
self.dropout = dropout
self.n_layers = n_layers
with self.name_scope():
self.layers = gluon.nn.Sequential()
# input layer
self.dense = gluon.nn.Dense(n_hidden, in_units=in_feats)
self.activation = activation
# hidden layers
for i in range(1, n_layers):
skip_start = (i == self.n_layers-1)
self.layers.add(NodeUpdate(i, n_hidden, n_hidden, dropout, activation, concat=skip_start))
# output layer
self.layers.add(NodeUpdate(n_layers, 2*n_hidden, n_classes, dropout))
def forward(self, nf):
h = nf.layers[0].data['preprocess']
if self.dropout:
h = mx.nd.Dropout(h, p=self.dropout)
h = self.dense(h)
skip_start = (0 == self.n_layers-1)
if skip_start:
h = mx.nd.concat(h, self.activation(h))
else:
h = self.activation(h)
for i, layer in enumerate(self.layers):
new_history = h.copy().detach()
history_str = 'h_{}'.format(i)
history = nf.layers[i].data[history_str]
h = h - history
nf.layers[i].data['h'] = h
nf.block_compute(i,
fn.copy_src(src='h', out='m'),
lambda node : {'h': node.mailbox['m'].mean(axis=1)},
layer)
h = nf.layers[i+1].data.pop('activation')
# update history
if i < nf.num_layers-1:
nf.layers[i].data[history_str] = new_history
return h
class GCNInfer(gluon.Block):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
**kwargs):
super(GCNInfer, self).__init__(**kwargs)
self.n_layers = n_layers
with self.name_scope():
self.layers = gluon.nn.Sequential()
# input layer
self.dense = gluon.nn.Dense(n_hidden, in_units=in_feats)
self.activation = activation
# hidden layers
for i in range(1, n_layers):
skip_start = (i == self.n_layers-1)
self.layers.add(NodeUpdate(i, n_hidden, n_hidden, 0, activation, True, concat=skip_start))
# output layer
self.layers.add(NodeUpdate(n_layers, 2*n_hidden, n_classes, 0, None, True))
def forward(self, nf):
h = nf.layers[0].data['preprocess']
h = self.dense(h)
skip_start = (0 == self.n_layers-1)
if skip_start:
h = mx.nd.concat(h, self.activation(h))
else:
h = self.activation(h)
for i, layer in enumerate(self.layers):
nf.layers[i].data['h'] = h
nf.block_compute(i,
fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'),
layer)
h = nf.layers[i+1].data.pop('activation')
return h
def main(args):
# load and preprocess dataset
data = load_data(args)
if args.gpu >= 0:
ctx = mx.gpu(args.gpu)
else:
ctx = mx.cpu()
if args.self_loop and not args.dataset.startswith('reddit'):
data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])
train_nid = mx.nd.array(np.nonzero(data.train_mask)[0]).astype(np.int64)
test_nid = mx.nd.array(np.nonzero(data.test_mask)[0]).astype(np.int64)
num_neighbors = args.num_neighbors
n_layers = args.n_layers
features = mx.nd.array(data.features).as_in_context(ctx)
labels = mx.nd.array(data.labels).as_in_context(ctx)
train_mask = mx.nd.array(data.train_mask).as_in_context(ctx)
val_mask = mx.nd.array(data.val_mask).as_in_context(ctx)
test_mask = mx.nd.array(data.test_mask).as_in_context(ctx)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_train_samples = train_mask.sum().asscalar()
n_test_samples = test_mask.sum().asscalar()
n_val_samples = val_mask.sum().asscalar()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
n_train_samples,
n_val_samples,
n_test_samples))
# create GCN model
g = DGLGraph(data.graph, readonly=True)
g.ndata['features'] = features
norm = mx.nd.expand_dims(1./g.in_degrees().astype('float32'), 1)
g.ndata['norm'] = norm.as_in_context(ctx)
g.update_all(fn.copy_src(src='features', out='m'),
fn.sum(msg='m', out='preprocess'),
lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']})
for i in range(n_layers):
g.ndata['h_{}'.format(i)] = mx.nd.zeros((features.shape[0], args.n_hidden), ctx=ctx)
g.ndata['h_{}'.format(n_layers-1)] = mx.nd.zeros((features.shape[0], 2*args.n_hidden), ctx=ctx)
model = GCNSampling(in_feats,
args.n_hidden,
n_classes,
n_layers,
mx.nd.relu,
args.dropout,
prefix='GCN')
model.initialize(ctx=ctx)
loss_fcn = gluon.loss.SoftmaxCELoss()
infer_model = GCNInfer(in_feats,
args.n_hidden,
n_classes,
n_layers,
mx.nd.relu,
prefix='GCN')
infer_model.initialize(ctx=ctx)
# use optimizer
print(model.collect_params())
trainer = gluon.Trainer(model.collect_params(), 'adam',
{'learning_rate': args.lr, 'wd': args.weight_decay},
kvstore=mx.kv.create('local'))
# initialize graph
dur = []
for epoch in range(args.n_epochs):
for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
num_neighbors,
neighbor_type='in',
shuffle=True,
num_hops=n_layers,
seed_nodes=train_nid):
for i in range(n_layers):
agg_history_str = 'agg_h_{}'.format(i)
g.pull(nf.layer_parent_nid(i+1), fn.copy_src(src='h_{}'.format(i), out='m'),
fn.sum(msg='m', out=agg_history_str),
lambda node : {agg_history_str: node.data[agg_history_str] * node.data['norm']})
node_embed_names = [['preprocess', 'h_0']]
for i in range(1, n_layers):
node_embed_names.append(['h_{}'.format(i), 'agg_h_{}'.format(i-1)])
node_embed_names.append(['agg_h_{}'.format(n_layers-1)])
nf.copy_from_parent(node_embed_names=node_embed_names)
# forward
with mx.autograd.record():
pred = model(nf)
batch_nids = nf.layer_parent_nid(-1).as_in_context(ctx)
batch_labels = labels[batch_nids]
loss = loss_fcn(pred, batch_labels)
loss = loss.sum() / len(batch_nids)
loss.backward()
trainer.step(batch_size=1)
node_embed_names = [['h_{}'.format(i)] for i in range(n_layers)]
node_embed_names.append([])
nf.copy_to_parent(node_embed_names=node_embed_names)
infer_params = infer_model.collect_params()
for key in infer_params:
idx = trainer._param2idx[key]
trainer._kvstore.pull(idx, out=infer_params[key].data())
num_acc = 0.
for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
g.number_of_nodes(),
neighbor_type='in',
num_hops=n_layers,
seed_nodes=test_nid):
node_embed_names = [['preprocess']]
for i in range(n_layers):
node_embed_names.append(['norm'])
nf.copy_from_parent(node_embed_names=node_embed_names)
pred = infer_model(nf)
batch_nids = nf.layer_parent_nid(-1).as_in_context(ctx)
batch_labels = labels[batch_nids]
num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar()
print("Test Accuracy {:.4f}". format(num_acc/n_test_samples))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=3e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--batch-size", type=int, default=1000,
help="train batch size")
parser.add_argument("--test-batch-size", type=int, default=1000,
help="test batch size")
parser.add_argument("--num-neighbors", type=int, default=2,
help="number of neighbors to be sampled")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
args = parser.parse_args()
print(args)
main(args)
import argparse, time, math
import numpy as np
import mxnet as mx
from mxnet import gluon
from functools import partial
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
class NodeUpdate(gluon.Block):
def __init__(self, in_feats, out_feats, activation=None, test=False, concat=False):
super(NodeUpdate, self).__init__()
self.dense = gluon.nn.Dense(out_feats, in_units=in_feats)
self.activation = activation
self.concat = concat
self.test = test
def forward(self, node):
h = node.data['h']
if self.test:
h = h * node.data['norm']
h = self.dense(h)
# skip connection
if self.concat:
h = mx.nd.concat(h, self.activation(h))
elif self.activation:
h = self.activation(h)
return {'activation': h}
class GCNSampling(gluon.Block):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
**kwargs):
super(GCNSampling, self).__init__(**kwargs)
self.dropout = dropout
self.n_layers = n_layers
with self.name_scope():
self.layers = gluon.nn.Sequential()
# input layer
skip_start = (0 == n_layers-1)
self.layers.add(NodeUpdate(in_feats, n_hidden, activation, concat=skip_start))
# hidden layers
for i in range(1, n_layers):
skip_start = (i == n_layers-1)
self.layers.add(NodeUpdate(n_hidden, n_hidden, activation, concat=skip_start))
# output layer
self.layers.add(NodeUpdate(2*n_hidden, n_classes))
def forward(self, nf):
nf.layers[0].data['activation'] = nf.layers[0].data['features']
for i, layer in enumerate(self.layers):
h = nf.layers[i].data.pop('activation')
if self.dropout:
h = mx.nd.Dropout(h, p=self.dropout)
nf.layers[i].data['h'] = h
nf.block_compute(i,
fn.copy_src(src='h', out='m'),
lambda node : {'h': node.mailbox['m'].mean(axis=1)},
layer)
h = nf.layers[-1].data.pop('activation')
return h
class GCNInfer(gluon.Block):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
**kwargs):
super(GCNInfer, self).__init__(**kwargs)
self.n_layers = n_layers
with self.name_scope():
self.layers = gluon.nn.Sequential()
# input layer
skip_start = (0 == n_layers-1)
self.layers.add(NodeUpdate(in_feats, n_hidden, activation, test=True, concat=skip_start))
# hidden layers
for i in range(1, n_layers):
skip_start = (i == n_layers-1)
self.layers.add(NodeUpdate(n_hidden, n_hidden, activation, test=True, concat=skip_start))
# output layer
self.layers.add(NodeUpdate(2*n_hidden, n_classes, test=True))
def forward(self, nf):
nf.layers[0].data['activation'] = nf.layers[0].data['features']
for i, layer in enumerate(self.layers):
h = nf.layers[i].data.pop('activation')
nf.layers[i].data['h'] = h
nf.block_compute(i,
fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'),
layer)
h = nf.layers[-1].data.pop('activation')
return h
def main(args):
# load and preprocess dataset
data = load_data(args)
if args.gpu >= 0:
ctx = mx.gpu(args.gpu)
else:
ctx = mx.cpu()
if args.self_loop and not args.dataset.startswith('reddit'):
data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])
train_nid = mx.nd.array(np.nonzero(data.train_mask)[0]).astype(np.int64).as_in_context(ctx)
test_nid = mx.nd.array(np.nonzero(data.test_mask)[0]).astype(np.int64).as_in_context(ctx)
features = mx.nd.array(data.features).as_in_context(ctx)
labels = mx.nd.array(data.labels).as_in_context(ctx)
train_mask = mx.nd.array(data.train_mask).as_in_context(ctx)
val_mask = mx.nd.array(data.val_mask).as_in_context(ctx)
test_mask = mx.nd.array(data.test_mask).as_in_context(ctx)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_train_samples = train_mask.sum().asscalar()
n_val_samples = val_mask.sum().asscalar()
n_test_samples = test_mask.sum().asscalar()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
n_train_samples,
n_val_samples,
n_test_samples))
# create GCN model
g = DGLGraph(data.graph, readonly=True)
g.ndata['features'] = features
num_neighbors = args.num_neighbors
degs = g.in_degrees().astype('float32').as_in_context(ctx)
norm = mx.nd.expand_dims(1./degs, 1)
g.ndata['norm'] = norm
model = GCNSampling(in_feats,
args.n_hidden,
n_classes,
args.n_layers,
mx.nd.relu,
args.dropout,
prefix='GCN')
model.initialize(ctx=ctx)
loss_fcn = gluon.loss.SoftmaxCELoss()
infer_model = GCNInfer(in_feats,
args.n_hidden,
n_classes,
args.n_layers,
mx.nd.relu,
prefix='GCN')
infer_model.initialize(ctx=ctx)
# use optimizer
print(model.collect_params())
trainer = gluon.Trainer(model.collect_params(), 'adam',
{'learning_rate': args.lr, 'wd': args.weight_decay},
kvstore=mx.kv.create('local'))
# initialize graph
dur = []
for epoch in range(args.n_epochs):
for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
args.num_neighbors,
neighbor_type='in',
shuffle=True,
num_hops=args.n_layers+1,
seed_nodes=train_nid):
nf.copy_from_parent()
# forward
with mx.autograd.record():
pred = model(nf)
batch_nids = nf.layer_parent_nid(-1).astype('int64').as_in_context(ctx)
batch_labels = labels[batch_nids]
loss = loss_fcn(pred, batch_labels)
loss = loss.sum() / len(batch_nids)
loss.backward()
trainer.step(batch_size=1)
infer_params = infer_model.collect_params()
for key in infer_params:
idx = trainer._param2idx[key]
trainer._kvstore.pull(idx, out=infer_params[key].data())
num_acc = 0.
for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
g.number_of_nodes(),
neighbor_type='in',
num_hops=args.n_layers+1,
seed_nodes=test_nid):
nf.copy_from_parent()
pred = infer_model(nf)
batch_nids = nf.layer_parent_nid(-1).astype('int64').as_in_context(ctx)
batch_labels = labels[batch_nids]
num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar()
print("Test Accuracy {:.4f}". format(num_acc/n_test_samples))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=3e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--batch-size", type=int, default=1000,
help="batch size")
parser.add_argument("--test-batch-size", type=int, default=1000,
help="test batch size")
parser.add_argument("--num-neighbors", type=int, default=3,
help="number of neighbors to be sampled")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
args = parser.parse_args()
print(args)
main(args)
import argparse, time, math
import numpy as np
import mxnet as mx
from mxnet import gluon
import argparse, time, math
import numpy as np
import mxnet as mx
from mxnet import gluon
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgl.graph_index import map_to_nodeflow_nid
class GraphSAGELayer(gluon.Block):
def __init__(self,
in_feats,
hidden,
out_feats,
dropout,
last=False,
**kwargs):
super(GraphSAGELayer, self).__init__(**kwargs)
self.last = last
self.dropout = dropout
with self.name_scope():
self.dense1 = gluon.nn.Dense(hidden, in_units=in_feats)
self.layer_norm1 = gluon.nn.LayerNorm(in_channels=hidden)
self.dense2 = gluon.nn.Dense(out_feats, in_units=hidden)
if not self.last:
self.layer_norm2 = gluon.nn.LayerNorm(in_channels=out_feats)
def forward(self, h):
h = self.dense1(h)
h = self.layer_norm1(h)
h = mx.nd.relu(h)
if self.dropout:
h = mx.nd.Dropout(h, p=self.dropout)
h = self.dense2(h)
if not self.last:
h = self.layer_norm2(h)
h = mx.nd.relu(h)
return h
class NodeUpdate(gluon.Block):
def __init__(self, layer_id, in_feats, out_feats, hidden, dropout,
test=False, last=False):
super(NodeUpdate, self).__init__()
self.layer_id = layer_id
self.dropout = dropout
self.test = test
self.last = last
with self.name_scope():
self.layer = GraphSAGELayer(in_feats, hidden, out_feats, dropout, last)
def forward(self, node):
h = node.data['h']
norm = node.data['norm']
# activation from previous layer of myself
self_h = node.data['self_h']
if self.test:
h = (h - self_h) * norm
# graphsage
h = mx.nd.concat(h, self_h)
else:
agg_history_str = 'agg_h_{}'.format(self.layer_id-1)
agg_history = node.data[agg_history_str]
# normalization constant
subg_norm = node.data['subg_norm']
# delta_h (h - history) from previous layer of myself
self_delta_h = node.data['self_delta_h']
# control variate
h = (h - self_delta_h) * subg_norm + agg_history * norm
# graphsage
h = mx.nd.concat(h, self_h)
if self.dropout:
h = mx.nd.Dropout(h, p=self.dropout)
h = self.layer(h)
return {'activation': h}
class GraphSAGETrain(gluon.Block):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
dropout,
**kwargs):
super(GraphSAGETrain, self).__init__(**kwargs)
self.dropout = dropout
with self.name_scope():
self.layers = gluon.nn.Sequential()
# input layer
self.input_layer = GraphSAGELayer(2*in_feats, n_hidden, n_hidden, dropout)
# hidden layers
for i in range(1, n_layers):
self.layers.add(NodeUpdate(i, 2*n_hidden, n_hidden, n_hidden, dropout))
# output layer
self.layers.add(NodeUpdate(n_layers, 2*n_hidden, n_classes, n_hidden, dropout, last=True))
def forward(self, nf):
h = nf.layers[0].data['preprocess']
features = nf.layers[0].data['features']
h = mx.nd.concat(h, features)
if self.dropout:
h = mx.nd.Dropout(h, p=self.dropout)
h = self.input_layer(h)
for i, layer in enumerate(self.layers):
parent_nid = dgl.utils.toindex(nf.layer_parent_nid(i+1))
layer_nid = map_to_nodeflow_nid(nf._graph, i, parent_nid).tousertensor()
self_h = h[layer_nid]
# activation from previous layer of myself, used in graphSAGE
nf.layers[i+1].data['self_h'] = self_h
new_history = h.copy().detach()
history_str = 'h_{}'.format(i)
history = nf.layers[i].data[history_str]
# delta_h used in control variate
delta_h = h - history
# delta_h from previous layer of the nodes in (i+1)-th layer, used in control variate
nf.layers[i+1].data['self_delta_h'] = delta_h[layer_nid]
nf.layers[i].data['h'] = delta_h
nf.block_compute(i,
fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'),
layer)
h = nf.layers[i+1].data.pop('activation')
# update history
if i < nf.num_layers-1:
nf.layers[i].data[history_str] = new_history
return h
class GraphSAGEInfer(gluon.Block):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
**kwargs):
super(GraphSAGEInfer, self).__init__(**kwargs)
with self.name_scope():
self.layers = gluon.nn.Sequential()
# input layer
self.input_layer = GraphSAGELayer(2*in_feats, n_hidden, n_hidden, 0)
# hidden layers
for i in range(1, n_layers):
self.layers.add(NodeUpdate(i, 2*n_hidden, n_hidden, n_hidden, 0, True))
# output layer
self.layers.add(NodeUpdate(n_layers, 2*n_hidden, n_classes, n_hidden, 0, True, last=True))
def forward(self, nf):
h = nf.layers[0].data['preprocess']
features = nf.layers[0].data['features']
h = mx.nd.concat(h, features)
h = self.input_layer(h)
for i, layer in enumerate(self.layers):
nf.layers[i].data['h'] = h
parent_nid = dgl.utils.toindex(nf.layer_parent_nid(i+1))
layer_nid = map_to_nodeflow_nid(nf._graph, i, parent_nid).tousertensor()
# activation from previous layer of the nodes in (i+1)-th layer, used in graphSAGE
self_h = h[layer_nid]
nf.layers[i+1].data['self_h'] = self_h
nf.block_compute(i,
fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'),
layer)
h = nf.layers[i+1].data.pop('activation')
return h
def main(args):
# load and preprocess dataset
data = load_data(args)
if args.gpu >= 0:
ctx = mx.gpu(args.gpu)
else:
ctx = mx.cpu()
if args.self_loop and not args.dataset.startswith('reddit'):
data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])
train_nid = mx.nd.array(np.nonzero(data.train_mask)[0]).astype(np.int64)
test_nid = mx.nd.array(np.nonzero(data.test_mask)[0]).astype(np.int64)
num_neighbors = args.num_neighbors
n_layers = args.n_layers
features = mx.nd.array(data.features).as_in_context(ctx)
labels = mx.nd.array(data.labels).as_in_context(ctx)
train_mask = mx.nd.array(data.train_mask).as_in_context(ctx)
val_mask = mx.nd.array(data.val_mask).as_in_context(ctx)
test_mask = mx.nd.array(data.test_mask).as_in_context(ctx)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_train_samples = train_mask.sum().asscalar()
n_test_samples = test_mask.sum().asscalar()
n_val_samples = val_mask.sum().asscalar()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
n_train_samples,
n_val_samples,
n_test_samples))
g = DGLGraph(data.graph, readonly=True)
g.ndata['features'] = features
norm = mx.nd.expand_dims(1./g.in_degrees().astype('float32'), 1)
g.ndata['norm'] = norm.as_in_context(ctx)
degs = g.in_degrees().astype('float32').asnumpy()
degs[degs > num_neighbors] = num_neighbors
g.ndata['subg_norm'] = mx.nd.expand_dims(mx.nd.array(1./degs, ctx=ctx), 1)
g.update_all(fn.copy_src(src='features', out='m'),
fn.sum(msg='m', out='preprocess'),
lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']})
for i in range(n_layers):
g.ndata['h_{}'.format(i)] = mx.nd.zeros((features.shape[0], args.n_hidden), ctx=ctx)
model = GraphSAGETrain(in_feats,
args.n_hidden,
n_classes,
n_layers,
args.dropout,
prefix='GraphSAGE')
model.initialize(ctx=ctx)
loss_fcn = gluon.loss.SoftmaxCELoss()
infer_model = GraphSAGEInfer(in_feats,
args.n_hidden,
n_classes,
n_layers,
prefix='GraphSAGE')
infer_model.initialize(ctx=ctx)
# use optimizer
print(model.collect_params())
trainer = gluon.Trainer(model.collect_params(), 'adam',
{'learning_rate': args.lr, 'wd': args.weight_decay},
kvstore=mx.kv.create('local'))
# initialize graph
dur = []
for epoch in range(args.n_epochs):
for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
num_neighbors,
neighbor_type='in',
shuffle=True,
num_hops=n_layers,
add_self_loop=True,
seed_nodes=train_nid):
for i in range(n_layers):
agg_history_str = 'agg_h_{}'.format(i)
g.pull(nf.layer_parent_nid(i+1), fn.copy_src(src='h_{}'.format(i), out='m'),
fn.sum(msg='m', out=agg_history_str))
node_embed_names = [['preprocess', 'features', 'h_0']]
for i in range(1, n_layers):
node_embed_names.append(['h_{}'.format(i), 'agg_h_{}'.format(i-1), 'subg_norm', 'norm'])
node_embed_names.append(['agg_h_{}'.format(n_layers-1), 'subg_norm', 'norm'])
nf.copy_from_parent(node_embed_names=node_embed_names)
# forward
with mx.autograd.record():
pred = model(nf)
batch_nids = nf.layer_parent_nid(-1).as_in_context(ctx)
batch_labels = labels[batch_nids]
loss = loss_fcn(pred, batch_labels)
loss = loss.sum() / len(batch_nids)
loss.backward()
trainer.step(batch_size=1)
node_embed_names = [['h_{}'.format(i)] for i in range(n_layers)]
node_embed_names.append([])
nf.copy_to_parent(node_embed_names=node_embed_names)
infer_params = infer_model.collect_params()
for key in infer_params:
idx = trainer._param2idx[key]
trainer._kvstore.pull(idx, out=infer_params[key].data())
num_acc = 0.
for nf, aux in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
g.number_of_nodes(),
neighbor_type='in',
num_hops=n_layers,
seed_nodes=test_nid,
add_self_loop=True):
node_embed_names = [['preprocess', 'features']]
for i in range(n_layers):
node_embed_names.append(['norm', 'subg_norm'])
nf.copy_from_parent(node_embed_names=node_embed_names)
pred = infer_model(nf)
batch_nids = nf.layer_parent_nid(-1).as_in_context(ctx)
batch_labels = labels[batch_nids]
num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar()
print("Test Accuracy {:.4f}". format(num_acc/n_test_samples))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GraphSAGE with Control Variate')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=3e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--batch-size", type=int, default=1000,
help="train batch size")
parser.add_argument("--test-batch-size", type=int, default=1000,
help="test batch size")
parser.add_argument("--num-neighbors", type=int, default=3,
help="number of neighbors to be sampled")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden GraphSAGE units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden GraphSAGE layers")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
args = parser.parse_args()
print(args)
main(args)
...@@ -53,11 +53,13 @@ class SamplerOp { ...@@ -53,11 +53,13 @@ class SamplerOp {
* \param edge_type the type of edges we should sample neighbors. * \param edge_type the type of edges we should sample neighbors.
* \param num_hops the number of hops to sample neighbors. * \param num_hops the number of hops to sample neighbors.
* \param expand_factor the max number of neighbors to sample. * \param expand_factor the max number of neighbors to sample.
* \param add_self_loop whether to add self loop to the sampled subgraph
* \return a NodeFlow graph. * \return a NodeFlow graph.
*/ */
static NodeFlow NeighborUniformSample(const ImmutableGraph *graph, IdArray seeds, static NodeFlow NeighborUniformSample(const ImmutableGraph *graph, IdArray seeds,
const std::string &edge_type, const std::string &edge_type,
int num_hops, int expand_factor); int num_hops, int expand_factor,
const bool add_self_loop);
/*! /*!
* \brief Batch-generate random walk traces * \brief Batch-generate random walk traces
......
...@@ -19,7 +19,8 @@ __all__ = ['NeighborSampler'] ...@@ -19,7 +19,8 @@ __all__ = ['NeighborSampler']
class NSSubgraphLoader(object): class NSSubgraphLoader(object):
def __init__(self, g, batch_size, expand_factor, num_hops=1, def __init__(self, g, batch_size, expand_factor, num_hops=1,
neighbor_type='in', node_prob=None, seed_nodes=None, neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, return_seed_id=False): shuffle=False, num_workers=1, return_seed_id=False,
add_self_loop=False):
self._g = g self._g = g
if not g._graph.is_readonly(): if not g._graph.is_readonly():
raise NotImplementedError("NodeFlow loader only support read-only graphs.") raise NotImplementedError("NodeFlow loader only support read-only graphs.")
...@@ -28,6 +29,7 @@ class NSSubgraphLoader(object): ...@@ -28,6 +29,7 @@ class NSSubgraphLoader(object):
self._num_hops = num_hops self._num_hops = num_hops
self._node_prob = node_prob self._node_prob = node_prob
self._return_seed_id = return_seed_id self._return_seed_id = return_seed_id
self._add_self_loop = add_self_loop
if self._node_prob is not None: if self._node_prob is not None:
assert self._node_prob.shape[0] == g.number_of_nodes(), \ assert self._node_prob.shape[0] == g.number_of_nodes(), \
"We need to know the sampling probability of every node" "We need to know the sampling probability of every node"
...@@ -56,7 +58,7 @@ class NSSubgraphLoader(object): ...@@ -56,7 +58,7 @@ class NSSubgraphLoader(object):
self._nflow_idx += 1 self._nflow_idx += 1
sgi = self._g._graph.neighbor_sampling(seed_ids, self._expand_factor, sgi = self._g._graph.neighbor_sampling(seed_ids, self._expand_factor,
self._num_hops, self._neighbor_type, self._num_hops, self._neighbor_type,
self._node_prob) self._node_prob, self._add_self_loop)
nflows = [NodeFlow(self._g, i) for i in sgi] nflows = [NodeFlow(self._g, i) for i in sgi]
self._nflows.extend(nflows) self._nflows.extend(nflows)
if self._return_seed_id: if self._return_seed_id:
...@@ -194,8 +196,8 @@ class _PrefetchingLoader(object): ...@@ -194,8 +196,8 @@ class _PrefetchingLoader(object):
def NeighborSampler(g, batch_size, expand_factor, num_hops=1, def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
neighbor_type='in', node_prob=None, seed_nodes=None, neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, shuffle=False, num_workers=1, return_seed_id=False,
return_seed_id=False, prefetch=False): prefetch=False, add_self_loop=False):
'''Create a sampler that samples neighborhood. '''Create a sampler that samples neighborhood.
This creates a NodeFlow loader that samples subgraphs from the input graph This creates a NodeFlow loader that samples subgraphs from the input graph
...@@ -241,6 +243,9 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1, ...@@ -241,6 +243,9 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
The seed Ids are in the parent graph. The seed Ids are in the parent graph.
prefetch : bool, default False prefetch : bool, default False
Whether to prefetch the samples in the next batch. Whether to prefetch the samples in the next batch.
add_self_loop : bool, default False
Whether to add self loop to the sampled NodeFlow.
If True, the edge IDs of the self loop edges are -1.
Returns Returns
------- -------
...@@ -249,7 +254,7 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1, ...@@ -249,7 +254,7 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
information about the NodeFlows. information about the NodeFlows.
''' '''
loader = NSSubgraphLoader(g, batch_size, expand_factor, num_hops, neighbor_type, node_prob, loader = NSSubgraphLoader(g, batch_size, expand_factor, num_hops, neighbor_type, node_prob,
seed_nodes, shuffle, num_workers, return_seed_id) seed_nodes, shuffle, num_workers, return_seed_id, add_self_loop)
if not prefetch: if not prefetch:
return loader return loader
else: else:
......
...@@ -674,7 +674,8 @@ class GraphIndex(object): ...@@ -674,7 +674,8 @@ class GraphIndex(object):
shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return inc, shuffle_idx return inc, shuffle_idx
def neighbor_sampling(self, seed_ids, expand_factor, num_hops, neighbor_type, node_prob): def neighbor_sampling(self, seed_ids, expand_factor, num_hops, neighbor_type,
node_prob, add_self_loop=False):
"""Neighborhood sampling""" """Neighborhood sampling"""
if len(seed_ids) == 0: if len(seed_ids) == 0:
return [] return []
...@@ -682,7 +683,8 @@ class GraphIndex(object): ...@@ -682,7 +683,8 @@ class GraphIndex(object):
seed_ids = [v.todgltensor() for v in seed_ids] seed_ids = [v.todgltensor() for v in seed_ids]
num_subgs = len(seed_ids) num_subgs = len(seed_ids)
if node_prob is None: if node_prob is None:
rst = _uniform_sampling(self, seed_ids, neighbor_type, num_hops, expand_factor) rst = _uniform_sampling(self, seed_ids, neighbor_type, num_hops,
expand_factor, add_self_loop)
else: else:
rst = _nonuniform_sampling(self, node_prob, seed_ids, neighbor_type, num_hops, rst = _nonuniform_sampling(self, node_prob, seed_ids, neighbor_type, num_hops,
expand_factor) expand_factor)
...@@ -993,7 +995,7 @@ def map_to_nodeflow_nid(nflow, layer_id, parent_nids): ...@@ -993,7 +995,7 @@ def map_to_nodeflow_nid(nflow, layer_id, parent_nids):
The graph index of a NodeFlow. The graph index of a NodeFlow.
layer_id : int layer_id : int
The layer Id The layer Id.
parent_nids: utils.Index parent_nids: utils.Index
Node Ids in the parent graph. Node Ids in the parent graph.
...@@ -1138,7 +1140,7 @@ _NEIGHBOR_SAMPLING_APIS = { ...@@ -1138,7 +1140,7 @@ _NEIGHBOR_SAMPLING_APIS = {
_EMPTY_ARRAYS = [utils.toindex(F.ones(shape=(0), dtype=F.int64, ctx=F.cpu()))] _EMPTY_ARRAYS = [utils.toindex(F.ones(shape=(0), dtype=F.int64, ctx=F.cpu()))]
def _uniform_sampling(gidx, seed_ids, neigh_type, num_hops, expand_factor): def _uniform_sampling(gidx, seed_ids, neigh_type, num_hops, expand_factor, add_self_loop):
num_seeds = len(seed_ids) num_seeds = len(seed_ids)
empty_ids = [] empty_ids = []
if len(seed_ids) > 1 and len(seed_ids) not in _NEIGHBOR_SAMPLING_APIS.keys(): if len(seed_ids) > 1 and len(seed_ids) not in _NEIGHBOR_SAMPLING_APIS.keys():
...@@ -1147,4 +1149,5 @@ def _uniform_sampling(gidx, seed_ids, neigh_type, num_hops, expand_factor): ...@@ -1147,4 +1149,5 @@ def _uniform_sampling(gidx, seed_ids, neigh_type, num_hops, expand_factor):
seed_ids.extend([empty.todgltensor() for empty in empty_ids]) seed_ids.extend([empty.todgltensor() for empty in empty_ids])
assert len(seed_ids) in _NEIGHBOR_SAMPLING_APIS.keys() assert len(seed_ids) in _NEIGHBOR_SAMPLING_APIS.keys()
return _NEIGHBOR_SAMPLING_APIS[len(seed_ids)](gidx._handle, *seed_ids, neigh_type, return _NEIGHBOR_SAMPLING_APIS[len(seed_ids)](gidx._handle, *seed_ids, neigh_type,
num_hops, expand_factor, num_seeds) num_hops, expand_factor, num_seeds,
add_self_loop)
...@@ -354,7 +354,11 @@ class NodeFlow(DGLBaseGraph): ...@@ -354,7 +354,11 @@ class NodeFlow(DGLBaseGraph):
block_id = self._get_block_id(block_id) block_id = self._get_block_id(block_id)
start = self._block_offsets[block_id] start = self._block_offsets[block_id]
end = self._block_offsets[block_id + 1] end = self._block_offsets[block_id + 1]
return self._edge_mapping.tousertensor()[start:end] ret = self._edge_mapping.tousertensor()[start:end]
# If `add_self_loop` is enabled, the returned parent eid can be -1.
# We have to make sure this case doesn't happen.
assert F.asnumpy(F.sum(ret == -1, 0)) == 0, "The eid in the parent graph is invalid."
return ret
def set_n_initializer(self, initializer, layer_id=ALL, field=None): def set_n_initializer(self, initializer, layer_id=ALL, field=None):
"""Set the initializer for empty node features. """Set the initializer for empty node features.
...@@ -703,7 +707,7 @@ class NodeFlow(DGLBaseGraph): ...@@ -703,7 +707,7 @@ class NodeFlow(DGLBaseGraph):
inplace=inplace) inplace=inplace)
def create_full_node_flow(g, num_layers): def create_full_node_flow(g, num_layers, add_self_loop=False):
"""Convert a full graph to NodeFlow to run a L-layer GNN model. """Convert a full graph to NodeFlow to run a L-layer GNN model.
Parameters Parameters
...@@ -712,6 +716,9 @@ def create_full_node_flow(g, num_layers): ...@@ -712,6 +716,9 @@ def create_full_node_flow(g, num_layers):
a DGL graph a DGL graph
num_layers : int num_layers : int
The number of layers The number of layers
add_self_loop : bool, default False
Whether to add self loop to the sampled NodeFlow.
If True, the edge IDs of the self loop edges are -1.
Returns Returns
------- -------
...@@ -719,5 +726,6 @@ def create_full_node_flow(g, num_layers): ...@@ -719,5 +726,6 @@ def create_full_node_flow(g, num_layers):
a NodeFlow with a specified number of layers. a NodeFlow with a specified number of layers.
""" """
seeds = [utils.toindex(F.arange(0, g.number_of_nodes()))] seeds = [utils.toindex(F.arange(0, g.number_of_nodes()))]
nfi = g._graph.neighbor_sampling(seeds, g.number_of_nodes(), num_layers, 'in', None) nfi = g._graph.neighbor_sampling(seeds, g.number_of_nodes(), num_layers,
'in', None, add_self_loop)
return NodeFlow(g, nfi[0]) return NodeFlow(g, nfi[0])
...@@ -443,6 +443,7 @@ void CAPI_NeighborUniformSample(DGLArgs args, DGLRetValue* rv) { ...@@ -443,6 +443,7 @@ void CAPI_NeighborUniformSample(DGLArgs args, DGLRetValue* rv) {
const int num_hops = args[num_seeds + 2]; const int num_hops = args[num_seeds + 2];
const int num_neighbors = args[num_seeds + 3]; const int num_neighbors = args[num_seeds + 3];
const int num_valid_seeds = args[num_seeds + 4]; const int num_valid_seeds = args[num_seeds + 4];
const bool add_self_loop = args[num_seeds + 5];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle); const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
const ImmutableGraph *gptr = dynamic_cast<const ImmutableGraph*>(ptr); const ImmutableGraph *gptr = dynamic_cast<const ImmutableGraph*>(ptr);
CHECK(gptr) << "sampling isn't implemented in mutable graph"; CHECK(gptr) << "sampling isn't implemented in mutable graph";
...@@ -450,8 +451,8 @@ void CAPI_NeighborUniformSample(DGLArgs args, DGLRetValue* rv) { ...@@ -450,8 +451,8 @@ void CAPI_NeighborUniformSample(DGLArgs args, DGLRetValue* rv) {
std::vector<NodeFlow> subgs(seeds.size()); std::vector<NodeFlow> subgs(seeds.size());
#pragma omp parallel for #pragma omp parallel for
for (int i = 0; i < num_valid_seeds; i++) { for (int i = 0; i < num_valid_seeds; i++) {
subgs[i] = SamplerOp::NeighborUniformSample(gptr, seeds[i], subgs[i] = SamplerOp::NeighborUniformSample(gptr, seeds[i], neigh_type, num_hops,
neigh_type, num_hops, num_neighbors); num_neighbors, add_self_loop);
} }
*rv = ConvertSubgraphToPackedFunc(subgs); *rv = ConvertSubgraphToPackedFunc(subgs);
} }
......
...@@ -376,7 +376,8 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph, ...@@ -376,7 +376,8 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
const float* probability, const float* probability,
const std::string &edge_type, const std::string &edge_type,
int num_hops, int num_hops,
size_t num_neighbor) { size_t num_neighbor,
const bool add_self_loop) {
unsigned int time_seed = time(nullptr); unsigned int time_seed = time(nullptr);
size_t num_seeds = seed_arr->shape[0]; size_t num_seeds = seed_arr->shape[0];
auto orig_csr = edge_type == "in" ? graph->GetInCSR() : graph->GetOutCSR(); auto orig_csr = edge_type == "in" ? graph->GetInCSR() : graph->GetOutCSR();
...@@ -440,6 +441,10 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph, ...@@ -440,6 +441,10 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
&tmp_sampled_edge_list, &tmp_sampled_edge_list,
&time_seed); &time_seed);
} }
if (add_self_loop) {
tmp_sampled_src_list.push_back(dst_id);
tmp_sampled_edge_list.push_back(-1);
}
CHECK_EQ(tmp_sampled_src_list.size(), tmp_sampled_edge_list.size()); CHECK_EQ(tmp_sampled_src_list.size(), tmp_sampled_edge_list.size());
neigh_pos.emplace_back(dst_id, neighbor_list.size(), tmp_sampled_src_list.size()); neigh_pos.emplace_back(dst_id, neighbor_list.size(), tmp_sampled_src_list.size());
// Then push the vertices // Then push the vertices
...@@ -474,13 +479,15 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph, ...@@ -474,13 +479,15 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
NodeFlow SamplerOp::NeighborUniformSample(const ImmutableGraph *graph, IdArray seeds, NodeFlow SamplerOp::NeighborUniformSample(const ImmutableGraph *graph, IdArray seeds,
const std::string &edge_type, const std::string &edge_type,
int num_hops, int expand_factor) { int num_hops, int expand_factor,
const bool add_self_loop) {
return SampleSubgraph(graph, return SampleSubgraph(graph,
seeds, // seed vector seeds, // seed vector
nullptr, // sample_id_probability nullptr, // sample_id_probability
edge_type, edge_type,
num_hops + 1, num_hops + 1,
expand_factor); expand_factor,
add_self_loop);
} }
IdArray SamplerOp::RandomWalk( IdArray SamplerOp::RandomWalk(
......
...@@ -6,23 +6,44 @@ from dgl.node_flow import create_full_node_flow ...@@ -6,23 +6,44 @@ from dgl.node_flow import create_full_node_flow
from dgl import utils from dgl import utils
import dgl.function as fn import dgl.function as fn
from functools import partial from functools import partial
import itertools
def generate_rand_graph(n, connect_more=False):
arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
# having one node to connect to all other nodes. def generate_rand_graph(n, connect_more=False, complete=False):
if connect_more: if complete:
arr[0] = 1 cord = [(i,j) for i, j in itertools.product(range(n), range(n)) if i != j]
arr[:,0] = 1 row = [t[0] for t in cord]
col = [t[1] for t in cord]
data = np.ones((len(row),))
arr = sp.sparse.coo_matrix((data, (row, col)), shape=(n, n))
else:
arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
# having one node to connect to all other nodes.
if connect_more:
arr[0] = 1
arr[:,0] = 1
g = dgl.DGLGraph(arr, readonly=True) g = dgl.DGLGraph(arr, readonly=True)
g.ndata['h1'] = F.randn((g.number_of_nodes(), 10)) g.ndata['h1'] = F.randn((g.number_of_nodes(), 10))
g.edata['h2'] = F.randn((g.number_of_edges(), 3)) g.edata['h2'] = F.randn((g.number_of_edges(), 3))
return g return g
def create_mini_batch(g, num_hops): def test_self_loop():
n = 100
num_hops = 2
g = generate_rand_graph(n, complete=True)
nf = create_mini_batch(g, num_hops, add_self_loop=True)
for i in range(1, nf.num_layers):
in_deg = nf.layer_in_degree(i)
deg = F.ones(in_deg.shape, dtype=F.int64) * n
assert F.array_equal(in_deg, deg)
def create_mini_batch(g, num_hops, add_self_loop=False):
seed_ids = np.array([0, 1, 2, 3]) seed_ids = np.array([0, 1, 2, 3])
seed_ids = utils.toindex(seed_ids) seed_ids = utils.toindex(seed_ids)
sgi = g._graph.neighbor_sampling([seed_ids], g.number_of_nodes(), num_hops, "in", None) sgi = g._graph.neighbor_sampling([seed_ids], g.number_of_nodes(), num_hops,
"in", None, add_self_loop)
assert len(sgi) == 1 assert len(sgi) == 1
return dgl.node_flow.NodeFlow(g, sgi[0]) return dgl.node_flow.NodeFlow(g, sgi[0])
...@@ -239,3 +260,4 @@ if __name__ == '__main__': ...@@ -239,3 +260,4 @@ if __name__ == '__main__':
test_apply_edges() test_apply_edges()
test_flow_compute() test_flow_compute()
test_prop_flows() test_prop_flows()
test_self_loop()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment