Unverified Commit 00add9f2 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Merge pull request #90 from jermainewang/cpp

[GraphIndex] Graph index and many related changes
parents ec4216dd dce1f44d
...@@ -4,17 +4,25 @@ from __future__ import absolute_import ...@@ -4,17 +4,25 @@ from __future__ import absolute_import
import operator import operator
import dgl.backend as F import dgl.backend as F
__all__ = ["MessageFunction", "src_mul_edge", "copy_src", "copy_edge"] __all__ = ["src_mul_edge", "copy_src", "copy_edge"]
class MessageFunction(object): class MessageFunction(object):
"""Base builtin message function class."""
def __call__(self, src, edge): def __call__(self, src, edge):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise NotImplementedError raise NotImplementedError
def name(self): def name(self):
"""Return the name of this builtin function."""
raise NotImplementedError raise NotImplementedError
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
"""Return whether the SPMV optimization is supported."""
raise NotImplementedError raise NotImplementedError
...@@ -22,12 +30,6 @@ class BundledMessageFunction(MessageFunction): ...@@ -22,12 +30,6 @@ class BundledMessageFunction(MessageFunction):
def __init__(self, fn_list): def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)): if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list] fn_list = [fn_list]
else:
# sanity check on out field
for fn in fn_list:
# cannot perform check for udf
if isinstance(fn, MessageFunction) and fn.out_field is None:
raise RuntimeError("Not specifying out field for multiple message is ambiguous")
self.fn_list = fn_list self.fn_list = fn_list
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
...@@ -43,11 +45,8 @@ class BundledMessageFunction(MessageFunction): ...@@ -43,11 +45,8 @@ class BundledMessageFunction(MessageFunction):
if ret is None: if ret is None:
ret = msg ret = msg
else: else:
try:
# ret and msg must be dict # ret and msg must be dict
ret.update(msg) ret.update(msg)
except:
raise RuntimeError("Must specify out field for multiple message")
return ret return ret
def name(self): def name(self):
...@@ -55,25 +54,26 @@ class BundledMessageFunction(MessageFunction): ...@@ -55,25 +54,26 @@ class BundledMessageFunction(MessageFunction):
def _is_spmv_supported_node_feat(g, field): def _is_spmv_supported_node_feat(g, field):
if field is None: """Return whether the node feature shape supports SPMV optimization.
feat = g.get_n_repr()
else: Only scalar and vector features are supported currently.
"""
feat = g.get_n_repr()[field] feat = g.get_n_repr()[field]
shape = F.shape(feat) shape = F.shape(feat)
return len(shape) == 1 or len(shape) == 2 return len(shape) == 1 or len(shape) == 2
def _is_spmv_supported_edge_feat(g, field): def _is_spmv_supported_edge_feat(g, field):
# check shape, only scalar edge feature can be optimized at the moment """Return whether the edge feature shape supports SPMV optimization.
if field is None:
feat = g.get_e_repr() Only scalar feature is supported currently.
else: """
feat = g.get_e_repr()[field] feat = g.get_e_repr()[field]
shape = F.shape(feat) shape = F.shape(feat)
return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1) return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1)
class SrcMulEdgeMessageFunction(MessageFunction): class SrcMulEdgeMessageFunction(MessageFunction):
def __init__(self, mul_op, src_field=None, edge_field=None, out_field=None): def __init__(self, mul_op, src_field, edge_field, out_field):
self.mul_op = mul_op self.mul_op = mul_op
self.src_field = src_field self.src_field = src_field
self.edge_field = edge_field self.edge_field = edge_field
...@@ -84,21 +84,14 @@ class SrcMulEdgeMessageFunction(MessageFunction): ...@@ -84,21 +84,14 @@ class SrcMulEdgeMessageFunction(MessageFunction):
and _is_spmv_supported_edge_feat(g, self.edge_field) and _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge): def __call__(self, src, edge):
if self.src_field is not None: ret = self.mul_op(src[self.src_field], edge[self.edge_field])
src = src[self.src_field]
if self.edge_field is not None:
edge = edge[self.edge_field]
ret = self.mul_op(src, edge)
if self.out_field is None:
return ret
else:
return {self.out_field : ret} return {self.out_field : ret}
def name(self): def name(self):
return "src_mul_edge" return "src_mul_edge"
class CopySrcMessageFunction(MessageFunction): class CopySrcMessageFunction(MessageFunction):
def __init__(self, src_field=None, out_field=None): def __init__(self, src_field, out_field):
self.src_field = src_field self.src_field = src_field
self.out_field = out_field self.out_field = out_field
...@@ -106,14 +99,7 @@ class CopySrcMessageFunction(MessageFunction): ...@@ -106,14 +99,7 @@ class CopySrcMessageFunction(MessageFunction):
return _is_spmv_supported_node_feat(g, self.src_field) return _is_spmv_supported_node_feat(g, self.src_field)
def __call__(self, src, edge): def __call__(self, src, edge):
if self.src_field is not None: return {self.out_field : src[self.src_field]}
ret = src[self.src_field]
else:
ret = src
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self): def name(self):
return "copy_src" return "copy_src"
...@@ -142,14 +128,41 @@ class CopyEdgeMessageFunction(MessageFunction): ...@@ -142,14 +128,41 @@ class CopyEdgeMessageFunction(MessageFunction):
return "copy_edge" return "copy_edge"
def src_mul_edge(src=None, edge=None, out=None): def src_mul_edge(src, edge, out):
"""TODO(minjie): docstring """ """Builtin message function that computes message by multiplying source node features
with edge features.
Parameters
----------
src : str
The source feature name.
edge : str
The edge feature name.
out : str
The output message name.
"""
return SrcMulEdgeMessageFunction(operator.mul, src, edge, out) return SrcMulEdgeMessageFunction(operator.mul, src, edge, out)
def copy_src(src=None, out=None): def copy_src(src, out):
"""TODO(minjie): docstring """ """Builtin message function that computes message using source node feature.
Parameters
----------
src : str
The source feature name.
out : str
The output message name.
"""
return CopySrcMessageFunction(src, out) return CopySrcMessageFunction(src, out)
def copy_edge(edge=None, out=None): def copy_edge(edge, out):
"""TODO(minjie): docstring """ """Builtin message function that computes message using edge feature.
Parameters
----------
edge : str
The edge feature name.
out : str
The output message name.
"""
return CopyEdgeMessageFunction(edge, out) return CopyEdgeMessageFunction(edge, out)
"""Built-in reducer function.""" """Built-in reducer function."""
from __future__ import absolute_import from __future__ import absolute_import
import dgl.backend as F from .. import backend as F
__all__ = ["ReduceFunction", "sum", "max"] __all__ = ["sum", "max"]
class ReduceFunction(object): class ReduceFunction(object):
"""Base builtin reduce function class."""
def __call__(self, node, msgs): def __call__(self, node, msgs):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise NotImplementedError raise NotImplementedError
def name(self): def name(self):
"""Return the name of this builtin function."""
raise NotImplementedError raise NotImplementedError
def is_spmv_supported(self): def is_spmv_supported(self):
"""Return whether the SPMV optimization is supported."""
raise NotImplementedError raise NotImplementedError
class BundledReduceFunction(ReduceFunction): class BundledReduceFunction(ReduceFunction):
def __init__(self, fn_list): def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)): if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list] fn_list = [fn_list]
else:
# sanity check on out field
for fn in fn_list:
if isinstance(fn, ReduceFunction) and fn.out_field is None:
raise RuntimeError("Not specifying out field for multiple reduce is ambiguous")
self.fn_list = fn_list self.fn_list = fn_list
def is_spmv_supported(self): def is_spmv_supported(self):
...@@ -39,51 +42,50 @@ class BundledReduceFunction(ReduceFunction): ...@@ -39,51 +42,50 @@ class BundledReduceFunction(ReduceFunction):
if ret is None: if ret is None:
ret = rpr ret = rpr
else: else:
try:
# ret and rpr must be dict # ret and rpr must be dict
ret.update(rpr) ret.update(rpr)
except:
raise RuntimeError("Must specify out field for multiple reudce")
return ret return ret
def name(self): def name(self):
return "bundled" return "bundled"
class ReducerFunctionTemplate(ReduceFunction): class ReducerFunctionTemplate(ReduceFunction):
def __init__(self, name, batch_op, nonbatch_op, msg_field=None, out_field=None): def __init__(self, name, op, msg_field, out_field):
self.name = name self.name = name
self.batch_op = batch_op self.op = op
self.nonbatch_op = nonbatch_op
self.msg_field = msg_field self.msg_field = msg_field
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self): def is_spmv_supported(self):
# TODO: support max # NOTE: only sum is supported right now.
return self.name == "sum" return self.name == "sum"
def __call__(self, node, msgs): def __call__(self, node, msgs):
if isinstance(msgs, list): return {self.out_field : self.op(msgs[self.msg_field], 1)}
if self.msg_field is None:
ret = self.nonbatch_op(msgs)
else:
ret = self.nonbatch_op([msg[self.msg_field] for msg in msgs])
else:
if self.msg_field is None:
ret = self.batch_op(msgs, 1)
else:
ret = self.batch_op(msgs[self.msg_field], 1)
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self): def name(self):
return self.name return self.name
_python_sum = sum def sum(msg, out):
def sum(msgs=None, out=None): """Builtin reduce function that aggregates messages by sum.
return ReducerFunctionTemplate("sum", F.sum, _python_sum, msgs, out)
Parameters
----------
msg : str
The message name.
out : str
The output node feature name.
"""
return ReducerFunctionTemplate("sum", F.sum, msg, out)
def max(msg, out):
"""Builtin reduce function that aggregates messages by max.
_python_max = max Parameters
def max(msgs=None, out=None): ----------
return ReducerFunctionTemplate("max", F.max, _python_max, msgs, out) msg : str
The message name.
out : str
The output node feature name.
"""
return ReducerFunctionTemplate("max", F.max, msg, out)
"""Line graph generator."""
from __future__ import absolute_import
import networkx as nx
import numpy as np
import dgl.backend as F
from dgl.graph import DGLGraph
from dgl.frame import FrameRef
def line_graph(G, no_backtracking=False):
"""Create the line graph that shares the underlying features.
The node features of the result line graph will share the edge features
of the given graph.
Parameters
----------
G : DGLGraph
The input graph.
no_backtracking : bool
Whether the backtracking edges are included in the line graph.
If i~j and j~i are two edges in original graph G, then
(i,j)~(j,i) and (j,i)~(i,j) are the "backtracking" edges on
the line graph.
"""
L = nx.DiGraph()
for eid, from_node in enumerate(G.edge_list):
L.add_node(from_node)
for to_node in G.edges(from_node[1]):
if no_backtracking and to_node[1] == from_node[0]:
continue
L.add_edge(from_node, to_node)
relabel_map = {}
for i, e in enumerate(G.edge_list):
relabel_map[e] = i
nx.relabel.relabel_nodes(L, relabel_map, copy=False)
return DGLGraph(L, node_frame=G._edge_frame)
...@@ -3,225 +3,707 @@ ...@@ -3,225 +3,707 @@
from __future__ import absolute_import from __future__ import absolute_import
import networkx as nx import networkx as nx
from networkx.classes.digraph import DiGraph import numpy as np
import dgl import dgl
from dgl.base import ALL, is_all, __MSG__, __REPR__ from .base import ALL, is_all, DGLError, dgl_warning
import dgl.backend as F from . import backend as F
from dgl.backend import Tensor from .backend import Tensor
from dgl.cached_graph import CachedGraph, create_cached_graph from .frame import FrameRef, merge_frames
import dgl.context as context from .function.message import BundledMessageFunction
from dgl.frame import FrameRef, merge_frames from .function.reducer import BundledReduceFunction
from dgl.nx_adapt import nx_init from .graph_index import GraphIndex, create_graph_index
import dgl.scheduler as scheduler from . import scheduler
import dgl.utils as utils from . import utils
from dgl.function.message import BundledMessageFunction
from dgl.function.reducer import BundledReduceFunction __all__ = ['DLGraph']
class DGLGraph(DiGraph): class DGLGraph(object):
"""Base graph class specialized for neural networks on graphs. """Base graph class specialized for neural networks on graphs.
TODO(minjie): document of batching semantics TODO(minjie): document of batching semantics
TODO(minjie): document of __REPR__ semantics
Parameters Parameters
---------- ----------
graph_data : graph data graph_data : graph data
Data to initialize graph. Same as networkx's semantics. Data to initialize graph. Same as networkx's semantics.
node_frame : dgl.frame.Frame node_frame : FrameRef
Node feature storage. Node feature storage.
edge_frame : dgl.frame.Frame edge_frame : FrameRef
Edge feature storage. Edge feature storage.
attr : keyword arguments, optional multigraph : bool, optional
Attributes to add to graph as key=value pairs. Whether the graph would be a multigraph (default: False)
""" """
def __init__(self, def __init__(self,
graph_data=None, graph_data=None,
node_frame=None, node_frame=None,
edge_frame=None, edge_frame=None,
**attr): multigraph=False):
# TODO(minjie): maintaining node/edge list is costly when graph is large. # graph
self._edge_list = [] self._graph = create_graph_index(graph_data, multigraph)
nx_init(self, # frame
self._add_node_callback,
self._add_edge_callback,
self._del_node_callback,
self._del_edge_callback,
graph_data,
**attr)
# cached graph and storage
self._cached_graph = None
self._node_frame = node_frame if node_frame is not None else FrameRef() self._node_frame = node_frame if node_frame is not None else FrameRef()
self._edge_frame = edge_frame if edge_frame is not None else FrameRef() self._edge_frame = edge_frame if edge_frame is not None else FrameRef()
# other class members # msg graph & frame
self._msg_graph = None self._msg_graph = create_graph_index(multigraph=multigraph)
self._msg_frame = FrameRef() self._msg_frame = FrameRef()
self._message_func = (None, None) self._msg_edges = []
self._reduce_func = (None, None) self.reset_messages()
self._edge_func = (None, None) # registered functions
self._apply_node_func = (None, None) self._message_func = None
self._apply_edge_func = (None, None) self._reduce_func = None
self._edge_func = None
self._apply_node_func = None
self._apply_edge_func = None
def add_nodes(self, num, reprs=None):
"""Add nodes.
Parameters
----------
num : int
Number of nodes to be added.
reprs : dict
Optional node representations.
"""
self._graph.add_nodes(num)
self._msg_graph.add_nodes(num)
#TODO(minjie): change frames
assert reprs is None
def add_edge(self, u, v, reprs=None):
"""Add one edge.
Parameters
----------
u : int
The src node.
v : int
The dst node.
reprs : dict
Optional edge representation.
"""
self._graph.add_edge(u, v)
#TODO(minjie): change frames
assert reprs is None
def add_edges(self, u, v, reprs=None):
"""Add many edges.
Parameters
----------
u : list, tensor
The src nodes.
v : list, tensor
The dst nodes.
reprs : dict
Optional node representations.
"""
u = utils.toindex(u)
v = utils.toindex(v)
self._graph.add_edges(u, v)
#TODO(minjie): change frames
assert reprs is None
def clear(self):
"""Clear the graph and its storage."""
self._graph.clear()
self._node_frame.clear()
self._edge_frame.clear()
self._msg_graph.clear()
self._msg_frame.clear()
self._msg_edges.clear()
def reset_messages(self):
"""Clear all messages."""
self._msg_graph.clear()
self._msg_frame.clear()
self._msg_edges.clear()
self._msg_graph.add_nodes(self.number_of_nodes())
def number_of_nodes(self):
"""Return the number of nodes.
Returns
-------
int
The number of nodes
"""
return self._graph.number_of_nodes()
def __len__(self):
"""Return the number of nodes."""
return self.number_of_nodes()
@property
def is_multigraph(self):
"""Whether the graph is a multigraph.
"""
return self._graph.is_multigraph()
def number_of_edges(self):
"""Return the number of edges.
Returns
-------
int
The number of edges
"""
return self._graph.number_of_edges()
def has_node(self, vid):
"""Return true if the node exists.
Parameters
----------
vid : int
The nodes
Returns
-------
bool
True if the node exists
"""
return self.has_node(vid)
def __contains__(self, vid):
"""Same as has_node."""
return self.has_node(vid)
def has_nodes(self, vids):
"""Return true if the nodes exist.
Parameters
----------
vid : list, tensor
The nodes
Returns
-------
tensor
0-1 array indicating existence
"""
vids = utils.toindex(vids)
rst = self._graph.has_nodes(vids)
return rst.tousertensor()
def has_edge_between(self, u, v):
"""Return true if the edge exists.
Parameters
----------
u : int
The src node.
v : int
The dst node.
Returns
-------
bool
True if the edge exists
"""
return self._graph.has_edge_between(u, v)
def has_edges_between(self, u, v):
"""Return true if the edge exists.
Parameters
----------
u : list, tensor
The src nodes.
v : list, tensor
The dst nodes.
Returns
-------
tensor
0-1 array indicating existence
"""
u = utils.toindex(u)
v = utils.toindex(v)
rst = self._graph.has_edges_between(u, v)
return rst.tousertensor()
def predecessors(self, v, radius=1):
"""Return the predecessors of the node.
Parameters
----------
v : int
The node.
radius : int, optional
The radius of the neighborhood.
Returns
-------
tensor
Array of predecessors
"""
return self._graph.predecessors(v).tousertensor()
def successors(self, v, radius=1):
"""Return the successors of the node.
Parameters
----------
v : int
The node.
radius : int, optional
The radius of the neighborhood.
Returns
-------
tensor
Array of successors
"""
return self._graph.successors(v).tousertensor()
def edge_id(self, u, v, force_multi=False):
"""Return the id of the edge.
Parameters
----------
u : int
The src node.
v : int
The dst node.
force_multi : bool
If False, will return a single edge ID if the graph is a simple graph.
If True, will always return an array.
Returns
-------
int or tensor
The edge id if force_multi == True and the graph is a simple graph.
The edge id array otherwise.
"""
idx = self._graph.edge_id(u, v)
return idx.tousertensor() if force_multi or self.is_multigraph else idx[0]
def edge_ids(self, u, v, force_multi=False):
"""Return the edge ids.
Parameters
----------
u : list, tensor
The src nodes.
v : list, tensor
The dst nodes.
force_multi : bool
If False, will return a single edge ID array if the graph is a simple graph.
If True, will always return 3 arrays (src nodes, dst nodes, edge ids).
Returns
-------
tensor, or (tensor, tensor, tensor)
If force_multi is True or the graph is multigraph, return (src nodes, dst nodes, edge ids)
Otherwise, return a single tensor of edge ids.
"""
u = utils.toindex(u)
v = utils.toindex(v)
src, dst, eid = self._graph.edge_ids(u, v)
if force_multi or self.is_multigraph:
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
else:
return eid.tousertensor()
def in_edges(self, v):
"""Return the in edges of the node(s).
Parameters
----------
v : int, list, tensor
The node(s).
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
"""
v = utils.toindex(v)
src, dst, eid = self._graph.in_edges(v)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
def out_edges(self, v):
"""Return the out edges of the node(s).
Parameters
----------
v : int, list, tensor
The node(s).
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
"""
v = utils.toindex(v)
src, dst, eid = self._graph.out_edges(v)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
def edges(self, sorted=False):
"""Return all the edges.
Parameters
----------
sorted : bool
True if the returned edges are sorted by their src and dst ids.
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
"""
src, dst, eid = self._graph.edges(sorted)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
def in_degree(self, v):
"""Return the in degree of the node.
Parameters
----------
v : int
The node.
Returns
-------
int
The in degree.
"""
return self._graph.in_degree(v)
def in_degrees(self, v):
"""Return the in degrees of the nodes.
Parameters
----------
v : list, tensor
The nodes.
Returns
-------
tensor
The in degree array.
"""
return self._graph.in_degrees(v).tousertensor()
def out_degree(self, v):
"""Return the out degree of the node.
Parameters
----------
v : int
The node.
Returns
-------
int
The out degree.
"""
return self._graph.out_degree(v)
def out_degrees(self, v):
"""Return the out degrees of the nodes.
Parameters
----------
v : list, tensor
The nodes.
Returns
-------
tensor
The out degree array.
"""
return self._graph.out_degrees(v).tousertensor()
def to_networkx(self, node_attrs=None, edge_attrs=None):
"""Convert to networkx graph.
The edge id will be saved as the 'id' edge attribute.
Parameters
----------
node_attrs : iterable of str, optional
The node attributes to be copied.
edge_attrs : iterable of str, optional
The edge attributes to be copied.
Returns
-------
networkx.DiGraph
The nx graph
"""
nx_graph = self._graph.to_networkx()
#TODO(minjie): attributes
dgl_warning('to_networkx currently does not support converting'
' node/edge features automatically.')
return nx_graph
def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
"""Convert from networkx graph.
If 'id' edge attribute exists, the edge will be added follows
the edge id order. Otherwise, order is undefined.
Parameters
----------
nx_graph : networkx.DiGraph
The nx graph
node_attrs : iterable of str, optional
The node attributes needs to be copied.
edge_attrs : iterable of str, optional
The edge attributes needs to be copied.
"""
self.clear()
self._graph.from_networkx(nx_graph)
self._msg_graph.add_nodes(self._graph.number_of_nodes())
# copy attributes
def _batcher(lst):
if isinstance(lst[0], Tensor):
return F.pack([F.unsqueeze(x, 0) for x in lst])
else:
return F.tensor(lst)
if node_attrs is not None:
attr_dict = {attr : [] for attr in node_attrs}
for nid in range(self.number_of_nodes()):
for attr in node_attrs:
attr_dict[attr].append(nx_graph.nodes[nid][attr])
for attr in node_attrs:
self._node_frame[attr] = _batcher(attr_dict[attr])
if edge_attrs is not None:
attr_dict = {attr : [] for attr in edge_attrs}
src, dst, _ = self._graph.edges()
for u, v in zip(src.tolist(), dst.tolist()):
for attr in edge_attrs:
attr_dict[attr].append(nx_graph.edges[u, v][attr])
for attr in edge_attrs:
self._edge_frame[attr] = _batcher(attr_dict[attr])
def from_scipy_sparse_matrix(self, a):
""" Convert from scipy sparse matrix.
Parameters
----------
a : scipy sparse matrix
The graph's adjacency matrix
"""
self.clear()
self._graph.from_scipy_sparse_matrix(a)
self._msg_graph.add_nodes(self._graph.number_of_nodes())
def node_attr_schemes(self): def node_attr_schemes(self):
"""Return the node feature schemes.
Returns
-------
dict of str to schemes
The schemes of node feature columns.
"""
return self._node_frame.schemes return self._node_frame.schemes
def edge_attr_schemes(self): def edge_attr_schemes(self):
"""Return the edge feature schemes.
Returns
-------
dict of str to schemes
The schemes of edge feature columns.
"""
return self._edge_frame.schemes return self._edge_frame.schemes
def set_n_repr(self, hu, u=ALL): def set_n_initializer(self, initializer):
"""Set the initializer for empty node features.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
"""
self._node_frame.set_initializer(initializer)
def set_e_initializer(self, initializer):
"""Set the initializer for empty edge features.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
"""
self._edge_frame.set_initializer(initializer)
def set_n_repr(self, hu, u=ALL, inplace=False):
"""Set node(s) representation. """Set node(s) representation.
To set multiple node representations at once, pass `u` with a tensor or `hu` is a dictionary from the feature name to feature tensor. Each tensor
a supported container of node ids. In this case, `hu` must be a tensor is of shape (B, D1, D2, ...), where B is the number of nodes to be updated,
of shape (B, D1, D2, ...), where B is the number of the nodes and and (D1, D2, ...) be the shape of the node representation tensor. The
(D1, D2, ...) is the shape of the node representation tensor. length of the given node ids must match B (i.e, len(u) == B).
Dictionary type is also supported for `hu`. In this case, each item All update will be done out-placely to work with autograd unless the inplace
will be treated as separate attribute of the nodes. flag is true.
Parameters Parameters
---------- ----------
hu : tensor or dict of tensor hu : dict of tensor
Node representation. Node representation.
u : node, container or tensor u : node, container or tensor
The node(s). The node(s).
inplace : bool
True if the update is done inplacely
""" """
# sanity check # sanity check
if not utils.is_dict_like(hu):
raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(hu))
if is_all(u): if is_all(u):
num_nodes = self.number_of_nodes() num_nodes = self.number_of_nodes()
else: else:
u = utils.toindex(u) u = utils.toindex(u)
num_nodes = len(u) num_nodes = len(u)
if utils.is_dict_like(hu):
for key, val in hu.items(): for key, val in hu.items():
assert F.shape(val)[0] == num_nodes nfeats = F.shape(val)[0]
else: if nfeats != num_nodes:
assert F.shape(hu)[0] == num_nodes raise DGLError('Expect number of features to match number of nodes (len(u)).'
' Got %d and %d instead.' % (nfeats, num_nodes))
# set # set
if is_all(u): if is_all(u):
if utils.is_dict_like(hu):
for key, val in hu.items(): for key, val in hu.items():
self._node_frame[key] = val self._node_frame[key] = val
else: else:
self._node_frame[__REPR__] = hu self._node_frame.update_rows(u, hu, inplace=inplace)
else:
if utils.is_dict_like(hu):
self._node_frame[u] = hu
else:
self._node_frame[u] = {__REPR__ : hu}
def get_n_repr(self, u=ALL): def get_n_repr(self, u=ALL):
"""Get node(s) representation. """Get node(s) representation.
The returned feature tensor batches multiple node features on the first dimension.
Parameters Parameters
---------- ----------
u : node, container or tensor u : node, container or tensor
The node(s). The node(s).
Returns
-------
dict
Representation dict from feature name to feature tensor.
""" """
if len(self.node_attr_schemes()) == 0:
return dict()
if is_all(u): if is_all(u):
if len(self._node_frame) == 1 and __REPR__ in self._node_frame:
return self._node_frame[__REPR__]
else:
return dict(self._node_frame) return dict(self._node_frame)
else: else:
u = utils.toindex(u) u = utils.toindex(u)
if len(self._node_frame) == 1 and __REPR__ in self._node_frame:
return self._node_frame.select_rows(u)[__REPR__]
else:
return self._node_frame.select_rows(u) return self._node_frame.select_rows(u)
def pop_n_repr(self, key=__REPR__): def pop_n_repr(self, key):
"""Get and remove the specified node repr. """Get and remove the specified node repr.
Parameters Parameters
---------- ----------
key : str key : str
The attribute name. The attribute name.
Returns
-------
Tensor
The popped representation
""" """
return self._node_frame.pop(key) return self._node_frame.pop(key)
def set_e_repr(self, h_uv, u=ALL, v=ALL): def set_e_repr(self, he, u=ALL, v=ALL, inplace=False):
"""Set edge(s) representation. """Set edge(s) representation.
To set multiple edge representations at once, pass `u` and `v` with tensors or `he` is a dictionary from the feature name to feature tensor. Each tensor
supported containers of node ids. In this case, `h_uv` must be a tensor is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
of shape (B, D1, D2, ...), where B is the number of the edges and and (D1, D2, ...) be the shape of the edge representation tensor.
(D1, D2, ...) is the shape of the edge representation tensor.
Dictionary type is also supported for `h_uv`. In this case, each item All update will be done out-placely to work with autograd unless the inplace
will be treated as separate attribute of the edges. flag is true.
Parameters Parameters
---------- ----------
h_uv : tensor or dict of tensor he : tensor or dict of tensor
Edge representation. Edge representation.
u : node, container or tensor u : node, container or tensor
The source node(s). The source node(s).
v : node, container or tensor v : node, container or tensor
The destination node(s). The destination node(s).
inplace : bool
True if the update is done inplacely
""" """
# sanity check # sanity check
if not utils.is_dict_like(he):
raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(he))
u_is_all = is_all(u) u_is_all = is_all(u)
v_is_all = is_all(v) v_is_all = is_all(v)
assert u_is_all == v_is_all assert u_is_all == v_is_all
if u_is_all: if u_is_all:
num_edges = self.cached_graph.num_edges() self.set_e_repr_by_id(he, eid=ALL, inplace=inplace)
else: else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
num_edges = max(len(u), len(v)) _, _, eid = self._graph.edge_ids(u, v)
if utils.is_dict_like(h_uv): self.set_e_repr_by_id(he, eid=eid, inplace=inplace)
for key, val in h_uv.items():
assert F.shape(val)[0] == num_edges
else:
assert F.shape(h_uv)[0] == num_edges
# set
if u_is_all:
if utils.is_dict_like(h_uv):
for key, val in h_uv.items():
self._edge_frame[key] = val
else:
self._edge_frame[__REPR__] = h_uv
else:
eid = self.cached_graph.get_edge_id(u, v)
if utils.is_dict_like(h_uv):
self._edge_frame[eid] = h_uv
else:
self._edge_frame[eid] = {__REPR__ : h_uv}
def set_e_repr_by_id(self, h_uv, eid=ALL): def set_e_repr_by_id(self, he, eid=ALL, inplace=False):
"""Set edge(s) representation by edge id. """Set edge(s) representation by edge id.
`he` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
and (D1, D2, ...) be the shape of the edge representation tensor.
All update will be done out-placely to work with autograd unless the inplace
flag is true.
Parameters Parameters
---------- ----------
h_uv : tensor or dict of tensor he : tensor or dict of tensor
Edge representation. Edge representation.
eid : int, container or tensor eid : int, container or tensor
The edge id(s). The edge id(s).
inplace : bool
True if the update is done inplacely
""" """
# sanity check # sanity check
if not utils.is_dict_like(he):
raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(he))
if is_all(eid): if is_all(eid):
num_edges = self.cached_graph.num_edges() num_edges = self.number_of_edges()
else: else:
eid = utils.toindex(eid) eid = utils.toindex(eid)
num_edges = len(eid) num_edges = len(eid)
if utils.is_dict_like(h_uv): for key, val in he.items():
for key, val in h_uv.items(): nfeats = F.shape(val)[0]
assert F.shape(val)[0] == num_edges if nfeats != num_edges:
else: raise DGLError('Expect number of features to match number of edges.'
assert F.shape(h_uv)[0] == num_edges ' Got %d and %d instead.' % (nfeats, num_edges))
# set # set
if is_all(eid): if is_all(eid):
if utils.is_dict_like(h_uv): # update column
for key, val in h_uv.items(): for key, val in he.items():
self._edge_frame[key] = val self._edge_frame[key] = val
else: else:
self._edge_frame[__REPR__] = h_uv # update row
else: self._edge_frame.update_rows(eid, he, inplace=inplace)
if utils.is_dict_like(h_uv):
self._edge_frame[eid] = h_uv
else:
self._edge_frame[eid] = {__REPR__ : h_uv}
def get_e_repr(self, u=ALL, v=ALL): def get_e_repr(self, u=ALL, v=ALL):
"""Get node(s) representation. """Get node(s) representation.
...@@ -232,31 +714,37 @@ class DGLGraph(DiGraph): ...@@ -232,31 +714,37 @@ class DGLGraph(DiGraph):
The source node(s). The source node(s).
v : node, container or tensor v : node, container or tensor
The destination node(s). The destination node(s).
Returns
-------
dict
Representation dict
""" """
u_is_all = is_all(u) u_is_all = is_all(u)
v_is_all = is_all(v) v_is_all = is_all(v)
assert u_is_all == v_is_all assert u_is_all == v_is_all
if len(self.edge_attr_schemes()) == 0:
return dict()
if u_is_all: if u_is_all:
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: return self.get_e_repr_by_id(eid=ALL)
return self._edge_frame[__REPR__]
else:
return dict(self._edge_frame)
else: else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
eid = self.cached_graph.get_edge_id(u, v) _, _, eid = self._graph.edge_ids(u, v)
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: return self.get_e_repr_by_id(eid=eid)
return self._edge_frame.select_rows(eid)[__REPR__]
else:
return self._edge_frame.select_rows(eid)
def pop_e_repr(self, key=__REPR__): def pop_e_repr(self, key):
"""Get and remove the specified edge repr. """Get and remove the specified edge repr.
Parameters Parameters
---------- ----------
key : str key : str
The attribute name. The attribute name.
Returns
-------
Tensor
The popped representation
""" """
return self._edge_frame.pop(key) return self._edge_frame.pop(key)
...@@ -267,150 +755,142 @@ class DGLGraph(DiGraph): ...@@ -267,150 +755,142 @@ class DGLGraph(DiGraph):
---------- ----------
eid : int, container or tensor eid : int, container or tensor
The edge id(s). The edge id(s).
Returns
-------
dict
Representation dict from feature name to feature tensor.
""" """
if len(self.edge_attr_schemes()) == 0:
return dict()
if is_all(eid): if is_all(eid):
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__]
else:
return dict(self._edge_frame) return dict(self._edge_frame)
else: else:
eid = utils.toindex(eid) eid = utils.toindex(eid)
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame.select_rows(eid)[__REPR__]
else:
return self._edge_frame.select_rows(eid) return self._edge_frame.select_rows(eid)
def register_edge_func(self, def register_edge_func(self, edge_func):
edge_func,
batchable=False):
"""Register global edge update function. """Register global edge update function.
Parameters Parameters
---------- ----------
edge_func : callable edge_func : callable
Message function on the edge. Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
""" """
self._edge_func = (edge_func, batchable) self._edge_func = edge_func
def register_message_func(self, def register_message_func(self, message_func):
message_func,
batchable=False):
"""Register global message function. """Register global message function.
Parameters Parameters
---------- ----------
message_func : callable message_func : callable
Message function on the edge. Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
""" """
self._message_func = (message_func, batchable) self._message_func = message_func
def register_reduce_func(self, def register_reduce_func(self, reduce_func):
reduce_func,
batchable=False):
"""Register global message reduce function. """Register global message reduce function.
Parameters Parameters
---------- ----------
reduce_func : str or callable reduce_func : str or callable
Reduce function on incoming edges. Reduce function on incoming edges.
batchable : bool
Whether the provided reduce function allows batch computing.
""" """
self._reduce_func = (reduce_func, batchable) self._reduce_func = reduce_func
def register_apply_node_func(self, def register_apply_node_func(self, apply_node_func):
apply_node_func,
batchable=False):
"""Register global node apply function. """Register global node apply function.
Parameters Parameters
---------- ----------
apply_node_func : callable apply_node_func : callable
Apply function on the node. Apply function on the node.
batchable : bool
Whether the provided function allows batch computing.
""" """
self._apply_node_func = (apply_node_func, batchable) self._apply_node_func = apply_node_func
def register_apply_edge_func(self, def register_apply_edge_func(self, apply_edge_func):
apply_edge_func,
batchable=False):
"""Register global edge apply function. """Register global edge apply function.
Parameters Parameters
---------- ----------
apply_edge_func : callable apply_edge_func : callable
Apply function on the edge. Apply function on the edge.
batchable : bool
Whether the provided function allows batch computing.
""" """
self._apply_edge_func = (apply_edge_func, batchable) self._apply_edge_func = apply_edge_func
def apply_nodes(self, v, apply_node_func="default", batchable=False): def apply_nodes(self, v=ALL, apply_node_func="default"):
"""Apply the function on node representations. """Apply the function on node representations.
Applying a None function will be ignored.
Parameters Parameters
---------- ----------
v : int, iterable of int, tensor v : int, iterable of int, tensor, optional
The node id(s). The node id(s).
apply_node_func : callable apply_node_func : callable
The apply node function. The apply node function.
batchable : bool """
Whether the provided function allows batch computing. self._apply_nodes(v, apply_node_func)
def _apply_nodes(self, v, apply_node_func="default", reduce_accum=None):
"""Internal apply nodes
Parameters
----------
reduce_accum: dict-like
The output of reduce func
""" """
if apply_node_func == "default": if apply_node_func == "default":
apply_node_func, batchable = self._apply_node_func apply_node_func = self._apply_node_func
if not apply_node_func: if not apply_node_func:
# Skip none function call. # Skip none function call.
if reduce_accum is not None:
# write reduce result back
self.set_n_repr(reduce_accum, v)
return return
if batchable: # take out current node repr
new_repr = apply_node_func(self.get_n_repr(v)) curr_repr = self.get_n_repr(v)
if reduce_accum is not None:
# merge current node_repr with reduce output
curr_repr = utils.HybridDict(reduce_accum, curr_repr)
new_repr = apply_node_func(curr_repr)
if reduce_accum is not None:
# merge new node_repr with reduce output
reduce_accum.update(new_repr)
new_repr = reduce_accum
self.set_n_repr(new_repr, v) self.set_n_repr(new_repr, v)
else:
if is_all(v):
v = self.nodes()
v = utils.toindex(v)
for vv in utils.node_iter(v):
ret = apply_node_func(_get_repr(self.nodes[vv]))
_set_repr(self.nodes[vv], ret)
def apply_edges(self, u, v, apply_edge_func="default", batchable=False): def apply_edges(self, u=None, v=None, apply_edge_func="default", eid=None):
"""Apply the function on edge representations. """Apply the function on edge representations.
Applying a None function will be ignored.
Parameters Parameters
---------- ----------
u : int, iterable of int, tensor u : optional, int, iterable of int, tensor
The src node id(s). The src node id(s).
v : int, iterable of int, tensor v : optional, int, iterable of int, tensor
The dst node id(s). The dst node id(s).
apply_edge_func : callable apply_edge_func : callable
The apply edge function. The apply edge function.
batchable : bool eid : None, edge, container or tensor
Whether the provided function allows batch computing. The edge to update on. If eid is not None then u and v are ignored.
""" """
if apply_edge_func == "default": if apply_edge_func == "default":
apply_edge_func, batchable = self._apply_edge_func apply_edge_func = self._apply_edge_func
if not apply_edge_func: if not apply_edge_func:
# Skip none function call. # Skip none function call.
return return
if batchable: if eid is None:
new_repr = apply_edge_func(self.get_e_repr(u, v)) new_repr = apply_edge_func(self.get_e_repr(u, v))
self.set_e_repr(new_repr, u, v) self.set_e_repr(new_repr, u, v)
else: else:
if is_all(u) == is_all(v): new_repr = apply_edge_func(self.get_e_repr_by_id(eid))
u, v = zip(*self.edges) self.set_e_repr_by_id(new_repr, eid)
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = apply_edge_func(_get_repr(self.edges[uu, vv]))
_set_repr(self.edges[uu, vv], ret)
def send(self, u, v, message_func="default", batchable=False): def send(self, u=None, v=None, message_func="default", eid=None):
"""Trigger the message function on edge u->v """Trigger the message function on edge u->v or eid
The message function should be compatible with following signature: The message function should be compatible with following signature:
...@@ -422,62 +902,108 @@ class DGLGraph(DiGraph): ...@@ -422,62 +902,108 @@ class DGLGraph(DiGraph):
The message function can be any of the pre-defined functions The message function can be any of the pre-defined functions
('from_src'). ('from_src').
Currently, we require the message functions of consecutive send's to
return the same keys. Otherwise the behavior will be undefined.
Parameters Parameters
---------- ----------
u : node, container or tensor u : optional, node, container or tensor
The source node(s). The source node(s).
v : node, container or tensor v : optional, node, container or tensor
The destination node(s). The destination node(s).
message_func : callable message_func : callable
The message function. The message function.
batchable : bool eid : optional, edge, container or tensor
Whether the function allows batched computation. The edge to update on. If eid is not None then u and v are ignored.
Notes
-----
On multigraphs, if u and v are specified, then the messages will be sent
along all edges between u and v.
""" """
if message_func == "default": if message_func == "default":
message_func, batchable = self._message_func message_func = self._message_func
assert message_func is not None assert message_func is not None
if isinstance(message_func, (tuple, list)): if isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func) message_func = BundledMessageFunction(message_func)
if batchable: self._batch_send(u, v, eid, message_func)
self._batch_send(u, v, message_func)
else:
self._nonbatch_send(u, v, message_func)
def _nonbatch_send(self, u, v, message_func): def _batch_send(self, u, v, eid, message_func):
if is_all(u) and is_all(v): if is_all(u) and is_all(v) and eid is None:
u, v = self.cached_graph.edges() u, v, eid = self._graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = message_func(_get_repr(self.nodes[uu]),
_get_repr(self.edges[uu, vv]))
self.edges[uu, vv][__MSG__] = ret
def _batch_send(self, u, v, message_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
self.msg_graph.add_edges(u, v)
# call UDF # call UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr() edge_reprs = self.get_e_repr()
msgs = message_func(src_reprs, edge_reprs) msgs = message_func(src_reprs, edge_reprs)
elif eid is not None:
eid = utils.toindex(eid)
u, v, _ = self._graph.find_edges(eid)
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs)
else: else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v) u, v, eid = self._graph.edge_ids(u, v)
eid = self.cached_graph.get_edge_id(u, v)
self.msg_graph.add_edges(u, v)
# call UDF # call UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid) edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs) msgs = message_func(src_reprs, edge_reprs)
if utils.is_dict_like(msgs): self._msg_graph.add_edges(u, v)
self._msg_frame.append(msgs) self._msg_frame.append(msgs)
# TODO(minjie): Fix these codes in next PR.
"""
new_uv = []
msg_target_rows = []
msg_update_rows = []
msg_append_rows = []
for i, (_u, _v, _eid) in enumerate(zip(u, v, eid)):
if _eid in self._msg_edges:
msg_target_rows.append(self._msg_edges.index(_eid))
msg_update_rows.append(i)
else:
new_uv.append((_u, _v))
self._msg_edges.append(_eid)
msg_append_rows.append(i)
msg_target_rows = utils.toindex(msg_target_rows)
msg_update_rows = utils.toindex(msg_update_rows)
msg_append_rows = utils.toindex(msg_append_rows)
if utils.is_dict_like(msgs):
if len(msg_target_rows) > 0:
self._msg_frame.update_rows(
msg_target_rows,
{k: F.gather_row(msgs[k], msg_update_rows.tousertensor())
for k in msgs},
inplace=False)
if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u)
new_v = utils.toindex(new_v)
self._msg_graph.add_edges(new_u, new_v)
self._msg_frame.append(
{k: F.gather_row(msgs[k], msg_append_rows.tousertensor())
for k in msgs})
else: else:
self._msg_frame.append({__MSG__ : msgs}) if len(msg_target_rows) > 0:
self._msg_frame.update_rows(
msg_target_rows,
{__MSG__: F.gather_row(msgs, msg_update_rows.tousertensor())},
inplace=False)
if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u)
new_v = utils.toindex(new_v)
self._msg_graph.add_edges(new_u, new_v)
self._msg_frame.append(
{__MSG__: F.gather_row(msgs, msg_append_rows.tousertensor())}
)
"""
def update_edge(self, u=ALL, v=ALL, edge_func="default", batchable=False): def update_edge(self, u=ALL, v=ALL, edge_func="default", eid=None):
"""Update representation on edge u->v """Update representation on edge u->v
The edge function should be compatible with following signature: The edge function should be compatible with following signature:
...@@ -496,32 +1022,17 @@ class DGLGraph(DiGraph): ...@@ -496,32 +1022,17 @@ class DGLGraph(DiGraph):
The destination node(s). The destination node(s).
edge_func : callable edge_func : callable
The update function. The update function.
batchable : bool eid : optional, edge, container or tensor
Whether the function allows batched computation. The edge to update on. If eid is not None then u and v are ignored.
""" """
if edge_func == "default": if edge_func == "default":
edge_func, batchable = self._edge_func edge_func = self._edge_func
assert edge_func is not None assert edge_func is not None
if batchable: self._batch_update_edge(u, v, eid, edge_func)
self._batch_update_edge(u, v, edge_func)
else:
self._nonbatch_update_edge(u, v, edge_func)
def _nonbatch_update_edge(self, u, v, edge_func): def _batch_update_edge(self, u, v, eid, edge_func):
if is_all(u) and is_all(v): if is_all(u) and is_all(v) and eid is None:
u, v = self.cached_graph.edges() u, v, eid = self._graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = edge_func(_get_repr(self.nodes[uu]),
_get_repr(self.nodes[vv]),
_get_repr(self.edges[uu, vv]))
_set_repr(self.edges[uu, vv], ret)
def _batch_update_edge(self, u, v, edge_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
# call the UDF # call the UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v) dst_reprs = self.get_n_repr(v)
...@@ -529,10 +1040,11 @@ class DGLGraph(DiGraph): ...@@ -529,10 +1040,11 @@ class DGLGraph(DiGraph):
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs) new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs)
self.set_e_repr(new_edge_reprs) self.set_e_repr(new_edge_reprs)
else: else:
if eid is None:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v) u, v = utils.edge_broadcasting(u, v)
eid = self.cached_graph.get_edge_id(u, v) _, _, eid = self._graph.edge_ids(u, v)
# call the UDF # call the UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v) dst_reprs = self.get_n_repr(v)
...@@ -543,8 +1055,7 @@ class DGLGraph(DiGraph): ...@@ -543,8 +1055,7 @@ class DGLGraph(DiGraph):
def recv(self, def recv(self,
u, u,
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default"):
batchable=False):
"""Receive and reduce in-coming messages and update representation on node u. """Receive and reduce in-coming messages and update representation on node u.
It computes the new node state using the messages sent from the predecessors It computes the new node state using the messages sent from the predecessors
...@@ -574,33 +1085,15 @@ class DGLGraph(DiGraph): ...@@ -574,33 +1085,15 @@ class DGLGraph(DiGraph):
The reduce function. The reduce function.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
batchable : bool, optional
Whether the reduce and update function allows batched computation.
""" """
if reduce_func == "default": if reduce_func == "default":
reduce_func, batchable = self._reduce_func reduce_func = self._reduce_func
assert reduce_func is not None assert reduce_func is not None
if isinstance(reduce_func, (list, tuple)): if isinstance(reduce_func, (list, tuple)):
reduce_func = BundledReduceFunction(reduce_func) reduce_func = BundledReduceFunction(reduce_func)
if batchable:
self._batch_recv(u, reduce_func) self._batch_recv(u, reduce_func)
else:
self._nonbatch_recv(u, reduce_func)
# optional apply nodes # optional apply nodes
self.apply_nodes(u, apply_node_func, batchable) self.apply_nodes(u, apply_node_func)
def _nonbatch_recv(self, u, reduce_func):
if is_all(u):
u = list(range(0, self.number_of_nodes()))
else:
u = utils.toindex(u)
for i, uu in enumerate(utils.node_iter(u)):
# reduce phase
msgs_batch = [self.edges[vv, uu].pop(__MSG__)
for vv in self.pred[uu] if __MSG__ in self.edges[vv, uu]]
if len(msgs_batch) != 0:
new_repr = reduce_func(_get_repr(self.nodes[uu]), msgs_batch)
_set_repr(self.nodes[uu], new_repr)
def _batch_recv(self, v, reduce_func): def _batch_recv(self, v, reduce_func):
if self._msg_frame.num_rows == 0: if self._msg_frame.num_rows == 0:
...@@ -616,7 +1109,7 @@ class DGLGraph(DiGraph): ...@@ -616,7 +1109,7 @@ class DGLGraph(DiGraph):
v = utils.toindex(v) v = utils.toindex(v)
# degree bucketing # degree bucketing
degrees, v_buckets = scheduler.degree_bucketing(self.msg_graph, v) degrees, v_buckets = scheduler.degree_bucketing(self._msg_graph, v)
if degrees == [0]: if degrees == [0]:
# no message has been sent to the specified node # no message has been sent to the specified node
return return
...@@ -631,33 +1124,26 @@ class DGLGraph(DiGraph): ...@@ -631,33 +1124,26 @@ class DGLGraph(DiGraph):
continue continue
bkt_len = len(v_bkt) bkt_len = len(v_bkt)
dst_reprs = self.get_n_repr(v_bkt) dst_reprs = self.get_n_repr(v_bkt)
uu, vv, _ = self.msg_graph.in_edges(v_bkt) uu, vv, in_msg_ids = self._msg_graph.in_edges(v_bkt)
in_msg_ids = self.msg_graph.get_edge_id(uu, vv)
in_msgs = self._msg_frame.select_rows(in_msg_ids) in_msgs = self._msg_frame.select_rows(in_msg_ids)
# Reshape the column tensor to (B, Deg, ...). # Reshape the column tensor to (B, Deg, ...).
def _reshape_fn(msg): def _reshape_fn(msg):
msg_shape = F.shape(msg) msg_shape = F.shape(msg)
new_shape = (bkt_len, deg) + msg_shape[1:] new_shape = (bkt_len, deg) + msg_shape[1:]
return F.reshape(msg, new_shape) return F.reshape(msg, new_shape)
if len(in_msgs) == 1 and __MSG__ in in_msgs:
reshaped_in_msgs = _reshape_fn(in_msgs[__MSG__])
else:
reshaped_in_msgs = utils.LazyDict( reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes) lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes)
reordered_v.append(v_bkt.totensor()) reordered_v.append(v_bkt.tousertensor())
new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs)) new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs))
# TODO: clear partial messages # TODO(minjie): clear partial messages
self.clear_messages() self.reset_messages()
# Pack all reducer results together # Pack all reducer results together
reordered_v = F.pack(reordered_v) reordered_v = F.pack(reordered_v)
if utils.is_dict_like(new_reprs[0]):
keys = new_reprs[0].keys() keys = new_reprs[0].keys()
new_reprs = {key : F.pack([repr[key] for repr in new_reprs]) new_reprs = {key : F.pack([repr[key] for repr in new_reprs])
for key in keys} for key in keys}
else:
new_reprs = {__REPR__ : F.pack(new_reprs)}
if v_is_all and not has_zero_degree: if v_is_all and not has_zero_degree:
# First do reorder and then replace the whole column. # First do reorder and then replace the whole column.
...@@ -670,18 +1156,19 @@ class DGLGraph(DiGraph): ...@@ -670,18 +1156,19 @@ class DGLGraph(DiGraph):
self.set_n_repr(new_reprs, reordered_v) self.set_n_repr(new_reprs, reordered_v)
def send_and_recv(self, def send_and_recv(self,
u, v, u=None, v=None,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default",
batchable=False): eid=None):
"""Trigger the message function on u->v and update v. """Trigger the message function on u->v and update v, or on edge eid
and update the destination nodes.
Parameters Parameters
---------- ----------
u : node, container or tensor u : optional, node, container or tensor
The source node(s). The source node(s).
v : node, container or tensor v : optional, node, container or tensor
The destination node(s). The destination node(s).
message_func : callable message_func : callable
The message function. The message function.
...@@ -689,45 +1176,83 @@ class DGLGraph(DiGraph): ...@@ -689,45 +1176,83 @@ class DGLGraph(DiGraph):
The reduce function. The reduce function.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
batchable : bool
Whether the reduce and update function allows batched computation. Notes
-----
On multigraphs, if u and v are specified, then the messages will be sent
and received along all edges between u and v.
""" """
if message_func == "default":
message_func = self._message_func
if reduce_func == "default":
reduce_func = self._reduce_func
assert message_func is not None
assert reduce_func is not None
if eid is None:
if u is None or v is None:
raise ValueError('u and v must be given if eid is None')
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
if len(u) == 0: if len(u) == 0:
# no edges to be triggered # no edges to be triggered
assert len(v) == 0 assert len(v) == 0
return return
unique_v = utils.toindex(F.unique(v.totensor())) unique_v = utils.toindex(F.unique(v.tousertensor()))
# TODO(minjie): better way to figure out `batchable` flag
if message_func == "default":
message_func, batchable = self._message_func
if reduce_func == "default":
reduce_func, _ = self._reduce_func
assert message_func is not None
assert reduce_func is not None
if batchable:
executor = scheduler.get_executor( executor = scheduler.get_executor(
'send_and_recv', self, src=u, dst=v, 'send_and_recv', self, src=u, dst=v,
message_func=message_func, reduce_func=reduce_func) message_func=message_func, reduce_func=reduce_func)
else: else:
eid = utils.toindex(eid)
if len(eid) == 0:
# no edges to be triggered
return
executor = None executor = None
if executor: if executor:
executor.run() new_reprs = executor.run()
unique_v = executor.recv_nodes
self._apply_nodes(unique_v, apply_node_func, reduce_accum=new_reprs)
elif eid is not None:
_, v, _ = self._graph.find_edges(eid)
unique_v = utils.toindex(F.unique(v.tousertensor()))
# TODO(quan): replace with the new DegreeBucketingScheduler
self.send(eid=eid, message_func=message_func)
self.recv(unique_v, reduce_func, apply_node_func)
else: else:
self.send(u, v, message_func, batchable=batchable) # handle multiple message and reduce func
self.recv(unique_v, reduce_func, None, batchable=batchable) if isinstance(message_func, (tuple, list)):
self.apply_nodes(unique_v, apply_node_func, batchable=batchable) message_func = BundledMessageFunction(message_func)
if isinstance(reduce_func, (list, tuple)):
reduce_func = BundledReduceFunction(reduce_func)
# message func
u, v = utils.edge_broadcasting(u, v)
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr(u, v)
msgs = message_func(src_reprs, edge_reprs)
msg_frame = FrameRef()
msg_frame.append(msgs)
# recv with degree bucketing
executor = scheduler.get_recv_executor(graph=self,
reduce_func=reduce_func,
message_frame=msg_frame,
edges=(u, v))
new_reprs = executor.run()
unique_v = executor.recv_nodes
self._apply_nodes(unique_v, apply_node_func, reduce_accum=new_reprs)
def pull(self, def pull(self,
v, v,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default"):
batchable=False):
"""Pull messages from the node's predecessors and then update it. """Pull messages from the node's predecessors and then update it.
Parameters Parameters
...@@ -740,24 +1265,20 @@ class DGLGraph(DiGraph): ...@@ -740,24 +1265,20 @@ class DGLGraph(DiGraph):
The reduce function. The reduce function.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
""" """
v = utils.toindex(v) v = utils.toindex(v)
if len(v) == 0: if len(v) == 0:
return return
uu, vv, _ = self.cached_graph.in_edges(v) uu, vv, _ = self._graph.in_edges(v)
self.send_and_recv(uu, vv, message_func, reduce_func, self.send_and_recv(uu, vv, message_func, reduce_func, apply_node_func=None)
apply_node_func=None, batchable=batchable) unique_v = F.unique(v.tousertensor())
unique_v = F.unique(v.totensor()) self.apply_nodes(unique_v, apply_node_func)
self.apply_nodes(unique_v, apply_node_func, batchable=batchable)
def push(self, def push(self,
u, u,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default"):
batchable=False):
"""Send message from the node to its successors and update them. """Send message from the node to its successors and update them.
Parameters Parameters
...@@ -770,21 +1291,18 @@ class DGLGraph(DiGraph): ...@@ -770,21 +1291,18 @@ class DGLGraph(DiGraph):
The reduce function. The reduce function.
apply_node_func : callable apply_node_func : callable
The update function. The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
""" """
u = utils.toindex(u) u = utils.toindex(u)
if len(u) == 0: if len(u) == 0:
return return
uu, vv, _ = self.cached_graph.out_edges(u) uu, vv, _ = self._graph.out_edges(u)
self.send_and_recv(uu, vv, message_func, self.send_and_recv(uu, vv, message_func,
reduce_func, apply_node_func, batchable=batchable) reduce_func, apply_node_func)
def update_all(self, def update_all(self,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default"):
batchable=False):
"""Send messages through all the edges and update all nodes. """Send messages through all the edges and update all nodes.
Parameters Parameters
...@@ -795,76 +1313,61 @@ class DGLGraph(DiGraph): ...@@ -795,76 +1313,61 @@ class DGLGraph(DiGraph):
The reduce function. The reduce function.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
""" """
if message_func == "default": if message_func == "default":
message_func, batchable = self._message_func message_func = self._message_func
if reduce_func == "default": if reduce_func == "default":
reduce_func, _ = self._reduce_func reduce_func = self._reduce_func
assert message_func is not None assert message_func is not None
assert reduce_func is not None assert reduce_func is not None
if batchable:
executor = scheduler.get_executor( executor = scheduler.get_executor(
"update_all", self, message_func=message_func, reduce_func=reduce_func) "update_all", self, message_func=message_func, reduce_func=reduce_func)
else:
executor = None
if executor: if executor:
executor.run() new_reprs = executor.run()
self._apply_nodes(ALL, apply_node_func, reduce_accum=new_reprs)
else: else:
self.send(ALL, ALL, message_func, batchable=batchable) self.send(ALL, ALL, message_func)
self.recv(ALL, reduce_func, None, batchable=batchable) self.recv(ALL, reduce_func, apply_node_func)
self.apply_nodes(ALL, apply_node_func, batchable=batchable)
def propagate(self, def propagate(self,
iterator='bfs', traverser='topo',
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default",
batchable=False,
**kwargs): **kwargs):
"""Propagate messages and update nodes using iterator. """Propagate messages and update nodes using graph traversal.
A convenient function for passing messages and updating A convenient function for passing messages and updating
nodes according to the iterator. The iterator can be nodes according to the traverser. The traverser can be
any of the pre-defined iterators ('bfs', 'dfs', 'pre-order', any of the pre-defined traverser (e.g. 'topo'). User can also provide custom
'mid-order', 'post-order'). The computation will be unrolled traverser that generates the edges and nodes.
in the backend efficiently. User can also provide custom
iterator that generates the edges and nodes.
Parameters Parameters
---------- ----------
traverser : str or generator of edges.
The traverser of the graph.
message_func : str or callable message_func : str or callable
The message function. The message function.
reduce_func : str or callable reduce_func : str or callable
The reduce function. The reduce function.
apply_node_func : str or callable apply_node_func : str or callable
The update function. The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
iterator : str or generator of steps.
The iterator of the graph.
kwargs : keyword arguments, optional kwargs : keyword arguments, optional
Arguments for pre-defined iterators. Arguments for pre-defined iterators.
""" """
if isinstance(iterator, str): if isinstance(traverser, str):
# TODO Call pre-defined routine to unroll the computation. # TODO(minjie): Call pre-defined routine to unroll the computation.
raise RuntimeError('Not implemented.') raise RuntimeError('Not implemented.')
else: else:
# NOTE: the iteration can return multiple edges at each step. # NOTE: the iteration can return multiple edges at each step.
for u, v in iterator: for u, v in traverser:
self.send_and_recv(u, v, self.send_and_recv(u, v,
message_func, reduce_func, apply_node_func, batchable) message_func, reduce_func, apply_node_func)
def subgraph(self, nodes): def subgraph(self, nodes):
"""Generate the subgraph among the given nodes. """Generate the subgraph among the given nodes.
The generated graph contains only the graph structure. The node/edge
features are not shared implicitly. Use `copy_from` to get node/edge
features from parent graph.
Parameters Parameters
---------- ----------
nodes : list, or iterable nodes : list, or iterable
...@@ -875,7 +1378,26 @@ class DGLGraph(DiGraph): ...@@ -875,7 +1378,26 @@ class DGLGraph(DiGraph):
G : DGLSubGraph G : DGLSubGraph
The subgraph. The subgraph.
""" """
return dgl.DGLSubGraph(self, nodes) induced_nodes = utils.toindex(nodes)
sgi = self._graph.node_subgraph(induced_nodes)
return dgl.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
def edge_subgraph(self, edges):
"""Generate the subgraph among the given edges.
Parameters
----------
edges : list, or iterable
A container of the edges to construct subgraph.
Returns
-------
G : DGLSubGraph
The subgraph.
"""
induced_edges = utils.toindex(edges)
sgi = self._graph.edge_subgraph(induced_edges)
return dgl.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
def merge(self, subgraphs, reduce_func='sum'): def merge(self, subgraphs, reduce_func='sum'):
"""Merge subgraph features back to this parent graph. """Merge subgraph features back to this parent graph.
...@@ -919,91 +1441,109 @@ class DGLGraph(DiGraph): ...@@ -919,91 +1441,109 @@ class DGLGraph(DiGraph):
self._edge_frame.num_rows, self._edge_frame.num_rows,
reduce_func) reduce_func)
def draw(self): def adjacency_matrix(self, ctx=None):
"""Plot the graph using dot.""" """Return the adjacency matrix representation of this graph.
from networkx.drawing.nx_agraph import graphviz_layout
pos = graphviz_layout(self, prog='dot') Parameters
nx.draw(self, pos, with_labels=True) ----------
ctx : optional
The context of returned adjacency matrix.
@property Returns
def cached_graph(self): -------
# TODO: dirty flag when mutated sparse_tensor
if self._cached_graph is None: The adjacency matrix.
self._cached_graph = create_cached_graph(self) """
return self._cached_graph return self._graph.adjacency_matrix().get(ctx)
@property def incidence_matrix(self, oriented=False, ctx=None):
def msg_graph(self): """Return the incidence matrix representation of this graph.
# TODO: dirty flag when mutated
if self._msg_graph is None:
self._msg_graph = CachedGraph()
self._msg_graph.add_nodes(self.number_of_nodes())
return self._msg_graph
def clear_messages(self): Parameters
if self._msg_graph is not None: ----------
self._msg_graph = CachedGraph() oriented : bool, optional
self._msg_graph.add_nodes(self.number_of_nodes()) Whether the returned incidence matrix is oriented.
self._msg_frame.clear()
@property ctx : optional
def edge_list(self): The context of returned incidence matrix.
"""Return edges in the addition order."""
return self._edge_list
def get_edge_id(self, u, v): Returns
"""Return the continuous edge id(s) assigned. -------
sparse_tensor
The incidence matrix.
"""
return self._graph.incidence_matrix(oriented).get(ctx)
def line_graph(self, backtracking=True, shared=False):
"""Return the line graph of this graph.
Parameters Parameters
---------- ----------
u : node, container or tensor backtracking : bool, optional
The source node(s). Whether the returned line graph is backtracking.
v : node, container or tensor
The destination node(s). shared : bool, optional
Whether the returned line graph shares representations with `self`.
Returns Returns
------- -------
eid : tensor DGLGraph
The tensor contains edge id(s). The line graph of this graph.
""" """
u = utils.toindex(u) graph_data = self._graph.line_graph(backtracking)
v = utils.toindex(v) node_frame = self._edge_frame if shared else None
return self.cached_graph.get_edge_id(u, v) return DGLGraph(graph_data, node_frame)
def _add_node_callback(self, node): def filter_nodes(self, predicate, nodes=ALL):
#print('New node:', node) """Return a tensor of node IDs that satisfy the given predicate.
self._cached_graph = None
Parameters
def _del_node_callback(self, node): ----------
#print('Del node:', node) predicate : callable
raise RuntimeError('Node removal is not supported currently.') The predicate should take in a dict of tensors whose values
node = utils.convert_to_id_tensor(node) are concatenation of node representations by node ID (same as
self._node_frame.delete_rows(node) get_n_repr()), and return a boolean tensor with N elements
self._cached_graph = None indicating which node satisfy the predicate.
nodes : container or tensor
def _add_edge_callback(self, u, v): The nodes to filter on
#print('New edge:', u, v)
self._edge_list.append((u, v)) Returns
self._cached_graph = None -------
tensor
def _del_edge_callback(self, u, v): The filtered nodes
#print('Del edge:', u, v) """
raise RuntimeError('Edge removal is not supported currently.') n_repr = self.get_n_repr(nodes)
u = utils.convert_to_id_tensor(u) n_mask = predicate(n_repr)
v = utils.convert_to_id_tensor(v)
eid = self.get_edge_id(u, v) if is_all(nodes):
self._edge_frame.delete_rows(eid) return F.nonzero_1d(n_mask)
self._cached_graph = None
def _get_repr(attr_dict):
if len(attr_dict) == 1 and __REPR__ in attr_dict:
return attr_dict[__REPR__]
else: else:
return attr_dict nodes = F.Tensor(nodes)
return nodes[n_mask]
def filter_edges(self, predicate, edges=ALL):
"""Return a tensor of edge IDs that satisfy the given predicate.
Parameters
----------
predicate : callable
The predicate should take in a dict of tensors whose values
are concatenation of edge representations by edge ID (same as
get_e_repr_by_id()), and return a boolean tensor with N elements
indicating which node satisfy the predicate.
edges : container or tensor
The edges to filter on
Returns
-------
tensor
The filtered edges
"""
e_repr = self.get_e_repr_by_id(edges)
e_mask = predicate(e_repr)
def _set_repr(attr_dict, attr): if is_all(edges):
if utils.is_dict_like(attr): return F.nonzero_1d(e_mask)
attr_dict.update(attr)
else: else:
attr_dict[__REPR__] = attr edges = F.Tensor(edges)
return edges[e_mask]
from __future__ import absolute_import
import ctypes
import numpy as np
import networkx as nx
import scipy
from ._ffi.base import c_array
from ._ffi.function import _init_api
from . import backend as F
from . import utils
GraphIndexHandle = ctypes.c_void_p
class GraphIndex(object):
"""Graph index object.
Parameters
----------
handle : GraphIndexHandle
Handler
"""
def __init__(self, handle):
self._handle = handle
self._cache = {}
def __del__(self):
"""Free this graph index object."""
_CAPI_DGLGraphFree(self._handle)
def add_nodes(self, num):
"""Add nodes.
Parameters
----------
num : int
Number of nodes to be added.
"""
_CAPI_DGLGraphAddVertices(self._handle, num);
self._cache.clear()
def add_edge(self, u, v):
"""Add one edge.
Parameters
----------
u : int
The src node.
v : int
The dst node.
"""
_CAPI_DGLGraphAddEdge(self._handle, u, v);
self._cache.clear()
def add_edges(self, u, v):
"""Add many edges.
Parameters
----------
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
_CAPI_DGLGraphAddEdges(self._handle, u_array, v_array)
self._cache.clear()
def clear(self):
"""Clear the graph."""
_CAPI_DGLGraphClear(self._handle)
self._cache.clear()
def is_multigraph(self):
"""Return whether the graph is a multigraph
Returns
-------
bool
True if it is a multigraph, False otherwise.
"""
return bool(_CAPI_DGLGraphIsMultigraph(self._handle))
def number_of_nodes(self):
"""Return the number of nodes.
Returns
-------
int
The number of nodes
"""
return _CAPI_DGLGraphNumVertices(self._handle)
def number_of_edges(self):
"""Return the number of edges.
Returns
-------
int
The number of edges
"""
return _CAPI_DGLGraphNumEdges(self._handle)
def has_node(self, vid):
"""Return true if the node exists.
Parameters
----------
vid : int
The nodes
Returns
-------
bool
True if the node exists, False otherwise.
"""
return bool(_CAPI_DGLGraphHasVertex(self._handle, vid))
def has_nodes(self, vids):
"""Return true if the nodes exist.
Parameters
----------
vid : utils.Index
The nodes
Returns
-------
utils.Index
0-1 array indicating existence
"""
vid_array = vids.todgltensor()
return utils.toindex(_CAPI_DGLGraphHasVertices(self._handle, vid_array))
def has_edge_between(self, u, v):
"""Return true if the edge exists.
Parameters
----------
u : int
The src node.
v : int
The dst node.
Returns
-------
bool
True if the edge exists, False otherwise
"""
return bool(_CAPI_DGLGraphHasEdgeBetween(self._handle, u, v))
def has_edges_between(self, u, v):
"""Return true if the edge exists.
Parameters
----------
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
Returns
-------
utils.Index
0-1 array indicating existence
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphHasEdgesBetween(self._handle, u_array, v_array))
def predecessors(self, v, radius=1):
"""Return the predecessors of the node.
Parameters
----------
v : int
The node.
radius : int, optional
The radius of the neighborhood.
Returns
-------
utils.Index
Array of predecessors
"""
return utils.toindex(_CAPI_DGLGraphPredecessors(self._handle, v, radius))
def successors(self, v, radius=1):
"""Return the successors of the node.
Parameters
----------
v : int
The node.
radius : int, optional
The radius of the neighborhood.
Returns
-------
utils.Index
Array of successors
"""
return utils.toindex(_CAPI_DGLGraphSuccessors(self._handle, v, radius))
def edge_id(self, u, v):
"""Return the id array of all edges between u and v.
Parameters
----------
u : int
The src node.
v : int
The dst node.
Returns
-------
utils.Index
The edge id array.
"""
return utils.toindex(_CAPI_DGLGraphEdgeId(self._handle, u, v))
def edge_ids(self, u, v):
"""Return a triplet of arrays that contains the edge IDs.
Parameters
----------
u : utils.Index
The src nodes.
v : utils.Index
The dst nodes.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
u_array = u.todgltensor()
v_array = v.todgltensor()
edge_array = _CAPI_DGLGraphEdgeIds(self._handle, u_array, v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def find_edges(self, eid):
"""Return a triplet of arrays that contains the edge IDs.
Parameters
----------
eid : utils.Index
The edge ids.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
eid_array = eid.todgltensor()
edge_array = _CAPI_DGLGraphFindEdges(self._handle, eid_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def in_edges(self, v):
"""Return the in edges of the node(s).
Parameters
----------
v : utils.Index
The node(s).
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if len(v) == 1:
edge_array = _CAPI_DGLGraphInEdges_1(self._handle, v[0])
else:
v_array = v.todgltensor()
edge_array = _CAPI_DGLGraphInEdges_2(self._handle, v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def out_edges(self, v):
"""Return the out edges of the node(s).
Parameters
----------
v : utils.Index
The node(s).
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
if len(v) == 1:
edge_array = _CAPI_DGLGraphOutEdges_1(self._handle, v[0])
else:
v_array = v.todgltensor()
edge_array = _CAPI_DGLGraphOutEdges_2(self._handle, v_array)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def edges(self, sorted=False):
"""Return all the edges
Parameters
----------
sorted : bool
True if the returned edges are sorted by their src and dst ids.
Returns
-------
utils.Index
The src nodes.
utils.Index
The dst nodes.
utils.Index
The edge ids.
"""
edge_array = _CAPI_DGLGraphEdges(self._handle, sorted)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
def in_degree(self, v):
"""Return the in degree of the node.
Parameters
----------
v : int
The node.
Returns
-------
int
The in degree.
"""
return _CAPI_DGLGraphInDegree(self._handle, v)
def in_degrees(self, v):
"""Return the in degrees of the nodes.
Parameters
----------
v : utils.Index
The nodes.
Returns
-------
int
The in degree array.
"""
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphInDegrees(self._handle, v_array))
def out_degree(self, v):
"""Return the out degree of the node.
Parameters
----------
v : int
The node.
Returns
-------
int
The out degree.
"""
return _CAPI_DGLGraphOutDegree(self._handle, v)
def out_degrees(self, v):
"""Return the out degrees of the nodes.
Parameters
----------
v : utils.Index
The nodes.
Returns
-------
int
The out degree array.
"""
v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphOutDegrees(self._handle, v_array))
def node_subgraph(self, v):
"""Return the induced node subgraph.
Parameters
----------
v : utils.Index
The nodes.
Returns
-------
SubgraphIndex
The subgraph index.
"""
v_array = v.todgltensor()
rst = _CAPI_DGLGraphVertexSubgraph(self._handle, v_array)
induced_edges = utils.toindex(rst(2))
return SubgraphIndex(rst(0), self, v, induced_edges)
def edge_subgraph(self, e):
"""Return the induced edge subgraph.
Parameters
----------
e : utils.Index
The edges.
Returns
-------
SubgraphIndex
The subgraph index.
"""
e_array = e.todgltensor()
rst = _CAPI_DGLGraphEdgeSubgraph(self._handle, e_array)
gi = GraphIndex(rst(0))
induced_nodes = utils.toindex(rst(1))
return SubgraphIndex(rst(0), self, induced_nodes, e)
def adjacency_matrix(self):
"""Return the adjacency matrix representation of this graph.
Returns
-------
utils.CtxCachedObject
An object that returns tensor given context.
"""
if not 'adj' in self._cache:
src, dst, _ = self.edges(sorted=False)
src = F.unsqueeze(src.tousertensor(), 0)
dst = F.unsqueeze(dst.tousertensor(), 0)
idx = F.pack([dst, src])
n = self.number_of_nodes()
dat = F.ones((self.number_of_edges(),))
mat = F.sparse_tensor(idx, dat, [n, n])
self._cache['adj'] = utils.CtxCachedObject(lambda ctx: F.to_context(mat, ctx))
return self._cache['adj']
def incidence_matrix(self, oriented=False):
"""Return the incidence matrix representation of this graph.
Parameters
----------
oriented : bool, optional (default=False)
Whether the returned incidence matrix is oriented.
Returns
-------
utils.CtxCachedObject
An object that returns tensor given context.
"""
key = ('oriented ' if oriented else '') + 'incidence matrix'
if not key in self._cache:
src, dst, _ = self.edges(sorted=False)
src = src.tousertensor()
dst = dst.tousertensor()
m = self.number_of_edges()
eid = F.arange(m, dtype=F.int64)
row = F.pack([src, dst])
col = F.pack([eid, eid])
idx = F.stack([row, col])
diagonal = (src == dst)
if oriented:
x = -F.ones((m,))
y = F.ones((m,))
x[diagonal] = 0
y[diagonal] = 0
dat = F.pack([x, y])
else:
x = F.ones((m,))
x[diagonal] = 0
dat = F.pack([x, x])
n = self.number_of_nodes()
mat = F.sparse_tensor(idx, dat, [n, m])
self._cache[key] = utils.CtxCachedObject(lambda ctx: F.to_context(mat, ctx))
return self._cache[key]
def to_networkx(self):
"""Convert to networkx graph.
The edge id will be saved as the 'id' edge attribute.
Returns
-------
networkx.DiGraph
The nx graph
"""
src, dst, eid = self.edges()
ret = nx.MultiDiGraph() if self.is_multigraph() else nx.DiGraph()
for u, v, id in zip(src, dst, eid):
ret.add_edge(u, v, id=id)
return ret
def from_networkx(self, nx_graph):
"""Convert from networkx graph.
If 'id' edge attribute exists, the edge will be added follows
the edge id order. Otherwise, order is undefined.
Parameters
----------
nx_graph : networkx.DiGraph
The nx graph
"""
self.clear()
if not isinstance(nx_graph, nx.Graph):
nx_graph = (nx.MultiDiGraph(nx_graph) if self.is_multigraph()
else nx.DiGraph(nx_graph))
else:
nx_graph = nx_graph.to_directed()
num_nodes = nx_graph.number_of_nodes()
self.add_nodes(num_nodes)
has_edge_id = 'id' in next(iter(nx_graph.edges))
if has_edge_id:
num_edges = nx_graph.number_of_edges()
src = np.zeros((num_edges,), dtype=np.int64)
dst = np.zeros((num_edges,), dtype=np.int64)
for e, attr in nx_graph.edges.items:
# MultiDiGraph returns a triplet in e while DiGraph returns a pair
eid = attr['id']
src[eid] = e[0]
dst[eid] = e[1]
else:
src = []
dst = []
for e in nx_graph.edges:
src.append(e[0])
dst.append(e[1])
src = utils.toindex(src)
dst = utils.toindex(dst)
self.add_edges(src, dst)
def from_scipy_sparse_matrix(self, adj):
"""Convert from scipy sparse matrix.
Parameters
----------
adj : scipy sparse matrix
"""
self.clear()
self.add_nodes(adj.shape[0])
adj_coo = adj.tocoo()
src = utils.toindex(adj_coo.row)
dst = utils.toindex(adj_coo.col)
self.add_edges(src, dst)
def line_graph(self, backtracking=True):
"""Return the line graph of this graph.
Parameters
----------
backtracking : bool, optional (default=False)
Whether (i, j) ~ (j, i) in L(G).
(i, j) ~ (j, i) is the behavior of networkx.line_graph.
Returns
-------
GraphIndex
The line graph of this graph.
"""
handle = _CAPI_DGLGraphLineGraph(self._handle, backtracking)
return GraphIndex(handle)
class SubgraphIndex(GraphIndex):
"""Graph index for subgraph.
Parameters
----------
handle : GraphIndexHandle
The capi handle.
paranet : GraphIndex
The parent graph index.
induced_nodes : utils.Index
The parent node ids in this subgraph.
induced_edges : utils.Index
The parent edge ids in this subgraph.
"""
def __init__(self, handle, parent, induced_nodes, induced_edges):
super(SubgraphIndex, self).__init__(handle)
self._parent = parent
self._induced_nodes = induced_nodes
self._induced_edges = induced_edges
def add_nodes(self, num):
"""Add nodes. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v):
"""Add edges. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v):
"""Add edges. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
@property
def induced_nodes(self):
"""Return parent node ids.
Returns
-------
utils.Index
The parent node ids.
"""
return self._induced_nodes
@property
def induced_edges(self):
"""Return parent edge ids.
Returns
-------
utils.Index
The parent edge ids.
"""
return self._induced_edges
def disjoint_union(graphs):
"""Return a disjoint union of the input graphs.
The new graph will include all the nodes/edges in the given graphs.
Nodes/Edges will be relabled by adding the cumsum of the previous graph sizes
in the given sequence order. For example, giving input [g1, g2, g3], where
they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become node#7
in the result graph. Edge ids are re-assigned similarly.
Parameters
----------
graphs : iterable of GraphIndex
The input graphs
Returns
-------
GraphIndex
The disjoint union
"""
inputs = c_array(GraphIndexHandle, [gr._handle for gr in graphs])
inputs = ctypes.cast(inputs, ctypes.c_void_p)
handle = _CAPI_DGLDisjointUnion(inputs, len(graphs))
return GraphIndex(handle)
def disjoint_partition(graph, num_or_size_splits):
"""Partition the graph disjointly.
This is a reverse operation of DisjointUnion. The graph will be partitioned
into num graphs. This requires the given number of partitions to evenly
divides the number of nodes in the graph. If the a size list is given,
the sum of the given sizes is equal.
Parameters
----------
graph : GraphIndex
The graph to be partitioned
num_or_size_splits : int or utils.Index
The partition number of size splits
Returns
-------
list of GraphIndex
The partitioned graphs
"""
if isinstance(num_or_size_splits, utils.Index):
rst = _CAPI_DGLDisjointPartitionBySizes(
graph._handle,
num_or_size_splits.todgltensor())
else:
rst = _CAPI_DGLDisjointPartitionByNum(
graph._handle,
int(num_or_size_splits))
graphs = []
for val in rst.asnumpy():
handle = ctypes.cast(int(val), ctypes.c_void_p)
graphs.append(GraphIndex(handle))
return graphs
def create_graph_index(graph_data=None, multigraph=False):
"""Create a graph index object.
Parameters
----------
graph_data : graph data, optional
Data to initialize graph. Same as networkx's semantics.
multigraph : bool, optional
Whether the graph is multigraph (default is False)
"""
if isinstance(graph_data, GraphIndex):
return graph_data
handle = _CAPI_DGLGraphCreate(multigraph)
gi = GraphIndex(handle)
if graph_data is None:
return gi
# scipy format
if isinstance(graph_data, scipy.sparse.spmatrix):
try:
gi.from_scipy_sparse_matrix(graph_data)
return gi
except:
raise Exception('Graph data is not a valid scipy sparse matrix.')
# networkx - any format
try:
gi.from_networkx(graph_data)
except:
raise Exception('Error while creating graph from input of type "%s".'
% type(graph_data))
return gi
_init_api("dgl.graph_index")
"""DGL Runtime NDArray API.
dgl.ndarray provides a minimum runtime array structure to be
used with C++ library.
"""
# pylint: disable=invalid-name,unused-import
from __future__ import absolute_import as _abs
import ctypes
import functools
import operator
import numpy as _np
from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
from ._ffi.ndarray import context, empty, from_dlpack, numpyasarray
from ._ffi.ndarray import _set_class_ndarray
from . import backend as F
class NDArray(NDArrayBase):
"""Lightweight NDArray class for DGL framework."""
def __len__(self):
return functools.reduce(operator.mul, self.shape, 1)
def cpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(1, dev_id)
def gpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(2, dev_id)
def array(arr, ctx=cpu(0)):
"""Create an array from source arr.
Parameters
----------
arr : numpy.ndarray
The array to be copied from
ctx : TVMContext, optional
The device context to create the array
Returns
-------
ret : NDArray
The created array
"""
if not isinstance(arr, (_np.ndarray, NDArray)):
arr = _np.array(arr)
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
def zerocopy_from_numpy(np_data):
"""Create an array that shares the given numpy data.
Parameters
----------
np_data : numpy.ndarray
The numpy data
Returns
-------
NDArray
The array
"""
arr, _ = numpyasarray(np_data)
handle = ctypes.pointer(arr)
return NDArray(handle, is_view=True)
_set_class_ndarray(NDArray)
"""Package nn modules"""
from __future__ import absolute_import
import os import os
__backend__ = os.environ.get('DGLBACKEND', 'pytorch').lower() __backend__ = os.environ.get('DGLBACKEND', 'pytorch').lower()
if __backend__ == 'numpy': if __backend__ == 'numpy':
pass pass
elif __backend__ == 'pytorch': elif __backend__ == 'pytorch':
from .pytorch import * from .pytorch import *
else: elif __backend__ != 'mxnet':
raise Exception("Unsupported backend %s" % __backend__) raise Exception("Unsupported backend %s" % __backend__)
...@@ -7,9 +7,8 @@ GCN with SPMV specialization. ...@@ -7,9 +7,8 @@ GCN with SPMV specialization.
""" """
import torch.nn as nn import torch.nn as nn
import dgl from ... import function as fn
import dgl.function as fn from ...base import ALL, is_all
from dgl.base import ALL, is_all
class NodeUpdateModule(nn.Module): class NodeUpdateModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self, in_feats, out_feats, activation=None):
......
"""Utility functions for networkx adapter."""
from __future__ import absolute_import
from collections import MutableMapping
import networkx as nx
import networkx.convert as convert
class NodeDict(MutableMapping):
def __init__(self, add_cb, del_cb):
self._dict = {}
self._add_cb = add_cb
self._del_cb = del_cb
def __setitem__(self, key, val):
self._add_cb(key)
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
self._del_cb(key)
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class AdjOuterDict(MutableMapping):
def __init__(self, add_cb, del_cb):
self._dict = {}
self._add_cb = add_cb
self._del_cb = del_cb
def __setitem__(self, key, val):
val.src = key
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
for val in self._dict[key]:
self._del_cb(key, val)
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class AdjInnerDict(MutableMapping):
def __init__(self, add_cb, del_cb):
self._dict = {}
self.src = None
self._add_cb = add_cb
self._del_cb = del_cb
def __setitem__(self, key, val):
if self.src is not None and key not in self._dict:
self._add_cb(self.src, key)
self._dict[key] = val
def __getitem__(self, key):
return self._dict[key]
def __delitem__(self, key):
if self.src is not None:
self._del_cb(self.src, key)
del self._dict[key]
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict)
class AdjInnerDictFactory(object):
def __init__(self, cb1, cb2):
self._cb1 = cb1
self._cb2 = cb2
def __call__(self):
return AdjInnerDict(self._cb1, self._cb2)
def nx_init(obj,
add_node_cb,
add_edge_cb,
del_node_cb,
del_edge_cb,
graph_data,
**attr):
"""Init the object to be compatible with networkx's DiGraph.
Parameters
----------
obj : any
The object to be init.
add_node_cb : callable
The callback function when node is added.
add_edge_cb : callable
The callback function when edge is added.
graph_data : graph data
Data to initialize graph. Same as networkx's semantics.
attr : keyword arguments, optional
Attributes to add to graph as key=value pairs.
"""
# The following codes work for networkx 2.1.
obj.adjlist_outer_dict_factory = None
obj.adjlist_inner_dict_factory = AdjInnerDictFactory(add_edge_cb, del_edge_cb)
obj.edge_attr_dict_factory = dict
obj.root_graph = obj
obj.graph = {}
obj._node = NodeDict(add_node_cb, del_node_cb)
obj._adj = AdjOuterDict(add_edge_cb, del_edge_cb)
obj._pred = dict()
obj._succ = obj._adj
if graph_data is not None:
convert.to_networkx_graph(graph_data, create_using=obj)
obj.graph.update(attr)
...@@ -3,20 +3,23 @@ from __future__ import absolute_import ...@@ -3,20 +3,23 @@ from __future__ import absolute_import
import numpy as np import numpy as np
import dgl.backend as F from .base import ALL, DGLError
import dgl.function.message as fmsg from . import backend as F
import dgl.function.reducer as fred from .function import message as fmsg
import dgl.utils as utils from .function import reducer as fred
from dgl.base import ALL from . import utils
from collections import defaultdict as ddict
__all__ = ["degree_bucketing", "get_executor"] from ._ffi.function import _init_api
def degree_bucketing(cached_graph, v): __all__ = ["degree_bucketing", "get_recv_executor", "get_executor"]
def degree_bucketing(graph, v):
"""Create degree bucketing scheduling policy. """Create degree bucketing scheduling policy.
Parameters Parameters
---------- ----------
cached_graph : dgl.cached_graph.CachedGraph graph : dgl.graph_index.GraphIndex
the graph the graph
v : dgl.utils.Index v : dgl.utils.Index
the nodes to gather messages the nodes to gather messages
...@@ -29,7 +32,7 @@ def degree_bucketing(cached_graph, v): ...@@ -29,7 +32,7 @@ def degree_bucketing(cached_graph, v):
list of node id buckets; nodes belong to the same bucket have list of node id buckets; nodes belong to the same bucket have
the same degree the same degree
""" """
degrees = F.asnumpy(cached_graph.in_degrees(v).totensor()) degrees = np.array(graph.in_degrees(v).tolist())
unique_degrees = list(np.unique(degrees)) unique_degrees = list(np.unique(degrees))
v_np = np.array(v.tolist()) v_np = np.array(v.tolist())
v_bkt = [] v_bkt = []
...@@ -39,9 +42,84 @@ def degree_bucketing(cached_graph, v): ...@@ -39,9 +42,84 @@ def degree_bucketing(cached_graph, v):
#print('degree-bucketing:', unique_degrees, [len(b) for b in v_bkt]) #print('degree-bucketing:', unique_degrees, [len(b) for b in v_bkt])
return unique_degrees, v_bkt return unique_degrees, v_bkt
def _process_buckets(buckets):
"""read bucketing auxiliary data"""
# get back results
degs = utils.toindex(buckets(0))
v = utils.toindex(buckets(1))
# TODO: convert directly from ndarary to python list?
v_section = buckets(2).asnumpy().tolist()
msg_ids = utils.toindex(buckets(3))
msg_section = buckets(4).asnumpy().tolist()
# split buckets
unique_v = v.tousertensor()
msg_ids = msg_ids.tousertensor()
dsts = F.unpack(unique_v, v_section)
msg_ids = F.unpack(msg_ids, msg_section)
# convert to utils.Index
unique_v = utils.toindex(unique_v)
dsts = [utils.toindex(dst) for dst in dsts]
msg_ids = [utils.toindex(msg_id) for msg_id in msg_ids]
return unique_v, degs, dsts, msg_ids
def light_degree_bucketing(v):
"""Return the bucketing by degree scheduling for destination nodes of messages
Parameters
----------
v: utils.Index
destionation node for each message
Returns
-------
unique_v: utils.Index
unqiue destination nodes
degrees: utils.Index
A list of degree for each bucket
v_bkt: list of utils.Index
A list of node id buckets, nodes in each bucket have the same degree
msg_ids: list of utils.Index
A list of message id buckets, each node in the ith node id bucket has
degree[i] messages in the ith message id bucket
"""
buckets = _CAPI_DGLDegreeBucketing(v.todgltensor())
return _process_buckets(buckets)
def light_degree_bucketing_for_graph(graph):
"""Return the bucketing by degree scheduling for the entire graph
Parameters:
graph: GraphIndex
Returns
-------
unique_v: utils.Index
unqiue destination nodes
degrees: utils.Index
A list of degree for each bucket
v_bkt: list of utils.Index
A list of node id buckets, nodes in each bucket have the same degree
msg_ids: list of utils.Index
A list of message id buckets, each node in the ith node id bucket has
degree[i] messages in the ith message id bucket
"""
buckets = _CAPI_DGLDegreeBucketingFromGraph(self._handle)
return _process_buckets(buckets)
class Executor(object): class Executor(object):
"""Base class for executing graph computation."""
def run(self): def run(self):
"""Run this executor.
This should return the new node features.
TODO(minjie): extend this to support computation on edges.
"""
raise NotImplementedError raise NotImplementedError
class SPMVOperator(Executor): class SPMVOperator(Executor):
...@@ -56,9 +134,6 @@ class SPMVOperator(Executor): ...@@ -56,9 +134,6 @@ class SPMVOperator(Executor):
def run(self): def run(self):
# get src col # get src col
if self.src_field is None:
srccol = self.node_repr
else:
srccol = self.node_repr[self.src_field] srccol = self.node_repr[self.src_field]
ctx = F.get_context(srccol) ctx = F.get_context(srccol)
...@@ -72,12 +147,52 @@ class SPMVOperator(Executor): ...@@ -72,12 +147,52 @@ class SPMVOperator(Executor):
dstcol = F.squeeze(dstcol) dstcol = F.squeeze(dstcol)
else: else:
dstcol = F.spmm(adjmat, srccol) dstcol = F.spmm(adjmat, srccol)
if self.dst_field is None:
return dstcol
else:
return {self.dst_field : dstcol} return {self.dst_field : dstcol}
# FIXME: refactorize in scheduler/executor redesign
class DegreeBucketingExecutor(Executor):
def __init__(self, g, rfunc, message_frame, edges=None):
self.g = g
self.rfunc = rfunc
self.msg_frame = message_frame
# calc degree bucketing schedule
if edges is not None:
unique_v, degs, dsts, msg_ids = light_degree_bucketing(edges[1])
else:
unique_v, degs, dsts, msg_ids = light_degree_bucketing_for_graph(g._graph)
self._recv_nodes = unique_v
self.degrees = degs
self.dsts = dsts
self.msg_ids = msg_ids
@property
def recv_nodes(self):
return self._recv_nodes
def run(self):
new_reprs = []
# loop over each bucket
# FIXME (lingfan): handle zero-degree case
for deg, vv, msg_id in zip(self.degrees, self.dsts, self.msg_ids):
dst_reprs = self.g.get_n_repr(vv)
in_msgs = self.msg_frame.select_rows(msg_id)
def _reshape_fn(msg):
msg_shape = F.shape(msg)
new_shape = (len(vv), deg) + msg_shape[1:]
return F.reshape(msg, new_shape)
reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self.msg_frame.schemes)
new_reprs.append(self.rfunc(dst_reprs, reshaped_in_msgs))
# Pack all reducer results together
keys = new_reprs[0].keys()
new_reprs = {key : F.pack([repr[key] for repr in new_reprs])
for key in keys}
return new_reprs
class BasicExecutor(Executor): class BasicExecutor(Executor):
def __init__(self, graph, mfunc, rfunc): def __init__(self, graph, mfunc, rfunc):
self.g = graph self.g = graph
...@@ -92,7 +207,7 @@ class BasicExecutor(Executor): ...@@ -92,7 +207,7 @@ class BasicExecutor(Executor):
raise NotImplementedError raise NotImplementedError
@property @property
def graph_mapping(self): def recv_nodes(self):
raise NotImplementedError raise NotImplementedError
def _build_exec(self, mfunc, rfunc): def _build_exec(self, mfunc, rfunc):
...@@ -115,8 +230,7 @@ class BasicExecutor(Executor): ...@@ -115,8 +230,7 @@ class BasicExecutor(Executor):
return exe return exe
def run(self): def run(self):
attr = self.exe.run() return self.exe.run()
self.g.set_n_repr(attr, self.graph_mapping)
class UpdateAllExecutor(BasicExecutor): class UpdateAllExecutor(BasicExecutor):
...@@ -129,13 +243,7 @@ class UpdateAllExecutor(BasicExecutor): ...@@ -129,13 +243,7 @@ class UpdateAllExecutor(BasicExecutor):
self._edge_repr = None self._edge_repr = None
self._graph_idx = None self._graph_idx = None
self._graph_shape = None self._graph_shape = None
self._graph_mapping = None self._recv_nodes = None
@property
def graph_idx(self):
if self._graph_idx is None:
self._graph_idx = self.g.cached_graph.adjmat()
return self._graph_idx
@property @property
def graph_shape(self): def graph_shape(self):
...@@ -145,7 +253,7 @@ class UpdateAllExecutor(BasicExecutor): ...@@ -145,7 +253,7 @@ class UpdateAllExecutor(BasicExecutor):
return self._graph_shape return self._graph_shape
@property @property
def graph_mapping(self): def recv_nodes(self):
return ALL return ALL
@property @property
...@@ -162,16 +270,13 @@ class UpdateAllExecutor(BasicExecutor): ...@@ -162,16 +270,13 @@ class UpdateAllExecutor(BasicExecutor):
def _adj_build_fn(self, edge_field, ctx, use_edge_feat): def _adj_build_fn(self, edge_field, ctx, use_edge_feat):
if use_edge_feat: if use_edge_feat:
if edge_field is None:
dat = self.edge_repr
else:
dat = self.edge_repr[edge_field] dat = self.edge_repr[edge_field]
dat = F.squeeze(dat) dat = F.squeeze(dat)
# TODO(minjie): should not directly use _indices # TODO(minjie): should not directly use _indices
idx = self.graph_idx.get(ctx)._indices() idx = self.g.adjacency_matrix(ctx)._indices()
adjmat = F.sparse_tensor(idx, dat, self.graph_shape) adjmat = F.sparse_tensor(idx, dat, self.graph_shape)
else: else:
adjmat = self.graph_idx.get(ctx) adjmat = self.g.adjacency_matrix(ctx)
return adjmat return adjmat
...@@ -186,7 +291,7 @@ class SendRecvExecutor(BasicExecutor): ...@@ -186,7 +291,7 @@ class SendRecvExecutor(BasicExecutor):
self._edge_repr = None self._edge_repr = None
self._graph_idx = None self._graph_idx = None
self._graph_shape = None self._graph_shape = None
self._graph_mapping = None self._recv_nodes = None
@property @property
def graph_idx(self): def graph_idx(self):
...@@ -201,10 +306,10 @@ class SendRecvExecutor(BasicExecutor): ...@@ -201,10 +306,10 @@ class SendRecvExecutor(BasicExecutor):
return self._graph_shape return self._graph_shape
@property @property
def graph_mapping(self): def recv_nodes(self):
if self._graph_mapping is None: if self._recv_nodes is None:
self._build_adjmat() self._build_adjmat()
return self._graph_mapping return self._recv_nodes
@property @property
def node_repr(self): def node_repr(self):
...@@ -221,21 +326,18 @@ class SendRecvExecutor(BasicExecutor): ...@@ -221,21 +326,18 @@ class SendRecvExecutor(BasicExecutor):
def _build_adjmat(self): def _build_adjmat(self):
# handle graph index # handle graph index
new2old, old2new = utils.build_relabel_map(self.v) new2old, old2new = utils.build_relabel_map(self.v)
u = self.u.totensor() u = self.u.tousertensor()
v = self.v.totensor() v = self.v.tousertensor()
# TODO(minjie): should not directly use [] # TODO(minjie): should not directly use []
new_v = old2new[v] new_v = old2new[v]
n = self.g.number_of_nodes() n = self.g.number_of_nodes()
m = len(new2old) m = len(new2old)
self._graph_idx = F.pack([F.unsqueeze(new_v, 0), F.unsqueeze(u, 0)]) self._graph_idx = F.pack([F.unsqueeze(new_v, 0), F.unsqueeze(u, 0)])
self._graph_shape = [m, n] self._graph_shape = [m, n]
self._graph_mapping = new2old self._recv_nodes = new2old
def _adj_build_fn(self, edge_field, ctx, use_edge_feat): def _adj_build_fn(self, edge_field, ctx, use_edge_feat):
if use_edge_feat: if use_edge_feat:
if edge_field is None:
dat = self.edge_repr
else:
dat = self.edge_repr[edge_field] dat = self.edge_repr[edge_field]
dat = F.squeeze(dat) dat = F.squeeze(dat)
else: else:
...@@ -268,9 +370,8 @@ class BundledExecutor(BasicExecutor): ...@@ -268,9 +370,8 @@ class BundledExecutor(BasicExecutor):
func_pairs = [] func_pairs = []
for rfn in rfunc.fn_list: for rfn in rfunc.fn_list:
mfn = out2mfunc.get(rfn.msg_field, None) mfn = out2mfunc.get(rfn.msg_field, None)
# field check if mfn is None:
assert mfn is not None, \ raise DGLError('Cannot find message field "%s".' % rfn.msg_field)
"cannot find message func for reduce func in-field {}".format(rfn.msg_field)
func_pairs.append((mfn, rfn)) func_pairs.append((mfn, rfn))
return func_pairs return func_pairs
...@@ -283,7 +384,7 @@ class BundledExecutor(BasicExecutor): ...@@ -283,7 +384,7 @@ class BundledExecutor(BasicExecutor):
else: else:
# attr and res must be dict # attr and res must be dict
attr.update(res) attr.update(res)
self.g.set_n_repr(attr, self.graph_mapping) return attr
class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor): class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor):
...@@ -291,13 +392,20 @@ class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor): ...@@ -291,13 +392,20 @@ class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor):
self._init_state() self._init_state()
BundledExecutor.__init__(self, graph, mfunc, rfunc) BundledExecutor.__init__(self, graph, mfunc, rfunc)
class BundledSendRecvExecutor(BundledExecutor, SendRecvExecutor): class BundledSendRecvExecutor(BundledExecutor, SendRecvExecutor):
def __init__(self, graph, src, dst, mfunc, rfunc): def __init__(self, graph, src, dst, mfunc, rfunc):
self._init_state(src, dst) self._init_state(src, dst)
BundledExecutor.__init__(self, graph, mfunc, rfunc) BundledExecutor.__init__(self, graph, mfunc, rfunc)
def _is_spmv_supported(fn, graph=None): def _is_spmv_supported(fn, graph=None):
# FIXME: also take into account
# (1) which backend DGL is under.
# (2) whether the graph is a multigraph.
#
# Current SPMV optimizer assumes that duplicate entries are summed up
# in sparse matrices, which is the case for PyTorch but not MXNet.
# The result is that on multigraphs, SPMV can still work for reducer=sum
# and message=copy_src/src_mul_edge *only in PyTorch*.
if isinstance(fn, fmsg.MessageFunction): if isinstance(fn, fmsg.MessageFunction):
return fn.is_spmv_supported(graph) return fn.is_spmv_supported(graph)
elif isinstance(fn, fred.ReduceFunction): elif isinstance(fn, fred.ReduceFunction):
...@@ -342,3 +450,24 @@ def get_executor(call_type, graph, **kwargs): ...@@ -342,3 +450,24 @@ def get_executor(call_type, graph, **kwargs):
return _create_send_and_recv_exec(graph, **kwargs) return _create_send_and_recv_exec(graph, **kwargs)
else: else:
return None return None
def get_recv_executor(graph, reduce_func, message_frame, edges=None):
"""Create executor for recv phase
Parameters
----------
graph: DGLGraph
DGLGraph on which to perform recv
reduce_func: callable
The reduce function
message_frame: FrameRef
Message frame
edges: tuple/list of utils.Index
src and dst Index representing edges along which messages are sent
If not specified, all edges of graph are used instead
"""
# FIXME: handle builtin spmv executor case
return DegreeBucketingExecutor(graph, reduce_func, message_frame, edges)
_init_api("dgl.scheduler")
"""DGLSubGraph""" """Class for subgraph data structure."""
from __future__ import absolute_import from __future__ import absolute_import
import networkx as nx import networkx as nx
import dgl.backend as F
from dgl.frame import Frame, FrameRef from . import backend as F
from dgl.graph import DGLGraph from .frame import Frame, FrameRef
from dgl.nx_adapt import nx_init from .graph import DGLGraph
import dgl.utils as utils from . import utils
class DGLSubGraph(DGLGraph): class DGLSubGraph(DGLGraph):
# TODO(gaiyu): ReadOnlyGraph """The subgraph class.
def __init__(self,
parent,
nodes):
super(DGLSubGraph, self).__init__()
# relabel nodes
self._node_mapping = utils.build_relabel_dict(nodes)
self._parent_nid = utils.toindex(nodes)
eids = []
# create subgraph
for eid, (u, v) in enumerate(parent.edge_list):
if u in self._node_mapping and v in self._node_mapping:
self.add_edge(self._node_mapping[u],
self._node_mapping[v])
eids.append(eid)
self._parent_eid = utils.toindex(eids)
def copy_from(self, parent):
"""Copy node/edge features from the parent graph.
All old features will be removed. There are two subgraph modes: shared and non-shared.
For the "non-shared" mode, the user needs to explicitly call
``copy_from_parent`` to copy node/edge features from its parent graph.
* If the user tries to get node/edge features before ``copy_from_parent``,
s/he will get nothing.
* If the subgraph already has its own node/edge features, ``copy_from_parent``
will override them.
* Any update on the subgraph's node/edge features will not be seen
by the parent graph. As such, the memory consumption is of the order
of the subgraph size.
* To write the subgraph's node/edge features back to parent graph. There are two options:
(1) Use ``copy_to_parent`` API to write node/edge features back.
(2) [TODO] Use ``dgl.merge`` to merge multiple subgraphs back to one parent.
The "shared" mode is currently not supported.
The subgraph is read-only so mutation is not allowed.
Parameters Parameters
---------- ----------
parent : DGLGraph parent : DGLGraph
The parent graph to copy from. The parent graph
parent_nid : utils.Index
The induced parent node ids in this subgraph.
parent_eid : utils.Index
The induced parent edge ids in this subgraph.
graph_idx : GraphIndex
The graph index.
shared : bool, optional
Whether the subgraph shares node/edge features with the parent graph.
"""
def __init__(self, parent, parent_nid, parent_eid, graph_idx, shared=False):
super(DGLSubGraph, self).__init__(graph_data=graph_idx)
self._parent = parent
self._parent_nid = parent_nid
self._parent_eid = parent_eid
# override APIs
def add_nodes(self, num, reprs=None):
"""Add nodes. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v, reprs=None):
"""Add one edge. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v, reprs=None):
"""Add many edges. Disabled because BatchedDGLGraph is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')
@property
def parent_nid(self):
"""Get the parent node ids.
The returned tensor can be used as a map from the node id
in this subgraph to the node id in the parent graph.
Returns
-------
Tensor
The parent node id array.
"""
return self._parent_nid.tousertensor()
@property
def parent_eid(self):
"""Get the parent edge ids.
The returned tensor can be used as a map from the edge id
in this subgraph to the edge id in the parent graph.
Returns
-------
Tensor
The parent edge id array.
"""
return self._parent_eid.tousertensor()
def copy_to_parent(self, inplace=False):
"""Write node/edge features to the parent graph.
Parameters
----------
inplace : bool
If true, use inplace write (no gradient but faster)
"""
self._parent._node_frame.update_rows(
self._parent_nid, self._node_frame, inplace=inplace)
self._parent._edge_frame.update_rows(
self._parent_eid, self._edge_frame, inplace=inplace)
def copy_from_parent(self):
"""Copy node/edge features from the parent graph.
All old features will be removed.
""" """
if parent._node_frame.num_rows != 0: if self._parent._node_frame.num_rows != 0:
self._node_frame = FrameRef(Frame(parent._node_frame[self._parent_nid])) self._node_frame = FrameRef(Frame(
if parent._edge_frame.num_rows != 0: self._parent._node_frame[self._parent_nid]))
self._edge_frame = FrameRef(Frame(parent._edge_frame[self._parent_eid])) if self._parent._edge_frame.num_rows != 0:
self._edge_frame = FrameRef(Frame(
self._parent._edge_frame[self._parent_eid]))
...@@ -5,50 +5,70 @@ from collections import Mapping ...@@ -5,50 +5,70 @@ from collections import Mapping
from functools import wraps from functools import wraps
import numpy as np import numpy as np
import dgl.backend as F from . import backend as F
from dgl.backend import Tensor, SparseTensor from .backend import Tensor, SparseTensor
from . import ndarray as nd
def is_id_tensor(u):
"""Return whether the input is a supported id tensor."""
return isinstance(u, Tensor) and F.isinteger(u) and len(F.shape(u)) == 1
def is_id_container(u):
"""Return whether the input is a supported id container."""
return (getattr(u, '__iter__', None) is not None
and getattr(u, '__len__', None) is not None)
class Index(object): class Index(object):
"""Index class that can be easily converted to list/tensor.""" """Index class that can be easily converted to list/tensor."""
def __init__(self, data): def __init__(self, data):
self._list_data = None self._list_data = None # a numpy type data
self._tensor_data = None self._user_tensor_data = dict() # dictionary of user tensors
self._ctx_data = dict() self._dgl_tensor_data = None # a dgl ndarray
self._dispatch(data) self._dispatch(data)
def _dispatch(self, data): def _dispatch(self, data):
if is_id_tensor(data): """Store data based on its type."""
self._tensor_data = data if isinstance(data, Tensor):
elif is_id_container(data): if not (F.dtype(data) == F.int64 and len(F.shape(data)) == 1):
self._list_data = data raise ValueError('Index data must be 1D int64 vector, but got: %s' % str(data))
self._user_tensor_data[F.get_context(data)] = data
elif isinstance(data, nd.NDArray):
if not (data.dtype == 'int64' and len(data.shape) == 1):
raise ValueError('Index data must be 1D int64 vector, but got: %s' % str(data))
self._dgl_tensor_data = data
else: else:
try: try:
self._list_data = [int(data)] self._list_data = np.array([int(data)]).astype(np.int64)
except:
try:
self._list_data = np.array(data).astype(np.int64)
except: except:
raise TypeError('Error index data: %s' % str(x)) raise ValueError('Error index data: %s' % str(data))
self._user_tensor_data[nd.cpu()] = F.zerocopy_from_numpy(self._list_data)
def tolist(self): def tolist(self):
"""Convert to a python-list compatible object."""
if self._list_data is None: if self._list_data is None:
self._list_data = list(F.asnumpy(self._tensor_data)) if self._dgl_tensor_data is not None:
self._list_data = self._dgl_tensor_data.asnumpy()
else:
data = self.tousertensor()
self._list_data = F.zerocopy_to_numpy(data)
return self._list_data return self._list_data
def totensor(self, ctx=None): def tousertensor(self, ctx=None):
if self._tensor_data is None: """Convert to user tensor (defined in `backend`)."""
self._tensor_data = F.tensor(self._list_data, dtype=F.int64)
if ctx is None: if ctx is None:
return self._tensor_data ctx = nd.cpu()
if ctx not in self._ctx_data: if len(self._user_tensor_data) == 0:
self._ctx_data[ctx] = F.to_context(self._tensor_data, ctx) # zero copy from dgl tensor
return self._ctx_data[ctx] dl = self._dgl_tensor_data.to_dlpack()
self._user_tensor_data[nd.cpu()] = F.zerocopy_from_dlpack(dl)
if ctx not in self._user_tensor_data:
# copy from cpu to another device
data = next(iter(self._user_tensor_data.values()))
self._user_tensor_data[ctx] = F.to_context(data, ctx)
return self._user_tensor_data[ctx]
def todgltensor(self):
"""Convert to dgl.NDArray."""
if self._dgl_tensor_data is None:
# zero copy from user tensor
tsor = self.tousertensor()
dl = F.zerocopy_to_dlpack(tsor)
self._dgl_tensor_data = nd.from_dlpack(dl)
return self._dgl_tensor_data
def __iter__(self): def __iter__(self):
return iter(self.tolist()) return iter(self.tolist())
...@@ -56,8 +76,11 @@ class Index(object): ...@@ -56,8 +76,11 @@ class Index(object):
def __len__(self): def __len__(self):
if self._list_data is not None: if self._list_data is not None:
return len(self._list_data) return len(self._list_data)
elif len(self._user_tensor_data) > 0:
data = next(iter(self._user_tensor_data.values()))
return len(data)
else: else:
return len(self._tensor_data) return len(self._dgl_tensor_data)
def __getitem__(self, i): def __getitem__(self, i):
return self.tolist()[i] return self.tolist()[i]
...@@ -118,40 +141,13 @@ def edge_broadcasting(u, v): ...@@ -118,40 +141,13 @@ def edge_broadcasting(u, v):
The dst id(s) after broadcasting The dst id(s) after broadcasting
""" """
if len(u) != len(v) and len(u) == 1: if len(u) != len(v) and len(u) == 1:
u = toindex(F.broadcast_to(u.totensor(), v.totensor())) u = toindex(F.broadcast_to(u.tousertensor(), v.tousertensor()))
elif len(u) != len(v) and len(v) == 1: elif len(u) != len(v) and len(v) == 1:
v = toindex(F.broadcast_to(v.totensor(), u.totensor())) v = toindex(F.broadcast_to(v.tousertensor(), u.tousertensor()))
else: else:
assert len(u) == len(v) assert len(u) == len(v)
return u, v return u, v
'''
def convert_to_id_container(x):
if is_id_container(x):
return x
elif is_id_tensor(x):
return F.asnumpy(x)
else:
try:
return [int(x)]
except:
raise TypeError('Error node: %s' % str(x))
return None
def convert_to_id_tensor(x, ctx=None):
if is_id_container(x):
ret = F.tensor(x, dtype=F.int64)
elif is_id_tensor(x):
ret = x
else:
try:
ret = F.tensor([int(x)], dtype=F.int64)
except:
raise TypeError('Error node: %s' % str(x))
ret = F.to_context(ret, ctx)
return ret
'''
class LazyDict(Mapping): class LazyDict(Mapping):
"""A readonly dictionary that does not materialize the storage.""" """A readonly dictionary that does not materialize the storage."""
def __init__(self, fn, keys): def __init__(self, fn, keys):
...@@ -172,6 +168,34 @@ class LazyDict(Mapping): ...@@ -172,6 +168,34 @@ class LazyDict(Mapping):
def __len__(self): def __len__(self):
return len(self._keys) return len(self._keys)
class HybridDict(Mapping):
"""A readonly dictonary that merges several dict-like (python dict, LazyDict).
If there are duplicate keys, early keys have priority over latter ones
"""
def __init__(self, *dict_like_list):
self._dict_like_list = dict_like_list
self._keys = None
def keys(self):
if self._keys is None:
self._keys = sum([set(d.keys()) for d in self._dict_like_list], set())
self._keys = list(self._keys)
return self._keys
def __getitem__(self, key):
for d in self._dict_like_list:
if key in d:
return d[key]
def __contains__(self, key):
return key in self.keys()
def __iter__(self):
return iter(self.keys())
def __len__(self):
return len(self.keys())
class ReadOnlyDict(Mapping): class ReadOnlyDict(Mapping):
"""A readonly dictionary wrapper.""" """A readonly dictionary wrapper."""
def __init__(self, dict_like): def __init__(self, dict_like):
...@@ -209,7 +233,7 @@ def build_relabel_map(x): ...@@ -209,7 +233,7 @@ def build_relabel_map(x):
One can use advanced indexing to convert an old id tensor to a One can use advanced indexing to convert an old id tensor to a
new id tensor: new_id = old_to_new[old_id] new id tensor: new_id = old_to_new[old_id]
""" """
x = x.totensor() x = x.tousertensor()
unique_x, _ = F.sort(F.unique(x)) unique_x, _ = F.sort(F.unique(x))
map_len = int(F.max(unique_x)) + 1 map_len = int(F.max(unique_x)) + 1
old_to_new = F.zeros(map_len, dtype=F.int64) old_to_new = F.zeros(map_len, dtype=F.int64)
...@@ -316,6 +340,6 @@ def reorder(dict_like, index): ...@@ -316,6 +340,6 @@ def reorder(dict_like, index):
""" """
new_dict = {} new_dict = {}
for key, val in dict_like.items(): for key, val in dict_like.items():
idx_ctx = index.totensor(F.get_context(val)) idx_ctx = index.tousertensor(F.get_context(val))
new_dict[key] = F.gather_row(val, idx_ctx) new_dict[key] = F.gather_row(val, idx_ctx)
return new_dict return new_dict
...@@ -17,7 +17,6 @@ setuptools.setup( ...@@ -17,7 +17,6 @@ setuptools.setup(
'numpy>=1.14.0', 'numpy>=1.14.0',
'scipy>=1.1.0', 'scipy>=1.1.0',
'networkx>=2.1', 'networkx>=2.1',
'python-igraph>=0.7.0',
], ],
data_files=[('', ['VERSION'])], data_files=[('', ['VERSION'])],
url='https://github.com/jermainewang/dgl-1') url='https://github.com/jermainewang/dgl')
/*!
* Copyright (c) 2018 by Contributors
* \file c_runtime_api.cc
* \brief DGL C API common implementations
*/
#include "c_api_common.h"
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMArgValue;
using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc;
using tvm::runtime::NDArray;
namespace dgl {
DLManagedTensor* CreateTmpDLManagedTensor(const TVMArgValue& arg) {
const DLTensor* dl_tensor = arg;
DLManagedTensor* ret = new DLManagedTensor();
ret->deleter = [] (DLManagedTensor* self) { delete self; };
ret->manager_ctx = nullptr;
ret->dl_tensor = *dl_tensor;
return ret;
}
PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) {
auto body = [vec](TVMArgs args, TVMRetValue* rv) {
size_t which = args[0];
if (which >= vec.size()) {
LOG(FATAL) << "invalid choice";
} else {
*rv = std::move(vec[which]);
}
};
return PackedFunc(body);
}
} // namespace dgl
/*!
* Copyright (c) 2018 by Contributors
* \file c_api_common.h
* \brief DGL C API common util functions
*/
#ifndef DGL_C_API_COMMON_H_
#define DGL_C_API_COMMON_H_
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include <vector>
namespace dgl {
// Graph handler type
typedef void* GraphHandle;
/*!
* \brief Convert the given DLTensor to DLManagedTensor.
*
* Return a temporary DLManagedTensor that does not own memory.
*/
DLManagedTensor* CreateTmpDLManagedTensor(
const tvm::runtime::TVMArgValue& arg);
/*!
* \brief Convert a vector of NDArray to PackedFunc.
*/
tvm::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc(
const std::vector<tvm::runtime::NDArray>& vec);
} // namespace dgl
#endif // DGL_C_API_COMMON_H_
/*!
* Copyright (c) 2018 by Contributors
* \file graph/graph.cc
* \brief DGL graph index implementation
*/
#include <dgl/graph.h>
#include <algorithm>
#include <unordered_map>
#include <set>
#include <functional>
namespace dgl {
namespace {
inline bool IsValidIdArray(const IdArray& arr) {
return arr->ctx.device_type == kDLCPU && arr->ndim == 1
&& arr->dtype.code == kDLInt && arr->dtype.bits == 64;
}
} // namespace
void Graph::AddVertices(uint64_t num_vertices) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
adjlist_.resize(adjlist_.size() + num_vertices);
reverse_adjlist_.resize(reverse_adjlist_.size() + num_vertices);
}
void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
CHECK(HasVertex(src) && HasVertex(dst))
<< "Invalid vertices: src=" << src << " dst=" << dst;
dgl_id_t eid = num_edges_++;
adjlist_[src].succ.push_back(dst);
adjlist_[src].edge_id.push_back(eid);
reverse_adjlist_[dst].succ.push_back(src);
reverse_adjlist_[dst].edge_id.push_back(eid);
all_edges_src_.push_back(src);
all_edges_dst_.push_back(dst);
}
void Graph::AddEdges(IdArray src_ids, IdArray dst_ids) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
if (srclen == 1) {
// one-many
for (int64_t i = 0; i < dstlen; ++i) {
AddEdge(src_data[0], dst_data[i]);
}
} else if (dstlen == 1) {
// many-one
for (int64_t i = 0; i < srclen; ++i) {
AddEdge(src_data[i], dst_data[0]);
}
} else {
// many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
AddEdge(src_data[i], dst_data[i]);
}
}
}
BoolArray Graph::HasVertices(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
BoolArray rst = BoolArray::Empty({len}, vids->dtype, vids->ctx);
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
const int64_t nverts = NumVertices();
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = (vid_data[i] < nverts)? 1 : 0;
}
return rst;
}
// O(E)
bool Graph::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {
if (!HasVertex(src) || !HasVertex(dst)) return false;
const auto& succ = adjlist_[src].succ;
return std::find(succ.begin(), succ.end(), dst) != succ.end();
}
// O(E*k) pretty slow
BoolArray Graph::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen);
BoolArray rst = BoolArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
if (srclen == 1) {
// one-many
for (int64_t i = 0; i < dstlen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[0], dst_data[i])? 1 : 0;
}
} else if (dstlen == 1) {
// many-one
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[i], dst_data[0])? 1 : 0;
}
} else {
// many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[i], dst_data[i])? 1 : 0;
}
}
return rst;
}
// The data is copy-out; support zero-copy?
IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
CHECK(radius >= 1) << "invalid radius: " << radius;
std::set<dgl_id_t> vset;
for (auto& it : reverse_adjlist_[vid].succ)
vset.insert(it);
const int64_t len = vset.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data);
std::copy(vset.begin(), vset.end(), rst_data);
return rst;
}
// The data is copy-out; support zero-copy?
IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
CHECK(radius >= 1) << "invalid radius: " << radius;
std::set<dgl_id_t> vset;
for (auto& it : adjlist_[vid].succ)
vset.insert(it);
const int64_t len = vset.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data);
std::copy(vset.begin(), vset.end(), rst_data);
return rst;
}
// O(E)
IdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
CHECK(HasVertex(src) && HasVertex(dst)) << "invalid edge: " << src << " -> " << dst;
const auto& succ = adjlist_[src].succ;
std::vector<dgl_id_t> edgelist;
for (size_t i = 0; i < succ.size(); ++i) {
if (succ[i] == dst)
edgelist.push_back(adjlist_[src].edge_id[i]);
}
// FIXME: signed? Also it seems that we are using int64_t everywhere...
const int64_t len = edgelist.size();
IdArray rst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
// FIXME: signed?
int64_t* rst_data = static_cast<int64_t*>(rst->data);
std::copy(edgelist.begin(), edgelist.end(), rst_data);
return rst;
}
// O(E*k) pretty slow
Graph::EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array.";
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
int64_t i, j;
CHECK((srclen == dstlen) || (srclen == 1) || (dstlen == 1))
<< "Invalid src and dst id array.";
const int64_t src_stride = (srclen == 1 && dstlen != 1) ? 0 : 1;
const int64_t dst_stride = (dstlen == 1 && srclen != 1) ? 0 : 1;
const int64_t* src_data = static_cast<int64_t*>(src_ids->data);
const int64_t* dst_data = static_cast<int64_t*>(dst_ids->data);
std::vector<dgl_id_t> src, dst, eid;
for (i = 0, j = 0; i < srclen && j < dstlen; i += src_stride, j += dst_stride) {
const dgl_id_t src_id = src_data[i], dst_id = dst_data[j];
const auto& succ = adjlist_[src_id].succ;
for (size_t k = 0; k < succ.size(); ++k) {
if (succ[k] == dst_id) {
src.push_back(src_id);
dst.push_back(dst_id);
eid.push_back(adjlist_[src_id].edge_id[k]);
}
}
}
int64_t rstlen = src.size();
IdArray rst_src = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
IdArray rst_dst = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
IdArray rst_eid = IdArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
int64_t* rst_src_data = static_cast<int64_t*>(rst_src->data);
int64_t* rst_dst_data = static_cast<int64_t*>(rst_dst->data);
int64_t* rst_eid_data = static_cast<int64_t*>(rst_eid->data);
std::copy(src.begin(), src.end(), rst_src_data);
std::copy(dst.begin(), dst.end(), rst_dst_data);
std::copy(eid.begin(), eid.end(), rst_eid_data);
return EdgeArray{rst_src, rst_dst, rst_eid};
}
Graph::EdgeArray Graph::FindEdges(IdArray eids) const {
int64_t len = eids->shape[0];
IdArray rst_src = IdArray::Empty({len}, eids->dtype, eids->ctx);
IdArray rst_dst = IdArray::Empty({len}, eids->dtype, eids->ctx);
IdArray rst_eid = IdArray::Empty({len}, eids->dtype, eids->ctx);
int64_t* eid_data = static_cast<int64_t*>(eids->data);
int64_t* rst_src_data = static_cast<int64_t*>(rst_src->data);
int64_t* rst_dst_data = static_cast<int64_t*>(rst_dst->data);
int64_t* rst_eid_data = static_cast<int64_t*>(rst_eid->data);
for (uint64_t i = 0; i < (uint64_t)len; ++i) {
dgl_id_t eid = eid_data[i];
if (eid >= num_edges_)
LOG(FATAL) << "invalid edge id:" << eid;
rst_src_data[i] = all_edges_src_[eid];
rst_dst_data[i] = all_edges_dst_[eid];
rst_eid_data[i] = eid;
}
return EdgeArray{rst_src, rst_dst, rst_eid};
}
// O(E)
Graph::EdgeArray Graph::InEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = reverse_adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* src_data = static_cast<int64_t*>(src->data);
int64_t* dst_data = static_cast<int64_t*>(dst->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
src_data[i] = reverse_adjlist_[vid].succ[i];
eid_data[i] = reverse_adjlist_[vid].edge_id[i];
}
std::fill(dst_data, dst_data + len, vid);
return EdgeArray{src, dst, eid};
}
// O(E)
Graph::EdgeArray Graph::InEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t rstlen = 0;
for (int64_t i = 0; i < len; ++i) {
CHECK(HasVertex(vid_data[i])) << "Invalid vertex: " << vid_data[i];
rstlen += reverse_adjlist_[vid_data[i]].succ.size();
}
IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray eid = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
const auto& pred = reverse_adjlist_[vid_data[i]].succ;
const auto& eids = reverse_adjlist_[vid_data[i]].edge_id;
for (size_t j = 0; j < pred.size(); ++j) {
*(src_ptr++) = pred[j];
*(dst_ptr++) = vid_data[i];
*(eid_ptr++) = eids[j];
}
}
return EdgeArray{src, dst, eid};
}
// O(E)
Graph::EdgeArray Graph::OutEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* src_data = static_cast<int64_t*>(src->data);
int64_t* dst_data = static_cast<int64_t*>(dst->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
dst_data[i] = adjlist_[vid].succ[i];
eid_data[i] = adjlist_[vid].edge_id[i];
}
std::fill(src_data, src_data + len, vid);
return EdgeArray{src, dst, eid};
}
// O(E)
Graph::EdgeArray Graph::OutEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
int64_t rstlen = 0;
for (int64_t i = 0; i < len; ++i) {
CHECK(HasVertex(vid_data[i])) << "Invalid vertex: " << vid_data[i];
rstlen += adjlist_[vid_data[i]].succ.size();
}
IdArray src = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray dst = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
IdArray eid = IdArray::Empty({rstlen}, vids->dtype, vids->ctx);
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
const auto& succ = adjlist_[vid_data[i]].succ;
const auto& eids = adjlist_[vid_data[i]].edge_id;
for (size_t j = 0; j < succ.size(); ++j) {
*(src_ptr++) = vid_data[i];
*(dst_ptr++) = succ[j];
*(eid_ptr++) = eids[j];
}
}
return EdgeArray{src, dst, eid};
}
// O(E*log(E)) if sort is required; otherwise, O(E)
Graph::EdgeArray Graph::Edges(bool sorted) const {
const int64_t len = num_edges_;
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
if (sorted) {
typedef std::tuple<int64_t, int64_t, int64_t> Tuple;
std::vector<Tuple> tuples;
tuples.reserve(len);
for (uint64_t eid = 0; eid < num_edges_; ++eid) {
tuples.emplace_back(all_edges_src_[eid], all_edges_dst_[eid], eid);
}
// sort according to src and dst ids
std::sort(tuples.begin(), tuples.end(),
[] (const Tuple& t1, const Tuple& t2) {
return std::get<0>(t1) < std::get<0>(t2)
|| (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2));
});
// make return arrays
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
for (size_t i = 0; i < tuples.size(); ++i) {
src_ptr[i] = std::get<0>(tuples[i]);
dst_ptr[i] = std::get<1>(tuples[i]);
eid_ptr[i] = std::get<2>(tuples[i]);
}
} else {
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
int64_t* eid_ptr = static_cast<int64_t*>(eid->data);
std::copy(all_edges_src_.begin(), all_edges_src_.end(), src_ptr);
std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), dst_ptr);
for (uint64_t eid = 0; eid < num_edges_; ++eid) {
eid_ptr[eid] = eid;
}
}
return EdgeArray{src, dst, eid};
}
// O(V)
DegreeArray Graph::InDegrees(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
const auto vid = vid_data[i];
CHECK(HasVertex(vid)) << "Invalid vertex: " << vid;
rst_data[i] = reverse_adjlist_[vid].succ.size();
}
return rst;
}
// O(V)
DegreeArray Graph::OutDegrees(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);
int64_t* rst_data = static_cast<int64_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
const auto vid = vid_data[i];
CHECK(HasVertex(vid)) << "Invalid vertex: " << vid;
rst_data[i] = adjlist_[vid].succ.size();
}
return rst;
}
Subgraph Graph::VertexSubgraph(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0];
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;
std::vector<dgl_id_t> edges;
const int64_t* vid_data = static_cast<int64_t*>(vids->data);
for (int64_t i = 0; i < len; ++i) {
oldv2newv[vid_data[i]] = i;
}
Subgraph rst;
rst.induced_vertices = vids;
rst.graph.AddVertices(len);
for (int64_t i = 0; i < len; ++i) {
const dgl_id_t oldvid = vid_data[i];
const dgl_id_t newvid = i;
for (size_t j = 0; j < adjlist_[oldvid].succ.size(); ++j) {
const dgl_id_t oldsucc = adjlist_[oldvid].succ[j];
if (oldv2newv.count(oldsucc)) {
const dgl_id_t newsucc = oldv2newv[oldsucc];
edges.push_back(adjlist_[oldvid].edge_id[j]);
rst.graph.AddEdge(newvid, newsucc);
}
}
}
rst.induced_edges = IdArray::Empty({static_cast<int64_t>(edges.size())}, vids->dtype, vids->ctx);
std::copy(edges.begin(), edges.end(), static_cast<int64_t*>(rst.induced_edges->data));
return rst;
}
Subgraph Graph::EdgeSubgraph(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid vertex id array.";
const auto len = eids->shape[0];
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;
std::vector<dgl_id_t> nodes;
const int64_t* eid_data = static_cast<int64_t*>(eids->data);
for (int64_t i = 0; i < len; ++i) {
dgl_id_t src_id = all_edges_src_[eid_data[i]];
dgl_id_t dst_id = all_edges_dst_[eid_data[i]];
if (oldv2newv.insert(std::make_pair(src_id, oldv2newv.size())).second)
nodes.push_back(src_id);
if (oldv2newv.insert(std::make_pair(dst_id, oldv2newv.size())).second)
nodes.push_back(dst_id);
}
Subgraph rst;
rst.induced_edges = eids;
rst.graph.AddVertices(nodes.size());
for (int64_t i = 0; i < len; ++i) {
dgl_id_t src_id = all_edges_src_[eid_data[i]];
dgl_id_t dst_id = all_edges_dst_[eid_data[i]];
rst.graph.AddEdge(oldv2newv[src_id], oldv2newv[dst_id]);
}
rst.induced_vertices = IdArray::Empty(
{static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);
std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data));
return rst;
}
Graph Graph::Reverse() const {
LOG(FATAL) << "not implemented";
return *this;
}
} // namespace dgl
/*!
* Copyright (c) 2018 by Contributors
* \file graph/graph.cc
* \brief DGL graph index APIs
*/
#include <dgl/graph.h>
#include <dgl/graph_op.h>
#include "../c_api_common.h"
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMArgValue;
using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc;
using tvm::runtime::NDArray;
namespace dgl {
namespace {
// Convert EdgeArray structure to PackedFunc.
PackedFunc ConvertEdgeArrayToPackedFunc(const Graph::EdgeArray& ea) {
auto body = [ea] (TVMArgs args, TVMRetValue* rv) {
int which = args[0];
if (which == 0) {
*rv = std::move(ea.src);
} else if (which == 1) {
*rv = std::move(ea.dst);
} else if (which == 2) {
*rv = std::move(ea.id);
} else {
LOG(FATAL) << "invalid choice";
}
};
return PackedFunc(body);
}
// Convert Subgraph structure to PackedFunc.
PackedFunc ConvertSubgraphToPackedFunc(const Subgraph& sg) {
auto body = [sg] (TVMArgs args, TVMRetValue* rv) {
int which = args[0];
if (which == 0) {
Graph* gptr = new Graph();
*gptr = std::move(sg.graph);
GraphHandle ghandle = gptr;
*rv = ghandle;
} else if (which == 1) {
*rv = std::move(sg.induced_vertices);
} else if (which == 2) {
*rv = std::move(sg.induced_edges);
} else {
LOG(FATAL) << "invalid choice";
}
};
return PackedFunc(body);
}
} // namespace
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreate")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
bool multigraph = static_cast<bool>(args[0]);
GraphHandle ghandle = new Graph(multigraph);
*rv = ghandle;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFree")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
delete gptr;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
uint64_t num_vertices = args[1];
gptr->AddVertices(num_vertices);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdge")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
gptr->AddEdge(src, dst);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const IdArray dst = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
gptr->AddEdges(src, dst);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphClear")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
Graph* gptr = static_cast<Graph*>(ghandle);
gptr->Clear();
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsMultigraph")
.set_body([] (TVMArgs args, TVMRetValue *rv) {
GraphHandle ghandle = args[0];
// NOTE: not const since we have caches
const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = gptr->IsMultigraph();
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumVertices")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumVertices());
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
*rv = static_cast<int64_t>(gptr->NumEdges());
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertex")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = gptr->HasVertex(vid);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertices")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = gptr->HasVertices(vids);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgeBetween")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
*rv = gptr->HasEdgeBetween(src, dst);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgesBetween")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const IdArray dst = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
*rv = gptr->HasEdgesBetween(src, dst);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
const uint64_t radius = args[2];
*rv = gptr->Predecessors(vid, radius);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphSuccessors")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
const uint64_t radius = args[2];
*rv = gptr->Successors(vid, radius);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
*rv = gptr->EdgeId(src, dst);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray src = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
const IdArray dst = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
*rv = ConvertEdgeArrayToPackedFunc(gptr->EdgeIds(src, dst));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFindEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray eids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertEdgeArrayToPackedFunc(gptr->FindEdges(eids));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_1")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->InEdges(vid));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_2")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertEdgeArrayToPackedFunc(gptr->InEdges(vids));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_1")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->OutEdges(vid));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_2")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertEdgeArrayToPackedFunc(gptr->OutEdges(vids));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdges")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const bool sorted = args[1];
*rv = ConvertEdgeArrayToPackedFunc(gptr->Edges(sorted));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegree")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(gptr->InDegree(vid));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegrees")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = gptr->InDegrees(vids);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegree")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(gptr->OutDegree(vid));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegrees")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = gptr->OutDegrees(vids);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray vids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertSubgraphToPackedFunc(gptr->VertexSubgraph(vids));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph *gptr = static_cast<Graph*>(ghandle);
const IdArray eids = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
*rv = ConvertSubgraphToPackedFunc(gptr->EdgeSubgraph(eids));
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
void* list = args[0];
GraphHandle* inhandles = static_cast<GraphHandle*>(list);
int list_size = args[1];
std::vector<const Graph*> graphs;
for (int i = 0; i < list_size; ++i) {
const Graph* gr = static_cast<const Graph*>(inhandles[i]);
graphs.push_back(gr);
}
Graph* gptr = new Graph();
*gptr = GraphOp::DisjointUnion(std::move(graphs));
GraphHandle ghandle = gptr;
*rv = ghandle;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
int64_t num = args[1];
std::vector<Graph>&& rst = GraphOp::DisjointPartitionByNum(gptr, num);
// return the pointer array as an integer array
const int64_t len = rst.size();
NDArray ptr_array = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* ptr_array_data = static_cast<int64_t*>(ptr_array->data);
for (size_t i = 0; i < rst.size(); ++i) {
Graph* ptr = new Graph();
*ptr = std::move(rst[i]);
ptr_array_data[i] = reinterpret_cast<std::intptr_t>(ptr);
}
*rv = ptr_array;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray sizes = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
std::vector<Graph>&& rst = GraphOp::DisjointPartitionBySizes(gptr, sizes);
// return the pointer array as an integer array
const int64_t len = rst.size();
NDArray ptr_array = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* ptr_array_data = static_cast<int64_t*>(ptr_array->data);
for (size_t i = 0; i < rst.size(); ++i) {
Graph* ptr = new Graph();
*ptr = std::move(rst[i]);
ptr_array_data[i] = reinterpret_cast<std::intptr_t>(ptr);
}
*rv = ptr_array;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLineGraph")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
bool backtracking = args[1];
const Graph* gptr = static_cast<Graph*>(ghandle);
Graph* lgptr = new Graph();
*lgptr = GraphOp::LineGraph(gptr, backtracking);
GraphHandle lghandle = lgptr;
*rv = lghandle;
});
} // namespace dgl
/*!
* Copyright (c) 2018 by Contributors
* \file graph/graph.cc
* \brief Graph operation implementation
*/
#include <dgl/graph_op.h>
#include <algorithm>
namespace dgl {
Graph GraphOp::LineGraph(const Graph* g, bool backtracking) {
typedef std::pair<dgl_id_t, dgl_id_t> entry;
typedef std::map<dgl_id_t, std::vector<entry>> csm; // Compressed Sparse Matrix
csm adj;
std::vector<entry> vec;
for (size_t i = 0; i != g->all_edges_src_.size(); ++i) {
auto u = g->all_edges_src_[i];
auto v = g->all_edges_dst_[i];
auto ret = adj.insert(csm::value_type(u, vec));
(ret.first)->second.push_back(std::make_pair(v, i));
}
std::vector<dgl_id_t> lg_src, lg_dst;
for (size_t i = 0; i != g->all_edges_src_.size(); ++i) {
auto u = g->all_edges_src_[i];
auto v = g->all_edges_dst_[i];
auto j = adj.find(v);
if (j != adj.end()) {
for (size_t k = 0; k != j->second.size(); ++k) {
if (backtracking || (!backtracking && j->second[k].first != u)) {
lg_src.push_back(i);
lg_dst.push_back(j->second[k].second);
}
}
}
}
const int64_t len = lg_src.size();
IdArray src = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray dst = IdArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
std::copy(lg_src.begin(), lg_src.end(), src_ptr);
std::copy(lg_dst.begin(), lg_dst.end(), dst_ptr);
Graph lg;
lg.AddVertices(g->NumEdges());
lg.AddEdges(src, dst);
return lg;
}
Graph GraphOp::DisjointUnion(std::vector<const Graph*> graphs) {
Graph rst;
uint64_t cumsum = 0;
for (const Graph* gr : graphs) {
rst.AddVertices(gr->NumVertices());
for (uint64_t i = 0; i < gr->NumEdges(); ++i) {
rst.AddEdge(gr->all_edges_src_[i] + cumsum, gr->all_edges_dst_[i] + cumsum);
}
cumsum += gr->NumVertices();
}
return rst;
}
std::vector<Graph> GraphOp::DisjointPartitionByNum(const Graph* graph, int64_t num) {
CHECK(num != 0 && graph->NumVertices() % num == 0)
<< "Number of partitions must evenly divide the number of nodes.";
IdArray sizes = IdArray::Empty({num}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num);
return DisjointPartitionBySizes(graph, sizes);
}
std::vector<Graph> GraphOp::DisjointPartitionBySizes(const Graph* graph, IdArray sizes) {
const int64_t len = sizes->shape[0];
const int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::vector<int64_t> cumsum;
cumsum.push_back(0);
for (int64_t i = 0; i < len; ++i) {
cumsum.push_back(cumsum[i] + sizes_data[i]);
}
CHECK_EQ(cumsum[len], graph->NumVertices())
<< "Sum of the given sizes must equal to the number of nodes.";
dgl_id_t node_offset = 0, edge_offset = 0;
std::vector<Graph> rst(len);
for (int64_t i = 0; i < len; ++i) {
// copy adj
rst[i].adjlist_.insert(rst[i].adjlist_.end(),
graph->adjlist_.begin() + node_offset,
graph->adjlist_.begin() + node_offset + sizes_data[i]);
rst[i].reverse_adjlist_.insert(rst[i].reverse_adjlist_.end(),
graph->reverse_adjlist_.begin() + node_offset,
graph->reverse_adjlist_.begin() + node_offset + sizes_data[i]);
// relabel adjs
size_t num_edges = 0;
for (auto& elist : rst[i].adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
}
num_edges += elist.succ.size();
}
for (auto& elist : rst[i].reverse_adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
}
}
// copy edges
rst[i].all_edges_src_.reserve(num_edges);
rst[i].all_edges_dst_.reserve(num_edges);
rst[i].num_edges_ = num_edges;
for (size_t j = edge_offset; j < edge_offset + num_edges; ++j) {
rst[i].all_edges_src_.push_back(graph->all_edges_src_[j] - node_offset);
rst[i].all_edges_dst_.push_back(graph->all_edges_dst_[j] - node_offset);
}
// update offset
CHECK_EQ(rst[i].NumVertices(), sizes_data[i]);
CHECK_EQ(rst[i].NumEdges(), num_edges);
node_offset += sizes_data[i];
edge_offset += num_edges;
}
return rst;
}
} // namespace dgl
/*!
* Copyright (c) 2016 by Contributors
* \file c_runtime_api.cc
* \brief Runtime API implementation
*/
#include <dmlc/thread_local.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <array>
#include <algorithm>
#include <string>
#include <cstdlib>
#include "runtime_base.h"
namespace tvm {
namespace runtime {
/*!
* \brief The name of Device API factory.
* \param type The device type.
*/
inline std::string DeviceName(int type) {
switch (type) {
case kDLCPU: return "cpu";
case kDLGPU: return "gpu";
case kDLOpenCL: return "opencl";
case kDLSDAccel: return "sdaccel";
case kDLAOCL: return "aocl";
case kDLVulkan: return "vulkan";
case kDLMetal: return "metal";
case kDLVPI: return "vpi";
case kDLROCM: return "rocm";
case kOpenGL: return "opengl";
case kExtDev: return "ext_dev";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
class DeviceAPIManager {
public:
static const int kMaxDeviceAPI = 32;
// Get API
static DeviceAPI* Get(const TVMContext& ctx) {
return Get(ctx.device_type);
}
static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
return Global()->GetAPI(dev_type, allow_missing);
}
private:
std::array<DeviceAPI*, kMaxDeviceAPI> api_;
DeviceAPI* rpc_api_{nullptr};
std::mutex mutex_;
// constructor
DeviceAPIManager() {
std::fill(api_.begin(), api_.end(), nullptr);
}
// Global static variable.
static DeviceAPIManager* Global() {
static DeviceAPIManager inst;
return &inst;
}
// Get or initialize API.
DeviceAPI* GetAPI(int type, bool allow_missing) {
if (type < kRPCSessMask) {
if (api_[type] != nullptr) return api_[type];
std::lock_guard<std::mutex> lock(mutex_);
if (api_[type] != nullptr) return api_[type];
api_[type] = GetAPI(DeviceName(type), allow_missing);
return api_[type];
} else {
if (rpc_api_ != nullptr) return rpc_api_;
std::lock_guard<std::mutex> lock(mutex_);
if (rpc_api_ != nullptr) return rpc_api_;
rpc_api_ = GetAPI("rpc", allow_missing);
return rpc_api_;
}
}
DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
std::string factory = "device_api." + name;
auto* f = Registry::Get(factory);
if (f == nullptr) {
CHECK(allow_missing)
<< "Device API " << name << " is not enabled.";
return nullptr;
}
void* ptr = (*f)();
return static_cast<DeviceAPI*>(ptr);
}
};
DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) {
return DeviceAPIManager::Get(
static_cast<int>(ctx.device_type), allow_missing);
}
void* DeviceAPI::AllocWorkspace(TVMContext ctx,
size_t size,
TVMType type_hint) {
return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
}
void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) {
FreeDataSpace(ctx, ptr);
}
TVMStreamHandle DeviceAPI::CreateStream(TVMContext ctx) {
LOG(FATAL) << "Device does not support stream api.";
return 0;
}
void DeviceAPI::FreeStream(TVMContext ctx, TVMStreamHandle stream) {
LOG(FATAL) << "Device does not support stream api.";
}
void DeviceAPI::SyncStreamFromTo(TVMContext ctx,
TVMStreamHandle event_src,
TVMStreamHandle event_dst) {
LOG(FATAL) << "Device does not support stream api.";
}
} // namespace runtime
} // namespace tvm
using namespace tvm::runtime;
struct TVMRuntimeEntry {
std::string ret_str;
std::string last_error;
TVMByteArray ret_bytes;
};
typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore;
const char *TVMGetLastError() {
return TVMAPIRuntimeStore::Get()->last_error.c_str();
}
void TVMAPISetLastError(const char* msg) {
#ifndef _LIBCPP_SGX_CONFIG
TVMAPIRuntimeStore::Get()->last_error = msg;
#else
sgx::OCallPackedFunc("__sgx_set_last_error__", msg);
#endif
}
int TVMModLoadFromFile(const char* file_name,
const char* format,
TVMModuleHandle* out) {
API_BEGIN();
Module m = Module::LoadFromFile(file_name, format);
*out = new Module(m);
API_END();
}
int TVMModImport(TVMModuleHandle mod,
TVMModuleHandle dep) {
API_BEGIN();
static_cast<Module*>(mod)->Import(
*static_cast<Module*>(dep));
API_END();
}
int TVMModGetFunction(TVMModuleHandle mod,
const char* func_name,
int query_imports,
TVMFunctionHandle *func) {
API_BEGIN();
PackedFunc pf = static_cast<Module*>(mod)->GetFunction(
func_name, query_imports != 0);
if (pf != nullptr) {
*func = new PackedFunc(pf);
} else {
*func = nullptr;
}
API_END();
}
int TVMModFree(TVMModuleHandle mod) {
API_BEGIN();
delete static_cast<Module*>(mod);
API_END();
}
int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
TVMFunctionHandle *func) {
API_BEGIN();
*func = (TVMFunctionHandle)(
static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name));
API_END();
}
void* TVMBackendAllocWorkspace(int device_type,
int device_id,
uint64_t size,
int dtype_code_hint,
int dtype_bits_hint) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
TVMType type_hint;
type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
type_hint.lanes = 1;
return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx,
static_cast<size_t>(size),
type_hint);
}
int TVMBackendFreeWorkspace(int device_type,
int device_id,
void* ptr) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);
return 0;
}
int TVMBackendRunOnce(void** handle,
int (*f)(void*),
void* cdata,
int nbytes) {
if (*handle == nullptr) {
*handle = reinterpret_cast<void*>(1);
return (*f)(cdata);
}
return 0;
}
int TVMFuncFree(TVMFunctionHandle func) {
API_BEGIN();
delete static_cast<PackedFunc*>(func);
API_END();
}
int TVMFuncCall(TVMFunctionHandle func,
TVMValue* args,
int* arg_type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code) {
API_BEGIN();
TVMRetValue rv;
(*static_cast<const PackedFunc*>(func)).CallPacked(
TVMArgs(args, arg_type_codes, num_args), &rv);
// handle return string.
if (rv.type_code() == kStr ||
rv.type_code() == kTVMType ||
rv.type_code() == kBytes) {
TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
if (rv.type_code() != kTVMType) {
e->ret_str = *rv.ptr<std::string>();
} else {
e->ret_str = rv.operator std::string();
}
if (rv.type_code() == kBytes) {
e->ret_bytes.data = e->ret_str.c_str();
e->ret_bytes.size = e->ret_str.length();
*ret_type_code = kBytes;
ret_val->v_handle = &(e->ret_bytes);
} else {
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
}
} else {
rv.MoveToCHost(ret_val, ret_type_code);
}
API_END();
}
int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue* value,
int* type_code,
int num_ret) {
API_BEGIN();
CHECK_EQ(num_ret, 1);
TVMRetValue* rv = static_cast<TVMRetValue*>(ret);
*rv = TVMArgValue(value[0], type_code[0]);
API_END();
}
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out) {
API_BEGIN();
if (fin == nullptr) {
*out = new PackedFunc(
[func, resource_handle](TVMArgs args, TVMRetValue* rv) {
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, resource_handle);
if (ret != 0) {
std::string err = "TVMCall CFunc Error:\n";
err += TVMGetLastError();
throw dmlc::Error(err);
}
});
} else {
// wrap it in a shared_ptr, with fin as deleter.
// so fin will be called when the lambda went out of scope.
std::shared_ptr<void> rpack(resource_handle, fin);
*out = new PackedFunc(
[func, rpack](TVMArgs args, TVMRetValue* rv) {
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, rpack.get());
if (ret != 0) {
std::string err = "TVMCall CFunc Error:\n";
err += TVMGetLastError();
throw dmlc::Error(err);
}
});
}
API_END();
}
int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
*out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
API_END();
}
int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
API_END();
}
int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
API_END();
}
int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
API_END();
}
int TVMStreamStreamSynchronize(int device_type,
int device_id,
TVMStreamHandle src,
TVMStreamHandle dst) {
API_BEGIN();
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
API_END();
}
int TVMCbArgToReturn(TVMValue* value, int code) {
API_BEGIN();
tvm::runtime::TVMRetValue rv;
rv = tvm::runtime::TVMArgValue(*value, code);
int tcode;
rv.MoveToCHost(value, &tcode);
CHECK_EQ(tcode, code);
API_END();
}
// set device api
TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device)
.set_body([](TVMArgs args, TVMRetValue *ret) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
ctx.device_id = args[1];
DeviceAPIManager::Get(ctx)->SetDevice(ctx);
});
// set device api
TVM_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
ctx.device_id = args[1];
DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
if (kind == kExist) {
DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
if (api != nullptr) {
api->GetAttr(ctx, kind, ret);
} else {
*ret = 0;
}
} else {
DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
}
});
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment