Unverified Commit 86d60a1f authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[DEMO] Add Pytorch demo for distributed sampler (#562)

* update

* update

* update

* add sender

* update

* remove duplicate cpde
parent ce27ebbb
...@@ -14,34 +14,48 @@ pip install torch requests ...@@ -14,34 +14,48 @@ pip install torch requests
`` ``
### Neighbor Sampling & Skip Connection ### Neighbor Sampling & Skip Connection
cora: test accuracy ~83% with --num-neighbors 2, ~84% by training on the full graph
#### cora
Test accuracy ~83% with --num-neighbors 2, ~84% by training on the full graph
``` ```
python3 gcn_ns_sc.py --dataset cora --self-loop --num-neighbors 2 --batch-size 1000000 --test-batch-size 1000000 --gpu 0 DGLBACKEND=pytorch python3 gcn_ns_sc.py --dataset cora --self-loop --num-neighbors 2 --batch-size 1000000 --test-batch-size 1000000
``` ```
citeseer: test accuracy ~69% with --num-neighbors 2, ~70% by training on the full graph #### citeseer
Test accuracy ~69% with --num-neighbors 2, ~70% by training on the full graph
``` ```
python3 gcn_ns_sc.py --dataset citeseer --self-loop --num-neighbors 2 --batch-size 1000000 --test-batch-size 1000000 --gpu 0 DGLBACKEND=pytorch python3 gcn_ns_sc.py --dataset citeseer --self-loop --num-neighbors 2 --batch-size 1000000 --test-batch-size 1000000
``` ```
pubmed: test accuracy ~76% with --num-neighbors 3, ~77% by training on the full graph #### pubmed
Test accuracy ~76% with --num-neighbors 3, ~77% by training on the full graph
``` ```
python3 gcn_ns_sc.py --dataset pubmed --self-loop --num-neighbors 3 --batch-size 1000000 --test-batch-size 1000000 --gpu 0 DGLBACKEND=pytorch python3 gcn_ns_sc.py --dataset pubmed --self-loop --num-neighbors 3 --batch-size 1000000 --test-batch-size 1000000
``` ```
### Control Variate & Skip Connection ### Control Variate & Skip Connection
cora: test accuracy ~84% with --num-neighbors 1, ~84% by training on the full graph
#### cora
Test accuracy ~84% with --num-neighbors 1, ~84% by training on the full graph
``` ```
python3 gcn_cv_sc.py --dataset cora --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000 --gpu 0 DGLBACKEND=pytorch python3 gcn_cv_sc.py --dataset cora --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000
``` ```
citeseer: test accuracy ~69% with --num-neighbors 1, ~70% by training on the full graph #### citeseer
Test accuracy ~69% with --num-neighbors 1, ~70% by training on the full graph
``` ```
python3 gcn_cv_sc.py --dataset citeseer --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000 --gpu 0 DGLBACKEND=pytorch python3 gcn_cv_sc.py --dataset citeseer --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000
``` ```
pubmed: test accuracy ~77% with --num-neighbors 1, ~77% by training on the full graph #### pubmed
Test accuracy ~77% with --num-neighbors 1, ~77% by training on the full graph
``` ```
python3 gcn_cv_sc.py --dataset pubmed --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000 --gpu 0 DGLBACKEND=pytorch python3 gcn_cv_sc.py --dataset pubmed --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000
``` ```
# Stochastic Training for Graph Convolutional Networks Using Distributed Sampler
* Paper: [Control Variate](https://arxiv.org/abs/1710.10568)
* Paper: [Skip Connection](https://arxiv.org/abs/1809.05343)
* Author's code: [https://github.com/thu-ml/stochastic_gcn](https://github.com/thu-ml/stochastic_gcn)
Dependencies
------------
- PyTorch 0.4.1+
- requests
``bash
pip install torch requests
``
### Neighbor Sampling & Skip Connection
#### cora
Test accuracy ~83% with --num-neighbors 2, ~84% by training on the full graph
Trainer side:
```
DGLBACKEND=pytorch python3 gcn_ns_sc_train.py --dataset cora --self-loop --num-neighbors 2 --batch-size 1000000 --test-batch-size 1000000 --ip 127.0.0.1:50051 --num-sampler 1
```
Sampler side:
```
DGLBACKEND=pytorch python3 sampler.py --model gcn_ns --dataset cora --self-loop --num-neighbors 2 --batch-size 1000000 --ip 127.0.0.1:50051
```
#### citeseer
Test accuracy ~69% with --num-neighbors 2, ~70% by training on the full graph
Trainer side:
```
DGLBACKEND=pytorch python3 gcn_ns_sc_train.py --dataset citeseer --self-loop --num-neighbors 2 --batch-size 1000000 --test-batch-size 1000000 --ip 127.0.0.1:50051 --num-sampler 1
```
Sampler side:
```
DGLBACKEND=pytorch python3 sampler.py --model gcn_ns --dataset citeseer --self-loop --num-neighbors 2 --batch-size 1000000 --ip 127.0.0.1:50051
```
#### pubmed
Test accuracy ~76% with --num-neighbors 3, ~77% by training on the full graph
Trainer side:
```
DGLBACKEND=pytorch python3 gcn_ns_sc_train.py --dataset pubmed --self-loop --num-neighbors 3 --batch-size 1000000 --test-batch-size 1000000 --ip 127.0.0.1:50051 --num-sampler 1
```
Sampler side:
```
DGLBACKEND=pytorch python3 sampler.py --model gcn_ns --dataset pubmed --self-loop --num-neighbors 3 --batch-size 1000000 --ip 127.0.0.1:50051
```
### Control Variate & Skip Connection
#### cora
Test accuracy ~84% with --num-neighbors 1, ~84% by training on the full graph
Trainer side:
```
DGLBACKEND=pytorch python3 gcn_cv_sc_train.py --dataset cora --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000 --ip 127.0.0.1:50051 --num-sampler 1
```
Sampler side:
```
DGLBACKEND=pytorch python3 sampler.py --model gcn_cv --dataset cora --self-loop --num-neighbors 1 --batch-size 1000000 --ip 127.0.0.1:50051
```
#### citeseer
Test accuracy ~69% with --num-neighbors 1, ~70% by training on the full graph
Trainer side:
```
DGLBACKEND=pytorch python3 gcn_cv_sc_train.py --dataset citeseer --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000 --ip 127.0.0.1:50051 --num-sampler 1
```
Sampler side:
```
DGLBACKEND=pytorch python3 sampler.py --model gcn_cv --dataset citeseer --self-loop --num-neighbors 1 --batch-size 1000000 --ip 127.0.0.1:50051
```
#### pubmed
Test accuracy ~77% with --num-neighbors 1, ~77% by training on the full graph
Trainer side:
```
DGLBACKEND=pytorch python3 gcn_cv_sc_train.py --dataset pubmed --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000 --ip 127.0.0.1:50051 --num-sampler 1
```
Sampler side:
```
DGLBACKEND=pytorch python3 sampler.py --model gcn_cv --dataset pubmed --self-loop --num-neighbors 1 --batch-size 1000000 --ip 127.0.0.1:50051
```
import os, sys
import argparse, time, math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
parentdir=os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, parentdir)
from gcn_cv_sc import NodeUpdate, GCNSampling, GCNInfer
def main(args):
# load and preprocess dataset
data = load_data(args)
if args.self_loop and not args.dataset.startswith('reddit'):
data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])
train_nid = np.nonzero(data.train_mask)[0].astype(np.int64)
test_nid = np.nonzero(data.test_mask)[0].astype(np.int64)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_train_samples = train_mask.sum().item()
n_val_samples = val_mask.sum().item()
n_test_samples = test_mask.sum().item()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
n_train_samples,
n_val_samples,
n_test_samples))
# create GCN model
g = DGLGraph(data.graph, readonly=True)
norm = 1. / g.in_degrees().float().unsqueeze(1)
if args.gpu < 0:
cuda = False
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
train_mask = train_mask.cuda()
val_mask = val_mask.cuda()
test_mask = test_mask.cuda()
norm = norm.cuda()
g.ndata['features'] = features
num_neighbors = args.num_neighbors
n_layers = args.n_layers
g.ndata['norm'] = norm
g.update_all(fn.copy_src(src='features', out='m'),
fn.sum(msg='m', out='preprocess'),
lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']})
for i in range(n_layers):
g.ndata['h_{}'.format(i)] = torch.zeros(features.shape[0], args.n_hidden).to(device=features.device)
g.ndata['h_{}'.format(n_layers-1)] = torch.zeros(features.shape[0], 2*args.n_hidden).to(device=features.device)
model = GCNSampling(in_feats,
args.n_hidden,
n_classes,
n_layers,
F.relu,
args.dropout)
loss_fcn = nn.CrossEntropyLoss()
infer_model = GCNInfer(in_feats,
args.n_hidden,
n_classes,
n_layers,
F.relu)
if cuda:
model.cuda()
infer_model.cuda()
# use optimizer
optimizer = torch.optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)
# Create sampler receiver
sampler = dgl.contrib.sampling.SamplerReceiver(graph=g, addr=args.ip, num_sender=args.num_sampler)
for epoch in range(args.n_epochs):
for nf in sampler:
for i in range(n_layers):
agg_history_str = 'agg_h_{}'.format(i)
g.pull(nf.layer_parent_nid(i+1).long(), fn.copy_src(src='h_{}'.format(i), out='m'),
fn.sum(msg='m', out=agg_history_str),
lambda node : {agg_history_str: node.data[agg_history_str] * node.data['norm']})
node_embed_names = [['preprocess', 'h_0']]
for i in range(1, n_layers):
node_embed_names.append(['h_{}'.format(i), 'agg_h_{}'.format(i-1)])
node_embed_names.append(['agg_h_{}'.format(n_layers-1)])
nf.copy_from_parent(node_embed_names=node_embed_names)
model.train()
# forward
pred = model(nf)
batch_nids = nf.layer_parent_nid(-1).to(device=pred.device).long()
batch_labels = labels[batch_nids]
loss = loss_fcn(pred, batch_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
node_embed_names = [['h_{}'.format(i)] for i in range(n_layers)]
node_embed_names.append([])
nf.copy_to_parent(node_embed_names=node_embed_names)
for infer_param, param in zip(infer_model.parameters(), model.parameters()):
infer_param.data.copy_(param.data)
num_acc = 0.
for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
g.number_of_nodes(),
neighbor_type='in',
num_workers=32,
num_hops=n_layers,
seed_nodes=test_nid):
node_embed_names = [['preprocess']]
for i in range(n_layers):
node_embed_names.append(['norm'])
nf.copy_from_parent(node_embed_names=node_embed_names)
infer_model.eval()
with torch.no_grad():
pred = infer_model(nf)
batch_nids = nf.layer_parent_nid(-1).to(device=pred.device).long()
batch_labels = labels[batch_nids]
num_acc += (pred.argmax(dim=1) == batch_labels).sum().cpu().item()
print("Test Accuracy {:.4f}". format(num_acc/n_test_samples))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=3e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--batch-size", type=int, default=1000,
help="train batch size")
parser.add_argument("--test-batch-size", type=int, default=1000,
help="test batch size")
parser.add_argument("--num-neighbors", type=int, default=2,
help="number of neighbors to be sampled")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
parser.add_argument("--ip", type=str, default='127.0.0.1:50051',
help="IP address")
parser.add_argument("--num-sampler", type=int, default=1,
help="number of sampler")
args = parser.parse_args()
print(args)
main(args)
import os, sys
import argparse, time, math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
parentdir=os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, parentdir)
from gcn_ns_sc import NodeUpdate, GCNSampling, GCNInfer
def main(args):
# load and preprocess dataset
data = load_data(args)
if args.self_loop and not args.dataset.startswith('reddit'):
data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])
train_nid = np.nonzero(data.train_mask)[0].astype(np.int64)
test_nid = np.nonzero(data.test_mask)[0].astype(np.int64)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_train_samples = train_mask.sum().item()
n_val_samples = val_mask.sum().item()
n_test_samples = test_mask.sum().item()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
n_train_samples,
n_val_samples,
n_test_samples))
# create GCN model
g = DGLGraph(data.graph, readonly=True)
norm = 1. / g.in_degrees().float().unsqueeze(1)
if args.gpu < 0:
cuda = False
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
train_mask = train_mask.cuda()
val_mask = val_mask.cuda()
test_mask = test_mask.cuda()
norm = norm.cuda()
g.ndata['features'] = features
num_neighbors = args.num_neighbors
g.ndata['norm'] = norm
model = GCNSampling(in_feats,
args.n_hidden,
n_classes,
args.n_layers,
F.relu,
args.dropout)
if cuda:
model.cuda()
loss_fcn = nn.CrossEntropyLoss()
infer_model = GCNInfer(in_feats,
args.n_hidden,
n_classes,
args.n_layers,
F.relu)
if cuda:
infer_model.cuda()
# use optimizer
optimizer = torch.optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)
# Create sampler receiver
sampler = dgl.contrib.sampling.SamplerReceiver(graph=g, addr=args.ip, num_sender=args.num_sampler)
# initialize graph
dur = []
for epoch in range(args.n_epochs):
for nf in sampler:
nf.copy_from_parent()
model.train()
# forward
pred = model(nf)
batch_nids = nf.layer_parent_nid(-1).to(device=pred.device, dtype=torch.long)
batch_labels = labels[batch_nids]
loss = loss_fcn(pred, batch_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
for infer_param, param in zip(infer_model.parameters(), model.parameters()):
infer_param.data.copy_(param.data)
num_acc = 0.
for nf in dgl.contrib.sampling.NeighborSampler(g, args.test_batch_size,
g.number_of_nodes(),
neighbor_type='in',
num_workers=32,
num_hops=args.n_layers+1,
seed_nodes=test_nid):
nf.copy_from_parent()
infer_model.eval()
with torch.no_grad():
pred = infer_model(nf)
batch_nids = nf.layer_parent_nid(-1).to(device=pred.device, dtype=torch.long)
batch_labels = labels[batch_nids]
num_acc += (pred.argmax(dim=1) == batch_labels).sum().cpu().item()
print("Test Accuracy {:.4f}". format(num_acc/n_test_samples))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=3e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--batch-size", type=int, default=1000,
help="batch size")
parser.add_argument("--test-batch-size", type=int, default=1000,
help="test batch size")
parser.add_argument("--num-neighbors", type=int, default=3,
help="number of neighbors to be sampled")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
parser.add_argument("--ip", type=str, default='127.0.0.1:50051',
help="IP address")
parser.add_argument("--num-sampler", type=int, default=1,
help="number of sampler")
args = parser.parse_args()
print(args)
main(args)
import argparse, time, math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgl.contrib.sampling import SamplerPool
class MySamplerPool(SamplerPool):
def worker(self, args):
number_hops = 1
if args.model == "gcn_ns":
number_hops = args.n_layers + 1
elif args.model == "gcn_cv":
number_hops = args.n_layers
else:
print("unknown model. Please choose from gcn_ns and gcn_cv")
# Start sender
namebook = { 0:args.ip }
sender = dgl.contrib.sampling.SamplerSender(namebook)
# load and preprocess dataset
data = load_data(args)
if args.self_loop and not args.dataset.startswith('reddit'):
data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])
train_nid = np.nonzero(data.train_mask)[0].astype(np.int64)
test_nid = np.nonzero(data.test_mask)[0].astype(np.int64)
# create GCN model
g = DGLGraph(data.graph, readonly=True)
while True:
idx = 0
for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
args.num_neighbors,
neighbor_type='in',
shuffle=True,
num_workers=32,
num_hops=number_hops,
seed_nodes=train_nid):
print("send train nodeflow: %d" % (idx))
sender.send(nf, 0)
idx += 1
sender.signal(0)
def main(args):
pool = MySamplerPool()
pool.start(args.num_sampler, args)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--model", type=str,
help="select a model. Valid models: gcn_ns, gcn_cv")
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=3e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--batch-size", type=int, default=1000,
help="batch size")
parser.add_argument("--test-batch-size", type=int, default=1000,
help="test batch size")
parser.add_argument("--num-neighbors", type=int, default=3,
help="number of neighbors to be sampled")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
parser.add_argument("--ip", type=str, default='127.0.0.1:50051',
help="IP address")
parser.add_argument("--num-sampler", type=int, default=1,
help="number of sampler")
args = parser.parse_args()
print(args)
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment