Commit 2be55fb5 authored by Minjie Wang's avatar Minjie Wang
Browse files

graph batch and unbatch

parent 7d04c8c9
...@@ -31,14 +31,28 @@ class GraphOp { ...@@ -31,14 +31,28 @@ class GraphOp {
/*! /*!
* \brief Partition the graph into several subgraphs. * \brief Partition the graph into several subgraphs.
* *
* The graph will be partitioned by the node ids. Edges between partitions * This is a reverse operation of DisjointUnion. The graph will be partitioned
* will be ignored. This requires the given number of partitions to evenly * into num graphs. This requires the given number of partitions to evenly
* divides the number of nodes in the graph. * divides the number of nodes in the graph.
* *
* \param graph The graph to be partitioned.
* \param num The number of partitions. * \param num The number of partitions.
* \return a list of partitioned graphs * \return a list of partitioned graphs
*/ */
static std::vector<Graph> PartitionByNum(const Graph* graph, size_t num); static std::vector<Graph> DisjointPartitionByNum(const Graph* graph, int64_t num);
/*!
* \brief Partition the graph into several subgraphs.
*
* This is a reverse operation of DisjointUnion. The graph will be partitioned
* based on the given sizes. This requires the sum of the given sizes is equal
* to the number of nodes in the graph.
*
* \param graph The graph to be partitioned.
* \param sizes The number of partitions.
* \return a list of partitioned graphs
*/
static std::vector<Graph> DisjointPartitionBySizes(const Graph* graph, IdArray sizes);
}; };
} // namespace dgl } // namespace dgl
......
...@@ -8,6 +8,7 @@ from .frame import FrameRef ...@@ -8,6 +8,7 @@ from .frame import FrameRef
from .graph import DGLGraph from .graph import DGLGraph
from . import graph_index as gi from . import graph_index as gi
from . import backend as F from . import backend as F
from . import utils
class BatchedDGLGraph(DGLGraph): class BatchedDGLGraph(DGLGraph):
"""The batched DGL graph. """The batched DGL graph.
...@@ -24,7 +25,6 @@ class BatchedDGLGraph(DGLGraph): ...@@ -24,7 +25,6 @@ class BatchedDGLGraph(DGLGraph):
The edge attributes to also be batched. The edge attributes to also be batched.
""" """
def __init__(self, graph_list, node_attrs, edge_attrs): def __init__(self, graph_list, node_attrs, edge_attrs):
# TODO(minjie): handle the input is again a batched graph.
# create batched graph index # create batched graph index
batched_index = gi.disjoint_union([g._graph for g in graph_list]) batched_index = gi.disjoint_union([g._graph for g in graph_list])
# create batched node and edge frames # create batched node and edge frames
...@@ -43,9 +43,19 @@ class BatchedDGLGraph(DGLGraph): ...@@ -43,9 +43,19 @@ class BatchedDGLGraph(DGLGraph):
edge_frame=batched_edge_frame) edge_frame=batched_edge_frame)
# extra members # extra members
self._batch_size = len(graph_list) self._batch_size = 0
self._batch_num_nodes = [gr.number_of_nodes() for gr in graph_list] self._batch_num_nodes = []
self._batch_num_edges = [gr.number_of_edges() for gr in graph_list] self._batch_num_edges = []
for gr in graph_list:
if isinstance(gr, BatchedDGLGraph):
# handle the input is again a batched graph.
self._batch_size += gr._batch_size
self._batch_num_nodes += gr._batch_num_nodes
self._batch_num_edges += gr._batch_num_edges
else:
self._batch_size += 1
self._batch_num_nodes.append(gr.number_of_nodes())
self._batch_num_edges.append(gr.number_of_edges())
@property @property
def batch_size(self): def batch_size(self):
...@@ -78,10 +88,12 @@ class BatchedDGLGraph(DGLGraph): ...@@ -78,10 +88,12 @@ class BatchedDGLGraph(DGLGraph):
# new APIs # new APIs
def __getitem__(self, idx): def __getitem__(self, idx):
"""Slice the batch and return the batch of graphs specified by the idx.""" """Slice the batch and return the batch of graphs specified by the idx."""
# TODO
pass pass
def __setitem__(self, idx, val): def __setitem__(self, idx, val):
"""Set the value of the slice. The graph size cannot be changed.""" """Set the value of the slice. The graph size cannot be changed."""
# TODO
pass pass
''' '''
...@@ -114,37 +126,36 @@ def split(graph_batch, num_or_size_splits): ...@@ -114,37 +126,36 @@ def split(graph_batch, num_or_size_splits):
# TODO(minjie): could follow torch.split syntax # TODO(minjie): could follow torch.split syntax
pass pass
def unbatch(graph_batch): def unbatch(graph):
"""Unbatch the graph and return a list of subgraphs. """Unbatch the graph and return a list of subgraphs.
Parameters Parameters
---------- ----------
graph_batch : DGLGraph graph : BatchedDGLGraph
The batched graph. The batched graph.
""" """
assert False, "disabled for now" assert isinstance(graph, BatchedDGLGraph)
graph_list = graph_batch.graph_list bsize = graph.batch_size
num_graphs = len(graph_list) bn = graph.batch_num_nodes
# split and set node attrs be = graph.batch_num_edges
attrs = [{} for _ in range(num_graphs)] # node attr dict for each graph pttns = gi.disjoint_partition(graph._graph, utils.toindex(bn))
for key in graph_batch.node_attr_schemes(): # split the frames
vals = F.unpack(graph_batch.pop_n_repr(key), graph_batch.num_nodes) node_frames = [FrameRef() for i in range(bsize)]
for attr, val in zip(attrs, vals): edge_frames = [FrameRef() for i in range(bsize)]
attr[key] = val for attr, col in graph._node_frame.items():
for attr, g in zip(attrs, graph_list): # TODO: device context
g.set_n_repr(attr) col_splits = F.unpack(col, bn)
for i in range(bsize):
# split and set edge attrs node_frames[i][attr] = col_splits[i]
attrs = [{} for _ in range(num_graphs)] # edge attr dict for each graph for attr, col in graph._edge_frame.items():
for key in graph_batch.edge_attr_schemes(): # TODO: device context
vals = F.unpack(graph_batch.pop_e_repr(key), graph_batch.num_edges) col_splits = F.unpack(col, be)
for attr, val in zip(attrs, vals): for i in range(bsize):
attr[key] = val edge_frames[i][attr] = col_splits[i]
for attr, g in zip(attrs, graph_list): return [DGLGraph(graph_data=pttns[i],
g.set_e_repr(attr) node_frame=node_frames[i],
edge_frame=edge_frames[i]) for i in range(bsize)]
return graph_list
def batch(graph_list, node_attrs=ALL, edge_attrs=ALL): def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
"""Batch a list of DGLGraphs into one single graph. """Batch a list of DGLGraphs into one single graph.
......
...@@ -483,6 +483,40 @@ def disjoint_union(graphs): ...@@ -483,6 +483,40 @@ def disjoint_union(graphs):
handle = _CAPI_DGLDisjointUnion(inputs, len(graphs)) handle = _CAPI_DGLDisjointUnion(inputs, len(graphs))
return GraphIndex(handle) return GraphIndex(handle)
def disjoint_partition(graph, num_or_size_splits):
"""Partition the graph disjointly.
This is a reverse operation of DisjointUnion. The graph will be partitioned
into num graphs. This requires the given number of partitions to evenly
divides the number of nodes in the graph. If the a size list is given,
the sum of the given sizes is equal.
Parameters
----------
graph : GraphIndex
The graph to be partitioned
num_or_size_splits : int or utils.Index
The partition number of size splits
Returns
-------
list of GraphIndex
The partitioned graphs
"""
if isinstance(num_or_size_splits, utils.Index):
rst = _CAPI_DGLDisjointPartitionBySizes(
graph._handle,
num_or_size_splits.todgltensor())
else:
rst = _CAPI_DGLDisjointPartitionByNum(
graph._handle,
int(num_or_size_splits))
graphs = []
for val in rst.asnumpy():
handle = ctypes.cast(int(val), ctypes.c_void_p)
graphs.append(GraphIndex(handle))
return graphs
def create_graph_index(graph_data=None): def create_graph_index(graph_data=None):
"""Create a graph index object. """Create a graph index object.
......
...@@ -7,6 +7,7 @@ using tvm::runtime::TVMArgs; ...@@ -7,6 +7,7 @@ using tvm::runtime::TVMArgs;
using tvm::runtime::TVMArgValue; using tvm::runtime::TVMArgValue;
using tvm::runtime::TVMRetValue; using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc; using tvm::runtime::PackedFunc;
using tvm::runtime::NDArray;
namespace dgl { namespace dgl {
...@@ -289,4 +290,39 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion") ...@@ -289,4 +290,39 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
*rv = ghandle; *rv = ghandle;
}); });
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
int64_t num = args[1];
std::vector<Graph>&& rst = GraphOp::DisjointPartitionByNum(gptr, num);
// return the pointer array as an integer array
const int64_t len = rst.size();
NDArray ptr_array = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* ptr_array_data = static_cast<int64_t*>(ptr_array->data);
for (size_t i = 0; i < rst.size(); ++i) {
Graph* ptr = new Graph();
*ptr = std::move(rst[i]);
ptr_array_data[i] = reinterpret_cast<std::intptr_t>(ptr);
}
*rv = ptr_array;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray sizes = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
std::vector<Graph>&& rst = GraphOp::DisjointPartitionBySizes(gptr, sizes);
// return the pointer array as an integer array
const int64_t len = rst.size();
NDArray ptr_array = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* ptr_array_data = static_cast<int64_t*>(ptr_array->data);
for (size_t i = 0; i < rst.size(); ++i) {
Graph* ptr = new Graph();
*ptr = std::move(rst[i]);
ptr_array_data[i] = reinterpret_cast<std::intptr_t>(ptr);
}
*rv = ptr_array;
});
} // namespace dgl } // namespace dgl
// Graph operation implementation // Graph operation implementation
#include <dgl/graph_op.h> #include <dgl/graph_op.h>
#include <algorithm>
namespace dgl { namespace dgl {
...@@ -16,4 +17,91 @@ Graph GraphOp::DisjointUnion(std::vector<const Graph*> graphs) { ...@@ -16,4 +17,91 @@ Graph GraphOp::DisjointUnion(std::vector<const Graph*> graphs) {
return rst; return rst;
} }
std::vector<Graph> GraphOp::DisjointPartitionByNum(const Graph* graph, int64_t num) {
CHECK(num != 0 && graph->NumVertices() % num == 0)
<< "Number of partitions must evenly divide the number of nodes.";
IdArray sizes = IdArray::Empty({num}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num);
return DisjointPartitionBySizes(graph, sizes);
}
std::vector<Graph> GraphOp::DisjointPartitionBySizes(const Graph* graph, IdArray sizes) {
const int64_t len = sizes->shape[0];
const int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::vector<int64_t> cumsum;
cumsum.push_back(0);
for (int64_t i = 0; i < len; ++i) {
cumsum.push_back(cumsum[i] + sizes_data[i]);
}
CHECK_EQ(cumsum[len], graph->NumVertices())
<< "Sum of the given sizes must equal to the number of nodes.";
dgl_id_t node_offset = 0, edge_offset = 0;
std::vector<Graph> rst(len);
for (int64_t i = 0; i < len; ++i) {
// copy adj
rst[i].adjlist_.insert(rst[i].adjlist_.end(),
graph->adjlist_.begin() + node_offset,
graph->adjlist_.begin() + node_offset + sizes_data[i]);
rst[i].reverse_adjlist_.insert(rst[i].reverse_adjlist_.end(),
graph->reverse_adjlist_.begin() + node_offset,
graph->reverse_adjlist_.begin() + node_offset + sizes_data[i]);
// relabel adjs
size_t num_edges = 0;
for (auto& elist : rst[i].adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
}
num_edges += elist.succ.size();
}
for (auto& elist : rst[i].reverse_adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
}
}
// copy edges
rst[i].all_edges_src_.reserve(num_edges);
rst[i].all_edges_dst_.reserve(num_edges);
rst[i].num_edges_ = num_edges;
for (size_t j = edge_offset; j < edge_offset + num_edges; ++j) {
rst[i].all_edges_src_.push_back(graph->all_edges_src_[j] - node_offset);
rst[i].all_edges_dst_.push_back(graph->all_edges_dst_[j] - node_offset);
}
// update offset
CHECK_EQ(rst[i].NumVertices(), sizes_data[i]);
CHECK_EQ(rst[i].NumEdges(), num_edges);
node_offset += sizes_data[i];
edge_offset += num_edges;
}
/*for (int64_t i = 0; i < len; ++i) {
rst[i].AddVertices(sizes_data[i]);
}
for (dgl_id_t eid = 0; eid < graph->num_edges_; ++eid) {
const dgl_id_t src = graph->all_edges_src_[eid];
const dgl_id_t dst = graph->all_edges_dst_[eid];
size_t src_select = 0, dst_select = 0;
for (size_t i = 1; i < cumsum.size(); ++i) { // TODO: replace with binary search
if (cumsum[i] > src) {
src_select = i;
break;
}
}
for (size_t i = 1; i < cumsum.size(); ++i) { // TODO: replace with binary search
if (cumsum[i] > dst) {
dst_select = i;
break;
}
}
if (src_select != dst_select) {
// the edge is ignored if across two partitions
continue;
}
const int64_t offset = cumsum[src_select - 1];
rst[src_select - 1].AddEdge(src - offset, dst - offset);
}*/
return rst;
}
} // namespace dgl } // namespace dgl
import networkx as nx import networkx as nx
import dgl import dgl
import torch import torch as th
import numpy as np import numpy as np
def tree1(): def tree1():
...@@ -13,17 +13,13 @@ def tree1(): ...@@ -13,17 +13,13 @@ def tree1():
Edges are from leaves to root. Edges are from leaves to root.
""" """
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.add_node(0) g.add_nodes(5)
g.add_node(1)
g.add_node(2)
g.add_node(3)
g.add_node(4)
g.add_edge(3, 1) g.add_edge(3, 1)
g.add_edge(4, 1) g.add_edge(4, 1)
g.add_edge(1, 0) g.add_edge(1, 0)
g.add_edge(2, 0) g.add_edge(2, 0)
g.set_n_repr(torch.Tensor([0, 1, 2, 3, 4])) g.set_n_repr(th.Tensor([0, 1, 2, 3, 4]))
g.set_e_repr(torch.randn(4, 10)) g.set_e_repr(th.randn(4, 10))
return g return g
def tree2(): def tree2():
...@@ -36,17 +32,13 @@ def tree2(): ...@@ -36,17 +32,13 @@ def tree2():
Edges are from leaves to root. Edges are from leaves to root.
""" """
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.add_node(0) g.add_nodes(5)
g.add_node(1)
g.add_node(2)
g.add_node(3)
g.add_node(4)
g.add_edge(2, 4) g.add_edge(2, 4)
g.add_edge(0, 4) g.add_edge(0, 4)
g.add_edge(4, 1) g.add_edge(4, 1)
g.add_edge(3, 1) g.add_edge(3, 1)
g.set_n_repr(torch.Tensor([0, 1, 2, 3, 4])) g.set_n_repr(th.Tensor([0, 1, 2, 3, 4]))
g.set_e_repr(torch.randn(4, 10)) g.set_e_repr(th.randn(4, 10))
return g return g
def test_batch_unbatch(): def test_batch_unbatch():
...@@ -58,13 +50,36 @@ def test_batch_unbatch(): ...@@ -58,13 +50,36 @@ def test_batch_unbatch():
e2 = t2.get_e_repr() e2 = t2.get_e_repr()
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
dgl.unbatch(bg) assert bg.number_of_nodes() == 10
assert bg.number_of_edges() == 8
assert(n1.equal(t1.get_n_repr())) assert bg.batch_size == 2
assert(n2.equal(t2.get_n_repr())) assert bg.batch_num_nodes == [5, 5]
assert(e1.equal(t1.get_e_repr())) assert bg.batch_num_edges == [4, 4]
assert(e2.equal(t2.get_e_repr()))
tt1, tt2 = dgl.unbatch(bg)
assert th.allclose(t1.get_n_repr(), tt1.get_n_repr())
assert th.allclose(t1.get_e_repr(), tt1.get_e_repr())
assert th.allclose(t2.get_n_repr(), tt2.get_n_repr())
assert th.allclose(t2.get_e_repr(), tt2.get_e_repr())
def test_batch_unbatch1():
t1 = tree1()
t2 = tree2()
b1 = dgl.batch([t1, t2])
b2 = dgl.batch([t2, b1])
assert b2.number_of_nodes() == 15
assert b2.number_of_edges() == 12
assert b2.batch_size == 3
assert b2.batch_num_nodes == [5, 5, 5]
assert b2.batch_num_edges == [4, 4, 4]
s1, s2, s3 = dgl.unbatch(b2)
assert th.allclose(t2.get_n_repr(), s1.get_n_repr())
assert th.allclose(t2.get_e_repr(), s1.get_e_repr())
assert th.allclose(t1.get_n_repr(), s2.get_n_repr())
assert th.allclose(t1.get_e_repr(), s2.get_e_repr())
assert th.allclose(t2.get_n_repr(), s3.get_n_repr())
assert th.allclose(t2.get_e_repr(), s3.get_e_repr())
def test_batch_sendrecv(): def test_batch_sendrecv():
t1 = tree1() t1 = tree1()
...@@ -72,7 +87,7 @@ def test_batch_sendrecv(): ...@@ -72,7 +87,7 @@ def test_batch_sendrecv():
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src) bg.register_message_func(lambda src, edge: src)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1)) bg.register_reduce_func(lambda node, msgs: th.sum(msgs, 1))
e1 = [(3, 1), (4, 1)] e1 = [(3, 1), (4, 1)]
e2 = [(2, 4), (0, 4)] e2 = [(2, 4), (0, 4)]
...@@ -95,7 +110,7 @@ def test_batch_propagate(): ...@@ -95,7 +110,7 @@ def test_batch_propagate():
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src) bg.register_message_func(lambda src, edge: src)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1)) bg.register_reduce_func(lambda node, msgs: th.sum(msgs, 1))
# get leaves. # get leaves.
order = [] order = []
...@@ -129,20 +144,21 @@ def test_batched_edge_ordering(): ...@@ -129,20 +144,21 @@ def test_batched_edge_ordering():
g1.add_nodes_from([0,1,2, 3, 4, 5]) g1.add_nodes_from([0,1,2, 3, 4, 5])
g1.add_edges_from([(4, 5), (4, 3), (2, 3), (2, 1), (0, 1)]) g1.add_edges_from([(4, 5), (4, 3), (2, 3), (2, 1), (0, 1)])
g1.edge_list g1.edge_list
e1 = torch.randn(5, 10) e1 = th.randn(5, 10)
g1.set_e_repr(e1) g1.set_e_repr(e1)
g2 = dgl.DGLGraph() g2 = dgl.DGLGraph()
g2.add_nodes_from([0, 1, 2, 3, 4, 5]) g2.add_nodes_from([0, 1, 2, 3, 4, 5])
g2.add_edges_from([(0, 1), (1, 2), (2, 3), (5, 4), (4, 3), (5, 0)]) g2.add_edges_from([(0, 1), (1, 2), (2, 3), (5, 4), (4, 3), (5, 0)])
e2 = torch.randn(6, 10) e2 = th.randn(6, 10)
g2.set_e_repr(e2) g2.set_e_repr(e2)
g = dgl.batch([g1, g2]) g = dgl.batch([g1, g2])
r1 = g.get_e_repr()[g.get_edge_id(4, 5)] r1 = g.get_e_repr()[g.get_edge_id(4, 5)]
r2 = g1.get_e_repr()[g1.get_edge_id(4, 5)] r2 = g1.get_e_repr()[g1.get_edge_id(4, 5)]
assert torch.equal(r1, r2) assert th.equal(r1, r2)
if __name__ == '__main__': if __name__ == '__main__':
test_batch_unbatch() test_batch_unbatch()
test_batched_edge_ordering() test_batch_unbatch1()
test_batch_sendrecv() #test_batched_edge_ordering()
test_batch_propagate() #test_batch_sendrecv()
#test_batch_propagate()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment