Unverified Commit 4b4186f8 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Sampler] Change argument type of fanout from list to dict (#1403)

parent 97b08fbb
...@@ -30,8 +30,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False): ...@@ -30,8 +30,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
Node ids to sample neighbors from. The allowed types Node ids to sample neighbors from. The allowed types
are dictionary of node types to node id tensors, or simply node id tensor if are dictionary of node types to node id tensors, or simply node id tensor if
the given graph g has only one type of nodes. the given graph g has only one type of nodes.
fanout : int or list[int] fanout : int or dict[etype, int]
The number of sampled neighbors for each node on each edge type. Provide a list The number of sampled neighbors for each node on each edge type. Provide a dict
to specify different fanout values for each edge type. to specify different fanout values for each edge type.
edge_dir : str, optional edge_dir : str, optional
Edge direction ('in' or 'out'). If is 'in', sample from in edges. Otherwise, Edge direction ('in' or 'out'). If is 'in', sample from in edges. Otherwise,
...@@ -60,11 +60,15 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False): ...@@ -60,11 +60,15 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
else: else:
nodes_all_types.append(nd.array([], ctx=nd.cpu())) nodes_all_types.append(nd.array([], ctx=nd.cpu()))
if not isinstance(fanout, list): if not isinstance(fanout, dict):
fanout = [int(fanout)] * len(g.etypes) fanout_array = [int(fanout)] * len(g.etypes)
if len(fanout) != len(g.etypes): else:
raise DGLError('Fan-out must be specified for each edge type ' if len(fanout) != len(g.etypes):
'if a list is provided.') raise DGLError('Fan-out must be specified for each edge type '
'if a dict is provided.')
fanout_array = [None] * len(g.etypes)
for etype, value in fanout.items():
fanout_array[g.get_etype_id(etype)] = value
if prob is None: if prob is None:
prob_arrays = [nd.array([], ctx=nd.cpu())] * len(g.etypes) prob_arrays = [nd.array([], ctx=nd.cpu())] * len(g.etypes)
...@@ -76,7 +80,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False): ...@@ -76,7 +80,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
else: else:
prob_arrays.append(nd.array([], ctx=nd.cpu())) prob_arrays.append(nd.array([], ctx=nd.cpu()))
subgidx = _CAPI_DGLSampleNeighbors(g._graph, nodes_all_types, fanout, subgidx = _CAPI_DGLSampleNeighbors(g._graph, nodes_all_types, fanout_array,
edge_dir, prob_arrays, replace) edge_dir, prob_arrays, replace)
induced_edges = subgidx.induced_edges induced_edges = subgidx.induced_edges
ret = DGLHeteroGraph(subgidx.graph, g.ntypes, g.etypes) ret = DGLHeteroGraph(subgidx.graph, g.ntypes, g.etypes)
......
...@@ -271,7 +271,11 @@ def _test_sample_neighbors(hypersparse): ...@@ -271,7 +271,11 @@ def _test_sample_neighbors(hypersparse):
# test different fanouts for different relations # test different fanouts for different relations
for i in range(10): for i in range(10):
subg = dgl.sampling.sample_neighbors(hg, {'user' : [0,1], 'game' : 0}, [1, 2, 0, 2], replace=True) subg = dgl.sampling.sample_neighbors(
hg,
{'user' : [0,1], 'game' : 0},
{'follow': 1, 'play': 2, 'liked-by': 0, 'flips': 2},
replace=True)
assert len(subg.ntypes) == 3 assert len(subg.ntypes) == 3
assert len(subg.etypes) == 4 assert len(subg.etypes) == 4
assert subg['follow'].number_of_edges() == 2 assert subg['follow'].number_of_edges() == 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment