Commit c99f4237 authored by Da Zheng's avatar Da Zheng Committed by Minjie Wang
Browse files

[Sampler] extend sampler (#238)

parent 455ea485
...@@ -271,7 +271,6 @@ class ImmutableGraphIndex(object): ...@@ -271,7 +271,6 @@ class ImmutableGraphIndex(object):
def neighbor_sampling(self, seed_ids, expand_factor, num_hops, neighbor_type, def neighbor_sampling(self, seed_ids, expand_factor, num_hops, neighbor_type,
node_prob, max_subgraph_size): node_prob, max_subgraph_size):
assert node_prob is None
if neighbor_type == 'in': if neighbor_type == 'in':
g = self._in_csr g = self._in_csr
elif neighbor_type == 'out': elif neighbor_type == 'out':
...@@ -280,9 +279,14 @@ class ImmutableGraphIndex(object): ...@@ -280,9 +279,14 @@ class ImmutableGraphIndex(object):
raise NotImplementedError raise NotImplementedError
num_nodes = [] num_nodes = []
num_subgs = len(seed_ids) num_subgs = len(seed_ids)
res = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(g, *seed_ids, num_hops=num_hops, if node_prob is None:
num_neighbor=expand_factor, res = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(g, *seed_ids, num_hops=num_hops,
max_num_vertices=max_subgraph_size) num_neighbor=expand_factor,
max_num_vertices=max_subgraph_size)
else:
res = mx.nd.contrib.dgl_csr_neighbor_non_uniform_sample(g, node_prob, *seed_ids, num_hops=num_hops,
num_neighbor=expand_factor,
max_num_vertices=max_subgraph_size)
vertices, subgraphs = res[0:num_subgs], res[num_subgs:(2*num_subgs)] vertices, subgraphs = res[0:num_subgs], res[num_subgs:(2*num_subgs)]
num_nodes = [subg_v[-1].asnumpy()[0] for subg_v in vertices] num_nodes = [subg_v[-1].asnumpy()[0] for subg_v in vertices]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment