Unverified Commit 3a1392e6 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Model] add multiprocessing training with sampling. (#484)

* reorganize sampling code.

* add multi-process training.

* speed up gcn_cv

* fix graphsage_cv.

* add new API in graph store.

* update barrier impl.

* support both local and distributed training.

* fix multiprocess train.

* fix.

* fix barrier.

* add script for loading data.

* multiprocessing sampling.

* accel training.

* replace pull with spmv for speedup.

* nodeflow copy from parent with context.

* enable GPU.

* fix a bug in graph store.

* enable multi-GPU training.

* fix lint.

* add comments.

* rename to run_store_server.py

* fix gcn_cv.

* fix a minor bug in sampler.

* handle error better in graph store.

* improve graphsage_cv for distributed mode.

* update README.

* fix.

* update.
parent e8951915
...@@ -64,3 +64,14 @@ reddit: test accuracy 96.1% with `--num-neighbors 1` and `--batch-size 1000`, ~9 ...@@ -64,3 +64,14 @@ reddit: test accuracy 96.1% with `--num-neighbors 1` and `--batch-size 1000`, ~9
DGLBACKEND=mxnet python examples/mxnet/sampling/train.py --model graphsage_cv --batch-size 1000 --test-batch-size 5000 --n-epochs 50 --dataset reddit --num-neighbors 1 --n-hidden 128 --dropout 0.2 --weight-decay 0 DGLBACKEND=mxnet python examples/mxnet/sampling/train.py --model graphsage_cv --batch-size 1000 --test-batch-size 5000 --n-epochs 50 --dataset reddit --num-neighbors 1 --n-hidden 128 --dropout 0.2 --weight-decay 0
``` ```
### Run multi-processing training
Run the graph store server that loads the reddit dataset with four workers.
```
python3 examples/mxnet/sampling/run_store_server.py --dataset reddit --num-workers 4
```
Run four workers to train GraphSage on the reddit dataset.
```
python3 ../incubator-mxnet/tools/launch.py -n 4 -s 1 --launcher local python3 examples/mxnet/sampling/multi_process_train.py --model graphsage_cv --batch-size 1000 --test-batch-size 5000 --n-epochs 1 --graph-name reddit --num-neighbors 1 --n-hidden 128 --dropout 0.2 --weight-decay 0
```
...@@ -139,25 +139,37 @@ class GCNInfer(gluon.Block): ...@@ -139,25 +139,37 @@ class GCNInfer(gluon.Block):
return h return h
def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples): def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, distributed):
features = g.ndata['features'] features = g.ndata['features']
labels = g.ndata['labels'] labels = g.ndata['labels']
in_feats = features.shape[1] in_feats = features.shape[1]
g_ctx = features.context
norm = mx.nd.expand_dims(1./g.in_degrees().astype('float32'), 1) norm = mx.nd.expand_dims(1./g.in_degrees().astype('float32'), 1)
g.ndata['norm'] = norm.as_in_context(ctx) g.ndata['norm'] = norm.as_in_context(g_ctx)
degs = g.in_degrees().astype('float32').asnumpy() degs = g.in_degrees().astype('float32').asnumpy()
degs[degs > args.num_neighbors] = args.num_neighbors degs[degs > args.num_neighbors] = args.num_neighbors
g.ndata['subg_norm'] = mx.nd.expand_dims(mx.nd.array(1./degs, ctx=ctx), 1) g.ndata['subg_norm'] = mx.nd.expand_dims(mx.nd.array(1./degs, ctx=g_ctx), 1)
g.update_all(fn.copy_src(src='features', out='m'),
fn.sum(msg='m', out='preprocess'),
lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']})
n_layers = args.n_layers n_layers = args.n_layers
for i in range(n_layers):
g.ndata['h_{}'.format(i)] = mx.nd.zeros((features.shape[0], args.n_hidden), ctx=ctx) if distributed:
g.ndata['h_{}'.format(n_layers-1)] = mx.nd.zeros((features.shape[0], 2*args.n_hidden), ctx=ctx) g.dist_update_all(fn.copy_src(src='features', out='m'),
fn.sum(msg='m', out='preprocess'),
lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']})
for i in range(n_layers - 1):
g.init_ndata('h_{}'.format(i), (features.shape[0], args.n_hidden), 'float32')
g.init_ndata('agg_h_{}'.format(i), (features.shape[0], args.n_hidden), 'float32')
g.init_ndata('h_{}'.format(n_layers-1), (features.shape[0], 2*args.n_hidden), 'float32')
g.init_ndata('agg_h_{}'.format(n_layers-1), (features.shape[0], 2*args.n_hidden), 'float32')
else:
g.update_all(fn.copy_src(src='features', out='m'),
fn.sum(msg='m', out='preprocess'),
lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']})
for i in range(n_layers):
g.ndata['h_{}'.format(i)] = mx.nd.zeros((features.shape[0], args.n_hidden), ctx=g_ctx)
g.ndata['agg_h_{}'.format(i)] = mx.nd.zeros((features.shape[0], args.n_hidden), ctx=g_ctx)
g.ndata['h_{}'.format(n_layers-1)] = mx.nd.zeros((features.shape[0], 2*args.n_hidden), ctx=g_ctx)
g.ndata['agg_h_{}'.format(n_layers-1)] = mx.nd.zeros((features.shape[0], 2*args.n_hidden), ctx=g_ctx)
model = GCNSampling(in_feats, model = GCNSampling(in_feats,
args.n_hidden, args.n_hidden,
...@@ -182,12 +194,14 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples): ...@@ -182,12 +194,14 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples):
# use optimizer # use optimizer
print(model.collect_params()) print(model.collect_params())
kv_type = 'dist_sync' if distributed else 'local'
trainer = gluon.Trainer(model.collect_params(), 'adam', trainer = gluon.Trainer(model.collect_params(), 'adam',
{'learning_rate': args.lr, 'wd': args.weight_decay}, {'learning_rate': args.lr, 'wd': args.weight_decay},
kvstore=mx.kv.create('local')) kvstore=mx.kv.create(kv_type))
# initialize graph # initialize graph
dur = [] dur = []
adj = g.adjacency_matrix().as_in_context(g_ctx)
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size, for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
args.num_neighbors, args.num_neighbors,
...@@ -198,20 +212,23 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples): ...@@ -198,20 +212,23 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples):
seed_nodes=train_nid): seed_nodes=train_nid):
for i in range(n_layers): for i in range(n_layers):
agg_history_str = 'agg_h_{}'.format(i) agg_history_str = 'agg_h_{}'.format(i)
g.pull(nf.layer_parent_nid(i+1), fn.copy_src(src='h_{}'.format(i), out='m'), dests = nf.layer_parent_nid(i+1).as_in_context(g_ctx)
fn.sum(msg='m', out=agg_history_str)) # TODO we could use DGLGraph.pull to implement this, but the current
# implementation of pull is very slow. Let's manually do it for now.
g.ndata[agg_history_str][dests] = mx.nd.dot(mx.nd.take(adj, dests),
g.ndata['h_{}'.format(i)])
node_embed_names = [['preprocess', 'h_0']] node_embed_names = [['preprocess', 'h_0']]
for i in range(1, n_layers): for i in range(1, n_layers):
node_embed_names.append(['h_{}'.format(i), 'agg_h_{}'.format(i-1), 'subg_norm', 'norm']) node_embed_names.append(['h_{}'.format(i), 'agg_h_{}'.format(i-1), 'subg_norm', 'norm'])
node_embed_names.append(['agg_h_{}'.format(n_layers-1), 'subg_norm', 'norm']) node_embed_names.append(['agg_h_{}'.format(n_layers-1), 'subg_norm', 'norm'])
nf.copy_from_parent(node_embed_names=node_embed_names) nf.copy_from_parent(node_embed_names=node_embed_names, ctx=ctx)
# forward # forward
with mx.autograd.record(): with mx.autograd.record():
pred = model(nf) pred = model(nf)
batch_nids = nf.layer_parent_nid(-1).as_in_context(ctx) batch_nids = nf.layer_parent_nid(-1)
batch_labels = labels[batch_nids] batch_labels = labels[batch_nids].as_in_context(ctx)
loss = loss_fcn(pred, batch_labels) loss = loss_fcn(pred, batch_labels)
loss = loss.sum() / len(batch_nids) loss = loss.sum() / len(batch_nids)
...@@ -241,14 +258,12 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples): ...@@ -241,14 +258,12 @@ def gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples):
for i in range(n_layers): for i in range(n_layers):
node_embed_names.append(['norm']) node_embed_names.append(['norm'])
nf.copy_from_parent(node_embed_names=node_embed_names) nf.copy_from_parent(node_embed_names=node_embed_names, ctx=ctx)
pred = infer_model(nf) pred = infer_model(nf)
batch_nids = nf.layer_parent_nid(-1).as_in_context(ctx) batch_nids = nf.layer_parent_nid(-1)
batch_labels = labels[batch_nids] batch_labels = labels[batch_nids].as_in_context(ctx)
num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar() num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar()
num_tests += nf.layer_size(-1) num_tests += nf.layer_size(-1)
break break
print("Test Accuracy {:.4f}". format(num_acc/num_tests)) print("Test Accuracy {:.4f}". format(num_acc/num_tests))
...@@ -112,8 +112,9 @@ class GCNInfer(gluon.Block): ...@@ -112,8 +112,9 @@ class GCNInfer(gluon.Block):
def gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples): def gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples):
in_feats = g.ndata['features'].shape[1] in_feats = g.ndata['features'].shape[1]
labels = g.ndata['labels'] labels = g.ndata['labels']
g_ctx = labels.context
degs = g.in_degrees().astype('float32').as_in_context(ctx) degs = g.in_degrees().astype('float32').as_in_context(g_ctx)
norm = mx.nd.expand_dims(1./degs, 1) norm = mx.nd.expand_dims(1./degs, 1)
g.ndata['norm'] = norm g.ndata['norm'] = norm
...@@ -153,12 +154,12 @@ def gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples): ...@@ -153,12 +154,12 @@ def gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples):
num_workers=32, num_workers=32,
num_hops=args.n_layers+1, num_hops=args.n_layers+1,
seed_nodes=train_nid): seed_nodes=train_nid):
nf.copy_from_parent() nf.copy_from_parent(ctx=ctx)
# forward # forward
with mx.autograd.record(): with mx.autograd.record():
pred = model(nf) pred = model(nf)
batch_nids = nf.layer_parent_nid(-1).astype('int64').as_in_context(ctx) batch_nids = nf.layer_parent_nid(-1)
batch_labels = labels[batch_nids] batch_labels = labels[batch_nids].as_in_context(ctx)
loss = loss_fcn(pred, batch_labels) loss = loss_fcn(pred, batch_labels)
loss = loss.sum() / len(batch_nids) loss = loss.sum() / len(batch_nids)
...@@ -179,10 +180,10 @@ def gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples): ...@@ -179,10 +180,10 @@ def gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples):
neighbor_type='in', neighbor_type='in',
num_hops=args.n_layers+1, num_hops=args.n_layers+1,
seed_nodes=test_nid): seed_nodes=test_nid):
nf.copy_from_parent() nf.copy_from_parent(ctx=ctx)
pred = infer_model(nf) pred = infer_model(nf)
batch_nids = nf.layer_parent_nid(-1).astype('int64').as_in_context(ctx) batch_nids = nf.layer_parent_nid(-1)
batch_labels = labels[batch_nids] batch_labels = labels[batch_nids].as_in_context(ctx)
num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar() num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar()
num_tests += nf.layer_size(-1) num_tests += nf.layer_size(-1)
break break
......
...@@ -3,9 +3,6 @@ import numpy as np ...@@ -3,9 +3,6 @@ import numpy as np
import mxnet as mx import mxnet as mx
from mxnet import gluon from mxnet import gluon
import argparse, time, math import argparse, time, math
import numpy as np
import mxnet as mx
from mxnet import gluon
import dgl import dgl
import dgl.function as fn import dgl.function as fn
from dgl import DGLGraph from dgl import DGLGraph
...@@ -181,25 +178,33 @@ class GraphSAGEInfer(gluon.Block): ...@@ -181,25 +178,33 @@ class GraphSAGEInfer(gluon.Block):
return h return h
def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples): def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, distributed):
features = g.ndata['features'] features = g.ndata['features']
labels = g.ndata['labels'] labels = g.ndata['labels']
in_feats = g.ndata['features'].shape[1] in_feats = g.ndata['features'].shape[1]
g_ctx = features.context
norm = mx.nd.expand_dims(1./g.in_degrees().astype('float32'), 1) norm = mx.nd.expand_dims(1./g.in_degrees().astype('float32'), 1)
g.ndata['norm'] = norm.as_in_context(ctx) g.ndata['norm'] = norm.as_in_context(g_ctx)
degs = g.in_degrees().astype('float32').asnumpy() degs = g.in_degrees().astype('float32').asnumpy()
degs[degs > args.num_neighbors] = args.num_neighbors degs[degs > args.num_neighbors] = args.num_neighbors
g.ndata['subg_norm'] = mx.nd.expand_dims(mx.nd.array(1./degs, ctx=ctx), 1) g.ndata['subg_norm'] = mx.nd.expand_dims(mx.nd.array(1./degs, ctx=g_ctx), 1)
g.update_all(fn.copy_src(src='features', out='m'),
fn.sum(msg='m', out='preprocess'),
lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']})
n_layers = args.n_layers n_layers = args.n_layers
for i in range(n_layers):
g.ndata['h_{}'.format(i)] = mx.nd.zeros((features.shape[0], args.n_hidden), ctx=ctx) if distributed:
g.dist_update_all(fn.copy_src(src='features', out='m'),
fn.sum(msg='m', out='preprocess'),
lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']})
for i in range(n_layers):
g.init_ndata('h_{}'.format(i), (features.shape[0], args.n_hidden), 'float32')
g.init_ndata('agg_h_{}'.format(i), (features.shape[0], args.n_hidden), 'float32')
else:
g.update_all(fn.copy_src(src='features', out='m'),
fn.sum(msg='m', out='preprocess'),
lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']})
for i in range(n_layers):
g.ndata['h_{}'.format(i)] = mx.nd.zeros((features.shape[0], args.n_hidden), ctx=g_ctx)
g.ndata['agg_h_{}'.format(i)] = mx.nd.zeros((features.shape[0], args.n_hidden), ctx=g_ctx)
model = GraphSAGETrain(in_feats, model = GraphSAGETrain(in_feats,
args.n_hidden, args.n_hidden,
...@@ -222,13 +227,21 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp ...@@ -222,13 +227,21 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp
# use optimizer # use optimizer
print(model.collect_params()) print(model.collect_params())
kv_type = 'dist_sync' if distributed else 'local'
trainer = gluon.Trainer(model.collect_params(), 'adam', trainer = gluon.Trainer(model.collect_params(), 'adam',
{'learning_rate': args.lr, 'wd': args.weight_decay}, {'learning_rate': args.lr, 'wd': args.weight_decay},
kvstore=mx.kv.create('local')) kvstore=mx.kv.create(kv_type))
# initialize graph # initialize graph
dur = [] dur = []
adj = g.adjacency_matrix().as_in_context(g_ctx)
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
start = time.time()
if distributed:
msg_head = "Worker {:d}, epoch {:d}".format(g.worker_id, epoch)
else:
msg_head = "epoch {:d}".format(epoch)
for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size, for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
args.num_neighbors, args.num_neighbors,
neighbor_type='in', neighbor_type='in',
...@@ -239,22 +252,28 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp ...@@ -239,22 +252,28 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp
seed_nodes=train_nid): seed_nodes=train_nid):
for i in range(n_layers): for i in range(n_layers):
agg_history_str = 'agg_h_{}'.format(i) agg_history_str = 'agg_h_{}'.format(i)
g.pull(nf.layer_parent_nid(i+1), fn.copy_src(src='h_{}'.format(i), out='m'), dests = nf.layer_parent_nid(i+1).as_in_context(g_ctx)
fn.sum(msg='m', out=agg_history_str)) # TODO we could use DGLGraph.pull to implement this, but the current
# implementation of pull is very slow. Let's manually do it for now.
g.ndata[agg_history_str][dests] = mx.nd.dot(mx.nd.take(adj, dests),
g.ndata['h_{}'.format(i)])
node_embed_names = [['preprocess', 'features', 'h_0']] node_embed_names = [['preprocess', 'features', 'h_0']]
for i in range(1, n_layers): for i in range(1, n_layers):
node_embed_names.append(['h_{}'.format(i), 'agg_h_{}'.format(i-1), 'subg_norm', 'norm']) node_embed_names.append(['h_{}'.format(i), 'agg_h_{}'.format(i-1), 'subg_norm', 'norm'])
node_embed_names.append(['agg_h_{}'.format(n_layers-1), 'subg_norm', 'norm']) node_embed_names.append(['agg_h_{}'.format(n_layers-1), 'subg_norm', 'norm'])
nf.copy_from_parent(node_embed_names=node_embed_names) nf.copy_from_parent(node_embed_names=node_embed_names, ctx=ctx)
# forward # forward
with mx.autograd.record(): with mx.autograd.record():
pred = model(nf) pred = model(nf)
batch_nids = nf.layer_parent_nid(-1).as_in_context(ctx) batch_nids = nf.layer_parent_nid(-1)
batch_labels = labels[batch_nids] batch_labels = labels[batch_nids].as_in_context(ctx)
loss = loss_fcn(pred, batch_labels) loss = loss_fcn(pred, batch_labels)
loss = loss.sum() / len(batch_nids) if distributed:
loss = loss.sum() / (len(batch_nids) * g.num_workers)
else:
loss = loss.sum() / (len(batch_nids))
loss.backward() loss.backward()
trainer.step(batch_size=1) trainer.step(batch_size=1)
...@@ -263,6 +282,7 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp ...@@ -263,6 +282,7 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp
node_embed_names.append([]) node_embed_names.append([])
nf.copy_to_parent(node_embed_names=node_embed_names) nf.copy_to_parent(node_embed_names=node_embed_names)
print(msg_head + ': training takes ' + str(time.time() - start))
infer_params = infer_model.collect_params() infer_params = infer_model.collect_params()
...@@ -273,22 +293,27 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp ...@@ -273,22 +293,27 @@ def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samp
num_acc = 0. num_acc = 0.
num_tests = 0 num_tests = 0
for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size, if not distributed or g.worker_id == 0:
g.number_of_nodes(), start = time.time()
neighbor_type='in', for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
num_hops=n_layers, g.number_of_nodes(),
seed_nodes=test_nid, neighbor_type='in',
add_self_loop=True): num_hops=n_layers,
node_embed_names = [['preprocess', 'features']] seed_nodes=test_nid,
for i in range(n_layers): add_self_loop=True):
node_embed_names.append(['norm', 'subg_norm']) node_embed_names = [['preprocess', 'features']]
nf.copy_from_parent(node_embed_names=node_embed_names) for i in range(n_layers):
node_embed_names.append(['norm', 'subg_norm'])
pred = infer_model(nf) nf.copy_from_parent(node_embed_names=node_embed_names, ctx=ctx)
batch_nids = nf.layer_parent_nid(-1).as_in_context(ctx)
batch_labels = labels[batch_nids] pred = infer_model(nf)
num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar() batch_nids = nf.layer_parent_nid(-1)
num_tests += nf.layer_size(-1) batch_labels = labels[batch_nids].as_in_context(ctx)
break num_acc += (pred.argmax(axis=1) == batch_labels).sum().asscalar()
num_tests += nf.layer_size(-1)
print("Test Accuracy {:.4f}". format(num_acc/num_tests)) if distributed:
g._sync_barrier()
print(msg_head + ": Test Accuracy {:.4f}". format(num_acc/num_tests))
break
elif distributed:
g._sync_barrier()
from multiprocessing import Process
import argparse, time, math
import numpy as np
import os
os.environ['OMP_NUM_THREADS'] = '16'
import mxnet as mx
from mxnet import gluon
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from gcn_ns_sc import gcn_ns_train
from gcn_cv_sc import gcn_cv_train
from graphsage_cv import graphsage_cv_train
def main(args):
g = dgl.contrib.graph_store.create_graph_from_store(args.graph_name, "shared_mem")
features = g.ndata['features']
labels = g.ndata['labels']
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
if args.num_gpus > 0:
ctx = mx.gpu(g.worker_id % args.num_gpus)
else:
ctx = mx.cpu()
train_nid = mx.nd.array(np.nonzero(train_mask.asnumpy())[0]).astype(np.int64)
test_nid = mx.nd.array(np.nonzero(test_mask.asnumpy())[0]).astype(np.int64)
n_classes = len(np.unique(labels.asnumpy()))
n_train_samples = train_mask.sum().asscalar()
n_val_samples = val_mask.sum().asscalar()
n_test_samples = test_mask.sum().asscalar()
if args.model == "gcn_ns":
gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples)
elif args.model == "gcn_cv":
gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, True)
elif args.model == "graphsage_cv":
graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, True)
else:
print("unknown model. Please choose from gcn_ns, gcn_cv, graphsage_cv")
print("parent ends")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--model", type=str,
help="select a model. Valid models: gcn_ns, gcn_cv, graphsage_cv")
parser.add_argument("--graph-name", type=str, default="",
help="graph name")
parser.add_argument("--num-feats", type=int, default=100,
help="the number of features")
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--num-gpus", type=int, default=0,
help="the number of GPUs to train")
parser.add_argument("--lr", type=float, default=3e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--batch-size", type=int, default=1000,
help="batch size")
parser.add_argument("--test-batch-size", type=int, default=1000,
help="test batch size")
parser.add_argument("--num-neighbors", type=int, default=3,
help="number of neighbors to be sampled")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
args = parser.parse_args()
print(args)
main(args)
import argparse, time, math
import numpy as np
from scipy import sparse as spsp
import mxnet as mx
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
class GraphData:
def __init__(self, csr, num_feats):
num_nodes = csr.shape[0]
num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0]
edge_ids = np.arange(0, num_edges, step=1, dtype=np.int64)
csr = spsp.csr_matrix((edge_ids, csr.indices.asnumpy(), csr.indptr.asnumpy()),
shape=csr.shape, dtype=np.int64)
self.graph = dgl.graph_index.GraphIndex(multigraph=False, readonly=True)
self.graph.from_csr_matrix(csr.indptr, csr.indices, "in")
self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
self.num_labels = 10
self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=self.num_labels,
shape=(csr.shape[0])))
self.train_mask = np.zeros((num_nodes,))
self.train_mask[np.arange(0, int(num_nodes/2), dtype=np.int64)] = 1
self.val_mask = np.zeros((num_nodes,))
self.val_mask[np.arange(int(num_nodes/2), int(num_nodes/4*3), dtype=np.int64)] = 1
self.test_mask = np.zeros((num_nodes,))
self.test_mask[np.arange(int(num_nodes/4*3), int(num_nodes), dtype=np.int64)] = 1
def main(args):
# load and preprocess dataset
if args.graph_file != '':
csr = mx.nd.load(args.graph_file)[0]
n_edges = csr.shape[0]
data = GraphData(csr, args.num_feats)
csr = None
graph_name = args.graph_file
else:
data = load_data(args)
n_edges = data.graph.number_of_edges()
graph_name = args.dataset
if args.self_loop and not args.dataset.startswith('reddit'):
data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])
mem_ctx = mx.cpu()
features = mx.nd.array(data.features, ctx=mem_ctx)
labels = mx.nd.array(data.labels, ctx=mem_ctx)
train_mask = mx.nd.array(data.train_mask, ctx=mem_ctx)
val_mask = mx.nd.array(data.val_mask, ctx=mem_ctx)
test_mask = mx.nd.array(data.test_mask, ctx=mem_ctx)
n_classes = data.num_labels
n_train_samples = train_mask.sum().asscalar()
n_val_samples = val_mask.sum().asscalar()
n_test_samples = test_mask.sum().asscalar()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
n_train_samples,
n_val_samples,
n_test_samples))
# create GCN model
g = dgl.contrib.graph_store.create_graph_store_server(data.graph, graph_name, "shared_mem",
args.num_workers, False)
g.ndata['features'] = features
g.ndata['labels'] = labels
g.ndata['train_mask'] = train_mask
g.ndata['val_mask'] = val_mask
g.ndata['test_mask'] = test_mask
g.run()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--graph-file", type=str, default="",
help="graph file")
parser.add_argument("--num-feats", type=int, default=100,
help="the number of features")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.add_argument("--num-workers", type=int, default=1,
help="the number of workers")
args = parser.parse_args()
main(args)
...@@ -24,14 +24,14 @@ def main(args): ...@@ -24,14 +24,14 @@ def main(args):
if args.self_loop and not args.dataset.startswith('reddit'): if args.self_loop and not args.dataset.startswith('reddit'):
data.graph.add_edges_from([(i,i) for i in range(len(data.graph))]) data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])
train_nid = mx.nd.array(np.nonzero(data.train_mask)[0]).astype(np.int64).as_in_context(ctx) train_nid = mx.nd.array(np.nonzero(data.train_mask)[0]).astype(np.int64)
test_nid = mx.nd.array(np.nonzero(data.test_mask)[0]).astype(np.int64).as_in_context(ctx) test_nid = mx.nd.array(np.nonzero(data.test_mask)[0]).astype(np.int64)
features = mx.nd.array(data.features).as_in_context(ctx) features = mx.nd.array(data.features)
labels = mx.nd.array(data.labels).as_in_context(ctx) labels = mx.nd.array(data.labels)
train_mask = mx.nd.array(data.train_mask).as_in_context(ctx) train_mask = mx.nd.array(data.train_mask)
val_mask = mx.nd.array(data.val_mask).as_in_context(ctx) val_mask = mx.nd.array(data.val_mask)
test_mask = mx.nd.array(data.test_mask).as_in_context(ctx) test_mask = mx.nd.array(data.test_mask)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
...@@ -59,9 +59,9 @@ def main(args): ...@@ -59,9 +59,9 @@ def main(args):
if args.model == "gcn_ns": if args.model == "gcn_ns":
gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples) gcn_ns_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples)
elif args.model == "gcn_cv": elif args.model == "gcn_cv":
gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples) gcn_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, False)
elif args.model == "graphsage_cv": elif args.model == "graphsage_cv":
graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples) graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, False)
else: else:
print("unknown model. Please choose from gcn_ns, gcn_cv, graphsage_cv") print("unknown model. Please choose from gcn_ns, gcn_cv, graphsage_cv")
......
...@@ -127,6 +127,92 @@ def _to_csr(graph_data, edge_dir, multigraph): ...@@ -127,6 +127,92 @@ def _to_csr(graph_data, edge_dir, multigraph):
csr = idx.adjacency_matrix_scipy(transpose, 'csr') csr = idx.adjacency_matrix_scipy(transpose, 'csr')
return csr.indptr, csr.indices return csr.indptr, csr.indices
class Barrier(object):
""" A barrier in the KVStore server used for one synchronization.
All workers have to enter the barrier before any of them can proceed
with any further computation.
Parameters
----------
num_workers: int
The number of workers will enter the barrier.
"""
def __init__(self, num_workers):
self.num_enters = 0
self.num_leaves = 0
self.num_workers = num_workers
def enter(self):
""" A worker enters the barrier.
"""
self.num_enters += 1
def leave(self):
""" A worker notifies the server that it's going to leave the barrier.
"""
self.num_leaves += 1
def all_enter(self):
""" Indicate that all workers have entered the barrier.
"""
return self.num_enters == self.num_workers
def all_leave(self):
""" Indicate that all workers have left the barrier.
"""
return self.num_leaves == self.num_workers
class BarrierManager(object):
""" The manager of barriers
When a worker wants to enter a barrier, it creates the barrier if it doesn't
exist. Otherwise, the worker will enter an existing barrier.
The manager needs to know the number of workers in advance so that it can
keep track of barriers and workers.
Parameters
----------
num_workers: int
The number of workers that need to synchronize with barriers.
"""
def __init__(self, num_workers):
self.num_workers = num_workers
self.barrier_ids = [0] * num_workers
self.barriers = {}
def enter(self, worker_id):
""" A worker enters a barrier.
Parameters
----------
worker_id : int
The worker that wants to enter a barrier.
"""
bid = self.barrier_ids[worker_id]
self.barrier_ids[worker_id] += 1
if bid in self.barriers:
self.barriers[bid].enter()
else:
self.barriers.update({bid : Barrier(self.num_workers)})
self.barriers[bid].enter()
return bid
def all_enter(self, worker_id, barrier_id):
""" Indicate whether all workers have entered a specified barrier.
"""
return self.barriers[barrier_id].all_enter()
def leave(self, worker_id, barrier_id):
""" A worker leaves a barrier.
This is useful for garbage collection of used barriers.
"""
self.barriers[barrier_id].leave()
if self.barriers[barrier_id].all_leave():
del self.barriers[barrier_id]
class SharedMemoryStoreServer(object): class SharedMemoryStoreServer(object):
"""The graph store server. """The graph store server.
...@@ -158,9 +244,23 @@ class SharedMemoryStoreServer(object): ...@@ -158,9 +244,23 @@ class SharedMemoryStoreServer(object):
self._num_workers = num_workers self._num_workers = num_workers
self._graph_name = graph_name self._graph_name = graph_name
self._edge_dir = edge_dir self._edge_dir = edge_dir
self._registered_nworkers = 0
self._barrier = BarrierManager(num_workers)
# RPC command: register a graph to the graph store server.
def register(graph_name):
if graph_name != self._graph_name:
print("graph store has %s, but the worker wants %s"
% (self._graph_name, graph_name))
return (-1, -1)
worker_id = self._registered_nworkers
self._registered_nworkers += 1
return worker_id, self._num_workers
# RPC command: get the graph information from the graph store server. # RPC command: get the graph information from the graph store server.
def get_graph_info(): def get_graph_info(graph_name):
assert graph_name == self._graph_name
return self._graph.number_of_nodes(), self._graph.number_of_edges(), \ return self._graph.number_of_nodes(), self._graph.number_of_edges(), \
self._graph.is_multigraph, edge_dir self._graph.is_multigraph, edge_dir
...@@ -205,13 +305,30 @@ class SharedMemoryStoreServer(object): ...@@ -205,13 +305,30 @@ class SharedMemoryStoreServer(object):
self._num_workers -= 1 self._num_workers -= 1
return 0 return 0
# RPC command: a worker enters a barrier.
def enter_barrier(worker_id):
return self._barrier.enter(worker_id)
# RPC command: a worker leaves a barrier.
def leave_barrier(worker_id, barrier_id):
self._barrier.leave(worker_id, barrier_id)
return 0
# RPC command: test if all workers have left a barrier.
def all_enter(worker_id, barrier_id):
return self._barrier.all_enter(worker_id, barrier_id)
self.server = SimpleXMLRPCServer(("localhost", port)) self.server = SimpleXMLRPCServer(("localhost", port))
self.server.register_function(register, "register")
self.server.register_function(get_graph_info, "get_graph_info") self.server.register_function(get_graph_info, "get_graph_info")
self.server.register_function(init_ndata, "init_ndata") self.server.register_function(init_ndata, "init_ndata")
self.server.register_function(init_edata, "init_edata") self.server.register_function(init_edata, "init_edata")
self.server.register_function(terminate, "terminate") self.server.register_function(terminate, "terminate")
self.server.register_function(list_ndata, "list_ndata") self.server.register_function(list_ndata, "list_ndata")
self.server.register_function(list_edata, "list_edata") self.server.register_function(list_edata, "list_edata")
self.server.register_function(enter_barrier, "enter_barrier")
self.server.register_function(leave_barrier, "leave_barrier")
self.server.register_function(all_enter, "all_enter")
def __del__(self): def __del__(self):
self._graph = None self._graph = None
...@@ -267,7 +384,10 @@ class SharedMemoryDGLGraph(DGLGraph): ...@@ -267,7 +384,10 @@ class SharedMemoryDGLGraph(DGLGraph):
self._graph_name = graph_name self._graph_name = graph_name
self._pid = os.getpid() self._pid = os.getpid()
self.proxy = xmlrpc.client.ServerProxy("http://localhost:" + str(port) + "/") self.proxy = xmlrpc.client.ServerProxy("http://localhost:" + str(port) + "/")
num_nodes, num_edges, multigraph, edge_dir = self.proxy.get_graph_info() self._worker_id, self._num_workers = self.proxy.register(graph_name)
if self._worker_id < 0:
raise Exception('fail to get graph ' + graph_name + ' from the graph store')
num_nodes, num_edges, multigraph, edge_dir = self.proxy.get_graph_info(graph_name)
graph_idx = GraphIndex(multigraph=multigraph, readonly=True) graph_idx = GraphIndex(multigraph=multigraph, readonly=True)
graph_idx.from_shared_mem_csr_matrix(_get_graph_path(graph_name), num_nodes, num_edges, edge_dir) graph_idx.from_shared_mem_csr_matrix(_get_graph_path(graph_name), num_nodes, num_edges, edge_dir)
...@@ -324,6 +444,105 @@ class SharedMemoryDGLGraph(DGLGraph): ...@@ -324,6 +444,105 @@ class SharedMemoryDGLGraph(DGLGraph):
dlpack = data.to_dlpack() dlpack = data.to_dlpack()
self.edata[edata_name] = F.zerocopy_from_dlpack(dlpack) self.edata[edata_name] = F.zerocopy_from_dlpack(dlpack)
@property
def num_workers(self):
""" The number of workers using the graph store.
"""
return self._num_workers
@property
def worker_id(self):
""" The id of the current worker using the graph store.
When a worker connects to a graph store, it is assigned with a worker id.
This is useful for the graph store server to identify who is sending
requests.
The worker id is a unique number between 0 and num_workers.
This is also useful for user's code. For example, user's code can
use this number to decide how to assign GPUs to workers in multi-processing
training.
"""
return self._worker_id
def _sync_barrier(self):
# Here I manually implement multi-processing barrier with RPC.
# It uses busy wait with RPC. Whenever, all_enter is called, there is
# a context switch, so it doesn't burn CPUs so badly.
bid = self.proxy.enter_barrier(self._worker_id)
while not self.proxy.all_enter(self._worker_id, bid):
continue
self.proxy.leave_barrier(self._worker_id, bid)
def init_ndata(self, ndata_name, shape, dtype):
"""Create node embedding.
It first creates the node embedding in the server and maps it to the current process
with shared memory.
Parameters
----------
ndata_name : string
The name of node embedding
shape : tuple
The shape of the node embedding
dtype : string
The data type of the node embedding. The currently supported data types
are "float32" and "int32".
"""
self.proxy.init_ndata(ndata_name, shape, dtype)
self._init_ndata(ndata_name, shape, dtype)
def init_edata(self, edata_name, shape, dtype):
"""Create edge embedding.
It first creates the edge embedding in the server and maps it to the current process
with shared memory.
Parameters
----------
edata_name : string
The name of edge embedding
shape : tuple
The shape of the edge embedding
dtype : string
The data type of the edge embedding. The currently supported data types
are "float32" and "int32".
"""
self.proxy.init_edata(edata_name, shape, dtype)
self._init_edata(edata_name, shape, dtype)
def dist_update_all(self, message_func="default",
reduce_func="default",
apply_node_func="default"):
""" Distribute the computation in update_all among all pre-defined workers.
dist_update_all requires that all workers invoke this method and will
return only when all workers finish their own portion of computation.
The number of workers are pre-defined. If one of them doesn't invoke the method,
it won't return because some portion of computation isn't finished.
Parameters
----------
message_func : callable, optional
Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`.
reduce_func : callable, optional
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
"""
num_worker_nodes = int(self.number_of_nodes() / self.num_workers) + 1
start_node = self.worker_id * num_worker_nodes
end_node = min((self.worker_id + 1) * num_worker_nodes, self.number_of_nodes())
worker_nodes = np.arange(start_node, end_node, dtype=np.int64)
self.pull(worker_nodes, message_func, reduce_func, apply_node_func, inplace=True)
self._sync_barrier()
def destroy(self): def destroy(self):
"""Destroy the graph store. """Destroy the graph store.
......
...@@ -153,7 +153,7 @@ class NodeFlowSampler(object): ...@@ -153,7 +153,7 @@ class NodeFlowSampler(object):
if self.immutable_only and not g._graph.is_readonly(): if self.immutable_only and not g._graph.is_readonly():
raise NotImplementedError("This loader only support read-only graphs.") raise NotImplementedError("This loader only support read-only graphs.")
self._batch_size = batch_size self._batch_size = int(batch_size)
if seed_nodes is None: if seed_nodes is None:
self._seed_nodes = F.arange(0, g.number_of_nodes()) self._seed_nodes = F.arange(0, g.number_of_nodes())
......
...@@ -152,7 +152,7 @@ class NodeFlow(DGLBaseGraph): ...@@ -152,7 +152,7 @@ class NodeFlow(DGLBaseGraph):
block_id = self._get_block_id(block_id) block_id = self._get_block_id(block_id)
return int(self._block_offsets[block_id + 1]) - int(self._block_offsets[block_id]) return int(self._block_offsets[block_id + 1]) - int(self._block_offsets[block_id])
def copy_from_parent(self, node_embed_names=ALL, edge_embed_names=ALL): def copy_from_parent(self, node_embed_names=ALL, edge_embed_names=ALL, ctx=F.cpu()):
"""Copy node/edge features from the parent graph. """Copy node/edge features from the parent graph.
Parameters Parameters
...@@ -166,7 +166,8 @@ class NodeFlow(DGLBaseGraph): ...@@ -166,7 +166,8 @@ class NodeFlow(DGLBaseGraph):
if is_all(node_embed_names): if is_all(node_embed_names):
for i in range(self.num_layers): for i in range(self.num_layers):
nid = utils.toindex(self.layer_parent_nid(i)) nid = utils.toindex(self.layer_parent_nid(i))
self._node_frames[i] = FrameRef(Frame(self._parent._node_frame[nid])) self._node_frames[i] = FrameRef(Frame(_copy_frame(
self._parent._node_frame[nid], ctx)))
elif node_embed_names is not None: elif node_embed_names is not None:
assert isinstance(node_embed_names, list) \ assert isinstance(node_embed_names, list) \
and len(node_embed_names) == self.num_layers, \ and len(node_embed_names) == self.num_layers, \
...@@ -174,13 +175,14 @@ class NodeFlow(DGLBaseGraph): ...@@ -174,13 +175,14 @@ class NodeFlow(DGLBaseGraph):
for i in range(self.num_layers): for i in range(self.num_layers):
nid = self.layer_parent_nid(i) nid = self.layer_parent_nid(i)
self._node_frames[i] = _get_frame(self._parent._node_frame, self._node_frames[i] = _get_frame(self._parent._node_frame,
node_embed_names[i], nid) node_embed_names[i], nid, ctx)
if self._parent._edge_frame.num_rows != 0 and self._parent._edge_frame.num_columns != 0: if self._parent._edge_frame.num_rows != 0 and self._parent._edge_frame.num_columns != 0:
if is_all(edge_embed_names): if is_all(edge_embed_names):
for i in range(self.num_blocks): for i in range(self.num_blocks):
eid = utils.toindex(self.block_parent_eid(i)) eid = utils.toindex(self.block_parent_eid(i))
self._edge_frames[i] = FrameRef(Frame(self._parent._edge_frame[eid])) self._edge_frames[i] = FrameRef(Frame(_copy_frame(
self._parent._edge_frame[eid], ctx)))
elif edge_embed_names is not None: elif edge_embed_names is not None:
assert isinstance(edge_embed_names, list) \ assert isinstance(edge_embed_names, list) \
and len(edge_embed_names) == self.num_blocks, \ and len(edge_embed_names) == self.num_blocks, \
...@@ -188,7 +190,7 @@ class NodeFlow(DGLBaseGraph): ...@@ -188,7 +190,7 @@ class NodeFlow(DGLBaseGraph):
for i in range(self.num_blocks): for i in range(self.num_blocks):
eid = self.block_parent_eid(i) eid = self.block_parent_eid(i)
self._edge_frames[i] = _get_frame(self._parent._edge_frame, self._edge_frames[i] = _get_frame(self._parent._edge_frame,
edge_embed_names[i], eid) edge_embed_names[i], eid, ctx)
def copy_to_parent(self, node_embed_names=ALL, edge_embed_names=ALL): def copy_to_parent(self, node_embed_names=ALL, edge_embed_names=ALL):
"""Copy node/edge embeddings to the parent graph. """Copy node/edge embeddings to the parent graph.
...@@ -902,13 +904,17 @@ class NodeFlow(DGLBaseGraph): ...@@ -902,13 +904,17 @@ class NodeFlow(DGLBaseGraph):
def _copy_to_like(arr1, arr2): def _copy_to_like(arr1, arr2):
return F.copy_to(arr1, F.context(arr2)) return F.copy_to(arr1, F.context(arr2))
def _get_frame(frame, names, ids): def _get_frame(frame, names, ids, ctx):
col_dict = {name: frame[name][_copy_to_like(ids, frame[name])] for name in names} col_dict = {name: F.copy_to(frame[name][_copy_to_like(ids, frame[name])], \
ctx) for name in names}
if len(col_dict) == 0: if len(col_dict) == 0:
return FrameRef(Frame(num_rows=len(ids))) return FrameRef(Frame(num_rows=len(ids)))
else: else:
return FrameRef(Frame(col_dict)) return FrameRef(Frame(col_dict))
def _copy_frame(frame, ctx):
return {name: F.copy_to(frame[name], ctx) for name in frame}
def _update_frame(frame, names, ids, new_frame): def _update_frame(frame, names, ids, new_frame):
col_dict = {name: new_frame[name] for name in names} col_dict = {name: new_frame[name] for name in names}
......
...@@ -28,10 +28,12 @@ def worker_func(worker_id): ...@@ -28,10 +28,12 @@ def worker_func(worker_id):
g.edata['test4'] = mx.nd.zeros((g.number_of_edges(), 10)) g.edata['test4'] = mx.nd.zeros((g.number_of_edges(), 10))
if worker_id == 0: if worker_id == 0:
time.sleep(3) time.sleep(3)
print(g.worker_id)
g.ndata['test4'][0] = 1 g.ndata['test4'][0] = 1
g.edata['test4'][0] = 2 g.edata['test4'][0] = 2
else: else:
time.sleep(5) time.sleep(5)
print(g.worker_id)
assert np.all(g.ndata['test4'][0].asnumpy() == 1) assert np.all(g.ndata['test4'][0].asnumpy() == 1)
assert np.all(g.edata['test4'][0].asnumpy() == 2) assert np.all(g.edata['test4'][0].asnumpy() == 2)
g.destroy() g.destroy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment