"docs/source/guide_ko/nn-construction.rst" did not exist on "863c249568058ab7c56cb5165de73c1b7b750e6e"
Unverified Commit 8801154b authored by VoVAllen's avatar VoVAllen Committed by GitHub
Browse files

Merge pull request #1 from jermainewang/cpp

Cpp
parents b46abb09 b2c1c4fa
"""Package for graph generators"""
from __future__ import absolute_import
from .line import *
......@@ -4,9 +4,9 @@ from __future__ import absolute_import
import networkx as nx
import numpy as np
import dgl.backend as F
from dgl.graph import DGLGraph
from dgl.frame import FrameRef
from .. import backend as F
from ..graph import DGLGraph
from ..frame import FrameRef
def line_graph(G, no_backtracking=False):
"""Create the line graph that shares the underlying features.
......
......@@ -3,20 +3,22 @@
from __future__ import absolute_import
import networkx as nx
from networkx.classes.digraph import DiGraph
import numpy as np
import dgl
from dgl.base import ALL, is_all, __MSG__, __REPR__
import dgl.backend as F
from dgl.backend import Tensor
from dgl.cached_graph import CachedGraph, create_cached_graph
import dgl.context as context
from dgl.frame import FrameRef, merge_frames
from dgl.nx_adapt import nx_init
import dgl.scheduler as scheduler
import dgl.utils as utils
class DGLGraph(DiGraph):
from .base import ALL, is_all, __MSG__, __REPR__
from . import backend as F
from .backend import Tensor
from .frame import FrameRef, merge_frames
from .function.message import BundledMessageFunction
from .function.reducer import BundledReduceFunction
from .graph_index import GraphIndex, create_graph_index
from . import scheduler
from . import utils
__all__ = ['DLGraph']
class DGLGraph(object):
"""Base graph class specialized for neural networks on graphs.
TODO(minjie): document of batching semantics
......@@ -26,47 +28,478 @@ class DGLGraph(DiGraph):
----------
graph_data : graph data
Data to initialize graph. Same as networkx's semantics.
node_frame : dgl.frame.Frame
node_frame : FrameRef
Node feature storage.
edge_frame : dgl.frame.Frame
edge_frame : FrameRef
Edge feature storage.
attr : keyword arguments, optional
Attributes to add to graph as key=value pairs.
"""
def __init__(self,
graph_data=None,
node_frame=None,
edge_frame=None,
**attr):
# TODO(minjie): maintaining node/edge list is costly when graph is large.
self._edge_list = []
nx_init(self,
self._add_node_callback,
self._add_edge_callback,
self._del_node_callback,
self._del_edge_callback,
graph_data,
**attr)
# cached graph and storage
self._cached_graph = None
edge_frame=None):
# graph
self._graph = create_graph_index(graph_data)
# frame
self._node_frame = node_frame if node_frame is not None else FrameRef()
self._edge_frame = edge_frame if edge_frame is not None else FrameRef()
# other class members
self._msg_graph = None
# msg graph & frame
self._msg_graph = create_graph_index()
self._msg_frame = FrameRef()
self._message_func = (None, None)
self._reduce_func = (None, None)
self._edge_func = (None, None)
self._apply_node_func = (None, None)
self._apply_edge_func = (None, None)
self.reset_messages()
# registered functions
self._message_func = None
self._reduce_func = None
self._edge_func = None
self._apply_node_func = None
self._apply_edge_func = None
def add_nodes(self, num, reprs=None):
"""Add nodes.
Parameters
----------
num : int
Number of nodes to be added.
reprs : dict
Optional node representations.
"""
self._graph.add_nodes(num)
self._msg_graph.add_nodes(num)
#TODO(minjie): change frames
assert reprs is None
def add_edge(self, u, v, reprs=None):
"""Add one edge.
Parameters
----------
u : int
The src node.
v : int
The dst node.
reprs : dict
Optional edge representation.
"""
self._graph.add_edge(u, v)
#TODO(minjie): change frames
assert reprs is None
def add_edges(self, u, v, reprs=None):
"""Add many edges.
Parameters
----------
u : list, tensor
The src nodes.
v : list, tensor
The dst nodes.
reprs : dict
Optional node representations.
"""
u = utils.toindex(u)
v = utils.toindex(v)
self._graph.add_edges(u, v)
#TODO(minjie): change frames
assert reprs is None
def clear(self):
"""Clear the graph and its storage."""
self._graph.clear()
self._node_frame.clear()
self._edge_frame.clear()
self._msg_graph.clear()
self._msg_frame.clear()
def reset_messages(self):
"""Clear all messages."""
self._msg_graph.clear()
self._msg_frame.clear()
self._msg_graph.add_nodes(self.number_of_nodes())
def number_of_nodes(self):
"""Return the number of nodes.
Returns
-------
int
The number of nodes
"""
return self._graph.number_of_nodes()
def __len__(self):
"""Return the number of nodes."""
return self.number_of_nodes()
def number_of_edges(self):
"""Return the number of edges.
Returns
-------
int
The number of edges
"""
return self._graph.number_of_edges()
def has_node(self, vid):
"""Return true if the node exists.
Parameters
----------
vid : int
The nodes
Returns
-------
bool
True if the node exists
"""
return self.has_node(vid)
def __contains__(self, vid):
"""Same as has_node."""
return self.has_node(vid)
def has_nodes(self, vids):
"""Return true if the nodes exist.
Parameters
----------
vid : list, tensor
The nodes
Returns
-------
tensor
0-1 array indicating existence
"""
vids = utils.toindex(vids)
rst = self._graph.has_nodes(vids)
return rst.tousertensor()
def has_edge(self, u, v):
"""Return true if the edge exists.
Parameters
----------
u : int
The src node.
v : int
The dst node.
Returns
-------
bool
True if the edge exists
"""
return self._graph.has_edge(u, v)
def has_edges(self, u, v):
"""Return true if the edge exists.
Parameters
----------
u : list, tensor
The src nodes.
v : list, tensor
The dst nodes.
Returns
-------
tensor
0-1 array indicating existence
"""
u = utils.toindex(u)
v = utils.toindex(v)
rst = self._graph.has_edges(u, v)
return rst.tousertensor()
def predecessors(self, v, radius=1):
"""Return the predecessors of the node.
Parameters
----------
v : int
The node.
radius : int, optional
The radius of the neighborhood.
Returns
-------
tensor
Array of predecessors
"""
return self._graph.predecessors(v).tousertensor()
def successors(self, v, radius=1):
"""Return the successors of the node.
Parameters
----------
v : int
The node.
radius : int, optional
The radius of the neighborhood.
Returns
-------
tensor
Array of successors
"""
return self._graph.successors(v).tousertensor()
def edge_id(self, u, v):
"""Return the id of the edge.
Parameters
----------
u : int
The src node.
v : int
The dst node.
Returns
-------
int
The edge id.
"""
return self._graph.edge_id(u, v)
def edge_ids(self, u, v):
"""Return the edge ids.
Parameters
----------
u : list, tensor
The src nodes.
v : list, tensor
The dst nodes.
Returns
-------
tensor
The edge id array.
"""
u = utils.toindex(u)
v = utils.toindex(v)
rst = self._graph.edge_ids(u, v)
return rst.tousertensor()
def in_edges(self, v):
"""Return the in edges of the node(s).
Parameters
----------
v : int, list, tensor
The node(s).
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
"""
v = utils.toindex(v)
src, dst, eid = self._graph.in_edges(v)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
def out_edges(self, v):
"""Return the out edges of the node(s).
Parameters
----------
v : int, list, tensor
The node(s).
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
"""
v = utils.toindex(v)
src, dst, eid = self._graph.out_edges(v)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
def edges(self, sorted=False):
"""Return all the edges.
Parameters
----------
sorted : bool
True if the returned edges are sorted by their src and dst ids.
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
"""
src, dst, eid = self._graph.edges(sorted)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
def in_degree(self, v):
"""Return the in degree of the node.
Parameters
----------
v : int
The node.
Returns
-------
int
The in degree.
"""
return self._graph.in_degree(v)
def in_degrees(self, v):
"""Return the in degrees of the nodes.
Parameters
----------
v : list, tensor
The nodes.
Returns
-------
tensor
The in degree array.
"""
return self._graph.in_degrees(v).tousertensor()
def out_degree(self, v):
"""Return the out degree of the node.
Parameters
----------
v : int
The node.
Returns
-------
int
The out degree.
"""
return self._graph.out_degree(v)
def out_degrees(self, v):
"""Return the out degrees of the nodes.
Parameters
----------
v : list, tensor
The nodes.
Returns
-------
tensor
The out degree array.
"""
return self._graph.out_degrees(v).tousertensor()
def to_networkx(self, node_attrs=None, edge_attrs=None):
"""Convert to networkx graph.
The edge id will be saved as the 'id' edge attribute.
Parameters
----------
node_attrs : iterable of str, optional
The node attributes to be copied.
edge_attrs : iterable of str, optional
The edge attributes to be copied.
Returns
-------
networkx.DiGraph
The nx graph
"""
nx_graph = self._graph.to_networkx()
#TODO: attributes
return nx_graph
def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
"""Convert from networkx graph.
If 'id' edge attribute exists, the edge will be added follows
the edge id order. Otherwise, order is undefined.
Parameters
----------
nx_graph : networkx.DiGraph
The nx graph
node_attrs : iterable of str, optional
The node attributes needs to be copied.
edge_attrs : iterable of str, optional
The edge attributes needs to be copied.
"""
self.clear()
self._graph.from_networkx(nx_graph)
self._msg_graph.add_nodes(self._graph.number_of_nodes())
# copy attributes
def _batcher(lst):
if isinstance(lst[0], Tensor):
return F.pack([F.unsqueeze(x, 0) for x in lst])
else:
return F.tensor(lst)
if node_attrs is not None:
attr_dict = {attr : [] for attr in node_attrs}
for nid in range(self.number_of_nodes()):
for attr in node_attrs:
attr_dict[attr].append(nx_graph.nodes[nid][attr])
for attr in node_attrs:
self._node_frame[attr] = _batcher(attr_dict[attr])
if edge_attrs is not None:
attr_dict = {attr : [] for attr in edge_attrs}
src, dst, _ = self._graph.edges()
for u, v in zip(src.tolist(), dst.tolist()):
for attr in edge_attrs:
attr_dict[attr].append(nx_graph.edges[u, v][attr])
for attr in edge_attrs:
self._edge_frame[attr] = _batcher(attr_dict[attr])
def from_scipy_sparse_matrix(self, a):
""" Convert from scipy sparse matrix.
Parameters
----------
a : scipy sparse matrix
The graph's adjacency matrix
"""
self.clear()
self._graph.from_scipy_sparse_matrix(a)
self._msg_graph.add_nodes(self._graph.number_of_nodes())
def node_attr_schemes(self):
"""Return the node attribute schemes.
Returns
-------
iterable
The set of attribute names
"""
return self._node_frame.schemes
def edge_attr_schemes(self):
"""Return the edge attribute schemes.
Returns
-------
iterable
The set of attribute names
"""
return self._edge_frame.schemes
def set_n_repr(self, hu, u=ALL):
def set_n_repr(self, hu, u=ALL, inplace=False):
"""Set node(s) representation.
To set multiple node representations at once, pass `u` with a tensor or
......@@ -104,9 +537,9 @@ class DGLGraph(DiGraph):
self._node_frame[__REPR__] = hu
else:
if utils.is_dict_like(hu):
self._node_frame[u] = hu
self._node_frame.update_rows(u, hu, inplace=inplace)
else:
self._node_frame[u] = {__REPR__ : hu}
self._node_frame.update_rows(u, {__REPR__ : hu}, inplace=inplace)
def get_n_repr(self, u=ALL):
"""Get node(s) representation.
......@@ -114,8 +547,15 @@ class DGLGraph(DiGraph):
Parameters
----------
u : node, container or tensor
The node(s).
The node(s).
Returns
-------
dict
Representation dict
"""
if len(self.node_attr_schemes()) == 0:
return dict()
if is_all(u):
if len(self._node_frame) == 1 and __REPR__ in self._node_frame:
return self._node_frame[__REPR__]
......@@ -134,7 +574,12 @@ class DGLGraph(DiGraph):
Parameters
----------
key : str
The attribute name.
The attribute name.
Returns
-------
Tensor
The popped representation
"""
return self._node_frame.pop(key)
......@@ -163,29 +608,12 @@ class DGLGraph(DiGraph):
v_is_all = is_all(v)
assert u_is_all == v_is_all
if u_is_all:
num_edges = self.cached_graph.num_edges()
self.set_e_repr_by_id(h_uv, eid=ALL)
else:
u = utils.toindex(u)
v = utils.toindex(v)
num_edges = max(len(u), len(v))
if utils.is_dict_like(h_uv):
for key, val in h_uv.items():
assert F.shape(val)[0] == num_edges
else:
assert F.shape(h_uv)[0] == num_edges
# set
if u_is_all:
if utils.is_dict_like(h_uv):
for key, val in h_uv.items():
self._edge_frame[key] = val
else:
self._edge_frame[__REPR__] = h_uv
else:
eid = self.cached_graph.get_edge_id(u, v)
if utils.is_dict_like(h_uv):
self._edge_frame[eid] = h_uv
else:
self._edge_frame[eid] = {__REPR__ : h_uv}
eid = self._graph.edge_ids(u, v)
self.set_e_repr_by_id(h_uv, eid=eid)
def set_e_repr_by_id(self, h_uv, eid=ALL):
"""Set edge(s) representation by edge id.
......@@ -199,7 +627,7 @@ class DGLGraph(DiGraph):
"""
# sanity check
if is_all(eid):
num_edges = self.cached_graph.num_edges()
num_edges = self.number_of_edges()
else:
eid = utils.toindex(eid)
num_edges = len(eid)
......@@ -230,23 +658,24 @@ class DGLGraph(DiGraph):
The source node(s).
v : node, container or tensor
The destination node(s).
Returns
-------
dict
Representation dict
"""
u_is_all = is_all(u)
v_is_all = is_all(v)
assert u_is_all == v_is_all
if len(self.edge_attr_schemes()) == 0:
return dict()
if u_is_all:
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__]
else:
return dict(self._edge_frame)
return self.get_e_repr_by_id(eid=ALL)
else:
u = utils.toindex(u)
v = utils.toindex(v)
eid = self.cached_graph.get_edge_id(u, v)
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame.select_rows(eid)[__REPR__]
else:
return self._edge_frame.select_rows(eid)
eid = self._graph.edge_ids(u, v)
return self.get_e_repr_by_id(eid=eid)
def pop_e_repr(self, key=__REPR__):
"""Get and remove the specified edge repr.
......@@ -255,6 +684,11 @@ class DGLGraph(DiGraph):
----------
key : str
The attribute name.
Returns
-------
Tensor
The popped representation
"""
return self._edge_frame.pop(key)
......@@ -265,7 +699,14 @@ class DGLGraph(DiGraph):
----------
eid : int, container or tensor
The edge id(s).
Returns
-------
dict
Representation dict
"""
if len(self.edge_attr_schemes()) == 0:
return dict()
if is_all(eid):
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__]
......@@ -278,77 +719,57 @@ class DGLGraph(DiGraph):
else:
return self._edge_frame.select_rows(eid)
def register_edge_func(self,
edge_func,
batchable=False):
def register_edge_func(self, edge_func):
"""Register global edge update function.
Parameters
----------
edge_func : callable
Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
"""
self._edge_func = (edge_func, batchable)
self._edge_func = edge_func
def register_message_func(self,
message_func,
batchable=False):
def register_message_func(self, message_func):
"""Register global message function.
Parameters
----------
message_func : callable
Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
"""
self._message_func = (message_func, batchable)
self._message_func = message_func
def register_reduce_func(self,
reduce_func,
batchable=False):
def register_reduce_func(self, reduce_func):
"""Register global message reduce function.
Parameters
----------
reduce_func : str or callable
Reduce function on incoming edges.
batchable : bool
Whether the provided reduce function allows batch computing.
"""
self._reduce_func = (reduce_func, batchable)
self._reduce_func = reduce_func
def register_apply_node_func(self,
apply_node_func,
batchable=False):
def register_apply_node_func(self, apply_node_func):
"""Register global node apply function.
Parameters
----------
apply_node_func : callable
Apply function on the node.
batchable : bool
Whether the provided function allows batch computing.
"""
self._apply_node_func = (apply_node_func, batchable)
self._apply_node_func = apply_node_func
def register_apply_edge_func(self,
apply_edge_func,
batchable=False):
def register_apply_edge_func(self, apply_edge_func):
"""Register global edge apply function.
Parameters
----------
apply_edge_func : callable
Apply function on the edge.
batchable : bool
Whether the provided function allows batch computing.
"""
self._apply_edge_func = (apply_edge_func, batchable)
self._apply_edge_func = apply_edge_func
def apply_nodes(self, v, apply_node_func="default", batchable=False):
def apply_nodes(self, v, apply_node_func="default"):
"""Apply the function on node representations.
Parameters
......@@ -357,26 +778,16 @@ class DGLGraph(DiGraph):
The node id(s).
apply_node_func : callable
The apply node function.
batchable : bool
Whether the provided function allows batch computing.
"""
if apply_node_func == "default":
apply_node_func, batchable = self._apply_node_func
apply_node_func = self._apply_node_func
if not apply_node_func:
# Skip none function call.
return
if batchable:
new_repr = apply_node_func(self.get_n_repr(v))
self.set_n_repr(new_repr, v)
else:
if is_all(v):
v = self.nodes()
v = utils.toindex(v)
for vv in utils.node_iter(v):
ret = apply_node_func(_get_repr(self.nodes[vv]))
_set_repr(self.nodes[vv], ret)
new_repr = apply_node_func(self.get_n_repr(v))
self.set_n_repr(new_repr, v)
def apply_edges(self, u, v, apply_edge_func="default", batchable=False):
def apply_edges(self, u, v, apply_edge_func="default"):
"""Apply the function on edge representations.
Parameters
......@@ -387,27 +798,16 @@ class DGLGraph(DiGraph):
The dst node id(s).
apply_edge_func : callable
The apply edge function.
batchable : bool
Whether the provided function allows batch computing.
"""
if apply_edge_func == "default":
apply_edge_func, batchable = self._apply_edge_func
apply_edge_func = self._apply_edge_func
if not apply_edge_func:
# Skip none function call.
return
if batchable:
new_repr = apply_edge_func(self.get_e_repr(u, v))
self.set_e_repr(new_repr, u, v)
else:
if is_all(u) == is_all(v):
u, v = zip(*self.edges)
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = apply_edge_func(_get_repr(self.edges[uu, vv]))
_set_repr(self.edges[uu, vv], ret)
new_repr = apply_edge_func(self.get_e_repr(u, v))
self.set_e_repr(new_repr, u, v)
def send(self, u, v, message_func="default", batchable=False):
def send(self, u, v, message_func="default"):
"""Trigger the message function on edge u->v
The message function should be compatible with following signature:
......@@ -428,32 +828,18 @@ class DGLGraph(DiGraph):
The destination node(s).
message_func : callable
The message function.
batchable : bool
Whether the function allows batched computation.
"""
if message_func == "default":
message_func, batchable = self._message_func
message_func = self._message_func
assert message_func is not None
if batchable:
self._batch_send(u, v, message_func)
else:
self._nonbatch_send(u, v, message_func)
def _nonbatch_send(self, u, v, message_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = message_func(_get_repr(self.nodes[uu]),
_get_repr(self.edges[uu, vv]))
self.edges[uu, vv][__MSG__] = ret
if isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func)
self._batch_send(u, v, message_func)
def _batch_send(self, u, v, message_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
self.msg_graph.add_edges(u, v)
u, v, _ = self._graph.edges()
self._msg_graph.add_edges(u, v) # TODO(minjie): can be optimized
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr()
......@@ -462,18 +848,17 @@ class DGLGraph(DiGraph):
u = utils.toindex(u)
v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v)
eid = self.cached_graph.get_edge_id(u, v)
self.msg_graph.add_edges(u, v)
self._msg_graph.add_edges(u, v)
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid)
edge_reprs = self.get_e_repr(u, v)
msgs = message_func(src_reprs, edge_reprs)
if utils.is_dict_like(msgs):
self._msg_frame.append(msgs)
else:
self._msg_frame.append({__MSG__ : msgs})
def update_edge(self, u, v, edge_func="default", batchable=False):
def update_edge(self, u=ALL, v=ALL, edge_func="default"):
"""Update representation on edge u->v
The edge function should be compatible with following signature:
......@@ -492,32 +877,15 @@ class DGLGraph(DiGraph):
The destination node(s).
edge_func : callable
The update function.
batchable : bool
Whether the function allows batched computation.
"""
if edge_func == "default":
edge_func, batchable = self._edge_func
edge_func = self._edge_func
assert edge_func is not None
if batchable:
self._batch_update_edge(u, v, edge_func)
else:
self._nonbatch_update_edge(u, v, edge_func)
def _nonbatch_update_edge(self, u, v, edge_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = edge_func(_get_repr(self.nodes[uu]),
_get_repr(self.nodes[vv]),
_get_repr(self.edges[uu, vv]))
_set_repr(self.edges[uu, vv], ret)
self._batch_update_edge(u, v, edge_func)
def _batch_update_edge(self, u, v, edge_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
u, v, _ = self._graph.edges()
# call the UDF
src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v)
......@@ -528,7 +896,7 @@ class DGLGraph(DiGraph):
u = utils.toindex(u)
v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v)
eid = self.cached_graph.get_edge_id(u, v)
eid = self._graph.edge_ids(u, v)
# call the UDF
src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v)
......@@ -539,8 +907,7 @@ class DGLGraph(DiGraph):
def recv(self,
u,
reduce_func="default",
apply_node_func="default",
batchable=False):
apply_node_func="default"):
"""Receive and reduce in-coming messages and update representation on node u.
It computes the new node state using the messages sent from the predecessors
......@@ -570,31 +937,15 @@ class DGLGraph(DiGraph):
The reduce function.
apply_node_func : callable, optional
The update function.
batchable : bool, optional
Whether the reduce and update function allows batched computation.
"""
if reduce_func == "default":
reduce_func, batchable = self._reduce_func
reduce_func = self._reduce_func
assert reduce_func is not None
if batchable:
self._batch_recv(u, reduce_func)
else:
self._nonbatch_recv(u, reduce_func)
if isinstance(reduce_func, (list, tuple)):
reduce_func = BundledReduceFunction(reduce_func)
self._batch_recv(u, reduce_func)
# optional apply nodes
self.apply_nodes(u, apply_node_func, batchable)
def _nonbatch_recv(self, u, reduce_func):
if is_all(u):
u = list(range(0, self.number_of_nodes()))
else:
u = utils.toindex(u)
for i, uu in enumerate(utils.node_iter(u)):
# reduce phase
msgs_batch = [self.edges[vv, uu].pop(__MSG__)
for vv in self.pred[uu] if __MSG__ in self.edges[vv, uu]]
if len(msgs_batch) != 0:
new_repr = reduce_func(_get_repr(self.nodes[uu]), msgs_batch)
_set_repr(self.nodes[uu], new_repr)
self.apply_nodes(u, apply_node_func)
def _batch_recv(self, v, reduce_func):
if self._msg_frame.num_rows == 0:
......@@ -610,7 +961,7 @@ class DGLGraph(DiGraph):
v = utils.toindex(v)
# degree bucketing
degrees, v_buckets = scheduler.degree_bucketing(self.msg_graph, v)
degrees, v_buckets = scheduler.degree_bucketing(self._msg_graph, v)
if degrees == [0]:
# no message has been sent to the specified node
return
......@@ -625,8 +976,7 @@ class DGLGraph(DiGraph):
continue
bkt_len = len(v_bkt)
dst_reprs = self.get_n_repr(v_bkt)
uu, vv, _ = self.msg_graph.in_edges(v_bkt)
in_msg_ids = self.msg_graph.get_edge_id(uu, vv)
uu, vv, in_msg_ids = self._msg_graph.in_edges(v_bkt)
in_msgs = self._msg_frame.select_rows(in_msg_ids)
# Reshape the column tensor to (B, Deg, ...).
def _reshape_fn(msg):
......@@ -638,11 +988,11 @@ class DGLGraph(DiGraph):
else:
reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes)
reordered_v.append(v_bkt.totensor())
reordered_v.append(v_bkt.tousertensor())
new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs))
# TODO: clear partial messages
self.clear_messages()
self.reset_messages()
# Pack all reducer results together
reordered_v = F.pack(reordered_v)
......@@ -667,8 +1017,7 @@ class DGLGraph(DiGraph):
u, v,
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False):
apply_node_func="default"):
"""Trigger the message function on u->v and update v.
Parameters
......@@ -683,8 +1032,6 @@ class DGLGraph(DiGraph):
The reduce function.
apply_node_func : callable, optional
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
u = utils.toindex(u)
v = utils.toindex(v)
......@@ -692,36 +1039,30 @@ class DGLGraph(DiGraph):
# no edges to be triggered
assert len(v) == 0
return
unique_v = utils.toindex(F.unique(v.totensor()))
unique_v = utils.toindex(F.unique(v.tousertensor()))
# TODO(minjie): better way to figure out `batchable` flag
if message_func == "default":
message_func, batchable = self._message_func
message_func = self._message_func
if reduce_func == "default":
reduce_func, _ = self._reduce_func
reduce_func = self._reduce_func
assert message_func is not None
assert reduce_func is not None
if batchable:
executor = scheduler.get_executor(
'send_and_recv', self, src=u, dst=v,
message_func=message_func, reduce_func=reduce_func)
else:
executor = None
executor = scheduler.get_executor(
'send_and_recv', self, src=u, dst=v,
message_func=message_func, reduce_func=reduce_func)
if executor:
executor.run()
else:
self.send(u, v, message_func, batchable=batchable)
self.recv(unique_v, reduce_func, None, batchable=batchable)
self.apply_nodes(unique_v, apply_node_func, batchable=batchable)
self.send(u, v, message_func)
self.recv(unique_v, reduce_func, None)
self.apply_nodes(unique_v, apply_node_func)
def pull(self,
v,
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False):
apply_node_func="default"):
"""Pull messages from the node's predecessors and then update it.
Parameters
......@@ -734,24 +1075,20 @@ class DGLGraph(DiGraph):
The reduce function.
apply_node_func : callable, optional
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
v = utils.toindex(v)
if len(v) == 0:
return
uu, vv, _ = self.cached_graph.in_edges(v)
self.send_and_recv(uu, vv, message_func, reduce_func,
apply_node_func=None, batchable=batchable)
unique_v = F.unique(v.totensor())
self.apply_nodes(unique_v, apply_node_func, batchable=batchable)
uu, vv, _ = self._graph.in_edges(v)
self.send_and_recv(uu, vv, message_func, reduce_func, apply_node_func=None)
unique_v = F.unique(v.tousertensor())
self.apply_nodes(unique_v, apply_node_func)
def push(self,
u,
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False):
apply_node_func="default"):
"""Send message from the node to its successors and update them.
Parameters
......@@ -764,21 +1101,18 @@ class DGLGraph(DiGraph):
The reduce function.
apply_node_func : callable
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
u = utils.toindex(u)
if len(u) == 0:
return
uu, vv, _ = self.cached_graph.out_edges(u)
uu, vv, _ = self._graph.out_edges(u)
self.send_and_recv(uu, vv, message_func,
reduce_func, apply_node_func, batchable=batchable)
reduce_func, apply_node_func)
def update_all(self,
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False):
apply_node_func="default"):
"""Send messages through all the edges and update all nodes.
Parameters
......@@ -789,76 +1123,61 @@ class DGLGraph(DiGraph):
The reduce function.
apply_node_func : callable, optional
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
if message_func == "default":
message_func, batchable = self._message_func
message_func = self._message_func
if reduce_func == "default":
reduce_func, _ = self._reduce_func
reduce_func = self._reduce_func
assert message_func is not None
assert reduce_func is not None
if batchable:
executor = scheduler.get_executor(
"update_all", self, message_func=message_func, reduce_func=reduce_func)
else:
executor = None
executor = scheduler.get_executor(
"update_all", self, message_func=message_func, reduce_func=reduce_func)
if executor:
executor.run()
else:
self.send(ALL, ALL, message_func, batchable=batchable)
self.recv(ALL, reduce_func, None, batchable=batchable)
self.apply_nodes(ALL, apply_node_func, batchable=batchable)
self.send(ALL, ALL, message_func)
self.recv(ALL, reduce_func, None)
self.apply_nodes(ALL, apply_node_func)
def propagate(self,
iterator='bfs',
traverser='topo',
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False,
**kwargs):
"""Propagate messages and update nodes using iterator.
"""Propagate messages and update nodes using graph traversal.
A convenient function for passing messages and updating
nodes according to the iterator. The iterator can be
any of the pre-defined iterators ('bfs', 'dfs', 'pre-order',
'mid-order', 'post-order'). The computation will be unrolled
in the backend efficiently. User can also provide custom
iterator that generates the edges and nodes.
nodes according to the traverser. The traverser can be
any of the pre-defined traverser (e.g. 'topo'). User can also provide custom
traverser that generates the edges and nodes.
Parameters
----------
traverser : str or generator of edges.
The traverser of the graph.
message_func : str or callable
The message function.
reduce_func : str or callable
The reduce function.
apply_node_func : str or callable
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
iterator : str or generator of steps.
The iterator of the graph.
kwargs : keyword arguments, optional
Arguments for pre-defined iterators.
"""
if isinstance(iterator, str):
if isinstance(traverser, str):
# TODO Call pre-defined routine to unroll the computation.
raise RuntimeError('Not implemented.')
else:
# NOTE: the iteration can return multiple edges at each step.
for u, v in iterator:
for u, v in traverser:
self.send_and_recv(u, v,
message_func, reduce_func, apply_node_func, batchable)
message_func, reduce_func, apply_node_func)
def subgraph(self, nodes):
"""Generate the subgraph among the given nodes.
The generated graph contains only the graph structure. The node/edge
features are not shared implicitly. Use `copy_from` to get node/edge
features from parent graph.
Parameters
----------
nodes : list, or iterable
......@@ -869,7 +1188,9 @@ class DGLGraph(DiGraph):
G : DGLSubGraph
The subgraph.
"""
return dgl.DGLSubGraph(self, nodes)
induced_nodes = utils.toindex(nodes)
gi, induced_edges = self._graph.node_subgraph(induced_nodes)
return dgl.DGLSubGraph(self, induced_nodes, induced_edges, gi)
def merge(self, subgraphs, reduce_func='sum'):
"""Merge subgraph features back to this parent graph.
......@@ -913,91 +1234,109 @@ class DGLGraph(DiGraph):
self._edge_frame.num_rows,
reduce_func)
def draw(self):
"""Plot the graph using dot."""
from networkx.drawing.nx_agraph import graphviz_layout
def adjacency_matrix(self, ctx=None):
"""Return the adjacency matrix representation of this graph.
Parameters
----------
ctx : optional
The context of returned adjacency matrix.
Returns
-------
sparse_tensor
The adjacency matrix.
"""
return self._graph.adjacency_matrix().get(ctx)
def incidence_matrix(self, oriented=False, ctx=None):
"""Return the incidence matrix representation of this graph.
Parameters
----------
oriented : bool, optional
Whether the returned incidence matrix is oriented.
ctx : optional
The context of returned incidence matrix.
Returns
-------
sparse_tensor
The incidence matrix.
"""
return self._graph.incidence_matrix(oriented).get(ctx)
def line_graph(self, backtracking=True, shared=False):
"""Return the line graph of this graph.
Parameters
----------
backtracking : bool, optional
Whether the returned line graph is backtracking.
shared : bool, optional
Whether the returned line graph shares representations with `self`.
pos = graphviz_layout(self, prog='dot')
nx.draw(self, pos, with_labels=True)
Returns
-------
DGLGraph
The line graph of this graph.
"""
graph_data = self._graph.line_graph(backtracking)
node_frame = self._edge_frame if shared else None
return DGLGraph(graph_data, node_frame)
@property
def cached_graph(self):
# TODO: dirty flag when mutated
if self._cached_graph is None:
self._cached_graph = create_cached_graph(self)
return self._cached_graph
def filter_nodes(self, predicate, nodes=ALL):
"""Return a tensor of node IDs that satisfy the given predicate.
@property
def msg_graph(self):
# TODO: dirty flag when mutated
if self._msg_graph is None:
self._msg_graph = CachedGraph()
self._msg_graph.add_nodes(self.number_of_nodes())
return self._msg_graph
Parameters
----------
predicate : callable
The predicate should take in a dict of tensors whose values
are concatenation of node representations by node ID (same as
get_n_repr()), and return a boolean tensor with N elements
indicating which node satisfy the predicate.
nodes : container or tensor
The nodes to filter on
def clear_messages(self):
if self._msg_graph is not None:
self._msg_graph = CachedGraph()
self._msg_graph.add_nodes(self.number_of_nodes())
self._msg_frame.clear()
Returns
-------
tensor
The filtered nodes
"""
n_repr = self.get_n_repr(nodes)
n_mask = predicate(n_repr)
@property
def edge_list(self):
"""Return edges in the addition order."""
return self._edge_list
if is_all(nodes):
return F.nonzero_1d(n_mask)
else:
nodes = F.Tensor(nodes)
return nodes[n_mask]
def get_edge_id(self, u, v):
"""Return the continuous edge id(s) assigned.
def filter_edges(self, predicate, edges=ALL):
"""Return a tensor of edge IDs that satisfy the given predicate.
Parameters
----------
u : node, container or tensor
The source node(s).
v : node, container or tensor
The destination node(s).
predicate : callable
The predicate should take in a dict of tensors whose values
are concatenation of edge representations by edge ID (same as
get_e_repr_by_id()), and return a boolean tensor with N elements
indicating which node satisfy the predicate.
edges : container or tensor
The edges to filter on
Returns
-------
eid : tensor
The tensor contains edge id(s).
tensor
The filtered edges
"""
u = utils.toindex(u)
v = utils.toindex(v)
return self.cached_graph.get_edge_id(u, v)
def _add_node_callback(self, node):
#print('New node:', node)
self._cached_graph = None
def _del_node_callback(self, node):
#print('Del node:', node)
raise RuntimeError('Node removal is not supported currently.')
node = utils.convert_to_id_tensor(node)
self._node_frame.delete_rows(node)
self._cached_graph = None
def _add_edge_callback(self, u, v):
#print('New edge:', u, v)
self._edge_list.append((u, v))
self._cached_graph = None
def _del_edge_callback(self, u, v):
#print('Del edge:', u, v)
raise RuntimeError('Edge removal is not supported currently.')
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
eid = self.get_edge_id(u, v)
self._edge_frame.delete_rows(eid)
self._cached_graph = None
def _get_repr(attr_dict):
if len(attr_dict) == 1 and __REPR__ in attr_dict:
return attr_dict[__REPR__]
else:
return attr_dict
def _set_repr(attr_dict, attr):
if utils.is_dict_like(attr):
attr_dict.update(attr)
else:
attr_dict[__REPR__] = attr
e_repr = self.get_e_repr_by_id(edges)
e_mask = predicate(e_repr)
if is_all(edges):
return F.nonzero_1d(e_mask)
else:
edges = F.Tensor(edges)
return edges[e_mask]
from __future__ import absolute_import
import ctypes
import numpy as np
import networkx as nx
import scipy.sparse as sp
from ._ffi.base import c_array
from ._ffi.function import _init_api
from . import backend as F
from . import utils
GraphIndexHandle = ctypes.c_void_p
class GraphIndex(object):
"""Graph index object.
Parameters
----------
handle : GraphIndexHandle
Handler
"""
def __init__(self, handle):
self._handle = handle
self._cache = {}
def __del__(self):
"""Free this graph index object."""
_CAPI_DGLGraphFree(self._handle)
def add_nodes(self, num):
"""Add nodes.
Parameters
----------
num : int
Number of nodes to be added.
"""
_CAPI_DGLGraphAddVertices(self._handle, num);
self._cache.clear()
def add_edge(self, u, v):
"""Add one edge.
Parameters
----------
u : int
The src node.
v : int
The dst node.
"""
_CAPI_DGLGraphAddEdge(self._handle, u, v);
self._cache.clear()
def add_edges(self, u, v):
"""Add many edges.
Parameters
----------
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
_CAPI_DGLGraphAddEdges(self._handle, u_array, v_array)
self._cache.clear()
def clear(self):
"""Clear the graph."""
_CAPI_DGLGraphClear(self._handle)
self._cache.clear()
def number_of_nodes(self):
"""Return the number of nodes.
Returns
-------
int
The number of nodes
"""
return _CAPI_DGLGraphNumVertices(self._handle)
def number_of_edges(self):
"""Return the number of edges.
Returns
-------
int
The number of edges
"""
return _CAPI_DGLGraphNumEdges(self._handle)
def has_node(self, vid):
"""Return true if the node exists.
Parameters
----------
vid : int
The nodes
Returns
-------
bool
True if the node exists
"""
return _CAPI_DGLGraphHasVertex(self._handle, vid)
def has_nodes(self, vids):
"""Return true if the nodes exist.
Parameters
----------
vid : utils.Index
The nodes
Returns
-------
utils.Index
0-1 array indicating existence
"""
vid_array = vids.todgltensor()
return utils.toindex(_CAPI_DGLGraphHasVertices(self._handle, vid_array))
def has_edge(self, u, v):
"""Return true if the edge exists.
Parameters
----------
u : int
The src node.
v : int
The dst node.
Returns
-------
bool
True if the edge exists
"""
return _CAPI_DGLGraphHasEdge(self._handle, u, v)
def has_edges(self, u, v):
"""Return true if the edge exists.
Parameters
----------
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
Returns
-------
utils.Index
0-1 array indicating existence
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphHasEdges(self._handle, u_array, v_array))
def predecessors(self, v, radius=1):
"""Return the predecessors of the node.
Parameters
----------
v : int
The node.
radius : int, optional
The radius of the neighborhood.
Returns
-------
utils.Index
Array of predecessors
"""
return utils.toindex(_CAPI_DGLGraphPredecessors(self._handle, v, radius))
def successors(self, v, radius=1):
"""Return the successors of the node.
Parameters
----------
v : int
The node.
radius : int, optional
The radius of the neighborhood.
Returns
-------
utils.Index
Array of successors
"""
return utils.toindex(_CAPI_DGLGraphSuccessors(self._handle, v, radius))
def edge_id(self, u, v):
"""Return the id of the edge.
Parameters
----------
u : int
The src node.
v : int
The dst node.
Returns
-------
int
The edge id.
"""
return _CAPI_DGLGraphEdgeId(self._handle, u, v)
def edge_ids(self, u, v):
"""Return the edge ids.
Parameters
----------
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
Returns
-------
utils.Index
Teh edge id array.
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphEdgeIds(self._handle, u_array, v_array))
def in_edges(self, v):
"""Return the in edges of the node(s).
Parameters
----------
v : utils.Index
The node(s).
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if len(v) == 1:
edge_array = _CAPI_DGLGraphInEdges_1(self._handle, v[0])
else:
v_array = v.todgltensor()
edge_array = _CAPI_DGLGraphInEdges_2(self._handle, v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def out_edges(self, v):
"""Return the out edges of the node(s).
Parameters
----------
v : utils.Index
The node(s).
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if len(v) == 1:
edge_array = _CAPI_DGLGraphOutEdges_1(self._handle, v[0])
else:
v_array = v.todgltensor()
edge_array = _CAPI_DGLGraphOutEdges_2(self._handle, v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def edges(self, sorted=False):
"""Return all the edges
Parameters
----------
sorted : bool
True if the returned edges are sorted by their src and dst ids.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
edge_array = _CAPI_DGLGraphEdges(self._handle, sorted)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def in_degree(self, v):
"""Return the in degree of the node.
Parameters
----------
v : int
The node.
Returns
-------
int
The in degree.
"""
return _CAPI_DGLGraphInDegree(self._handle, v)
def in_degrees(self, v):
"""Return the in degrees of the nodes.
Parameters
----------
v : utils.Index
The nodes.
Returns
-------
int
The in degree array.
"""
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphInDegrees(self._handle, v_array))
def out_degree(self, v):
"""Return the out degree of the node.
Parameters
----------
v : int
The node.
Returns
-------
int
The out degree.
"""
return _CAPI_DGLGraphOutDegree(self._handle, v)
def out_degrees(self, v):
"""Return the out degrees of the nodes.
Parameters
----------
v : utils.Index
The nodes.
Returns
-------
int
The out degree array.
"""
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphOutDegrees(self._handle, v_array))
def node_subgraph(self, v):
"""Return the induced node subgraph.
Parameters
----------
v : utils.Index
The nodes.
Returns
-------
GraphIndex
The subgraph index.
utils.Index
The induced edge ids. This is also a map from new edge id to parent edge id.
"""
v_array = v.todgltensor()
rst = _CAPI_DGLGraphVertexSubgraph(self._handle, v_array)
gi = GraphIndex(rst(0))
induced_edges = utils.toindex(rst(2))
return gi, induced_edges
def adjacency_matrix(self):
"""Return the adjacency matrix representation of this graph.
Returns
-------
utils.CtxCachedObject
An object that returns tensor given context.
"""
if not 'adj' in self._cache:
src, dst, _ = self.edges(sorted=False)
src = F.unsqueeze(src.tousertensor(), 0)
dst = F.unsqueeze(dst.tousertensor(), 0)
idx = F.pack([dst, src])
n = self.number_of_nodes()
dat = F.ones((self.number_of_edges(),))
mat = F.sparse_tensor(idx, dat, [n, n])
self._cache['adj'] = utils.CtxCachedObject(lambda ctx: F.to_context(mat, ctx))
return self._cache['adj']
def incidence_matrix(self, oriented=False):
"""Return the incidence matrix representation of this graph.
Parameters
----------
oriented : bool, optional (default=False)
Whether the returned incidence matrix is oriented.
Returns
-------
utils.CtxCachedObject
An object that returns tensor given context.
"""
key = ('oriented ' if oriented else '') + 'incidence matrix'
if not key in self._cache:
src, dst, _ = self.edges(sorted=False)
src = src.tousertensor()
dst = dst.tousertensor()
m = self.number_of_edges()
eid = F.arange(m, dtype=F.int64)
row = F.pack([src, dst])
col = F.pack([eid, eid])
idx = F.stack([row, col])
diagonal = (src == dst)
if oriented:
x = -F.ones((m,))
y = F.ones((m,))
x[diagonal] = 0
y[diagonal] = 0
dat = F.pack([x, y])
else:
x = F.ones((m,))
x[diagonal] = 0
dat = F.pack([x, x])
n = self.number_of_nodes()
mat = F.sparse_tensor(idx, dat, [n, m])
self._cache[key] = utils.CtxCachedObject(lambda ctx: F.to_context(mat, ctx))
return self._cache[key]
def to_networkx(self):
"""Convert to networkx graph.
The edge id will be saved as the 'id' edge attribute.
Returns
-------
networkx.DiGraph
The nx graph
"""
src, dst, eid = self.edges()
ret = nx.DiGraph()
for u, v, id in zip(src, dst, eid):
ret.add_edge(u, v, id=id)
return ret
def from_networkx(self, nx_graph):
"""Convert from networkx graph.
If 'id' edge attribute exists, the edge will be added follows
the edge id order. Otherwise, order is undefined.
Parameters
----------
nx_graph : networkx.DiGraph
The nx graph
"""
self.clear()
if not isinstance(nx_graph, nx.DiGraph):
nx_graph = nx.DiGraph(nx_graph)
num_nodes = nx_graph.number_of_nodes()
self.add_nodes(num_nodes)
has_edge_id = 'id' in next(iter(nx_graph.edges))
if has_edge_id:
num_edges = nx_graph.number_of_edges()
src = np.zeros((num_edges,), dtype=np.int64)
dst = np.zeros((num_edges,), dtype=np.int64)
for e, attr in nx_graph.edges.items:
u, v = e
eid = attr['id']
src[eid] = u
dst[eid] = v
else:
src = []
dst = []
for u, v in nx_graph.edges:
src.append(u)
dst.append(v)
src = utils.toindex(src)
dst = utils.toindex(dst)
self.add_edges(src, dst)
def from_scipy_sparse_matrix(self, adj):
"""Convert from scipy sparse matrix.
Parameters
----------
adj : scipy sparse matrix
"""
self.clear()
self.add_nodes(adj.shape[0])
adj_coo = adj.tocoo()
src = utils.toindex(adj_coo.row)
dst = utils.toindex(adj_coo.col)
self.add_edges(src, dst)
def line_graph(self, backtracking=True):
"""Return the line graph of this graph.
Parameters
----------
backtracking : bool, optional (default=False)
Whether (i, j) ~ (j, i) in L(G).
(i, j) ~ (j, i) is the behavior of networkx.line_graph.
Returns
-------
GraphIndex
The line graph of this graph.
"""
handle = _CAPI_DGLGraphLineGraph(self._handle, backtracking)
return GraphIndex(handle)
def disjoint_union(graphs):
"""Return a disjoint union of the input graphs.
The new graph will include all the nodes/edges in the given graphs.
Nodes/Edges will be relabled by adding the cumsum of the previous graph sizes
in the given sequence order. For example, giving input [g1, g2, g3], where
they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become node#7
in the result graph. Edge ids are re-assigned similarly.
Parameters
----------
graphs : iterable of GraphIndex
The input graphs
Returns
-------
GraphIndex
The disjoint union
"""
inputs = c_array(GraphIndexHandle, [gr._handle for gr in graphs])
inputs = ctypes.cast(inputs, ctypes.c_void_p)
handle = _CAPI_DGLDisjointUnion(inputs, len(graphs))
return GraphIndex(handle)
def disjoint_partition(graph, num_or_size_splits):
"""Partition the graph disjointly.
This is a reverse operation of DisjointUnion. The graph will be partitioned
into num graphs. This requires the given number of partitions to evenly
divides the number of nodes in the graph. If the a size list is given,
the sum of the given sizes is equal.
Parameters
----------
graph : GraphIndex
The graph to be partitioned
num_or_size_splits : int or utils.Index
The partition number of size splits
Returns
-------
list of GraphIndex
The partitioned graphs
"""
if isinstance(num_or_size_splits, utils.Index):
rst = _CAPI_DGLDisjointPartitionBySizes(
graph._handle,
num_or_size_splits.todgltensor())
else:
rst = _CAPI_DGLDisjointPartitionByNum(
graph._handle,
int(num_or_size_splits))
graphs = []
for val in rst.asnumpy():
handle = ctypes.cast(int(val), ctypes.c_void_p)
graphs.append(GraphIndex(handle))
return graphs
def create_graph_index(graph_data=None):
"""Create a graph index object.
Parameters
----------
graph_data : graph data, optional
Data to initialize graph. Same as networkx's semantics.
"""
if isinstance(graph_data, GraphIndex):
return graph_data
handle = _CAPI_DGLGraphCreate()
gi = GraphIndex(handle)
if graph_data is not None:
gi.from_networkx(graph_data)
return gi
_init_api("dgl.graph_index")
"""DGL Runtime NDArray API.
dgl.ndarray provides a minimum runtime array structure to be
used with C++ library.
"""
# pylint: disable=invalid-name,unused-import
from __future__ import absolute_import as _abs
import ctypes
import functools
import operator
import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack, numpyasarray
from ._ffi.ndarray import _set_class_ndarray
from . import backend as F
class NDArray(NDArrayBase):
"""Lightweight NDArray class for DGL framework."""
def __len__(self):
return functools.reduce(operator.mul, self.shape, 1)
def cpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(1, dev_id)
def gpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(2, dev_id)
def array(arr, ctx=cpu(0)):
"""Create an array from source arr.
Parameters
----------
arr : numpy.ndarray
The array to be copied from
ctx : TVMContext, optional
The device context to create the array
Returns
-------
ret : NDArray
The created array
"""
if not isinstance(arr, (_np.ndarray, NDArray)):
arr = _np.array(arr)
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
def zerocopy_from_numpy(np_data):
"""Create an array that shares the given numpy data.
Parameters
----------
np_data : numpy.ndarray
The numpy data
Returns
-------
NDArray
The array
"""
arr, _ = numpyasarray(np_data)
handle = ctypes.pointer(arr)
return NDArray(handle, is_view=True)
_set_class_ndarray(NDArray)
"""Package nn modules"""
from __future__ import absolute_import
import os
__backend__ = os.environ.get('DGLBACKEND', 'pytorch').lower()
if __backend__ == 'numpy':
pass
elif __backend__ == 'pytorch':
from .pytorch import *
else:
elif __backend__ != 'mxnet':
raise Exception("Unsupported backend %s" % __backend__)
......@@ -7,9 +7,8 @@ GCN with SPMV specialization.
"""
import torch.nn as nn
import dgl
import dgl.function as fn
from dgl.base import ALL, is_all
from ... import function as fn
from ...base import ALL, is_all
class NodeUpdateModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
......
"""Utility functions for networkx adapter."""
from __future__ import absolute_import
from collections import MutableMapping
import networkx as nx
import networkx.convert as convert
class NodeDict(MutableMapping):
def __init__(self, add_cb, del_cb):
self._dict = {}
self._add_cb = add_cb
self._del_cb = del_cb
def __setitem__(self, key, val):
self._add_cb(key)
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
self._del_cb(key)
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class AdjOuterDict(MutableMapping):
def __init__(self, add_cb, del_cb):
self._dict = {}
self._add_cb = add_cb
self._del_cb = del_cb
def __setitem__(self, key, val):
val.src = key
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
for val in self._dict[key]:
self._del_cb(key, val)
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class AdjInnerDict(MutableMapping):
def __init__(self, add_cb, del_cb):
self._dict = {}
self.src = None
self._add_cb = add_cb
self._del_cb = del_cb
def __setitem__(self, key, val):
if self.src is not None and key not in self._dict:
self._add_cb(self.src, key)
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
if self.src is not None:
self._del_cb(self.src, key)
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class AdjInnerDictFactory(object):
def __init__(self, cb1, cb2):
self._cb1 = cb1
self._cb2 = cb2
def __call__(self):
return AdjInnerDict(self._cb1, self._cb2)
def nx_init(obj,
add_node_cb,
add_edge_cb,
del_node_cb,
del_edge_cb,
graph_data,
**attr):
"""Init the object to be compatible with networkx's DiGraph.
Parameters
----------
obj : any
The object to be init.
add_node_cb : callable
The callback function when node is added.
add_edge_cb : callable
The callback function when edge is added.
graph_data : graph data
Data to initialize graph. Same as networkx's semantics.
attr : keyword arguments, optional
Attributes to add to graph as key=value pairs.
"""
# The following codes work for networkx 2.1.
obj.adjlist_outer_dict_factory = None
obj.adjlist_inner_dict_factory = AdjInnerDictFactory(add_edge_cb, del_edge_cb)
obj.edge_attr_dict_factory = dict
obj.root_graph = obj
obj.graph = {}
obj._node = NodeDict(add_node_cb, del_node_cb)
obj._adj = AdjOuterDict(add_edge_cb, del_edge_cb)
obj._pred = dict()
obj._succ = obj._adj
if graph_data is not None:
convert.to_networkx_graph(graph_data, create_using=obj)
obj.graph.update(attr)
......@@ -3,19 +3,20 @@ from __future__ import absolute_import
import numpy as np
import dgl.backend as F
import dgl.function.message as fmsg
import dgl.function.reducer as fred
import dgl.utils as utils
from .base import ALL
from . import backend as F
from .function import message as fmsg
from .function import reducer as fred
from . import utils
__all__ = ["degree_bucketing", "get_executor"]
def degree_bucketing(cached_graph, v):
def degree_bucketing(graph, v):
"""Create degree bucketing scheduling policy.
Parameters
----------
cached_graph : dgl.cached_graph.CachedGraph
graph : dgl.graph_index.GraphIndex
the graph
v : dgl.utils.Index
the nodes to gather messages
......@@ -28,7 +29,7 @@ def degree_bucketing(cached_graph, v):
list of node id buckets; nodes belong to the same bucket have
the same degree
"""
degrees = F.asnumpy(cached_graph.in_degrees(v).totensor())
degrees = np.array(graph.in_degrees(v).tolist())
unique_degrees = list(np.unique(degrees))
v_np = np.array(v.tolist())
v_bkt = []
......@@ -38,37 +39,32 @@ def degree_bucketing(cached_graph, v):
#print('degree-bucketing:', unique_degrees, [len(b) for b in v_bkt])
return unique_degrees, v_bkt
class Executor(object):
def run(self, graph):
def run(self):
raise NotImplementedError
class UpdateAllSPMVExecutor(Executor):
def __init__(self, graph, src_field, dst_field, edge_field, use_adj):
self.graph = graph
class SPMVOperator(Executor):
def __init__(self, src_field, edge_field, dst_field, use_edge_feat,
node_repr, adj_build_fn):
self.src_field = src_field
self.dst_field = dst_field
self.edge_field = edge_field
self.use_adj = use_adj
self.dst_field = dst_field
self.use_edge_feat = use_edge_feat
self.node_repr = node_repr
self.adj_build_fn = adj_build_fn
def run(self):
g = self.graph
# get src col
if self.src_field is None:
srccol = g.get_n_repr()
srccol = self.node_repr
else:
srccol = g.get_n_repr()[self.src_field]
srccol = self.node_repr[self.src_field]
ctx = F.get_context(srccol)
if self.use_adj:
adjmat = g.cached_graph.adjmat().get(ctx)
else:
if self.edge_field is None:
dat = g.get_e_repr()
else:
dat = g.get_e_repr()[self.edge_field]
dat = F.squeeze(dat)
# TODO(minjie): should not directly use _indices
idx = g.cached_graph.adjmat().get(ctx)._indices()
n = g.number_of_nodes()
adjmat = F.sparse_tensor(idx, dat, [n, n])
# build adjmat
adjmat = self.adj_build_fn(self.edge_field, ctx, self.use_edge_feat)
# spmm
if len(F.shape(srccol)) == 1:
srccol = F.unsqueeze(srccol, 1)
......@@ -77,104 +73,249 @@ class UpdateAllSPMVExecutor(Executor):
else:
dstcol = F.spmm(adjmat, srccol)
if self.dst_field is None:
g.set_n_repr(dstcol)
return dstcol
else:
g.set_n_repr({self.dst_field : dstcol})
return {self.dst_field : dstcol}
class SendRecvSPMVExecutor(Executor):
def __init__(self, graph, src, dst, src_field, dst_field, edge_field, use_edge_dat):
self.graph = graph
self.src = src
self.dst = dst
self.src_field = src_field
self.dst_field = dst_field
self.edge_field = edge_field
self.use_edge_dat = use_edge_dat
def run(self):
# get src col
g = self.graph
if self.src_field is None:
srccol = g.get_n_repr()
class BasicExecutor(Executor):
def __init__(self, graph, mfunc, rfunc):
self.g = graph
self.exe = self._build_exec(mfunc, rfunc)
@property
def node_repr(self):
raise NotImplementedError
@property
def edge_repr(self):
raise NotImplementedError
@property
def graph_mapping(self):
raise NotImplementedError
def _build_exec(self, mfunc, rfunc):
if isinstance(mfunc, fmsg.CopySrcMessageFunction):
exe = SPMVOperator(src_field=mfunc.src_field,
edge_field=None,
dst_field=rfunc.out_field,
use_edge_feat=False,
node_repr=self.node_repr,
adj_build_fn=self._adj_build_fn)
elif isinstance(mfunc, fmsg.SrcMulEdgeMessageFunction):
exe = SPMVOperator(src_field=mfunc.src_field,
edge_field=mfunc.edge_field,
dst_field=rfunc.out_field,
use_edge_feat=True,
node_repr=self.node_repr,
adj_build_fn=self._adj_build_fn)
else:
srccol = g.get_n_repr()[self.src_field]
ctx = F.get_context(srccol)
raise NotImplementedError("message func type {}".format(type(mfunc)))
return exe
# build adjmat
# build adjmat dat
u, v = utils.edge_broadcasting(self.src, self.dst)
if self.use_edge_dat:
if self.edge_field is None:
dat = g.get_e_repr(u, v)
def run(self):
attr = self.exe.run()
self.g.set_n_repr(attr, self.graph_mapping)
class UpdateAllExecutor(BasicExecutor):
def __init__(self, graph, mfunc, rfunc):
self._init_state()
super(UpdateAllExecutor, self).__init__(graph, mfunc, rfunc)
def _init_state(self):
self._node_repr = None
self._edge_repr = None
self._graph_idx = None
self._graph_shape = None
self._graph_mapping = None
@property
def graph_idx(self):
if self._graph_idx is None:
self._graph_idx = self.g._graph.adjacency_matrix()
return self._graph_idx
@property
def graph_shape(self):
if self._graph_shape is None:
n = self.g.number_of_nodes()
self._graph_shape = [n, n]
return self._graph_shape
@property
def graph_mapping(self):
return ALL
@property
def node_repr(self):
if self._node_repr is None:
self._node_repr = self.g.get_n_repr()
return self._node_repr
@property
def edge_repr(self):
if self._edge_repr is None:
self._edge_repr = self.g.get_e_repr()
return self._edge_repr
def _adj_build_fn(self, edge_field, ctx, use_edge_feat):
if use_edge_feat:
if edge_field is None:
dat = self.edge_repr
else:
dat = g.get_e_repr(u, v)[self.edge_field]
dat = self.edge_repr[edge_field]
dat = F.squeeze(dat)
# TODO(minjie): should not directly use _indices
idx = self.graph_idx.get(ctx)._indices()
adjmat = F.sparse_tensor(idx, dat, self.graph_shape)
else:
dat = F.ones((len(u),))
# build adjmat index
new2old, old2new = utils.build_relabel_map(v)
u = u.totensor()
v = v.totensor()
adjmat = self.graph_idx.get(ctx)
return adjmat
class SendRecvExecutor(BasicExecutor):
def __init__(self, graph, src, dst, mfunc, rfunc):
self._init_state(src, dst)
super(SendRecvExecutor, self).__init__(graph, mfunc, rfunc)
def _init_state(self, src, dst):
self.u, self.v = utils.edge_broadcasting(src, dst)
self._node_repr = None
self._edge_repr = None
self._graph_idx = None
self._graph_shape = None
self._graph_mapping = None
@property
def graph_idx(self):
if self._graph_idx is None:
self._build_adjmat()
return self._graph_idx
@property
def graph_shape(self):
if self._graph_shape is None:
self._build_adjmat()
return self._graph_shape
@property
def graph_mapping(self):
if self._graph_mapping is None:
self._build_adjmat()
return self._graph_mapping
@property
def node_repr(self):
if self._node_repr is None:
self._node_repr = self.g.get_n_repr()
return self._node_repr
@property
def edge_repr(self):
if self._edge_repr is None:
self._edge_repr = self.g.get_e_repr(self.u, self.v)
return self._edge_repr
def _build_adjmat(self):
# handle graph index
new2old, old2new = utils.build_relabel_map(self.v)
u = self.u.tousertensor()
v = self.v.tousertensor()
# TODO(minjie): should not directly use []
new_v = old2new[v]
idx = F.pack([F.unsqueeze(new_v, 0), F.unsqueeze(u, 0)])
n = g.number_of_nodes()
n = self.g.number_of_nodes()
m = len(new2old)
adjmat = F.sparse_tensor(idx, dat, [m, n])
adjmat = F.to_context(adjmat, ctx)
# spmm
if len(F.shape(srccol)) == 1:
srccol = F.unsqueeze(srccol, 1)
dstcol = F.spmm(adjmat, srccol)
dstcol = F.squeeze(dstcol)
else:
dstcol = F.spmm(adjmat, srccol)
if self.dst_field is None:
g.set_n_repr(dstcol, new2old)
self._graph_idx = F.pack([F.unsqueeze(new_v, 0), F.unsqueeze(u, 0)])
self._graph_shape = [m, n]
self._graph_mapping = new2old
def _adj_build_fn(self, edge_field, ctx, use_edge_feat):
if use_edge_feat:
if edge_field is None:
dat = self.edge_repr
else:
dat = self.edge_repr[edge_field]
dat = F.squeeze(dat)
else:
g.set_n_repr({self.dst_field : dstcol}, new2old)
dat = F.ones((len(self.u), ))
adjmat = F.sparse_tensor(self.graph_idx, dat, self.graph_shape)
return F.to_context(adjmat, ctx)
def _is_spmv_supported_node_feat(g, field):
if field is None:
feat = g.get_n_repr()
else:
feat = g.get_n_repr()[field]
shape = F.shape(feat)
return (len(shape) == 1 or len(shape) == 2)
def _is_spmv_supported_edge_feat(g, field):
# check shape, only scalar edge feature can be optimized at the moment.
if field is None:
feat = g.get_e_repr()
class BundledExecutor(BasicExecutor):
"""
Base class for Bundled execution
All shared structure like graph index should be cached in this class or its subclass
BundledUpdateAllExecutor and BundledSendRecvExecutor should subclass BundledExecutor
"""
def __init__(self, graph, mfunc, rfunc):
self.g = graph
func_pairs = self._match_message_with_reduce(mfunc, rfunc)
# create all executors
self.executors = self._build_executors(func_pairs)
def _build_executors(self, func_pairs):
executors = []
for mfunc, rfunc in func_pairs:
exe = self._build_exec(mfunc, rfunc)
executors.append(exe)
return executors
def _match_message_with_reduce(self, mfunc, rfunc):
out2mfunc = {fn.out_field: fn for fn in mfunc.fn_list}
func_pairs = []
for rfn in rfunc.fn_list:
mfn = out2mfunc.get(rfn.msg_field, None)
# field check
assert mfn is not None, \
"cannot find message func for reduce func in-field {}".format(rfn.msg_field)
func_pairs.append((mfn, rfn))
return func_pairs
def run(self):
attr = None
for exe in self.executors:
res = exe.run()
if attr is None:
attr = res
else:
# attr and res must be dict
attr.update(res)
self.g.set_n_repr(attr, self.graph_mapping)
class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor):
def __init__(self, graph, mfunc, rfunc):
self._init_state()
BundledExecutor.__init__(self, graph, mfunc, rfunc)
class BundledSendRecvExecutor(BundledExecutor, SendRecvExecutor):
def __init__(self, graph, src, dst, mfunc, rfunc):
self._init_state(src, dst)
BundledExecutor.__init__(self, graph, mfunc, rfunc)
def _is_spmv_supported(fn, graph=None):
if isinstance(fn, fmsg.MessageFunction):
return fn.is_spmv_supported(graph)
elif isinstance(fn, fred.ReduceFunction):
return fn.is_spmv_supported()
else:
feat = g.get_e_repr()[field]
shape = F.shape(feat)
return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1)
return False
def _create_update_all_exec(graph, **kwargs):
mfunc = kwargs.pop('message_func')
rfunc = kwargs.pop('reduce_func')
if (isinstance(mfunc, fmsg.CopySrcMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)):
# TODO(minjie): more sanity check on field names
return UpdateAllSPMVExecutor(graph,
src_field=mfunc.src_field,
dst_field=rfunc.out_field,
edge_field=None,
use_adj=True)
elif (isinstance(mfunc, fmsg.SrcMulEdgeMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)
and _is_spmv_supported_edge_feat(graph, mfunc.edge_field)):
return UpdateAllSPMVExecutor(graph,
src_field=mfunc.src_field,
dst_field=rfunc.out_field,
edge_field=mfunc.edge_field,
use_adj=False)
elif (isinstance(mfunc, fmsg.CopyEdgeMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)):
return None
if isinstance(mfunc, (list, tuple)) or isinstance(rfunc, (list, tuple)):
mfunc = fmsg.BundledMessageFunction(mfunc)
rfunc = fred.BundledReduceFunction(rfunc)
exec_cls = BundledUpdateAllExecutor
else:
exec_cls = UpdateAllExecutor
if _is_spmv_supported(mfunc, graph) and _is_spmv_supported(rfunc):
return exec_cls(graph, mfunc=mfunc, rfunc=rfunc)
else:
return None
......@@ -183,28 +324,14 @@ def _create_send_and_recv_exec(graph, **kwargs):
dst = kwargs.pop('dst')
mfunc = kwargs.pop('message_func')
rfunc = kwargs.pop('reduce_func')
if (isinstance(mfunc, fmsg.CopySrcMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)):
# TODO(minjie): more sanity check on field names
return SendRecvSPMVExecutor(graph,
src=src,
dst=dst,
src_field=mfunc.src_field,
dst_field=rfunc.out_field,
edge_field=None,
use_edge_dat=False)
elif (isinstance(mfunc, fmsg.SrcMulEdgeMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)
and _is_spmv_supported_edge_feat(graph, mfunc.edge_field)):
return SendRecvSPMVExecutor(graph,
src=src,
dst=dst,
src_field=mfunc.src_field,
dst_field=rfunc.out_field,
edge_field=mfunc.edge_field,
use_edge_dat=True)
if isinstance(mfunc, (list, tuple)) or isinstance(rfunc, (list, tuple)):
mfunc = fmsg.BundledMessageFunction(mfunc)
rfunc = fred.BundledReduceFunction(rfunc)
exec_cls = BundledSendRecvExecutor
else:
exec_cls = SendRecvExecutor
if _is_spmv_supported(mfunc, graph) and _is_spmv_supported(rfunc):
return exec_cls(graph, src=src, dst=dst, mfunc=mfunc, rfunc=rfunc)
else:
return None
......
"""DGLSubGraph"""
"""Class for subgraph data structure."""
from __future__ import absolute_import
import networkx as nx
import dgl.backend as F
from dgl.frame import Frame, FrameRef
from dgl.graph import DGLGraph
from dgl.nx_adapt import nx_init
import dgl.utils as utils
from . import backend as F
from .frame import Frame, FrameRef
from .graph import DGLGraph
from . import utils
class DGLSubGraph(DGLGraph):
# TODO(gaiyu): ReadOnlyGraph
def __init__(self,
parent,
nodes):
super(DGLSubGraph, self).__init__()
# relabel nodes
self._node_mapping = utils.build_relabel_dict(nodes)
self._parent_nid = utils.toindex(nodes)
eids = []
# create subgraph
for eid, (u, v) in enumerate(parent.edge_list):
if u in self._node_mapping and v in self._node_mapping:
self.add_edge(self._node_mapping[u],
self._node_mapping[v])
eids.append(eid)
self._parent_eid = utils.toindex(eids)
def copy_from(self, parent):
"""Copy node/edge features from the parent graph.
"""The subgraph class.
All old features will be removed.
There are two subgraph modes: shared and non-shared.
For the "non-shared" mode, the user needs to explicitly call
``copy_from_parent`` to copy node/edge features from its parent graph.
* If the user tries to get node/edge features before ``copy_from_parent``,
s/he will get nothing.
* If the subgraph already has its own node/edge features, ``copy_from_parent``
will override them.
* Any update on the subgraph's node/edge features will not be seen
by the parent graph. As such, the memory consumption is of the order
of the subgraph size.
* To write the subgraph's node/edge features back to parent graph. There are two options:
(1) Use ``copy_to_parent`` API to write node/edge features back.
(2) [TODO] Use ``dgl.merge`` to merge multiple subgraphs back to one parent.
The "shared" mode is currently not supported.
The subgraph is read-only so mutation is not allowed.
Parameters
----------
parent : DGLGraph
The parent graph
parent_nid : utils.Index
The induced parent node ids in this subgraph.
parent_eid : utils.Index
The induced parent edge ids in this subgraph.
graph_idx : GraphIndex
The graph index.
shared : bool, optional
Whether the subgraph shares node/edge features with the parent graph.
"""
def __init__(self, parent, parent_nid, parent_eid, graph_idx, shared=False):
super(DGLSubGraph, self).__init__(graph_data=graph_idx)
self._parent = parent
self._parent_nid = parent_nid
self._parent_eid = parent_eid
# override APIs
def add_nodes(self, num, reprs=None):
"""Add nodes. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v, reprs=None):
"""Add one edge. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v, reprs=None):
"""Add many edges. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
@property
def parent_nid(self):
"""Get the parent node ids.
The returned tensor can be used as a map from the node id
in this subgraph to the node id in the parent graph.
Returns
-------
Tensor
The parent node id array.
"""
return self._parent_nid.tousertensor()
@property
def parent_eid(self):
"""Get the parent edge ids.
The returned tensor can be used as a map from the edge id
in this subgraph to the edge id in the parent graph.
Returns
-------
Tensor
The parent edge id array.
"""
return self._parent_eid.tousertensor()
def copy_to_parent(self, inplace=False):
"""Write node/edge features to the parent graph.
Parameters
----------
parent : DGLGraph
The parent graph to copy from.
inplace : bool
If true, use inplace write (no gradient but faster)
"""
self._parent._node_frame.update_rows(
self._parent_nid, self._node_frame, inplace=inplace)
self._parent._edge_frame.update_rows(
self._parent_eid, self._edge_frame, inplace=inplace)
def copy_from_parent(self):
"""Copy node/edge features from the parent graph.
All old features will be removed.
"""
if parent._node_frame.num_rows != 0:
self._node_frame = FrameRef(Frame(parent._node_frame[self._parent_nid]))
if parent._edge_frame.num_rows != 0:
self._edge_frame = FrameRef(Frame(parent._edge_frame[self._parent_eid]))
if self._parent._node_frame.num_rows != 0:
self._node_frame = FrameRef(Frame(
self._parent._node_frame[self._parent_nid]))
if self._parent._edge_frame.num_rows != 0:
self._edge_frame = FrameRef(Frame(
self._parent._edge_frame[self._parent_eid]))
......@@ -5,50 +5,70 @@ from collections import Mapping
from functools import wraps
import numpy as np
import dgl.backend as F
from dgl.backend import Tensor, SparseTensor
def is_id_tensor(u):
"""Return whether the input is a supported id tensor."""
return isinstance(u, Tensor) and F.isinteger(u) and len(F.shape(u)) == 1
def is_id_container(u):
"""Return whether the input is a supported id container."""
return (getattr(u, '__iter__', None) is not None
and getattr(u, '__len__', None) is not None)
from . import backend as F
from .backend import Tensor, SparseTensor
from . import ndarray as nd
class Index(object):
"""Index class that can be easily converted to list/tensor."""
def __init__(self, data):
self._list_data = None
self._tensor_data = None
self._ctx_data = dict()
self._list_data = None # a numpy type data
self._user_tensor_data = dict() # dictionary of user tensors
self._dgl_tensor_data = None # a dgl ndarray
self._dispatch(data)
def _dispatch(self, data):
if is_id_tensor(data):
self._tensor_data = data
elif is_id_container(data):
self._list_data = data
"""Store data based on its type."""
if isinstance(data, Tensor):
if not (F.dtype(data) == F.int64 and len(F.shape(data)) == 1):
raise ValueError('Index data must be 1D int64 vector, but got: %s' % str(data))
self._user_tensor_data[F.get_context(data)] = data
elif isinstance(data, nd.NDArray):
if not (data.dtype == 'int64' and len(data.shape) == 1):
raise ValueError('Index data must be 1D int64 vector, but got: %s' % str(data))
self._dgl_tensor_data = data
else:
try:
self._list_data = [int(data)]
self._list_data = np.array([int(data)]).astype(np.int64)
except:
raise TypeError('Error index data: %s' % str(x))
try:
self._list_data = np.array(data).astype(np.int64)
except:
raise ValueError('Error index data: %s' % str(data))
self._user_tensor_data[nd.cpu()] = F.zerocopy_from_numpy(self._list_data)
def tolist(self):
"""Convert to a python-list compatible object."""
if self._list_data is None:
self._list_data = list(F.asnumpy(self._tensor_data))
if self._dgl_tensor_data is not None:
self._list_data = self._dgl_tensor_data.asnumpy()
else:
data = self.tousertensor()
self._list_data = F.zerocopy_to_numpy(data)
return self._list_data
def totensor(self, ctx=None):
if self._tensor_data is None:
self._tensor_data = F.tensor(self._list_data, dtype=F.int64)
def tousertensor(self, ctx=None):
"""Convert to user tensor (defined in `backend`)."""
if ctx is None:
return self._tensor_data
if ctx not in self._ctx_data:
self._ctx_data[ctx] = F.to_context(self._tensor_data, ctx)
return self._ctx_data[ctx]
ctx = nd.cpu()
if len(self._user_tensor_data) == 0:
# zero copy from dgl tensor
dl = self._dgl_tensor_data.to_dlpack()
self._user_tensor_data[nd.cpu()] = F.zerocopy_from_dlpack(dl)
if ctx not in self._user_tensor_data:
# copy from cpu to another device
data = next(iter(self._user_tensor_data.values()))
self._user_tensor_data[ctx] = F.to_context(data, ctx)
return self._user_tensor_data[ctx]
def todgltensor(self):
"""Convert to dgl.NDArray."""
if self._dgl_tensor_data is None:
# zero copy from user tensor
tsor = self.tousertensor()
dl = F.zerocopy_to_dlpack(tsor)
self._dgl_tensor_data = nd.from_dlpack(dl)
return self._dgl_tensor_data
def __iter__(self):
return iter(self.tolist())
......@@ -56,8 +76,11 @@ class Index(object):
def __len__(self):
if self._list_data is not None:
return len(self._list_data)
elif len(self._user_tensor_data) > 0:
data = next(iter(self._user_tensor_data.values()))
return len(data)
else:
return len(self._tensor_data)
return len(self._dgl_tensor_data)
def __getitem__(self, i):
return self.tolist()[i]
......@@ -118,40 +141,13 @@ def edge_broadcasting(u, v):
The dst id(s) after broadcasting
"""
if len(u) != len(v) and len(u) == 1:
u = toindex(F.broadcast_to(u.totensor(), v.totensor()))
u = toindex(F.broadcast_to(u.tousertensor(), v.tousertensor()))
elif len(u) != len(v) and len(v) == 1:
v = toindex(F.broadcast_to(v.totensor(), u.totensor()))
v = toindex(F.broadcast_to(v.tousertensor(), u.tousertensor()))
else:
assert len(u) == len(v)
return u, v
'''
def convert_to_id_container(x):
if is_id_container(x):
return x
elif is_id_tensor(x):
return F.asnumpy(x)
else:
try:
return [int(x)]
except:
raise TypeError('Error node: %s' % str(x))
return None
def convert_to_id_tensor(x, ctx=None):
if is_id_container(x):
ret = F.tensor(x, dtype=F.int64)
elif is_id_tensor(x):
ret = x
else:
try:
ret = F.tensor([int(x)], dtype=F.int64)
except:
raise TypeError('Error node: %s' % str(x))
ret = F.to_context(ret, ctx)
return ret
'''
class LazyDict(Mapping):
"""A readonly dictionary that does not materialize the storage."""
def __init__(self, fn, keys):
......@@ -209,7 +205,7 @@ def build_relabel_map(x):
One can use advanced indexing to convert an old id tensor to a
new id tensor: new_id = old_to_new[old_id]
"""
x = x.totensor()
x = x.tousertensor()
unique_x, _ = F.sort(F.unique(x))
map_len = int(F.max(unique_x)) + 1
old_to_new = F.zeros(map_len, dtype=F.int64)
......@@ -316,6 +312,6 @@ def reorder(dict_like, index):
"""
new_dict = {}
for key, val in dict_like.items():
idx_ctx = index.totensor(F.get_context(val))
idx_ctx = index.tousertensor(F.get_context(val))
new_dict[key] = F.gather_row(val, idx_ctx)
return new_dict
......@@ -17,7 +17,6 @@ setuptools.setup(
'numpy>=1.14.0',
'scipy>=1.1.0',
'networkx>=2.1',
'python-igraph>=0.7.0',
],
data_files=[('', ['VERSION'])],
url='https://github.com/jermainewang/dgl-1')
url='https://github.com/jermainewang/dgl')
// Graph class implementation
#include <algorithm>
#include <unordered_map>
#include <dgl/graph.h>
namespace dgl {
namespace {
inline bool IsValidIdArray(const IdArray& arr) {
return arr->ctx.device_type == kDLCPU && arr->ndim == 1
&& arr->dtype.code == kDLInt && arr->dtype.bits == 64;
}
} // namespace
void Graph::AddVertices(uint64_t num_vertices) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
adjlist_.resize(adjlist_.size() + num_vertices);
reverse_adjlist_.resize(reverse_adjlist_.size() + num_vertices);
}
void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
CHECK(HasVertex(src) && HasVertex(dst))
<< "Invalid vertices: src=" << src << " dst=" << dst;
dgl_id_t eid = num_edges_++;
adjlist_[src].succ.push_back(dst);
adjlist_[src].edge_id.push_back(eid);
reverse_adjlist_[dst].succ.push_back(src);
reverse_adjlist_[dst].edge_id.push_back(eid);
all_edges_src_.push_back(src);
all_edges_dst_.push_back(dst);
}
void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
if (srclen == 1) {
// one-many
for (int64_t i = 0; i < dstlen; ++i) {
AddEdge(src_data[0], dst_data[i]);
}
} else if (dstlen == 1) {
// many-one
for (int64_t i = 0; i < srclen; ++i) {
AddEdge(src_data[i], dst_data[0]);
}
} else {
// many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
AddEdge(src_data[i], dst_data[i]);
}
}
}
BoolArray Graph::HasVertices(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
BoolArray rst = BoolArray::Empty({len}, vids->dtype, vids->ctx);
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
const int64_t nverts = NumVertices();
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = (vid_data[i] < nverts)? 1 : 0;
}
return rst;
}
// O(E)
bool Graph::HasEdge(dgl_id_t src, dgl_id_t dst) const {
if (!HasVertex(src) || !HasVertex(dst)) return false;
const auto& succ = adjlist_[src].succ;
return std::find(succ.begin(), succ.end(), dst) != succ.end();
}
// O(E*K) pretty slow
BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen);
BoolArray rst = BoolArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
if (srclen == 1) {
// one-many
for (int64_t i = 0; i < dstlen; ++i) {
rst_data[i] = HasEdge(src_data[0], dst_data[i])? 1 : 0;
}
} else if (dstlen == 1) {
// many-one
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdge(src_data[i], dst_data[0])? 1 : 0;
}
} else {
// many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdge(src_data[i], dst_data[i])? 1 : 0;
}
}
return rst;
}
// The data is copy-out; support zero-copy?
IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
CHECK(radius >= 1) << "invalid radius: " << radius;
const auto& pred = reverse_adjlist_[vid].succ;
const int64_t len = pred.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = pred[i];
}
return rst;
}
// The data is copy-out; support zero-copy?
IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
CHECK(radius >= 1) << "invalid radius: " << radius;
const auto& succ = adjlist_[vid].succ;
const int64_t len = succ.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = succ[i];
}
return rst;
}
// O(E)
dgl_id_t Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
CHECK(HasVertex(src)) << "invalid edge: " << src << " -> " << dst;
const auto& succ = adjlist_[src].succ;
for (size_t i = 0; i < succ.size(); ++i) {
if (succ[i] == dst) {
return adjlist_[src].edge_id[i];
}
}
LOG(FATAL) << "invalid edge: " << src << " -> " << dst;
return 0;
}
// O(E*k) pretty slow
IdArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen);
IdArray rst = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
if (srclen == 1) {
// one-many
for (int64_t i = 0; i < dstlen; ++i) {
rst_data[i] = EdgeId(src_data[0], dst_data[i]);
}
} else if (dstlen == 1) {
// many-one
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = EdgeId(src_data[i], dst_data[0]);
}
} else {
// many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = EdgeId(src_data[i], dst_data[i]);
}
}
return rst;
}
// O(E)
Graph::EdgeArray Graph::InEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = reverse_adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* src_data = static_cast<int64_t*>(src->data);
int64_t* dst_data = static_cast<int64_t*>(dst->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
src_data[i] = reverse_adjlist_[vid].succ[i];
eid_data[i] = reverse_adjlist_[vid].edge_id[i];
}
std::fill(dst_data, dst_data + len, vid);
return EdgeArray{src, dst, eid};
}
// O(E)
Graph::EdgeArray Graph::InEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t rstlen = 0;
for (int64_t i = 0; i < len; ++i) {
CHECK(HasVertex(vid_data[i])) << "Invalid vertex: " << vid_data[i];
rstlen += reverse_adjlist_[vid_data[i]].succ.size();
}
IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray eid = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
const auto& pred = reverse_adjlist_[vid_data[i]].succ;
const auto& eids = reverse_adjlist_[vid_data[i]].edge_id;
for (size_t j = 0; j < pred.size(); ++j) {
*(src_ptr++) = pred[j];
*(dst_ptr++) = vid_data[i];
*(eid_ptr++) = eids[j];
}
}
return EdgeArray{src, dst, eid};
}
// O(E)
Graph::EdgeArray Graph::OutEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* src_data = static_cast<int64_t*>(src->data);
int64_t* dst_data = static_cast<int64_t*>(dst->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
dst_data[i] = adjlist_[vid].succ[i];
eid_data[i] = adjlist_[vid].edge_id[i];
}
std::fill(src_data, src_data + len, vid);
return EdgeArray{src, dst, eid};
}
// O(E)
Graph::EdgeArray Graph::OutEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t rstlen = 0;
for (int64_t i = 0; i < len; ++i) {
CHECK(HasVertex(vid_data[i])) << "Invalid vertex: " << vid_data[i];
rstlen += adjlist_[vid_data[i]].succ.size();
}
IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray eid = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
const auto& succ = adjlist_[vid_data[i]].succ;
const auto& eids = adjlist_[vid_data[i]].edge_id;
for (size_t j = 0; j < succ.size(); ++j) {
*(src_ptr++) = vid_data[i];
*(dst_ptr++) = succ[j];
*(eid_ptr++) = eids[j];
}
}
return EdgeArray{src, dst, eid};
}
// O(E*log(E)) if sort is required; otherwise, O(E)
Graph::EdgeArray Graph::Edges(bool sorted) const {
const int64_t len = num_edges_;
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
if (sorted) {
typedef std::tuple<int64_t, int64_t, int64_t> Tuple;
std::vector<Tuple> tuples;
tuples.reserve(len);
for (uint64_t eid = 0; eid < num_edges_; ++eid) {
tuples.emplace_back(all_edges_src_[eid], all_edges_dst_[eid], eid);
}
// sort according to src and dst ids
std::sort(tuples.begin(), tuples.end(),
[] (const Tuple& t1, const Tuple& t2) {
return std::get<0>(t1) < std::get<0>(t2)
|| (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2));
});
// make return arrays
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
for (size_t i = 0; i < tuples.size(); ++i) {
src_ptr[i] = std::get<0>(tuples[i]);
dst_ptr[i] = std::get<1>(tuples[i]);
eid_ptr[i] = std::get<2>(tuples[i]);
}
} else {
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
std::copy(all_edges_src_.begin(), all_edges_src_.end(), src_ptr);
std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), dst_ptr);
for (uint64_t eid = 0; eid < num_edges_; ++eid) {
eid_ptr[eid] = eid;
}
}
return EdgeArray{src, dst, eid};
}
// O(V)
DegreeArray Graph::InDegrees(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
const auto vid = vid_data[i];
CHECK(HasVertex(vid)) << "Invalid vertex: " << vid;
rst_data[i] = reverse_adjlist_[vid].succ.size();
}
return rst;
}
// O(V)
DegreeArray Graph::OutDegrees(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
const auto vid = vid_data[i];
CHECK(HasVertex(vid)) << "Invalid vertex: " << vid;
rst_data[i] = adjlist_[vid].succ.size();
}
return rst;
}
Subgraph Graph::VertexSubgraph(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;
std::vector<dgl_id_t> edges;
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
for (int64_t i = 0; i < len; ++i) {
oldv2newv[vid_data[i]] = i;
}
Subgraph rst;
rst.induced_vertices = vids;
rst.graph.AddVertices(len);
for (int64_t i = 0; i < len; ++i) {
const dgl_id_t oldvid = vid_data[i];
const dgl_id_t newvid = i;
for (size_t j = 0; j < adjlist_[oldvid].succ.size(); ++j) {
const dgl_id_t oldsucc = adjlist_[oldvid].succ[j];
if (oldv2newv.count(oldsucc)) {
const dgl_id_t newsucc = oldv2newv[oldsucc];
edges.push_back(adjlist_[oldvid].edge_id[j]);
rst.graph.AddEdge(newvid, newsucc);
}
}
}
rst.induced_edges = IdArray::Empty({static_cast<int64_t>(edges.size())}, vids->dtype, vids->ctx);
std::copy(edges.begin(), edges.end(), static_cast<int64_t*>(rst.induced_edges->data));
return rst;
}
Subgraph Graph::EdgeSubgraph(IdArray src, IdArray dst) const {
LOG(FATAL) << "not implemented";
return Subgraph();
}
Graph Graph::Reverse() const {
LOG(FATAL) << "not implemented";
return *this;
}
} // namespace dgl
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include <dgl/graph.h>
#include <dgl/graph_op.h>
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMArgValue;
using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc;
using tvm::runtime::NDArray;
namespace dgl {
// Graph handler type
typedef void* GraphHandle;
namespace {
// Convert EdgeArray structure to PackedFunc.
PackedFunc ConvertEdgeArrayToPackedFunc(const Graph::EdgeArray& ea) {
auto body = [ea] (TVMArgs args, TVMRetValue* rv) {
int which = args[0];
if (which == 0) {
*rv = std::move(ea.src);
} else if (which == 1) {
*rv = std::move(ea.dst);
} else if (which == 2) {
*rv = std::move(ea.id);
} else {
LOG(FATAL) << "invalid choice";
}
};
return PackedFunc(body);
}
// Convert Subgraph structure to PackedFunc.
PackedFunc ConvertSubgraphToPackedFunc(const Subgraph& sg) {
auto body = [sg] (TVMArgs args, TVMRetValue* rv) {
int which = args[0];
if (which == 0) {
Graph* gptr = new Graph();
*gptr = std::move(sg.graph);
GraphHandle ghandle = gptr;
*rv = ghandle;
} else if (which == 1) {
*rv = std::move(sg.induced_vertices);
} else if (which == 2) {
*rv = std::move(sg.induced_edges);
} else {
LOG(FATAL) << "invalid choice";
}
};
return PackedFunc(body);
}
// Convert the given DLTensor to a temporary DLManagedTensor that does not own memory.
DLManagedTensor* CreateTmpDLManagedTensor(const TVMArgValue& arg) {
const DLTensor* dl_tensor = arg;
DLManagedTensor* ret = new DLManagedTensor();
ret->deleter = [] (DLManagedTensor* self) { delete self; };
ret->manager_ctx = nullptr;
ret->dl_tensor = *dl_tensor;
return ret;
}
} // namespace
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreate")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = new Graph();
*rv = ghandle;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFree")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
delete gptr;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
uint64_t num_vertices = args[1];
gptr->AddVertices(num_vertices);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdge")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
gptr->AddEdge(src, dst);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const IdArray dst = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
gptr->AddEdges(src, dst);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphClear")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
gptr->Clear();
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumVertices")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumVertices());
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumEdges());
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertex")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = gptr->HasVertex(vid);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertices")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = gptr->HasVertices(vids);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdge")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
*rv = gptr->HasEdge(src, dst);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const IdArray dst = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
*rv = gptr->HasEdges(src, dst);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
const uint64_t radius = args[2];
*rv = gptr->Predecessors(vid, radius);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphSuccessors")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
const uint64_t radius = args[2];
*rv = gptr->Successors(vid, radius);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
*rv = static_cast<int64_t>(gptr->EdgeId(src, dst));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const IdArray dst = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
*rv = gptr->EdgeIds(src, dst);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_1")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->InEdges(vid));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_2")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertEdgeArrayToPackedFunc(gptr->InEdges(vids));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_1")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->OutEdges(vid));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_2")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertEdgeArrayToPackedFunc(gptr->OutEdges(vids));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const bool sorted = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->Edges(sorted));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegree")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(gptr->InDegree(vid));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegrees")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = gptr->InDegrees(vids);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegree")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(gptr->OutDegree(vid));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegrees")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = gptr->OutDegrees(vids);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertSubgraphToPackedFunc(gptr->VertexSubgraph(vids));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
void* list = args[0];
GraphHandle* inhandles = static_cast<GraphHandle*>(list);
int list_size = args[1];
std::vector<const Graph*> graphs;
for (int i = 0; i < list_size; ++i) {
const Graph* gr = static_cast<const Graph*>(inhandles[i]);
graphs.push_back(gr);
}
Graph* gptr = new Graph();
*gptr = GraphOp::DisjointUnion(std::move(graphs));
GraphHandle ghandle = gptr;
*rv = ghandle;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
int64_t num = args[1];
std::vector<Graph>&& rst = GraphOp::DisjointPartitionByNum(gptr, num);
// return the pointer array as an integer array
const int64_t len = rst.size();
NDArray ptr_array = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* ptr_array_data = static_cast<int64_t*>(ptr_array->data);
for (size_t i = 0; i < rst.size(); ++i) {
Graph* ptr = new Graph();
*ptr = std::move(rst[i]);
ptr_array_data[i] = reinterpret_cast<std::intptr_t>(ptr);
}
*rv = ptr_array;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray sizes = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
std::vector<Graph>&& rst = GraphOp::DisjointPartitionBySizes(gptr, sizes);
// return the pointer array as an integer array
const int64_t len = rst.size();
NDArray ptr_array = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* ptr_array_data = static_cast<int64_t*>(ptr_array->data);
for (size_t i = 0; i < rst.size(); ++i) {
Graph* ptr = new Graph();
*ptr = std::move(rst[i]);
ptr_array_data[i] = reinterpret_cast<std::intptr_t>(ptr);
}
*rv = ptr_array;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLineGraph")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
bool backtracking = args[1];
const Graph* gptr = static_cast<Graph*>(ghandle);
Graph* lgptr = new Graph();
*lgptr = GraphOp::LineGraph(gptr, backtracking);
GraphHandle lghandle = lgptr;
*rv = lghandle;
});
} // namespace dgl
// Graph operation implementation
#include <dgl/graph_op.h>
#include <algorithm>
namespace dgl {
Graph GraphOp::LineGraph(const Graph* g, bool backtracking){
typedef std::pair<dgl_id_t, dgl_id_t> entry;
typedef std::map<dgl_id_t, std::vector<entry>> csm; // Compressed Sparse Matrix
csm adj;
std::vector<entry> vec;
for (size_t i = 0; i != g->all_edges_src_.size(); ++i) {
auto u = g->all_edges_src_[i];
auto v = g->all_edges_dst_[i];
auto ret = adj.insert(csm::value_type(u, vec));
(ret.first)->second.push_back(std::make_pair(v, i));
}
std::vector<dgl_id_t> lg_src, lg_dst;
for (size_t i = 0; i != g->all_edges_src_.size(); ++i) {
auto u = g->all_edges_src_[i];
auto v = g->all_edges_dst_[i];
auto j = adj.find(v);
if (j != adj.end()) {
for (size_t k = 0; k != j->second.size(); ++k) {
if (backtracking || (!backtracking && j->second[k].first != u)) {
lg_src.push_back(i);
lg_dst.push_back(j->second[k].second);
}
}
}
}
const int64_t len = lg_src.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
std::copy(lg_src.begin(), lg_src.end(), src_ptr);
std::copy(lg_dst.begin(), lg_dst.end(), dst_ptr);
Graph lg;
lg.AddVertices(g->NumEdges());
lg.AddEdges(src, dst);
return lg;
}
Graph GraphOp::DisjointUnion(std::vector<const Graph*> graphs) {
Graph rst;
uint64_t cumsum = 0;
for (const Graph* gr : graphs) {
rst.AddVertices(gr->NumVertices());
for (uint64_t i = 0; i < gr->NumEdges(); ++i) {
rst.AddEdge(gr->all_edges_src_[i] + cumsum, gr->all_edges_dst_[i] + cumsum);
}
cumsum += gr->NumVertices();
}
return rst;
}
std::vector<Graph> GraphOp::DisjointPartitionByNum(const Graph* graph, int64_t num) {
CHECK(num != 0 && graph->NumVertices() % num == 0)
<< "Number of partitions must evenly divide the number of nodes.";
IdArray sizes = IdArray::Empty({num}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num);
return DisjointPartitionBySizes(graph, sizes);
}
std::vector<Graph> GraphOp::DisjointPartitionBySizes(const Graph* graph, IdArray sizes) {
const int64_t len = sizes->shape[0];
const int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::vector<int64_t> cumsum;
cumsum.push_back(0);
for (int64_t i = 0; i < len; ++i) {
cumsum.push_back(cumsum[i] + sizes_data[i]);
}
CHECK_EQ(cumsum[len], graph->NumVertices())
<< "Sum of the given sizes must equal to the number of nodes.";
dgl_id_t node_offset = 0, edge_offset = 0;
std::vector<Graph> rst(len);
for (int64_t i = 0; i < len; ++i) {
// copy adj
rst[i].adjlist_.insert(rst[i].adjlist_.end(),
graph->adjlist_.begin() + node_offset,
graph->adjlist_.begin() + node_offset + sizes_data[i]);
rst[i].reverse_adjlist_.insert(rst[i].reverse_adjlist_.end(),
graph->reverse_adjlist_.begin() + node_offset,
graph->reverse_adjlist_.begin() + node_offset + sizes_data[i]);
// relabel adjs
size_t num_edges = 0;
for (auto& elist : rst[i].adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
}
num_edges += elist.succ.size();
}
for (auto& elist : rst[i].reverse_adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
}
}
// copy edges
rst[i].all_edges_src_.reserve(num_edges);
rst[i].all_edges_dst_.reserve(num_edges);
rst[i].num_edges_ = num_edges;
for (size_t j = edge_offset; j < edge_offset + num_edges; ++j) {
rst[i].all_edges_src_.push_back(graph->all_edges_src_[j] - node_offset);
rst[i].all_edges_dst_.push_back(graph->all_edges_dst_[j] - node_offset);
}
// update offset
CHECK_EQ(rst[i].NumVertices(), sizes_data[i]);
CHECK_EQ(rst[i].NumEdges(), num_edges);
node_offset += sizes_data[i];
edge_offset += num_edges;
}
/*for (int64_t i = 0; i < len; ++i) {
rst[i].AddVertices(sizes_data[i]);
}
for (dgl_id_t eid = 0; eid < graph->num_edges_; ++eid) {
const dgl_id_t src = graph->all_edges_src_[eid];
const dgl_id_t dst = graph->all_edges_dst_[eid];
size_t src_select = 0, dst_select = 0;
for (size_t i = 1; i < cumsum.size(); ++i) { // TODO: replace with binary search
if (cumsum[i] > src) {
src_select = i;
break;
}
}
for (size_t i = 1; i < cumsum.size(); ++i) { // TODO: replace with binary search
if (cumsum[i] > dst) {
dst_select = i;
break;
}
}
if (src_select != dst_select) {
// the edge is ignored if across two partitions
continue;
}
const int64_t offset = cumsum[src_select - 1];
rst[src_select - 1].AddEdge(src - offset, dst - offset);
}*/
return rst;
}
} // namespace dgl
# C API and runtime
Borrowed and adapted from TVM project.
/*!
* Copyright (c) 2016 by Contributors
* \file c_runtime_api.cc
* \brief Runtime API implementation
*/
#include <dmlc/thread_local.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <array>
#include <algorithm>
#include <string>
#include <cstdlib>
#include "runtime_base.h"
namespace tvm {
namespace runtime {
/*!
* \brief The name of Device API factory.
* \param type The device type.
*/
inline std::string DeviceName(int type) {
switch (type) {
case kDLCPU: return "cpu";
case kDLGPU: return "gpu";
case kDLOpenCL: return "opencl";
case kDLSDAccel: return "sdaccel";
case kDLAOCL: return "aocl";
case kDLVulkan: return "vulkan";
case kDLMetal: return "metal";
case kDLVPI: return "vpi";
case kDLROCM: return "rocm";
case kOpenGL: return "opengl";
case kExtDev: return "ext_dev";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
class DeviceAPIManager {
public:
static const int kMaxDeviceAPI = 32;
// Get API
static DeviceAPI* Get(const TVMContext& ctx) {
return Get(ctx.device_type);
}
static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
return Global()->GetAPI(dev_type, allow_missing);
}
private:
std::array<DeviceAPI*, kMaxDeviceAPI> api_;
DeviceAPI* rpc_api_{nullptr};
std::mutex mutex_;
// constructor
DeviceAPIManager() {
std::fill(api_.begin(), api_.end(), nullptr);
}
// Global static variable.
static DeviceAPIManager* Global() {
static DeviceAPIManager inst;
return &inst;
}
// Get or initialize API.
DeviceAPI* GetAPI(int type, bool allow_missing) {
if (type < kRPCSessMask) {
if (api_[type] != nullptr) return api_[type];
std::lock_guard<std::mutex> lock(mutex_);
if (api_[type] != nullptr) return api_[type];
api_[type] = GetAPI(DeviceName(type), allow_missing);
return api_[type];
} else {
if (rpc_api_ != nullptr) return rpc_api_;
std::lock_guard<std::mutex> lock(mutex_);
if (rpc_api_ != nullptr) return rpc_api_;
rpc_api_ = GetAPI("rpc", allow_missing);
return rpc_api_;
}
}
DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
std::string factory = "device_api." + name;
auto* f = Registry::Get(factory);
if (f == nullptr) {
CHECK(allow_missing)
<< "Device API " << name << " is not enabled.";
return nullptr;
}
void* ptr = (*f)();
return static_cast<DeviceAPI*>(ptr);
}
};
DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) {
return DeviceAPIManager::Get(
static_cast<int>(ctx.device_type), allow_missing);
}
void* DeviceAPI::AllocWorkspace(TVMContext ctx,
size_t size,
TVMType type_hint) {
return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
}
void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) {
FreeDataSpace(ctx, ptr);
}
TVMStreamHandle DeviceAPI::CreateStream(TVMContext ctx) {
LOG(FATAL) << "Device does not support stream api.";
return 0;
}
void DeviceAPI::FreeStream(TVMContext ctx, TVMStreamHandle stream) {
LOG(FATAL) << "Device does not support stream api.";
}
void DeviceAPI::SyncStreamFromTo(TVMContext ctx,
TVMStreamHandle event_src,
TVMStreamHandle event_dst) {
LOG(FATAL) << "Device does not support stream api.";
}
} // namespace runtime
} // namespace tvm
using namespace tvm::runtime;
struct TVMRuntimeEntry {
std::string ret_str;
std::string last_error;
TVMByteArray ret_bytes;
};
typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore;
const char *TVMGetLastError() {
return TVMAPIRuntimeStore::Get()->last_error.c_str();
}
void TVMAPISetLastError(const char* msg) {
#ifndef _LIBCPP_SGX_CONFIG
TVMAPIRuntimeStore::Get()->last_error = msg;
#else
sgx::OCallPackedFunc("__sgx_set_last_error__", msg);
#endif
}
int TVMModLoadFromFile(const char* file_name,
const char* format,
TVMModuleHandle* out) {
API_BEGIN();
Module m = Module::LoadFromFile(file_name, format);
*out = new Module(m);
API_END();
}
int TVMModImport(TVMModuleHandle mod,
TVMModuleHandle dep) {
API_BEGIN();
static_cast<Module*>(mod)->Import(
*static_cast<Module*>(dep));
API_END();
}
int TVMModGetFunction(TVMModuleHandle mod,
const char* func_name,
int query_imports,
TVMFunctionHandle *func) {
API_BEGIN();
PackedFunc pf = static_cast<Module*>(mod)->GetFunction(
func_name, query_imports != 0);
if (pf != nullptr) {
*func = new PackedFunc(pf);
} else {
*func = nullptr;
}
API_END();
}
int TVMModFree(TVMModuleHandle mod) {
API_BEGIN();
delete static_cast<Module*>(mod);
API_END();
}
int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
TVMFunctionHandle *func) {
API_BEGIN();
*func = (TVMFunctionHandle)(
static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name));
API_END();
}
void* TVMBackendAllocWorkspace(int device_type,
int device_id,
uint64_t size,
int dtype_code_hint,
int dtype_bits_hint) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
TVMType type_hint;
type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
type_hint.lanes = 1;
return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx,
static_cast<size_t>(size),
type_hint);
}
int TVMBackendFreeWorkspace(int device_type,
int device_id,
void* ptr) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);
return 0;
}
int TVMBackendRunOnce(void** handle,
int (*f)(void*),
void* cdata,
int nbytes) {
if (*handle == nullptr) {
*handle = reinterpret_cast<void*>(1);
return (*f)(cdata);
}
return 0;
}
int TVMFuncFree(TVMFunctionHandle func) {
API_BEGIN();
delete static_cast<PackedFunc*>(func);
API_END();
}
int TVMFuncCall(TVMFunctionHandle func,
TVMValue* args,
int* arg_type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code) {
API_BEGIN();
TVMRetValue rv;
(*static_cast<const PackedFunc*>(func)).CallPacked(
TVMArgs(args, arg_type_codes, num_args), &rv);
// handle return string.
if (rv.type_code() == kStr ||
rv.type_code() == kTVMType ||
rv.type_code() == kBytes) {
TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
if (rv.type_code() != kTVMType) {
e->ret_str = *rv.ptr<std::string>();
} else {
e->ret_str = rv.operator std::string();
}
if (rv.type_code() == kBytes) {
e->ret_bytes.data = e->ret_str.c_str();
e->ret_bytes.size = e->ret_str.length();
*ret_type_code = kBytes;
ret_val->v_handle = &(e->ret_bytes);
} else {
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
}
} else {
rv.MoveToCHost(ret_val, ret_type_code);
}
API_END();
}
int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue* value,
int* type_code,
int num_ret) {
API_BEGIN();
CHECK_EQ(num_ret, 1);
TVMRetValue* rv = static_cast<TVMRetValue*>(ret);
*rv = TVMArgValue(value[0], type_code[0]);
API_END();
}
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out) {
API_BEGIN();
if (fin == nullptr) {
*out = new PackedFunc(
[func, resource_handle](TVMArgs args, TVMRetValue* rv) {
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, resource_handle);
if (ret != 0) {
std::string err = "TVMCall CFunc Error:\n";
err += TVMGetLastError();
throw dmlc::Error(err);
}
});
} else {
// wrap it in a shared_ptr, with fin as deleter.
// so fin will be called when the lambda went out of scope.
std::shared_ptr<void> rpack(resource_handle, fin);
*out = new PackedFunc(
[func, rpack](TVMArgs args, TVMRetValue* rv) {
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, rpack.get());
if (ret != 0) {
std::string err = "TVMCall CFunc Error:\n";
err += TVMGetLastError();
throw dmlc::Error(err);
}
});
}
API_END();
}
int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
*out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
API_END();
}
int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
API_END();
}
int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
API_END();
}
int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
API_END();
}
int TVMStreamStreamSynchronize(int device_type,
int device_id,
TVMStreamHandle src,
TVMStreamHandle dst) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
API_END();
}
int TVMCbArgToReturn(TVMValue* value, int code) {
API_BEGIN();
tvm::runtime::TVMRetValue rv;
rv = tvm::runtime::TVMArgValue(*value, code);
int tcode;
rv.MoveToCHost(value, &tcode);
CHECK_EQ(tcode, code);
API_END();
}
// set device api
TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device)
.set_body([](TVMArgs args, TVMRetValue *ret) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
ctx.device_id = args[1];
DeviceAPIManager::Get(ctx)->SetDevice(ctx);
});
// set device api
TVM_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
ctx.device_id = args[1];
DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
if (kind == kExist) {
DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
if (api != nullptr) {
api->GetAttr(ctx, kind, ret);
} else {
*ret = 0;
}
} else {
DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
}
});
/*!
* Copyright (c) 2016 by Contributors
* \file cpu_device_api.cc
*/
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <cstdlib>
#include <cstring>
#include "workspace_pool.h"
namespace tvm {
namespace runtime {
class CPUDeviceAPI final : public DeviceAPI {
public:
void SetDevice(TVMContext ctx) final {}
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
if (kind == kExist) {
*rv = 1;
}
}
void* AllocDataSpace(TVMContext ctx,
size_t nbytes,
size_t alignment,
TVMType type_hint) final {
void* ptr;
#if _MSC_VER
ptr = _aligned_malloc(nbytes, alignment);
if (ptr == nullptr) throw std::bad_alloc();
#elif defined(_LIBCPP_SGX_CONFIG)
ptr = memalign(alignment, nbytes);
if (ptr == nullptr) throw std::bad_alloc();
#else
int ret = posix_memalign(&ptr, alignment, nbytes);
if (ret != 0) throw std::bad_alloc();
#endif
return ptr;
}
void FreeDataSpace(TVMContext ctx, void* ptr) final {
#if _MSC_VER
_aligned_free(ptr);
#else
free(ptr);
#endif
}
void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) final {
memcpy(static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset,
size);
}
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
}
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
static const std::shared_ptr<CPUDeviceAPI>& Global() {
static std::shared_ptr<CPUDeviceAPI> inst =
std::make_shared<CPUDeviceAPI>();
return inst;
}
};
struct CPUWorkspacePool : public WorkspacePool {
CPUWorkspacePool() :
WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {}
};
void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx,
size_t size,
TVMType type_hint) {
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()
->AllocWorkspace(ctx, size);
}
void CPUDeviceAPI::FreeWorkspace(TVMContext ctx, void* data) {
dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data);
}
TVM_REGISTER_GLOBAL("device_api.cpu")
.set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = CPUDeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file dso_dll_module.cc
* \brief Module to load from dynamic shared library.
*/
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h>
#include "module_util.h"
#if defined(_WIN32)
#include <windows.h>
#else
#include <dlfcn.h>
#endif
namespace tvm {
namespace runtime {
// Module to load from dynamic shared libary.
// This is the default module TVM used for host-side AOT
class DSOModuleNode final : public ModuleNode {
public:
~DSOModuleNode() {
if (lib_handle_) Unload();
}
const char* type_key() const final {
return "dso";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
BackendPackedCFunc faddr;
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::tvm_module_main));
CHECK(entry_name!= nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name));
} else {
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str()));
}
if (faddr == nullptr) return PackedFunc();
return WrapPackedFunc(faddr, sptr_to_self);
}
void Init(const std::string& name) {
Load(name);
if (auto *ctx_addr =
reinterpret_cast<void**>(GetSymbol(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = this;
}
InitContextFunctions([this](const char* fname) {
return GetSymbol(fname);
});
// Load the imported modules
const char* dev_mblob =
reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::tvm_dev_mblob));
if (dev_mblob != nullptr) {
ImportModuleBlob(dev_mblob, &imports_);
}
}
private:
// Platform dependent handling.
#if defined(_WIN32)
// library handle
HMODULE lib_handle_{nullptr};
// Load the library
void Load(const std::string& name) {
// use wstring version that is needed by LLVM.
std::wstring wname(name.begin(), name.end());
lib_handle_ = LoadLibraryW(wname.c_str());
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name;
}
void* GetSymbol(const char* name) {
return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}
void Unload() {
FreeLibrary(lib_handle_);
}
#else
// Library handle
void* lib_handle_{nullptr};
// load the library
void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name
<< " " << dlerror();
}
void* GetSymbol(const char* name) {
return dlsym(lib_handle_, name);
}
void Unload() {
dlclose(lib_handle_);
}
#endif
};
TVM_REGISTER_GLOBAL("module.loadfile_so")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>();
n->Init(args[0]);
*rv = runtime::Module(n);
});
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file file_util.cc
*/
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <dgl/runtime/serializer.h>
#include <fstream>
#include <vector>
#include "file_util.h"
namespace tvm {
namespace runtime {
void FunctionInfo::Save(dmlc::JSONWriter* writer) const {
std::vector<std::string> sarg_types(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
sarg_types[i] = TVMType2String(arg_types[i]);
}
writer->BeginObject();
writer->WriteObjectKeyValue("name", name);
writer->WriteObjectKeyValue("arg_types", sarg_types);
writer->WriteObjectKeyValue("thread_axis_tags", thread_axis_tags);
writer->EndObject();
}
void FunctionInfo::Load(dmlc::JSONReader* reader) {
dmlc::JSONObjectReadHelper helper;
std::vector<std::string> sarg_types;
helper.DeclareField("name", &name);
helper.DeclareField("arg_types", &sarg_types);
helper.DeclareField("thread_axis_tags", &thread_axis_tags);
helper.ReadAllFields(reader);
arg_types.resize(sarg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
arg_types[i] = String2TVMType(sarg_types[i]);
}
}
void FunctionInfo::Save(dmlc::Stream* writer) const {
writer->Write(name);
writer->Write(arg_types);
writer->Write(thread_axis_tags);
}
bool FunctionInfo::Load(dmlc::Stream* reader) {
if (!reader->Read(&name)) return false;
if (!reader->Read(&arg_types)) return false;
if (!reader->Read(&thread_axis_tags)) return false;
return true;
}
std::string GetFileFormat(const std::string& file_name,
const std::string& format) {
std::string fmt = format;
if (fmt.length() == 0) {
if (file_name.find(".signed.so") != std::string::npos) return "sgx";
size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) {
return file_name.substr(pos + 1, file_name.length() - pos - 1);
} else {
return "";
}
} else {
return format;
}
}
std::string GetCacheDir() {
char* env_cache_dir;
if ((env_cache_dir = getenv("TVM_CACHE_DIR"))) return env_cache_dir;
if ((env_cache_dir = getenv("XDG_CACHE_HOME"))) {
return std::string(env_cache_dir) + "/tvm";
}
if ((env_cache_dir = getenv("HOME"))) {
return std::string(env_cache_dir) + "/.cache/tvm";
}
return ".";
}
std::string GetFileBasename(const std::string& file_name) {
size_t last_slash = file_name.find_last_of("/");
if (last_slash == std::string::npos) return file_name;
return file_name.substr(last_slash + 1);
}
std::string GetMetaFilePath(const std::string& file_name) {
size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) {
return file_name.substr(0, pos) + ".tvm_meta.json";
} else {
return file_name + ".tvm_meta.json";
}
}
void LoadBinaryFromFile(const std::string& file_name,
std::string* data) {
std::ifstream fs(file_name, std::ios::in | std::ios::binary);
CHECK(!fs.fail()) << "Cannot open " << file_name;
// get its size:
fs.seekg(0, std::ios::end);
size_t size = static_cast<size_t>(fs.tellg());
fs.seekg(0, std::ios::beg);
data->resize(size);
fs.read(&(*data)[0], size);
}
void SaveBinaryToFile(
const std::string& file_name,
const std::string& data) {
std::ofstream fs(file_name, std::ios::out | std::ios::binary);
CHECK(!fs.fail()) << "Cannot open " << file_name;
fs.write(&data[0], data.length());
}
void SaveMetaDataToFile(
const std::string& file_name,
const std::unordered_map<std::string, FunctionInfo>& fmap) {
std::string version = "0.1.0";
std::ofstream fs(file_name.c_str());
CHECK(!fs.fail()) << "Cannot open file " << file_name;
dmlc::JSONWriter writer(&fs);
writer.BeginObject();
writer.WriteObjectKeyValue("tvm_version", version);
writer.WriteObjectKeyValue("func_info", fmap);
writer.EndObject();
fs.close();
}
void LoadMetaDataFromFile(
const std::string& file_name,
std::unordered_map<std::string, FunctionInfo>* fmap) {
std::ifstream fs(file_name.c_str());
CHECK(!fs.fail()) << "Cannot open file " << file_name;
std::string version;
dmlc::JSONReader reader(&fs);
dmlc::JSONObjectReadHelper helper;
helper.DeclareField("tvm_version", &version);
helper.DeclareField("func_info", fmap);
helper.ReadAllFields(&reader);
fs.close();
}
} // namespace runtime
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment