Unverified Commit b82b8e2d authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Schedule] get rid of unnecessary graph.edges() call for v2v-spmv update_all (#179)

parent 524e656d
...@@ -101,8 +101,13 @@ def schedule_snr(graph, ...@@ -101,8 +101,13 @@ def schedule_snr(graph,
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes') var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes')
# generate send and reduce schedule # generate send and reduce schedule
reduced_feat = _gen_send_reduce(call_type, graph, uv_getter = lambda : (var_u, var_v)
message_func, reduce_func, (var_u, var_v, var_eid), recv_nodes) adj_creator = lambda : spmv.build_adj_matrix_uv(graph, (u, v), recv_nodes)
inc_creator = lambda : spmv.build_inc_matrix_dst(v, recv_nodes)
reduced_feat = _gen_send_reduce(
graph, message_func, reduce_func,
var_eid, var_recv_nodes,
uv_getter, adj_creator, inc_creator)
# generate apply schedule # generate apply schedule
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func) final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
...@@ -128,18 +133,22 @@ def schedule_update_all(graph, message_func, reduce_func, apply_func): ...@@ -128,18 +133,22 @@ def schedule_update_all(graph, message_func, reduce_func, apply_func):
schedule_apply_nodes(graph, nodes, apply_func) schedule_apply_nodes(graph, nodes, apply_func)
else: else:
call_type = 'update_all' call_type = 'update_all'
src, dst, _ = graph._graph.edges()
eid = utils.toindex(slice(0, graph.number_of_edges())) # shortcut for ALL eid = utils.toindex(slice(0, graph.number_of_edges())) # shortcut for ALL
recv_nodes = utils.toindex(slice(0, graph.number_of_nodes())) # shortcut for ALL recv_nodes = utils.toindex(slice(0, graph.number_of_nodes())) # shortcut for ALL
# create vars # create vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf') var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes') var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes')
var_src = var.IDX(src)
var_dst = var.IDX(dst)
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
# generate send + reduce # generate send + reduce
reduced_feat = _gen_send_reduce(call_type, graph, def uv_getter():
message_func, reduce_func, (var_src, var_dst, var_eid), recv_nodes) src, dst, _ = graph._graph.edges()
return var.IDX(src), var.IDX(dst)
adj_creator = lambda : spmv.build_adj_matrix_graph(graph)
inc_creator = lambda : spmv.build_inc_matrix_graph(graph)
reduced_feat = _gen_send_reduce(
graph, message_func, reduce_func,
var_eid, var_recv_nodes,
uv_getter, adj_creator, inc_creator)
# generate optional apply # generate optional apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func) final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
ir.WRITE_DICT_(var_nf, final_feat) ir.WRITE_DICT_(var_nf, final_feat)
...@@ -248,7 +257,7 @@ def schedule_pull(graph, pull_nodes, message_func, reduce_func, apply_func): ...@@ -248,7 +257,7 @@ def schedule_pull(graph, pull_nodes, message_func, reduce_func, apply_func):
""" """
# TODO(minjie): `in_edges` can be omitted if message and reduce func pairs # TODO(minjie): `in_edges` can be omitted if message and reduce func pairs
# can be specialized to SPMV. This needs support for creating adjmat # can be specialized to SPMV. This needs support for creating adjmat
# directly from dst node frontier. # directly from pull node frontier.
u, v, eid = graph._graph.in_edges(pull_nodes) u, v, eid = graph._graph.in_edges(pull_nodes)
if len(eid) == 0: if len(eid) == 0:
# All the nodes are 0deg; downgrades to apply. # All the nodes are 0deg; downgrades to apply.
...@@ -265,8 +274,13 @@ def schedule_pull(graph, pull_nodes, message_func, reduce_func, apply_func): ...@@ -265,8 +274,13 @@ def schedule_pull(graph, pull_nodes, message_func, reduce_func, apply_func):
var_v = var.IDX(v) var_v = var.IDX(v)
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
# generate send and reduce schedule # generate send and reduce schedule
reduced_feat = _gen_send_reduce(call_type, graph, uv_getter = lambda : (var_u, var_v)
message_func, reduce_func, (var_u, var_v, var_eid), pull_nodes) adj_creator = lambda : spmv.build_adj_matrix_uv(graph, (u, v), pull_nodes)
inc_creator = lambda : spmv.build_inc_matrix_dst(v, pull_nodes)
reduced_feat = _gen_send_reduce(
graph, message_func, reduce_func,
var_eid, var_pull_nodes,
uv_getter, adj_creator, inc_creator)
# generate optional apply # generate optional apply
final_feat = _apply_with_accum(graph, var_pull_nodes, var_nf, reduced_feat, apply_func) final_feat = _apply_with_accum(graph, var_pull_nodes, var_nf, reduced_feat, apply_func)
ir.WRITE_ROW_(var_nf, var_pull_nodes, final_feat) ir.WRITE_ROW_(var_nf, var_pull_nodes, final_feat)
...@@ -363,8 +377,7 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes): ...@@ -363,8 +377,7 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
# analyze e2v spmv # analyze e2v spmv
spmv_rfunc, rfunc = spmv.analyze_e2v_spmv(graph, rfunc) spmv_rfunc, rfunc = spmv.analyze_e2v_spmv(graph, rfunc)
# FIXME: refactor this when fixing the multi-recv bug # FIXME: refactor this when fixing the multi-recv bug
mat = spmv.build_incmat_by_eid(graph._msg_frame.num_rows, mid, dst, recv_nodes) inc = spmv.build_inc_matrix_eid(graph._msg_frame.num_rows, mid, dst, recv_nodes)
inc = utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx))
spmv.gen_e2v_spmv_schedule(inc, spmv_rfunc, msg, out) spmv.gen_e2v_spmv_schedule(inc, spmv_rfunc, msg, out)
if len(rfunc) == 0: if len(rfunc) == 0:
...@@ -380,28 +393,50 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes): ...@@ -380,28 +393,50 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
return out return out
def _gen_send_reduce( def _gen_send_reduce(
call_type,
graph, graph,
message_func, message_func,
reduce_func, reduce_func,
edge_tuples, var_send_edges,
recv_nodes): var_reduce_nodes,
uv_getter,
adj_creator,
inc_creator):
"""Generate send and reduce schedule. """Generate send and reduce schedule.
This guarantees that the returned reduced features are batched This guarantees that the returned reduced features are batched
in the *unique-ascending* order of the edge destination node ids. in the *unique-ascending* order of the edge destination node ids.
call_type : str Parameters
----------
graph : DGLGraph graph : DGLGraph
The graph
message_func : callable, list of builtins message_func : callable, list of builtins
The message func(s).
reduce_func : callable, list of builtins reduce_func : callable, list of builtins
edge_tuples : (u, v, eid) tuple of var.Var The reduce func(s).
recv_nodes : utils.index var_send_edges : var.IDX
The edges (ids) to perform send.
var_reduce_nodes : var.IDX
The nodes to perform reduce. This should include unique(v) + 0deg nodes.
uv_getter : callable
A function that returns a pair of var.IDX (u, v) for the triggered edges.
adj_creator : callable
A function that returns var.SPMAT that represents the adjmat.
inc_creator : callable
A function that returns var.SPMAT that represents the incmat.
Returns
-------
var.FEAT_DICT
The reduced feature dict.
""" """
# NOTE: currently, this function requires all var.IDX to contain concrete data.
reduce_nodes = var_reduce_nodes.data
# arg vars # arg vars
var_u, var_v, var_eid = edge_tuples
var_nf = var.FEAT_DICT(graph._node_frame, name='nf') var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef') var_ef = var.FEAT_DICT(graph._edge_frame, name='ef')
var_eid = var_send_edges
# format the input functions # format the input functions
mfunc = _standardize_func_usage(message_func, 'message') mfunc = _standardize_func_usage(message_func, 'message')
...@@ -413,15 +448,14 @@ def _gen_send_reduce( ...@@ -413,15 +448,14 @@ def _gen_send_reduce(
# The frame has the same size and schemes of the # The frame has the same size and schemes of the
# node frame. # node frame.
# TODO(minjie): should replace this with an IR call to make the program stateless. # TODO(minjie): should replace this with an IR call to make the program stateless.
tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(recv_nodes))) tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(reduce_nodes)))
var_out = var.FEAT_DICT(data=tmpframe) var_out = var.FEAT_DICT(data=tmpframe)
if mfunc_is_list and rfunc_is_list: if mfunc_is_list and rfunc_is_list:
# builtin message + builtin reducer # builtin message + builtin reducer
# analyze v2v spmv # analyze v2v spmv
spmv_pairs, mfunc, rfunc = spmv.analyze_v2v_spmv(graph, mfunc, rfunc) spmv_pairs, mfunc, rfunc = spmv.analyze_v2v_spmv(graph, mfunc, rfunc)
adj = spmv.build_adj_matrix(call_type, graph, adj = adj_creator()
(var_u.data, var_v.data), recv_nodes)
spmv.gen_v2v_spmv_schedule(adj, spmv_pairs, var_nf, var_ef, var_eid, var_out) spmv.gen_v2v_spmv_schedule(adj, spmv_pairs, var_nf, var_ef, var_eid, var_out)
if len(mfunc) == 0: if len(mfunc) == 0:
...@@ -437,13 +471,14 @@ def _gen_send_reduce( ...@@ -437,13 +471,14 @@ def _gen_send_reduce(
mfunc = BundledFunction(mfunc) mfunc = BundledFunction(mfunc)
# generate UDF send schedule # generate UDF send schedule
var_u, var_v = uv_getter()
var_mf = _gen_send(graph, var_nf, var_ef, var_u, var_v, var_eid, mfunc) var_mf = _gen_send(graph, var_nf, var_ef, var_u, var_v, var_eid, mfunc)
if rfunc_is_list: if rfunc_is_list:
# UDF message + builtin reducer # UDF message + builtin reducer
# analyze e2v spmv # analyze e2v spmv
spmv_rfunc, rfunc = spmv.analyze_e2v_spmv(graph, rfunc) spmv_rfunc, rfunc = spmv.analyze_e2v_spmv(graph, rfunc)
inc = spmv.build_inc_matrix(call_type, graph, var_v.data, recv_nodes) inc = inc_creator()
spmv.gen_e2v_spmv_schedule(inc, spmv_rfunc, var_mf, var_out) spmv.gen_e2v_spmv_schedule(inc, spmv_rfunc, var_mf, var_out)
if len(rfunc) == 0: if len(rfunc) == 0:
...@@ -456,7 +491,7 @@ def _gen_send_reduce( ...@@ -456,7 +491,7 @@ def _gen_send_reduce(
# gen degree bucketing schedule for UDF recv # gen degree bucketing schedule for UDF recv
mid = utils.toindex(slice(0, len(var_v.data))) # message id is from 0~|dst| mid = utils.toindex(slice(0, len(var_v.data))) # message id is from 0~|dst|
db.gen_degree_bucketing_schedule(graph, rfunc, db.gen_degree_bucketing_schedule(graph, rfunc,
mid, var_v.data, recv_nodes, mid, var_v.data, reduce_nodes,
var_nf, var_mf, var_out) var_nf, var_mf, var_out)
return var_out return var_out
......
...@@ -122,35 +122,20 @@ def gen_e2v_spmv_schedule(inc, spmv_rfunc, mf, out): ...@@ -122,35 +122,20 @@ def gen_e2v_spmv_schedule(inc, spmv_rfunc, mf, out):
ftdst = ir.SPMV(inc_var, ftmsg) ftdst = ir.SPMV(inc_var, ftmsg)
ir.WRITE_COL_(out, var.STR(rfn.out_field), ftdst) ir.WRITE_COL_(out, var.STR(rfn.out_field), ftdst)
def build_adj_matrix(call_type, graph, edges, reduce_nodes): def build_adj_matrix_graph(graph):
"""Build adjacency matrix. """Build adjacency matrix of the whole graph.
Parameters Parameters
---------- ----------
call_type : str
Can be 'update_all', 'send_and_recv'
graph : DGLGraph graph : DGLGraph
The graph The graph
edges : tuple of utils.Index
(u, v)
reduce_nodes : utils.Index
The nodes to reduce messages, which will be target dimension
of the adjmat. The nodes include unique(v) and zero-degree-nodes.
Returns Returns
------- -------
utils.CtxCachedObject utils.CtxCachedObject
Get be used to get adjacency matrix on the provided ctx. Get be used to get adjacency matrix on the provided ctx.
""" """
if call_type == "update_all": return utils.CtxCachedObject(lambda ctx : graph.adjacency_matrix(ctx=ctx))
# full graph case
return utils.CtxCachedObject(lambda ctx : graph.adjacency_matrix(ctx=ctx))
elif call_type == "send_and_recv":
# edgeset case
mat = build_adj_matrix_uv(graph, edges, reduce_nodes)
return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx))
else:
raise DGLError('Invalid call type:', call_type)
def build_adj_matrix_index_uv(graph, edges, reduce_nodes): def build_adj_matrix_index_uv(graph, edges, reduce_nodes):
"""Build adj matrix index and shape using the given (u, v) edges. """Build adj matrix index and shape using the given (u, v) edges.
...@@ -212,8 +197,8 @@ def build_adj_matrix_uv(graph, edges, reduce_nodes): ...@@ -212,8 +197,8 @@ def build_adj_matrix_uv(graph, edges, reduce_nodes):
Returns Returns
------- -------
Sparse matrix utils.CtxCachedObject
The adjacency matrix on CPU Get be used to get adjacency matrix on the provided ctx.
""" """
sp_idx, shape = build_adj_matrix_index_uv(graph, edges, reduce_nodes) sp_idx, shape = build_adj_matrix_index_uv(graph, edges, reduce_nodes)
u, v = edges u, v = edges
...@@ -221,39 +206,24 @@ def build_adj_matrix_uv(graph, edges, reduce_nodes): ...@@ -221,39 +206,24 @@ def build_adj_matrix_uv(graph, edges, reduce_nodes):
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu()) dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu())
mat = F.sparse_matrix(dat, sp_idx, shape) mat = F.sparse_matrix(dat, sp_idx, shape)
return mat return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx))
def build_inc_matrix(call_type, graph, dst, reduce_nodes): def build_inc_matrix_graph(graph):
"""Build incidence matrix. """Build incidence matrix.
Parameters Parameters
---------- ----------
call_type : str
Can be 'update_all', 'send_and_recv'.
graph : DGLGraph graph : DGLGraph
The graph. The graph.
dst : utils.Index
The destination nodes of the edges.
reduce_nodes : utils.Index
The nodes to reduce messages, which will be target dimension
of the incmat. The nodes include unique(dst) and zero-degree-nodes.
Returns Returns
------- -------
utils.CtxCachedObject utils.CtxCachedObject
Get be used to get incidence matrix on the provided ctx. Get be used to get incidence matrix on the provided ctx.
""" """
if call_type == "update_all": return utils.CtxCachedObject(lambda ctx : graph.incidence_matrix(type='in', ctx=ctx))
# full graph case
return utils.CtxCachedObject(lambda ctx : graph.incidence_matrix(type='in', ctx=ctx)) def build_inc_matrix_eid(m, eid, dst, reduce_nodes):
elif call_type == "send_and_recv":
# edgeset case
mat = build_incmat_by_dst(dst, reduce_nodes)
return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx))
else:
raise DGLError('Invalid call type:', call_type)
def build_incmat_by_eid(m, eid, dst, reduce_nodes):
"""Build incidence matrix using edge id and edge dst nodes. """Build incidence matrix using edge id and edge dst nodes.
The incidence matrix is of shape (n, m), where n=len(reduce_nodes). The incidence matrix is of shape (n, m), where n=len(reduce_nodes).
...@@ -276,7 +246,7 @@ def build_incmat_by_eid(m, eid, dst, reduce_nodes): ...@@ -276,7 +246,7 @@ def build_incmat_by_eid(m, eid, dst, reduce_nodes):
>>> eid = [1, 2, 3, 5, 6] >>> eid = [1, 2, 3, 5, 6]
>>> dst = [1, 1, 3, 4, 4] >>> dst = [1, 1, 3, 4, 4]
>>> reduce_nodes = [0, 1, 2, 3, 4] >>> reduce_nodes = [0, 1, 2, 3, 4]
>>> build_incmat_by_eid(m, eid, dst, reduce_nodes) >>> build_inc_matrix_eid(m, eid, dst, reduce_nodes)
tensor([[0, 0, 0, 0, 0, 0, 0], tensor([[0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 0, 0, 0, 0], [0, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0],
...@@ -297,8 +267,8 @@ def build_incmat_by_eid(m, eid, dst, reduce_nodes): ...@@ -297,8 +267,8 @@ def build_incmat_by_eid(m, eid, dst, reduce_nodes):
Returns Returns
------- -------
Sparse matrix utils.CtxCachedObject
The incidence matrix on CPU Get be used to get incidence matrix on the provided ctx.
""" """
new2old, old2new = utils.build_relabel_map(reduce_nodes, sorted=True) new2old, old2new = utils.build_relabel_map(reduce_nodes, sorted=True)
dst = dst.tousertensor() dst = dst.tousertensor()
...@@ -313,9 +283,10 @@ def build_incmat_by_eid(m, eid, dst, reduce_nodes): ...@@ -313,9 +283,10 @@ def build_incmat_by_eid(m, eid, dst, reduce_nodes):
# create dat tensor # create dat tensor
nnz = len(eid) nnz = len(eid)
dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu()) dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu())
return F.sparse_matrix(dat, ('coo', idx), (n, m)) mat = F.sparse_matrix(dat, ('coo', idx), (n, m))
return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx))
def build_incmat_by_dst(dst, reduce_nodes): def build_inc_matrix_dst(dst, reduce_nodes):
"""Build incidence matrix using only edge destinations. """Build incidence matrix using only edge destinations.
The incidence matrix is of shape (n, m), where n=len(reduce_nodes), m=len(dst). The incidence matrix is of shape (n, m), where n=len(reduce_nodes), m=len(dst).
...@@ -328,7 +299,7 @@ def build_incmat_by_dst(dst, reduce_nodes): ...@@ -328,7 +299,7 @@ def build_incmat_by_dst(dst, reduce_nodes):
target dimension (0~4), where node 0 and 2 are two 0-deg nodes. target dimension (0~4), where node 0 and 2 are two 0-deg nodes.
>>> dst = [1, 1, 3, 4, 4] >>> dst = [1, 1, 3, 4, 4]
>>> reduce_nodes = [0, 1, 2, 3, 4] >>> reduce_nodes = [0, 1, 2, 3, 4]
>>> build_incmat_by_dst(dst, reduced_nodes) >>> build_inc_matrix_dst(dst, reduced_nodes)
tensor([[0, 0, 0, 0, 0], tensor([[0, 0, 0, 0, 0],
[1, 1, 0, 0, 0], [1, 1, 0, 0, 0],
[0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
...@@ -345,8 +316,8 @@ def build_incmat_by_dst(dst, reduce_nodes): ...@@ -345,8 +316,8 @@ def build_incmat_by_dst(dst, reduce_nodes):
Returns Returns
------- -------
Sparse matrix utils.CtxCachedObject
The incidence matrix on CPU Get be used to get incidence matrix on the provided ctx.
""" """
eid = utils.toindex(F.arange(0, len(dst))) eid = utils.toindex(F.arange(0, len(dst)))
return build_incmat_by_eid(len(eid), eid, dst, reduce_nodes) return build_inc_matrix_eid(len(eid), eid, dst, reduce_nodes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment