Unverified Commit bc3f852d authored by Gan Quan's avatar Gan Quan Committed by GitHub
Browse files

[GRAPHINDEX] Multigraph support (#79)

* multigraph support on graph index

* more tests

* multigraph flag, bugfix on clear & copy

* networkx interfaces

* including graph index tests in Jenkins

* node subgraph test

* edge subgraphs

* removing duplicates in pred/succ

* more explicit test and doc

* query source and destination from edge id

* subgraphindex

* renaming has_edge to has_edge_between, apply_edges adding eid

* send_on and send_and_recv_on

* DGLGraph edge subgraph

* merged send_on and send_and_recv_on

* change request

* removing hashmap

* creating multigraph by flag; mingw support

* changes per request

* reverting networkx auto multigraph discovery

* notes on send/send_and_recv on multigraphs

* changing test reducer from sum to max

* added a fixme note in spmv scheduler
parent 750e5037
...@@ -35,6 +35,7 @@ pipeline { ...@@ -35,6 +35,7 @@ pipeline {
withEnv(["DGL_LIBRARY_PATH=${env.WORKSPACE}/build"]) { withEnv(["DGL_LIBRARY_PATH=${env.WORKSPACE}/build"]) {
sh 'nosetests tests -v --with-xunit' sh 'nosetests tests -v --with-xunit'
sh 'nosetests tests/pytorch -v --with-xunit' sh 'nosetests tests/pytorch -v --with-xunit'
sh 'nosetests tests/graph_index -v --with-xunit'
} }
} }
} }
...@@ -77,6 +78,7 @@ pipeline { ...@@ -77,6 +78,7 @@ pipeline {
withEnv(["DGL_LIBRARY_PATH=${env.WORKSPACE}/build"]) { withEnv(["DGL_LIBRARY_PATH=${env.WORKSPACE}/build"]) {
sh 'nosetests tests -v --with-xunit' sh 'nosetests tests -v --with-xunit'
sh 'nosetests tests/pytorch -v --with-xunit' sh 'nosetests tests/pytorch -v --with-xunit'
sh 'nosetests tests/graph_index -v --with-xunit'
} }
} }
} }
......
...@@ -41,7 +41,7 @@ class Graph { ...@@ -41,7 +41,7 @@ class Graph {
} EdgeArray; } EdgeArray;
/*! \brief default constructor */ /*! \brief default constructor */
Graph() {} Graph(bool multigraph = false) : is_multigraph_(multigraph) {}
/*! \brief default copy constructor */ /*! \brief default copy constructor */
Graph(const Graph& other) = default; Graph(const Graph& other) = default;
...@@ -56,8 +56,9 @@ class Graph { ...@@ -56,8 +56,9 @@ class Graph {
all_edges_src_ = other.all_edges_src_; all_edges_src_ = other.all_edges_src_;
all_edges_dst_ = other.all_edges_dst_; all_edges_dst_ = other.all_edges_dst_;
read_only_ = other.read_only_; read_only_ = other.read_only_;
is_multigraph_ = other.is_multigraph_;
num_edges_ = other.num_edges_; num_edges_ = other.num_edges_;
other.clear(); other.Clear();
} }
#endif // _MSC_VER #endif // _MSC_VER
...@@ -101,6 +102,14 @@ class Graph { ...@@ -101,6 +102,14 @@ class Graph {
num_edges_ = 0; num_edges_ = 0;
} }
/*!
* \note not const since we have caches
* \return whether the graph is a multigraph
*/
bool IsMultigraph() const {
return is_multigraph_;
}
/*! \return the number of vertices in the graph.*/ /*! \return the number of vertices in the graph.*/
uint64_t NumVertices() const { uint64_t NumVertices() const {
return adjlist_.size(); return adjlist_.size();
...@@ -120,10 +129,10 @@ class Graph { ...@@ -120,10 +129,10 @@ class Graph {
BoolArray HasVertices(IdArray vids) const; BoolArray HasVertices(IdArray vids) const;
/*! \return true if the given edge is in the graph.*/ /*! \return true if the given edge is in the graph.*/
bool HasEdge(dgl_id_t src, dgl_id_t dst) const; bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const;
/*! \return a 0-1 array indicating whether the given edges are in the graph.*/ /*! \return a 0-1 array indicating whether the given edges are in the graph.*/
BoolArray HasEdges(IdArray src_ids, IdArray dst_ids) const; BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const;
/*! /*!
* \brief Find the predecessors of a vertex. * \brief Find the predecessors of a vertex.
...@@ -142,22 +151,32 @@ class Graph { ...@@ -142,22 +151,32 @@ class Graph {
IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const; IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const;
/*! /*!
* \brief Get the edge id using the two endpoints * \brief Get all edge ids between the two given endpoints
* \note Edges are associated with an integer id start from zero. * \note Edges are associated with an integer id start from zero.
* The id is assigned when the edge is being added to the graph. * The id is assigned when the edge is being added to the graph.
* \param src The source vertex. * \param src The source vertex.
* \param dst The destination vertex. * \param dst The destination vertex.
* \return the edge id. * \return the edge id array.
*/ */
dgl_id_t EdgeId(dgl_id_t src, dgl_id_t dst) const; IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const;
/*! /*!
* \brief Get the edge id using the two endpoints * \brief Get all edge ids between the given endpoint pairs.
* \note Edges are associated with an integer id start from zero. * \note Edges are associated with an integer id start from zero.
* The id is assigned when the edge is being added to the graph. * The id is assigned when the edge is being added to the graph.
* \return the edge id array. * If duplicate pairs exist, the returned edge IDs will also duplicate.
* The order of returned edge IDs will follow the order of src-dst pairs
* first, and ties are broken by the order of edge ID.
* \return EdgeArray containing all edges between all pairs.
*/ */
IdArray EdgeIds(IdArray src, IdArray dst) const; EdgeArray EdgeIds(IdArray src, IdArray dst) const;
/*!
* \brief Find the edge IDs and return their source and target node IDs.
* \param eids The edge ID array.
* \return EdgeArray containing all edges with id in eid. The order is preserved.
*/
EdgeArray FindEdges(IdArray eids) const;
/*! /*!
* \brief Get the in edges of the vertex. * \brief Get the in edges of the vertex.
...@@ -263,10 +282,10 @@ class Graph { ...@@ -263,10 +282,10 @@ class Graph {
* *
* The result subgraph is read-only. * The result subgraph is read-only.
* *
* \param vids The edges in the subgraph. * \param eids The edges in the subgraph.
* \return the induced edge subgraph * \return the induced edge subgraph
*/ */
Subgraph EdgeSubgraph(IdArray src, IdArray dst) const; Subgraph EdgeSubgraph(IdArray eids) const;
/*! /*!
* \brief Return a new graph with all the edges reversed. * \brief Return a new graph with all the edges reversed.
...@@ -300,6 +319,12 @@ class Graph { ...@@ -300,6 +319,12 @@ class Graph {
/*! \brief read only flag */ /*! \brief read only flag */
bool read_only_ = false; bool read_only_ = false;
/*!
* \brief Whether if this is a multigraph.
*
* When a multiedge is added, this flag switches to true.
*/
bool is_multigraph_ = false;
/*! \brief number of edges */ /*! \brief number of edges */
uint64_t num_edges_ = 0; uint64_t num_edges_ = 0;
}; };
......
...@@ -32,19 +32,23 @@ class DGLGraph(object): ...@@ -32,19 +32,23 @@ class DGLGraph(object):
Node feature storage. Node feature storage.
edge_frame : FrameRef edge_frame : FrameRef
Edge feature storage. Edge feature storage.
multigraph : bool, optional
Whether the graph would be a multigraph (default: False)
""" """
def __init__(self, def __init__(self,
graph_data=None, graph_data=None,
node_frame=None, node_frame=None,
edge_frame=None): edge_frame=None,
multigraph=False):
# graph # graph
self._graph = create_graph_index(graph_data) self._graph = create_graph_index(graph_data, multigraph)
# frame # frame
self._node_frame = node_frame if node_frame is not None else FrameRef() self._node_frame = node_frame if node_frame is not None else FrameRef()
self._edge_frame = edge_frame if edge_frame is not None else FrameRef() self._edge_frame = edge_frame if edge_frame is not None else FrameRef()
# msg graph & frame # msg graph & frame
self._msg_graph = create_graph_index() self._msg_graph = create_graph_index(multigraph=multigraph)
self._msg_frame = FrameRef() self._msg_frame = FrameRef()
self._msg_edges = []
self.reset_messages() self.reset_messages()
# registered functions # registered functions
self._message_func = None self._message_func = None
...@@ -109,11 +113,13 @@ class DGLGraph(object): ...@@ -109,11 +113,13 @@ class DGLGraph(object):
self._edge_frame.clear() self._edge_frame.clear()
self._msg_graph.clear() self._msg_graph.clear()
self._msg_frame.clear() self._msg_frame.clear()
self._msg_edges.clear()
def reset_messages(self): def reset_messages(self):
"""Clear all messages.""" """Clear all messages."""
self._msg_graph.clear() self._msg_graph.clear()
self._msg_frame.clear() self._msg_frame.clear()
self._msg_edges.clear()
self._msg_graph.add_nodes(self.number_of_nodes()) self._msg_graph.add_nodes(self.number_of_nodes())
def number_of_nodes(self): def number_of_nodes(self):
...@@ -130,6 +136,12 @@ class DGLGraph(object): ...@@ -130,6 +136,12 @@ class DGLGraph(object):
"""Return the number of nodes.""" """Return the number of nodes."""
return self.number_of_nodes() return self.number_of_nodes()
@property
def is_multigraph(self):
"""Whether the graph is a multigraph.
"""
return self._graph.is_multigraph()
def number_of_edges(self): def number_of_edges(self):
"""Return the number of edges. """Return the number of edges.
...@@ -176,7 +188,7 @@ class DGLGraph(object): ...@@ -176,7 +188,7 @@ class DGLGraph(object):
rst = self._graph.has_nodes(vids) rst = self._graph.has_nodes(vids)
return rst.tousertensor() return rst.tousertensor()
def has_edge(self, u, v): def has_edge_between(self, u, v):
"""Return true if the edge exists. """Return true if the edge exists.
Parameters Parameters
...@@ -191,9 +203,9 @@ class DGLGraph(object): ...@@ -191,9 +203,9 @@ class DGLGraph(object):
bool bool
True if the edge exists True if the edge exists
""" """
return self._graph.has_edge(u, v) return self._graph.has_edge_between(u, v)
def has_edges(self, u, v): def has_edges_between(self, u, v):
"""Return true if the edge exists. """Return true if the edge exists.
Parameters Parameters
...@@ -210,7 +222,7 @@ class DGLGraph(object): ...@@ -210,7 +222,7 @@ class DGLGraph(object):
""" """
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
rst = self._graph.has_edges(u, v) rst = self._graph.has_edges_between(u, v)
return rst.tousertensor() return rst.tousertensor()
def predecessors(self, v, radius=1): def predecessors(self, v, radius=1):
...@@ -247,7 +259,7 @@ class DGLGraph(object): ...@@ -247,7 +259,7 @@ class DGLGraph(object):
""" """
return self._graph.successors(v).tousertensor() return self._graph.successors(v).tousertensor()
def edge_id(self, u, v): def edge_id(self, u, v, force_multi=False):
"""Return the id of the edge. """Return the id of the edge.
Parameters Parameters
...@@ -256,15 +268,20 @@ class DGLGraph(object): ...@@ -256,15 +268,20 @@ class DGLGraph(object):
The src node. The src node.
v : int v : int
The dst node. The dst node.
force_multi : bool
If False, will return a single edge ID if the graph is a simple graph.
If True, will always return an array.
Returns Returns
------- -------
int int or tensor
The edge id. The edge id if force_multi == True and the graph is a simple graph.
The edge id array otherwise.
""" """
return self._graph.edge_id(u, v) idx = self._graph.edge_id(u, v)
return idx.tousertensor() if force_multi or self.is_multigraph else idx[0]
def edge_ids(self, u, v): def edge_ids(self, u, v, force_multi=False):
"""Return the edge ids. """Return the edge ids.
Parameters Parameters
...@@ -273,16 +290,23 @@ class DGLGraph(object): ...@@ -273,16 +290,23 @@ class DGLGraph(object):
The src nodes. The src nodes.
v : list, tensor v : list, tensor
The dst nodes. The dst nodes.
force_multi : bool
If False, will return a single edge ID array if the graph is a simple graph.
If True, will always return 3 arrays (src nodes, dst nodes, edge ids).
Returns Returns
------- -------
tensor tensor, or (tensor, tensor, tensor)
The edge id array. If force_multi is True or the graph is multigraph, return (src nodes, dst nodes, edge ids)
Otherwise, return a single tensor of edge ids.
""" """
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
rst = self._graph.edge_ids(u, v) src, dst, eid = self._graph.edge_ids(u, v)
return rst.tousertensor() if force_multi or self.is_multigraph:
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
else:
return eid.tousertensor()
def in_edges(self, v): def in_edges(self, v):
"""Return the in edges of the node(s). """Return the in edges of the node(s).
...@@ -612,7 +636,7 @@ class DGLGraph(object): ...@@ -612,7 +636,7 @@ class DGLGraph(object):
else: else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
eid = self._graph.edge_ids(u, v) _, _, eid = self._graph.edge_ids(u, v)
self.set_e_repr_by_id(h_uv, eid=eid) self.set_e_repr_by_id(h_uv, eid=eid)
def set_e_repr_by_id(self, h_uv, eid=ALL): def set_e_repr_by_id(self, h_uv, eid=ALL):
...@@ -674,7 +698,7 @@ class DGLGraph(object): ...@@ -674,7 +698,7 @@ class DGLGraph(object):
else: else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
eid = self._graph.edge_ids(u, v) _, _, eid = self._graph.edge_ids(u, v)
return self.get_e_repr_by_id(eid=eid) return self.get_e_repr_by_id(eid=eid)
def pop_e_repr(self, key=__REPR__): def pop_e_repr(self, key=__REPR__):
...@@ -809,28 +833,35 @@ class DGLGraph(object): ...@@ -809,28 +833,35 @@ class DGLGraph(object):
new_repr = reduce_accum new_repr = reduce_accum
self.set_n_repr(new_repr, v) self.set_n_repr(new_repr, v)
def apply_edges(self, u, v, apply_edge_func="default"): def apply_edges(self, u=None, v=None, apply_edge_func="default", eid=None):
"""Apply the function on edge representations. """Apply the function on edge representations.
Parameters Parameters
---------- ----------
u : int, iterable of int, tensor u : optional, int, iterable of int, tensor
The src node id(s). The src node id(s).
v : int, iterable of int, tensor v : optional, int, iterable of int, tensor
The dst node id(s). The dst node id(s).
apply_edge_func : callable apply_edge_func : callable
The apply edge function. The apply edge function.
eid : None, edge, container or tensor
The edge to update on. If eid is not None then u and v are ignored.
""" """
if apply_edge_func == "default": if apply_edge_func == "default":
apply_edge_func = self._apply_edge_func apply_edge_func = self._apply_edge_func
if not apply_edge_func: if not apply_edge_func:
# Skip none function call. # Skip none function call.
return return
new_repr = apply_edge_func(self.get_e_repr(u, v))
self.set_e_repr(new_repr, u, v)
def send(self, u, v, message_func="default"): if eid is None:
"""Trigger the message function on edge u->v new_repr = apply_edge_func(self.get_e_repr(u, v))
self.set_e_repr(new_repr, u, v)
else:
new_repr = apply_edge_func(self.get_e_repr_by_id(eid))
self.set_e_repr_by_id(new_repr, eid)
def send(self, u=None, v=None, message_func="default", eid=None):
"""Trigger the message function on edge u->v or eid
The message function should be compatible with following signature: The message function should be compatible with following signature:
...@@ -842,45 +873,105 @@ class DGLGraph(object): ...@@ -842,45 +873,105 @@ class DGLGraph(object):
The message function can be any of the pre-defined functions The message function can be any of the pre-defined functions
('from_src'). ('from_src').
Currently, we require the message functions of consecutive send's and
send_on's to return the same keys. Otherwise the behavior will be
undefined.
Parameters Parameters
---------- ----------
u : node, container or tensor u : optional, node, container or tensor
The source node(s). The source node(s).
v : node, container or tensor v : optional, node, container or tensor
The destination node(s). The destination node(s).
message_func : callable message_func : callable
The message function. The message function.
eid : optional, edge, container or tensor
The edge to update on. If eid is not None then u and v are ignored.
Notes
-----
On multigraphs, if u and v are specified, then the messages will be sent
along all edges between u and v.
""" """
if message_func == "default": if message_func == "default":
message_func = self._message_func message_func = self._message_func
assert message_func is not None assert message_func is not None
if isinstance(message_func, (tuple, list)): if isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func) message_func = BundledMessageFunction(message_func)
self._batch_send(u, v, message_func) self._batch_send(u, v, eid, message_func)
def _batch_send(self, u, v, message_func): def _batch_send(self, u, v, eid, message_func):
if is_all(u) and is_all(v): if is_all(u) and is_all(v) and eid is None:
u, v, _ = self._graph.edges() u, v, eid = self._graph.edges()
self._msg_graph.add_edges(u, v) # TODO(minjie): can be optimized
# call UDF # call UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr() edge_reprs = self.get_e_repr()
msgs = message_func(src_reprs, edge_reprs) msgs = message_func(src_reprs, edge_reprs)
elif eid is not None:
eid = utils.toindex(eid)
u, v, _ = self._graph.find_edges(eid)
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs)
else: else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v) u, v, eid = self._graph.edge_ids(u, v)
self._msg_graph.add_edges(u, v)
# call UDF # call UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr(u, v) edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs) msgs = message_func(src_reprs, edge_reprs)
new_uv = []
msg_target_rows = []
msg_update_rows = []
msg_append_rows = []
for i, (_u, _v, _eid) in enumerate(zip(u, v, eid)):
if _eid in self._msg_edges:
msg_target_rows.append(self._msg_edges.index(_eid))
msg_update_rows.append(i)
else:
new_uv.append((_u, _v))
self._msg_edges.append(_eid)
msg_append_rows.append(i)
msg_target_rows = utils.toindex(msg_target_rows)
msg_update_rows = utils.toindex(msg_update_rows)
msg_append_rows = utils.toindex(msg_append_rows)
if utils.is_dict_like(msgs): if utils.is_dict_like(msgs):
self._msg_frame.append(msgs) if len(msg_target_rows) > 0:
self._msg_frame.update_rows(
msg_target_rows,
{k: F.gather_row(msgs[k], msg_update_rows.tousertensor())
for k in msgs}
)
if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u)
new_v = utils.toindex(new_v)
self._msg_graph.add_edges(new_u, new_v)
self._msg_frame.append(
{k: F.gather_row(msgs[k], msg_append_rows.tousertensor())
for k in msgs}
)
else: else:
self._msg_frame.append({__MSG__ : msgs}) if len(msg_target_rows) > 0:
self._msg_frame.update_rows(
def update_edge(self, u=ALL, v=ALL, edge_func="default"): msg_target_rows,
{__MSG__: F.gather_row(msgs, msg_update_rows.tousertensor())}
)
if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u)
new_v = utils.toindex(new_v)
self._msg_graph.add_edges(new_u, new_v)
self._msg_frame.append(
{__MSG__: F.gather_row(msgs, msg_append_rows.tousertensor())}
)
def update_edge(self, u=ALL, v=ALL, edge_func="default", eid=None):
"""Update representation on edge u->v """Update representation on edge u->v
The edge function should be compatible with following signature: The edge function should be compatible with following signature:
...@@ -899,15 +990,17 @@ class DGLGraph(object): ...@@ -899,15 +990,17 @@ class DGLGraph(object):
The destination node(s). The destination node(s).
edge_func : callable edge_func : callable
The update function. The update function.
eid : optional, edge, container or tensor
The edge to update on. If eid is not None then u and v are ignored.
""" """
if edge_func == "default": if edge_func == "default":
edge_func = self._edge_func edge_func = self._edge_func
assert edge_func is not None assert edge_func is not None
self._batch_update_edge(u, v, edge_func) self._batch_update_edge(u, v, eid, edge_func)
def _batch_update_edge(self, u, v, edge_func): def _batch_update_edge(self, u, v, eid, edge_func):
if is_all(u) and is_all(v): if is_all(u) and is_all(v) and eid is None:
u, v, _ = self._graph.edges() u, v, eid = self._graph.edges()
# call the UDF # call the UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v) dst_reprs = self.get_n_repr(v)
...@@ -915,10 +1008,12 @@ class DGLGraph(object): ...@@ -915,10 +1008,12 @@ class DGLGraph(object):
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs) new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs)
self.set_e_repr(new_edge_reprs) self.set_e_repr(new_edge_reprs)
else: else:
u = utils.toindex(u) if eid is None:
v = utils.toindex(v) u = utils.toindex(u)
u, v = utils.edge_broadcasting(u, v) v = utils.toindex(v)
eid = self._graph.edge_ids(u, v) u, v = utils.edge_broadcasting(u, v)
_, _, eid = self._graph.edge_ids(u, v)
# call the UDF # call the UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v) dst_reprs = self.get_n_repr(v)
...@@ -1036,17 +1131,19 @@ class DGLGraph(object): ...@@ -1036,17 +1131,19 @@ class DGLGraph(object):
self.set_n_repr(new_reprs, reordered_v) self.set_n_repr(new_reprs, reordered_v)
def send_and_recv(self, def send_and_recv(self,
u, v, u=None, v=None,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default"): apply_node_func="default",
"""Trigger the message function on u->v and update v. eid=None):
"""Trigger the message function on u->v and update v, or on edge eid
and update the destination nodes.
Parameters Parameters
---------- ----------
u : node, container or tensor u : optional, node, container or tensor
The source node(s). The source node(s).
v : node, container or tensor v : optional, node, container or tensor
The destination node(s). The destination node(s).
message_func : callable message_func : callable
The message function. The message function.
...@@ -1054,14 +1151,12 @@ class DGLGraph(object): ...@@ -1054,14 +1151,12 @@ class DGLGraph(object):
The reduce function. The reduce function.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
"""
u = utils.toindex(u)
v = utils.toindex(v)
if len(u) == 0:
# no edges to be triggered
assert len(v) == 0
return
Notes
-----
On multigraphs, if u and v are specified, then the messages will be sent
and received along all edges between u and v.
"""
if message_func == "default": if message_func == "default":
message_func = self._message_func message_func = self._message_func
if reduce_func == "default": if reduce_func == "default":
...@@ -1069,14 +1164,42 @@ class DGLGraph(object): ...@@ -1069,14 +1164,42 @@ class DGLGraph(object):
assert message_func is not None assert message_func is not None
assert reduce_func is not None assert reduce_func is not None
executor = scheduler.get_executor( if eid is None:
'send_and_recv', self, src=u, dst=v, if u is None or v is None:
message_func=message_func, reduce_func=reduce_func) raise ValueError('u and v must be given if eid is None')
u = utils.toindex(u)
v = utils.toindex(v)
if len(u) == 0:
# no edges to be triggered
assert len(v) == 0
return
unique_v = utils.toindex(F.unique(v.tousertensor()))
executor = scheduler.get_executor(
'send_and_recv', self, src=u, dst=v,
message_func=message_func, reduce_func=reduce_func)
else:
eid = utils.toindex(eid)
if len(eid) == 0:
# no edges to be triggered
return
executor = None
if executor: if executor:
new_reprs = executor.run() new_reprs = executor.run()
if not utils.is_dict_like(new_reprs): if not utils.is_dict_like(new_reprs):
new_reprs = {__REPR__: new_reprs} new_reprs = {__REPR__: new_reprs}
unique_v = executor.recv_nodes unique_v = executor.recv_nodes
self._apply_nodes(unique_v, apply_node_func, reduce_accum=new_reprs)
elif eid is not None:
_, v, _ = self._graph.find_edges(eid)
unique_v = utils.toindex(F.unique(v.tousertensor()))
# TODO: replace with the new DegreeBucketingScheduler
self.send(eid=eid, message_func=message_func)
self.recv(unique_v, reduce_func, apply_node_func)
else: else:
# handle multiple message and reduce func # handle multiple message and reduce func
if isinstance(message_func, (tuple, list)): if isinstance(message_func, (tuple, list)):
...@@ -1103,7 +1226,7 @@ class DGLGraph(object): ...@@ -1103,7 +1226,7 @@ class DGLGraph(object):
new_reprs = executor.run() new_reprs = executor.run()
unique_v = executor.recv_nodes unique_v = executor.recv_nodes
self._apply_nodes(unique_v, apply_node_func, reduce_accum=new_reprs) self._apply_nodes(unique_v, apply_node_func, reduce_accum=new_reprs)
def pull(self, def pull(self,
v, v,
...@@ -1238,8 +1361,25 @@ class DGLGraph(object): ...@@ -1238,8 +1361,25 @@ class DGLGraph(object):
The subgraph. The subgraph.
""" """
induced_nodes = utils.toindex(nodes) induced_nodes = utils.toindex(nodes)
gi, induced_edges = self._graph.node_subgraph(induced_nodes) sgi = self._graph.node_subgraph(induced_nodes)
return dgl.DGLSubGraph(self, induced_nodes, induced_edges, gi) return dgl.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
def edge_subgraph(self, edges):
"""Generate the subgraph among the given edges.
Parameters
----------
edges : list, or iterable
A container of the edges to construct subgraph.
Returns
-------
G : DGLSubGraph
The subgraph.
"""
induced_edges = utils.toindex(edges)
sgi = self._graph.edge_subgraph(induced_edges)
return dgl.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
def merge(self, subgraphs, reduce_func='sum'): def merge(self, subgraphs, reduce_func='sum'):
"""Merge subgraph features back to this parent graph. """Merge subgraph features back to this parent graph.
......
...@@ -72,6 +72,16 @@ class GraphIndex(object): ...@@ -72,6 +72,16 @@ class GraphIndex(object):
_CAPI_DGLGraphClear(self._handle) _CAPI_DGLGraphClear(self._handle)
self._cache.clear() self._cache.clear()
def is_multigraph(self):
"""Return whether the graph is a multigraph
Returns
-------
bool
True if it is a multigraph, False otherwise.
"""
return bool(_CAPI_DGLGraphIsMultigraph(self._handle))
def number_of_nodes(self): def number_of_nodes(self):
"""Return the number of nodes. """Return the number of nodes.
...@@ -103,9 +113,9 @@ class GraphIndex(object): ...@@ -103,9 +113,9 @@ class GraphIndex(object):
Returns Returns
------- -------
bool bool
True if the node exists True if the node exists, False otherwise.
""" """
return _CAPI_DGLGraphHasVertex(self._handle, vid) return bool(_CAPI_DGLGraphHasVertex(self._handle, vid))
def has_nodes(self, vids): def has_nodes(self, vids):
"""Return true if the nodes exist. """Return true if the nodes exist.
...@@ -123,7 +133,7 @@ class GraphIndex(object): ...@@ -123,7 +133,7 @@ class GraphIndex(object):
vid_array = vids.todgltensor() vid_array = vids.todgltensor()
return utils.toindex(_CAPI_DGLGraphHasVertices(self._handle, vid_array)) return utils.toindex(_CAPI_DGLGraphHasVertices(self._handle, vid_array))
def has_edge(self, u, v): def has_edge_between(self, u, v):
"""Return true if the edge exists. """Return true if the edge exists.
Parameters Parameters
...@@ -136,11 +146,11 @@ class GraphIndex(object): ...@@ -136,11 +146,11 @@ class GraphIndex(object):
Returns Returns
------- -------
bool bool
True if the edge exists True if the edge exists, False otherwise
""" """
return _CAPI_DGLGraphHasEdge(self._handle, u, v) return bool(_CAPI_DGLGraphHasEdgeBetween(self._handle, u, v))
def has_edges(self, u, v): def has_edges_between(self, u, v):
"""Return true if the edge exists. """Return true if the edge exists.
Parameters Parameters
...@@ -157,7 +167,7 @@ class GraphIndex(object): ...@@ -157,7 +167,7 @@ class GraphIndex(object):
""" """
u_array = u.todgltensor() u_array = u.todgltensor()
v_array = v.todgltensor() v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphHasEdges(self._handle, u_array, v_array)) return utils.toindex(_CAPI_DGLGraphHasEdgesBetween(self._handle, u_array, v_array))
def predecessors(self, v, radius=1): def predecessors(self, v, radius=1):
"""Return the predecessors of the node. """Return the predecessors of the node.
...@@ -194,7 +204,7 @@ class GraphIndex(object): ...@@ -194,7 +204,7 @@ class GraphIndex(object):
return utils.toindex(_CAPI_DGLGraphSuccessors(self._handle, v, radius)) return utils.toindex(_CAPI_DGLGraphSuccessors(self._handle, v, radius))
def edge_id(self, u, v): def edge_id(self, u, v):
"""Return the id of the edge. """Return the id array of all edges between u and v.
Parameters Parameters
---------- ----------
...@@ -205,13 +215,13 @@ class GraphIndex(object): ...@@ -205,13 +215,13 @@ class GraphIndex(object):
Returns Returns
------- -------
int utils.Index
The edge id. The edge id array.
""" """
return _CAPI_DGLGraphEdgeId(self._handle, u, v) return utils.toindex(_CAPI_DGLGraphEdgeId(self._handle, u, v))
def edge_ids(self, u, v): def edge_ids(self, u, v):
"""Return the edge ids. """Return a triplet of arrays that contains the edge IDs.
Parameters Parameters
---------- ----------
...@@ -223,11 +233,47 @@ class GraphIndex(object): ...@@ -223,11 +233,47 @@ class GraphIndex(object):
Returns Returns
------- -------
utils.Index utils.Index
Teh edge id array. The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
""" """
u_array = u.todgltensor() u_array = u.todgltensor()
v_array = v.todgltensor() v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphEdgeIds(self._handle, u_array, v_array)) edge_array = _CAPI_DGLGraphEdgeIds(self._handle, u_array, v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def find_edges(self, eid):
"""Return a triplet of arrays that contains the edge IDs.
Parameters
----------
eid : utils.Index
The edge ids.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
eid_array = eid.todgltensor()
edge_array = _CAPI_DGLGraphFindEdges(self._handle, eid_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def in_edges(self, v): def in_edges(self, v):
"""Return the in edges of the node(s). """Return the in edges of the node(s).
...@@ -378,16 +424,32 @@ class GraphIndex(object): ...@@ -378,16 +424,32 @@ class GraphIndex(object):
Returns Returns
------- -------
GraphIndex SubgraphIndex
The subgraph index. The subgraph index.
utils.Index
The induced edge ids. This is also a map from new edge id to parent edge id.
""" """
v_array = v.todgltensor() v_array = v.todgltensor()
rst = _CAPI_DGLGraphVertexSubgraph(self._handle, v_array) rst = _CAPI_DGLGraphVertexSubgraph(self._handle, v_array)
gi = GraphIndex(rst(0))
induced_edges = utils.toindex(rst(2)) induced_edges = utils.toindex(rst(2))
return gi, induced_edges return SubgraphIndex(rst(0), self, v, induced_edges)
def edge_subgraph(self, e):
"""Return the induced edge subgraph.
Parameters
----------
e : utils.Index
The edges.
Returns
-------
SubgraphIndex
The subgraph index.
"""
e_array = e.todgltensor()
rst = _CAPI_DGLGraphEdgeSubgraph(self._handle, e_array)
gi = GraphIndex(rst(0))
induced_nodes = utils.toindex(rst(1))
return SubgraphIndex(rst(0), self, induced_nodes, e)
def adjacency_matrix(self): def adjacency_matrix(self):
"""Return the adjacency matrix representation of this graph. """Return the adjacency matrix representation of this graph.
...@@ -460,7 +522,7 @@ class GraphIndex(object): ...@@ -460,7 +522,7 @@ class GraphIndex(object):
The nx graph The nx graph
""" """
src, dst, eid = self.edges() src, dst, eid = self.edges()
ret = nx.DiGraph() ret = nx.MultiDiGraph() if self.is_multigraph() else nx.DiGraph()
for u, v, id in zip(src, dst, eid): for u, v, id in zip(src, dst, eid):
ret.add_edge(u, v, id=id) ret.add_edge(u, v, id=id)
return ret return ret
...@@ -477,8 +539,13 @@ class GraphIndex(object): ...@@ -477,8 +539,13 @@ class GraphIndex(object):
The nx graph The nx graph
""" """
self.clear() self.clear()
if not isinstance(nx_graph, nx.DiGraph):
nx_graph = nx.DiGraph(nx_graph) if not isinstance(nx_graph, nx.Graph):
nx_graph = (nx.MultiDiGraph(nx_graph) if self.is_multigraph()
else nx.DiGraph(nx_graph))
else:
nx_graph = nx_graph.to_directed()
num_nodes = nx_graph.number_of_nodes() num_nodes = nx_graph.number_of_nodes()
self.add_nodes(num_nodes) self.add_nodes(num_nodes)
has_edge_id = 'id' in next(iter(nx_graph.edges)) has_edge_id = 'id' in next(iter(nx_graph.edges))
...@@ -487,16 +554,16 @@ class GraphIndex(object): ...@@ -487,16 +554,16 @@ class GraphIndex(object):
src = np.zeros((num_edges,), dtype=np.int64) src = np.zeros((num_edges,), dtype=np.int64)
dst = np.zeros((num_edges,), dtype=np.int64) dst = np.zeros((num_edges,), dtype=np.int64)
for e, attr in nx_graph.edges.items: for e, attr in nx_graph.edges.items:
u, v = e # MultiDiGraph returns a triplet in e while DiGraph returns a pair
eid = attr['id'] eid = attr['id']
src[eid] = u src[eid] = e[0]
dst[eid] = v dst[eid] = e[1]
else: else:
src = [] src = []
dst = [] dst = []
for u, v in nx_graph.edges: for e in nx_graph.edges:
src.append(u) src.append(e[0])
dst.append(v) dst.append(e[1])
src = utils.toindex(src) src = utils.toindex(src)
dst = utils.toindex(dst) dst = utils.toindex(dst)
self.add_edges(src, dst) self.add_edges(src, dst)
...@@ -531,7 +598,32 @@ class GraphIndex(object): ...@@ -531,7 +598,32 @@ class GraphIndex(object):
""" """
handle = _CAPI_DGLGraphLineGraph(self._handle, backtracking) handle = _CAPI_DGLGraphLineGraph(self._handle, backtracking)
return GraphIndex(handle) return GraphIndex(handle)
class SubgraphIndex(GraphIndex):
def __init__(self, handle, parent, induced_nodes, induced_edges):
super().__init__(handle)
self._parent = parent
self._induced_nodes = induced_nodes
self._induced_edges = induced_edges
def add_nodes(self, num):
raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v):
raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v):
raise RuntimeError('Readonly graph. Mutation is not allowed.')
@property
def induced_edges(self):
return self._induced_edges
@property
def induced_nodes(self):
return self._induced_nodes
def disjoint_union(graphs): def disjoint_union(graphs):
"""Return a disjoint union of the input graphs. """Return a disjoint union of the input graphs.
...@@ -590,17 +682,20 @@ def disjoint_partition(graph, num_or_size_splits): ...@@ -590,17 +682,20 @@ def disjoint_partition(graph, num_or_size_splits):
graphs.append(GraphIndex(handle)) graphs.append(GraphIndex(handle))
return graphs return graphs
def create_graph_index(graph_data=None): def create_graph_index(graph_data=None, multigraph=False):
"""Create a graph index object. """Create a graph index object.
Parameters Parameters
---------- ----------
graph_data : graph data, optional graph_data : graph data, optional
Data to initialize graph. Same as networkx's semantics. Data to initialize graph. Same as networkx's semantics.
multigraph : bool, optional
Whether the graph is multigraph (default is False)
""" """
if isinstance(graph_data, GraphIndex): if isinstance(graph_data, GraphIndex):
return graph_data return graph_data
handle = _CAPI_DGLGraphCreate()
handle = _CAPI_DGLGraphCreate(multigraph)
gi = GraphIndex(handle) gi = GraphIndex(handle)
if graph_data is not None: if graph_data is not None:
gi.from_networkx(graph_data) gi.from_networkx(graph_data)
......
...@@ -416,6 +416,14 @@ class BundledSendRecvExecutor(BundledExecutor, SendRecvExecutor): ...@@ -416,6 +416,14 @@ class BundledSendRecvExecutor(BundledExecutor, SendRecvExecutor):
BundledExecutor.__init__(self, graph, mfunc, rfunc) BundledExecutor.__init__(self, graph, mfunc, rfunc)
def _is_spmv_supported(fn, graph=None): def _is_spmv_supported(fn, graph=None):
# FIXME: also take into account
# (1) which backend DGL is under.
# (2) whether the graph is a multigraph.
#
# Current SPMV optimizer assumes that duplicate entries are summed up
# in sparse matrices, which is the case for PyTorch but not MXNet.
# The result is that on multigraphs, SPMV can still work for reducer=sum
# and message=copy_src/src_mul_edge *only in PyTorch*.
if isinstance(fn, fmsg.MessageFunction): if isinstance(fn, fmsg.MessageFunction):
return fn.is_spmv_supported(graph) return fn.is_spmv_supported(graph)
elif isinstance(fn, fred.ReduceFunction): elif isinstance(fn, fred.ReduceFunction):
......
// Graph class implementation // Graph class implementation
#include <algorithm> #include <algorithm>
#include <unordered_map> #include <unordered_map>
#include <set>
#include <functional>
#include <dgl/graph.h> #include <dgl/graph.h>
namespace dgl { namespace dgl {
...@@ -21,11 +23,14 @@ void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) { ...@@ -21,11 +23,14 @@ void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed."; CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
CHECK(HasVertex(src) && HasVertex(dst)) CHECK(HasVertex(src) && HasVertex(dst))
<< "Invalid vertices: src=" << src << " dst=" << dst; << "Invalid vertices: src=" << src << " dst=" << dst;
dgl_id_t eid = num_edges_++; dgl_id_t eid = num_edges_++;
adjlist_[src].succ.push_back(dst); adjlist_[src].succ.push_back(dst);
adjlist_[src].edge_id.push_back(eid); adjlist_[src].edge_id.push_back(eid);
reverse_adjlist_[dst].succ.push_back(src); reverse_adjlist_[dst].succ.push_back(src);
reverse_adjlist_[dst].edge_id.push_back(eid); reverse_adjlist_[dst].edge_id.push_back(eid);
all_edges_src_.push_back(src); all_edges_src_.push_back(src);
all_edges_dst_.push_back(dst); all_edges_dst_.push_back(dst);
} }
...@@ -71,14 +76,14 @@ BoolArray Graph::HasVertices(IdArray vids) const { ...@@ -71,14 +76,14 @@ BoolArray Graph::HasVertices(IdArray vids) const {
} }
// O(E) // O(E)
bool Graph::HasEdge(dgl_id_t src, dgl_id_t dst) const { bool Graph::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {
if (!HasVertex(src) || !HasVertex(dst)) return false; if (!HasVertex(src) || !HasVertex(dst)) return false;
const auto& succ = adjlist_[src].succ; const auto& succ = adjlist_[src].succ;
return std::find(succ.begin(), succ.end(), dst) != succ.end(); return std::find(succ.begin(), succ.end(), dst) != succ.end();
} }
// O(E*K) pretty slow // O(E*k) pretty slow
BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const { BoolArray Graph::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array."; CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array."; CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0]; const auto srclen = src_ids->shape[0];
...@@ -91,18 +96,18 @@ BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const { ...@@ -91,18 +96,18 @@ BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const {
if (srclen == 1) { if (srclen == 1) {
// one-many // one-many
for (int64_t i = 0; i < dstlen; ++i) { for (int64_t i = 0; i < dstlen; ++i) {
rst_data[i] = HasEdge(src_data[0], dst_data[i])? 1 : 0; rst_data[i] = HasEdgeBetween(src_data[0], dst_data[i])? 1 : 0;
} }
} else if (dstlen == 1) { } else if (dstlen == 1) {
// many-one // many-one
for (int64_t i = 0; i < srclen; ++i) { for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdge(src_data[i], dst_data[0])? 1 : 0; rst_data[i] = HasEdgeBetween(src_data[i], dst_data[0])? 1 : 0;
} }
} else { } else {
// many-many // many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array."; CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) { for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdge(src_data[i], dst_data[i])? 1 : 0; rst_data[i] = HasEdgeBetween(src_data[i], dst_data[i])? 1 : 0;
} }
} }
return rst; return rst;
...@@ -112,13 +117,16 @@ BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const { ...@@ -112,13 +117,16 @@ BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const {
IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const { IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
CHECK(radius >= 1) << "invalid radius: " << radius; CHECK(radius >= 1) << "invalid radius: " << radius;
const auto& pred = reverse_adjlist_[vid].succ; std::set<dgl_id_t> vset;
const int64_t len = pred.size();
for (auto& it : reverse_adjlist_[vid].succ)
vset.insert(it);
const int64_t len = vset.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data); int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = pred[i]; std::copy(vset.begin(), vset.end(), rst_data);
}
return rst; return rst;
} }
...@@ -126,58 +134,109 @@ IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const { ...@@ -126,58 +134,109 @@ IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const {
IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const { IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
CHECK(radius >= 1) << "invalid radius: " << radius; CHECK(radius >= 1) << "invalid radius: " << radius;
const auto& succ = adjlist_[vid].succ; std::set<dgl_id_t> vset;
const int64_t len = succ.size();
for (auto& it : adjlist_[vid].succ)
vset.insert(it);
const int64_t len = vset.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data); int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = succ[i]; std::copy(vset.begin(), vset.end(), rst_data);
}
return rst; return rst;
} }
// O(E) // O(E)
dgl_id_t Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const { IdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
CHECK(HasVertex(src)) << "invalid edge: " << src << " -> " << dst; CHECK(HasVertex(src) && HasVertex(dst)) << "invalid edge: " << src << " -> " << dst;
const auto& succ = adjlist_[src].succ; const auto& succ = adjlist_[src].succ;
std::vector<dgl_id_t> edgelist;
for (size_t i = 0; i < succ.size(); ++i) { for (size_t i = 0; i < succ.size(); ++i) {
if (succ[i] == dst) { if (succ[i] == dst)
return adjlist_[src].edge_id[i]; edgelist.push_back(adjlist_[src].edge_id[i]);
}
} }
LOG(FATAL) << "invalid edge: " << src << " -> " << dst;
return 0; // FIXME: signed? Also it seems that we are using int64_t everywhere...
const int64_t len = edgelist.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
// FIXME: signed?
int64_t* rst_data = static_cast<int64_t*>(rst->data);
std::copy(edgelist.begin(), edgelist.end(), rst_data);
return rst;
} }
// O(E*k) pretty slow // O(E*k) pretty slow
IdArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const { Graph::EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array."; CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array."; CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0]; const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0]; const auto dstlen = dst_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen); int64_t i, j;
IdArray rst = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data); CHECK((srclen == dstlen) || (srclen == 1) || (dstlen == 1))
<< "Invalid src and dst id array.";
const int64_t src_stride = (srclen == 1 && dstlen != 1) ? 0 : 1;
const int64_t dst_stride = (dstlen == 1 && srclen != 1) ? 0 : 1;
const int64_t* src_data = static_cast<int64_t*>(src_ids->data); const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data); const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
if (srclen == 1) {
// one-many std::vector<dgl_id_t> src, dst, eid;
for (int64_t i = 0; i < dstlen; ++i) {
rst_data[i] = EdgeId(src_data[0], dst_data[i]); for (i = 0, j = 0; i < srclen && j < dstlen; i += src_stride, j += dst_stride) {
} const dgl_id_t src_id = src_data[i], dst_id = dst_data[j];
} else if (dstlen == 1) { const auto& succ = adjlist_[src_id].succ;
// many-one for (size_t k = 0; k < succ.size(); ++k) {
for (int64_t i = 0; i < srclen; ++i) { if (succ[k] == dst_id) {
rst_data[i] = EdgeId(src_data[i], dst_data[0]); src.push_back(src_id);
} dst.push_back(dst_id);
} else { eid.push_back(adjlist_[src_id].edge_id[k]);
// many-many }
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = EdgeId(src_data[i], dst_data[i]);
} }
} }
return rst;
int64_t rstlen = src.size();
IdArray rst_src = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
IdArray rst_dst = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
IdArray rst_eid = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
int64_t* rst_src_data = static_cast<int64_t*>(rst_src->data);
int64_t* rst_dst_data = static_cast<int64_t*>(rst_dst->data);
int64_t* rst_eid_data = static_cast<int64_t*>(rst_eid->data);
std::copy(src.begin(), src.end(), rst_src_data);
std::copy(dst.begin(), dst.end(), rst_dst_data);
std::copy(eid.begin(), eid.end(), rst_eid_data);
return EdgeArray{rst_src, rst_dst, rst_eid};
}
Graph::EdgeArray Graph::FindEdges(IdArray eids) const {
int64_t len = eids->shape[0];
IdArray rst_src = IdArray::Empty({len}, eids->dtype, eids->ctx);
IdArray rst_dst = IdArray::Empty({len}, eids->dtype, eids->ctx);
IdArray rst_eid = IdArray::Empty({len}, eids->dtype, eids->ctx);
int64_t* eid_data = static_cast<int64_t*>(eids->data);
int64_t* rst_src_data = static_cast<int64_t*>(rst_src->data);
int64_t* rst_dst_data = static_cast<int64_t*>(rst_dst->data);
int64_t* rst_eid_data = static_cast<int64_t*>(rst_eid->data);
for (uint64_t i = 0; i < (uint64_t)len; ++i) {
dgl_id_t eid = eid_data[i];
if (eid >= num_edges_)
LOG(FATAL) << "invalid edge id:" << eid;
rst_src_data[i] = all_edges_src_[eid];
rst_dst_data[i] = all_edges_dst_[eid];
rst_eid_data[i] = eid;
}
return EdgeArray{rst_src, rst_dst, rst_eid};
} }
// O(E) // O(E)
...@@ -375,9 +434,37 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const { ...@@ -375,9 +434,37 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const {
return rst; return rst;
} }
Subgraph Graph::EdgeSubgraph(IdArray src, IdArray dst) const { Subgraph Graph::EdgeSubgraph(IdArray eids) const {
LOG(FATAL) << "not implemented"; CHECK(IsValidIdArray(eids)) << "Invalid vertex id array.";
return Subgraph();
const auto len = eids->shape[0];
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;
std::vector<dgl_id_t> nodes;
const int64_t* eid_data = static_cast<int64_t*>(eids->data);
for (int64_t i = 0; i < len; ++i) {
dgl_id_t src_id = all_edges_src_[eid_data[i]];
dgl_id_t dst_id = all_edges_dst_[eid_data[i]];
if (oldv2newv.insert(std::make_pair(src_id, oldv2newv.size())).second)
nodes.push_back(src_id);
if (oldv2newv.insert(std::make_pair(dst_id, oldv2newv.size())).second)
nodes.push_back(dst_id);
}
Subgraph rst;
rst.induced_edges = eids;
rst.graph.AddVertices(nodes.size());
for (int64_t i = 0; i < len; ++i) {
dgl_id_t src_id = all_edges_src_[eid_data[i]];
dgl_id_t dst_id = all_edges_dst_[eid_data[i]];
rst.graph.AddEdge(oldv2newv[src_id], oldv2newv[dst_id]);
}
rst.induced_vertices = IdArray::Empty({static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);
std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data));
return rst;
} }
Graph Graph::Reverse() const { Graph Graph::Reverse() const {
......
...@@ -52,7 +52,8 @@ PackedFunc ConvertSubgraphToPackedFunc(const Subgraph& sg) { ...@@ -52,7 +52,8 @@ PackedFunc ConvertSubgraphToPackedFunc(const Subgraph& sg) {
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreate") TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreate")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = new Graph(); bool multigraph = static_cast<bool>(args[0]);
GraphHandle ghandle = new Graph(multigraph);
*rv = ghandle; *rv = ghandle;
}); });
...@@ -96,6 +97,14 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphClear") ...@@ -96,6 +97,14 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphClear")
gptr->Clear(); gptr->Clear();
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsMultigraph")
.set_body([] (TVMArgs args, TVMRetValue *rv) {
GraphHandle ghandle = args[0];
// NOTE: not const since we have caches
const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = gptr->IsMultigraph();
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumVertices") TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumVertices")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
...@@ -126,22 +135,22 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertices") ...@@ -126,22 +135,22 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertices")
*rv = gptr->HasVertices(vids); *rv = gptr->HasVertices(vids);
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdge") TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgeBetween")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1]; const dgl_id_t src = args[1];
const dgl_id_t dst = args[2]; const dgl_id_t dst = args[2];
*rv = gptr->HasEdge(src, dst); *rv = gptr->HasEdgeBetween(src, dst);
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdges") TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgesBetween")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const IdArray dst = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2])); const IdArray dst = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
*rv = gptr->HasEdges(src, dst); *rv = gptr->HasEdgesBetween(src, dst);
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors") TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors")
...@@ -168,7 +177,7 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId") ...@@ -168,7 +177,7 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId")
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1]; const dgl_id_t src = args[1];
const dgl_id_t dst = args[2]; const dgl_id_t dst = args[2];
*rv = static_cast<int64_t>(gptr->EdgeId(src, dst)); *rv = gptr->EdgeId(src, dst);
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds") TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds")
...@@ -177,7 +186,15 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds") ...@@ -177,7 +186,15 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds")
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1])); const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const IdArray dst = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2])); const IdArray dst = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
*rv = gptr->EdgeIds(src, dst); *rv = ConvertEdgeArrayToPackedFunc(gptr->EdgeIds(src, dst));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFindEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray eids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertEdgeArrayToPackedFunc(gptr->FindEdges(eids));
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_1") TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_1")
...@@ -260,6 +277,14 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph") ...@@ -260,6 +277,14 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
*rv = ConvertSubgraphToPackedFunc(gptr->VertexSubgraph(vids)); *rv = ConvertSubgraphToPackedFunc(gptr->VertexSubgraph(vids));
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph *gptr = static_cast<Graph*>(ghandle);
const IdArray eids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertSubgraphToPackedFunc(gptr->EdgeSubgraph(eids));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion") TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (TVMArgs args, TVMRetValue* rv) {
void* list = args[0]; void* list = args[0];
......
...@@ -25,7 +25,7 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -25,7 +25,7 @@ class CPUDeviceAPI final : public DeviceAPI {
size_t alignment, size_t alignment,
TVMType type_hint) final { TVMType type_hint) final {
void* ptr; void* ptr;
#if _MSC_VER #if _MSC_VER || defined(__MINGW32__)
ptr = _aligned_malloc(nbytes, alignment); ptr = _aligned_malloc(nbytes, alignment);
if (ptr == nullptr) throw std::bad_alloc(); if (ptr == nullptr) throw std::bad_alloc();
#elif defined(_LIBCPP_SGX_CONFIG) #elif defined(_LIBCPP_SGX_CONFIG)
...@@ -39,7 +39,7 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -39,7 +39,7 @@ class CPUDeviceAPI final : public DeviceAPI {
} }
void FreeDataSpace(TVMContext ctx, void* ptr) final { void FreeDataSpace(TVMContext ctx, void* ptr) final {
#if _MSC_VER #if _MSC_VER || defined(__MINGW32__)
_aligned_free(ptr); _aligned_free(ptr);
#else #else
free(ptr); free(ptr);
......
...@@ -136,8 +136,8 @@ DLManagedTensor* NDArray::ToDLPack() const { ...@@ -136,8 +136,8 @@ DLManagedTensor* NDArray::ToDLPack() const {
} }
NDArray NDArray::Empty(std::vector<int64_t> shape, NDArray NDArray::Empty(std::vector<int64_t> shape,
DLDataType dtype, DLDataType dtype,
DLContext ctx) { DLContext ctx) {
NDArray ret = Internal::Create(shape, dtype, ctx); NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content // setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor); size_t size = GetDataSize(ret.data_->dl_tensor);
......
from dgl import DGLError
from dgl.utils import toindex
from dgl.graph_index import create_graph_index
import networkx as nx
def test_edge_id():
gi = create_graph_index(multigraph=False)
assert not gi.is_multigraph()
gi = create_graph_index(multigraph=True)
gi.add_nodes(4)
gi.add_edge(0, 1)
eid = gi.edge_id(0, 1).tolist()
assert len(eid) == 1
assert eid[0] == 0
assert gi.is_multigraph()
# multiedges
gi.add_edge(0, 1)
eid = gi.edge_id(0, 1).tolist()
assert len(eid) == 2
assert eid[0] == 0
assert eid[1] == 1
gi.add_edges(toindex([0, 1, 1, 2]), toindex([2, 2, 2, 3]))
src, dst, eid = gi.edge_ids(toindex([0, 0, 2, 1]), toindex([2, 1, 3, 2]))
eid_answer = [2, 0, 1, 5, 3, 4]
assert len(eid) == 6
assert all(e == ea for e, ea in zip(eid, eid_answer))
# find edges
src, dst, eid = gi.find_edges(toindex([1, 3, 5]))
assert len(src) == len(dst) == len(eid) == 3
assert src[0] == 0 and src[1] == 1 and src[2] == 2
assert dst[0] == 1 and dst[1] == 2 and dst[2] == 3
assert eid[0] == 1 and eid[1] == 3 and eid[2] == 5
# source broadcasting
src, dst, eid = gi.edge_ids(toindex([0]), toindex([1, 2]))
eid_answer = [0, 1, 2]
assert len(eid) == 3
assert all(e == ea for e, ea in zip(eid, eid_answer))
# destination broadcasting
src, dst, eid = gi.edge_ids(toindex([1, 0]), toindex([2]))
eid_answer = [3, 4, 2]
assert len(eid) == 3
assert all(e == ea for e, ea in zip(eid, eid_answer))
gi.clear()
# the following assumes that grabbing nonexistent edge will throw an error
try:
gi.edge_id(0, 1)
fail = True
except DGLError:
fail = False
finally:
assert not fail
gi.add_nodes(4)
gi.add_edge(0, 1)
eid = gi.edge_id(0, 1).tolist()
assert len(eid) == 1
assert eid[0] == 0
def test_nx():
gi = create_graph_index(multigraph=True)
gi.add_nodes(2)
gi.add_edge(0, 1)
nxg = gi.to_networkx()
assert len(nxg.nodes) == 2
assert len(nxg.edges(0, 1)) == 1
gi.add_edge(0, 1)
nxg = gi.to_networkx()
assert len(nxg.edges(0, 1)) == 2
nxg = nx.DiGraph()
nxg.add_edge(0, 1)
gi = create_graph_index(nxg)
assert not gi.is_multigraph()
assert gi.number_of_nodes() == 2
assert gi.number_of_edges() == 1
assert gi.edge_id(0, 1)[0] == 0
nxg = nx.MultiDiGraph()
nxg.add_edge(0, 1)
nxg.add_edge(0, 1)
gi = create_graph_index(nxg, True)
assert gi.is_multigraph()
assert gi.number_of_nodes() == 2
assert gi.number_of_edges() == 2
assert 0 in gi.edge_id(0, 1)
assert 1 in gi.edge_id(0, 1)
def test_predsucc():
gi = create_graph_index(multigraph=True)
gi.add_nodes(4)
gi.add_edge(0, 1)
gi.add_edge(0, 1)
gi.add_edge(0, 2)
gi.add_edge(2, 0)
gi.add_edge(3, 0)
gi.add_edge(0, 0)
gi.add_edge(0, 0)
pred = gi.predecessors(0)
assert len(pred) == 3
assert 2 in pred
assert 3 in pred
assert 0 in pred
succ = gi.successors(0)
assert len(succ) == 3
assert 1 in succ
assert 2 in succ
assert 0 in succ
if __name__ == '__main__':
test_edge_id()
test_nx()
test_predsucc()
from dgl import DGLError
from dgl.utils import toindex
from dgl.graph_index import create_graph_index
def test_node_subgraph():
gi = create_graph_index()
gi.add_nodes(4)
gi.add_edge(0, 1)
gi.add_edge(0, 2)
gi.add_edge(0, 2)
gi.add_edge(0, 3)
sub2par_nodemap = [2, 0, 3]
sgi = gi.node_subgraph(toindex(sub2par_nodemap))
for s, d, e in zip(*sgi.edges()):
assert sgi.induced_edges[e] in gi.edge_id(
sgi.induced_nodes[s], sgi.induced_nodes[d])
def test_edge_subgraph():
gi = create_graph_index()
gi.add_nodes(4)
gi.add_edge(0, 1)
gi.add_edge(0, 1)
gi.add_edge(0, 2)
gi.add_edge(2, 3)
sub2par_edgemap = [3, 2]
sgi = gi.edge_subgraph(toindex(sub2par_edgemap))
for s, d, e in zip(*sgi.edges()):
assert sgi.induced_edges[e] in gi.edge_id(
sgi.induced_nodes[s], sgi.induced_nodes[d])
if __name__ == '__main__':
test_node_subgraph()
test_edge_subgraph()
...@@ -242,6 +242,104 @@ def test_pull_0deg(): ...@@ -242,6 +242,104 @@ def test_pull_0deg():
assert th.allclose(new_repr[0], old_repr[0]) assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[0]) assert th.allclose(new_repr[1], old_repr[0])
def test_send_twice():
g = DGLGraph()
g.add_nodes(3)
g.add_edge(0, 1)
g.add_edge(2, 1)
def _message_a(src, edge):
return {'a': src['a']}
def _message_b(src, edge):
return {'a': src['a'] * 3}
def _reduce(node, msgs):
assert msgs is not None
return {'a': msgs['a'].max(1)[0]}
old_repr = th.randn(3, 5)
g.set_n_repr({'a': old_repr})
g.send(0, 1, _message_a)
g.send(0, 1, _message_b)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
assert th.allclose(new_repr[1], old_repr[0] * 3)
g.set_n_repr({'a': old_repr})
g.send(0, 1, _message_a)
g.send(2, 1, _message_b)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
assert th.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0])
def test_send_multigraph():
g = DGLGraph(multigraph=True)
g.add_nodes(3)
g.add_edge(0, 1)
g.add_edge(0, 1)
g.add_edge(0, 1)
g.add_edge(2, 1)
def _message_a(src, edge):
return {'a': edge['a']}
def _message_b(src, edge):
return {'a': edge['a'] * 3}
def _reduce(node, msgs):
assert msgs is not None
return {'a': msgs['a'].max(1)[0]}
def answer(*args):
return th.stack(args, 0).max(0)[0]
# send by eid
old_repr = th.randn(4, 5)
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send(eid=[0, 2], message_func=_message_a)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2]))
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send(eid=[0, 2, 3], message_func=_message_a)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
# send on multigraph
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send([0, 2], [1, 1], _message_a)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
assert th.allclose(new_repr[1], old_repr.max(0)[0])
# consecutive send and send_on
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send(2, 1, _message_a)
g.send(eid=[0, 1], message_func=_message_b)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
assert th.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3]))
# consecutive send_on
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send(eid=0, message_func=_message_a)
g.send(eid=1, message_func=_message_b)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3))
# send_and_recv_on
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send_and_recv(eid=[0, 2, 3], message_func=_message_a, reduce_func=_reduce)
new_repr = g.get_n_repr()['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
assert th.allclose(new_repr[[0, 2]], th.zeros(2, 5))
if __name__ == '__main__': if __name__ == '__main__':
test_batch_setter_getter() test_batch_setter_getter()
test_batch_setter_autograd() test_batch_setter_autograd()
...@@ -250,3 +348,5 @@ if __name__ == '__main__': ...@@ -250,3 +348,5 @@ if __name__ == '__main__':
test_update_routines() test_update_routines()
test_reduce_0deg() test_reduce_0deg()
test_pull_0deg() test_pull_0deg()
test_send_twice()
test_send_multigraph()
...@@ -39,8 +39,8 @@ def test_no_backtracking(): ...@@ -39,8 +39,8 @@ def test_no_backtracking():
for i in range(1, N): for i in range(1, N):
e1 = G.edge_id(0, i) e1 = G.edge_id(0, i)
e2 = G.edge_id(i, 0) e2 = G.edge_id(i, 0)
assert not L.has_edge(e1, e2) assert not L.has_edge_between(e1, e2)
assert not L.has_edge(e2, e1) assert not L.has_edge_between(e2, e1)
if __name__ == '__main__': if __name__ == '__main__':
test_line_graph() test_line_graph()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment