Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
d7062980
Unverified
Commit
d7062980
authored
Jun 11, 2019
by
Da Zheng
Committed by
GitHub
Jun 11, 2019
Browse files
fix a bug. (#646)
parent
e16e895d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
5 deletions
+38
-5
src/graph/sampler.cc
src/graph/sampler.cc
+13
-2
tests/compute/test_nodeflow.py
tests/compute/test_nodeflow.py
+25
-3
No files found.
src/graph/sampler.cc
View file @
d7062980
...
...
@@ -441,9 +441,20 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
&
tmp_sampled_edge_list
,
&
time_seed
);
}
if
(
add_self_loop
)
{
// If we need to add self loop and it doesn't exist in the sampled neighbor list.
if
(
add_self_loop
&&
std
::
find
(
tmp_sampled_src_list
.
begin
(),
tmp_sampled_src_list
.
end
(),
dst_id
)
==
tmp_sampled_src_list
.
end
())
{
tmp_sampled_src_list
.
push_back
(
dst_id
);
tmp_sampled_edge_list
.
push_back
(
-
1
);
const
dgl_id_t
*
src_list
=
col_list
+
*
(
indptr
+
dst_id
);
const
dgl_id_t
*
eid_list
=
val_list
+
*
(
indptr
+
dst_id
);
// TODO(zhengda) this operation has O(N) complexity. It can be pretty slow.
const
dgl_id_t
*
src
=
std
::
find
(
src_list
,
src_list
+
ver_len
,
dst_id
);
// If there doesn't exist a self loop in the graph.
// we have to add -1 as the edge id for the self-loop edge.
if
(
src
==
src_list
+
ver_len
)
tmp_sampled_edge_list
.
push_back
(
-
1
);
else
tmp_sampled_edge_list
.
push_back
(
eid_list
[
src
-
src_list
]);
}
CHECK_EQ
(
tmp_sampled_src_list
.
size
(),
tmp_sampled_edge_list
.
size
());
neigh_pos
.
emplace_back
(
dst_id
,
neighbor_list
.
size
(),
tmp_sampled_src_list
.
size
());
...
...
tests/compute/test_nodeflow.py
View file @
d7062980
...
...
@@ -10,7 +10,7 @@ import dgl.function as fn
from
functools
import
partial
import
itertools
def
generate_rand_graph
(
n
,
connect_more
=
False
,
complete
=
False
):
def
generate_rand_graph
(
n
,
connect_more
=
False
,
complete
=
False
,
add_self_loop
=
False
):
if
complete
:
cord
=
[(
i
,
j
)
for
i
,
j
in
itertools
.
product
(
range
(
n
),
range
(
n
))
if
i
!=
j
]
row
=
[
t
[
0
]
for
t
in
cord
]
...
...
@@ -23,7 +23,13 @@ def generate_rand_graph(n, connect_more=False, complete=False):
if
connect_more
:
arr
[
0
]
=
1
arr
[:,
0
]
=
1
g
=
dgl
.
DGLGraph
(
arr
,
readonly
=
True
)
if
add_self_loop
:
g
=
dgl
.
DGLGraph
(
arr
,
readonly
=
False
)
nodes
=
np
.
arange
(
g
.
number_of_nodes
())
g
.
add_edges
(
nodes
,
nodes
)
g
.
readonly
()
else
:
g
=
dgl
.
DGLGraph
(
arr
,
readonly
=
True
)
g
.
ndata
[
'h1'
]
=
F
.
randn
((
g
.
number_of_nodes
(),
10
))
g
.
edata
[
'h2'
]
=
F
.
randn
((
g
.
number_of_edges
(),
3
))
return
g
...
...
@@ -39,6 +45,18 @@ def test_self_loop():
deg
=
F
.
copy_to
(
F
.
ones
(
in_deg
.
shape
,
dtype
=
F
.
int64
),
F
.
cpu
())
*
n
assert_array_equal
(
F
.
asnumpy
(
in_deg
),
F
.
asnumpy
(
deg
))
g
=
generate_rand_graph
(
n
,
complete
=
True
,
add_self_loop
=
True
)
g
=
dgl
.
to_simple_graph
(
g
)
nf
=
create_mini_batch
(
g
,
num_hops
,
add_self_loop
=
True
)
for
i
in
range
(
nf
.
num_blocks
):
parent_eid
=
F
.
asnumpy
(
nf
.
block_parent_eid
(
i
))
parent_nid
=
F
.
asnumpy
(
nf
.
layer_parent_nid
(
i
+
1
))
# The loop eid in the parent graph must exist in the block parent eid.
parent_loop_eid
=
F
.
asnumpy
(
g
.
edge_ids
(
parent_nid
,
parent_nid
))
assert
len
(
parent_loop_eid
)
==
len
(
parent_nid
)
for
eid
in
parent_loop_eid
:
assert
eid
in
parent_eid
def
create_mini_batch
(
g
,
num_hops
,
add_self_loop
=
False
):
seed_ids
=
np
.
array
([
1
,
2
,
0
,
3
])
sampler
=
NeighborSampler
(
g
,
batch_size
=
4
,
expand_factor
=
g
.
number_of_nodes
(),
...
...
@@ -59,7 +77,6 @@ def check_basic(g, nf):
assert
nf
.
number_of_edges
()
==
num_edges
assert
len
(
nf
)
==
num_nodes
assert
nf
.
is_readonly
assert
not
nf
.
is_multigraph
assert
np
.
all
(
F
.
asnumpy
(
nf
.
has_nodes
(
list
(
range
(
num_nodes
)))))
for
i
in
range
(
num_nodes
):
...
...
@@ -131,6 +148,11 @@ def test_basic():
assert
nf
.
num_layers
==
num_layers
+
1
check_basic
(
g
,
nf
)
g
=
generate_rand_graph
(
100
,
add_self_loop
=
True
)
nf
=
create_mini_batch
(
g
,
num_layers
,
add_self_loop
=
True
)
assert
nf
.
num_layers
==
num_layers
+
1
check_basic
(
g
,
nf
)
def
check_apply_nodes
(
create_node_flow
,
use_negative_block_id
):
num_layers
=
2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment