Commit da3ab84c authored by Chao Ma's avatar Chao Ma Committed by Da Zheng
Browse files

Add demo for distributed sampler (#474)

* add C++ rpc infrastructure and distributed sampler

* update

* update lint

* update lint

* update lint

* update

* update

* update

* updare

* update

* update

* update

* update serialize and unittest

* update serialize

* lint

* update

* update

* update

* update

* update

* update

* update unittest

* put Finalize() to __del__

* update unittest

* update

* delete buffer in Finalize

* update unittest

* update unittest

* update unittest

* update unittest

* update

* update

* fix small bug

* windows socket impl

* update API

* fix bug in serialize

* fix bug in serialzie

* set parent graph

* update

* update

* update

* update

* update

* update

* fix lint

* fix lint

* fix

* fix windows compilation error

* fix windows error

* change API to lower-case

* update test

* fix typo

* update

* add SamplerPool

* add SamplerPool

* update

* update test

* update

* update

* update

* update

* add example

* update

* update

* add distributed sampler demo

* add index

* update demo of distributed sampler

* fix lower-case

* print subg index

* update README.md

* update

* remove --gpu args
parent 9b4fb2fb
### Demo for Distributed Sampler
First we need to change the `--ip` and `--port` in `run_trainer.sh` and `run_sampler.sh` for your own environemnt.
Then we need to start trainer node:
```
./run_trainer.sh
```
When you see the message:
```
[04:48:20] .../socket_communicator.cc:68: Bind to 127.0.0.1:2049
[04:48:20] .../socket_communicator.cc:74: Listen on 127.0.0.1:2049, wait sender connect ...
```
then, you can start sampler:
```
./run_sampler.sh
```
\ No newline at end of file
import argparse, time, math
import numpy as np
import mxnet as mx
from mxnet import gluon
from functools import partial
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgl.contrib.sampling import SamplerPool
import time
class MySamplerPool(SamplerPool):
def worker(self, args):
"""User-defined worker function
"""
# Start sender
sender = dgl.contrib.sampling.SamplerSender(ip=args.ip, port=args.port)
# load and preprocess dataset
data = load_data(args)
ctx = mx.cpu()
if args.self_loop and not args.dataset.startswith('reddit'):
data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])
train_nid = mx.nd.array(np.nonzero(data.train_mask)[0]).astype(np.int64).as_in_context(ctx)
test_nid = mx.nd.array(np.nonzero(data.test_mask)[0]).astype(np.int64).as_in_context(ctx)
# create GCN model
g = DGLGraph(data.graph, readonly=True)
for epoch in range(args.n_epochs):
# Here we onlt send nodeflow for training
idx = 0
for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
args.num_neighbors,
neighbor_type='in',
shuffle=True,
num_hops=args.n_layers+1,
seed_nodes=train_nid):
print("send train nodeflow: %d" %(idx))
sender.send(nf)
idx += 1
def main(args):
pool = MySamplerPool()
pool.start(args.num_sender, args)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--batch-size", type=int, default=1000,
help="batch size")
parser.add_argument("--test-batch-size", type=int, default=1000,
help="test batch size")
parser.add_argument("--num-neighbors", type=int, default=3,
help="number of neighbors to be sampled")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--ip", type=str, default='127.0.0.1',
help="ip address of remote trainer machine")
parser.add_argument("--port", type=int, default=2049,
help="listen port of remote trainer machine")
parser.add_argument("--num-sender", type=int, default=1,
help="total number of sampler sender")
args = parser.parse_args()
print(args)
main(args)
import argparse, time, math
import numpy as np
import mxnet as mx
from mxnet import gluon
from functools import partial
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
class NodeUpdate(gluon.Block):
def __init__(self, in_feats, out_feats, activation=None, test=False, concat=False):
super(NodeUpdate, self).__init__()
self.dense = gluon.nn.Dense(out_feats, in_units=in_feats)
self.activation = activation
self.concat = concat
self.test = test
def forward(self, node):
h = node.data['h']
if self.test:
h = h * node.data['norm']
h = self.dense(h)
# skip connection
if self.concat:
h = mx.nd.concat(h, self.activation(h))
elif self.activation:
h = self.activation(h)
return {'activation': h}
class GCNSampling(gluon.Block):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
**kwargs):
super(GCNSampling, self).__init__(**kwargs)
self.dropout = dropout
self.n_layers = n_layers
with self.name_scope():
self.layers = gluon.nn.Sequential()
# input layer
skip_start = (0 == n_layers-1)
self.layers.add(NodeUpdate(in_feats, n_hidden, activation, concat=skip_start))
# hidden layers
for i in range(1, n_layers):
skip_start = (i == n_layers-1)
self.layers.add(NodeUpdate(n_hidden, n_hidden, activation, concat=skip_start))
# output layer
self.layers.add(NodeUpdate(2*n_hidden, n_classes))
def forward(self, nf):
nf.layers[0].data['activation'] = nf.layers[0].data['features']
for i, layer in enumerate(self.layers):
h = nf.layers[i].data.pop('activation')
if self.dropout:
h = mx.nd.Dropout(h, p=self.dropout)
nf.layers[i].data['h'] = h
nf.block_compute(i,
fn.copy_src(src='h', out='m'),
lambda node : {'h': node.mailbox['m'].mean(axis=1)},
layer)
h = nf.layers[-1].data.pop('activation')
return h
class GCNInfer(gluon.Block):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
**kwargs):
super(GCNInfer, self).__init__(**kwargs)
self.n_layers = n_layers
with self.name_scope():
self.layers = gluon.nn.Sequential()
# input layer
skip_start = (0 == n_layers-1)
self.layers.add(NodeUpdate(in_feats, n_hidden, activation, test=True, concat=skip_start))
# hidden layers
for i in range(1, n_layers):
skip_start = (i == n_layers-1)
self.layers.add(NodeUpdate(n_hidden, n_hidden, activation, test=True, concat=skip_start))
# output layer
self.layers.add(NodeUpdate(2*n_hidden, n_classes, test=True))
def forward(self, nf):
nf.layers[0].data['activation'] = nf.layers[0].data['features']
for i, layer in enumerate(self.layers):
h = nf.layers[i].data.pop('activation')
nf.layers[i].data['h'] = h
nf.block_compute(i,
fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'),
layer)
h = nf.layers[-1].data.pop('activation')
return h
def main(args):
# load and preprocess dataset
data = load_data(args)
if args.gpu >= 0:
ctx = mx.gpu(args.gpu)
else:
ctx = mx.cpu()
if args.self_loop and not args.dataset.startswith('reddit'):
data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])
# Create sampler receiver
receiver = dgl.contrib.sampling.SamplerReceiver(ip=args.ip, port=args.port, num_sender=args.num_sender)
train_nid = mx.nd.array(np.nonzero(data.train_mask)[0]).astype(np.int64).as_in_context(ctx)
test_nid = mx.nd.array(np.nonzero(data.test_mask)[0]).astype(np.int64).as_in_context(ctx)
features = mx.nd.array(data.features).as_in_context(ctx)
labels = mx.nd.array(data.labels).as_in_context(ctx)
train_mask = mx.nd.array(data.train_mask).as_in_context(ctx)
val_mask = mx.nd.array(data.val_mask).as_in_context(ctx)
test_mask = mx.nd.array(data.test_mask).as_in_context(ctx)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_train_samples = train_mask.sum().asscalar()
n_val_samples = val_mask.sum().asscalar()
n_test_samples = test_mask.sum().asscalar()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
n_train_samples,
n_val_samples,
n_test_samples))
# create GCN model
g = DGLGraph(data.graph, readonly=True)
g.ndata['features'] = features
num_neighbors = args.num_neighbors
degs = g.in_degrees().astype('float32').as_in_context(ctx)
norm = mx.nd.expand_dims(1./degs, 1)
g.ndata['norm'] = norm
model = GCNSampling(in_feats,
args.n_hidden,
n_classes,
args.n_layers,
mx.nd.relu,
args.dropout,
prefix='GCN')
model.initialize(ctx=ctx)
loss_fcn = gluon.loss.SoftmaxCELoss()
infer_model = GCNInfer(in_feats,
args.n_hidden,
n_classes,
args.n_layers,
mx.nd.relu,
prefix='GCN')
infer_model.initialize(ctx=ctx)
# use optimizer
print(model.collect_params())
trainer = gluon.Trainer(model.collect_params(), 'adam',
{'learning_rate': args.lr, 'wd': args.weight_decay},
kvstore=mx.kv.create('local'))
# initialize graph
dur = []
total_count = 153
for epoch in range(args.n_epochs):
for subg_count in range(total_count):
print(subg_count)
nf = receiver.recv(g)
nf.copy_from_parent()
# forward
with mx.autograd.record():
pred = model(nf)
batch_nids = nf.layer_parent_nid(-1).astype('int64').as_in_context(ctx)
batch_labels = labels[batch_nids]
loss = loss_fcn(pred, batch_labels)
loss = loss.sum() / len(batch_nids)
loss.backward()
trainer.step(batch_size=1)
infer_params = infer_model.collect_params()
for key in infer_params:
idx = trainer._param2idx[key]
trainer._kvstore.pull(idx, out=infer_params[key].data())
num_acc = 0.
for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
g.number_of_nodes(),
neighbor_type='in',
num_hops=args.n_layers+1,
seed_nodes=test_nid):
nf.copy_from_parent()
pred = infer_model(nf)
batch_nids = nf.layer_parent_nid(-1).astype('int64').as_in_context(ctx)
batch_labels = labels[batch_nids]
num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar()
print("Test Accuracy {:.4f}". format(num_acc/n_test_samples))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=3e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--batch-size", type=int, default=1000,
help="batch size")
parser.add_argument("--test-batch-size", type=int, default=1000,
help="test batch size")
parser.add_argument("--num-neighbors", type=int, default=3,
help="number of neighbors to be sampled")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
parser.add_argument("--ip", type=str, default='127.0.0.1',
help="IP address of sampler receiver machine")
parser.add_argument("--port", type=int, default=2049,
help="Listening port of sampler receiver machine")
parser.add_argument("--num-sender", type=int, default=1,
help="Number of sampler sender machine")
args = parser.parse_args()
print(args)
main(args)
DGLBACKEND=mxnet python3 gcn_ns_sampler.py --ip 127.0.0.1 --port 2049 --num-sender=5 --dataset reddit-self-loop --num-neighbors 2 --batch-size 1000 --test-batch-size 500
DGLBACKEND=mxnet python3 gcn_trainer.py --ip 127.0.0.1 --port 2049 --num-sender=5 --dataset reddit-self-loop --num-neighbors 2 --batch-size 1000 --test-batch-size 500 --n-hidden 64
......@@ -189,12 +189,15 @@ def main(args):
# initialize graph
dur = []
for epoch in range(args.n_epochs):
index = 0
for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
args.num_neighbors,
neighbor_type='in',
shuffle=True,
num_hops=args.n_layers+1,
seed_nodes=train_nid):
print(index)
index = index + 1
nf.copy_from_parent()
# forward
with mx.autograd.record():
......
......@@ -21,27 +21,33 @@ class SamplerPool(object):
# Do anything here #
if __name__ == '__main__':
...
args = parser.parse_args()
pool = MySamplerPool()
pool.start(5) # Start 5 processes
Parameters
----------
num_worker : int
number of worker (child process)
pool.start(args.num_sender, args)
"""
__metaclass__ = ABCMeta
def start(self, num_worker):
def start(self, num_worker, args):
"""Start sampler pool
Parameters
----------
num_worker : int
number of worker (number of child process)
args : arguments
arguments passed by user
"""
p = Pool()
for i in range(num_worker):
print("Start child process %d ..." % i)
p.apply_async(self.worker)
p.apply_async(self.worker, args=(args,))
# Waiting for all subprocesses done ...
p.close()
p.join()
@abstractmethod
def worker(self):
def worker(self, args):
pass
class SamplerSender(object):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment