Commit 52d4535b authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by Minjie Wang
Browse files

[Hetero][RFC] Heterogeneous graph Python interfaces & Message Passing (#752)

* moving heterograph index to another file

* node view

* python interfaces

* heterograph init

* bug fixes

* docstring for readonly

* more docstring

* unit tests & lint

* oops

* oops x2

* removed node/edge addition

* addressed comments

* lint

* rw on frames with one node/edge type

* homograph with underlying heterograph demo

* view is not necessary

* bugfix

* replace

* scheduler, builtins not working yet

* moving bipartite.h to header

* moving back bipartite to bipartite.h

* oops

* asbits and copyto for bipartite

* tested update_all and send_and_recv

* lightweight node & edge type retrieval

* oops

* sorry

* removing obsolete code

* oops

* lint

* various bug fixes & more tests

* UDF tests

* multiple type number_of_nodes and number_of_edges

* docstring fixes

* more tests

* going for dict in initialization

* lint

* updated api as per discussions

* lint

* bug

* bugfix

* moving back bipartite impl to cc

* note on views

* fix
parent 66971c1a
...@@ -116,6 +116,11 @@ IdArray IndexSelect(IdArray array, IdArray index); ...@@ -116,6 +116,11 @@ IdArray IndexSelect(IdArray array, IdArray index);
*/ */
IdArray Relabel_(const std::vector<IdArray>& arrays); IdArray Relabel_(const std::vector<IdArray>& arrays);
/*!\brief Return whether the array is a valid 1D int array*/
inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ndim == 1 && arr->dtype.code == kDLInt;
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
// Sparse matrix // Sparse matrix
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
......
...@@ -17,6 +17,7 @@ from .base import ALL ...@@ -17,6 +17,7 @@ from .base import ALL
from .backend import load_backend from .backend import load_backend
from .batched_graph import * from .batched_graph import *
from .graph import DGLGraph from .graph import DGLGraph
from .heterograph import DGLHeteroGraph
from .nodeflow import * from .nodeflow import *
from .traversal import * from .traversal import *
from .transform import * from .transform import *
......
...@@ -49,6 +49,14 @@ class DGLBaseGraph(object): ...@@ -49,6 +49,14 @@ class DGLBaseGraph(object):
""" """
return self._graph.number_of_nodes() return self._graph.number_of_nodes()
def _number_of_src_nodes(self):
"""Return number of source nodes (only used in scheduler)"""
return self.number_of_nodes()
def _number_of_dst_nodes(self):
"""Return number of destination nodes (only used in scheduler)"""
return self.number_of_nodes()
def __len__(self): def __len__(self):
"""Return the number of nodes in the graph.""" """Return the number of nodes in the graph."""
return self.number_of_nodes() return self.number_of_nodes()
...@@ -65,6 +73,10 @@ class DGLBaseGraph(object): ...@@ -65,6 +73,10 @@ class DGLBaseGraph(object):
""" """
return self._graph.is_readonly() return self._graph.is_readonly()
def _number_of_edges(self):
"""Return number of edges in the current view (only used for scheduler)"""
return self.number_of_edges()
def number_of_edges(self): def number_of_edges(self):
"""Return the number of edges in the graph. """Return the number of edges in the graph.
...@@ -939,6 +951,14 @@ class DGLGraph(DGLBaseGraph): ...@@ -939,6 +951,14 @@ class DGLGraph(DGLBaseGraph):
def _set_msg_index(self, index): def _set_msg_index(self, index):
self._msg_index = index self._msg_index = index
@property
def _src_frame(self):
return self._node_frame
@property
def _dst_frame(self):
return self._node_frame
def add_nodes(self, num, data=None): def add_nodes(self, num, data=None):
"""Add multiple new nodes. """Add multiple new nodes.
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -161,7 +161,8 @@ def gen_group_apply_edge_schedule( ...@@ -161,7 +161,8 @@ def gen_group_apply_edge_schedule(
apply_func, apply_func,
u, v, eid, u, v, eid,
group_by, group_by,
var_nf, var_src_nf,
var_dst_nf,
var_ef, var_ef,
var_out): var_out):
"""Create degree bucketing schedule for group_apply_edge """Create degree bucketing schedule for group_apply_edge
...@@ -186,8 +187,10 @@ def gen_group_apply_edge_schedule( ...@@ -186,8 +187,10 @@ def gen_group_apply_edge_schedule(
Edges to apply Edges to apply
group_by: str group_by: str
If "src", group by u. If "dst", group by v If "src", group by u. If "dst", group by v
var_nf : var.FEAT_DICT var_src_nf : var.FEAT_DICT
The variable for node feature frame. The variable for source feature frame.
var_dst_nf : var.FEAT_DICT
The variable for destination feature frame.
var_ef : var.FEAT_DICT var_ef : var.FEAT_DICT
The variable for edge frame. The variable for edge frame.
var_out : var.FEAT_DICT var_out : var.FEAT_DICT
...@@ -213,8 +216,8 @@ def gen_group_apply_edge_schedule( ...@@ -213,8 +216,8 @@ def gen_group_apply_edge_schedule(
var_v = var.IDX(v_bkt) var_v = var.IDX(v_bkt)
var_eid = var.IDX(eid_bkt) var_eid = var.IDX(eid_bkt)
# apply edge UDF on each bucket # apply edge UDF on each bucket
fdsrc = ir.READ_ROW(var_nf, var_u) fdsrc = ir.READ_ROW(var_src_nf, var_u)
fddst = ir.READ_ROW(var_nf, var_v) fddst = ir.READ_ROW(var_dst_nf, var_v)
fdedge = ir.READ_ROW(var_ef, var_eid) fdedge = ir.READ_ROW(var_ef, var_eid)
fdedge = ir.EDGE_UDF(_efunc, fdsrc, fdedge, fddst, ret=fdedge) # reuse var fdedge = ir.EDGE_UDF(_efunc, fdsrc, fdedge, fddst, ret=fdedge) # reuse var
# save for merge # save for merge
......
...@@ -8,6 +8,8 @@ from .. import backend as F ...@@ -8,6 +8,8 @@ from .. import backend as F
from ..frame import frame_like, FrameRef from ..frame import frame_like, FrameRef
from ..function.base import BuiltinFunction from ..function.base import BuiltinFunction
from ..udf import EdgeBatch, NodeBatch from ..udf import EdgeBatch, NodeBatch
from ..graph_index import GraphIndex
from ..heterograph_index import HeteroGraphIndex
from . import ir from . import ir
from .ir import var from .ir import var
...@@ -28,6 +30,15 @@ __all__ = [ ...@@ -28,6 +30,15 @@ __all__ = [
"schedule_pull" "schedule_pull"
] ]
def _dispatch(graph, method, *args, **kwargs):
graph_index = graph._graph
if isinstance(graph_index, GraphIndex):
return getattr(graph._graph, method)(*args, **kwargs)
elif isinstance(graph_index, HeteroGraphIndex):
return getattr(graph._graph, method)(graph._current_etype_idx, *args, **kwargs)
else:
raise TypeError('unknown type %s' % type(graph_index))
def schedule_send(graph, u, v, eid, message_func): def schedule_send(graph, u, v, eid, message_func):
"""get send schedule """get send schedule
...@@ -45,7 +56,8 @@ def schedule_send(graph, u, v, eid, message_func): ...@@ -45,7 +56,8 @@ def schedule_send(graph, u, v, eid, message_func):
The message function The message function
""" """
var_mf = var.FEAT_DICT(graph._msg_frame) var_mf = var.FEAT_DICT(graph._msg_frame)
var_nf = var.FEAT_DICT(graph._node_frame) var_src_nf = var.FEAT_DICT(graph._src_frame)
var_dst_nf = var.FEAT_DICT(graph._dst_frame)
var_ef = var.FEAT_DICT(graph._edge_frame) var_ef = var.FEAT_DICT(graph._edge_frame)
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
...@@ -54,8 +66,8 @@ def schedule_send(graph, u, v, eid, message_func): ...@@ -54,8 +66,8 @@ def schedule_send(graph, u, v, eid, message_func):
v=v, v=v,
eid=eid, eid=eid,
mfunc=message_func, mfunc=message_func,
var_src_nf=var_nf, var_src_nf=var_src_nf,
var_dst_nf=var_nf, var_dst_nf=var_dst_nf,
var_ef=var_ef) var_ef=var_ef)
# write tmp msg back # write tmp msg back
...@@ -83,7 +95,7 @@ def schedule_recv(graph, ...@@ -83,7 +95,7 @@ def schedule_recv(graph,
inplace: bool inplace: bool
If True, the update will be done in place If True, the update will be done in place
""" """
src, dst, eid = graph._graph.in_edges(recv_nodes) src, dst, eid = _dispatch(graph, 'in_edges', recv_nodes)
if len(eid) > 0: if len(eid) > 0:
nonzero_idx = graph._get_msg_index().get_items(eid).nonzero() nonzero_idx = graph._get_msg_index().get_items(eid).nonzero()
eid = eid.get_items(nonzero_idx) eid = eid.get_items(nonzero_idx)
...@@ -96,7 +108,7 @@ def schedule_recv(graph, ...@@ -96,7 +108,7 @@ def schedule_recv(graph,
if apply_func is not None: if apply_func is not None:
schedule_apply_nodes(graph, recv_nodes, apply_func, inplace) schedule_apply_nodes(graph, recv_nodes, apply_func, inplace)
else: else:
var_nf = var.FEAT_DICT(graph._node_frame, name='nf') var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf')
# sort and unique the argument # sort and unique the argument
recv_nodes, _ = F.sort_1d(F.unique(recv_nodes.tousertensor())) recv_nodes, _ = F.sort_1d(F.unique(recv_nodes.tousertensor()))
recv_nodes = utils.toindex(recv_nodes) recv_nodes = utils.toindex(recv_nodes)
...@@ -105,12 +117,12 @@ def schedule_recv(graph, ...@@ -105,12 +117,12 @@ def schedule_recv(graph,
reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, eid), reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, eid),
recv_nodes) recv_nodes)
# apply # apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf,
reduced_feat, apply_func) reduced_feat, apply_func)
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_INPLACE_(var_dst_nf, var_recv_nodes, final_feat)
else: else:
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_(var_dst_nf, var_recv_nodes, final_feat)
# set message indicator to 0 # set message indicator to 0
graph._set_msg_index(graph._get_msg_index().set_items(eid, 0)) graph._set_msg_index(graph._get_msg_index().set_items(eid, 0))
if not graph._get_msg_index().has_nonzero(): if not graph._get_msg_index().has_nonzero():
...@@ -148,7 +160,7 @@ def schedule_snr(graph, ...@@ -148,7 +160,7 @@ def schedule_snr(graph,
recv_nodes, _ = F.sort_1d(F.unique(v.tousertensor())) recv_nodes, _ = F.sort_1d(F.unique(v.tousertensor()))
recv_nodes = utils.toindex(recv_nodes) recv_nodes = utils.toindex(recv_nodes)
# create vars # create vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf') var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='dst_nf')
var_u = var.IDX(u) var_u = var.IDX(u)
var_v = var.IDX(v) var_v = var.IDX(v)
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
...@@ -156,11 +168,11 @@ def schedule_snr(graph, ...@@ -156,11 +168,11 @@ def schedule_snr(graph,
# generate send and reduce schedule # generate send and reduce schedule
uv_getter = lambda: (var_u, var_v) uv_getter = lambda: (var_u, var_v)
adj_creator = lambda: spmv.build_gidx_and_mapping_uv( adj_creator = lambda: spmv.build_gidx_and_mapping_uv(
edge_tuples, graph.number_of_nodes()) edge_tuples, graph._number_of_src_nodes(), graph._number_of_dst_nodes())
out_map_creator = lambda nbits: _build_idx_map(recv_nodes, nbits) out_map_creator = lambda nbits: _build_idx_map(recv_nodes, nbits)
reduced_feat = _gen_send_reduce(graph=graph, reduced_feat = _gen_send_reduce(graph=graph,
src_node_frame=graph._node_frame, src_node_frame=graph._src_frame,
dst_node_frame=graph._node_frame, dst_node_frame=graph._dst_frame,
edge_frame=graph._edge_frame, edge_frame=graph._edge_frame,
message_func=message_func, message_func=message_func,
reduce_func=reduce_func, reduce_func=reduce_func,
...@@ -170,12 +182,12 @@ def schedule_snr(graph, ...@@ -170,12 +182,12 @@ def schedule_snr(graph,
adj_creator=adj_creator, adj_creator=adj_creator,
out_map_creator=out_map_creator) out_map_creator=out_map_creator)
# generate apply schedule # generate apply schedule
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf, reduced_feat,
apply_func) apply_func)
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_INPLACE_(var_dst_nf, var_recv_nodes, final_feat)
else: else:
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_(var_dst_nf, var_recv_nodes, final_feat)
def schedule_update_all(graph, def schedule_update_all(graph,
message_func, message_func,
...@@ -194,27 +206,27 @@ def schedule_update_all(graph, ...@@ -194,27 +206,27 @@ def schedule_update_all(graph,
apply_func: callable apply_func: callable
The apply node function The apply node function
""" """
if graph.number_of_edges() == 0: if graph._number_of_edges() == 0:
# All the nodes are zero degree; downgrade to apply nodes # All the nodes are zero degree; downgrade to apply nodes
if apply_func is not None: if apply_func is not None:
nodes = utils.toindex(slice(0, graph.number_of_nodes())) nodes = utils.toindex(slice(0, graph._number_of_dst_nodes()))
schedule_apply_nodes(graph, nodes, apply_func, inplace=False) schedule_apply_nodes(graph, nodes, apply_func, inplace=False)
else: else:
eid = utils.toindex(slice(0, graph.number_of_edges())) # ALL eid = utils.toindex(slice(0, graph._number_of_edges())) # ALL
recv_nodes = utils.toindex(slice(0, graph.number_of_nodes())) # ALL recv_nodes = utils.toindex(slice(0, graph._number_of_dst_nodes())) # ALL
# create vars # create vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf') var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf')
var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes') var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes')
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
# generate send + reduce # generate send + reduce
def uv_getter(): def uv_getter():
src, dst, _ = graph._graph.edges('eid') src, dst, _ = _dispatch(graph, 'edges', 'eid')
return var.IDX(src), var.IDX(dst) return var.IDX(src), var.IDX(dst)
adj_creator = lambda: spmv.build_gidx_and_mapping_graph(graph) adj_creator = lambda: spmv.build_gidx_and_mapping_graph(graph)
out_map_creator = lambda nbits: None out_map_creator = lambda nbits: None
reduced_feat = _gen_send_reduce(graph=graph, reduced_feat = _gen_send_reduce(graph=graph,
src_node_frame=graph._node_frame, src_node_frame=graph._src_frame,
dst_node_frame=graph._node_frame, dst_node_frame=graph._dst_frame,
edge_frame=graph._edge_frame, edge_frame=graph._edge_frame,
message_func=message_func, message_func=message_func,
reduce_func=reduce_func, reduce_func=reduce_func,
...@@ -224,9 +236,9 @@ def schedule_update_all(graph, ...@@ -224,9 +236,9 @@ def schedule_update_all(graph,
adj_creator=adj_creator, adj_creator=adj_creator,
out_map_creator=out_map_creator) out_map_creator=out_map_creator)
# generate optional apply # generate optional apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf,
reduced_feat, apply_func) reduced_feat, apply_func)
ir.WRITE_DICT_(var_nf, final_feat) ir.WRITE_DICT_(var_dst_nf, final_feat)
def schedule_apply_nodes(graph, def schedule_apply_nodes(graph,
v, v,
...@@ -326,10 +338,12 @@ def schedule_apply_edges(graph, ...@@ -326,10 +338,12 @@ def schedule_apply_edges(graph,
A list of executors for DGL Runtime A list of executors for DGL Runtime
""" """
# vars # vars
var_nf = var.FEAT_DICT(graph._node_frame) var_src_nf = var.FEAT_DICT(graph._src_frame)
var_dst_nf = var.FEAT_DICT(graph._dst_frame)
var_ef = var.FEAT_DICT(graph._edge_frame) var_ef = var.FEAT_DICT(graph._edge_frame)
var_out = _gen_send(graph=graph, u=u, v=v, eid=eid, mfunc=apply_func, var_out = _gen_send(graph=graph, u=u, v=v, eid=eid, mfunc=apply_func,
var_src_nf=var_nf, var_dst_nf=var_nf, var_ef=var_ef) var_src_nf=var_src_nf, var_dst_nf=var_dst_nf,
var_ef=var_ef)
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef') var_ef = var.FEAT_DICT(graph._edge_frame, name='ef')
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
# schedule apply edges # schedule apply edges
...@@ -401,7 +415,7 @@ def schedule_push(graph, ...@@ -401,7 +415,7 @@ def schedule_push(graph,
inplace: bool inplace: bool
If True, the update will be done in place If True, the update will be done in place
""" """
u, v, eid = graph._graph.out_edges(u) u, v, eid = _dispatch(graph, 'out_edges', u)
if len(eid) == 0: if len(eid) == 0:
# All the pushing nodes have no out edges. No computation is scheduled. # All the pushing nodes have no out edges. No computation is scheduled.
return return
...@@ -434,7 +448,7 @@ def schedule_pull(graph, ...@@ -434,7 +448,7 @@ def schedule_pull(graph,
# TODO(minjie): `in_edges` can be omitted if message and reduce func pairs # TODO(minjie): `in_edges` can be omitted if message and reduce func pairs
# can be specialized to SPMV. This needs support for creating adjmat # can be specialized to SPMV. This needs support for creating adjmat
# directly from pull node frontier. # directly from pull node frontier.
u, v, eid = graph._graph.in_edges(pull_nodes) u, v, eid = _dispatch(graph, 'in_edges', pull_nodes)
if len(eid) == 0: if len(eid) == 0:
# All the nodes are 0deg; downgrades to apply. # All the nodes are 0deg; downgrades to apply.
if apply_func is not None: if apply_func is not None:
...@@ -443,27 +457,27 @@ def schedule_pull(graph, ...@@ -443,27 +457,27 @@ def schedule_pull(graph,
pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor())) pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor()))
pull_nodes = utils.toindex(pull_nodes) pull_nodes = utils.toindex(pull_nodes)
# create vars # create vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf') var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf')
var_pull_nodes = var.IDX(pull_nodes, name='pull_nodes') var_pull_nodes = var.IDX(pull_nodes, name='pull_nodes')
var_u = var.IDX(u) var_u = var.IDX(u)
var_v = var.IDX(v) var_v = var.IDX(v)
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
# generate send and reduce schedule # generate send and reduce schedule
uv_getter = lambda: (var_u, var_v) uv_getter = lambda: (var_u, var_v)
num_nodes = graph.number_of_nodes() adj_creator = lambda: spmv.build_gidx_and_mapping_uv(
adj_creator = lambda: spmv.build_gidx_and_mapping_uv((u, v, eid), num_nodes) (u, v, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes())
out_map_creator = lambda nbits: _build_idx_map(pull_nodes, nbits) out_map_creator = lambda nbits: _build_idx_map(pull_nodes, nbits)
reduced_feat = _gen_send_reduce(graph, graph._node_frame, reduced_feat = _gen_send_reduce(graph, graph._src_frame,
graph._node_frame, graph._edge_frame, graph._dst_frame, graph._edge_frame,
message_func, reduce_func, var_eid, message_func, reduce_func, var_eid,
var_pull_nodes, uv_getter, adj_creator, var_pull_nodes, uv_getter, adj_creator,
out_map_creator) out_map_creator)
# generate optional apply # generate optional apply
final_feat = _apply_with_accum(graph, var_pull_nodes, var_nf, reduced_feat, apply_func) final_feat = _apply_with_accum(graph, var_pull_nodes, var_dst_nf, reduced_feat, apply_func)
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_pull_nodes, final_feat) ir.WRITE_ROW_INPLACE_(var_dst_nf, var_pull_nodes, final_feat)
else: else:
ir.WRITE_ROW_(var_nf, var_pull_nodes, final_feat) ir.WRITE_ROW_(var_dst_nf, var_pull_nodes, final_feat)
def schedule_group_apply_edge(graph, def schedule_group_apply_edge(graph,
u, v, eid, u, v, eid,
...@@ -494,11 +508,12 @@ def schedule_group_apply_edge(graph, ...@@ -494,11 +508,12 @@ def schedule_group_apply_edge(graph,
A list of executors for DGL Runtime A list of executors for DGL Runtime
""" """
# vars # vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf') var_src_nf = var.FEAT_DICT(graph._src_frame, name='src_nf')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='dst_nf')
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef') var_ef = var.FEAT_DICT(graph._edge_frame, name='ef')
var_out = var.FEAT_DICT(name='new_ef') var_out = var.FEAT_DICT(name='new_ef')
db.gen_group_apply_edge_schedule(graph, apply_func, u, v, eid, group_by, db.gen_group_apply_edge_schedule(graph, apply_func, u, v, eid, group_by,
var_nf, var_ef, var_out) var_src_nf, var_dst_nf, var_ef, var_out)
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
if inplace: if inplace:
ir.WRITE_ROW_INPLACE_(var_ef, var_eid, var_out) ir.WRITE_ROW_INPLACE_(var_ef, var_eid, var_out)
...@@ -719,17 +734,16 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes): ...@@ -719,17 +734,16 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
# node frame. # node frame.
# TODO(minjie): should replace this with an IR call to make the program # TODO(minjie): should replace this with an IR call to make the program
# stateless. # stateless.
tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(recv_nodes))) tmpframe = FrameRef(frame_like(graph._dst_frame._frame, len(recv_nodes)))
# vars # vars
var_msg = var.FEAT_DICT(graph._msg_frame, 'msg') var_msg = var.FEAT_DICT(graph._msg_frame, 'msg')
var_nf = var.FEAT_DICT(graph._node_frame, 'nf') var_dst_nf = var.FEAT_DICT(graph._dst_frame, 'nf')
var_out = var.FEAT_DICT(data=tmpframe) var_out = var.FEAT_DICT(data=tmpframe)
if rfunc_is_list: if rfunc_is_list:
num_nodes = graph.number_of_nodes()
adj, edge_map, nbits = spmv.build_gidx_and_mapping_uv( adj, edge_map, nbits = spmv.build_gidx_and_mapping_uv(
(src, dst, eid), num_nodes) (src, dst, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes())
# using edge map instead of message map because messages are in global # using edge map instead of message map because messages are in global
# message frame # message frame
var_out_map = _build_idx_map(recv_nodes, nbits) var_out_map = _build_idx_map(recv_nodes, nbits)
...@@ -744,7 +758,7 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes): ...@@ -744,7 +758,7 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
else: else:
# gen degree bucketing schedule for UDF recv # gen degree bucketing schedule for UDF recv
db.gen_degree_bucketing_schedule(graph, rfunc, eid, dst, recv_nodes, db.gen_degree_bucketing_schedule(graph, rfunc, eid, dst, recv_nodes,
var_nf, var_msg, var_out) var_dst_nf, var_msg, var_out)
return var_out return var_out
def _gen_send_reduce( def _gen_send_reduce(
...@@ -930,12 +944,12 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef): ...@@ -930,12 +944,12 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef):
var_eid = var.IDX(eid) var_eid = var.IDX(eid)
if mfunc_is_list: if mfunc_is_list:
if eid.is_slice(0, graph.number_of_edges()): if eid.is_slice(0, graph._number_of_edges()):
# full graph case # full graph case
res = spmv.build_gidx_and_mapping_graph(graph) res = spmv.build_gidx_and_mapping_graph(graph)
else: else:
num_nodes = graph.number_of_nodes() res = spmv.build_gidx_and_mapping_uv(
res = spmv.build_gidx_and_mapping_uv((u, v, eid), num_nodes) (u, v, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes())
adj, edge_map, _ = res adj, edge_map, _ = res
# create a tmp message frame # create a tmp message frame
tmp_mfr = FrameRef(frame_like(graph._edge_frame._frame, len(eid))) tmp_mfr = FrameRef(frame_like(graph._edge_frame._frame, len(eid)))
......
"""Module for SPMV rules.""" """Module for SPMV rules."""
from __future__ import absolute_import from __future__ import absolute_import
from functools import partial
from ..base import DGLError from ..base import DGLError
from .. import backend as F from .. import backend as F
from .. import utils from .. import utils
from .. import ndarray as nd from .. import ndarray as nd
from ..graph_index import from_coo from ..graph_index import GraphIndex
from ..heterograph_index import HeteroGraphIndex, create_bipartite_from_coo
from . import ir from . import ir
from .ir import var from .ir import var
...@@ -127,8 +129,8 @@ def build_gidx_and_mapping_graph(graph): ...@@ -127,8 +129,8 @@ def build_gidx_and_mapping_graph(graph):
Parameters Parameters
---------- ----------
graph : DGLGraph graph : DGLGraph or DGLHeteroGraph
The graph The homogeneous graph, or a bipartite view of the heterogeneous graph.
Returns Returns
------- -------
...@@ -141,10 +143,17 @@ def build_gidx_and_mapping_graph(graph): ...@@ -141,10 +143,17 @@ def build_gidx_and_mapping_graph(graph):
Number of ints needed to represent the graph Number of ints needed to represent the graph
""" """
gidx = graph._graph gidx = graph._graph
if isinstance(gidx, GraphIndex):
return gidx.get_immutable_gidx, None, gidx.bits_needed() return gidx.get_immutable_gidx, None, gidx.bits_needed()
elif isinstance(gidx, HeteroGraphIndex):
return (partial(gidx.get_bipartite, graph._current_etype_idx),
None,
gidx.bits_needed(graph._current_etype_idx))
else:
raise TypeError('unknown graph index type %s' % type(gidx))
def build_gidx_and_mapping_uv(edge_tuples, num_nodes): def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst):
"""Build immutable graph index and mapping using the given (u, v) edges """Build immutable graph index and mapping using the given (u, v) edges
The matrix is of shape (len(reduce_nodes), n), where n is the number of The matrix is of shape (len(reduce_nodes), n), where n is the number of
...@@ -155,8 +164,8 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes): ...@@ -155,8 +164,8 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes):
--------- ---------
edge_tuples : tuple of three utils.Index edge_tuples : tuple of three utils.Index
A tuple of (u, v, eid) A tuple of (u, v, eid)
num_nodes : int num_src, num_dst : int
The number of nodes. The number of source and destination nodes.
Returns Returns
------- -------
...@@ -169,10 +178,10 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes): ...@@ -169,10 +178,10 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes):
Number of ints needed to represent the graph Number of ints needed to represent the graph
""" """
u, v, eid = edge_tuples u, v, eid = edge_tuples
gidx = from_coo(num_nodes, u, v, None, True) gidx = create_bipartite_from_coo(num_src, num_dst, u, v)
forward, backward = gidx.get_csr_shuffle_order() forward, backward = gidx.get_csr_shuffle_order(0)
eid = eid.tousertensor() eid = eid.tousertensor()
nbits = gidx.bits_needed() nbits = gidx.bits_needed(0)
forward_map = utils.to_nbits_int(eid[forward.tousertensor()], nbits) forward_map = utils.to_nbits_int(eid[forward.tousertensor()], nbits)
backward_map = utils.to_nbits_int(eid[backward.tousertensor()], nbits) backward_map = utils.to_nbits_int(eid[backward.tousertensor()], nbits)
forward_map = F.zerocopy_to_dgl_ndarray(forward_map) forward_map = F.zerocopy_to_dgl_ndarray(forward_map)
...@@ -180,7 +189,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes): ...@@ -180,7 +189,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes):
edge_map = utils.CtxCachedObject( edge_map = utils.CtxCachedObject(
lambda ctx: (nd.array(forward_map, ctx=ctx), lambda ctx: (nd.array(forward_map, ctx=ctx),
nd.array(backward_map, ctx=ctx))) nd.array(backward_map, ctx=ctx)))
return gidx.get_immutable_gidx, edge_map, nbits return partial(gidx.get_bipartite, None), edge_map, nbits
def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None): def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None):
...@@ -212,6 +221,6 @@ def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None): ...@@ -212,6 +221,6 @@ def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None):
eid = utils.toindex(eid) eid = utils.toindex(eid)
else: else:
u, v, eid = edge_tuples u, v, eid = edge_tuples
num_nodes = max(graph.layer_size(block_id), graph.layer_size(block_id + 1)) num_src, num_dst = graph.layer_size(block_id), graph.layer_size(block_id + 1)
gidx, edge_map, nbits = build_gidx_and_mapping_uv((u, v, eid), num_nodes) gidx, edge_map, nbits = build_gidx_and_mapping_uv((u, v, eid), num_src, num_dst)
return gidx, edge_map, nbits return gidx, edge_map, nbits
This diff is collapsed.
...@@ -52,11 +52,6 @@ enum BoolFlag { ...@@ -52,11 +52,6 @@ enum BoolFlag {
dgl::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc( dgl::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc(
const std::vector<dgl::runtime::NDArray>& vec); const std::vector<dgl::runtime::NDArray>& vec);
/*!\brief Return whether the array is a valid 1D int array*/
inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ndim == 1 && arr->dtype.code == kDLInt;
}
/*! /*!
* \brief Copy a vector to an int64_t NDArray. * \brief Copy a vector to an int64_t NDArray.
* *
......
This diff is collapsed.
...@@ -7,12 +7,14 @@ ...@@ -7,12 +7,14 @@
#ifndef DGL_GRAPH_BIPARTITE_H_ #ifndef DGL_GRAPH_BIPARTITE_H_
#define DGL_GRAPH_BIPARTITE_H_ #define DGL_GRAPH_BIPARTITE_H_
#include <dgl/graph_interface.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <vector> #include <dgl/lazy.h>
#include <string> #include <dgl/array.h>
#include <utility> #include <utility>
#include <memory> #include <string>
#include <vector>
#include "../c_api_common.h"
namespace dgl { namespace dgl {
...@@ -32,6 +34,12 @@ class Bipartite : public BaseHeteroGraph { ...@@ -32,6 +34,12 @@ class Bipartite : public BaseHeteroGraph {
/*! \brief edge group type */ /*! \brief edge group type */
static constexpr dgl_type_t kEType = 0; static constexpr dgl_type_t kEType = 0;
// internal data structure
class COO;
class CSR;
typedef std::shared_ptr<COO> COOPtr;
typedef std::shared_ptr<CSR> CSRPtr;
uint64_t NumVertexTypes() const override { uint64_t NumVertexTypes() const override {
return 2; return 2;
} }
...@@ -140,14 +148,11 @@ class Bipartite : public BaseHeteroGraph { ...@@ -140,14 +148,11 @@ class Bipartite : public BaseHeteroGraph {
int64_t num_src, int64_t num_dst, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids); IdArray indptr, IdArray indices, IdArray edge_ids);
private: /*! \brief Convert the graph to use the given number of bits for storage */
// internal data structure static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
class COO;
class CSR;
typedef std::shared_ptr<COO> COOPtr;
typedef std::shared_ptr<CSR> CSRPtr;
Bipartite(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo); /*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx);
/*! \return Return the in-edge CSR format. Create from other format if not exist. */ /*! \return Return the in-edge CSR format. Create from other format if not exist. */
CSRPtr GetInCSR() const; CSRPtr GetInCSR() const;
...@@ -158,6 +163,18 @@ class Bipartite : public BaseHeteroGraph { ...@@ -158,6 +163,18 @@ class Bipartite : public BaseHeteroGraph {
/*! \return Return the COO format. Create from other format if not exist. */ /*! \return Return the COO format. Create from other format if not exist. */
COOPtr GetCOO() const; COOPtr GetCOO() const;
/*! \return Return the in-edge CSR in the matrix form */
aten::CSRMatrix GetInCSRMatrix() const;
/*! \return Return the out-edge CSR in the matrix form */
aten::CSRMatrix GetOutCSRMatrix() const;
/*! \return Return the COO matrix form */
aten::COOMatrix GetCOOMatrix() const;
private:
Bipartite(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo);
/*! \return Return any existing format. */ /*! \return Return any existing format. */
HeteroGraphPtr GetAny() const; HeteroGraphPtr GetAny() const;
......
This diff is collapsed.
...@@ -249,8 +249,8 @@ std::vector<GraphPtr> GraphOp::DisjointPartitionBySizes( ...@@ -249,8 +249,8 @@ std::vector<GraphPtr> GraphOp::DisjointPartitionBySizes(
} }
IdArray GraphOp::MapParentIdToSubgraphId(IdArray parent_vids, IdArray query) { IdArray GraphOp::MapParentIdToSubgraphId(IdArray parent_vids, IdArray query) {
CHECK(IsValidIdArray(parent_vids)) << "Invalid parent id array."; CHECK(aten::IsValidIdArray(parent_vids)) << "Invalid parent id array.";
CHECK(IsValidIdArray(query)) << "Invalid query id array."; CHECK(aten::IsValidIdArray(query)) << "Invalid query id array.";
const auto parent_len = parent_vids->shape[0]; const auto parent_len = parent_vids->shape[0];
const auto query_len = query->shape[0]; const auto query_len = query->shape[0];
const dgl_id_t* parent_data = static_cast<dgl_id_t*>(parent_vids->data); const dgl_id_t* parent_data = static_cast<dgl_id_t*>(parent_vids->data);
......
This diff is collapsed.
This diff is collapsed.
...@@ -749,7 +749,7 @@ std::vector<NodeFlow> NeighborSamplingImpl(const ImmutableGraphPtr gptr, ...@@ -749,7 +749,7 @@ std::vector<NodeFlow> NeighborSamplingImpl(const ImmutableGraphPtr gptr,
const bool add_self_loop, const bool add_self_loop,
const ValueType *probability) { const ValueType *probability) {
// process args // process args
CHECK(IsValidIdArray(seed_nodes)); CHECK(aten::IsValidIdArray(seed_nodes));
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data); const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
const int64_t num_seeds = seed_nodes->shape[0]; const int64_t num_seeds = seed_nodes->shape[0];
const int64_t num_workers = std::min(max_num_workers, const int64_t num_workers = std::min(max_num_workers,
...@@ -859,7 +859,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling") ...@@ -859,7 +859,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
// process args // process args
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph"; CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(IsValidIdArray(seed_nodes)); CHECK(aten::IsValidIdArray(seed_nodes));
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data); const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
const int64_t num_seeds = seed_nodes->shape[0]; const int64_t num_seeds = seed_nodes->shape[0];
const int64_t num_workers = std::min(max_num_workers, const int64_t num_workers = std::min(max_num_workers,
...@@ -1017,7 +1017,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling") ...@@ -1017,7 +1017,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
// process args // process args
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph"; CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(IsValidIdArray(seed_edges)); CHECK(aten::IsValidIdArray(seed_edges));
BuildCoo(*gptr); BuildCoo(*gptr);
const int64_t num_seeds = seed_edges->shape[0]; const int64_t num_seeds = seed_edges->shape[0];
......
...@@ -197,7 +197,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges") ...@@ -197,7 +197,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges")
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray source = args[1]; const IdArray source = args[1];
const bool reversed = args[2]; const bool reversed = args[2];
CHECK(IsValidIdArray(source)) << "Invalid source node id array."; CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data); const int64_t* src_data = static_cast<int64_t*>(source->data);
std::vector<std::vector<dgl_id_t>> edges(len); std::vector<std::vector<dgl_id_t>> edges(len);
...@@ -219,7 +219,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges") ...@@ -219,7 +219,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges")
const bool has_nontree_edge = args[4]; const bool has_nontree_edge = args[4];
const bool return_labels = args[5]; const bool return_labels = args[5];
CHECK(IsValidIdArray(source)) << "Invalid source node id array."; CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const int64_t* src_data = static_cast<int64_t*>(source->data); const int64_t* src_data = static_cast<int64_t*>(source->data);
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment