Unverified Commit 5da57663 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[BUGFIX] fix sampler. (#616)

* fix sampler.

* update doc.

* fix.
parent 70ee8664
......@@ -255,7 +255,6 @@ class NeighborSampler(NodeFlowSampler):
* "in": the neighbors on the in-edges.
* "out": the neighbors on the out-edges.
* "both": the neighbors on both types of edges.
Default: "in"
node_prob : Tensor, optional
......@@ -333,17 +332,35 @@ class LayerSampler(NodeFlowSampler):
Parameters
----------
g: the DGLGraph where we sample NodeFlows.
batch_size: The number of NodeFlows in a batch.
layer_size: A list of layer sizes.
node_prob: the probability that a neighbor node is sampled.
Not implemented.
seed_nodes: a list of nodes where we sample NodeFlows from.
If it's None, the seed vertices are all vertices in the graph.
shuffle: indicates the sampled NodeFlows are shuffled.
num_workers: the number of worker threads that sample NodeFlows in parallel.
prefetch : bool, default False
Whether to prefetch the samples in the next batch.
g : DGLGraph
The DGLGraph where we sample NodeFlows.
batch_size : int
The batch size (i.e, the number of nodes in the last layer)
layer_size: int
A list of layer sizes.
neighbor_type: str, optional
Indicates the neighbors on different types of edges.
* "in": the neighbors on the in-edges.
* "out": the neighbors on the out-edges.
Default: "in"
node_prob : Tensor, optional
A 1D tensor for the probability that a neighbor node is sampled.
None means uniform sampling. Otherwise, the number of elements
should be equal to the number of vertices in the graph.
It's not implemented.
Default: None
seed_nodes : Tensor, optional
A 1D tensor list of nodes where we sample NodeFlows from.
If None, the seed vertices are all the vertices in the graph.
Default: None
shuffle : bool, optional
Indicates the sampled NodeFlows are shuffled. Default: False
num_workers : int, optional
The number of worker threads that sample NodeFlows in parallel. Default: 1
prefetch : bool, optional
If true, prefetch the samples in the next batch. Default: False
'''
immutable_only = True
......
......@@ -700,6 +700,18 @@ NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph,
return nf;
}
void BuildCsr(const ImmutableGraph &g, const std::string neigh_type) {
if (neigh_type == "in") {
auto csr = g.GetInCSR();
assert(csr);
} else if (neigh_type == "out") {
auto csr = g.GetOutCSR();
assert(csr);
} else {
LOG(FATAL) << "We don't support sample from neighbor type " << neigh_type;
}
}
DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
// arguments
......@@ -721,6 +733,8 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
const int64_t num_seeds = seed_nodes->shape[0];
const int64_t num_workers = std::min(max_num_workers,
(num_seeds + batch_size - 1) / batch_size - batch_start_id);
// We need to make sure we have the right CSR before we enter parallel sampling.
BuildCsr(*gptr, neigh_type);
// generate node flows
std::vector<NodeFlow*> nflows(num_workers);
#pragma omp parallel for
......@@ -758,6 +772,8 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
const int64_t num_seeds = seed_nodes->shape[0];
const int64_t num_workers = std::min(max_num_workers,
(num_seeds + batch_size - 1) / batch_size - batch_start_id);
// We need to make sure we have the right CSR before we enter parallel sampling.
BuildCsr(*gptr, neigh_type);
// generate node flows
std::vector<NodeFlow*> nflows(num_workers);
#pragma omp parallel for
......
......@@ -13,14 +13,14 @@ def generate_rand_graph(n):
def test_create_full():
g = generate_rand_graph(100)
full_nf = dgl.contrib.sampling.sampler.create_full_nodeflow(g, 5)
assert full_nf.number_of_nodes() == 600
assert full_nf.number_of_nodes() == g.number_of_nodes() * 6
assert full_nf.number_of_edges() == g.number_of_edges() * 5
def test_1neighbor_sampler_all():
g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for i, subg in enumerate(dgl.contrib.sampling.NeighborSampler(
g, 1, 100, neighbor_type='in', num_workers=4)):
g, 1, g.number_of_nodes(), neighbor_type='in', num_workers=4)):
seed_ids = subg.layer_parent_nid(-1)
assert len(seed_ids) == 1
src, dst, eid = g.in_edges(seed_ids, form='all')
......@@ -80,8 +80,8 @@ def test_prefetch_neighbor_sampler():
def test_10neighbor_sampler_all():
g = generate_rand_graph(100)
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
for subg in dgl.contrib.sampling.NeighborSampler(g, 10, 100, neighbor_type='in',
num_workers=4):
for subg in dgl.contrib.sampling.NeighborSampler(g, 10, g.number_of_nodes(),
neighbor_type='in', num_workers=4):
seed_ids = subg.layer_parent_nid(-1)
assert F.array_equal(seed_ids, subg.map_to_parent_nid(subg.layer_nid(-1)))
......@@ -151,11 +151,14 @@ def _test_layer_sampler(prefetch=False):
sub_m = sub_g.number_of_edges()
assert sum(F.shape(sub_g.block_eid(i))[0] for i in range(n_blocks)) == sub_m
def test_layer_sampler():
_test_layer_sampler()
_test_layer_sampler(prefetch=True)
if __name__ == '__main__':
test_create_full()
test_1neighbor_sampler_all()
test_10neighbor_sampler_all()
test_1neighbor_sampler()
test_10neighbor_sampler()
#test_layer_sampler()
#test_layer_sampler(prefetch=True)
test_layer_sampler()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment