"examples/mxnet/gin/main.py" did not exist on "a3febc061b867a30ffa80cbc7e394cb0208e7346"
Unverified Commit 20e1bb45 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

rewrite to use dataloader (#1333)


Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent d98c71ef
...@@ -5,6 +5,7 @@ import torch.nn as nn ...@@ -5,6 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import dgl.function as fn import dgl.function as fn
import dgl.nn.pytorch as dglnn import dgl.nn.pytorch as dglnn
import time import time
...@@ -23,10 +24,11 @@ class NeighborSampler(object): ...@@ -23,10 +24,11 @@ class NeighborSampler(object):
self.fanouts = fanouts self.fanouts = fanouts
def sample_blocks(self, seeds): def sample_blocks(self, seeds):
seeds = th.LongTensor(np.asarray(seeds))
blocks = [] blocks = []
for fanout in self.fanouts: for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors. # For each seed node, sample ``fanout`` neighbors.
frontier = dgl.sampling.sample_neighbors(g, seeds, fanout) frontier = dgl.sampling.sample_neighbors(g, seeds, fanout, replace=True)
# Then we compact the frontier into a bipartite graph for message passing. # Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds) block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer. # Obtain the seed nodes for next layer.
...@@ -91,9 +93,9 @@ class SAGE(nn.Module): ...@@ -91,9 +93,9 @@ class SAGE(nn.Module):
end = start + batch_size end = start + batch_size
batch_nodes = nodes[start:end] batch_nodes = nodes[start:end]
block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes) block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes)
induced_nodes = block.srcdata[dgl.NID] input_nodes = block.srcdata[dgl.NID]
h = x[induced_nodes].to(device) h = x[input_nodes].to(device)
h_dst = h[:block.number_of_nodes(block.dsttype)] h_dst = h[:block.number_of_nodes(block.dsttype)]
h = layer(block, (h, h_dst)) h = layer(block, (h, h_dst))
...@@ -135,6 +137,18 @@ def thread_wrapped_func(func): ...@@ -135,6 +137,18 @@ def thread_wrapped_func(func):
raise exception.__class__(trace) raise exception.__class__(trace)
return decorated_function return decorated_function
def prepare_mp(g):
"""
Explicitly materialize the CSR, CSC and COO representation of the given graph
so that they could be shared via copy-on-write to sampler workers and GPU
trainers.
This is a workaround before full shared memory support on heterogeneous graphs.
"""
g.in_degree(0)
g.out_degree(0)
g.find_edges([0])
def compute_acc(pred, labels): def compute_acc(pred, labels):
""" """
Compute the accuracy of prediction given the labels. Compute the accuracy of prediction given the labels.
...@@ -157,11 +171,11 @@ def evaluate(model, g, inputs, labels, val_mask, batch_size, device): ...@@ -157,11 +171,11 @@ def evaluate(model, g, inputs, labels, val_mask, batch_size, device):
model.train() model.train()
return compute_acc(pred[val_mask], labels[val_mask]) return compute_acc(pred[val_mask], labels[val_mask])
def load_subtensor(g, labels, seeds, induced_nodes, dev_id): def load_subtensor(g, labels, seeds, input_nodes, dev_id):
""" """
Copys features and labels of a set of nodes onto GPU. Copys features and labels of a set of nodes onto GPU.
""" """
batch_inputs = g.ndata['features'][induced_nodes].to(dev_id) batch_inputs = g.ndata['features'][input_nodes].to(dev_id)
batch_labels = labels[seeds].to(dev_id) batch_labels = labels[seeds].to(dev_id)
return batch_inputs, batch_labels return batch_inputs, batch_labels
...@@ -194,7 +208,16 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -194,7 +208,16 @@ def run(proc_id, n_gpus, args, devices, data):
train_nid = th.split(train_nid, len(train_nid) // n_gpus)[dev_id] train_nid = th.split(train_nid, len(train_nid) // n_gpus)[dev_id]
# Create sampler # Create sampler
sampler = NeighborSampler(g, [args.fan_out] * args.num_layers) sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')])
# Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader(
dataset=train_nid.numpy(),
batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True,
drop_last=False,
num_workers=args.num_workers_per_gpu)
# Define model and optimizer # Define model and optimizer
model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, dropout) model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, dropout)
...@@ -210,18 +233,20 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -210,18 +233,20 @@ def run(proc_id, n_gpus, args, devices, data):
iter_tput = [] iter_tput = []
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
tic = time.time() tic = time.time()
train_nid_batches = train_nid[th.randperm(len(train_nid))]
n_batches = (len(train_nid_batches) + args.batch_size - 1) // args.batch_size # Loop over the dataloader to sample the computation dependency graph as a list of
for step in range(n_batches): # blocks.
seeds = train_nid_batches[step * args.batch_size:(step+1) * args.batch_size] for step, blocks in enumerate(dataloader):
if proc_id == 0: if proc_id == 0:
tic_step = time.time() tic_step = time.time()
# Sample blocks for message propagation # The nodes for input lies at the LHS side of the first block.
blocks = sampler.sample_blocks(seeds) # The nodes for output lies at the RHS side of the last block.
induced_nodes = blocks[0].srcdata[dgl.NID] input_nodes = blocks[0].srcdata[dgl.NID]
seeds = blocks[-1].dstdata[dgl.NID]
# Load the input features as well as output labels # Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(g, labels, seeds, induced_nodes, dev_id) batch_inputs, batch_labels = load_subtensor(g, labels, seeds, input_nodes, dev_id)
# Compute loss and prediction # Compute loss and prediction
batch_pred = model(blocks, batch_inputs) batch_pred = model(blocks, batch_inputs)
...@@ -241,8 +266,8 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -241,8 +266,8 @@ def run(proc_id, n_gpus, args, devices, data):
iter_tput.append(len(seeds) * n_gpus / (time.time() - tic_step)) iter_tput.append(len(seeds) * n_gpus / (time.time() - tic_step))
if step % args.log_every == 0 and proc_id == 0: if step % args.log_every == 0 and proc_id == 0:
acc = compute_acc(batch_pred, batch_labels) acc = compute_acc(batch_pred, batch_labels)
print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f}'.format( print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MiB'.format(
epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]))) epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), th.cuda.max_memory_allocated() / 1000000))
if n_gpus > 1: if n_gpus > 1:
th.distributed.barrier() th.distributed.barrier()
...@@ -253,7 +278,10 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -253,7 +278,10 @@ def run(proc_id, n_gpus, args, devices, data):
if epoch >= 5: if epoch >= 5:
avg += toc - tic avg += toc - tic
if epoch % args.eval_every == 0 and epoch != 0: if epoch % args.eval_every == 0 and epoch != 0:
if n_gpus == 1:
eval_acc = evaluate(model, g, g.ndata['features'], labels, val_mask, args.batch_size, 0) eval_acc = evaluate(model, g, g.ndata['features'], labels, val_mask, args.batch_size, 0)
else:
eval_acc = evaluate(model.module, g, g.ndata['features'], labels, val_mask, args.batch_size, 0)
print('Eval Acc {:.4f}'.format(eval_acc)) print('Eval Acc {:.4f}'.format(eval_acc))
if n_gpus > 1: if n_gpus > 1:
...@@ -267,11 +295,12 @@ if __name__ == '__main__': ...@@ -267,11 +295,12 @@ if __name__ == '__main__':
argparser.add_argument('--num-epochs', type=int, default=20) argparser.add_argument('--num-epochs', type=int, default=20)
argparser.add_argument('--num-hidden', type=int, default=16) argparser.add_argument('--num-hidden', type=int, default=16)
argparser.add_argument('--num-layers', type=int, default=2) argparser.add_argument('--num-layers', type=int, default=2)
argparser.add_argument('--fan-out', type=int, default=10) argparser.add_argument('--fan-out', type=str, default='10,25')
argparser.add_argument('--batch-size', type=int, default=1000) argparser.add_argument('--batch-size', type=int, default=1000)
argparser.add_argument('--log-every', type=int, default=20) argparser.add_argument('--log-every', type=int, default=20)
argparser.add_argument('--eval-every', type=int, default=5) argparser.add_argument('--eval-every', type=int, default=5)
argparser.add_argument('--lr', type=float, default=0.003) argparser.add_argument('--lr', type=float, default=0.003)
argparser.add_argument('--num-workers-per-gpu', type=int, default=0)
args = argparser.parse_args() args = argparser.parse_args()
devices = list(map(int, args.gpu.split(','))) devices = list(map(int, args.gpu.split(',')))
...@@ -288,6 +317,7 @@ if __name__ == '__main__': ...@@ -288,6 +317,7 @@ if __name__ == '__main__':
# Construct graph # Construct graph
g = dgl.graph(data.graph.all_edges()) g = dgl.graph(data.graph.all_edges())
g.ndata['features'] = features g.ndata['features'] = features
prepare_mp(g)
# Pack data # Pack data
data = train_mask, val_mask, in_feats, labels, n_classes, g data = train_mask, val_mask, in_feats, labels, n_classes, g
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment