"vscode:/vscode.git/clone" did not exist on "aba0506c8150419c5de890ba887db50e15537cc9"
Unverified Commit 4b4186f8 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Sampler] Change argument type of fanout from list to dict (#1403)

parent 97b08fbb
......@@ -30,8 +30,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
Node ids to sample neighbors from. The allowed types
are dictionary of node types to node id tensors, or simply node id tensor if
the given graph g has only one type of nodes.
fanout : int or list[int]
The number of sampled neighbors for each node on each edge type. Provide a list
fanout : int or dict[etype, int]
The number of sampled neighbors for each node on each edge type. Provide a dict
to specify different fanout values for each edge type.
edge_dir : str, optional
Edge direction ('in' or 'out'). If is 'in', sample from in edges. Otherwise,
......@@ -60,11 +60,15 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
else:
nodes_all_types.append(nd.array([], ctx=nd.cpu()))
if not isinstance(fanout, list):
fanout = [int(fanout)] * len(g.etypes)
if not isinstance(fanout, dict):
fanout_array = [int(fanout)] * len(g.etypes)
else:
if len(fanout) != len(g.etypes):
raise DGLError('Fan-out must be specified for each edge type '
'if a list is provided.')
'if a dict is provided.')
fanout_array = [None] * len(g.etypes)
for etype, value in fanout.items():
fanout_array[g.get_etype_id(etype)] = value
if prob is None:
prob_arrays = [nd.array([], ctx=nd.cpu())] * len(g.etypes)
......@@ -76,7 +80,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
else:
prob_arrays.append(nd.array([], ctx=nd.cpu()))
subgidx = _CAPI_DGLSampleNeighbors(g._graph, nodes_all_types, fanout,
subgidx = _CAPI_DGLSampleNeighbors(g._graph, nodes_all_types, fanout_array,
edge_dir, prob_arrays, replace)
induced_edges = subgidx.induced_edges
ret = DGLHeteroGraph(subgidx.graph, g.ntypes, g.etypes)
......
......@@ -271,7 +271,11 @@ def _test_sample_neighbors(hypersparse):
# test different fanouts for different relations
for i in range(10):
subg = dgl.sampling.sample_neighbors(hg, {'user' : [0,1], 'game' : 0}, [1, 2, 0, 2], replace=True)
subg = dgl.sampling.sample_neighbors(
hg,
{'user' : [0,1], 'game' : 0},
{'follow': 1, 'play': 2, 'liked-by': 0, 'flips': 2},
replace=True)
assert len(subg.ntypes) == 3
assert len(subg.etypes) == 4
assert subg['follow'].number_of_edges() == 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment