Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
37d992ec
Unverified
Commit
37d992ec
authored
Feb 26, 2020
by
Minjie Wang
Committed by
GitHub
Feb 26, 2020
Browse files
[Bugfix] Fix no attribute num_edges bug in Nodeflow (#1289)
* fix nodeflow bug when using builtin on edge data * fix
parent
b05cb84a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
5 deletions
+16
-5
python/dgl/runtime/scheduler.py
python/dgl/runtime/scheduler.py
+9
-5
tests/compute/test_nodeflow.py
tests/compute/test_nodeflow.py
+7
-0
No files found.
python/dgl/runtime/scheduler.py
View file @
37d992ec
...
...
@@ -407,7 +407,7 @@ def schedule_nodeflow_apply_edges(graph, block_id,
name
=
'out_nf'
)
var_ef
=
var
.
FEAT_DICT
(
graph
.
_get_edge_frame
(
block_id
),
name
=
'ef'
)
var_out
=
_gen_send
(
graph
,
u
,
v
,
eid
,
apply_func
,
in_var_nf
,
out_var_nf
,
var_ef
)
var_ef
,
block_id
=
block_id
)
var_eid
=
var
.
IDX
(
eid
)
if
inplace
:
ir
.
WRITE_ROW_INPLACE_
(
var_ef
,
var_eid
,
var_out
)
...
...
@@ -967,13 +967,14 @@ def _gen_udf_send(var_src_nf, var_dst_nf, var_ef, u, v, eid, mfunc,
fdedge
=
ir
.
READ_ROW
(
var_ef
,
eid
)
def
_mfunc_wrapper
(
src_data
,
edge_data
,
dst_data
):
ebatch
=
EdgeBatch
((
u
.
data
,
v
.
data
,
eid
.
data
),
src_data
,
edge_data
,
dst_data
,
canonical_etype
=
canonical_etype
)
src_data
,
edge_data
,
dst_data
,
canonical_etype
=
canonical_etype
)
return
mfunc
(
ebatch
)
_mfunc_wrapper
=
var
.
FUNC
(
_mfunc_wrapper
)
msg
=
ir
.
EDGE_UDF
(
_mfunc_wrapper
,
fdsrc
,
fdedge
,
fddst
)
return
msg
def
_gen_send
(
graph
,
u
,
v
,
eid
,
mfunc
,
var_src_nf
,
var_dst_nf
,
var_ef
):
def
_gen_send
(
graph
,
u
,
v
,
eid
,
mfunc
,
var_src_nf
,
var_dst_nf
,
var_ef
,
block_id
=
None
):
"""Internal function to generate send schedule"""
mfunc
=
_standardize_func_usage
(
mfunc
,
'message'
)
mfunc_is_list
=
utils
.
is_iterable
(
mfunc
)
...
...
@@ -983,7 +984,10 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef):
var_eid
=
var
.
IDX
(
eid
)
if
mfunc_is_list
:
if
eid
.
is_slice
(
0
,
graph
.
num_edges
()):
if
not
hasattr
(
graph
,
'num_edges'
):
# XXX(minjie): a temporary hack to detect Nodeflow object
res
=
spmv
.
build_gidx_and_mapping_block
(
graph
,
block_id
)
elif
eid
.
is_slice
(
0
,
graph
.
num_edges
()):
# full graph case
res
=
spmv
.
build_gidx_and_mapping_graph
(
graph
)
else
:
...
...
@@ -991,7 +995,7 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef):
(
u
,
v
,
eid
),
graph
.
num_src
(),
graph
.
num_dst
())
adj
,
edge_map
,
_
=
res
# create a tmp message frame
tmp_mfr
=
FrameRef
(
frame_like
(
graph
.
edgeframe
.
_frame
,
len
(
eid
)))
tmp_mfr
=
FrameRef
(
frame_like
(
var_ef
.
data
.
_frame
,
len
(
eid
)))
var_out
=
var
.
FEAT_DICT
(
data
=
tmp_mfr
)
spmv
.
gen_v2e_spmv_schedule
(
graph
=
adj
,
mfunc
=
mfunc
,
...
...
tests/compute/test_nodeflow.py
View file @
37d992ec
...
...
@@ -219,6 +219,13 @@ def check_apply_edges(create_node_flow):
assert_array_equal
(
F
.
asnumpy
(
nf
.
blocks
[
i
].
data
[
'f2'
]),
F
.
asnumpy
(
expected_f_sum
))
# test built-in
nf
.
apply_block
(
i
,
fn
.
u_add_v
(
'f'
,
'f'
,
'f2'
))
eids
=
nf
.
block_parent_eid
(
i
)
srcs
,
dsts
=
g
.
find_edges
(
eids
)
expected_f_sum
=
g
.
nodes
[
srcs
].
data
[
"f"
]
+
g
.
nodes
[
dsts
].
data
[
"f"
]
assert_array_equal
(
F
.
asnumpy
(
nf
.
blocks
[
i
].
data
[
'f2'
]),
F
.
asnumpy
(
expected_f_sum
))
def
check_apply_edges1
(
create_node_flow
):
num_layers
=
2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment