Commit 4aebfd7b authored by Minjie Wang's avatar Minjie Wang
Browse files

add merge graph; fix bug in from_networkx

parent 148cc048
......@@ -50,6 +50,9 @@ class Graph {
#else
Graph(Graph&& other) {
adjlist_ = other.adjlist_;
reverse_adjlist_ = other.reverse_adjlist_;
all_edges_src_ = other.all_edges_src_;
all_edges_dst_ = other.all_edges_dst_;
read_only_ = other.read_only_;
num_edges_ = other.num_edges_;
other.clear();
......@@ -90,6 +93,8 @@ class Graph {
void Clear() {
adjlist_.clear();
reverse_adjlist_.clear();
all_edges_src_.clear();
all_edges_dst_.clear();
read_only_ = false;
num_edges_ = 0;
}
......@@ -270,12 +275,22 @@ class Graph {
*/
Graph Reverse() const;
// TODO
std::vector<Graph> Split(std::vector<IdArray> vids_array) const;
/*!
* \brief Merge several graphs
* \brief Merge several graphs into one graph.
*
* The new graph will include all the nodes/edges in the given graphs.
* Nodes/Edges will be relabled by adding the cumsum of the previous graph sizes
* in the given sequence order. For example, giving input [g1, g2, g3], where
* they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become node#7
* in the result graph. Edge ids are re-assigned similarly.
*
* \param graphs A list of input graphs to be merged.
* \return the merged graph
*/
static Graph Merge(std::vector<Graph> graphs);
static Graph Merge(std::vector<const Graph*> graphs);
private:
/*! \brief Internal edge list type */
......@@ -298,7 +313,7 @@ class Graph {
std::vector<dgl_id_t> all_edges_dst_;
/*! \brief read only flag */
bool read_only_{false};
bool read_only_ = false;
/*! \brief number of edges */
uint64_t num_edges_ = 0;
};
......
......@@ -11,7 +11,7 @@ from .backend import Tensor
from .frame import FrameRef, merge_frames
from .function.message import BundledMessageFunction
from .function.reducer import BundledReduceFunction
from .graph_index import GraphIndex
from .graph_index import GraphIndex, create_graph_index
from . import scheduler
from . import utils
......@@ -41,12 +41,12 @@ class DGLGraph(object):
**attr):
# TODO: keyword attr
# graph
self._graph = GraphIndex(graph_data)
self._graph = create_graph_index(graph_data)
# frame
self._node_frame = node_frame if node_frame is not None else FrameRef()
self._edge_frame = edge_frame if edge_frame is not None else FrameRef()
# other class members
self._msg_graph = GraphIndex()
self._msg_graph = create_graph_index()
self._msg_frame = FrameRef()
self._message_func = (None, None)
self._reduce_func = (None, None)
......
from __future__ import absolute_import
import ctypes
import numpy as np
import networkx as nx
from ._ffi.base import c_array
from ._ffi.function import _init_api
from . import backend as F
from . import utils
GraphIndexHandle = ctypes.c_void_p
class GraphIndex(object):
"""Graph index object.
Parameters
----------
graph_data : graph data, optional
Data to initialize graph. Same as networkx's semantics.
handle : GraphIndexHandle
Handler
"""
def __init__(self, graph_data=None):
self._handle = _CAPI_DGLGraphCreate()
if isinstance(graph_data, nx.DiGraph):
self.from_networkx(graph_data)
elif graph_data is not None:
self.from_networkx(nx.DiGraph(graph_data))
def __init__(self, handle):
self._handle = handle
self._cache = {}
def __del__(self):
......@@ -414,6 +414,8 @@ class GraphIndex(object):
The nx graph
"""
self.clear()
if not isinstance(nx_graph, nx.DiGraph):
nx_graph = nx.DiGraph(nx_graph)
num_nodes = nx_graph.number_of_nodes()
self.add_nodes(num_nodes)
has_edge_id = 'id' in next(iter(nx_graph.edges))
......@@ -436,4 +438,43 @@ class GraphIndex(object):
dst = utils.toindex(dst)
self.add_edges(src, dst)
@staticmethod
def merge(graphs):
"""Merge a list of graphs into one graph.
The new graph will include all the nodes/edges in the given graphs.
Nodes/Edges will be relabled by adding the cumsum of the previous graph sizes
in the given sequence order. For example, giving input [g1, g2, g3], where
they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become node#7
in the result graph. Edge ids are re-assigned similarly.
Parameters
----------
graphs : iterable of GraphIndex
The input graphs
Returns
-------
GraphIndex
The merged graph
"""
inputs = c_array(GraphIndexHandle, [gr._handle for gr in graphs])
inputs = ctypes.cast(inputs, ctypes.c_void_p)
handle = _CAPI_DGLGraphMerge(inputs, len(graphs))
return GraphIndex(handle)
def create_graph_index(graph_data=None):
"""Create a graph index object.
Parameters
----------
graph_data : graph data, optional
Data to initialize graph. Same as networkx's semantics.
"""
handle = _CAPI_DGLGraphCreate()
gi = GraphIndex(handle)
if graph_data is not None:
gi.from_networkx(graph_data)
return gi
_init_api("dgl.graph_index")
......@@ -360,4 +360,17 @@ Graph Graph::Reverse() const {
return *this;
}
Graph Graph::Merge(std::vector<const Graph*> graphs) {
Graph rst;
uint64_t cumsum = 0;
for (const Graph* gr : graphs) {
rst.AddVertices(gr->NumVertices());
for (uint64_t i = 0; i < gr->NumEdges(); ++i) {
rst.AddEdge(gr->all_edges_src_[i] + cumsum, gr->all_edges_dst_[i] + cumsum);
}
cumsum += gr->NumVertices();
}
return rst;
}
} // namespace dgl
......@@ -242,4 +242,20 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegrees")
*rv = gptr->OutDegrees(vids);
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphMerge")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
void* list = args[0];
GraphHandle* inhandles = static_cast<GraphHandle*>(list);
int list_size = args[1];
std::vector<const Graph*> graphs;
for (int i = 0; i < list_size; ++i) {
const Graph* gr = static_cast<const Graph*>(inhandles[i]);
graphs.push_back(gr);
}
Graph* gptr = new Graph();
*gptr = Graph::Merge(std::move(graphs));
GraphHandle ghandle = gptr;
*rv = ghandle;
});
} // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment