"docs/source/vscode:/vscode.git/clone" did not exist on "097fd6dbb1a2ec306438edd1b94a29ec393fa0fa"
Commit 26dcfb5b authored by GaiYu0's avatar GaiYu0
Browse files

Merge branch 'cpp' of https://github.com/jermainewang/dgl into cpp

Conflicts:
	python/dgl/graph.py
parents cc372e37 2be55fb5
......@@ -14,4 +14,28 @@ Show below, there are three sets of APIs for different models.
- Always choose the API at the *highest* possible level.
- Refer to [the default modules](examples/pytorch/util.py) to see how to register message and node update functions as well as readout functions; note how you can control sharing of parameters by adding a counter.
## How to build (the `cpp` branch)
Before building, make sure that the submodules are cloned. If you haven't initialized the submodules, run
```sh
$ git submodule init
```
To sync the submodules, run
```sh
$ git submodule update
```
At the root directory of the repo:
```sh
$ mkdir build
$ cd build
$ cmake ..
$ make
$ export DGL_LIBRARY_PATH=$PWD
```
The `DGL_LIBRARY_PATH` environment variable should point to the library `libdgl.so` built by CMake.
......@@ -31,14 +31,28 @@ class GraphOp {
/*!
* \brief Partition the graph into several subgraphs.
*
* The graph will be partitioned by the node ids. Edges between partitions
* will be ignored. This requires the given number of partitions to evenly
* This is a reverse operation of DisjointUnion. The graph will be partitioned
* into num graphs. This requires the given number of partitions to evenly
* divides the number of nodes in the graph.
*
* \param graph The graph to be partitioned.
* \param num The number of partitions.
* \return a list of partitioned graphs
*/
static std::vector<Graph> PartitionByNum(const Graph* graph, size_t num);
static std::vector<Graph> DisjointPartitionByNum(const Graph* graph, int64_t num);
/*!
* \brief Partition the graph into several subgraphs.
*
* This is a reverse operation of DisjointUnion. The graph will be partitioned
* based on the given sizes. This requires the sum of the given sizes is equal
* to the number of nodes in the graph.
*
* \param graph The graph to be partitioned.
* \param sizes The number of partitions.
* \return a list of partitioned graphs
*/
static std::vector<Graph> DisjointPartitionBySizes(const Graph* graph, IdArray sizes);
};
} // namespace dgl
......
......@@ -8,6 +8,7 @@ from .frame import FrameRef
from .graph import DGLGraph
from . import graph_index as gi
from . import backend as F
from . import utils
class BatchedDGLGraph(DGLGraph):
"""The batched DGL graph.
......@@ -24,7 +25,6 @@ class BatchedDGLGraph(DGLGraph):
The edge attributes to also be batched.
"""
def __init__(self, graph_list, node_attrs, edge_attrs):
# TODO(minjie): handle the input is again a batched graph.
# create batched graph index
batched_index = gi.disjoint_union([g._graph for g in graph_list])
# create batched node and edge frames
......@@ -43,9 +43,19 @@ class BatchedDGLGraph(DGLGraph):
edge_frame=batched_edge_frame)
# extra members
self._batch_size = len(graph_list)
self._batch_num_nodes = [gr.number_of_nodes() for gr in graph_list]
self._batch_num_edges = [gr.number_of_edges() for gr in graph_list]
self._batch_size = 0
self._batch_num_nodes = []
self._batch_num_edges = []
for gr in graph_list:
if isinstance(gr, BatchedDGLGraph):
# handle the input is again a batched graph.
self._batch_size += gr._batch_size
self._batch_num_nodes += gr._batch_num_nodes
self._batch_num_edges += gr._batch_num_edges
else:
self._batch_size += 1
self._batch_num_nodes.append(gr.number_of_nodes())
self._batch_num_edges.append(gr.number_of_edges())
@property
def batch_size(self):
......@@ -78,10 +88,12 @@ class BatchedDGLGraph(DGLGraph):
# new APIs
def __getitem__(self, idx):
"""Slice the batch and return the batch of graphs specified by the idx."""
# TODO
pass
def __setitem__(self, idx, val):
"""Set the value of the slice. The graph size cannot be changed."""
# TODO
pass
'''
......@@ -114,36 +126,35 @@ def split(graph_batch, num_or_size_splits):
# TODO(minjie): could follow torch.split syntax
pass
def unbatch(graph_batch):
def unbatch(graph):
"""Unbatch the graph and return a list of subgraphs.
Parameters
----------
graph_batch : DGLGraph
graph : BatchedDGLGraph
The batched graph.
"""
assert False, "disabled for now"
graph_list = graph_batch.graph_list
num_graphs = len(graph_list)
# split and set node attrs
attrs = [{} for _ in range(num_graphs)] # node attr dict for each graph
for key in graph_batch.node_attr_schemes():
vals = F.unpack(graph_batch.pop_n_repr(key), graph_batch.num_nodes)
for attr, val in zip(attrs, vals):
attr[key] = val
for attr, g in zip(attrs, graph_list):
g.set_n_repr(attr)
# split and set edge attrs
attrs = [{} for _ in range(num_graphs)] # edge attr dict for each graph
for key in graph_batch.edge_attr_schemes():
vals = F.unpack(graph_batch.pop_e_repr(key), graph_batch.num_edges)
for attr, val in zip(attrs, vals):
attr[key] = val
for attr, g in zip(attrs, graph_list):
g.set_e_repr(attr)
return graph_list
assert isinstance(graph, BatchedDGLGraph)
bsize = graph.batch_size
bn = graph.batch_num_nodes
be = graph.batch_num_edges
pttns = gi.disjoint_partition(graph._graph, utils.toindex(bn))
# split the frames
node_frames = [FrameRef() for i in range(bsize)]
edge_frames = [FrameRef() for i in range(bsize)]
for attr, col in graph._node_frame.items():
# TODO: device context
col_splits = F.unpack(col, bn)
for i in range(bsize):
node_frames[i][attr] = col_splits[i]
for attr, col in graph._edge_frame.items():
# TODO: device context
col_splits = F.unpack(col, be)
for i in range(bsize):
edge_frames[i][attr] = col_splits[i]
return [DGLGraph(graph_data=pttns[i],
node_frame=node_frames[i],
edge_frame=edge_frames[i]) for i in range(bsize)]
def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
"""Batch a list of DGLGraphs into one single graph.
......
......@@ -51,11 +51,11 @@ class DGLGraph(object):
self._msg_frame = FrameRef()
self.reset_messages()
# registered functions
self._message_func = (None, None)
self._reduce_func = (None, None)
self._edge_func = (None, None)
self._apply_node_func = (None, None)
self._apply_edge_func = (None, None)
self._message_func = None
self._reduce_func = None
self._edge_func = None
self._apply_node_func = None
self._apply_edge_func = None
def add_nodes(self, num, reprs=None):
"""Add nodes.
......@@ -722,77 +722,57 @@ class DGLGraph(object):
else:
return self._edge_frame.select_rows(eid)
def register_edge_func(self,
edge_func,
batchable=False):
def register_edge_func(self, edge_func):
"""Register global edge update function.
Parameters
----------
edge_func : callable
Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
"""
self._edge_func = (edge_func, batchable)
self._edge_func = edge_func
def register_message_func(self,
message_func,
batchable=False):
def register_message_func(self, message_func):
"""Register global message function.
Parameters
----------
message_func : callable
Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
"""
self._message_func = (message_func, batchable)
self._message_func = message_func
def register_reduce_func(self,
reduce_func,
batchable=False):
def register_reduce_func(self, reduce_func):
"""Register global message reduce function.
Parameters
----------
reduce_func : str or callable
Reduce function on incoming edges.
batchable : bool
Whether the provided reduce function allows batch computing.
"""
self._reduce_func = (reduce_func, batchable)
self._reduce_func = reduce_func
def register_apply_node_func(self,
apply_node_func,
batchable=False):
def register_apply_node_func(self, apply_node_func):
"""Register global node apply function.
Parameters
----------
apply_node_func : callable
Apply function on the node.
batchable : bool
Whether the provided function allows batch computing.
"""
self._apply_node_func = (apply_node_func, batchable)
self._apply_node_func = apply_node_func
def register_apply_edge_func(self,
apply_edge_func,
batchable=False):
def register_apply_edge_func(self, apply_edge_func):
"""Register global edge apply function.
Parameters
----------
apply_edge_func : callable
Apply function on the edge.
batchable : bool
Whether the provided function allows batch computing.
"""
self._apply_edge_func = (apply_edge_func, batchable)
self._apply_edge_func = apply_edge_func
def apply_nodes(self, v, apply_node_func="default", batchable=False):
def apply_nodes(self, v, apply_node_func="default"):
"""Apply the function on node representations.
Parameters
......@@ -801,27 +781,16 @@ class DGLGraph(object):
The node id(s).
apply_node_func : callable
The apply node function.
batchable : bool
Whether the provided function allows batch computing.
"""
if apply_node_func == "default":
apply_node_func, batchable = self._apply_node_func
apply_node_func = self._apply_node_func
if not apply_node_func:
# Skip none function call.
return
if batchable:
new_repr = apply_node_func(self.get_n_repr(v))
self.set_n_repr(new_repr, v)
else:
raise RuntimeError('Disabled')
if is_all(v):
v = self.nodes()
v = utils.toindex(v)
for vv in utils.node_iter(v):
ret = apply_node_func(_get_repr(self.nodes[vv]))
_set_repr(self.nodes[vv], ret)
def apply_edges(self, u, v, apply_edge_func="default", batchable=False):
def apply_edges(self, u, v, apply_edge_func="default"):
"""Apply the function on edge representations.
Parameters
......@@ -832,27 +801,16 @@ class DGLGraph(object):
The dst node id(s).
apply_edge_func : callable
The apply edge function.
batchable : bool
Whether the provided function allows batch computing.
"""
if apply_edge_func == "default":
apply_edge_func, batchable = self._apply_edge_func
apply_edge_func = self._apply_edge_func
if not apply_edge_func:
# Skip none function call.
return
if batchable:
new_repr = apply_edge_func(self.get_e_repr(u, v))
self.set_e_repr(new_repr, u, v)
else:
if is_all(u) == is_all(v):
u, v = zip(*self.edges)
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = apply_edge_func(_get_repr(self.edges[uu, vv]))
_set_repr(self.edges[uu, vv], ret)
def send(self, u, v, message_func="default", batchable=False):
def send(self, u, v, message_func="default"):
"""Trigger the message function on edge u->v
The message function should be compatible with following signature:
......@@ -873,30 +831,13 @@ class DGLGraph(object):
The destination node(s).
message_func : callable
The message function.
batchable : bool
Whether the function allows batched computation.
"""
if message_func == "default":
message_func, batchable = self._message_func
message_func = self._message_func
assert message_func is not None
if isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func)
if batchable:
self._batch_send(u, v, message_func)
else:
self._nonbatch_send(u, v, message_func)
def _nonbatch_send(self, u, v, message_func):
raise RuntimeError('Disabled')
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = message_func(_get_repr(self.nodes[uu]),
_get_repr(self.edges[uu, vv]))
self.edges[uu, vv][__MSG__] = ret
def _batch_send(self, u, v, message_func):
if is_all(u) and is_all(v):
......@@ -920,7 +861,7 @@ class DGLGraph(object):
else:
self._msg_frame.append({__MSG__ : msgs})
def update_edge(self, u=ALL, v=ALL, edge_func="default", batchable=False):
def update_edge(self, u=ALL, v=ALL, edge_func="default"):
"""Update representation on edge u->v
The edge function should be compatible with following signature:
......@@ -939,29 +880,11 @@ class DGLGraph(object):
The destination node(s).
edge_func : callable
The update function.
batchable : bool
Whether the function allows batched computation.
"""
if edge_func == "default":
edge_func, batchable = self._edge_func
edge_func = self._edge_func
assert edge_func is not None
if batchable:
self._batch_update_edge(u, v, edge_func)
else:
self._nonbatch_update_edge(u, v, edge_func)
def _nonbatch_update_edge(self, u, v, edge_func):
raise RuntimeError('Disabled')
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = edge_func(_get_repr(self.nodes[uu]),
_get_repr(self.nodes[vv]),
_get_repr(self.edges[uu, vv]))
_set_repr(self.edges[uu, vv], ret)
def _batch_update_edge(self, u, v, edge_func):
if is_all(u) and is_all(v):
......@@ -987,8 +910,7 @@ class DGLGraph(object):
def recv(self,
u,
reduce_func="default",
apply_node_func="default",
batchable=False):
apply_node_func="default"):
"""Receive and reduce in-coming messages and update representation on node u.
It computes the new node state using the messages sent from the predecessors
......@@ -1018,34 +940,15 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable, optional
The update function.
batchable : bool, optional
Whether the reduce and update function allows batched computation.
"""
if reduce_func == "default":
reduce_func, batchable = self._reduce_func
reduce_func = self._reduce_func
assert reduce_func is not None
if isinstance(reduce_func, (list, tuple)):
reduce_func = BundledReduceFunction(reduce_func)
if batchable:
self._batch_recv(u, reduce_func)
else:
self._nonbatch_recv(u, reduce_func)
# optional apply nodes
self.apply_nodes(u, apply_node_func, batchable)
def _nonbatch_recv(self, u, reduce_func):
raise RuntimeError('Disabled')
if is_all(u):
u = list(range(0, self.number_of_nodes()))
else:
u = utils.toindex(u)
for i, uu in enumerate(utils.node_iter(u)):
# reduce phase
msgs_batch = [self.edges[vv, uu].pop(__MSG__)
for vv in self.pred[uu] if __MSG__ in self.edges[vv, uu]]
if len(msgs_batch) != 0:
new_repr = reduce_func(_get_repr(self.nodes[uu]), msgs_batch)
_set_repr(self.nodes[uu], new_repr)
self.apply_nodes(u, apply_node_func)
def _batch_recv(self, v, reduce_func):
if self._msg_frame.num_rows == 0:
......@@ -1117,8 +1020,7 @@ class DGLGraph(object):
u, v,
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False):
apply_node_func="default"):
"""Trigger the message function on u->v and update v.
Parameters
......@@ -1133,8 +1035,6 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable, optional
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
u = utils.toindex(u)
v = utils.toindex(v)
......@@ -1144,34 +1044,28 @@ class DGLGraph(object):
return
unique_v = utils.toindex(F.unique(v.tousertensor()))
# TODO(minjie): better way to figure out `batchable` flag
if message_func == "default":
message_func, batchable = self._message_func
message_func = self._message_func
if reduce_func == "default":
reduce_func, _ = self._reduce_func
reduce_func = self._reduce_func
assert message_func is not None
assert reduce_func is not None
if batchable:
executor = scheduler.get_executor(
'send_and_recv', self, src=u, dst=v,
message_func=message_func, reduce_func=reduce_func)
else:
executor = None
if executor:
executor.run()
else:
self.send(u, v, message_func, batchable=batchable)
self.recv(unique_v, reduce_func, None, batchable=batchable)
self.apply_nodes(unique_v, apply_node_func, batchable=batchable)
self.send(u, v, message_func)
self.recv(unique_v, reduce_func, None)
self.apply_nodes(unique_v, apply_node_func)
def pull(self,
v,
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False):
apply_node_func="default"):
"""Pull messages from the node's predecessors and then update it.
Parameters
......@@ -1184,24 +1078,20 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable, optional
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
v = utils.toindex(v)
if len(v) == 0:
return
uu, vv, _ = self._graph.in_edges(v)
self.send_and_recv(uu, vv, message_func, reduce_func,
apply_node_func=None, batchable=batchable)
self.send_and_recv(uu, vv, message_func, reduce_func, apply_node_func=None)
unique_v = F.unique(v.tousertensor())
self.apply_nodes(unique_v, apply_node_func, batchable=batchable)
self.apply_nodes(unique_v, apply_node_func)
def push(self,
u,
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False):
apply_node_func="default"):
"""Send message from the node to its successors and update them.
Parameters
......@@ -1214,21 +1104,18 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
u = utils.toindex(u)
if len(u) == 0:
return
uu, vv, _ = self._graph.out_edges(u)
self.send_and_recv(uu, vv, message_func,
reduce_func, apply_node_func, batchable=batchable)
reduce_func, apply_node_func)
def update_all(self,
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False):
apply_node_func="default"):
"""Send messages through all the edges and update all nodes.
Parameters
......@@ -1239,35 +1126,28 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable, optional
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
if message_func == "default":
message_func, batchable = self._message_func
message_func = self._message_func
if reduce_func == "default":
reduce_func, _ = self._reduce_func
reduce_func = self._reduce_func
assert message_func is not None
assert reduce_func is not None
if batchable:
executor = scheduler.get_executor(
"update_all", self, message_func=message_func, reduce_func=reduce_func)
else:
executor = None
if executor:
executor.run()
else:
self.send(ALL, ALL, message_func, batchable=batchable)
self.recv(ALL, reduce_func, None, batchable=batchable)
self.apply_nodes(ALL, apply_node_func, batchable=batchable)
self.send(ALL, ALL, message_func)
self.recv(ALL, reduce_func, None)
self.apply_nodes(ALL, apply_node_func)
def propagate(self,
iterator='bfs',
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False,
**kwargs):
"""Propagate messages and update nodes using iterator.
......@@ -1286,8 +1166,6 @@ class DGLGraph(object):
The reduce function.
apply_node_func : str or callable
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
iterator : str or generator of steps.
The iterator of the graph.
kwargs : keyword arguments, optional
......@@ -1300,7 +1178,7 @@ class DGLGraph(object):
# NOTE: the iteration can return multiple edges at each step.
for u, v in iterator:
self.send_and_recv(u, v,
message_func, reduce_func, apply_node_func, batchable)
message_func, reduce_func, apply_node_func)
def subgraph(self, nodes):
"""Generate the subgraph among the given nodes.
......@@ -1362,6 +1240,7 @@ class DGLGraph(object):
[sg._parent_eid for sg in to_merge],
self._edge_frame.num_rows,
reduce_func)
<<<<<<< HEAD
def adjacency_matrix(self):
"""Return the adjacency matrix representation of this graph.
......@@ -1404,3 +1283,5 @@ def _set_repr(attr_dict, attr):
attr_dict.update(attr)
else:
attr_dict[__REPR__] = attr
=======
>>>>>>> 2be55fb50ab08c1f0a3bbb40df8f9265d73b4d2d
......@@ -542,6 +542,40 @@ def disjoint_union(graphs):
handle = _CAPI_DGLDisjointUnion(inputs, len(graphs))
return GraphIndex(handle)
def disjoint_partition(graph, num_or_size_splits):
"""Partition the graph disjointly.
This is a reverse operation of DisjointUnion. The graph will be partitioned
into num graphs. This requires the given number of partitions to evenly
divides the number of nodes in the graph. If the a size list is given,
the sum of the given sizes is equal.
Parameters
----------
graph : GraphIndex
The graph to be partitioned
num_or_size_splits : int or utils.Index
The partition number of size splits
Returns
-------
list of GraphIndex
The partitioned graphs
"""
if isinstance(num_or_size_splits, utils.Index):
rst = _CAPI_DGLDisjointPartitionBySizes(
graph._handle,
num_or_size_splits.todgltensor())
else:
rst = _CAPI_DGLDisjointPartitionByNum(
graph._handle,
int(num_or_size_splits))
graphs = []
for val in rst.asnumpy():
handle = ctypes.cast(int(val), ctypes.c_void_p)
graphs.append(GraphIndex(handle))
return graphs
def create_graph_index(graph_data=None):
"""Create a graph index object.
......
......@@ -7,6 +7,7 @@ using tvm::runtime::TVMArgs;
using tvm::runtime::TVMArgValue;
using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc;
using tvm::runtime::NDArray;
namespace dgl {
......@@ -289,4 +290,39 @@ TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
*rv = ghandle;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
int64_t num = args[1];
std::vector<Graph>&& rst = GraphOp::DisjointPartitionByNum(gptr, num);
// return the pointer array as an integer array
const int64_t len = rst.size();
NDArray ptr_array = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* ptr_array_data = static_cast<int64_t*>(ptr_array->data);
for (size_t i = 0; i < rst.size(); ++i) {
Graph* ptr = new Graph();
*ptr = std::move(rst[i]);
ptr_array_data[i] = reinterpret_cast<std::intptr_t>(ptr);
}
*rv = ptr_array;
});
TVM_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes")
.set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle);
const IdArray sizes = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[1]));
std::vector<Graph>&& rst = GraphOp::DisjointPartitionBySizes(gptr, sizes);
// return the pointer array as an integer array
const int64_t len = rst.size();
NDArray ptr_array = NDArray::Empty({len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* ptr_array_data = static_cast<int64_t*>(ptr_array->data);
for (size_t i = 0; i < rst.size(); ++i) {
Graph* ptr = new Graph();
*ptr = std::move(rst[i]);
ptr_array_data[i] = reinterpret_cast<std::intptr_t>(ptr);
}
*rv = ptr_array;
});
} // namespace dgl
// Graph operation implementation
#include <dgl/graph_op.h>
#include <algorithm>
namespace dgl {
......@@ -16,4 +17,91 @@ Graph GraphOp::DisjointUnion(std::vector<const Graph*> graphs) {
return rst;
}
std::vector<Graph> GraphOp::DisjointPartitionByNum(const Graph* graph, int64_t num) {
CHECK(num != 0 && graph->NumVertices() % num == 0)
<< "Number of partitions must evenly divide the number of nodes.";
IdArray sizes = IdArray::Empty({num}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num);
return DisjointPartitionBySizes(graph, sizes);
}
std::vector<Graph> GraphOp::DisjointPartitionBySizes(const Graph* graph, IdArray sizes) {
const int64_t len = sizes->shape[0];
const int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::vector<int64_t> cumsum;
cumsum.push_back(0);
for (int64_t i = 0; i < len; ++i) {
cumsum.push_back(cumsum[i] + sizes_data[i]);
}
CHECK_EQ(cumsum[len], graph->NumVertices())
<< "Sum of the given sizes must equal to the number of nodes.";
dgl_id_t node_offset = 0, edge_offset = 0;
std::vector<Graph> rst(len);
for (int64_t i = 0; i < len; ++i) {
// copy adj
rst[i].adjlist_.insert(rst[i].adjlist_.end(),
graph->adjlist_.begin() + node_offset,
graph->adjlist_.begin() + node_offset + sizes_data[i]);
rst[i].reverse_adjlist_.insert(rst[i].reverse_adjlist_.end(),
graph->reverse_adjlist_.begin() + node_offset,
graph->reverse_adjlist_.begin() + node_offset + sizes_data[i]);
// relabel adjs
size_t num_edges = 0;
for (auto& elist : rst[i].adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
}
num_edges += elist.succ.size();
}
for (auto& elist : rst[i].reverse_adjlist_) {
for (size_t j = 0; j < elist.succ.size(); ++j) {
elist.succ[j] -= node_offset;
elist.edge_id[j] -= edge_offset;
}
}
// copy edges
rst[i].all_edges_src_.reserve(num_edges);
rst[i].all_edges_dst_.reserve(num_edges);
rst[i].num_edges_ = num_edges;
for (size_t j = edge_offset; j < edge_offset + num_edges; ++j) {
rst[i].all_edges_src_.push_back(graph->all_edges_src_[j] - node_offset);
rst[i].all_edges_dst_.push_back(graph->all_edges_dst_[j] - node_offset);
}
// update offset
CHECK_EQ(rst[i].NumVertices(), sizes_data[i]);
CHECK_EQ(rst[i].NumEdges(), num_edges);
node_offset += sizes_data[i];
edge_offset += num_edges;
}
/*for (int64_t i = 0; i < len; ++i) {
rst[i].AddVertices(sizes_data[i]);
}
for (dgl_id_t eid = 0; eid < graph->num_edges_; ++eid) {
const dgl_id_t src = graph->all_edges_src_[eid];
const dgl_id_t dst = graph->all_edges_dst_[eid];
size_t src_select = 0, dst_select = 0;
for (size_t i = 1; i < cumsum.size(); ++i) { // TODO: replace with binary search
if (cumsum[i] > src) {
src_select = i;
break;
}
}
for (size_t i = 1; i < cumsum.size(); ++i) { // TODO: replace with binary search
if (cumsum[i] > dst) {
dst_select = i;
break;
}
}
if (src_select != dst_select) {
// the edge is ignored if across two partitions
continue;
}
const int64_t offset = cumsum[src_select - 1];
rst[src_select - 1].AddEdge(src - offset, dst - offset);
}*/
return rst;
}
} // namespace dgl
......@@ -133,7 +133,7 @@ def test_batch_send():
def _fmsg(src, edge):
assert src['h'].shape == (5, D)
return {'m' : src['h']}
g.register_message_func(_fmsg, batchable=True)
g.register_message_func(_fmsg)
# many-many send
u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5])
......@@ -150,9 +150,9 @@ def test_batch_send():
def test_batch_recv():
# basic recv test
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_apply_node_func(apply_node_func, batchable=True)
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_node_func)
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear()
......@@ -163,9 +163,9 @@ def test_batch_recv():
def test_update_routines():
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_apply_node_func(apply_node_func, batchable=True)
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_node_func)
# send_and_recv
reduce_msg_shapes.clear()
......@@ -209,7 +209,7 @@ def test_reduce_0deg():
return node + msgs.sum(1)
old_repr = th.randn(5, 5)
g.set_n_repr(old_repr)
g.update_all(_message, _reduce, batchable=True)
g.update_all(_message, _reduce)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[1:], old_repr[1:])
......@@ -227,17 +227,17 @@ def test_pull_0deg():
old_repr = th.randn(2, 5)
g.set_n_repr(old_repr)
g.pull(0, _message, _reduce, batchable=True)
g.pull(0, _message, _reduce)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[1])
g.pull(1, _message, _reduce, batchable=True)
g.pull(1, _message, _reduce)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[1], old_repr[0])
old_repr = th.randn(2, 5)
g.set_n_repr(old_repr)
g.pull([0, 1], _message, _reduce, batchable=True)
g.pull([0, 1], _message, _reduce)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[0])
......
......@@ -129,7 +129,7 @@ def test_batch_send():
def _fmsg(hu, edge):
assert hu.shape == (5, D)
return hu
g.register_message_func(_fmsg, batchable=True)
g.register_message_func(_fmsg)
# many-many send
u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5])
......@@ -145,8 +145,8 @@ def test_batch_send():
def test_batch_recv():
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear()
......@@ -157,8 +157,8 @@ def test_batch_recv():
def test_update_routines():
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
# send_and_recv
reduce_msg_shapes.clear()
......
......@@ -51,32 +51,32 @@ def reducer_none(node, msgs):
def test_copy_src():
# copy_src with both fields
g = generate_graph()
g.register_message_func(fn.copy_src(src='h', out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.copy_src(src='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_src with only src field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_src(src='h'), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.copy_src(src='h'))
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_src with no src field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_src(out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.copy_src(out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy src with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_src(), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.copy_src())
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
......@@ -84,32 +84,32 @@ def test_copy_src():
def test_copy_edge():
# copy_edge with both fields
g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h', out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.copy_edge(edge='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_edge with only edge field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h'), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.copy_edge(edge='h'))
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_edge with no edge field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_edge(out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.copy_edge(out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy edge with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_edge(), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.copy_edge())
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
......@@ -117,36 +117,36 @@ def test_copy_edge():
def test_src_mul_edge():
# src_mul_edge with all fields
g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h'), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.src_mul_edge(src='h', edge='h'))
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.register_message_func(fn.src_mul_edge(out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.register_message_func(fn.src_mul_edge())
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=True)
g.register_reduce_func(reducer_none, batchable=True)
g.register_message_func(fn.src_mul_edge())
g.register_reduce_func(reducer_none)
g.update_all()
assert th.allclose(g.get_n_repr(),
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
......
import networkx as nx
import dgl
import torch
import torch as th
import numpy as np
def tree1():
......@@ -13,17 +13,13 @@ def tree1():
Edges are from leaves to root.
"""
g = dgl.DGLGraph()
g.add_node(0)
g.add_node(1)
g.add_node(2)
g.add_node(3)
g.add_node(4)
g.add_nodes(5)
g.add_edge(3, 1)
g.add_edge(4, 1)
g.add_edge(1, 0)
g.add_edge(2, 0)
g.set_n_repr(torch.Tensor([0, 1, 2, 3, 4]))
g.set_e_repr(torch.randn(4, 10))
g.set_n_repr(th.Tensor([0, 1, 2, 3, 4]))
g.set_e_repr(th.randn(4, 10))
return g
def tree2():
......@@ -36,17 +32,13 @@ def tree2():
Edges are from leaves to root.
"""
g = dgl.DGLGraph()
g.add_node(0)
g.add_node(1)
g.add_node(2)
g.add_node(3)
g.add_node(4)
g.add_nodes(5)
g.add_edge(2, 4)
g.add_edge(0, 4)
g.add_edge(4, 1)
g.add_edge(3, 1)
g.set_n_repr(torch.Tensor([0, 1, 2, 3, 4]))
g.set_e_repr(torch.randn(4, 10))
g.set_n_repr(th.Tensor([0, 1, 2, 3, 4]))
g.set_e_repr(th.randn(4, 10))
return g
def test_batch_unbatch():
......@@ -58,21 +50,44 @@ def test_batch_unbatch():
e2 = t2.get_e_repr()
bg = dgl.batch([t1, t2])
dgl.unbatch(bg)
assert(n1.equal(t1.get_n_repr()))
assert(n2.equal(t2.get_n_repr()))
assert(e1.equal(t1.get_e_repr()))
assert(e2.equal(t2.get_e_repr()))
assert bg.number_of_nodes() == 10
assert bg.number_of_edges() == 8
assert bg.batch_size == 2
assert bg.batch_num_nodes == [5, 5]
assert bg.batch_num_edges == [4, 4]
tt1, tt2 = dgl.unbatch(bg)
assert th.allclose(t1.get_n_repr(), tt1.get_n_repr())
assert th.allclose(t1.get_e_repr(), tt1.get_e_repr())
assert th.allclose(t2.get_n_repr(), tt2.get_n_repr())
assert th.allclose(t2.get_e_repr(), tt2.get_e_repr())
def test_batch_unbatch1():
t1 = tree1()
t2 = tree2()
b1 = dgl.batch([t1, t2])
b2 = dgl.batch([t2, b1])
assert b2.number_of_nodes() == 15
assert b2.number_of_edges() == 12
assert b2.batch_size == 3
assert b2.batch_num_nodes == [5, 5, 5]
assert b2.batch_num_edges == [4, 4, 4]
s1, s2, s3 = dgl.unbatch(b2)
assert th.allclose(t2.get_n_repr(), s1.get_n_repr())
assert th.allclose(t2.get_e_repr(), s1.get_e_repr())
assert th.allclose(t1.get_n_repr(), s2.get_n_repr())
assert th.allclose(t1.get_e_repr(), s2.get_e_repr())
assert th.allclose(t2.get_n_repr(), s3.get_n_repr())
assert th.allclose(t2.get_e_repr(), s3.get_e_repr())
def test_batch_sendrecv():
t1 = tree1()
t2 = tree2()
bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src, batchable=True)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1), batchable=True)
bg.register_message_func(lambda src, edge: src)
bg.register_reduce_func(lambda node, msgs: th.sum(msgs, 1))
e1 = [(3, 1), (4, 1)]
e2 = [(2, 4), (0, 4)]
......@@ -94,8 +109,8 @@ def test_batch_propagate():
t2 = tree2()
bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src, batchable=True)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1), batchable=True)
bg.register_message_func(lambda src, edge: src)
bg.register_reduce_func(lambda node, msgs: th.sum(msgs, 1))
# get leaves.
order = []
......@@ -129,20 +144,21 @@ def test_batched_edge_ordering():
g1.add_nodes_from([0,1,2, 3, 4, 5])
g1.add_edges_from([(4, 5), (4, 3), (2, 3), (2, 1), (0, 1)])
g1.edge_list
e1 = torch.randn(5, 10)
e1 = th.randn(5, 10)
g1.set_e_repr(e1)
g2 = dgl.DGLGraph()
g2.add_nodes_from([0, 1, 2, 3, 4, 5])
g2.add_edges_from([(0, 1), (1, 2), (2, 3), (5, 4), (4, 3), (5, 0)])
e2 = torch.randn(6, 10)
e2 = th.randn(6, 10)
g2.set_e_repr(e2)
g = dgl.batch([g1, g2])
r1 = g.get_e_repr()[g.get_edge_id(4, 5)]
r2 = g1.get_e_repr()[g1.get_edge_id(4, 5)]
assert torch.equal(r1, r2)
assert th.equal(r1, r2)
if __name__ == '__main__':
test_batch_unbatch()
test_batched_edge_ordering()
test_batch_sendrecv()
test_batch_propagate()
test_batch_unbatch1()
#test_batched_edge_ordering()
#test_batch_sendrecv()
#test_batch_propagate()
......@@ -38,23 +38,23 @@ def test_update_all():
g = generate_graph()
# update all
v1 = g.get_n_repr()[fld]
g.update_all(fn.copy_src(src=fld), fn.sum(out=fld), apply_func, batchable=True)
g.update_all(fn.copy_src(src=fld), fn.sum(out=fld), apply_func)
v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.update_all(message_func, reduce_func, apply_func, batchable=True)
g.update_all(message_func, reduce_func, apply_func)
v3 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
# update all with edge weights
v1 = g.get_n_repr()[fld]
g.update_all(fn.src_mul_edge(src=fld, edge='e1'),
fn.sum(out=fld), apply_func, batchable=True)
fn.sum(out=fld), apply_func)
v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.update_all(fn.src_mul_edge(src=fld, edge='e2'),
fn.sum(out=fld), apply_func, batchable=True)
fn.sum(out=fld), apply_func)
v3 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.update_all(message_func_edge, reduce_func, apply_func, batchable=True)
g.update_all(message_func_edge, reduce_func, apply_func)
v4 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
assert th.allclose(v3, v4)
......@@ -85,25 +85,25 @@ def test_send_and_recv():
# send and recv
v1 = g.get_n_repr()[fld]
g.send_and_recv(u, v, fn.copy_src(src=fld),
fn.sum(out=fld), apply_func, batchable=True)
fn.sum(out=fld), apply_func)
v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.send_and_recv(u, v, message_func,
reduce_func, apply_func, batchable=True)
reduce_func, apply_func)
v3 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
# send and recv with edge weights
v1 = g.get_n_repr()[fld]
g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e1'),
fn.sum(out=fld), apply_func, batchable=True)
fn.sum(out=fld), apply_func)
v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e2'),
fn.sum(out=fld), apply_func, batchable=True)
fn.sum(out=fld), apply_func)
v3 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.send_and_recv(u, v, message_func_edge,
reduce_func, apply_func, batchable=True)
reduce_func, apply_func)
v4 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
assert th.allclose(v3, v4)
......@@ -127,18 +127,18 @@ def test_update_all_multi_fn():
# update all, mix of builtin and UDF
g.update_all([fn.copy_src(src=fld, out='m1'), message_func],
[fn.sum(msgs='m1', out='v1'), reduce_func],
None, batchable=True)
None)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)
# run builtin with single message and reduce
g.update_all(fn.copy_src(src=fld), fn.sum(out='v1'), None, batchable=True)
g.update_all(fn.copy_src(src=fld), fn.sum(out='v1'), None)
v1 = g.get_n_repr()['v1']
assert th.allclose(v1, v2)
# 1 message, 2 reduces, using anonymous repr
g.update_all(fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None, batchable=True)
g.update_all(fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None)
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2)
......@@ -147,7 +147,7 @@ def test_update_all_multi_fn():
# update all with edge weights, 2 message, 3 reduces
g.update_all([fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
[fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')],
None, batchable=True)
None)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
......@@ -155,7 +155,7 @@ def test_update_all_multi_fn():
assert th.allclose(v1, v3)
# run UDF with single message and reduce
g.update_all(message_func_edge, reduce_func, None, batchable=True)
g.update_all(message_func_edge, reduce_func, None)
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)
......@@ -179,19 +179,19 @@ def test_send_and_recv_multi_fn():
g.send_and_recv(u, v,
[fn.copy_src(src=fld, out='m1'), message_func],
[fn.sum(msgs='m1', out='v1'), reduce_func],
None, batchable=True)
None)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)
# run builtin with single message and reduce
g.send_and_recv(u, v, fn.copy_src(src=fld), fn.sum(out='v1'),
None, batchable=True)
None)
v1 = g.get_n_repr()['v1']
assert th.allclose(v1, v2)
# 1 message, 2 reduces, using anonymous repr
g.send_and_recv(u, v, fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None, batchable=True)
g.send_and_recv(u, v, fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None)
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2)
......@@ -201,7 +201,7 @@ def test_send_and_recv_multi_fn():
g.send_and_recv(u, v,
[fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
[fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')],
None, batchable=True)
None)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
......@@ -210,7 +210,7 @@ def test_send_and_recv_multi_fn():
# run UDF with single message and reduce
g.send_and_recv(u, v, message_func_edge,
reduce_func, None, batchable=True)
reduce_func, None)
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)
......
from dgl import DGLGraph
from dgl.graph import __REPR__
def message_func(hu, e_uv):
return hu + e_uv
def reduce_func(h, msgs):
return h + sum(msgs)
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i, __REPR__=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i, __REPR__=1)
g.add_edge(i, 9, __REPR__=1)
# add a back flow from 9 to 0
g.add_edge(9, 0)
return g
def check(g, h):
nh = [str(g.nodes[i][__REPR__]) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def test_sendrecv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.send(0, 1)
g.recv(1)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10])
g.send(5, 9)
g.send(6, 9)
g.recv(9)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25])
def message_func_hybrid(src, edge):
return src[__REPR__] + edge
def reduce_func_hybrid(node, msgs):
return node[__REPR__] + sum(msgs)
def test_hybridrepr():
g = generate_graph()
for i in range(10):
g.nodes[i]['id'] = -i
g.register_message_func(message_func_hybrid)
g.register_reduce_func(reduce_func_hybrid)
g.send(0, 1)
g.recv(1)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10])
g.send(5, 9)
g.send(6, 9)
g.recv(9)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25])
if __name__ == '__main__':
test_sendrecv()
test_hybridrepr()
from dgl.graph import DGLGraph
def message_func(src, edge):
return src['h']
def reduce_func(node, msgs):
return {'m' : sum(msgs)}
def apply_func(node):
return {'h' : node['h'] + node['m']}
def message_dict_func(src, edge):
return {'m' : src['h']}
def reduce_dict_func(node, msgs):
return {'m' : sum([msg['m'] for msg in msgs])}
def apply_dict_func(node):
return {'h' : node['h'] + node['m']}
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i, h=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
return g
def check(g, h):
nh = [str(g.nodes[i]['h']) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def register1(g):
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_func)
def register2(g):
g.register_message_func(message_dict_func)
g.register_reduce_func(reduce_dict_func)
g.register_apply_node_func(apply_dict_func)
def _test_sendrecv(g):
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.send(0, 1)
g.recv(1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.send(5, 9)
g.send(6, 9)
g.recv(9)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23])
def _test_multi_sendrecv(g):
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
# one-many
g.send(0, [1, 2, 3])
g.recv([1, 2, 3])
check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 10])
# many-one
g.send([6, 7, 8], 9)
g.recv(9)
check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 34])
# many-many
g.send([0, 0, 4, 5], [4, 5, 9, 9])
g.recv([4, 5, 9])
check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45])
def _test_update_routines(g):
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.send_and_recv(0, 1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.pull(9)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 55])
g.push(0)
check(g, [1, 4, 4, 5, 6, 7, 8, 9, 10, 55])
g.update_all()
check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108])
def test_sendrecv():
g = generate_graph()
register1(g)
_test_sendrecv(g)
g = generate_graph()
register2(g)
_test_sendrecv(g)
def test_multi_sendrecv():
g = generate_graph()
register1(g)
_test_multi_sendrecv(g)
g = generate_graph()
register2(g)
_test_multi_sendrecv(g)
def test_update_routines():
g = generate_graph()
register1(g)
_test_update_routines(g)
g = generate_graph()
register2(g)
_test_update_routines(g)
if __name__ == '__main__':
test_sendrecv()
test_multi_sendrecv()
test_update_routines()
from dgl import DGLGraph
from dgl.graph import __REPR__
def message_func(hu, e_uv):
return hu
def message_not_called(hu, e_uv):
assert False
return hu
def reduce_not_called(h, msgs):
assert False
return 0
def reduce_func(h, msgs):
return h + sum(msgs)
def check(g, h):
nh = [str(g.nodes[i][__REPR__]) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i, __REPR__=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
return g
def test_no_msg_recv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_not_called)
g.register_reduce_func(reduce_not_called)
g.register_apply_node_func(lambda h : h + 1)
for i in range(10):
g.recv(i)
check(g, [2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
def test_double_recv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.send(1, 9)
g.send(2, 9)
g.recv(9)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 15])
g.register_reduce_func(reduce_not_called)
g.recv(9)
def test_pull_0deg():
g = DGLGraph()
g.add_node(0, h=2)
g.add_node(1, h=1)
g.add_edge(0, 1)
def _message(src, edge):
assert False
return src
def _reduce(node, msgs):
assert False
return node
def _update(node):
return {'h': node['h'] * 2}
g.pull(0, _message, _reduce, _update)
assert g.nodes[0]['h'] == 4
if __name__ == '__main__':
test_no_msg_recv()
test_double_recv()
test_pull_0deg()
import dgl
import dgl.function as fn
from dgl.graph import __REPR__
def generate_graph():
g = dgl.DGLGraph()
for i in range(10):
g.add_node(i, h=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i, h=1)
g.add_edge(i, 9, h=i+1)
# add a back flow from 9 to 0
g.add_edge(9, 0, h=10)
return g
def check(g, h, fld):
nh = [str(g.nodes[i][fld]) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def generate_graph1():
"""graph with anonymous repr"""
g = dgl.DGLGraph()
for i in range(10):
g.add_node(i, __REPR__=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i, __REPR__=1)
g.add_edge(i, 9, __REPR__=i+1)
# add a back flow from 9 to 0
g.add_edge(9, 0, __REPR__=10)
return g
def test_copy_src():
# copy_src with both fields
g = generate_graph()
g.register_message_func(fn.copy_src(src='h', out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy_src with only src field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_src(src='h'), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy_src with no src field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_src(out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy src with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_src(), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
def test_copy_edge():
# copy_edge with both fields
g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h', out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy_edge with only edge field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h'), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy_edge with no edge field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_edge(out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy edge with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_edge(), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
def test_src_mul_edge():
# src_mul_edge with all fields
g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h')
g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h'), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h')
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h')
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h')
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=False)
g.register_reduce_func(fn.sum(), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], __REPR__)
if __name__ == '__main__':
test_copy_src()
test_copy_edge()
test_src_mul_edge()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment