Unverified Commit 9c135fd5 authored by VoVAllen's avatar VoVAllen Committed by GitHub
Browse files

Merge pull request #4 from jermainewang/master

Sync with latest commit
parents 9d3f299d 00add9f2
This diff is collapsed.
...@@ -4,17 +4,25 @@ from __future__ import absolute_import ...@@ -4,17 +4,25 @@ from __future__ import absolute_import
import operator import operator
import dgl.backend as F import dgl.backend as F
__all__ = ["MessageFunction", "src_mul_edge", "copy_src", "copy_edge"] __all__ = ["src_mul_edge", "copy_src", "copy_edge"]
class MessageFunction(object): class MessageFunction(object):
"""Base builtin message function class."""
def __call__(self, src, edge): def __call__(self, src, edge):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise NotImplementedError raise NotImplementedError
def name(self): def name(self):
"""Return the name of this builtin function."""
raise NotImplementedError raise NotImplementedError
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
"""Return whether the SPMV optimization is supported."""
raise NotImplementedError raise NotImplementedError
...@@ -22,12 +30,6 @@ class BundledMessageFunction(MessageFunction): ...@@ -22,12 +30,6 @@ class BundledMessageFunction(MessageFunction):
def __init__(self, fn_list): def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)): if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list] fn_list = [fn_list]
else:
# sanity check on out field
for fn in fn_list:
# cannot perform check for udf
if isinstance(fn, MessageFunction) and fn.out_field is None:
raise RuntimeError("Not specifying out field for multiple message is ambiguous")
self.fn_list = fn_list self.fn_list = fn_list
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
...@@ -43,11 +45,8 @@ class BundledMessageFunction(MessageFunction): ...@@ -43,11 +45,8 @@ class BundledMessageFunction(MessageFunction):
if ret is None: if ret is None:
ret = msg ret = msg
else: else:
try: # ret and msg must be dict
# ret and msg must be dict ret.update(msg)
ret.update(msg)
except:
raise RuntimeError("Must specify out field for multiple message")
return ret return ret
def name(self): def name(self):
...@@ -55,25 +54,26 @@ class BundledMessageFunction(MessageFunction): ...@@ -55,25 +54,26 @@ class BundledMessageFunction(MessageFunction):
def _is_spmv_supported_node_feat(g, field): def _is_spmv_supported_node_feat(g, field):
if field is None: """Return whether the node feature shape supports SPMV optimization.
feat = g.get_n_repr()
else: Only scalar and vector features are supported currently.
feat = g.get_n_repr()[field] """
feat = g.get_n_repr()[field]
shape = F.shape(feat) shape = F.shape(feat)
return len(shape) == 1 or len(shape) == 2 return len(shape) == 1 or len(shape) == 2
def _is_spmv_supported_edge_feat(g, field): def _is_spmv_supported_edge_feat(g, field):
# check shape, only scalar edge feature can be optimized at the moment """Return whether the edge feature shape supports SPMV optimization.
if field is None:
feat = g.get_e_repr() Only scalar feature is supported currently.
else: """
feat = g.get_e_repr()[field] feat = g.get_e_repr()[field]
shape = F.shape(feat) shape = F.shape(feat)
return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1) return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1)
class SrcMulEdgeMessageFunction(MessageFunction): class SrcMulEdgeMessageFunction(MessageFunction):
def __init__(self, mul_op, src_field=None, edge_field=None, out_field=None): def __init__(self, mul_op, src_field, edge_field, out_field):
self.mul_op = mul_op self.mul_op = mul_op
self.src_field = src_field self.src_field = src_field
self.edge_field = edge_field self.edge_field = edge_field
...@@ -84,21 +84,14 @@ class SrcMulEdgeMessageFunction(MessageFunction): ...@@ -84,21 +84,14 @@ class SrcMulEdgeMessageFunction(MessageFunction):
and _is_spmv_supported_edge_feat(g, self.edge_field) and _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge): def __call__(self, src, edge):
if self.src_field is not None: ret = self.mul_op(src[self.src_field], edge[self.edge_field])
src = src[self.src_field] return {self.out_field : ret}
if self.edge_field is not None:
edge = edge[self.edge_field]
ret = self.mul_op(src, edge)
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self): def name(self):
return "src_mul_edge" return "src_mul_edge"
class CopySrcMessageFunction(MessageFunction): class CopySrcMessageFunction(MessageFunction):
def __init__(self, src_field=None, out_field=None): def __init__(self, src_field, out_field):
self.src_field = src_field self.src_field = src_field
self.out_field = out_field self.out_field = out_field
...@@ -106,14 +99,7 @@ class CopySrcMessageFunction(MessageFunction): ...@@ -106,14 +99,7 @@ class CopySrcMessageFunction(MessageFunction):
return _is_spmv_supported_node_feat(g, self.src_field) return _is_spmv_supported_node_feat(g, self.src_field)
def __call__(self, src, edge): def __call__(self, src, edge):
if self.src_field is not None: return {self.out_field : src[self.src_field]}
ret = src[self.src_field]
else:
ret = src
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self): def name(self):
return "copy_src" return "copy_src"
...@@ -142,14 +128,41 @@ class CopyEdgeMessageFunction(MessageFunction): ...@@ -142,14 +128,41 @@ class CopyEdgeMessageFunction(MessageFunction):
return "copy_edge" return "copy_edge"
def src_mul_edge(src=None, edge=None, out=None): def src_mul_edge(src, edge, out):
"""TODO(minjie): docstring """ """Builtin message function that computes message by multiplying source node features
with edge features.
Parameters
----------
src : str
The source feature name.
edge : str
The edge feature name.
out : str
The output message name.
"""
return SrcMulEdgeMessageFunction(operator.mul, src, edge, out) return SrcMulEdgeMessageFunction(operator.mul, src, edge, out)
def copy_src(src=None, out=None): def copy_src(src, out):
"""TODO(minjie): docstring """ """Builtin message function that computes message using source node feature.
Parameters
----------
src : str
The source feature name.
out : str
The output message name.
"""
return CopySrcMessageFunction(src, out) return CopySrcMessageFunction(src, out)
def copy_edge(edge=None, out=None): def copy_edge(edge, out):
"""TODO(minjie): docstring """ """Builtin message function that computes message using edge feature.
Parameters
----------
edge : str
The edge feature name.
out : str
The output message name.
"""
return CopyEdgeMessageFunction(edge, out) return CopyEdgeMessageFunction(edge, out)
...@@ -3,27 +3,30 @@ from __future__ import absolute_import ...@@ -3,27 +3,30 @@ from __future__ import absolute_import
from .. import backend as F from .. import backend as F
__all__ = ["ReduceFunction", "sum", "max"] __all__ = ["sum", "max"]
class ReduceFunction(object): class ReduceFunction(object):
"""Base builtin reduce function class."""
def __call__(self, node, msgs): def __call__(self, node, msgs):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise NotImplementedError raise NotImplementedError
def name(self): def name(self):
"""Return the name of this builtin function."""
raise NotImplementedError raise NotImplementedError
def is_spmv_supported(self): def is_spmv_supported(self):
"""Return whether the SPMV optimization is supported."""
raise NotImplementedError raise NotImplementedError
class BundledReduceFunction(ReduceFunction): class BundledReduceFunction(ReduceFunction):
def __init__(self, fn_list): def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)): if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list] fn_list = [fn_list]
else:
# sanity check on out field
for fn in fn_list:
if isinstance(fn, ReduceFunction) and fn.out_field is None:
raise RuntimeError("Not specifying out field for multiple reduce is ambiguous")
self.fn_list = fn_list self.fn_list = fn_list
def is_spmv_supported(self): def is_spmv_supported(self):
...@@ -39,51 +42,50 @@ class BundledReduceFunction(ReduceFunction): ...@@ -39,51 +42,50 @@ class BundledReduceFunction(ReduceFunction):
if ret is None: if ret is None:
ret = rpr ret = rpr
else: else:
try: # ret and rpr must be dict
# ret and rpr must be dict ret.update(rpr)
ret.update(rpr)
except:
raise RuntimeError("Must specify out field for multiple reudce")
return ret return ret
def name(self): def name(self):
return "bundled" return "bundled"
class ReducerFunctionTemplate(ReduceFunction): class ReducerFunctionTemplate(ReduceFunction):
def __init__(self, name, batch_op, nonbatch_op, msg_field=None, out_field=None): def __init__(self, name, op, msg_field, out_field):
self.name = name self.name = name
self.batch_op = batch_op self.op = op
self.nonbatch_op = nonbatch_op
self.msg_field = msg_field self.msg_field = msg_field
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self): def is_spmv_supported(self):
# TODO: support max # NOTE: only sum is supported right now.
return self.name == "sum" return self.name == "sum"
def __call__(self, node, msgs): def __call__(self, node, msgs):
if isinstance(msgs, list): return {self.out_field : self.op(msgs[self.msg_field], 1)}
if self.msg_field is None:
ret = self.nonbatch_op(msgs)
else:
ret = self.nonbatch_op([msg[self.msg_field] for msg in msgs])
else:
if self.msg_field is None:
ret = self.batch_op(msgs, 1)
else:
ret = self.batch_op(msgs[self.msg_field], 1)
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self): def name(self):
return self.name return self.name
_python_sum = sum def sum(msg, out):
def sum(msgs=None, out=None): """Builtin reduce function that aggregates messages by sum.
return ReducerFunctionTemplate("sum", F.sum, _python_sum, msgs, out)
Parameters
----------
msg : str
The message name.
out : str
The output node feature name.
"""
return ReducerFunctionTemplate("sum", F.sum, msg, out)
def max(msg, out):
"""Builtin reduce function that aggregates messages by max.
_python_max = max Parameters
def max(msgs=None, out=None): ----------
return ReducerFunctionTemplate("max", F.max, _python_max, msgs, out) msg : str
The message name.
out : str
The output node feature name.
"""
return ReducerFunctionTemplate("max", F.max, msg, out)
"""Package for graph generators"""
from __future__ import absolute_import
from .line import *
"""Line graph generator."""
from __future__ import absolute_import
import networkx as nx
import numpy as np
from .. import backend as F
from ..graph import DGLGraph
from ..frame import FrameRef
def line_graph(G, no_backtracking=False):
"""Create the line graph that shares the underlying features.
The node features of the result line graph will share the edge features
of the given graph.
Parameters
----------
G : DGLGraph
The input graph.
no_backtracking : bool
Whether the backtracking edges are included in the line graph.
If i~j and j~i are two edges in original graph G, then
(i,j)~(j,i) and (j,i)~(i,j) are the "backtracking" edges on
the line graph.
"""
L = nx.DiGraph()
for eid, from_node in enumerate(G.edge_list):
L.add_node(from_node)
for to_node in G.edges(from_node[1]):
if no_backtracking and to_node[1] == from_node[0]:
continue
L.add_edge(from_node, to_node)
relabel_map = {}
for i, e in enumerate(G.edge_list):
relabel_map[e] = i
nx.relabel.relabel_nodes(L, relabel_map, copy=False)
return DGLGraph(L, node_frame=G._edge_frame)
This diff is collapsed.
...@@ -3,7 +3,7 @@ from __future__ import absolute_import ...@@ -3,7 +3,7 @@ from __future__ import absolute_import
import ctypes import ctypes
import numpy as np import numpy as np
import networkx as nx import networkx as nx
import scipy.sparse as sp import scipy
from ._ffi.base import c_array from ._ffi.base import c_array
from ._ffi.function import _init_api from ._ffi.function import _init_api
...@@ -600,30 +600,59 @@ class GraphIndex(object): ...@@ -600,30 +600,59 @@ class GraphIndex(object):
return GraphIndex(handle) return GraphIndex(handle)
class SubgraphIndex(GraphIndex): class SubgraphIndex(GraphIndex):
def __init__(self, handle, parent, induced_nodes, induced_edges): """Graph index for subgraph.
super().__init__(handle)
Parameters
----------
handle : GraphIndexHandle
The capi handle.
paranet : GraphIndex
The parent graph index.
induced_nodes : utils.Index
The parent node ids in this subgraph.
induced_edges : utils.Index
The parent edge ids in this subgraph.
"""
def __init__(self, handle, parent, induced_nodes, induced_edges):
super(SubgraphIndex, self).__init__(handle)
self._parent = parent self._parent = parent
self._induced_nodes = induced_nodes self._induced_nodes = induced_nodes
self._induced_edges = induced_edges self._induced_edges = induced_edges
def add_nodes(self, num): def add_nodes(self, num):
"""Add nodes. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.') raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v): def add_edge(self, u, v):
"""Add edges. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.') raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v): def add_edges(self, u, v):
"""Add edges. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.') raise RuntimeError('Readonly graph. Mutation is not allowed.')
@property
def induced_edges(self):
return self._induced_edges
@property @property
def induced_nodes(self): def induced_nodes(self):
"""Return parent node ids.
Returns
-------
utils.Index
The parent node ids.
"""
return self._induced_nodes return self._induced_nodes
@property
def induced_edges(self):
"""Return parent edge ids.
Returns
-------
utils.Index
The parent edge ids.
"""
return self._induced_edges
def disjoint_union(graphs): def disjoint_union(graphs):
"""Return a disjoint union of the input graphs. """Return a disjoint union of the input graphs.
...@@ -697,8 +726,25 @@ def create_graph_index(graph_data=None, multigraph=False): ...@@ -697,8 +726,25 @@ def create_graph_index(graph_data=None, multigraph=False):
handle = _CAPI_DGLGraphCreate(multigraph) handle = _CAPI_DGLGraphCreate(multigraph)
gi = GraphIndex(handle) gi = GraphIndex(handle)
if graph_data is not None:
if graph_data is None:
return gi
# scipy format
if isinstance(graph_data, scipy.sparse.spmatrix):
try:
gi.from_scipy_sparse_matrix(graph_data)
return gi
except:
raise Exception('Graph data is not a valid scipy sparse matrix.')
# networkx - any format
try:
gi.from_networkx(graph_data) gi.from_networkx(graph_data)
except:
raise Exception('Error while creating graph from input of type "%s".'
% type(graph_data))
return gi return gi
_init_api("dgl.graph_index") _init_api("dgl.graph_index")
...@@ -3,7 +3,7 @@ from __future__ import absolute_import ...@@ -3,7 +3,7 @@ from __future__ import absolute_import
import numpy as np import numpy as np
from .base import ALL, __MSG__, __REPR__ from .base import ALL, DGLError
from . import backend as F from . import backend as F
from .function import message as fmsg from .function import message as fmsg
from .function import reducer as fred from .function import reducer as fred
...@@ -111,7 +111,15 @@ def light_degree_bucketing_for_graph(graph): ...@@ -111,7 +111,15 @@ def light_degree_bucketing_for_graph(graph):
class Executor(object): class Executor(object):
"""Base class for executing graph computation."""
def run(self): def run(self):
"""Run this executor.
This should return the new node features.
TODO(minjie): extend this to support computation on edges.
"""
raise NotImplementedError raise NotImplementedError
class SPMVOperator(Executor): class SPMVOperator(Executor):
...@@ -126,10 +134,7 @@ class SPMVOperator(Executor): ...@@ -126,10 +134,7 @@ class SPMVOperator(Executor):
def run(self): def run(self):
# get src col # get src col
if self.src_field is None: srccol = self.node_repr[self.src_field]
srccol = self.node_repr
else:
srccol = self.node_repr[self.src_field]
ctx = F.get_context(srccol) ctx = F.get_context(srccol)
# build adjmat # build adjmat
...@@ -142,10 +147,7 @@ class SPMVOperator(Executor): ...@@ -142,10 +147,7 @@ class SPMVOperator(Executor):
dstcol = F.squeeze(dstcol) dstcol = F.squeeze(dstcol)
else: else:
dstcol = F.spmm(adjmat, srccol) dstcol = F.spmm(adjmat, srccol)
if self.dst_field is None: return {self.dst_field : dstcol}
return dstcol
else:
return {self.dst_field : dstcol}
# FIXME: refactorize in scheduler/executor redesign # FIXME: refactorize in scheduler/executor redesign
...@@ -180,20 +182,14 @@ class DegreeBucketingExecutor(Executor): ...@@ -180,20 +182,14 @@ class DegreeBucketingExecutor(Executor):
msg_shape = F.shape(msg) msg_shape = F.shape(msg)
new_shape = (len(vv), deg) + msg_shape[1:] new_shape = (len(vv), deg) + msg_shape[1:]
return F.reshape(msg, new_shape) return F.reshape(msg, new_shape)
if len(in_msgs) == 1 and __MSG__ in in_msgs: reshaped_in_msgs = utils.LazyDict(
reshaped_in_msgs = _reshape_fn(in_msgs[__MSG__]) lambda key: _reshape_fn(in_msgs[key]), self.msg_frame.schemes)
else:
reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self.msg_frame.schemes)
new_reprs.append(self.rfunc(dst_reprs, reshaped_in_msgs)) new_reprs.append(self.rfunc(dst_reprs, reshaped_in_msgs))
# Pack all reducer results together # Pack all reducer results together
if utils.is_dict_like(new_reprs[0]): keys = new_reprs[0].keys()
keys = new_reprs[0].keys() new_reprs = {key : F.pack([repr[key] for repr in new_reprs])
new_reprs = {key : F.pack([repr[key] for repr in new_reprs]) for key in keys}
for key in keys}
else:
new_reprs = {__REPR__ : F.pack(new_reprs)}
return new_reprs return new_reprs
...@@ -249,12 +245,6 @@ class UpdateAllExecutor(BasicExecutor): ...@@ -249,12 +245,6 @@ class UpdateAllExecutor(BasicExecutor):
self._graph_shape = None self._graph_shape = None
self._recv_nodes = None self._recv_nodes = None
@property
def graph_idx(self):
if self._graph_idx is None:
self._graph_idx = self.g._graph.adjacency_matrix()
return self._graph_idx
@property @property
def graph_shape(self): def graph_shape(self):
if self._graph_shape is None: if self._graph_shape is None:
...@@ -280,16 +270,13 @@ class UpdateAllExecutor(BasicExecutor): ...@@ -280,16 +270,13 @@ class UpdateAllExecutor(BasicExecutor):
def _adj_build_fn(self, edge_field, ctx, use_edge_feat): def _adj_build_fn(self, edge_field, ctx, use_edge_feat):
if use_edge_feat: if use_edge_feat:
if edge_field is None: dat = self.edge_repr[edge_field]
dat = self.edge_repr
else:
dat = self.edge_repr[edge_field]
dat = F.squeeze(dat) dat = F.squeeze(dat)
# TODO(minjie): should not directly use _indices # TODO(minjie): should not directly use _indices
idx = self.graph_idx.get(ctx)._indices() idx = self.g.adjacency_matrix(ctx)._indices()
adjmat = F.sparse_tensor(idx, dat, self.graph_shape) adjmat = F.sparse_tensor(idx, dat, self.graph_shape)
else: else:
adjmat = self.graph_idx.get(ctx) adjmat = self.g.adjacency_matrix(ctx)
return adjmat return adjmat
...@@ -351,10 +338,7 @@ class SendRecvExecutor(BasicExecutor): ...@@ -351,10 +338,7 @@ class SendRecvExecutor(BasicExecutor):
def _adj_build_fn(self, edge_field, ctx, use_edge_feat): def _adj_build_fn(self, edge_field, ctx, use_edge_feat):
if use_edge_feat: if use_edge_feat:
if edge_field is None: dat = self.edge_repr[edge_field]
dat = self.edge_repr
else:
dat = self.edge_repr[edge_field]
dat = F.squeeze(dat) dat = F.squeeze(dat)
else: else:
dat = F.ones((len(self.u), )) dat = F.ones((len(self.u), ))
...@@ -386,9 +370,8 @@ class BundledExecutor(BasicExecutor): ...@@ -386,9 +370,8 @@ class BundledExecutor(BasicExecutor):
func_pairs = [] func_pairs = []
for rfn in rfunc.fn_list: for rfn in rfunc.fn_list:
mfn = out2mfunc.get(rfn.msg_field, None) mfn = out2mfunc.get(rfn.msg_field, None)
# field check if mfn is None:
assert mfn is not None, \ raise DGLError('Cannot find message field "%s".' % rfn.msg_field)
"cannot find message func for reduce func in-field {}".format(rfn.msg_field)
func_pairs.append((mfn, rfn)) func_pairs.append((mfn, rfn))
return func_pairs return func_pairs
...@@ -409,7 +392,6 @@ class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor): ...@@ -409,7 +392,6 @@ class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor):
self._init_state() self._init_state()
BundledExecutor.__init__(self, graph, mfunc, rfunc) BundledExecutor.__init__(self, graph, mfunc, rfunc)
class BundledSendRecvExecutor(BundledExecutor, SendRecvExecutor): class BundledSendRecvExecutor(BundledExecutor, SendRecvExecutor):
def __init__(self, graph, src, dst, mfunc, rfunc): def __init__(self, graph, src, dst, mfunc, rfunc):
self._init_state(src, dst) self._init_state(src, dst)
......
/*!
* Copyright (c) 2018 by Contributors
* \file c_runtime_api.cc
* \brief DGL C API common implementations
*/
#include "c_api_common.h" #include "c_api_common.h"
using tvm::runtime::TVMArgs; using tvm::runtime::TVMArgs;
...@@ -29,5 +34,5 @@ PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) { ...@@ -29,5 +34,5 @@ PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) {
return PackedFunc(body); return PackedFunc(body);
} }
} // namespace dgl } // namespace dgl
// DGL C API common util functions /*!
* Copyright (c) 2018 by Contributors
* \file c_api_common.h
* \brief DGL C API common util functions
*/
#ifndef DGL_C_API_COMMON_H_ #ifndef DGL_C_API_COMMON_H_
#define DGL_C_API_COMMON_H_ #define DGL_C_API_COMMON_H_
...@@ -12,12 +16,20 @@ namespace dgl { ...@@ -12,12 +16,20 @@ namespace dgl {
// Graph handler type // Graph handler type
typedef void* GraphHandle; typedef void* GraphHandle;
// Convert the given DLTensor to a temporary DLManagedTensor that does not own memory. /*!
DLManagedTensor* CreateTmpDLManagedTensor(const tvm::runtime::TVMArgValue& arg); * \brief Convert the given DLTensor to DLManagedTensor.
*
* Return a temporary DLManagedTensor that does not own memory.
*/
DLManagedTensor* CreateTmpDLManagedTensor(
const tvm::runtime::TVMArgValue& arg);
// Convert a vector of NDArray to PackedFunc /*!
tvm::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<tvm::runtime::NDArray>& vec); * \brief Convert a vector of NDArray to PackedFunc.
*/
tvm::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc(
const std::vector<tvm::runtime::NDArray>& vec);
} // namespace dgl } // namespace dgl
#endif // DGL_C_API_COMMON_H_ #endif // DGL_C_API_COMMON_H_
// Graph class implementation /*!
* Copyright (c) 2018 by Contributors
* \file graph/graph.cc
* \brief DGL graph index implementation
*/
#include <dgl/graph.h>
#include <algorithm> #include <algorithm>
#include <unordered_map> #include <unordered_map>
#include <set> #include <set>
#include <functional> #include <functional>
#include <dgl/graph.h>
namespace dgl { namespace dgl {
namespace { namespace {
...@@ -193,9 +197,9 @@ Graph::EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const { ...@@ -193,9 +197,9 @@ Graph::EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
const auto& succ = adjlist_[src_id].succ; const auto& succ = adjlist_[src_id].succ;
for (size_t k = 0; k < succ.size(); ++k) { for (size_t k = 0; k < succ.size(); ++k) {
if (succ[k] == dst_id) { if (succ[k] == dst_id) {
src.push_back(src_id); src.push_back(src_id);
dst.push_back(dst_id); dst.push_back(dst_id);
eid.push_back(adjlist_[src_id].edge_id[k]); eid.push_back(adjlist_[src_id].edge_id[k]);
} }
} }
} }
...@@ -351,7 +355,7 @@ Graph::EdgeArray Graph::Edges(bool sorted) const { ...@@ -351,7 +355,7 @@ Graph::EdgeArray Graph::Edges(bool sorted) const {
return std::get<0>(t1) < std::get<0>(t2) return std::get<0>(t1) < std::get<0>(t2)
|| (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2)); || (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2));
}); });
// make return arrays // make return arrays
int64_t* src_ptr = static_cast<int64_t*>(src->data); int64_t* src_ptr = static_cast<int64_t*>(src->data);
int64_t* dst_ptr = static_cast<int64_t*>(dst->data); int64_t* dst_ptr = static_cast<int64_t*>(dst->data);
...@@ -461,7 +465,8 @@ Subgraph Graph::EdgeSubgraph(IdArray eids) const { ...@@ -461,7 +465,8 @@ Subgraph Graph::EdgeSubgraph(IdArray eids) const {
rst.graph.AddEdge(oldv2newv[src_id], oldv2newv[dst_id]); rst.graph.AddEdge(oldv2newv[src_id], oldv2newv[dst_id]);
} }
rst.induced_vertices = IdArray::Empty({static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx); rst.induced_vertices = IdArray::Empty(
{static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);
std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data)); std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data));
return rst; return rst;
......
/*!
* Copyright (c) 2018 by Contributors
* \file graph/graph.cc
* \brief DGL graph index APIs
*/
#include <dgl/graph.h> #include <dgl/graph.h>
#include <dgl/graph_op.h> #include <dgl/graph_op.h>
#include "../c_api_common.h" #include "../c_api_common.h"
......
// Graph operation implementation /*!
* Copyright (c) 2018 by Contributors
* \file graph/graph.cc
* \brief Graph operation implementation
*/
#include <dgl/graph_op.h> #include <dgl/graph_op.h>
#include <algorithm> #include <algorithm>
namespace dgl { namespace dgl {
Graph GraphOp::LineGraph(const Graph* g, bool backtracking){ Graph GraphOp::LineGraph(const Graph* g, bool backtracking) {
typedef std::pair<dgl_id_t, dgl_id_t> entry; typedef std::pair<dgl_id_t, dgl_id_t> entry;
typedef std::map<dgl_id_t, std::vector<entry>> csm; // Compressed Sparse Matrix typedef std::map<dgl_id_t, std::vector<entry>> csm; // Compressed Sparse Matrix
csm adj; csm adj;
std::vector<entry> vec; std::vector<entry> vec;
...@@ -67,7 +71,7 @@ std::vector<Graph> GraphOp::DisjointPartitionByNum(const Graph* graph, int64_t n ...@@ -67,7 +71,7 @@ std::vector<Graph> GraphOp::DisjointPartitionByNum(const Graph* graph, int64_t n
std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num); std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num);
return DisjointPartitionBySizes(graph, sizes); return DisjointPartitionBySizes(graph, sizes);
} }
std::vector<Graph> GraphOp::DisjointPartitionBySizes(const Graph* graph, IdArray sizes) { std::vector<Graph> GraphOp::DisjointPartitionBySizes(const Graph* graph, IdArray sizes) {
const int64_t len = sizes->shape[0]; const int64_t len = sizes->shape[0];
const int64_t* sizes_data = static_cast<int64_t*>(sizes->data); const int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
...@@ -117,32 +121,6 @@ std::vector<Graph> GraphOp::DisjointPartitionBySizes(const Graph* graph, IdArray ...@@ -117,32 +121,6 @@ std::vector<Graph> GraphOp::DisjointPartitionBySizes(const Graph* graph, IdArray
node_offset += sizes_data[i]; node_offset += sizes_data[i];
edge_offset += num_edges; edge_offset += num_edges;
} }
/*for (int64_t i = 0; i < len; ++i) {
rst[i].AddVertices(sizes_data[i]);
}
for (dgl_id_t eid = 0; eid < graph->num_edges_; ++eid) {
const dgl_id_t src = graph->all_edges_src_[eid];
const dgl_id_t dst = graph->all_edges_dst_[eid];
size_t src_select = 0, dst_select = 0;
for (size_t i = 1; i < cumsum.size(); ++i) { // TODO: replace with binary search
if (cumsum[i] > src) {
src_select = i;
break;
}
}
for (size_t i = 1; i < cumsum.size(); ++i) { // TODO: replace with binary search
if (cumsum[i] > dst) {
dst_select = i;
break;
}
}
if (src_select != dst_select) {
// the edge is ignored if across two partitions
continue;
}
const int64_t offset = cumsum[src_select - 1];
rst[src_select - 1].AddEdge(src - offset, dst - offset);
}*/
return rst; return rst;
} }
......
# C API and runtime
Borrowed and adapted from TVM project.
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
* \file file_util.h * \file file_util.h
* \brief Minimum file manipulation util for runtime. * \brief Minimum file manipulation util for runtime.
*/ */
#ifndef TVM_RUNTIME_FILE_UTIL_H_ #ifndef DGL_RUNTIME_FILE_UTIL_H_
#define TVM_RUNTIME_FILE_UTIL_H_ #define DGL_RUNTIME_FILE_UTIL_H_
#include <string> #include <string>
#include "meta_data.h" #include "meta_data.h"
...@@ -73,4 +73,4 @@ void LoadMetaDataFromFile( ...@@ -73,4 +73,4 @@ void LoadMetaDataFromFile(
std::unordered_map<std::string, FunctionInfo>* fmap); std::unordered_map<std::string, FunctionInfo>* fmap);
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_FILE_UTIL_H_ #endif // DGL_RUNTIME_FILE_UTIL_H_
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
* \file meta_data.h * \file meta_data.h
* \brief Meta data related utilities * \brief Meta data related utilities
*/ */
#ifndef TVM_RUNTIME_META_DATA_H_ #ifndef DGL_RUNTIME_META_DATA_H_
#define TVM_RUNTIME_META_DATA_H_ #define DGL_RUNTIME_META_DATA_H_
#include <dmlc/json.h> #include <dmlc/json.h>
#include <dmlc/io.h> #include <dmlc/io.h>
...@@ -33,4 +33,4 @@ struct FunctionInfo { ...@@ -33,4 +33,4 @@ struct FunctionInfo {
namespace dmlc { namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::FunctionInfo, true); DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::FunctionInfo, true);
} // namespace dmlc } // namespace dmlc
#endif // TVM_RUNTIME_META_DATA_H_ #endif // DGL_RUNTIME_META_DATA_H_
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
* \file module_util.h * \file module_util.h
* \brief Helper utilities for module building * \brief Helper utilities for module building
*/ */
#ifndef TVM_RUNTIME_MODULE_UTIL_H_ #ifndef DGL_RUNTIME_MODULE_UTIL_H_
#define TVM_RUNTIME_MODULE_UTIL_H_ #define DGL_RUNTIME_MODULE_UTIL_H_
#include <dgl/runtime/module.h> #include <dgl/runtime/module.h>
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
...@@ -58,4 +58,4 @@ void InitContextFunctions(FLookup flookup) { ...@@ -58,4 +58,4 @@ void InitContextFunctions(FLookup flookup) {
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_MODULE_UTIL_H_ #endif // DGL_RUNTIME_MODULE_UTIL_H_
...@@ -10,8 +10,8 @@ ...@@ -10,8 +10,8 @@
* union_32bit args[N], int num_args); * union_32bit args[N], int num_args);
* - Pack buffer by address, pack rest parameter into 32bit union buffer. * - Pack buffer by address, pack rest parameter into 32bit union buffer.
*/ */
#ifndef TVM_RUNTIME_PACK_ARGS_H_ #ifndef DGL_RUNTIME_PACK_ARGS_H_
#define TVM_RUNTIME_PACK_ARGS_H_ #define DGL_RUNTIME_PACK_ARGS_H_
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
#include <vector> #include <vector>
...@@ -307,4 +307,4 @@ inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types) ...@@ -307,4 +307,4 @@ inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types)
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_PACK_ARGS_H_ #endif // DGL_RUNTIME_PACK_ARGS_H_
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
* \file runtime_base.h * \file runtime_base.h
* \brief Base of all C APIs * \brief Base of all C APIs
*/ */
#ifndef TVM_RUNTIME_RUNTIME_BASE_H_ #ifndef DGL_RUNTIME_RUNTIME_BASE_H_
#define TVM_RUNTIME_RUNTIME_BASE_H_ #define DGL_RUNTIME_RUNTIME_BASE_H_
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
#include <stdexcept> #include <stdexcept>
...@@ -31,4 +31,4 @@ inline int TVMAPIHandleException(const std::runtime_error &e) { ...@@ -31,4 +31,4 @@ inline int TVMAPIHandleException(const std::runtime_error &e) {
return -1; return -1;
} }
#endif // TVM_RUNTIME_RUNTIME_BASE_H_ #endif // DGL_RUNTIME_RUNTIME_BASE_H_
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
* \file thread_storage_scope.h * \file thread_storage_scope.h
* \brief Extract thread axis configuration from TVMArgs. * \brief Extract thread axis configuration from TVMArgs.
*/ */
#ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #ifndef DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_
#define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #define DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/packed_func.h>
#include <string> #include <string>
...@@ -204,4 +204,4 @@ struct hash<::tvm::runtime::StorageScope> { ...@@ -204,4 +204,4 @@ struct hash<::tvm::runtime::StorageScope> {
} }
}; };
} // namespace std } // namespace std
#endif // TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #endif // DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment