"src/vscode:/vscode.git/clone" did not exist on "be52be721563d155b66dac67593b7d01c0cc78a8"
Unverified Commit 37d992ec authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Bugfix] Fix no attribute num_edges bug in Nodeflow (#1289)

* fix nodeflow bug when using builtin on edge data

* fix
parent b05cb84a
...@@ -407,7 +407,7 @@ def schedule_nodeflow_apply_edges(graph, block_id, ...@@ -407,7 +407,7 @@ def schedule_nodeflow_apply_edges(graph, block_id,
name='out_nf') name='out_nf')
var_ef = var.FEAT_DICT(graph._get_edge_frame(block_id), name='ef') var_ef = var.FEAT_DICT(graph._get_edge_frame(block_id), name='ef')
var_out = _gen_send(graph, u, v, eid, apply_func, in_var_nf, out_var_nf, var_out = _gen_send(graph, u, v, eid, apply_func, in_var_nf, out_var_nf,
var_ef) var_ef, block_id=block_id)
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_ef, var_eid, var_out) ir.WRITE_ROW_INPLACE_(var_ef, var_eid, var_out)
...@@ -967,13 +967,14 @@ def _gen_udf_send(var_src_nf, var_dst_nf, var_ef, u, v, eid, mfunc, ...@@ -967,13 +967,14 @@ def _gen_udf_send(var_src_nf, var_dst_nf, var_ef, u, v, eid, mfunc,
fdedge = ir.READ_ROW(var_ef, eid) fdedge = ir.READ_ROW(var_ef, eid)
def _mfunc_wrapper(src_data, edge_data, dst_data): def _mfunc_wrapper(src_data, edge_data, dst_data):
ebatch = EdgeBatch((u.data, v.data, eid.data), ebatch = EdgeBatch((u.data, v.data, eid.data),
src_data, edge_data, dst_data, canonical_etype=canonical_etype) src_data, edge_data, dst_data,
canonical_etype=canonical_etype)
return mfunc(ebatch) return mfunc(ebatch)
_mfunc_wrapper = var.FUNC(_mfunc_wrapper) _mfunc_wrapper = var.FUNC(_mfunc_wrapper)
msg = ir.EDGE_UDF(_mfunc_wrapper, fdsrc, fdedge, fddst) msg = ir.EDGE_UDF(_mfunc_wrapper, fdsrc, fdedge, fddst)
return msg return msg
def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef): def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef, block_id=None):
"""Internal function to generate send schedule""" """Internal function to generate send schedule"""
mfunc = _standardize_func_usage(mfunc, 'message') mfunc = _standardize_func_usage(mfunc, 'message')
mfunc_is_list = utils.is_iterable(mfunc) mfunc_is_list = utils.is_iterable(mfunc)
...@@ -983,7 +984,10 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef): ...@@ -983,7 +984,10 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef):
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
if mfunc_is_list: if mfunc_is_list:
if eid.is_slice(0, graph.num_edges()): if not hasattr(graph, 'num_edges'):
# XXX(minjie): a temporary hack to detect Nodeflow object
res = spmv.build_gidx_and_mapping_block(graph, block_id)
elif eid.is_slice(0, graph.num_edges()):
# full graph case # full graph case
res = spmv.build_gidx_and_mapping_graph(graph) res = spmv.build_gidx_and_mapping_graph(graph)
else: else:
...@@ -991,7 +995,7 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef): ...@@ -991,7 +995,7 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef):
(u, v, eid), graph.num_src(), graph.num_dst()) (u, v, eid), graph.num_src(), graph.num_dst())
adj, edge_map, _ = res adj, edge_map, _ = res
# create a tmp message frame # create a tmp message frame
tmp_mfr = FrameRef(frame_like(graph.edgeframe._frame, len(eid))) tmp_mfr = FrameRef(frame_like(var_ef.data._frame, len(eid)))
var_out = var.FEAT_DICT(data=tmp_mfr) var_out = var.FEAT_DICT(data=tmp_mfr)
spmv.gen_v2e_spmv_schedule(graph=adj, spmv.gen_v2e_spmv_schedule(graph=adj,
mfunc=mfunc, mfunc=mfunc,
......
...@@ -219,6 +219,13 @@ def check_apply_edges(create_node_flow): ...@@ -219,6 +219,13 @@ def check_apply_edges(create_node_flow):
assert_array_equal( assert_array_equal(
F.asnumpy(nf.blocks[i].data['f2']), F.asnumpy(expected_f_sum)) F.asnumpy(nf.blocks[i].data['f2']), F.asnumpy(expected_f_sum))
# test built-in
nf.apply_block(i, fn.u_add_v('f', 'f', 'f2'))
eids = nf.block_parent_eid(i)
srcs, dsts = g.find_edges(eids)
expected_f_sum = g.nodes[srcs].data["f"] + g.nodes[dsts].data["f"]
assert_array_equal(
F.asnumpy(nf.blocks[i].data['f2']), F.asnumpy(expected_f_sum))
def check_apply_edges1(create_node_flow): def check_apply_edges1(create_node_flow):
num_layers = 2 num_layers = 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment