Unverified Commit 8801154b authored by VoVAllen's avatar VoVAllen Committed by GitHub
Browse files

Merge pull request #1 from jermainewang/cpp

Cpp
parents b46abb09 b2c1c4fa
"""Package for graph generators"""
from __future__ import absolute_import
from .line import * from .line import *
...@@ -4,9 +4,9 @@ from __future__ import absolute_import ...@@ -4,9 +4,9 @@ from __future__ import absolute_import
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import dgl.backend as F from .. import backend as F
from dgl.graph import DGLGraph from ..graph import DGLGraph
from dgl.frame import FrameRef from ..frame import FrameRef
def line_graph(G, no_backtracking=False): def line_graph(G, no_backtracking=False):
"""Create the line graph that shares the underlying features. """Create the line graph that shares the underlying features.
......
...@@ -3,20 +3,22 @@ ...@@ -3,20 +3,22 @@
from __future__ import absolute_import from __future__ import absolute_import
import networkx as nx import networkx as nx
from networkx.classes.digraph import DiGraph import numpy as np
import dgl import dgl
from dgl.base import ALL, is_all, __MSG__, __REPR__ from .base import ALL, is_all, __MSG__, __REPR__
import dgl.backend as F from . import backend as F
from dgl.backend import Tensor from .backend import Tensor
from dgl.cached_graph import CachedGraph, create_cached_graph from .frame import FrameRef, merge_frames
import dgl.context as context from .function.message import BundledMessageFunction
from dgl.frame import FrameRef, merge_frames from .function.reducer import BundledReduceFunction
from dgl.nx_adapt import nx_init from .graph_index import GraphIndex, create_graph_index
import dgl.scheduler as scheduler from . import scheduler
import dgl.utils as utils from . import utils
class DGLGraph(DiGraph): __all__ = ['DLGraph']
class DGLGraph(object):
"""Base graph class specialized for neural networks on graphs. """Base graph class specialized for neural networks on graphs.
TODO(minjie): document of batching semantics TODO(minjie): document of batching semantics
...@@ -26,47 +28,478 @@ class DGLGraph(DiGraph): ...@@ -26,47 +28,478 @@ class DGLGraph(DiGraph):
---------- ----------
graph_data : graph data graph_data : graph data
Data to initialize graph. Same as networkx's semantics. Data to initialize graph. Same as networkx's semantics.
node_frame : dgl.frame.Frame node_frame : FrameRef
Node feature storage. Node feature storage.
edge_frame : dgl.frame.Frame edge_frame : FrameRef
Edge feature storage. Edge feature storage.
attr : keyword arguments, optional
Attributes to add to graph as key=value pairs.
""" """
def __init__(self, def __init__(self,
graph_data=None, graph_data=None,
node_frame=None, node_frame=None,
edge_frame=None, edge_frame=None):
**attr): # graph
# TODO(minjie): maintaining node/edge list is costly when graph is large. self._graph = create_graph_index(graph_data)
self._edge_list = [] # frame
nx_init(self,
self._add_node_callback,
self._add_edge_callback,
self._del_node_callback,
self._del_edge_callback,
graph_data,
**attr)
# cached graph and storage
self._cached_graph = None
self._node_frame = node_frame if node_frame is not None else FrameRef() self._node_frame = node_frame if node_frame is not None else FrameRef()
self._edge_frame = edge_frame if edge_frame is not None else FrameRef() self._edge_frame = edge_frame if edge_frame is not None else FrameRef()
# other class members # msg graph & frame
self._msg_graph = None self._msg_graph = create_graph_index()
self._msg_frame = FrameRef() self._msg_frame = FrameRef()
self._message_func = (None, None) self.reset_messages()
self._reduce_func = (None, None) # registered functions
self._edge_func = (None, None) self._message_func = None
self._apply_node_func = (None, None) self._reduce_func = None
self._apply_edge_func = (None, None) self._edge_func = None
self._apply_node_func = None
self._apply_edge_func = None
def add_nodes(self, num, reprs=None):
"""Add nodes.
Parameters
----------
num : int
Number of nodes to be added.
reprs : dict
Optional node representations.
"""
self._graph.add_nodes(num)
self._msg_graph.add_nodes(num)
#TODO(minjie): change frames
assert reprs is None
def add_edge(self, u, v, reprs=None):
"""Add one edge.
Parameters
----------
u : int
The src node.
v : int
The dst node.
reprs : dict
Optional edge representation.
"""
self._graph.add_edge(u, v)
#TODO(minjie): change frames
assert reprs is None
def add_edges(self, u, v, reprs=None):
"""Add many edges.
Parameters
----------
u : list, tensor
The src nodes.
v : list, tensor
The dst nodes.
reprs : dict
Optional node representations.
"""
u = utils.toindex(u)
v = utils.toindex(v)
self._graph.add_edges(u, v)
#TODO(minjie): change frames
assert reprs is None
def clear(self):
"""Clear the graph and its storage."""
self._graph.clear()
self._node_frame.clear()
self._edge_frame.clear()
self._msg_graph.clear()
self._msg_frame.clear()
def reset_messages(self):
"""Clear all messages."""
self._msg_graph.clear()
self._msg_frame.clear()
self._msg_graph.add_nodes(self.number_of_nodes())
def number_of_nodes(self):
"""Return the number of nodes.
Returns
-------
int
The number of nodes
"""
return self._graph.number_of_nodes()
def __len__(self):
"""Return the number of nodes."""
return self.number_of_nodes()
def number_of_edges(self):
"""Return the number of edges.
Returns
-------
int
The number of edges
"""
return self._graph.number_of_edges()
def has_node(self, vid):
"""Return true if the node exists.
Parameters
----------
vid : int
The nodes
Returns
-------
bool
True if the node exists
"""
return self.has_node(vid)
def __contains__(self, vid):
"""Same as has_node."""
return self.has_node(vid)
def has_nodes(self, vids):
"""Return true if the nodes exist.
Parameters
----------
vid : list, tensor
The nodes
Returns
-------
tensor
0-1 array indicating existence
"""
vids = utils.toindex(vids)
rst = self._graph.has_nodes(vids)
return rst.tousertensor()
def has_edge(self, u, v):
"""Return true if the edge exists.
Parameters
----------
u : int
The src node.
v : int
The dst node.
Returns
-------
bool
True if the edge exists
"""
return self._graph.has_edge(u, v)
def has_edges(self, u, v):
"""Return true if the edge exists.
Parameters
----------
u : list, tensor
The src nodes.
v : list, tensor
The dst nodes.
Returns
-------
tensor
0-1 array indicating existence
"""
u = utils.toindex(u)
v = utils.toindex(v)
rst = self._graph.has_edges(u, v)
return rst.tousertensor()
def predecessors(self, v, radius=1):
"""Return the predecessors of the node.
Parameters
----------
v : int
The node.
radius : int, optional
The radius of the neighborhood.
Returns
-------
tensor
Array of predecessors
"""
return self._graph.predecessors(v).tousertensor()
def successors(self, v, radius=1):
"""Return the successors of the node.
Parameters
----------
v : int
The node.
radius : int, optional
The radius of the neighborhood.
Returns
-------
tensor
Array of successors
"""
return self._graph.successors(v).tousertensor()
def edge_id(self, u, v):
"""Return the id of the edge.
Parameters
----------
u : int
The src node.
v : int
The dst node.
Returns
-------
int
The edge id.
"""
return self._graph.edge_id(u, v)
def edge_ids(self, u, v):
"""Return the edge ids.
Parameters
----------
u : list, tensor
The src nodes.
v : list, tensor
The dst nodes.
Returns
-------
tensor
The edge id array.
"""
u = utils.toindex(u)
v = utils.toindex(v)
rst = self._graph.edge_ids(u, v)
return rst.tousertensor()
def in_edges(self, v):
"""Return the in edges of the node(s).
Parameters
----------
v : int, list, tensor
The node(s).
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
"""
v = utils.toindex(v)
src, dst, eid = self._graph.in_edges(v)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
def out_edges(self, v):
"""Return the out edges of the node(s).
Parameters
----------
v : int, list, tensor
The node(s).
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
"""
v = utils.toindex(v)
src, dst, eid = self._graph.out_edges(v)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
def edges(self, sorted=False):
"""Return all the edges.
Parameters
----------
sorted : bool
True if the returned edges are sorted by their src and dst ids.
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
"""
src, dst, eid = self._graph.edges(sorted)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
def in_degree(self, v):
"""Return the in degree of the node.
Parameters
----------
v : int
The node.
Returns
-------
int
The in degree.
"""
return self._graph.in_degree(v)
def in_degrees(self, v):
"""Return the in degrees of the nodes.
Parameters
----------
v : list, tensor
The nodes.
Returns
-------
tensor
The in degree array.
"""
return self._graph.in_degrees(v).tousertensor()
def out_degree(self, v):
"""Return the out degree of the node.
Parameters
----------
v : int
The node.
Returns
-------
int
The out degree.
"""
return self._graph.out_degree(v)
def out_degrees(self, v):
"""Return the out degrees of the nodes.
Parameters
----------
v : list, tensor
The nodes.
Returns
-------
tensor
The out degree array.
"""
return self._graph.out_degrees(v).tousertensor()
def to_networkx(self, node_attrs=None, edge_attrs=None):
"""Convert to networkx graph.
The edge id will be saved as the 'id' edge attribute.
Parameters
----------
node_attrs : iterable of str, optional
The node attributes to be copied.
edge_attrs : iterable of str, optional
The edge attributes to be copied.
Returns
-------
networkx.DiGraph
The nx graph
"""
nx_graph = self._graph.to_networkx()
#TODO: attributes
return nx_graph
def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
"""Convert from networkx graph.
If 'id' edge attribute exists, the edge will be added follows
the edge id order. Otherwise, order is undefined.
Parameters
----------
nx_graph : networkx.DiGraph
The nx graph
node_attrs : iterable of str, optional
The node attributes needs to be copied.
edge_attrs : iterable of str, optional
The edge attributes needs to be copied.
"""
self.clear()
self._graph.from_networkx(nx_graph)
self._msg_graph.add_nodes(self._graph.number_of_nodes())
# copy attributes
def _batcher(lst):
if isinstance(lst[0], Tensor):
return F.pack([F.unsqueeze(x, 0) for x in lst])
else:
return F.tensor(lst)
if node_attrs is not None:
attr_dict = {attr : [] for attr in node_attrs}
for nid in range(self.number_of_nodes()):
for attr in node_attrs:
attr_dict[attr].append(nx_graph.nodes[nid][attr])
for attr in node_attrs:
self._node_frame[attr] = _batcher(attr_dict[attr])
if edge_attrs is not None:
attr_dict = {attr : [] for attr in edge_attrs}
src, dst, _ = self._graph.edges()
for u, v in zip(src.tolist(), dst.tolist()):
for attr in edge_attrs:
attr_dict[attr].append(nx_graph.edges[u, v][attr])
for attr in edge_attrs:
self._edge_frame[attr] = _batcher(attr_dict[attr])
def from_scipy_sparse_matrix(self, a):
""" Convert from scipy sparse matrix.
Parameters
----------
a : scipy sparse matrix
The graph's adjacency matrix
"""
self.clear()
self._graph.from_scipy_sparse_matrix(a)
self._msg_graph.add_nodes(self._graph.number_of_nodes())
def node_attr_schemes(self): def node_attr_schemes(self):
"""Return the node attribute schemes.
Returns
-------
iterable
The set of attribute names
"""
return self._node_frame.schemes return self._node_frame.schemes
def edge_attr_schemes(self): def edge_attr_schemes(self):
"""Return the edge attribute schemes.
Returns
-------
iterable
The set of attribute names
"""
return self._edge_frame.schemes return self._edge_frame.schemes
def set_n_repr(self, hu, u=ALL): def set_n_repr(self, hu, u=ALL, inplace=False):
"""Set node(s) representation. """Set node(s) representation.
To set multiple node representations at once, pass `u` with a tensor or To set multiple node representations at once, pass `u` with a tensor or
...@@ -104,9 +537,9 @@ class DGLGraph(DiGraph): ...@@ -104,9 +537,9 @@ class DGLGraph(DiGraph):
self._node_frame[__REPR__] = hu self._node_frame[__REPR__] = hu
else: else:
if utils.is_dict_like(hu): if utils.is_dict_like(hu):
self._node_frame[u] = hu self._node_frame.update_rows(u, hu, inplace=inplace)
else: else:
self._node_frame[u] = {__REPR__ : hu} self._node_frame.update_rows(u, {__REPR__ : hu}, inplace=inplace)
def get_n_repr(self, u=ALL): def get_n_repr(self, u=ALL):
"""Get node(s) representation. """Get node(s) representation.
...@@ -114,8 +547,15 @@ class DGLGraph(DiGraph): ...@@ -114,8 +547,15 @@ class DGLGraph(DiGraph):
Parameters Parameters
---------- ----------
u : node, container or tensor u : node, container or tensor
The node(s). The node(s).
Returns
-------
dict
Representation dict
""" """
if len(self.node_attr_schemes()) == 0:
return dict()
if is_all(u): if is_all(u):
if len(self._node_frame) == 1 and __REPR__ in self._node_frame: if len(self._node_frame) == 1 and __REPR__ in self._node_frame:
return self._node_frame[__REPR__] return self._node_frame[__REPR__]
...@@ -134,7 +574,12 @@ class DGLGraph(DiGraph): ...@@ -134,7 +574,12 @@ class DGLGraph(DiGraph):
Parameters Parameters
---------- ----------
key : str key : str
The attribute name. The attribute name.
Returns
-------
Tensor
The popped representation
""" """
return self._node_frame.pop(key) return self._node_frame.pop(key)
...@@ -163,29 +608,12 @@ class DGLGraph(DiGraph): ...@@ -163,29 +608,12 @@ class DGLGraph(DiGraph):
v_is_all = is_all(v) v_is_all = is_all(v)
assert u_is_all == v_is_all assert u_is_all == v_is_all
if u_is_all: if u_is_all:
num_edges = self.cached_graph.num_edges() self.set_e_repr_by_id(h_uv, eid=ALL)
else: else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
num_edges = max(len(u), len(v)) eid = self._graph.edge_ids(u, v)
if utils.is_dict_like(h_uv): self.set_e_repr_by_id(h_uv, eid=eid)
for key, val in h_uv.items():
assert F.shape(val)[0] == num_edges
else:
assert F.shape(h_uv)[0] == num_edges
# set
if u_is_all:
if utils.is_dict_like(h_uv):
for key, val in h_uv.items():
self._edge_frame[key] = val
else:
self._edge_frame[__REPR__] = h_uv
else:
eid = self.cached_graph.get_edge_id(u, v)
if utils.is_dict_like(h_uv):
self._edge_frame[eid] = h_uv
else:
self._edge_frame[eid] = {__REPR__ : h_uv}
def set_e_repr_by_id(self, h_uv, eid=ALL): def set_e_repr_by_id(self, h_uv, eid=ALL):
"""Set edge(s) representation by edge id. """Set edge(s) representation by edge id.
...@@ -199,7 +627,7 @@ class DGLGraph(DiGraph): ...@@ -199,7 +627,7 @@ class DGLGraph(DiGraph):
""" """
# sanity check # sanity check
if is_all(eid): if is_all(eid):
num_edges = self.cached_graph.num_edges() num_edges = self.number_of_edges()
else: else:
eid = utils.toindex(eid) eid = utils.toindex(eid)
num_edges = len(eid) num_edges = len(eid)
...@@ -230,23 +658,24 @@ class DGLGraph(DiGraph): ...@@ -230,23 +658,24 @@ class DGLGraph(DiGraph):
The source node(s). The source node(s).
v : node, container or tensor v : node, container or tensor
The destination node(s). The destination node(s).
Returns
-------
dict
Representation dict
""" """
u_is_all = is_all(u) u_is_all = is_all(u)
v_is_all = is_all(v) v_is_all = is_all(v)
assert u_is_all == v_is_all assert u_is_all == v_is_all
if len(self.edge_attr_schemes()) == 0:
return dict()
if u_is_all: if u_is_all:
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: return self.get_e_repr_by_id(eid=ALL)
return self._edge_frame[__REPR__]
else:
return dict(self._edge_frame)
else: else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
eid = self.cached_graph.get_edge_id(u, v) eid = self._graph.edge_ids(u, v)
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: return self.get_e_repr_by_id(eid=eid)
return self._edge_frame.select_rows(eid)[__REPR__]
else:
return self._edge_frame.select_rows(eid)
def pop_e_repr(self, key=__REPR__): def pop_e_repr(self, key=__REPR__):
"""Get and remove the specified edge repr. """Get and remove the specified edge repr.
...@@ -255,6 +684,11 @@ class DGLGraph(DiGraph): ...@@ -255,6 +684,11 @@ class DGLGraph(DiGraph):
---------- ----------
key : str key : str
The attribute name. The attribute name.
Returns
-------
Tensor
The popped representation
""" """
return self._edge_frame.pop(key) return self._edge_frame.pop(key)
...@@ -265,7 +699,14 @@ class DGLGraph(DiGraph): ...@@ -265,7 +699,14 @@ class DGLGraph(DiGraph):
---------- ----------
eid : int, container or tensor eid : int, container or tensor
The edge id(s). The edge id(s).
Returns
-------
dict
Representation dict
""" """
if len(self.edge_attr_schemes()) == 0:
return dict()
if is_all(eid): if is_all(eid):
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__] return self._edge_frame[__REPR__]
...@@ -278,77 +719,57 @@ class DGLGraph(DiGraph): ...@@ -278,77 +719,57 @@ class DGLGraph(DiGraph):
else: else:
return self._edge_frame.select_rows(eid) return self._edge_frame.select_rows(eid)
def register_edge_func(self, def register_edge_func(self, edge_func):
edge_func,
batchable=False):
"""Register global edge update function. """Register global edge update function.
Parameters Parameters
---------- ----------
edge_func : callable edge_func : callable
Message function on the edge. Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
""" """
self._edge_func = (edge_func, batchable) self._edge_func = edge_func
def register_message_func(self, def register_message_func(self, message_func):
message_func,
batchable=False):
"""Register global message function. """Register global message function.
Parameters Parameters
---------- ----------
message_func : callable message_func : callable
Message function on the edge. Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
""" """
self._message_func = (message_func, batchable) self._message_func = message_func
def register_reduce_func(self, def register_reduce_func(self, reduce_func):
reduce_func,
batchable=False):
"""Register global message reduce function. """Register global message reduce function.
Parameters Parameters
---------- ----------
reduce_func : str or callable reduce_func : str or callable
Reduce function on incoming edges. Reduce function on incoming edges.
batchable : bool
Whether the provided reduce function allows batch computing.
""" """
self._reduce_func = (reduce_func, batchable) self._reduce_func = reduce_func
def register_apply_node_func(self, def register_apply_node_func(self, apply_node_func):
apply_node_func,
batchable=False):
"""Register global node apply function. """Register global node apply function.
Parameters Parameters
---------- ----------
apply_node_func : callable apply_node_func : callable
Apply function on the node. Apply function on the node.
batchable : bool
Whether the provided function allows batch computing.
""" """
self._apply_node_func = (apply_node_func, batchable) self._apply_node_func = apply_node_func
def register_apply_edge_func(self, def register_apply_edge_func(self, apply_edge_func):
apply_edge_func,
batchable=False):
"""Register global edge apply function. """Register global edge apply function.
Parameters Parameters
---------- ----------
apply_edge_func : callable apply_edge_func : callable
Apply function on the edge. Apply function on the edge.
batchable : bool
Whether the provided function allows batch computing.
""" """
self._apply_edge_func = (apply_edge_func, batchable) self._apply_edge_func = apply_edge_func
def apply_nodes(self, v, apply_node_func="default", batchable=False): def apply_nodes(self, v, apply_node_func="default"):
"""Apply the function on node representations. """Apply the function on node representations.
Parameters Parameters
...@@ -357,26 +778,16 @@ class DGLGraph(DiGraph): ...@@ -357,26 +778,16 @@ class DGLGraph(DiGraph):
The node id(s). The node id(s).
apply_node_func : callable apply_node_func : callable
The apply node function. The apply node function.
batchable : bool
Whether the provided function allows batch computing.
""" """
if apply_node_func == "default": if apply_node_func == "default":
apply_node_func, batchable = self._apply_node_func apply_node_func = self._apply_node_func
if not apply_node_func: if not apply_node_func:
# Skip none function call. # Skip none function call.
return return
if batchable: new_repr = apply_node_func(self.get_n_repr(v))
new_repr = apply_node_func(self.get_n_repr(v)) self.set_n_repr(new_repr, v)
self.set_n_repr(new_repr, v)
else:
if is_all(v):
v = self.nodes()
v = utils.toindex(v)
for vv in utils.node_iter(v):
ret = apply_node_func(_get_repr(self.nodes[vv]))
_set_repr(self.nodes[vv], ret)
def apply_edges(self, u, v, apply_edge_func="default", batchable=False): def apply_edges(self, u, v, apply_edge_func="default"):
"""Apply the function on edge representations. """Apply the function on edge representations.
Parameters Parameters
...@@ -387,27 +798,16 @@ class DGLGraph(DiGraph): ...@@ -387,27 +798,16 @@ class DGLGraph(DiGraph):
The dst node id(s). The dst node id(s).
apply_edge_func : callable apply_edge_func : callable
The apply edge function. The apply edge function.
batchable : bool
Whether the provided function allows batch computing.
""" """
if apply_edge_func == "default": if apply_edge_func == "default":
apply_edge_func, batchable = self._apply_edge_func apply_edge_func = self._apply_edge_func
if not apply_edge_func: if not apply_edge_func:
# Skip none function call. # Skip none function call.
return return
if batchable: new_repr = apply_edge_func(self.get_e_repr(u, v))
new_repr = apply_edge_func(self.get_e_repr(u, v)) self.set_e_repr(new_repr, u, v)
self.set_e_repr(new_repr, u, v)
else:
if is_all(u) == is_all(v):
u, v = zip(*self.edges)
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = apply_edge_func(_get_repr(self.edges[uu, vv]))
_set_repr(self.edges[uu, vv], ret)
def send(self, u, v, message_func="default", batchable=False): def send(self, u, v, message_func="default"):
"""Trigger the message function on edge u->v """Trigger the message function on edge u->v
The message function should be compatible with following signature: The message function should be compatible with following signature:
...@@ -428,32 +828,18 @@ class DGLGraph(DiGraph): ...@@ -428,32 +828,18 @@ class DGLGraph(DiGraph):
The destination node(s). The destination node(s).
message_func : callable message_func : callable
The message function. The message function.
batchable : bool
Whether the function allows batched computation.
""" """
if message_func == "default": if message_func == "default":
message_func, batchable = self._message_func message_func = self._message_func
assert message_func is not None assert message_func is not None
if batchable: if isinstance(message_func, (tuple, list)):
self._batch_send(u, v, message_func) message_func = BundledMessageFunction(message_func)
else: self._batch_send(u, v, message_func)
self._nonbatch_send(u, v, message_func)
def _nonbatch_send(self, u, v, message_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = message_func(_get_repr(self.nodes[uu]),
_get_repr(self.edges[uu, vv]))
self.edges[uu, vv][__MSG__] = ret
def _batch_send(self, u, v, message_func): def _batch_send(self, u, v, message_func):
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
u, v = self.cached_graph.edges() u, v, _ = self._graph.edges()
self.msg_graph.add_edges(u, v) self._msg_graph.add_edges(u, v) # TODO(minjie): can be optimized
# call UDF # call UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr() edge_reprs = self.get_e_repr()
...@@ -462,18 +848,17 @@ class DGLGraph(DiGraph): ...@@ -462,18 +848,17 @@ class DGLGraph(DiGraph):
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v) u, v = utils.edge_broadcasting(u, v)
eid = self.cached_graph.get_edge_id(u, v) self._msg_graph.add_edges(u, v)
self.msg_graph.add_edges(u, v)
# call UDF # call UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid) edge_reprs = self.get_e_repr(u, v)
msgs = message_func(src_reprs, edge_reprs) msgs = message_func(src_reprs, edge_reprs)
if utils.is_dict_like(msgs): if utils.is_dict_like(msgs):
self._msg_frame.append(msgs) self._msg_frame.append(msgs)
else: else:
self._msg_frame.append({__MSG__ : msgs}) self._msg_frame.append({__MSG__ : msgs})
def update_edge(self, u, v, edge_func="default", batchable=False): def update_edge(self, u=ALL, v=ALL, edge_func="default"):
"""Update representation on edge u->v """Update representation on edge u->v
The edge function should be compatible with following signature: The edge function should be compatible with following signature:
...@@ -492,32 +877,15 @@ class DGLGraph(DiGraph): ...@@ -492,32 +877,15 @@ class DGLGraph(DiGraph):
The destination node(s). The destination node(s).
edge_func : callable edge_func : callable
The update function. The update function.
batchable : bool
Whether the function allows batched computation.
""" """
if edge_func == "default": if edge_func == "default":
edge_func, batchable = self._edge_func edge_func = self._edge_func
assert edge_func is not None assert edge_func is not None
if batchable: self._batch_update_edge(u, v, edge_func)
self._batch_update_edge(u, v, edge_func)
else:
self._nonbatch_update_edge(u, v, edge_func)
def _nonbatch_update_edge(self, u, v, edge_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = edge_func(_get_repr(self.nodes[uu]),
_get_repr(self.nodes[vv]),
_get_repr(self.edges[uu, vv]))
_set_repr(self.edges[uu, vv], ret)
def _batch_update_edge(self, u, v, edge_func): def _batch_update_edge(self, u, v, edge_func):
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
u, v = self.cached_graph.edges() u, v, _ = self._graph.edges()
# call the UDF # call the UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v) dst_reprs = self.get_n_repr(v)
...@@ -528,7 +896,7 @@ class DGLGraph(DiGraph): ...@@ -528,7 +896,7 @@ class DGLGraph(DiGraph):
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v) u, v = utils.edge_broadcasting(u, v)
eid = self.cached_graph.get_edge_id(u, v) eid = self._graph.edge_ids(u, v)
# call the UDF # call the UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v) dst_reprs = self.get_n_repr(v)
...@@ -539,8 +907,7 @@ class DGLGraph(DiGraph): ...@@ -539,8 +907,7 @@ class DGLGraph(DiGraph):
def recv(self, def recv(self,
u, u,
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default"):
batchable=False):
"""Receive and reduce in-coming messages and update representation on node u. """Receive and reduce in-coming messages and update representation on node u.
It computes the new node state using the messages sent from the predecessors It computes the new node state using the messages sent from the predecessors
...@@ -570,31 +937,15 @@ class DGLGraph(DiGraph): ...@@ -570,31 +937,15 @@ class DGLGraph(DiGraph):
The reduce function. The reduce function.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
batchable : bool, optional
Whether the reduce and update function allows batched computation.
""" """
if reduce_func == "default": if reduce_func == "default":
reduce_func, batchable = self._reduce_func reduce_func = self._reduce_func
assert reduce_func is not None assert reduce_func is not None
if batchable: if isinstance(reduce_func, (list, tuple)):
self._batch_recv(u, reduce_func) reduce_func = BundledReduceFunction(reduce_func)
else: self._batch_recv(u, reduce_func)
self._nonbatch_recv(u, reduce_func)
# optional apply nodes # optional apply nodes
self.apply_nodes(u, apply_node_func, batchable) self.apply_nodes(u, apply_node_func)
def _nonbatch_recv(self, u, reduce_func):
if is_all(u):
u = list(range(0, self.number_of_nodes()))
else:
u = utils.toindex(u)
for i, uu in enumerate(utils.node_iter(u)):
# reduce phase
msgs_batch = [self.edges[vv, uu].pop(__MSG__)
for vv in self.pred[uu] if __MSG__ in self.edges[vv, uu]]
if len(msgs_batch) != 0:
new_repr = reduce_func(_get_repr(self.nodes[uu]), msgs_batch)
_set_repr(self.nodes[uu], new_repr)
def _batch_recv(self, v, reduce_func): def _batch_recv(self, v, reduce_func):
if self._msg_frame.num_rows == 0: if self._msg_frame.num_rows == 0:
...@@ -610,7 +961,7 @@ class DGLGraph(DiGraph): ...@@ -610,7 +961,7 @@ class DGLGraph(DiGraph):
v = utils.toindex(v) v = utils.toindex(v)
# degree bucketing # degree bucketing
degrees, v_buckets = scheduler.degree_bucketing(self.msg_graph, v) degrees, v_buckets = scheduler.degree_bucketing(self._msg_graph, v)
if degrees == [0]: if degrees == [0]:
# no message has been sent to the specified node # no message has been sent to the specified node
return return
...@@ -625,8 +976,7 @@ class DGLGraph(DiGraph): ...@@ -625,8 +976,7 @@ class DGLGraph(DiGraph):
continue continue
bkt_len = len(v_bkt) bkt_len = len(v_bkt)
dst_reprs = self.get_n_repr(v_bkt) dst_reprs = self.get_n_repr(v_bkt)
uu, vv, _ = self.msg_graph.in_edges(v_bkt) uu, vv, in_msg_ids = self._msg_graph.in_edges(v_bkt)
in_msg_ids = self.msg_graph.get_edge_id(uu, vv)
in_msgs = self._msg_frame.select_rows(in_msg_ids) in_msgs = self._msg_frame.select_rows(in_msg_ids)
# Reshape the column tensor to (B, Deg, ...). # Reshape the column tensor to (B, Deg, ...).
def _reshape_fn(msg): def _reshape_fn(msg):
...@@ -638,11 +988,11 @@ class DGLGraph(DiGraph): ...@@ -638,11 +988,11 @@ class DGLGraph(DiGraph):
else: else:
reshaped_in_msgs = utils.LazyDict( reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes) lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes)
reordered_v.append(v_bkt.totensor()) reordered_v.append(v_bkt.tousertensor())
new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs)) new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs))
# TODO: clear partial messages # TODO: clear partial messages
self.clear_messages() self.reset_messages()
# Pack all reducer results together # Pack all reducer results together
reordered_v = F.pack(reordered_v) reordered_v = F.pack(reordered_v)
...@@ -667,8 +1017,7 @@ class DGLGraph(DiGraph): ...@@ -667,8 +1017,7 @@ class DGLGraph(DiGraph):
u, v, u, v,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default"):
batchable=False):
"""Trigger the message function on u->v and update v. """Trigger the message function on u->v and update v.
Parameters Parameters
...@@ -683,8 +1032,6 @@ class DGLGraph(DiGraph): ...@@ -683,8 +1032,6 @@ class DGLGraph(DiGraph):
The reduce function. The reduce function.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
""" """
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
...@@ -692,36 +1039,30 @@ class DGLGraph(DiGraph): ...@@ -692,36 +1039,30 @@ class DGLGraph(DiGraph):
# no edges to be triggered # no edges to be triggered
assert len(v) == 0 assert len(v) == 0
return return
unique_v = utils.toindex(F.unique(v.totensor())) unique_v = utils.toindex(F.unique(v.tousertensor()))
# TODO(minjie): better way to figure out `batchable` flag
if message_func == "default": if message_func == "default":
message_func, batchable = self._message_func message_func = self._message_func
if reduce_func == "default": if reduce_func == "default":
reduce_func, _ = self._reduce_func reduce_func = self._reduce_func
assert message_func is not None assert message_func is not None
assert reduce_func is not None assert reduce_func is not None
if batchable: executor = scheduler.get_executor(
executor = scheduler.get_executor( 'send_and_recv', self, src=u, dst=v,
'send_and_recv', self, src=u, dst=v, message_func=message_func, reduce_func=reduce_func)
message_func=message_func, reduce_func=reduce_func)
else:
executor = None
if executor: if executor:
executor.run() executor.run()
else: else:
self.send(u, v, message_func, batchable=batchable) self.send(u, v, message_func)
self.recv(unique_v, reduce_func, None, batchable=batchable) self.recv(unique_v, reduce_func, None)
self.apply_nodes(unique_v, apply_node_func, batchable=batchable) self.apply_nodes(unique_v, apply_node_func)
def pull(self, def pull(self,
v, v,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default"):
batchable=False):
"""Pull messages from the node's predecessors and then update it. """Pull messages from the node's predecessors and then update it.
Parameters Parameters
...@@ -734,24 +1075,20 @@ class DGLGraph(DiGraph): ...@@ -734,24 +1075,20 @@ class DGLGraph(DiGraph):
The reduce function. The reduce function.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
""" """
v = utils.toindex(v) v = utils.toindex(v)
if len(v) == 0: if len(v) == 0:
return return
uu, vv, _ = self.cached_graph.in_edges(v) uu, vv, _ = self._graph.in_edges(v)
self.send_and_recv(uu, vv, message_func, reduce_func, self.send_and_recv(uu, vv, message_func, reduce_func, apply_node_func=None)
apply_node_func=None, batchable=batchable) unique_v = F.unique(v.tousertensor())
unique_v = F.unique(v.totensor()) self.apply_nodes(unique_v, apply_node_func)
self.apply_nodes(unique_v, apply_node_func, batchable=batchable)
def push(self, def push(self,
u, u,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default"):
batchable=False):
"""Send message from the node to its successors and update them. """Send message from the node to its successors and update them.
Parameters Parameters
...@@ -764,21 +1101,18 @@ class DGLGraph(DiGraph): ...@@ -764,21 +1101,18 @@ class DGLGraph(DiGraph):
The reduce function. The reduce function.
apply_node_func : callable apply_node_func : callable
The update function. The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
""" """
u = utils.toindex(u) u = utils.toindex(u)
if len(u) == 0: if len(u) == 0:
return return
uu, vv, _ = self.cached_graph.out_edges(u) uu, vv, _ = self._graph.out_edges(u)
self.send_and_recv(uu, vv, message_func, self.send_and_recv(uu, vv, message_func,
reduce_func, apply_node_func, batchable=batchable) reduce_func, apply_node_func)
def update_all(self, def update_all(self,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default"):
batchable=False):
"""Send messages through all the edges and update all nodes. """Send messages through all the edges and update all nodes.
Parameters Parameters
...@@ -789,76 +1123,61 @@ class DGLGraph(DiGraph): ...@@ -789,76 +1123,61 @@ class DGLGraph(DiGraph):
The reduce function. The reduce function.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
""" """
if message_func == "default": if message_func == "default":
message_func, batchable = self._message_func message_func = self._message_func
if reduce_func == "default": if reduce_func == "default":
reduce_func, _ = self._reduce_func reduce_func = self._reduce_func
assert message_func is not None assert message_func is not None
assert reduce_func is not None assert reduce_func is not None
if batchable: executor = scheduler.get_executor(
executor = scheduler.get_executor( "update_all", self, message_func=message_func, reduce_func=reduce_func)
"update_all", self, message_func=message_func, reduce_func=reduce_func)
else:
executor = None
if executor: if executor:
executor.run() executor.run()
else: else:
self.send(ALL, ALL, message_func, batchable=batchable) self.send(ALL, ALL, message_func)
self.recv(ALL, reduce_func, None, batchable=batchable) self.recv(ALL, reduce_func, None)
self.apply_nodes(ALL, apply_node_func, batchable=batchable) self.apply_nodes(ALL, apply_node_func)
def propagate(self, def propagate(self,
iterator='bfs', traverser='topo',
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default",
batchable=False,
**kwargs): **kwargs):
"""Propagate messages and update nodes using iterator. """Propagate messages and update nodes using graph traversal.
A convenient function for passing messages and updating A convenient function for passing messages and updating
nodes according to the iterator. The iterator can be nodes according to the traverser. The traverser can be
any of the pre-defined iterators ('bfs', 'dfs', 'pre-order', any of the pre-defined traverser (e.g. 'topo'). User can also provide custom
'mid-order', 'post-order'). The computation will be unrolled traverser that generates the edges and nodes.
in the backend efficiently. User can also provide custom
iterator that generates the edges and nodes.
Parameters Parameters
---------- ----------
traverser : str or generator of edges.
The traverser of the graph.
message_func : str or callable message_func : str or callable
The message function. The message function.
reduce_func : str or callable reduce_func : str or callable
The reduce function. The reduce function.
apply_node_func : str or callable apply_node_func : str or callable
The update function. The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
iterator : str or generator of steps.
The iterator of the graph.
kwargs : keyword arguments, optional kwargs : keyword arguments, optional
Arguments for pre-defined iterators. Arguments for pre-defined iterators.
""" """
if isinstance(iterator, str): if isinstance(traverser, str):
# TODO Call pre-defined routine to unroll the computation. # TODO Call pre-defined routine to unroll the computation.
raise RuntimeError('Not implemented.') raise RuntimeError('Not implemented.')
else: else:
# NOTE: the iteration can return multiple edges at each step. # NOTE: the iteration can return multiple edges at each step.
for u, v in iterator: for u, v in traverser:
self.send_and_recv(u, v, self.send_and_recv(u, v,
message_func, reduce_func, apply_node_func, batchable) message_func, reduce_func, apply_node_func)
def subgraph(self, nodes): def subgraph(self, nodes):
"""Generate the subgraph among the given nodes. """Generate the subgraph among the given nodes.
The generated graph contains only the graph structure. The node/edge
features are not shared implicitly. Use `copy_from` to get node/edge
features from parent graph.
Parameters Parameters
---------- ----------
nodes : list, or iterable nodes : list, or iterable
...@@ -869,7 +1188,9 @@ class DGLGraph(DiGraph): ...@@ -869,7 +1188,9 @@ class DGLGraph(DiGraph):
G : DGLSubGraph G : DGLSubGraph
The subgraph. The subgraph.
""" """
return dgl.DGLSubGraph(self, nodes) induced_nodes = utils.toindex(nodes)
gi, induced_edges = self._graph.node_subgraph(induced_nodes)
return dgl.DGLSubGraph(self, induced_nodes, induced_edges, gi)
def merge(self, subgraphs, reduce_func='sum'): def merge(self, subgraphs, reduce_func='sum'):
"""Merge subgraph features back to this parent graph. """Merge subgraph features back to this parent graph.
...@@ -913,91 +1234,109 @@ class DGLGraph(DiGraph): ...@@ -913,91 +1234,109 @@ class DGLGraph(DiGraph):
self._edge_frame.num_rows, self._edge_frame.num_rows,
reduce_func) reduce_func)
def draw(self): def adjacency_matrix(self, ctx=None):
"""Plot the graph using dot.""" """Return the adjacency matrix representation of this graph.
from networkx.drawing.nx_agraph import graphviz_layout
Parameters
----------
ctx : optional
The context of returned adjacency matrix.
Returns
-------
sparse_tensor
The adjacency matrix.
"""
return self._graph.adjacency_matrix().get(ctx)
def incidence_matrix(self, oriented=False, ctx=None):
"""Return the incidence matrix representation of this graph.
Parameters
----------
oriented : bool, optional
Whether the returned incidence matrix is oriented.
ctx : optional
The context of returned incidence matrix.
Returns
-------
sparse_tensor
The incidence matrix.
"""
return self._graph.incidence_matrix(oriented).get(ctx)
def line_graph(self, backtracking=True, shared=False):
"""Return the line graph of this graph.
Parameters
----------
backtracking : bool, optional
Whether the returned line graph is backtracking.
shared : bool, optional
Whether the returned line graph shares representations with `self`.
pos = graphviz_layout(self, prog='dot') Returns
nx.draw(self, pos, with_labels=True) -------
DGLGraph
The line graph of this graph.
"""
graph_data = self._graph.line_graph(backtracking)
node_frame = self._edge_frame if shared else None
return DGLGraph(graph_data, node_frame)
@property def filter_nodes(self, predicate, nodes=ALL):
def cached_graph(self): """Return a tensor of node IDs that satisfy the given predicate.
# TODO: dirty flag when mutated
if self._cached_graph is None:
self._cached_graph = create_cached_graph(self)
return self._cached_graph
@property Parameters
def msg_graph(self): ----------
# TODO: dirty flag when mutated predicate : callable
if self._msg_graph is None: The predicate should take in a dict of tensors whose values
self._msg_graph = CachedGraph() are concatenation of node representations by node ID (same as
self._msg_graph.add_nodes(self.number_of_nodes()) get_n_repr()), and return a boolean tensor with N elements
return self._msg_graph indicating which node satisfy the predicate.
nodes : container or tensor
The nodes to filter on
def clear_messages(self): Returns
if self._msg_graph is not None: -------
self._msg_graph = CachedGraph() tensor
self._msg_graph.add_nodes(self.number_of_nodes()) The filtered nodes
self._msg_frame.clear() """
n_repr = self.get_n_repr(nodes)
n_mask = predicate(n_repr)
@property if is_all(nodes):
def edge_list(self): return F.nonzero_1d(n_mask)
"""Return edges in the addition order.""" else:
return self._edge_list nodes = F.Tensor(nodes)
return nodes[n_mask]
def get_edge_id(self, u, v): def filter_edges(self, predicate, edges=ALL):
"""Return the continuous edge id(s) assigned. """Return a tensor of edge IDs that satisfy the given predicate.
Parameters Parameters
---------- ----------
u : node, container or tensor predicate : callable
The source node(s). The predicate should take in a dict of tensors whose values
v : node, container or tensor are concatenation of edge representations by edge ID (same as
The destination node(s). get_e_repr_by_id()), and return a boolean tensor with N elements
indicating which node satisfy the predicate.
edges : container or tensor
The edges to filter on
Returns Returns
------- -------
eid : tensor tensor
The tensor contains edge id(s). The filtered edges
""" """
u = utils.toindex(u) e_repr = self.get_e_repr_by_id(edges)
v = utils.toindex(v) e_mask = predicate(e_repr)
return self.cached_graph.get_edge_id(u, v)
if is_all(edges):
def _add_node_callback(self, node): return F.nonzero_1d(e_mask)
#print('New node:', node) else:
self._cached_graph = None edges = F.Tensor(edges)
return edges[e_mask]
def _del_node_callback(self, node):
#print('Del node:', node)
raise RuntimeError('Node removal is not supported currently.')
node = utils.convert_to_id_tensor(node)
self._node_frame.delete_rows(node)
self._cached_graph = None
def _add_edge_callback(self, u, v):
#print('New edge:', u, v)
self._edge_list.append((u, v))
self._cached_graph = None
def _del_edge_callback(self, u, v):
#print('Del edge:', u, v)
raise RuntimeError('Edge removal is not supported currently.')
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
eid = self.get_edge_id(u, v)
self._edge_frame.delete_rows(eid)
self._cached_graph = None
def _get_repr(attr_dict):
if len(attr_dict) == 1 and __REPR__ in attr_dict:
return attr_dict[__REPR__]
else:
return attr_dict
def _set_repr(attr_dict, attr):
if utils.is_dict_like(attr):
attr_dict.update(attr)
else:
attr_dict[__REPR__] = attr
from __future__ import absolute_import
import ctypes
import numpy as np
import networkx as nx
import scipy.sparse as sp
from ._ffi.base import c_array
from ._ffi.function import _init_api
from . import backend as F
from . import utils
GraphIndexHandle = ctypes.c_void_p
class GraphIndex(object):
"""Graph index object.
Parameters
----------
handle : GraphIndexHandle
Handler
"""
def __init__(self, handle):
self._handle = handle
self._cache = {}
def __del__(self):
"""Free this graph index object."""
_CAPI_DGLGraphFree(self._handle)
def add_nodes(self, num):
"""Add nodes.
Parameters
----------
num : int
Number of nodes to be added.
"""
_CAPI_DGLGraphAddVertices(self._handle, num);
self._cache.clear()
def add_edge(self, u, v):
"""Add one edge.
Parameters
----------
u : int
The src node.
v : int
The dst node.
"""
_CAPI_DGLGraphAddEdge(self._handle, u, v);
self._cache.clear()
def add_edges(self, u, v):
"""Add many edges.
Parameters
----------
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
_CAPI_DGLGraphAddEdges(self._handle, u_array, v_array)
self._cache.clear()
def clear(self):
"""Clear the graph."""
_CAPI_DGLGraphClear(self._handle)
self._cache.clear()
def number_of_nodes(self):
"""Return the number of nodes.
Returns
-------
int
The number of nodes
"""
return _CAPI_DGLGraphNumVertices(self._handle)
def number_of_edges(self):
"""Return the number of edges.
Returns
-------
int
The number of edges
"""
return _CAPI_DGLGraphNumEdges(self._handle)
def has_node(self, vid):
"""Return true if the node exists.
Parameters
----------
vid : int
The nodes
Returns
-------
bool
True if the node exists
"""
return _CAPI_DGLGraphHasVertex(self._handle, vid)
def has_nodes(self, vids):
"""Return true if the nodes exist.
Parameters
----------
vid : utils.Index
The nodes
Returns
-------
utils.Index
0-1 array indicating existence
"""
vid_array = vids.todgltensor()
return utils.toindex(_CAPI_DGLGraphHasVertices(self._handle, vid_array))
def has_edge(self, u, v):
"""Return true if the edge exists.
Parameters
----------
u : int
The src node.
v : int
The dst node.
Returns
-------
bool
True if the edge exists
"""
return _CAPI_DGLGraphHasEdge(self._handle, u, v)
def has_edges(self, u, v):
"""Return true if the edge exists.
Parameters
----------
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
Returns
-------
utils.Index
0-1 array indicating existence
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphHasEdges(self._handle, u_array, v_array))
def predecessors(self, v, radius=1):
"""Return the predecessors of the node.
Parameters
----------
v : int
The node.
radius : int, optional
The radius of the neighborhood.
Returns
-------
utils.Index
Array of predecessors
"""
return utils.toindex(_CAPI_DGLGraphPredecessors(self._handle, v, radius))
def successors(self, v, radius=1):
"""Return the successors of the node.
Parameters
----------
v : int
The node.
radius : int, optional
The radius of the neighborhood.
Returns
-------
utils.Index
Array of successors
"""
return utils.toindex(_CAPI_DGLGraphSuccessors(self._handle, v, radius))
def edge_id(self, u, v):
"""Return the id of the edge.
Parameters
----------
u : int
The src node.
v : int
The dst node.
Returns
-------
int
The edge id.
"""
return _CAPI_DGLGraphEdgeId(self._handle, u, v)
def edge_ids(self, u, v):
"""Return the edge ids.
Parameters
----------
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
Returns
-------
utils.Index
Teh edge id array.
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphEdgeIds(self._handle, u_array, v_array))
def in_edges(self, v):
"""Return the in edges of the node(s).
Parameters
----------
v : utils.Index
The node(s).
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if len(v) == 1:
edge_array = _CAPI_DGLGraphInEdges_1(self._handle, v[0])
else:
v_array = v.todgltensor()
edge_array = _CAPI_DGLGraphInEdges_2(self._handle, v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def out_edges(self, v):
"""Return the out edges of the node(s).
Parameters
----------
v : utils.Index
The node(s).
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if len(v) == 1:
edge_array = _CAPI_DGLGraphOutEdges_1(self._handle, v[0])
else:
v_array = v.todgltensor()
edge_array = _CAPI_DGLGraphOutEdges_2(self._handle, v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def edges(self, sorted=False):
"""Return all the edges
Parameters
----------
sorted : bool
True if the returned edges are sorted by their src and dst ids.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
edge_array = _CAPI_DGLGraphEdges(self._handle, sorted)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def in_degree(self, v):
"""Return the in degree of the node.
Parameters
----------
v : int
The node.
Returns
-------
int
The in degree.
"""
return _CAPI_DGLGraphInDegree(self._handle, v)
def in_degrees(self, v):
"""Return the in degrees of the nodes.
Parameters
----------
v : utils.Index
The nodes.
Returns
-------
int
The in degree array.
"""
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphInDegrees(self._handle, v_array))
def out_degree(self, v):
"""Return the out degree of the node.
Parameters
----------
v : int
The node.
Returns
-------
int
The out degree.
"""
return _CAPI_DGLGraphOutDegree(self._handle, v)
def out_degrees(self, v):
"""Return the out degrees of the nodes.
Parameters
----------
v : utils.Index
The nodes.
Returns
-------
int
The out degree array.
"""
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphOutDegrees(self._handle, v_array))
def node_subgraph(self, v):
"""Return the induced node subgraph.
Parameters
----------
v : utils.Index
The nodes.
Returns
-------
GraphIndex
The subgraph index.
utils.Index
The induced edge ids. This is also a map from new edge id to parent edge id.
"""
v_array = v.todgltensor()
rst = _CAPI_DGLGraphVertexSubgraph(self._handle, v_array)
gi = GraphIndex(rst(0))
induced_edges = utils.toindex(rst(2))
return gi, induced_edges
def adjacency_matrix(self):
"""Return the adjacency matrix representation of this graph.
Returns
-------
utils.CtxCachedObject
An object that returns tensor given context.
"""
if not 'adj' in self._cache:
src, dst, _ = self.edges(sorted=False)
src = F.unsqueeze(src.tousertensor(), 0)
dst = F.unsqueeze(dst.tousertensor(), 0)
idx = F.pack([dst, src])
n = self.number_of_nodes()
dat = F.ones((self.number_of_edges(),))
mat = F.sparse_tensor(idx, dat, [n, n])
self._cache['adj'] = utils.CtxCachedObject(lambda ctx: F.to_context(mat, ctx))
return self._cache['adj']
def incidence_matrix(self, oriented=False):
"""Return the incidence matrix representation of this graph.
Parameters
----------
oriented : bool, optional (default=False)
Whether the returned incidence matrix is oriented.
Returns
-------
utils.CtxCachedObject
An object that returns tensor given context.
"""
key = ('oriented ' if oriented else '') + 'incidence matrix'
if not key in self._cache:
src, dst, _ = self.edges(sorted=False)
src = src.tousertensor()
dst = dst.tousertensor()
m = self.number_of_edges()
eid = F.arange(m, dtype=F.int64)
row = F.pack([src, dst])
col = F.pack([eid, eid])
idx = F.stack([row, col])
diagonal = (src == dst)
if oriented:
x = -F.ones((m,))
y = F.ones((m,))
x[diagonal] = 0
y[diagonal] = 0
dat = F.pack([x, y])
else:
x = F.ones((m,))
x[diagonal] = 0
dat = F.pack([x, x])
n = self.number_of_nodes()
mat = F.sparse_tensor(idx, dat, [n, m])
self._cache[key] = utils.CtxCachedObject(lambda ctx: F.to_context(mat, ctx))
return self._cache[key]
def to_networkx(self):
"""Convert to networkx graph.
The edge id will be saved as the 'id' edge attribute.
Returns
-------
networkx.DiGraph
The nx graph
"""
src, dst, eid = self.edges()
ret = nx.DiGraph()
for u, v, id in zip(src, dst, eid):
ret.add_edge(u, v, id=id)
return ret
def from_networkx(self, nx_graph):
"""Convert from networkx graph.
If 'id' edge attribute exists, the edge will be added follows
the edge id order. Otherwise, order is undefined.
Parameters
----------
nx_graph : networkx.DiGraph
The nx graph
"""
self.clear()
if not isinstance(nx_graph, nx.DiGraph):
nx_graph = nx.DiGraph(nx_graph)
num_nodes = nx_graph.number_of_nodes()
self.add_nodes(num_nodes)
has_edge_id = 'id' in next(iter(nx_graph.edges))
if has_edge_id:
num_edges = nx_graph.number_of_edges()
src = np.zeros((num_edges,), dtype=np.int64)
dst = np.zeros((num_edges,), dtype=np.int64)
for e, attr in nx_graph.edges.items:
u, v = e
eid = attr['id']
src[eid] = u
dst[eid] = v
else:
src = []
dst = []
for u, v in nx_graph.edges:
src.append(u)
dst.append(v)
src = utils.toindex(src)
dst = utils.toindex(dst)
self.add_edges(src, dst)
def from_scipy_sparse_matrix(self, adj):
"""Convert from scipy sparse matrix.
Parameters
----------
adj : scipy sparse matrix
"""
self.clear()
self.add_nodes(adj.shape[0])
adj_coo = adj.tocoo()
src = utils.toindex(adj_coo.row)
dst = utils.toindex(adj_coo.col)
self.add_edges(src, dst)
def line_graph(self, backtracking=True):
"""Return the line graph of this graph.
Parameters
----------
backtracking : bool, optional (default=False)
Whether (i, j) ~ (j, i) in L(G).
(i, j) ~ (j, i) is the behavior of networkx.line_graph.
Returns
-------
GraphIndex
The line graph of this graph.
"""
handle = _CAPI_DGLGraphLineGraph(self._handle, backtracking)
return GraphIndex(handle)
def disjoint_union(graphs):
"""Return a disjoint union of the input graphs.
The new graph will include all the nodes/edges in the given graphs.
Nodes/Edges will be relabled by adding the cumsum of the previous graph sizes
in the given sequence order. For example, giving input [g1, g2, g3], where
they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become node#7
in the result graph. Edge ids are re-assigned similarly.
Parameters
----------
graphs : iterable of GraphIndex
The input graphs
Returns
-------
GraphIndex
The disjoint union
"""
inputs = c_array(GraphIndexHandle, [gr._handle for gr in graphs])
inputs = ctypes.cast(inputs, ctypes.c_void_p)
handle = _CAPI_DGLDisjointUnion(inputs, len(graphs))
return GraphIndex(handle)
def disjoint_partition(graph, num_or_size_splits):
"""Partition the graph disjointly.
This is a reverse operation of DisjointUnion. The graph will be partitioned
into num graphs. This requires the given number of partitions to evenly
divides the number of nodes in the graph. If the a size list is given,
the sum of the given sizes is equal.
Parameters
----------
graph : GraphIndex
The graph to be partitioned
num_or_size_splits : int or utils.Index
The partition number of size splits
Returns
-------
list of GraphIndex
The partitioned graphs
"""
if isinstance(num_or_size_splits, utils.Index):
rst = _CAPI_DGLDisjointPartitionBySizes(
graph._handle,
num_or_size_splits.todgltensor())
else:
rst = _CAPI_DGLDisjointPartitionByNum(
graph._handle,
int(num_or_size_splits))
graphs = []
for val in rst.asnumpy():
handle = ctypes.cast(int(val), ctypes.c_void_p)
graphs.append(GraphIndex(handle))
return graphs
def create_graph_index(graph_data=None):
"""Create a graph index object.
Parameters
----------
graph_data : graph data, optional
Data to initialize graph. Same as networkx's semantics.
"""
if isinstance(graph_data, GraphIndex):
return graph_data
handle = _CAPI_DGLGraphCreate()
gi = GraphIndex(handle)
if graph_data is not None:
gi.from_networkx(graph_data)
return gi
_init_api("dgl.graph_index")
"""DGL Runtime NDArray API.
dgl.ndarray provides a minimum runtime array structure to be
used with C++ library.
"""
# pylint: disable=invalid-name,unused-import
from __future__ import absolute_import as _abs
import ctypes
import functools
import operator
import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack, numpyasarray
from ._ffi.ndarray import _set_class_ndarray
from . import backend as F
class NDArray(NDArrayBase):
"""Lightweight NDArray class for DGL framework."""
def __len__(self):
return functools.reduce(operator.mul, self.shape, 1)
def cpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(1, dev_id)
def gpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(2, dev_id)
def array(arr, ctx=cpu(0)):
"""Create an array from source arr.
Parameters
----------
arr : numpy.ndarray
The array to be copied from
ctx : TVMContext, optional
The device context to create the array
Returns
-------
ret : NDArray
The created array
"""
if not isinstance(arr, (_np.ndarray, NDArray)):
arr = _np.array(arr)
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
def zerocopy_from_numpy(np_data):
"""Create an array that shares the given numpy data.
Parameters
----------
np_data : numpy.ndarray
The numpy data
Returns
-------
NDArray
The array
"""
arr, _ = numpyasarray(np_data)
handle = ctypes.pointer(arr)
return NDArray(handle, is_view=True)
_set_class_ndarray(NDArray)
"""Package nn modules"""
from __future__ import absolute_import
import os import os
__backend__ = os.environ.get('DGLBACKEND', 'pytorch').lower() __backend__ = os.environ.get('DGLBACKEND', 'pytorch').lower()
if __backend__ == 'numpy': if __backend__ == 'numpy':
pass pass
elif __backend__ == 'pytorch': elif __backend__ == 'pytorch':
from .pytorch import * from .pytorch import *
else: elif __backend__ != 'mxnet':
raise Exception("Unsupported backend %s" % __backend__) raise Exception("Unsupported backend %s" % __backend__)
...@@ -7,9 +7,8 @@ GCN with SPMV specialization. ...@@ -7,9 +7,8 @@ GCN with SPMV specialization.
""" """
import torch.nn as nn import torch.nn as nn
import dgl from ... import function as fn
import dgl.function as fn from ...base import ALL, is_all
from dgl.base import ALL, is_all
class NodeUpdateModule(nn.Module): class NodeUpdateModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self, in_feats, out_feats, activation=None):
......
"""Utility functions for networkx adapter."""
from __future__ import absolute_import
from collections import MutableMapping
import networkx as nx
import networkx.convert as convert
class NodeDict(MutableMapping):
def __init__(self, add_cb, del_cb):
self._dict = {}
self._add_cb = add_cb
self._del_cb = del_cb
def __setitem__(self, key, val):
self._add_cb(key)
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
self._del_cb(key)
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class AdjOuterDict(MutableMapping):
def __init__(self, add_cb, del_cb):
self._dict = {}
self._add_cb = add_cb
self._del_cb = del_cb
def __setitem__(self, key, val):
val.src = key
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
for val in self._dict[key]:
self._del_cb(key, val)
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class AdjInnerDict(MutableMapping):
def __init__(self, add_cb, del_cb):
self._dict = {}
self.src = None
self._add_cb = add_cb
self._del_cb = del_cb
def __setitem__(self, key, val):
if self.src is not None and key not in self._dict:
self._add_cb(self.src, key)
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
if self.src is not None:
self._del_cb(self.src, key)
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class AdjInnerDictFactory(object):
def __init__(self, cb1, cb2):
self._cb1 = cb1
self._cb2 = cb2
def __call__(self):
return AdjInnerDict(self._cb1, self._cb2)
def nx_init(obj,
add_node_cb,
add_edge_cb,
del_node_cb,
del_edge_cb,
graph_data,
**attr):
"""Init the object to be compatible with networkx's DiGraph.
Parameters
----------
obj : any
The object to be init.
add_node_cb : callable
The callback function when node is added.
add_edge_cb : callable
The callback function when edge is added.
graph_data : graph data
Data to initialize graph. Same as networkx's semantics.
attr : keyword arguments, optional
Attributes to add to graph as key=value pairs.
"""
# The following codes work for networkx 2.1.
obj.adjlist_outer_dict_factory = None
obj.adjlist_inner_dict_factory = AdjInnerDictFactory(add_edge_cb, del_edge_cb)
obj.edge_attr_dict_factory = dict
obj.root_graph = obj
obj.graph = {}
obj._node = NodeDict(add_node_cb, del_node_cb)
obj._adj = AdjOuterDict(add_edge_cb, del_edge_cb)
obj._pred = dict()
obj._succ = obj._adj
if graph_data is not None:
convert.to_networkx_graph(graph_data, create_using=obj)
obj.graph.update(attr)
...@@ -3,19 +3,20 @@ from __future__ import absolute_import ...@@ -3,19 +3,20 @@ from __future__ import absolute_import
import numpy as np import numpy as np
import dgl.backend as F from .base import ALL
import dgl.function.message as fmsg from . import backend as F
import dgl.function.reducer as fred from .function import message as fmsg
import dgl.utils as utils from .function import reducer as fred
from . import utils
__all__ = ["degree_bucketing", "get_executor"] __all__ = ["degree_bucketing", "get_executor"]
def degree_bucketing(cached_graph, v): def degree_bucketing(graph, v):
"""Create degree bucketing scheduling policy. """Create degree bucketing scheduling policy.
Parameters Parameters
---------- ----------
cached_graph : dgl.cached_graph.CachedGraph graph : dgl.graph_index.GraphIndex
the graph the graph
v : dgl.utils.Index v : dgl.utils.Index
the nodes to gather messages the nodes to gather messages
...@@ -28,7 +29,7 @@ def degree_bucketing(cached_graph, v): ...@@ -28,7 +29,7 @@ def degree_bucketing(cached_graph, v):
list of node id buckets; nodes belong to the same bucket have list of node id buckets; nodes belong to the same bucket have
the same degree the same degree
""" """
degrees = F.asnumpy(cached_graph.in_degrees(v).totensor()) degrees = np.array(graph.in_degrees(v).tolist())
unique_degrees = list(np.unique(degrees)) unique_degrees = list(np.unique(degrees))
v_np = np.array(v.tolist()) v_np = np.array(v.tolist())
v_bkt = [] v_bkt = []
...@@ -38,37 +39,32 @@ def degree_bucketing(cached_graph, v): ...@@ -38,37 +39,32 @@ def degree_bucketing(cached_graph, v):
#print('degree-bucketing:', unique_degrees, [len(b) for b in v_bkt]) #print('degree-bucketing:', unique_degrees, [len(b) for b in v_bkt])
return unique_degrees, v_bkt return unique_degrees, v_bkt
class Executor(object): class Executor(object):
def run(self, graph): def run(self):
raise NotImplementedError raise NotImplementedError
class UpdateAllSPMVExecutor(Executor): class SPMVOperator(Executor):
def __init__(self, graph, src_field, dst_field, edge_field, use_adj): def __init__(self, src_field, edge_field, dst_field, use_edge_feat,
self.graph = graph node_repr, adj_build_fn):
self.src_field = src_field self.src_field = src_field
self.dst_field = dst_field
self.edge_field = edge_field self.edge_field = edge_field
self.use_adj = use_adj self.dst_field = dst_field
self.use_edge_feat = use_edge_feat
self.node_repr = node_repr
self.adj_build_fn = adj_build_fn
def run(self): def run(self):
g = self.graph # get src col
if self.src_field is None: if self.src_field is None:
srccol = g.get_n_repr() srccol = self.node_repr
else: else:
srccol = g.get_n_repr()[self.src_field] srccol = self.node_repr[self.src_field]
ctx = F.get_context(srccol) ctx = F.get_context(srccol)
if self.use_adj:
adjmat = g.cached_graph.adjmat().get(ctx) # build adjmat
else: adjmat = self.adj_build_fn(self.edge_field, ctx, self.use_edge_feat)
if self.edge_field is None:
dat = g.get_e_repr()
else:
dat = g.get_e_repr()[self.edge_field]
dat = F.squeeze(dat)
# TODO(minjie): should not directly use _indices
idx = g.cached_graph.adjmat().get(ctx)._indices()
n = g.number_of_nodes()
adjmat = F.sparse_tensor(idx, dat, [n, n])
# spmm # spmm
if len(F.shape(srccol)) == 1: if len(F.shape(srccol)) == 1:
srccol = F.unsqueeze(srccol, 1) srccol = F.unsqueeze(srccol, 1)
...@@ -77,104 +73,249 @@ class UpdateAllSPMVExecutor(Executor): ...@@ -77,104 +73,249 @@ class UpdateAllSPMVExecutor(Executor):
else: else:
dstcol = F.spmm(adjmat, srccol) dstcol = F.spmm(adjmat, srccol)
if self.dst_field is None: if self.dst_field is None:
g.set_n_repr(dstcol) return dstcol
else: else:
g.set_n_repr({self.dst_field : dstcol}) return {self.dst_field : dstcol}
class SendRecvSPMVExecutor(Executor):
def __init__(self, graph, src, dst, src_field, dst_field, edge_field, use_edge_dat):
self.graph = graph
self.src = src
self.dst = dst
self.src_field = src_field
self.dst_field = dst_field
self.edge_field = edge_field
self.use_edge_dat = use_edge_dat
def run(self): class BasicExecutor(Executor):
# get src col def __init__(self, graph, mfunc, rfunc):
g = self.graph self.g = graph
if self.src_field is None: self.exe = self._build_exec(mfunc, rfunc)
srccol = g.get_n_repr()
@property
def node_repr(self):
raise NotImplementedError
@property
def edge_repr(self):
raise NotImplementedError
@property
def graph_mapping(self):
raise NotImplementedError
def _build_exec(self, mfunc, rfunc):
if isinstance(mfunc, fmsg.CopySrcMessageFunction):
exe = SPMVOperator(src_field=mfunc.src_field,
edge_field=None,
dst_field=rfunc.out_field,
use_edge_feat=False,
node_repr=self.node_repr,
adj_build_fn=self._adj_build_fn)
elif isinstance(mfunc, fmsg.SrcMulEdgeMessageFunction):
exe = SPMVOperator(src_field=mfunc.src_field,
edge_field=mfunc.edge_field,
dst_field=rfunc.out_field,
use_edge_feat=True,
node_repr=self.node_repr,
adj_build_fn=self._adj_build_fn)
else: else:
srccol = g.get_n_repr()[self.src_field] raise NotImplementedError("message func type {}".format(type(mfunc)))
ctx = F.get_context(srccol) return exe
# build adjmat def run(self):
# build adjmat dat attr = self.exe.run()
u, v = utils.edge_broadcasting(self.src, self.dst) self.g.set_n_repr(attr, self.graph_mapping)
if self.use_edge_dat:
if self.edge_field is None:
dat = g.get_e_repr(u, v) class UpdateAllExecutor(BasicExecutor):
def __init__(self, graph, mfunc, rfunc):
self._init_state()
super(UpdateAllExecutor, self).__init__(graph, mfunc, rfunc)
def _init_state(self):
self._node_repr = None
self._edge_repr = None
self._graph_idx = None
self._graph_shape = None
self._graph_mapping = None
@property
def graph_idx(self):
if self._graph_idx is None:
self._graph_idx = self.g._graph.adjacency_matrix()
return self._graph_idx
@property
def graph_shape(self):
if self._graph_shape is None:
n = self.g.number_of_nodes()
self._graph_shape = [n, n]
return self._graph_shape
@property
def graph_mapping(self):
return ALL
@property
def node_repr(self):
if self._node_repr is None:
self._node_repr = self.g.get_n_repr()
return self._node_repr
@property
def edge_repr(self):
if self._edge_repr is None:
self._edge_repr = self.g.get_e_repr()
return self._edge_repr
def _adj_build_fn(self, edge_field, ctx, use_edge_feat):
if use_edge_feat:
if edge_field is None:
dat = self.edge_repr
else: else:
dat = g.get_e_repr(u, v)[self.edge_field] dat = self.edge_repr[edge_field]
dat = F.squeeze(dat) dat = F.squeeze(dat)
# TODO(minjie): should not directly use _indices
idx = self.graph_idx.get(ctx)._indices()
adjmat = F.sparse_tensor(idx, dat, self.graph_shape)
else: else:
dat = F.ones((len(u),)) adjmat = self.graph_idx.get(ctx)
# build adjmat index return adjmat
new2old, old2new = utils.build_relabel_map(v)
u = u.totensor()
v = v.totensor() class SendRecvExecutor(BasicExecutor):
def __init__(self, graph, src, dst, mfunc, rfunc):
self._init_state(src, dst)
super(SendRecvExecutor, self).__init__(graph, mfunc, rfunc)
def _init_state(self, src, dst):
self.u, self.v = utils.edge_broadcasting(src, dst)
self._node_repr = None
self._edge_repr = None
self._graph_idx = None
self._graph_shape = None
self._graph_mapping = None
@property
def graph_idx(self):
if self._graph_idx is None:
self._build_adjmat()
return self._graph_idx
@property
def graph_shape(self):
if self._graph_shape is None:
self._build_adjmat()
return self._graph_shape
@property
def graph_mapping(self):
if self._graph_mapping is None:
self._build_adjmat()
return self._graph_mapping
@property
def node_repr(self):
if self._node_repr is None:
self._node_repr = self.g.get_n_repr()
return self._node_repr
@property
def edge_repr(self):
if self._edge_repr is None:
self._edge_repr = self.g.get_e_repr(self.u, self.v)
return self._edge_repr
def _build_adjmat(self):
# handle graph index
new2old, old2new = utils.build_relabel_map(self.v)
u = self.u.tousertensor()
v = self.v.tousertensor()
# TODO(minjie): should not directly use [] # TODO(minjie): should not directly use []
new_v = old2new[v] new_v = old2new[v]
idx = F.pack([F.unsqueeze(new_v, 0), F.unsqueeze(u, 0)]) n = self.g.number_of_nodes()
n = g.number_of_nodes()
m = len(new2old) m = len(new2old)
adjmat = F.sparse_tensor(idx, dat, [m, n]) self._graph_idx = F.pack([F.unsqueeze(new_v, 0), F.unsqueeze(u, 0)])
adjmat = F.to_context(adjmat, ctx) self._graph_shape = [m, n]
# spmm self._graph_mapping = new2old
if len(F.shape(srccol)) == 1:
srccol = F.unsqueeze(srccol, 1) def _adj_build_fn(self, edge_field, ctx, use_edge_feat):
dstcol = F.spmm(adjmat, srccol) if use_edge_feat:
dstcol = F.squeeze(dstcol) if edge_field is None:
else: dat = self.edge_repr
dstcol = F.spmm(adjmat, srccol) else:
if self.dst_field is None: dat = self.edge_repr[edge_field]
g.set_n_repr(dstcol, new2old) dat = F.squeeze(dat)
else: else:
g.set_n_repr({self.dst_field : dstcol}, new2old) dat = F.ones((len(self.u), ))
adjmat = F.sparse_tensor(self.graph_idx, dat, self.graph_shape)
return F.to_context(adjmat, ctx)
def _is_spmv_supported_node_feat(g, field):
if field is None: class BundledExecutor(BasicExecutor):
feat = g.get_n_repr() """
else: Base class for Bundled execution
feat = g.get_n_repr()[field] All shared structure like graph index should be cached in this class or its subclass
shape = F.shape(feat) BundledUpdateAllExecutor and BundledSendRecvExecutor should subclass BundledExecutor
return (len(shape) == 1 or len(shape) == 2) """
def __init__(self, graph, mfunc, rfunc):
def _is_spmv_supported_edge_feat(g, field): self.g = graph
# check shape, only scalar edge feature can be optimized at the moment. func_pairs = self._match_message_with_reduce(mfunc, rfunc)
if field is None: # create all executors
feat = g.get_e_repr() self.executors = self._build_executors(func_pairs)
def _build_executors(self, func_pairs):
executors = []
for mfunc, rfunc in func_pairs:
exe = self._build_exec(mfunc, rfunc)
executors.append(exe)
return executors
def _match_message_with_reduce(self, mfunc, rfunc):
out2mfunc = {fn.out_field: fn for fn in mfunc.fn_list}
func_pairs = []
for rfn in rfunc.fn_list:
mfn = out2mfunc.get(rfn.msg_field, None)
# field check
assert mfn is not None, \
"cannot find message func for reduce func in-field {}".format(rfn.msg_field)
func_pairs.append((mfn, rfn))
return func_pairs
def run(self):
attr = None
for exe in self.executors:
res = exe.run()
if attr is None:
attr = res
else:
# attr and res must be dict
attr.update(res)
self.g.set_n_repr(attr, self.graph_mapping)
class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor):
def __init__(self, graph, mfunc, rfunc):
self._init_state()
BundledExecutor.__init__(self, graph, mfunc, rfunc)
class BundledSendRecvExecutor(BundledExecutor, SendRecvExecutor):
def __init__(self, graph, src, dst, mfunc, rfunc):
self._init_state(src, dst)
BundledExecutor.__init__(self, graph, mfunc, rfunc)
def _is_spmv_supported(fn, graph=None):
if isinstance(fn, fmsg.MessageFunction):
return fn.is_spmv_supported(graph)
elif isinstance(fn, fred.ReduceFunction):
return fn.is_spmv_supported()
else: else:
feat = g.get_e_repr()[field] return False
shape = F.shape(feat)
return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1)
def _create_update_all_exec(graph, **kwargs): def _create_update_all_exec(graph, **kwargs):
mfunc = kwargs.pop('message_func') mfunc = kwargs.pop('message_func')
rfunc = kwargs.pop('reduce_func') rfunc = kwargs.pop('reduce_func')
if (isinstance(mfunc, fmsg.CopySrcMessageFunction) if isinstance(mfunc, (list, tuple)) or isinstance(rfunc, (list, tuple)):
and isinstance(rfunc, fred.SumReducerFunction) mfunc = fmsg.BundledMessageFunction(mfunc)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)): rfunc = fred.BundledReduceFunction(rfunc)
# TODO(minjie): more sanity check on field names exec_cls = BundledUpdateAllExecutor
return UpdateAllSPMVExecutor(graph, else:
src_field=mfunc.src_field, exec_cls = UpdateAllExecutor
dst_field=rfunc.out_field, if _is_spmv_supported(mfunc, graph) and _is_spmv_supported(rfunc):
edge_field=None, return exec_cls(graph, mfunc=mfunc, rfunc=rfunc)
use_adj=True)
elif (isinstance(mfunc, fmsg.SrcMulEdgeMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)
and _is_spmv_supported_edge_feat(graph, mfunc.edge_field)):
return UpdateAllSPMVExecutor(graph,
src_field=mfunc.src_field,
dst_field=rfunc.out_field,
edge_field=mfunc.edge_field,
use_adj=False)
elif (isinstance(mfunc, fmsg.CopyEdgeMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)):
return None
else: else:
return None return None
...@@ -183,28 +324,14 @@ def _create_send_and_recv_exec(graph, **kwargs): ...@@ -183,28 +324,14 @@ def _create_send_and_recv_exec(graph, **kwargs):
dst = kwargs.pop('dst') dst = kwargs.pop('dst')
mfunc = kwargs.pop('message_func') mfunc = kwargs.pop('message_func')
rfunc = kwargs.pop('reduce_func') rfunc = kwargs.pop('reduce_func')
if (isinstance(mfunc, fmsg.CopySrcMessageFunction) if isinstance(mfunc, (list, tuple)) or isinstance(rfunc, (list, tuple)):
and isinstance(rfunc, fred.SumReducerFunction) mfunc = fmsg.BundledMessageFunction(mfunc)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)): rfunc = fred.BundledReduceFunction(rfunc)
# TODO(minjie): more sanity check on field names exec_cls = BundledSendRecvExecutor
return SendRecvSPMVExecutor(graph, else:
src=src, exec_cls = SendRecvExecutor
dst=dst, if _is_spmv_supported(mfunc, graph) and _is_spmv_supported(rfunc):
src_field=mfunc.src_field, return exec_cls(graph, src=src, dst=dst, mfunc=mfunc, rfunc=rfunc)
dst_field=rfunc.out_field,
edge_field=None,
use_edge_dat=False)
elif (isinstance(mfunc, fmsg.SrcMulEdgeMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)
and _is_spmv_supported_edge_feat(graph, mfunc.edge_field)):
return SendRecvSPMVExecutor(graph,
src=src,
dst=dst,
src_field=mfunc.src_field,
dst_field=rfunc.out_field,
edge_field=mfunc.edge_field,
use_edge_dat=True)
else: else:
return None return None
......
"""DGLSubGraph""" """Class for subgraph data structure."""
from __future__ import absolute_import from __future__ import absolute_import
import networkx as nx import networkx as nx
import dgl.backend as F
from dgl.frame import Frame, FrameRef from . import backend as F
from dgl.graph import DGLGraph from .frame import Frame, FrameRef
from dgl.nx_adapt import nx_init from .graph import DGLGraph
import dgl.utils as utils from . import utils
class DGLSubGraph(DGLGraph): class DGLSubGraph(DGLGraph):
# TODO(gaiyu): ReadOnlyGraph """The subgraph class.
def __init__(self,
parent,
nodes):
super(DGLSubGraph, self).__init__()
# relabel nodes
self._node_mapping = utils.build_relabel_dict(nodes)
self._parent_nid = utils.toindex(nodes)
eids = []
# create subgraph
for eid, (u, v) in enumerate(parent.edge_list):
if u in self._node_mapping and v in self._node_mapping:
self.add_edge(self._node_mapping[u],
self._node_mapping[v])
eids.append(eid)
self._parent_eid = utils.toindex(eids)
def copy_from(self, parent):
"""Copy node/edge features from the parent graph.
All old features will be removed. There are two subgraph modes: shared and non-shared.
For the "non-shared" mode, the user needs to explicitly call
``copy_from_parent`` to copy node/edge features from its parent graph.
* If the user tries to get node/edge features before ``copy_from_parent``,
s/he will get nothing.
* If the subgraph already has its own node/edge features, ``copy_from_parent``
will override them.
* Any update on the subgraph's node/edge features will not be seen
by the parent graph. As such, the memory consumption is of the order
of the subgraph size.
* To write the subgraph's node/edge features back to parent graph. There are two options:
(1) Use ``copy_to_parent`` API to write node/edge features back.
(2) [TODO] Use ``dgl.merge`` to merge multiple subgraphs back to one parent.
The "shared" mode is currently not supported.
The subgraph is read-only so mutation is not allowed.
Parameters
----------
parent : DGLGraph
The parent graph
parent_nid : utils.Index
The induced parent node ids in this subgraph.
parent_eid : utils.Index
The induced parent edge ids in this subgraph.
graph_idx : GraphIndex
The graph index.
shared : bool, optional
Whether the subgraph shares node/edge features with the parent graph.
"""
def __init__(self, parent, parent_nid, parent_eid, graph_idx, shared=False):
super(DGLSubGraph, self).__init__(graph_data=graph_idx)
self._parent = parent
self._parent_nid = parent_nid
self._parent_eid = parent_eid
# override APIs
def add_nodes(self, num, reprs=None):
"""Add nodes. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v, reprs=None):
"""Add one edge. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v, reprs=None):
"""Add many edges. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
@property
def parent_nid(self):
"""Get the parent node ids.
The returned tensor can be used as a map from the node id
in this subgraph to the node id in the parent graph.
Returns
-------
Tensor
The parent node id array.
"""
return self._parent_nid.tousertensor()
@property
def parent_eid(self):
"""Get the parent edge ids.
The returned tensor can be used as a map from the edge id
in this subgraph to the edge id in the parent graph.
Returns
-------
Tensor
The parent edge id array.
"""
return self._parent_eid.tousertensor()
def copy_to_parent(self, inplace=False):
"""Write node/edge features to the parent graph.
Parameters Parameters
---------- ----------
parent : DGLGraph inplace : bool
The parent graph to copy from. If true, use inplace write (no gradient but faster)
"""
self._parent._node_frame.update_rows(
self._parent_nid, self._node_frame, inplace=inplace)
self._parent._edge_frame.update_rows(
self._parent_eid, self._edge_frame, inplace=inplace)
def copy_from_parent(self):
"""Copy node/edge features from the parent graph.
All old features will be removed.
""" """
if parent._node_frame.num_rows != 0: if self._parent._node_frame.num_rows != 0:
self._node_frame = FrameRef(Frame(parent._node_frame[self._parent_nid])) self._node_frame = FrameRef(Frame(
if parent._edge_frame.num_rows != 0: self._parent._node_frame[self._parent_nid]))
self._edge_frame = FrameRef(Frame(parent._edge_frame[self._parent_eid])) if self._parent._edge_frame.num_rows != 0:
self._edge_frame = FrameRef(Frame(
self._parent._edge_frame[self._parent_eid]))
...@@ -5,50 +5,70 @@ from collections import Mapping ...@@ -5,50 +5,70 @@ from collections import Mapping
from functools import wraps from functools import wraps
import numpy as np import numpy as np
import dgl.backend as F from . import backend as F
from dgl.backend import Tensor, SparseTensor from .backend import Tensor, SparseTensor
from . import ndarray as nd
def is_id_tensor(u):
"""Return whether the input is a supported id tensor."""
return isinstance(u, Tensor) and F.isinteger(u) and len(F.shape(u)) == 1
def is_id_container(u):
"""Return whether the input is a supported id container."""
return (getattr(u, '__iter__', None) is not None
and getattr(u, '__len__', None) is not None)
class Index(object): class Index(object):
"""Index class that can be easily converted to list/tensor.""" """Index class that can be easily converted to list/tensor."""
def __init__(self, data): def __init__(self, data):
self._list_data = None self._list_data = None # a numpy type data
self._tensor_data = None self._user_tensor_data = dict() # dictionary of user tensors
self._ctx_data = dict() self._dgl_tensor_data = None # a dgl ndarray
self._dispatch(data) self._dispatch(data)
def _dispatch(self, data): def _dispatch(self, data):
if is_id_tensor(data): """Store data based on its type."""
self._tensor_data = data if isinstance(data, Tensor):
elif is_id_container(data): if not (F.dtype(data) == F.int64 and len(F.shape(data)) == 1):
self._list_data = data raise ValueError('Index data must be 1D int64 vector, but got: %s' % str(data))
self._user_tensor_data[F.get_context(data)] = data
elif isinstance(data, nd.NDArray):
if not (data.dtype == 'int64' and len(data.shape) == 1):
raise ValueError('Index data must be 1D int64 vector, but got: %s' % str(data))
self._dgl_tensor_data = data
else: else:
try: try:
self._list_data = [int(data)] self._list_data = np.array([int(data)]).astype(np.int64)
except: except:
raise TypeError('Error index data: %s' % str(x)) try:
self._list_data = np.array(data).astype(np.int64)
except:
raise ValueError('Error index data: %s' % str(data))
self._user_tensor_data[nd.cpu()] = F.zerocopy_from_numpy(self._list_data)
def tolist(self): def tolist(self):
"""Convert to a python-list compatible object."""
if self._list_data is None: if self._list_data is None:
self._list_data = list(F.asnumpy(self._tensor_data)) if self._dgl_tensor_data is not None:
self._list_data = self._dgl_tensor_data.asnumpy()
else:
data = self.tousertensor()
self._list_data = F.zerocopy_to_numpy(data)
return self._list_data return self._list_data
def totensor(self, ctx=None): def tousertensor(self, ctx=None):
if self._tensor_data is None: """Convert to user tensor (defined in `backend`)."""
self._tensor_data = F.tensor(self._list_data, dtype=F.int64)
if ctx is None: if ctx is None:
return self._tensor_data ctx = nd.cpu()
if ctx not in self._ctx_data: if len(self._user_tensor_data) == 0:
self._ctx_data[ctx] = F.to_context(self._tensor_data, ctx) # zero copy from dgl tensor
return self._ctx_data[ctx] dl = self._dgl_tensor_data.to_dlpack()
self._user_tensor_data[nd.cpu()] = F.zerocopy_from_dlpack(dl)
if ctx not in self._user_tensor_data:
# copy from cpu to another device
data = next(iter(self._user_tensor_data.values()))
self._user_tensor_data[ctx] = F.to_context(data, ctx)
return self._user_tensor_data[ctx]
def todgltensor(self):
"""Convert to dgl.NDArray."""
if self._dgl_tensor_data is None:
# zero copy from user tensor
tsor = self.tousertensor()
dl = F.zerocopy_to_dlpack(tsor)
self._dgl_tensor_data = nd.from_dlpack(dl)
return self._dgl_tensor_data
def __iter__(self): def __iter__(self):
return iter(self.tolist()) return iter(self.tolist())
...@@ -56,8 +76,11 @@ class Index(object): ...@@ -56,8 +76,11 @@ class Index(object):
def __len__(self): def __len__(self):
if self._list_data is not None: if self._list_data is not None:
return len(self._list_data) return len(self._list_data)
elif len(self._user_tensor_data) > 0:
data = next(iter(self._user_tensor_data.values()))
return len(data)
else: else:
return len(self._tensor_data) return len(self._dgl_tensor_data)
def __getitem__(self, i): def __getitem__(self, i):
return self.tolist()[i] return self.tolist()[i]
...@@ -118,40 +141,13 @@ def edge_broadcasting(u, v): ...@@ -118,40 +141,13 @@ def edge_broadcasting(u, v):
The dst id(s) after broadcasting The dst id(s) after broadcasting
""" """
if len(u) != len(v) and len(u) == 1: if len(u) != len(v) and len(u) == 1:
u = toindex(F.broadcast_to(u.totensor(), v.totensor())) u = toindex(F.broadcast_to(u.tousertensor(), v.tousertensor()))
elif len(u) != len(v) and len(v) == 1: elif len(u) != len(v) and len(v) == 1:
v = toindex(F.broadcast_to(v.totensor(), u.totensor())) v = toindex(F.broadcast_to(v.tousertensor(), u.tousertensor()))
else: else:
assert len(u) == len(v) assert len(u) == len(v)
return u, v return u, v
'''
def convert_to_id_container(x):
if is_id_container(x):
return x
elif is_id_tensor(x):
return F.asnumpy(x)
else:
try:
return [int(x)]
except:
raise TypeError('Error node: %s' % str(x))
return None
def convert_to_id_tensor(x, ctx=None):
if is_id_container(x):
ret = F.tensor(x, dtype=F.int64)
elif is_id_tensor(x):
ret = x
else:
try:
ret = F.tensor([int(x)], dtype=F.int64)
except:
raise TypeError('Error node: %s' % str(x))
ret = F.to_context(ret, ctx)
return ret
'''
class LazyDict(Mapping): class LazyDict(Mapping):
"""A readonly dictionary that does not materialize the storage.""" """A readonly dictionary that does not materialize the storage."""
def __init__(self, fn, keys): def __init__(self, fn, keys):
...@@ -209,7 +205,7 @@ def build_relabel_map(x): ...@@ -209,7 +205,7 @@ def build_relabel_map(x):
One can use advanced indexing to convert an old id tensor to a One can use advanced indexing to convert an old id tensor to a
new id tensor: new_id = old_to_new[old_id] new id tensor: new_id = old_to_new[old_id]
""" """
x = x.totensor() x = x.tousertensor()
unique_x, _ = F.sort(F.unique(x)) unique_x, _ = F.sort(F.unique(x))
map_len = int(F.max(unique_x)) + 1 map_len = int(F.max(unique_x)) + 1
old_to_new = F.zeros(map_len, dtype=F.int64) old_to_new = F.zeros(map_len, dtype=F.int64)
...@@ -316,6 +312,6 @@ def reorder(dict_like, index): ...@@ -316,6 +312,6 @@ def reorder(dict_like, index):
""" """
new_dict = {} new_dict = {}
for key, val in dict_like.items(): for key, val in dict_like.items():
idx_ctx = index.totensor(F.get_context(val)) idx_ctx = index.tousertensor(F.get_context(val))
new_dict[key] = F.gather_row(val, idx_ctx) new_dict[key] = F.gather_row(val, idx_ctx)
return new_dict return new_dict
...@@ -17,7 +17,6 @@ setuptools.setup( ...@@ -17,7 +17,6 @@ setuptools.setup(
'numpy>=1.14.0', 'numpy>=1.14.0',
'scipy>=1.1.0', 'scipy>=1.1.0',
'networkx>=2.1', 'networkx>=2.1',
'python-igraph>=0.7.0',
], ],
data_files=[('', ['VERSION'])], data_files=[('', ['VERSION'])],
url='https://github.com/jermainewang/dgl-1') url='https://github.com/jermainewang/dgl')
// Graph class implementation
#include <algorithm>
#include <unordered_map>
#include <dgl/graph.h>
namespace dgl {
namespace {
inline bool IsValidIdArray(const IdArray& arr) {
return arr->ctx.device_type == kDLCPU && arr->ndim == 1
&& arr->dtype.code == kDLInt && arr->dtype.bits == 64;
}
} // namespace
void Graph::AddVertices(uint64_t num_vertices) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
adjlist_.resize(adjlist_.size() + num_vertices);
reverse_adjlist_.resize(reverse_adjlist_.size() + num_vertices);
}
void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
CHECK(HasVertex(src) && HasVertex(dst))
<< "Invalid vertices: src=" << src << " dst=" << dst;
dgl_id_t eid = num_edges_++;
adjlist_[src].succ.push_back(dst);
adjlist_[src].edge_id.push_back(eid);
reverse_adjlist_[dst].succ.push_back(src);
reverse_adjlist_[dst].edge_id.push_back(eid);
all_edges_src_.push_back(src);
all_edges_dst_.push_back(dst);
}
void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
if (srclen == 1) {
// one-many
for (int64_t i = 0; i < dstlen; ++i) {
AddEdge(src_data[0], dst_data[i]);
}
} else if (dstlen == 1) {
// many-one
for (int64_t i = 0; i < srclen; ++i) {
AddEdge(src_data[i], dst_data[0]);
}
} else {
// many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
AddEdge(src_data[i], dst_data[i]);
}
}
}
BoolArray Graph::HasVertices(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
BoolArray rst = BoolArray::Empty({len}, vids->dtype, vids->ctx);
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
const int64_t nverts = NumVertices();
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = (vid_data[i] < nverts)? 1 : 0;
}
return rst;
}
// O(E)
bool Graph::HasEdge(dgl_id_t src, dgl_id_t dst) const {
if (!HasVertex(src) || !HasVertex(dst)) return false;
const auto& succ = adjlist_[src].succ;
return std::find(succ.begin(), succ.end(), dst) != succ.end();
}
// O(E*K) pretty slow
BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen);
BoolArray rst = BoolArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
if (srclen == 1) {
// one-many
for (int64_t i = 0; i < dstlen; ++i) {
rst_data[i] = HasEdge(src_data[0], dst_data[i])? 1 : 0;
}
} else if (dstlen == 1) {
// many-one
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdge(src_data[i], dst_data[0])? 1 : 0;
}
} else {
// many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdge(src_data[i], dst_data[i])? 1 : 0;
}
}
return rst;
}
// The data is copy-out; support zero-copy?
IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
CHECK(radius >= 1) << "invalid radius: " << radius;
const auto& pred = reverse_adjlist_[vid].succ;
const int64_t len = pred.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = pred[i];
}
return rst;
}
// The data is copy-out; support zero-copy?
IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
CHECK(radius >= 1) << "invalid radius: " << radius;
const auto& succ = adjlist_[vid].succ;
const int64_t len = succ.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = succ[i];
}
return rst;
}
// O(E)
dgl_id_t Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
CHECK(HasVertex(src)) << "invalid edge: " << src << " -> " << dst;
const auto& succ = adjlist_[src].succ;
for (size_t i = 0; i < succ.size(); ++i) {
if (succ[i] == dst) {
return adjlist_[src].edge_id[i];
}
}
LOG(FATAL) << "invalid edge: " << src << " -> " << dst;
return 0;
}
// O(E*k) pretty slow
IdArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen);
IdArray rst = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
if (srclen == 1) {
// one-many
for (int64_t i = 0; i < dstlen; ++i) {
rst_data[i] = EdgeId(src_data[0], dst_data[i]);
}
} else if (dstlen == 1) {
// many-one
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = EdgeId(src_data[i], dst_data[0]);
}
} else {
// many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = EdgeId(src_data[i], dst_data[i]);
}
}
return rst;
}
// O(E)
Graph::EdgeArray Graph::InEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = reverse_adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* src_data = static_cast<int64_t*>(src->data);
int64_t* dst_data = static_cast<int64_t*>(dst->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
src_data[i] = reverse_adjlist_[vid].succ[i];
eid_data[i] = reverse_adjlist_[vid].edge_id[i];
}
std::fill(dst_data, dst_data + len, vid);
return EdgeArray{src, dst, eid};
}
// O(E)
Graph::EdgeArray Graph::InEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t rstlen = 0;
for (int64_t i = 0; i < len; ++i) {
CHECK(HasVertex(vid_data[i])) << "Invalid vertex: " << vid_data[i];
rstlen += reverse_adjlist_[vid_data[i]].succ.size();
}
IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray eid = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
const auto& pred = reverse_adjlist_[vid_data[i]].succ;
const auto& eids = reverse_adjlist_[vid_data[i]].edge_id;
for (size_t j = 0; j < pred.size(); ++j) {
*(src_ptr++) = pred[j];
*(dst_ptr++) = vid_data[i];
*(eid_ptr++) = eids[j];
}
}
return EdgeArray{src, dst, eid};
}
// O(E)
Graph::EdgeArray Graph::OutEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* src_data = static_cast<int64_t*>(src->data);
int64_t* dst_data = static_cast<int64_t*>(dst->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
dst_data[i] = adjlist_[vid].succ[i];
eid_data[i] = adjlist_[vid].edge_id[i];
}
std::fill(src_data, src_data + len, vid);
return EdgeArray{src, dst, eid};
}
// O(E)
Graph::EdgeArray Graph::OutEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t rstlen = 0;
for (int64_t i = 0; i < len; ++i) {
CHECK(HasVertex(vid_data[i])) << "Invalid vertex: " << vid_data[i];
rstlen += adjlist_[vid_data[i]].succ.size();
}
IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray eid = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
const auto& succ = adjlist_[vid_data[i]].succ;
const auto& eids = adjlist_[vid_data[i]].edge_id;
for (size_t j = 0; j < succ.size(); ++j) {
*(src_ptr++) = vid_data[i];
*(dst_ptr++) = succ[j];
*(eid_ptr++) = eids[j];
}
}
return EdgeArray{src, dst, eid};
}
// O(E*log(E)) if sort is required; otherwise, O(E)
Graph::EdgeArray Graph::Edges(bool sorted) const {
const int64_t len = num_edges_;
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
if (sorted) {
typedef std::tuple<int64_t, int64_t, int64_t> Tuple;
std::vector<Tuple> tuples;
tuples.reserve(len);
for (uint64_t eid = 0; eid < num_edges_; ++eid) {
tuples.emplace_back(all_edges_src_[eid], all_edges_dst_[eid], eid);
}
// sort according to src and dst ids
std::sort(tuples.begin(), tuples.end(),
[] (const Tuple& t1, const Tuple& t2) {
return std::get<0>(t1) < std::get<0>(t2)
|| (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2));
});
// make return arrays
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
for (size_t i = 0; i < tuples.size(); ++i) {
src_ptr[i] = std::get<0>(tuples[i]);
dst_ptr[i] = std::get<1>(tuples[i]);
eid_ptr[i] = std::get<2>(tuples[i]);
}
} else {
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
std::copy(all_edges_src_.begin(), all_edges_src_.end(), src_ptr);
std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), dst_ptr);
for (uint64_t eid = 0; eid < num_edges_; ++eid) {
eid_ptr[eid] = eid;
}
}
return EdgeArray{src, dst, eid};
}
// O(V)
DegreeArray Graph::InDegrees(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
const auto vid = vid_data[i];
CHECK(HasVertex(vid)) << "Invalid vertex: " << vid;
rst_data[i] = reverse_adjlist_[vid].succ.size();
}
return rst;
}
// O(V)
DegreeArray Graph::OutDegrees(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
const auto vid = vid_data[i];
CHECK(HasVertex(vid)) << "Invalid vertex: " << vid;
rst_data[i] = adjlist_[vid].succ.size();
}
return rst;
}
Subgraph Graph::VertexSubgraph(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;
std::vector<dgl_id_t> edges;
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
for (int64_t i = 0; i < len; ++i) {
oldv2newv[vid_data[i]] = i;
}
Subgraph rst;
rst.induced_vertices = vids;
rst.graph.AddVertices(len);
for (int64_t i = 0; i < len; ++i) {
const dgl_id_t oldvid = vid_data[i];
const dgl_id_t newvid = i;
for (size_t j = 0; j < adjlist_[oldvid].succ.size(); ++j) {
const dgl_id_t oldsucc = adjlist_[oldvid].succ[j];
if (oldv2newv.count(oldsucc)) {
const dgl_id_t newsucc = oldv2newv[oldsucc];
edges.push_back(adjlist_[oldvid].edge_id[j]);
rst.graph.AddEdge(newvid, newsucc);
}
}
}
rst.induced_edges = IdArray::Empty({static_cast<int64_t>(edges.size())}, vids->dtype, vids->ctx);
std::copy(edges.begin(), edges.end(), static_cast<int64_t*>(rst.induced_edges->data));
return rst;
}
Subgraph Graph::EdgeSubgraph(IdArray src, IdArray dst) const {
LOG(FATAL) << "not implemented";
return Subgraph();
}
Graph Graph::Reverse() const {
LOG(FATAL) << "not implemented";
return *this;
}
} // namespace dgl
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include <dgl/graph.h>
#include <dgl/graph_op.h>
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMArgValue;
using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc;
using tvm::runtime::NDArray;
namespace dgl {
// Graph handler type
typedef void* GraphHandle;
namespace {
// Convert EdgeArray structure to PackedFunc.
PackedFunc ConvertEdgeArrayToPackedFunc(const Graph::EdgeArray& ea) {
auto body = [ea] (TVMArgs args, TVMRetValue* rv) {
int which = args[0];
if (which == 0) {
*rv = std::move(ea.src);
} else if (which == 1) {
*rv = std::move(ea.dst);
} else if (which == 2) {
*rv = std::move(ea.id);
} else {
LOG(FATAL) << "invalid choice";
}
};
return PackedFunc(body);
}
// Convert Subgraph structure to PackedFunc.
PackedFunc ConvertSubgraphToPackedFunc(const Subgraph& sg) {
auto body = [sg] (TVMArgs args, TVMRetValue* rv) {
int which = args[0];
if (which == 0) {
Graph* gptr = new Graph();
*gptr = std::move(sg.graph);
GraphHandle ghandle = gptr;
*rv = ghandle;
} else if (which == 1) {
*rv = std::move(sg.induced_vertices);
} else if (which == 2) {
*rv = std::move(sg.induced_edges);
} else {
LOG(FATAL) << "invalid choice";
}
};
return PackedFunc(body);
}
// Convert the given DLTensor to a temporary DLManagedTensor that does not own memory.
DLManagedTensor* CreateTmpDLManagedTensor(const TVMArgValue& arg) {
const DLTensor* dl_tensor = arg;
DLManagedTensor* ret = new DLManagedTensor();
ret->deleter = [] (DLManagedTensor* self) { delete self; };
ret->manager_ctx = nullptr;
ret->dl_tensor = *dl_tensor;
return ret;
}
} // namespace
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreate")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = new Graph();
*rv = ghandle;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFree")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
delete gptr;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
uint64_t num_vertices = args[1];
gptr->AddVertices(num_vertices);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdge")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
gptr->AddEdge(src, dst);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const IdArray dst = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
gptr->AddEdges(src, dst);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphClear")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
gptr->Clear();
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumVertices")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumVertices());
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumEdges());
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertex")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = gptr->HasVertex(vid);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertices")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = gptr->HasVertices(vids);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdge")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
*rv = gptr->HasEdge(src, dst);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const IdArray dst = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
*rv = gptr->HasEdges(src, dst);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
const uint64_t radius = args[2];
*rv = gptr->Predecessors(vid, radius);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphSuccessors")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
const uint64_t radius = args[2];
*rv = gptr->Successors(vid, radius);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
*rv = static_cast<int64_t>(gptr->EdgeId(src, dst));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const IdArray dst = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
*rv = gptr->EdgeIds(src, dst);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_1")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->InEdges(vid));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_2")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertEdgeArrayToPackedFunc(gptr->InEdges(vids));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_1")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->OutEdges(vid));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_2")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertEdgeArrayToPackedFunc(gptr->OutEdges(vids));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const bool sorted = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->Edges(sorted));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegree")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(gptr->InDegree(vid));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegrees")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = gptr->InDegrees(vids);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegree")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(gptr->OutDegree(vid));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegrees")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = gptr->OutDegrees(vids);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertSubgraphToPackedFunc(gptr->VertexSubgraph(vids));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
void* list = args[0];
GraphHandle* inhandles = static_cast<GraphHandle*>(list);
int list_size = args[1];
std::vector<const Graph*> graphs;
for (int i = 0; i < list_size; ++i) {
const Graph* gr = static_cast<const Graph*>(inhandles[i]);
graphs.push_back(gr);
}
Graph* gptr = new Graph();
*gptr = GraphOp::DisjointUnion(std::move(graphs));
GraphHandle ghandle = gptr;
*rv = ghandle;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
int64_t num = args[1];
std::vector<Graph>&& rst = GraphOp::DisjointPartitionByNum(gptr, num);
// return the pointer array as an integer array
const int64_t len = rst.size();
NDArray ptr_array = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* ptr_array_data = static_cast<int64_t*>(ptr_array->data);
for (size_t i = 0; i < rst.size(); ++i) {
Graph* ptr = new Graph();
*ptr = std::move(rst[i]);
ptr_array_data[i] = reinterpret_cast<std::intptr_t>(ptr);
}
*rv = ptr_array;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray sizes = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
std::vector<Graph>&& rst = GraphOp::DisjointPartitionBySizes(gptr, sizes);
// return the pointer array as an integer array
const int64_t len = rst.size();
NDArray ptr_array = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* ptr_array_data = static_cast<int64_t*>(ptr_array->data);
for (size_t i = 0; i < rst.size(); ++i) {
Graph* ptr = new Graph();
*ptr = std::move(rst[i]);
ptr_array_data[i] = reinterpret_cast<std::intptr_t>(ptr);
}
*rv = ptr_array;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLineGraph")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
bool backtracking = args[1];
const Graph* gptr = static_cast<Graph*>(ghandle);
Graph* lgptr = new Graph();
*lgptr = GraphOp::LineGraph(gptr, backtracking);
GraphHandle lghandle = lgptr;
*rv = lghandle;
});
} // namespace dgl
// Graph operation implementation
#include <dgl/graph_op.h>
#include <algorithm>
namespace dgl {
Graph GraphOp::LineGraph(const Graph* g, bool backtracking){
typedef std::pair<dgl_id_t, dgl_id_t> entry;
typedef std::map<dgl_id_t, std::vector<entry>> csm; // Compressed Sparse Matrix
csm adj;
std::vector<entry> vec;
for (size_t i = 0; i != g->all_edges_src_.size(); ++i) {
auto u = g->all_edges_src_[i];
auto v = g->all_edges_dst_[i];
auto ret = adj.insert(csm::value_type(u, vec));
(ret.first)->second.push_back(std::make_pair(v, i));
}
std::vector<dgl_id_t> lg_src, lg_dst;
for (size_t i = 0; i != g->all_edges_src_.size(); ++i) {
auto u = g->all_edges_src_[i];
auto v = g->all_edges_dst_[i];
auto j = adj.find(v);
if (j != adj.end()) {
for (size_t k = 0; k != j->second.size(); ++k) {
if (backtracking || (!backtracking && j->second[k].first != u)) {
lg_src.push_back(i);
lg_dst.push_back(j->second[k].second);
}
}
}
}
const int64_t len = lg_src.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
std::copy(lg_src.begin(), lg_src.end(), src_ptr);
std::copy(lg_dst.begin(), lg_dst.end(), dst_ptr);
Graph lg;
lg.AddVertices(g->NumEdges());
lg.AddEdges(src, dst);
return lg;
}
Graph GraphOp::DisjointUnion(std::vector<const Graph*> graphs) {
Graph rst;
uint64_t cumsum = 0;
for (const Graph* gr : graphs) {
rst.AddVertices(gr->NumVertices());
for (uint64_t i = 0; i < gr->NumEdges(); ++i) {
rst.AddEdge(gr->all_edges_src_[i] + cumsum, gr->all_edges_dst_[i] + cumsum);
}
cumsum += gr->NumVertices();
}
return rst;
}
std::vector<Graph> GraphOp::DisjointPartitionByNum(const Graph* graph, int64_t num) {
CHECK(num != 0 && graph->NumVertices() % num == 0)
<< "Number of partitions must evenly divide the number of nodes.";
IdArray sizes = IdArray::Empty({num}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num);
return DisjointPartitionBySizes(graph, sizes);
}
std::vector<Graph> GraphOp::DisjointPartitionBySizes(const Graph* graph, IdArray sizes) {
const int64_t len = sizes->shape[0];
const int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::vector<int64_t> cumsum;
cumsum.push_back(0);
for (int64_t i = 0; i < len; ++i) {
cumsum.push_back(cumsum[i] + sizes_data[i]);
}
CHECK_EQ(cumsum[len], graph->NumVertices())
<< "Sum of the given sizes must equal to the number of nodes.";
dgl_id_t node_offset = 0, edge_offset = 0;
std::vector<Graph> rst(len);
for (int64_t i = 0; i < len; ++i) {
// copy adj
rst[i].adjlist_.insert(rst[i].adjlist_.end(),
graph->adjlist_.begin() + node_offset,
graph->adjlist_.begin() + node_offset + sizes_data[i]);
rst[i].reverse_adjlist_.insert(rst[i].reverse_adjlist_.end(),
graph->reverse_adjlist_.begin() + node_offset,
graph->reverse_adjlist_.begin() + node_offset + sizes_data[i]);
// relabel adjs
size_t num_edges = 0;
for (auto& elist : rst[i].adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
}
num_edges += elist.succ.size();
}
for (auto& elist : rst[i].reverse_adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
}
}
// copy edges
rst[i].all_edges_src_.reserve(num_edges);
rst[i].all_edges_dst_.reserve(num_edges);
rst[i].num_edges_ = num_edges;
for (size_t j = edge_offset; j < edge_offset + num_edges; ++j) {
rst[i].all_edges_src_.push_back(graph->all_edges_src_[j] - node_offset);
rst[i].all_edges_dst_.push_back(graph->all_edges_dst_[j] - node_offset);
}
// update offset
CHECK_EQ(rst[i].NumVertices(), sizes_data[i]);
CHECK_EQ(rst[i].NumEdges(), num_edges);
node_offset += sizes_data[i];
edge_offset += num_edges;
}
/*for (int64_t i = 0; i < len; ++i) {
rst[i].AddVertices(sizes_data[i]);
}
for (dgl_id_t eid = 0; eid < graph->num_edges_; ++eid) {
const dgl_id_t src = graph->all_edges_src_[eid];
const dgl_id_t dst = graph->all_edges_dst_[eid];
size_t src_select = 0, dst_select = 0;
for (size_t i = 1; i < cumsum.size(); ++i) { // TODO: replace with binary search
if (cumsum[i] > src) {
src_select = i;
break;
}
}
for (size_t i = 1; i < cumsum.size(); ++i) { // TODO: replace with binary search
if (cumsum[i] > dst) {
dst_select = i;
break;
}
}
if (src_select != dst_select) {
// the edge is ignored if across two partitions
continue;
}
const int64_t offset = cumsum[src_select - 1];
rst[src_select - 1].AddEdge(src - offset, dst - offset);
}*/
return rst;
}
} // namespace dgl
# C API and runtime
Borrowed and adapted from TVM project.
/*!
* Copyright (c) 2016 by Contributors
* \file c_runtime_api.cc
* \brief Runtime API implementation
*/
#include <dmlc/thread_local.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <array>
#include <algorithm>
#include <string>
#include <cstdlib>
#include "runtime_base.h"
namespace tvm {
namespace runtime {
/*!
* \brief The name of Device API factory.
* \param type The device type.
*/
inline std::string DeviceName(int type) {
switch (type) {
case kDLCPU: return "cpu";
case kDLGPU: return "gpu";
case kDLOpenCL: return "opencl";
case kDLSDAccel: return "sdaccel";
case kDLAOCL: return "aocl";
case kDLVulkan: return "vulkan";
case kDLMetal: return "metal";
case kDLVPI: return "vpi";
case kDLROCM: return "rocm";
case kOpenGL: return "opengl";
case kExtDev: return "ext_dev";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
class DeviceAPIManager {
public:
static const int kMaxDeviceAPI = 32;
// Get API
static DeviceAPI* Get(const TVMContext& ctx) {
return Get(ctx.device_type);
}
static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
return Global()->GetAPI(dev_type, allow_missing);
}
private:
std::array<DeviceAPI*, kMaxDeviceAPI> api_;
DeviceAPI* rpc_api_{nullptr};
std::mutex mutex_;
// constructor
DeviceAPIManager() {
std::fill(api_.begin(), api_.end(), nullptr);
}
// Global static variable.
static DeviceAPIManager* Global() {
static DeviceAPIManager inst;
return &inst;
}
// Get or initialize API.
DeviceAPI* GetAPI(int type, bool allow_missing) {
if (type < kRPCSessMask) {
if (api_[type] != nullptr) return api_[type];
std::lock_guard<std::mutex> lock(mutex_);
if (api_[type] != nullptr) return api_[type];
api_[type] = GetAPI(DeviceName(type), allow_missing);
return api_[type];
} else {
if (rpc_api_ != nullptr) return rpc_api_;
std::lock_guard<std::mutex> lock(mutex_);
if (rpc_api_ != nullptr) return rpc_api_;
rpc_api_ = GetAPI("rpc", allow_missing);
return rpc_api_;
}
}
DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
std::string factory = "device_api." + name;
auto* f = Registry::Get(factory);
if (f == nullptr) {
CHECK(allow_missing)
<< "Device API " << name << " is not enabled.";
return nullptr;
}
void* ptr = (*f)();
return static_cast<DeviceAPI*>(ptr);
}
};
DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) {
return DeviceAPIManager::Get(
static_cast<int>(ctx.device_type), allow_missing);
}
void* DeviceAPI::AllocWorkspace(TVMContext ctx,
size_t size,
TVMType type_hint) {
return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
}
void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) {
FreeDataSpace(ctx, ptr);
}
TVMStreamHandle DeviceAPI::CreateStream(TVMContext ctx) {
LOG(FATAL) << "Device does not support stream api.";
return 0;
}
void DeviceAPI::FreeStream(TVMContext ctx, TVMStreamHandle stream) {
LOG(FATAL) << "Device does not support stream api.";
}
void DeviceAPI::SyncStreamFromTo(TVMContext ctx,
TVMStreamHandle event_src,
TVMStreamHandle event_dst) {
LOG(FATAL) << "Device does not support stream api.";
}
} // namespace runtime
} // namespace tvm
using namespace tvm::runtime;
struct TVMRuntimeEntry {
std::string ret_str;
std::string last_error;
TVMByteArray ret_bytes;
};
typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore;
const char *TVMGetLastError() {
return TVMAPIRuntimeStore::Get()->last_error.c_str();
}
void TVMAPISetLastError(const char* msg) {
#ifndef _LIBCPP_SGX_CONFIG
TVMAPIRuntimeStore::Get()->last_error = msg;
#else
sgx::OCallPackedFunc("__sgx_set_last_error__", msg);
#endif
}
int TVMModLoadFromFile(const char* file_name,
const char* format,
TVMModuleHandle* out) {
API_BEGIN();
Module m = Module::LoadFromFile(file_name, format);
*out = new Module(m);
API_END();
}
int TVMModImport(TVMModuleHandle mod,
TVMModuleHandle dep) {
API_BEGIN();
static_cast<Module*>(mod)->Import(
*static_cast<Module*>(dep));
API_END();
}
int TVMModGetFunction(TVMModuleHandle mod,
const char* func_name,
int query_imports,
TVMFunctionHandle *func) {
API_BEGIN();
PackedFunc pf = static_cast<Module*>(mod)->GetFunction(
func_name, query_imports != 0);
if (pf != nullptr) {
*func = new PackedFunc(pf);
} else {
*func = nullptr;
}
API_END();
}
int TVMModFree(TVMModuleHandle mod) {
API_BEGIN();
delete static_cast<Module*>(mod);
API_END();
}
int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
TVMFunctionHandle *func) {
API_BEGIN();
*func = (TVMFunctionHandle)(
static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name));
API_END();
}
void* TVMBackendAllocWorkspace(int device_type,
int device_id,
uint64_t size,
int dtype_code_hint,
int dtype_bits_hint) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
TVMType type_hint;
type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
type_hint.lanes = 1;
return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx,
static_cast<size_t>(size),
type_hint);
}
int TVMBackendFreeWorkspace(int device_type,
int device_id,
void* ptr) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);
return 0;
}
int TVMBackendRunOnce(void** handle,
int (*f)(void*),
void* cdata,
int nbytes) {
if (*handle == nullptr) {
*handle = reinterpret_cast<void*>(1);
return (*f)(cdata);
}
return 0;
}
int TVMFuncFree(TVMFunctionHandle func) {
API_BEGIN();
delete static_cast<PackedFunc*>(func);
API_END();
}
int TVMFuncCall(TVMFunctionHandle func,
TVMValue* args,
int* arg_type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code) {
API_BEGIN();
TVMRetValue rv;
(*static_cast<const PackedFunc*>(func)).CallPacked(
TVMArgs(args, arg_type_codes, num_args), &rv);
// handle return string.
if (rv.type_code() == kStr ||
rv.type_code() == kTVMType ||
rv.type_code() == kBytes) {
TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
if (rv.type_code() != kTVMType) {
e->ret_str = *rv.ptr<std::string>();
} else {
e->ret_str = rv.operator std::string();
}
if (rv.type_code() == kBytes) {
e->ret_bytes.data = e->ret_str.c_str();
e->ret_bytes.size = e->ret_str.length();
*ret_type_code = kBytes;
ret_val->v_handle = &(e->ret_bytes);
} else {
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
}
} else {
rv.MoveToCHost(ret_val, ret_type_code);
}
API_END();
}
int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue* value,
int* type_code,
int num_ret) {
API_BEGIN();
CHECK_EQ(num_ret, 1);
TVMRetValue* rv = static_cast<TVMRetValue*>(ret);
*rv = TVMArgValue(value[0], type_code[0]);
API_END();
}
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out) {
API_BEGIN();
if (fin == nullptr) {
*out = new PackedFunc(
[func, resource_handle](TVMArgs args, TVMRetValue* rv) {
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, resource_handle);
if (ret != 0) {
std::string err = "TVMCall CFunc Error:\n";
err += TVMGetLastError();
throw dmlc::Error(err);
}
});
} else {
// wrap it in a shared_ptr, with fin as deleter.
// so fin will be called when the lambda went out of scope.
std::shared_ptr<void> rpack(resource_handle, fin);
*out = new PackedFunc(
[func, rpack](TVMArgs args, TVMRetValue* rv) {
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, rpack.get());
if (ret != 0) {
std::string err = "TVMCall CFunc Error:\n";
err += TVMGetLastError();
throw dmlc::Error(err);
}
});
}
API_END();
}
int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
*out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
API_END();
}
int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
API_END();
}
int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
API_END();
}
int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
API_END();
}
int TVMStreamStreamSynchronize(int device_type,
int device_id,
TVMStreamHandle src,
TVMStreamHandle dst) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
API_END();
}
int TVMCbArgToReturn(TVMValue* value, int code) {
API_BEGIN();
tvm::runtime::TVMRetValue rv;
rv = tvm::runtime::TVMArgValue(*value, code);
int tcode;
rv.MoveToCHost(value, &tcode);
CHECK_EQ(tcode, code);
API_END();
}
// set device api
TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device)
.set_body([](TVMArgs args, TVMRetValue *ret) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
ctx.device_id = args[1];
DeviceAPIManager::Get(ctx)->SetDevice(ctx);
});
// set device api
TVM_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
ctx.device_id = args[1];
DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
if (kind == kExist) {
DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
if (api != nullptr) {
api->GetAttr(ctx, kind, ret);
} else {
*ret = 0;
}
} else {
DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
}
});
/*!
* Copyright (c) 2016 by Contributors
* \file cpu_device_api.cc
*/
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <cstdlib>
#include <cstring>
#include "workspace_pool.h"
namespace tvm {
namespace runtime {
class CPUDeviceAPI final : public DeviceAPI {
public:
void SetDevice(TVMContext ctx) final {}
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
if (kind == kExist) {
*rv = 1;
}
}
void* AllocDataSpace(TVMContext ctx,
size_t nbytes,
size_t alignment,
TVMType type_hint) final {
void* ptr;
#if _MSC_VER
ptr = _aligned_malloc(nbytes, alignment);
if (ptr == nullptr) throw std::bad_alloc();
#elif defined(_LIBCPP_SGX_CONFIG)
ptr = memalign(alignment, nbytes);
if (ptr == nullptr) throw std::bad_alloc();
#else
int ret = posix_memalign(&ptr, alignment, nbytes);
if (ret != 0) throw std::bad_alloc();
#endif
return ptr;
}
void FreeDataSpace(TVMContext ctx, void* ptr) final {
#if _MSC_VER
_aligned_free(ptr);
#else
free(ptr);
#endif
}
void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) final {
memcpy(static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset,
size);
}
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
}
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
static const std::shared_ptr<CPUDeviceAPI>& Global() {
static std::shared_ptr<CPUDeviceAPI> inst =
std::make_shared<CPUDeviceAPI>();
return inst;
}
};
struct CPUWorkspacePool : public WorkspacePool {
CPUWorkspacePool() :
WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {}
};
void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx,
size_t size,
TVMType type_hint) {
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()
->AllocWorkspace(ctx, size);
}
void CPUDeviceAPI::FreeWorkspace(TVMContext ctx, void* data) {
dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data);
}
TVM_REGISTER_GLOBAL("device_api.cpu")
.set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = CPUDeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file dso_dll_module.cc
* \brief Module to load from dynamic shared library.
*/
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h>
#include "module_util.h"
#if defined(_WIN32)
#include <windows.h>
#else
#include <dlfcn.h>
#endif
namespace tvm {
namespace runtime {
// Module to load from dynamic shared libary.
// This is the default module TVM used for host-side AOT
class DSOModuleNode final : public ModuleNode {
public:
~DSOModuleNode() {
if (lib_handle_) Unload();
}
const char* type_key() const final {
return "dso";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
BackendPackedCFunc faddr;
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::tvm_module_main));
CHECK(entry_name!= nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name));
} else {
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str()));
}
if (faddr == nullptr) return PackedFunc();
return WrapPackedFunc(faddr, sptr_to_self);
}
void Init(const std::string& name) {
Load(name);
if (auto *ctx_addr =
reinterpret_cast<void**>(GetSymbol(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = this;
}
InitContextFunctions([this](const char* fname) {
return GetSymbol(fname);
});
// Load the imported modules
const char* dev_mblob =
reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::tvm_dev_mblob));
if (dev_mblob != nullptr) {
ImportModuleBlob(dev_mblob, &imports_);
}
}
private:
// Platform dependent handling.
#if defined(_WIN32)
// library handle
HMODULE lib_handle_{nullptr};
// Load the library
void Load(const std::string& name) {
// use wstring version that is needed by LLVM.
std::wstring wname(name.begin(), name.end());
lib_handle_ = LoadLibraryW(wname.c_str());
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name;
}
void* GetSymbol(const char* name) {
return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}
void Unload() {
FreeLibrary(lib_handle_);
}
#else
// Library handle
void* lib_handle_{nullptr};
// load the library
void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name
<< " " << dlerror();
}
void* GetSymbol(const char* name) {
return dlsym(lib_handle_, name);
}
void Unload() {
dlclose(lib_handle_);
}
#endif
};
TVM_REGISTER_GLOBAL("module.loadfile_so")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>();
n->Init(args[0]);
*rv = runtime::Module(n);
});
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file file_util.cc
*/
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <dgl/runtime/serializer.h>
#include <fstream>
#include <vector>
#include "file_util.h"
namespace tvm {
namespace runtime {
void FunctionInfo::Save(dmlc::JSONWriter* writer) const {
std::vector<std::string> sarg_types(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
sarg_types[i] = TVMType2String(arg_types[i]);
}
writer->BeginObject();
writer->WriteObjectKeyValue("name", name);
writer->WriteObjectKeyValue("arg_types", sarg_types);
writer->WriteObjectKeyValue("thread_axis_tags", thread_axis_tags);
writer->EndObject();
}
void FunctionInfo::Load(dmlc::JSONReader* reader) {
dmlc::JSONObjectReadHelper helper;
std::vector<std::string> sarg_types;
helper.DeclareField("name", &name);
helper.DeclareField("arg_types", &sarg_types);
helper.DeclareField("thread_axis_tags", &thread_axis_tags);
helper.ReadAllFields(reader);
arg_types.resize(sarg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
arg_types[i] = String2TVMType(sarg_types[i]);
}
}
void FunctionInfo::Save(dmlc::Stream* writer) const {
writer->Write(name);
writer->Write(arg_types);
writer->Write(thread_axis_tags);
}
bool FunctionInfo::Load(dmlc::Stream* reader) {
if (!reader->Read(&name)) return false;
if (!reader->Read(&arg_types)) return false;
if (!reader->Read(&thread_axis_tags)) return false;
return true;
}
std::string GetFileFormat(const std::string& file_name,
const std::string& format) {
std::string fmt = format;
if (fmt.length() == 0) {
if (file_name.find(".signed.so") != std::string::npos) return "sgx";
size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) {
return file_name.substr(pos + 1, file_name.length() - pos - 1);
} else {
return "";
}
} else {
return format;
}
}
std::string GetCacheDir() {
char* env_cache_dir;
if ((env_cache_dir = getenv("TVM_CACHE_DIR"))) return env_cache_dir;
if ((env_cache_dir = getenv("XDG_CACHE_HOME"))) {
return std::string(env_cache_dir) + "/tvm";
}
if ((env_cache_dir = getenv("HOME"))) {
return std::string(env_cache_dir) + "/.cache/tvm";
}
return ".";
}
std::string GetFileBasename(const std::string& file_name) {
size_t last_slash = file_name.find_last_of("/");
if (last_slash == std::string::npos) return file_name;
return file_name.substr(last_slash + 1);
}
std::string GetMetaFilePath(const std::string& file_name) {
size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) {
return file_name.substr(0, pos) + ".tvm_meta.json";
} else {
return file_name + ".tvm_meta.json";
}
}
void LoadBinaryFromFile(const std::string& file_name,
std::string* data) {
std::ifstream fs(file_name, std::ios::in | std::ios::binary);
CHECK(!fs.fail()) << "Cannot open " << file_name;
// get its size:
fs.seekg(0, std::ios::end);
size_t size = static_cast<size_t>(fs.tellg());
fs.seekg(0, std::ios::beg);
data->resize(size);
fs.read(&(*data)[0], size);
}
void SaveBinaryToFile(
const std::string& file_name,
const std::string& data) {
std::ofstream fs(file_name, std::ios::out | std::ios::binary);
CHECK(!fs.fail()) << "Cannot open " << file_name;
fs.write(&data[0], data.length());
}
void SaveMetaDataToFile(
const std::string& file_name,
const std::unordered_map<std::string, FunctionInfo>& fmap) {
std::string version = "0.1.0";
std::ofstream fs(file_name.c_str());
CHECK(!fs.fail()) << "Cannot open file " << file_name;
dmlc::JSONWriter writer(&fs);
writer.BeginObject();
writer.WriteObjectKeyValue("tvm_version", version);
writer.WriteObjectKeyValue("func_info", fmap);
writer.EndObject();
fs.close();
}
void LoadMetaDataFromFile(
const std::string& file_name,
std::unordered_map<std::string, FunctionInfo>* fmap) {
std::ifstream fs(file_name.c_str());
CHECK(!fs.fail()) << "Cannot open file " << file_name;
std::string version;
dmlc::JSONReader reader(&fs);
dmlc::JSONObjectReadHelper helper;
helper.DeclareField("tvm_version", &version);
helper.DeclareField("func_info", fmap);
helper.ReadAllFields(&reader);
fs.close();
}
} // namespace runtime
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment