"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "a6bd96aa963075a0f3b7eee0a5d5170a0c954d30"
Unverified Commit 901e0c24 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Example][Optimization] Performance optimization for graphsage unsupervised example (#1531)



* test

* profile

* opt

* Some fix

* upd

* upd

* Add multigpu training support for graphsage unsupervised

* Add share neg

* Fix

* Add profile

* turn on eval

* upd

* Fix

* performance opt
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-12-103.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-14-53.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-19-29.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-2-50.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-52-181.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-87-240.ec2.internal>
parent 7639b5e7
...@@ -10,13 +10,13 @@ import dgl.function as fn ...@@ -10,13 +10,13 @@ import dgl.function as fn
import dgl.nn.pytorch as dglnn import dgl.nn.pytorch as dglnn
import time import time
import argparse import argparse
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset from dgl.data import RedditDataset
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
import tqdm import tqdm
import traceback import traceback
from utils import thread_wrapped_func
#### Neighbor sampler #### Neighbor sampler
class NeighborSampler(object): class NeighborSampler(object):
...@@ -110,39 +110,6 @@ class SAGE(nn.Module): ...@@ -110,39 +110,6 @@ class SAGE(nn.Module):
x = y x = y
return y return y
#### Miscellaneous functions
# According to https://github.com/pytorch/pytorch/issues/17199, this decorator
# is necessary to make fork() and openmp work together.
#
# TODO: confirm if this is necessary for MXNet and Tensorflow. If so, we need
# to standardize worker process creation since our operators are implemented with
# OpenMP.
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
def prepare_mp(g): def prepare_mp(g):
""" """
Explicitly materialize the CSR, CSC and COO representation of the given graph Explicitly materialize the CSR, CSC and COO representation of the given graph
......
...@@ -13,11 +13,14 @@ import argparse ...@@ -13,11 +13,14 @@ import argparse
from _thread import start_new_thread from _thread import start_new_thread
from functools import wraps from functools import wraps
from dgl.data import RedditDataset from dgl.data import RedditDataset
from torch.nn.parallel import DistributedDataParallel
import tqdm import tqdm
import traceback import traceback
import sklearn.linear_model as lm import sklearn.linear_model as lm
import sklearn.metrics as skm import sklearn.metrics as skm
from utils import thread_wrapped_func
#### Negative sampler #### Negative sampler
class NegativeSampler(object): class NegativeSampler(object):
...@@ -30,18 +33,26 @@ class NegativeSampler(object): ...@@ -30,18 +33,26 @@ class NegativeSampler(object):
#### Neighbor sampler #### Neighbor sampler
class NeighborSampler(object): class NeighborSampler(object):
def __init__(self, g, fanouts, num_negs): def __init__(self, g, fanouts, num_negs, neg_share=False):
self.g = g self.g = g
self.fanouts = fanouts self.fanouts = fanouts
self.neg_sampler = NegativeSampler(g) self.neg_sampler = NegativeSampler(g)
self.num_negs = num_negs self.num_negs = num_negs
self.neg_share = neg_share
def sample_blocks(self, seed_edges): def sample_blocks(self, seed_edges):
n_edges = len(seed_edges) n_edges = len(seed_edges)
seed_edges = th.LongTensor(np.asarray(seed_edges)) seed_edges = th.LongTensor(np.asarray(seed_edges))
heads, tails = self.g.find_edges(seed_edges) heads, tails = self.g.find_edges(seed_edges)
neg_tails = self.neg_sampler(self.num_negs * n_edges) if self.neg_share and n_edges % self.num_negs == 0:
neg_heads = heads.view(-1, 1).expand(n_edges, self.num_negs).flatten() neg_tails = self.neg_sampler(n_edges)
neg_tails = neg_tails.view(-1, 1, self.num_negs).expand(n_edges//self.num_negs,
self.num_negs,
self.num_negs).flatten()
neg_heads = heads.view(-1, 1).expand(n_edges, self.num_negs).flatten()
else:
neg_tails = self.neg_sampler(self.num_negs * n_edges)
neg_heads = heads.view(-1, 1).expand(n_edges, self.num_negs).flatten()
# Maintain the correspondence between heads, tails and negative tails as two # Maintain the correspondence between heads, tails and negative tails as two
# graphs. # graphs.
...@@ -60,7 +71,7 @@ class NeighborSampler(object): ...@@ -60,7 +71,7 @@ class NeighborSampler(object):
blocks = [] blocks = []
for fanout in self.fanouts: for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors. # For each seed node, sample ``fanout`` neighbors.
frontier = dgl.sampling.sample_neighbors(g, seeds, fanout, replace=True) frontier = dgl.sampling.sample_neighbors(self.g, seeds, fanout, replace=True)
# Remove all edges between heads and tails, as well as heads and neg_tails. # Remove all edges between heads and tails, as well as heads and neg_tails.
_, _, edge_ids = frontier.edge_ids( _, _, edge_ids = frontier.edge_ids(
th.cat([heads, tails, neg_heads, neg_tails]), th.cat([heads, tails, neg_heads, neg_tails]),
...@@ -69,12 +80,24 @@ class NeighborSampler(object): ...@@ -69,12 +80,24 @@ class NeighborSampler(object):
frontier = dgl.remove_edges(frontier, edge_ids) frontier = dgl.remove_edges(frontier, edge_ids)
# Then we compact the frontier into a bipartite graph for message passing. # Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds) block = dgl.to_block(frontier, seeds)
# Pre-generate CSR format that it can be used in training directly
block.in_degree(0)
# Obtain the seed nodes for next layer. # Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID] seeds = block.srcdata[dgl.NID]
blocks.insert(0, block) blocks.insert(0, block)
# Pre-generate CSR format that it can be used in training directly
return pos_graph, neg_graph, blocks return pos_graph, neg_graph, blocks
def load_subtensor(g, input_nodes, device):
"""
Copys features and labels of a set of nodes onto GPU.
"""
batch_inputs = g.ndata['features'][input_nodes].to(device)
return batch_inputs
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, def __init__(self,
in_feats, in_feats,
...@@ -163,18 +186,6 @@ class CrossEntropyLoss(nn.Module): ...@@ -163,18 +186,6 @@ class CrossEntropyLoss(nn.Module):
loss = F.binary_cross_entropy_with_logits(score, label.float()) loss = F.binary_cross_entropy_with_logits(score, label.float())
return loss return loss
def prepare_mp(g):
"""
Explicitly materialize the CSR, CSC and COO representation of the given graph
so that they could be shared via copy-on-write to sampler workers and GPU
trainers.
This is a workaround before full shared memory support on heterogeneous graphs.
"""
g.in_degree(0)
g.out_degree(0)
g.find_edges([0])
def compute_acc(emb, labels, train_nids, val_nids, test_nids): def compute_acc(emb, labels, train_nids, val_nids, test_nids):
""" """
Compute the accuracy of prediction given the labels. Compute the accuracy of prediction given the labels.
...@@ -211,20 +222,27 @@ def evaluate(model, g, inputs, labels, train_nids, val_nids, test_nids, batch_si ...@@ -211,20 +222,27 @@ def evaluate(model, g, inputs, labels, train_nids, val_nids, test_nids, batch_si
""" """
model.eval() model.eval()
with th.no_grad(): with th.no_grad():
pred = model.inference(g, inputs, batch_size, device) # single gpu
if isinstance(model, SAGE):
pred = model.inference(g, inputs, batch_size, device)
# multi gpu
else:
pred = model.module.inference(g, inputs, batch_size, device)
model.train() model.train()
return compute_acc(pred, labels, train_nids, val_nids, test_nids) return compute_acc(pred, labels, train_nids, val_nids, test_nids)
def load_subtensor(g, seeds, input_nodes, device):
"""
Copys features and labels of a set of nodes onto GPU.
"""
batch_inputs = g.ndata['features'][input_nodes].to(device)
return batch_inputs
#### Entry point #### Entry point
def run(args, device, data): def run(proc_id, n_gpus, args, devices, data):
# Unpack data # Unpack data
device = devices[proc_id]
if n_gpus > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
world_size = n_gpus
th.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=world_size,
rank=proc_id)
train_mask, val_mask, test_mask, in_feats, labels, n_classes, g = data train_mask, val_mask, test_mask, in_feats, labels, n_classes, g = data
train_nid = th.LongTensor(np.nonzero(train_mask)[0]) train_nid = th.LongTensor(np.nonzero(train_mask)[0])
...@@ -232,27 +250,41 @@ def run(args, device, data): ...@@ -232,27 +250,41 @@ def run(args, device, data):
test_nid = th.LongTensor(np.nonzero(test_mask)[0]) test_nid = th.LongTensor(np.nonzero(test_mask)[0])
# Create sampler # Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')], args.num_negs) sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')], args.num_negs, args.neg_share)
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
train_seeds = np.arange(g.number_of_edges())
if n_gpus > 0:
num_per_gpu = (train_seeds.shape[0] + n_gpus -1) // n_gpus
train_seeds = train_seeds[proc_id * num_per_gpu :
(proc_id + 1) * num_per_gpu \
if (proc_id + 1) * num_per_gpu < train_seeds.shape[0]
else train_seeds.shape[0]]
dataloader = DataLoader( dataloader = DataLoader(
dataset=np.arange(g.number_of_edges()), dataset=train_seeds,
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=sampler.sample_blocks, collate_fn=sampler.sample_blocks,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
pin_memory=True,
num_workers=args.num_workers) num_workers=args.num_workers)
# Define model and optimizer # Define model and optimizer
model = SAGE(in_feats, args.num_hidden, args.num_hidden, args.num_layers, F.relu, args.dropout) model = SAGE(in_feats, args.num_hidden, args.num_hidden, args.num_layers, F.relu, args.dropout)
model = model.to(device) model = model.to(device)
if n_gpus > 1:
model = DistributedDataParallel(model, device_ids=[device], output_device=device)
loss_fcn = CrossEntropyLoss() loss_fcn = CrossEntropyLoss()
loss_fcn = loss_fcn.to(device) loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr) optimizer = optim.Adam(model.parameters(), lr=args.lr)
# Training loop # Training loop
avg = 0 avg = 0
iter_tput = [] iter_pos = []
iter_neg = []
iter_d = []
iter_t = []
best_eval_acc = 0 best_eval_acc = 0
best_test_acc = 0 best_test_acc = 0
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
...@@ -260,16 +292,14 @@ def run(args, device, data): ...@@ -260,16 +292,14 @@ def run(args, device, data):
# Loop over the dataloader to sample the computation dependency graph as a list of # Loop over the dataloader to sample the computation dependency graph as a list of
# blocks. # blocks.
for step, (pos_graph, neg_graph, blocks) in enumerate(dataloader):
tic_step = time.time()
tic_step = time.time()
for step, (pos_graph, neg_graph, blocks) in enumerate(dataloader):
# The nodes for input lies at the LHS side of the first block. # The nodes for input lies at the LHS side of the first block.
# The nodes for output lies at the RHS side of the last block. # The nodes for output lies at the RHS side of the last block.
input_nodes = blocks[0].srcdata[dgl.NID] input_nodes = blocks[0].srcdata[dgl.NID]
seeds = blocks[-1].dstdata[dgl.NID] batch_inputs = load_subtensor(g, input_nodes, device)
d_step = time.time()
# Load the input features as well as output labels
batch_inputs = load_subtensor(g, seeds, input_nodes, device)
# Compute loss and prediction # Compute loss and prediction
batch_pred = model(blocks, batch_inputs) batch_pred = model(blocks, batch_inputs)
...@@ -278,30 +308,74 @@ def run(args, device, data): ...@@ -278,30 +308,74 @@ def run(args, device, data):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
iter_tput.append(len(seeds) / (time.time() - tic_step)) t = time.time()
pos_edges = pos_graph.number_of_edges()
neg_edges = neg_graph.number_of_edges()
iter_pos.append(pos_edges / (t - tic_step))
iter_neg.append(neg_edges / (t - tic_step))
iter_d.append(d_step - tic_step)
iter_t.append(t - d_step)
if step % args.log_every == 0: if step % args.log_every == 0:
gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0 gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MiB'.format( print('[{}]Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed (samples/sec) {:.4f}|{:.4f} | Load {:.4f}| train {:.4f} | GPU {:.1f} MiB'.format(
epoch, step, loss.item(), np.mean(iter_tput[3:]), gpu_mem_alloc)) proc_id, epoch, step, loss.item(), np.mean(iter_pos[3:]), np.mean(iter_neg[3:]), np.mean(iter_d[3:]), np.mean(iter_t[3:]), gpu_mem_alloc))
tic_step = time.time()
if step % args.eval_every == 0: if step % args.eval_every == 0 and proc_id == 0:
eval_acc, test_acc = evaluate(model, g, g.ndata['features'], labels, train_nid, val_nid, test_nid, args.batch_size, device) eval_acc, test_acc = evaluate(model, g, g.ndata['features'], labels, train_nid, val_nid, test_nid, args.batch_size, device)
print('Eval Acc {:.4f} Test Acc {:.4f}'.format(eval_acc, test_acc)) print('Eval Acc {:.4f} Test Acc {:.4f}'.format(eval_acc, test_acc))
if eval_acc > best_eval_acc: if eval_acc > best_eval_acc:
best_eval_acc = eval_acc best_eval_acc = eval_acc
best_test_acc = test_acc best_test_acc = test_acc
print('Best Eval Acc {:.4f} Test Acc {:.4f}'.format(best_eval_acc, best_test_acc)) print('Best Eval Acc {:.4f} Test Acc {:.4f}'.format(best_eval_acc, best_test_acc))
if n_gpus > 1:
th.distributed.barrier()
print('Avg epoch time: {}'.format(avg / (epoch - 4))) print('Avg epoch time: {}'.format(avg / (epoch - 4)))
def main(args, devices):
# load reddit data
data = RedditDataset(self_loop=True)
train_mask = data.train_mask
val_mask = data.val_mask
test_mask = data.test_mask
features = th.Tensor(data.features)
in_feats = features.shape[1]
labels = th.LongTensor(data.labels)
n_classes = data.num_labels
# Construct graph
g = dgl.graph(data.graph.all_edges())
g.ndata['features'] = features
# Pack data
data = train_mask, val_mask, test_mask, in_feats, labels, n_classes, g
n_gpus = len(devices)
if devices[0] == -1:
run(0, 0, args, ['cpu'], data)
if n_gpus == 1:
run(0, n_gpus, args, devices, data)
else:
procs = []
for proc_id in range(n_gpus):
p = mp.Process(target=thread_wrapped_func(run),
args=(proc_id, n_gpus, args, devices, data))
p.start()
procs.append(p)
for p in procs:
p.join()
run(args, device, data)
if __name__ == '__main__': if __name__ == '__main__':
argparser = argparse.ArgumentParser("multi-gpu training") argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument('--gpu', type=int, default=0, argparser.add_argument("--gpu", type=str, default='0',
help="GPU device ID. Use -1 for CPU training") help="GPU, can be a list of gpus for multi-gpu trianing, e.g., 0,1,2,3; -1 for CPU")
argparser.add_argument('--num-epochs', type=int, default=20) argparser.add_argument('--num-epochs', type=int, default=20)
argparser.add_argument('--num-hidden', type=int, default=16) argparser.add_argument('--num-hidden', type=int, default=16)
argparser.add_argument('--num-layers', type=int, default=2) argparser.add_argument('--num-layers', type=int, default=2)
argparser.add_argument('--num-negs', type=int, default=1) argparser.add_argument('--num-negs', type=int, default=1)
argparser.add_argument('--neg-share', default=False, action='store_true',
help="sharing neg nodes for positive nodes")
argparser.add_argument('--fan-out', type=str, default='10,25') argparser.add_argument('--fan-out', type=str, default='10,25')
argparser.add_argument('--batch-size', type=int, default=10000) argparser.add_argument('--batch-size', type=int, default=10000)
argparser.add_argument('--log-every', type=int, default=20) argparser.add_argument('--log-every', type=int, default=20)
...@@ -312,25 +386,6 @@ if __name__ == '__main__': ...@@ -312,25 +386,6 @@ if __name__ == '__main__':
help="Number of sampling processes. Use 0 for no extra process.") help="Number of sampling processes. Use 0 for no extra process.")
args = argparser.parse_args() args = argparser.parse_args()
if args.gpu >= 0: devices = list(map(int, args.gpu.split(',')))
device = th.device('cuda:%d' % args.gpu)
else:
device = th.device('cpu')
# load reddit data main(args, devices)
data = RedditDataset(self_loop=True)
train_mask = data.train_mask
val_mask = data.val_mask
test_mask = data.test_mask
features = th.Tensor(data.features)
in_feats = features.shape[1]
labels = th.LongTensor(data.labels)
n_classes = data.num_labels
# Construct graph
g = dgl.graph(data.graph.all_edges())
g.ndata['features'] = features
prepare_mp(g)
# Pack data
data = train_mask, val_mask, test_mask, in_feats, labels, n_classes, g
run(args, device, data)
#### Miscellaneous functions
# According to https://github.com/pytorch/pytorch/issues/17199, this decorator
# is necessary to make fork() and openmp work together.
#
# TODO: confirm if this is necessary for MXNet and Tensorflow. If so, we need
# to standardize worker process creation since our operators are implemented with
# OpenMP.
import torch.multiprocessing as mp
from _thread import start_new_thread
from functools import wraps
import traceback
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment