Unverified Commit a5a35d1f authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Sampler] improve random shuffle performance in sampler (#228)

* fix.

* make it generic.

* add the API.

* fix.

* remove mxnet.
parent 78269ced
...@@ -158,13 +158,12 @@ def copy_to_gpu(subg, ctx): ...@@ -158,13 +158,12 @@ def copy_to_gpu(subg, ctx):
subg.ndata[key] = frame[key].as_in_context(ctx) subg.ndata[key] = frame[key].as_in_context(ctx)
class CachedSubgraph(object): class CachedSubgraph(object):
def __init__(self, subg, seeds, subg_seeds): def __init__(self, subg, seeds):
# We can't cache the input subgraph because it contains node frames # We can't cache the input subgraph because it contains node frames
# and data frames. # and data frames.
self.subg = dgl.DGLSubGraph(subg._parent, subg._parent_nid, subg._parent_eid, self.subg = dgl.DGLSubGraph(subg._parent, subg._parent_nid, subg._parent_eid,
subg._graph) subg._graph)
self.seeds = seeds self.seeds = seeds
self.subg_seeds = subg_seeds
class CachedSubgraphLoader(object): class CachedSubgraphLoader(object):
def __init__(self, loader, shuffle): def __init__(self, loader, shuffle):
...@@ -184,14 +183,13 @@ class CachedSubgraphLoader(object): ...@@ -184,14 +183,13 @@ class CachedSubgraphLoader(object):
def __next__(self): def __next__(self):
if len(self._subgraphs) > 0: if len(self._subgraphs) > 0:
s = self._subgraphs.pop(0) s = self._subgraphs.pop(0)
subg, seeds, subg_seeds = s.subg, s.seeds, s.subg_seeds subg, seeds = s.subg, s.seeds
elif self._gen_subgraph: elif self._gen_subgraph:
subg, seeds = self._loader.__next__() subg, seeds = self._loader.__next__()
subg_seeds = subg.map_to_subgraph_nid(seeds)
else: else:
raise StopIteration raise StopIteration
self._cached.append(CachedSubgraph(subg, seeds, subg_seeds)) self._cached.append(CachedSubgraph(subg, seeds))
return subg, seeds, subg_seeds return subg, seeds
def main(args, data): def main(args, data):
if isinstance(data.features, mx.nd.NDArray): if isinstance(data.features, mx.nd.NDArray):
...@@ -266,6 +264,7 @@ def main(args, data): ...@@ -266,6 +264,7 @@ def main(args, data):
sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand, sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
neighbor_type='in', num_workers=args.num_parallel_subgraphs, seed_nodes=train_vs, neighbor_type='in', num_workers=args.num_parallel_subgraphs, seed_nodes=train_vs,
shuffle=True) shuffle=True)
if args.cache_subgraph:
sampler = CachedSubgraphLoader(sampler, shuffle=True) sampler = CachedSubgraphLoader(sampler, shuffle=True)
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
t0 = time.time() t0 = time.time()
...@@ -273,8 +272,8 @@ def main(args, data): ...@@ -273,8 +272,8 @@ def main(args, data):
i = 0 i = 0
num_batches = len(train_vs) / args.batch_size num_batches = len(train_vs) / args.batch_size
start1 = time.time() start1 = time.time()
sampler.restart() for subg, seeds in sampler:
for subg, seeds, subg_seeds in sampler: subg_seeds = subg.map_to_subgraph_nid(seeds)
subg.copy_from_parent() subg.copy_from_parent()
losses = [] losses = []
...@@ -308,6 +307,14 @@ def main(args, data): ...@@ -308,6 +307,14 @@ def main(args, data):
if i > num_batches / 3: if i > num_batches / 3:
break break
if args.cache_subgraph:
sampler.restart()
else:
sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
neighbor_type='in',
num_workers=args.num_parallel_subgraphs,
seed_nodes=train_vs, shuffle=True)
# prediction. # prediction.
logits = model_infer(g, eval_vs) logits = model_infer(g, eval_vs)
eval_loss = mx.nd.softmax_cross_entropy(logits, eval_labels) eval_loss = mx.nd.softmax_cross_entropy(logits, eval_labels)
...@@ -381,11 +388,13 @@ if __name__ == '__main__': ...@@ -381,11 +388,13 @@ if __name__ == '__main__':
parser.add_argument("--use-spmv", action="store_true", parser.add_argument("--use-spmv", action="store_true",
help="use SpMV for faster speed.") help="use SpMV for faster speed.")
parser.add_argument("--dgl", action="store_true") parser.add_argument("--dgl", action="store_true")
parser.add_argument("--cache-subgraph", default=False, action="store_false")
parser.add_argument("--num-parallel-subgraphs", type=int, default=1, parser.add_argument("--num-parallel-subgraphs", type=int, default=1,
help="the number of subgraphs to construct in parallel.") help="the number of subgraphs to construct in parallel.")
parser.add_argument("--neigh-expand", type=int, default=16, parser.add_argument("--neigh-expand", type=int, default=16,
help="the number of neighbors to sample.") help="the number of neighbors to sample.")
args = parser.parse_args() args = parser.parse_args()
print("cache: " + str(args.cache_subgraph))
# load and preprocess dataset # load and preprocess dataset
if args.graph_file != '': if args.graph_file != '':
......
...@@ -681,6 +681,23 @@ def arange(start, stop): ...@@ -681,6 +681,23 @@ def arange(start, stop):
""" """
pass pass
def rand_shuffle(arr):
"""Random shuffle the data in the first dimension of the array.
The shuffled data is stored in a new array.
Parameters
----------
arr : Tensor
The data tensor
Returns
-------
Tensor
The result tensor
"""
pass
def zerocopy_to_dlpack(input): def zerocopy_to_dlpack(input):
"""Create a dlpack tensor that shares the input memory. """Create a dlpack tensor that shares the input memory.
......
...@@ -179,6 +179,9 @@ def sort_1d(input): ...@@ -179,6 +179,9 @@ def sort_1d(input):
def arange(start, stop): def arange(start, stop):
return nd.arange(start, stop, dtype=np.int64) return nd.arange(start, stop, dtype=np.int64)
def rand_shuffle(arr):
return mx.nd.random.shuffle(arr)
def zerocopy_to_dlpack(arr): def zerocopy_to_dlpack(arr):
return arr.to_dlpack_for_read() return arr.to_dlpack_for_read()
......
...@@ -128,6 +128,11 @@ def sort_1d(input): ...@@ -128,6 +128,11 @@ def sort_1d(input):
def arange(start, stop): def arange(start, stop):
return np.arange(start, stop, dtype=np.int64) return np.arange(start, stop, dtype=np.int64)
def rand_shuffle(arr):
copy = np.copy(arr)
np.random.shuffle(copy)
return copy
# zerocopy_to_dlpack not enabled # zerocopy_to_dlpack not enabled
# zerocopy_from_dlpack not enabled # zerocopy_from_dlpack not enabled
......
...@@ -136,6 +136,10 @@ def sort_1d(input): ...@@ -136,6 +136,10 @@ def sort_1d(input):
def arange(start, stop): def arange(start, stop):
return th.arange(start, stop, dtype=th.int64) return th.arange(start, stop, dtype=th.int64)
def rand_shuffle(arr):
idx = th.randperm(len(arr))
return arr[idx]
def zerocopy_to_dlpack(input): def zerocopy_to_dlpack(input):
return dlpack.to_dlpack(input.contiguous()) return dlpack.to_dlpack(input.contiguous())
......
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
from ... import utils from ... import utils
from ...subgraph import DGLSubGraph from ...subgraph import DGLSubGraph
from ... import backend as F
__all__ = ['NeighborSampler'] __all__ = ['NeighborSampler']
...@@ -22,11 +23,11 @@ class NSSubgraphLoader(object): ...@@ -22,11 +23,11 @@ class NSSubgraphLoader(object):
assert self._node_prob.shape[0] == g.number_of_nodes(), \ assert self._node_prob.shape[0] == g.number_of_nodes(), \
"We need to know the sampling probability of every node" "We need to know the sampling probability of every node"
if seed_nodes is None: if seed_nodes is None:
self._seed_nodes = np.arange(0, g.number_of_nodes(), dtype=np.int64) self._seed_nodes = F.arange(0, g.number_of_nodes())
else: else:
self._seed_nodes = seed_nodes self._seed_nodes = seed_nodes
if shuffle: if shuffle:
np.random.shuffle(self._seed_nodes) self._seed_nodes = F.rand_shuffle(self._seed_nodes)
self._num_workers = num_workers self._num_workers = num_workers
if max_subgraph_size is None: if max_subgraph_size is None:
# This size is set temporarily. # This size is set temporarily.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment