Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
4b4186f8
Unverified
Commit
4b4186f8
authored
Mar 30, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Mar 30, 2020
Browse files
[Sampler] Change argument type of fanout from list to dict (#1403)
parent
97b08fbb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
9 deletions
+17
-9
python/dgl/sampling/neighbor.py
python/dgl/sampling/neighbor.py
+12
-8
tests/compute/test_sampling.py
tests/compute/test_sampling.py
+5
-1
No files found.
python/dgl/sampling/neighbor.py
View file @
4b4186f8
...
...
@@ -30,8 +30,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
Node ids to sample neighbors from. The allowed types
are dictionary of node types to node id tensors, or simply node id tensor if
the given graph g has only one type of nodes.
fanout : int or
list[
int]
The number of sampled neighbors for each node on each edge type. Provide a
lis
t
fanout : int or
dict[etype,
int]
The number of sampled neighbors for each node on each edge type. Provide a
dic
t
to specify different fanout values for each edge type.
edge_dir : str, optional
Edge direction ('in' or 'out'). If is 'in', sample from in edges. Otherwise,
...
...
@@ -60,11 +60,15 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
else
:
nodes_all_types
.
append
(
nd
.
array
([],
ctx
=
nd
.
cpu
()))
if
not
isinstance
(
fanout
,
list
):
fanout
=
[
int
(
fanout
)]
*
len
(
g
.
etypes
)
if
len
(
fanout
)
!=
len
(
g
.
etypes
):
raise
DGLError
(
'Fan-out must be specified for each edge type '
'if a list is provided.'
)
if
not
isinstance
(
fanout
,
dict
):
fanout_array
=
[
int
(
fanout
)]
*
len
(
g
.
etypes
)
else
:
if
len
(
fanout
)
!=
len
(
g
.
etypes
):
raise
DGLError
(
'Fan-out must be specified for each edge type '
'if a dict is provided.'
)
fanout_array
=
[
None
]
*
len
(
g
.
etypes
)
for
etype
,
value
in
fanout
.
items
():
fanout_array
[
g
.
get_etype_id
(
etype
)]
=
value
if
prob
is
None
:
prob_arrays
=
[
nd
.
array
([],
ctx
=
nd
.
cpu
())]
*
len
(
g
.
etypes
)
...
...
@@ -76,7 +80,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
else
:
prob_arrays
.
append
(
nd
.
array
([],
ctx
=
nd
.
cpu
()))
subgidx
=
_CAPI_DGLSampleNeighbors
(
g
.
_graph
,
nodes_all_types
,
fanout
,
subgidx
=
_CAPI_DGLSampleNeighbors
(
g
.
_graph
,
nodes_all_types
,
fanout
_array
,
edge_dir
,
prob_arrays
,
replace
)
induced_edges
=
subgidx
.
induced_edges
ret
=
DGLHeteroGraph
(
subgidx
.
graph
,
g
.
ntypes
,
g
.
etypes
)
...
...
tests/compute/test_sampling.py
View file @
4b4186f8
...
...
@@ -271,7 +271,11 @@ def _test_sample_neighbors(hypersparse):
# test different fanouts for different relations
for
i
in
range
(
10
):
subg
=
dgl
.
sampling
.
sample_neighbors
(
hg
,
{
'user'
:
[
0
,
1
],
'game'
:
0
},
[
1
,
2
,
0
,
2
],
replace
=
True
)
subg
=
dgl
.
sampling
.
sample_neighbors
(
hg
,
{
'user'
:
[
0
,
1
],
'game'
:
0
},
{
'follow'
:
1
,
'play'
:
2
,
'liked-by'
:
0
,
'flips'
:
2
},
replace
=
True
)
assert
len
(
subg
.
ntypes
)
==
3
assert
len
(
subg
.
etypes
)
==
4
assert
subg
[
'follow'
].
number_of_edges
()
==
2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment