Unverified Commit 6105e441 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Many fix and updates (#47)

* subgraph copy from

* WIP

* cached members

* Change all usage of id tensor to the new Index object; remove set device in DGLGraph;

* subgraph merge API tested

* add dict type reduced msg test
parent 3721822e
...@@ -154,8 +154,6 @@ def main(args): ...@@ -154,8 +154,6 @@ def main(args):
# create GCN model # create GCN model
g = DGLGraph(data.graph) g = DGLGraph(data.graph)
if cuda:
g.set_device(dgl.gpu(args.gpu))
# create model # create model
model = GAT(g, model = GAT(g,
......
...@@ -85,8 +85,6 @@ def main(args): ...@@ -85,8 +85,6 @@ def main(args):
# create GCN model # create GCN model
g = DGLGraph(data.graph) g = DGLGraph(data.graph)
if cuda:
g.set_device(dgl.gpu(args.gpu))
model = GCN(g, model = GCN(g,
in_feats, in_feats,
args.n_hidden, args.n_hidden,
......
...@@ -79,8 +79,6 @@ def main(args): ...@@ -79,8 +79,6 @@ def main(args):
# create GCN model # create GCN model
g = DGLGraph(data.graph) g = DGLGraph(data.graph)
if cuda:
g.set_device(dgl.gpu(args.gpu))
model = GCN(g, model = GCN(g,
in_feats, in_feats,
args.n_hidden, args.n_hidden,
......
...@@ -213,7 +213,6 @@ def main(args): ...@@ -213,7 +213,6 @@ def main(args):
count, label, node_list, mask, active, label1, label1_tensor = ground_truth[0] count, label, node_list, mask, active, label1, label1_tensor = ground_truth[0]
label, node_list, mask, label1_tensor = move2cuda((label, node_list, mask, label1_tensor)) label, node_list, mask, label1_tensor = move2cuda((label, node_list, mask, label1_tensor))
ground_truth[0] = (count, label, node_list, mask, active, label1, label1_tensor) ground_truth[0] = (count, label, node_list, mask, active, label1, label1_tensor)
ground_truth[1][0].set_device(dgl.gpu(args.gpu))
optimizer.zero_grad() optimizer.zero_grad()
# create new empty graphs # create new empty graphs
......
...@@ -2,6 +2,7 @@ from __future__ import absolute_import ...@@ -2,6 +2,7 @@ from __future__ import absolute_import
import torch as th import torch as th
import scipy.sparse import scipy.sparse
import dgl.context as context
# Tensor types # Tensor types
Tensor = th.Tensor Tensor = th.Tensor
...@@ -73,3 +74,9 @@ def to_context(x, ctx): ...@@ -73,3 +74,9 @@ def to_context(x, ctx):
return x.cpu() return x.cpu()
else: else:
raise RuntimeError('Invalid context', ctx) raise RuntimeError('Invalid context', ctx)
def get_context(x):
if x.device.type == 'cpu':
return context.cpu()
else:
return context.gpu(x.device.index)
...@@ -87,7 +87,7 @@ def unbatch(graph_batch): ...@@ -87,7 +87,7 @@ def unbatch(graph_batch):
num_graphs = len(graph_list) num_graphs = len(graph_list)
# split and set node attrs # split and set node attrs
attrs = [{} for _ in range(num_graphs)] # node attr dict for each graph attrs = [{} for _ in range(num_graphs)] # node attr dict for each graph
for key in graph_batch.get_n_attr_list(): for key in graph_batch.node_attr_schemes():
vals = F.unpack(graph_batch.pop_n_repr(key), graph_batch.num_nodes) vals = F.unpack(graph_batch.pop_n_repr(key), graph_batch.num_nodes)
for attr, val in zip(attrs, vals): for attr, val in zip(attrs, vals):
attr[key] = val attr[key] = val
...@@ -96,7 +96,7 @@ def unbatch(graph_batch): ...@@ -96,7 +96,7 @@ def unbatch(graph_batch):
# split and set edge attrs # split and set edge attrs
attrs = [{} for _ in range(num_graphs)] # edge attr dict for each graph attrs = [{} for _ in range(num_graphs)] # edge attr dict for each graph
for key in graph_batch.get_e_attr_list(): for key in graph_batch.edge_attr_schemes():
vals = F.unpack(graph_batch.pop_e_repr(key), graph_batch.num_edges) vals = F.unpack(graph_batch.pop_e_repr(key), graph_batch.num_edges)
for attr, val in zip(attrs, vals): for attr, val in zip(attrs, vals):
attr[key] = val attr[key] = val
......
...@@ -8,7 +8,10 @@ def message_from_src(src, edge): ...@@ -8,7 +8,10 @@ def message_from_src(src, edge):
def reduce_sum(node, msgs): def reduce_sum(node, msgs):
if isinstance(msgs, list): if isinstance(msgs, list):
return sum(msgs) if isinstance(msgs[0], dict):
return {k : sum(m[k] for m in msgs) for k in msgs[0].keys()}
else:
return sum(msgs)
else: else:
return F.sum(msgs, 1) return F.sum(msgs, 1)
......
...@@ -14,15 +14,21 @@ import dgl.utils as utils ...@@ -14,15 +14,21 @@ import dgl.utils as utils
class CachedGraph: class CachedGraph:
def __init__(self): def __init__(self):
self._graph = igraph.Graph(directed=True) self._graph = igraph.Graph(directed=True)
self._adjmat = None # cached adjacency matrix self._freeze = False
def add_nodes(self, num_nodes): def add_nodes(self, num_nodes):
if self._freeze:
raise RuntimeError('Freezed cached graph cannot be mutated.')
self._graph.add_vertices(num_nodes) self._graph.add_vertices(num_nodes)
def add_edge(self, u, v): def add_edge(self, u, v):
if self._freeze:
raise RuntimeError('Freezed cached graph cannot be mutated.')
self._graph.add_edge(u, v) self._graph.add_edge(u, v)
def add_edges(self, u, v): def add_edges(self, u, v):
if self._freeze:
raise RuntimeError('Freezed cached graph cannot be mutated.')
# The edge will be assigned ids equal to the order. # The edge will be assigned ids equal to the order.
uvs = list(utils.edge_iter(u, v)) uvs = list(utils.edge_iter(u, v))
self._graph.add_edges(uvs) self._graph.add_edges(uvs)
...@@ -30,7 +36,7 @@ class CachedGraph: ...@@ -30,7 +36,7 @@ class CachedGraph:
def get_edge_id(self, u, v): def get_edge_id(self, u, v):
uvs = list(utils.edge_iter(u, v)) uvs = list(utils.edge_iter(u, v))
eids = self._graph.get_eids(uvs) eids = self._graph.get_eids(uvs)
return utils.convert_to_id_tensor(eids) return utils.toindex(eids)
def in_edges(self, v): def in_edges(self, v):
src = [] src = []
...@@ -39,8 +45,8 @@ class CachedGraph: ...@@ -39,8 +45,8 @@ class CachedGraph:
uu = self._graph.predecessors(vv) uu = self._graph.predecessors(vv)
src += uu src += uu
dst += [vv] * len(uu) dst += [vv] * len(uu)
src = utils.convert_to_id_tensor(src) src = utils.toindex(src)
dst = utils.convert_to_id_tensor(dst) dst = utils.toindex(dst)
return src, dst return src, dst
def out_edges(self, u): def out_edges(self, u):
...@@ -50,44 +56,51 @@ class CachedGraph: ...@@ -50,44 +56,51 @@ class CachedGraph:
vv = self._graph.successors(uu) vv = self._graph.successors(uu)
src += [uu] * len(vv) src += [uu] * len(vv)
dst += vv dst += vv
src = utils.convert_to_id_tensor(src) src = utils.toindex(src)
dst = utils.convert_to_id_tensor(dst) dst = utils.toindex(dst)
return src, dst return src, dst
def in_degrees(self, v):
degs = self._graph.indegree(list(v))
return utils.toindex(degs)
def num_edges(self):
return self._graph.ecount()
@utils.cached_member
def edges(self): def edges(self):
elist = self._graph.get_edgelist() elist = self._graph.get_edgelist()
src = [u for u, _ in elist] src = [u for u, _ in elist]
dst = [v for _, v in elist] dst = [v for _, v in elist]
src = utils.convert_to_id_tensor(src) src = utils.toindex(src)
dst = utils.convert_to_id_tensor(dst) dst = utils.toindex(dst)
return src, dst return src, dst
def in_degrees(self, v): @utils.ctx_cached_member
degs = self._graph.indegree(list(v))
return utils.convert_to_id_tensor(degs)
def adjmat(self, ctx): def adjmat(self, ctx):
"""Return a sparse adjacency matrix. """Return a sparse adjacency matrix.
The row dimension represents the dst nodes; the column dimension The row dimension represents the dst nodes; the column dimension
represents the src nodes. represents the src nodes.
""" """
if self._adjmat is None: elist = self._graph.get_edgelist()
elist = self._graph.get_edgelist() src = F.tensor([u for u, _ in elist], dtype=F.int64)
src = [u for u, _ in elist] dst = F.tensor([v for _, v in elist], dtype=F.int64)
dst = [v for _, v in elist] src = F.unsqueeze(src, 0)
src = F.unsqueeze(utils.convert_to_id_tensor(src), 0) dst = F.unsqueeze(dst, 0)
dst = F.unsqueeze(utils.convert_to_id_tensor(dst), 0) idx = F.pack([dst, src])
idx = F.pack([dst, src]) n = self._graph.vcount()
n = self._graph.vcount() dat = F.ones((len(elist),))
dat = F.ones((len(elist),)) mat = F.sparse_tensor(idx, dat, [n, n])
self._adjmat = F.sparse_tensor(idx, dat, [n, n]) mat = F.to_context(mat, ctx)
# TODO(minjie): manually convert adjmat to context return mat
self._adjmat = F.to_context(self._adjmat, ctx)
return self._adjmat def freeze(self):
self._freeze = True
def create_cached_graph(dglgraph): def create_cached_graph(dglgraph):
cg = CachedGraph() cg = CachedGraph()
cg.add_nodes(dglgraph.number_of_nodes()) cg.add_nodes(dglgraph.number_of_nodes())
cg._graph.add_edges(dglgraph.edge_list) cg._graph.add_edges(dglgraph.edge_list)
cg.freeze()
return cg return cg
...@@ -8,6 +8,12 @@ class Context(object): ...@@ -8,6 +8,12 @@ class Context(object):
def __str__(self): def __str__(self):
return '{}:{}'.format(self.device, self.device_id) return '{}:{}'.format(self.device, self.device_id)
def __eq__(self, other):
return self.device == other.device and self.device_id == other.device_id
def __hash__(self):
return hash((self.device, self.device_id))
def gpu(gpuid): def gpu(gpuid):
return Context('gpu', gpuid) return Context('gpu', gpuid)
......
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
import dgl.backend as F import dgl.backend as F
from dgl.backend import Tensor from dgl.backend import Tensor
from dgl.utils import LazyDict import dgl.utils as utils
class Frame(MutableMapping): class Frame(MutableMapping):
def __init__(self, data=None): def __init__(self, data=None):
...@@ -77,15 +77,24 @@ class Frame(MutableMapping): ...@@ -77,15 +77,24 @@ class Frame(MutableMapping):
return self.num_columns return self.num_columns
class FrameRef(MutableMapping): class FrameRef(MutableMapping):
"""Frame reference
Parameters
----------
frame : dgl.frame.Frame
The underlying frame.
index : iterable of int
The rows that are referenced in the underlying frame.
"""
def __init__(self, frame=None, index=None): def __init__(self, frame=None, index=None):
self._frame = frame if frame is not None else Frame() self._frame = frame if frame is not None else Frame()
if index is None: if index is None:
self._index = slice(0, self._frame.num_rows) self._index_data = slice(0, self._frame.num_rows)
else: else:
# check no duplicate index # check no duplication
assert len(index) == len(np.unique(index)) assert len(index) == len(np.unique(index))
self._index = index self._index_data = index
self._index_tensor = None self._index = None
@property @property
def schemes(self): def schemes(self):
...@@ -97,10 +106,10 @@ class FrameRef(MutableMapping): ...@@ -97,10 +106,10 @@ class FrameRef(MutableMapping):
@property @property
def num_rows(self): def num_rows(self):
if isinstance(self._index, slice): if isinstance(self._index_data, slice):
return self._index.stop return self._index_data.stop
else: else:
return len(self._index) return len(self._index_data)
def __contains__(self, key): def __contains__(self, key):
return key in self._frame return key in self._frame
...@@ -114,15 +123,17 @@ class FrameRef(MutableMapping): ...@@ -114,15 +123,17 @@ class FrameRef(MutableMapping):
def select_rows(self, query): def select_rows(self, query):
rowids = self._getrowid(query) rowids = self._getrowid(query)
def _lazy_select(key): def _lazy_select(key):
return F.gather_row(self._frame[key], rowids) idx = rowids.totensor(F.get_context(self._frame[key]))
return LazyDict(_lazy_select, keys=self.schemes) return F.gather_row(self._frame[key], idx)
return utils.LazyDict(_lazy_select, keys=self.schemes)
def get_column(self, name): def get_column(self, name):
col = self._frame[name] col = self._frame[name]
if self.is_span_whole_column(): if self.is_span_whole_column():
return col return col
else: else:
return F.gather_row(col, self.index_tensor()) idx = self.index().totensor(F.get_context(col))
return F.gather_row(col, idx)
def __setitem__(self, key, val): def __setitem__(self, key, val):
if isinstance(key, str): if isinstance(key, str):
...@@ -134,22 +145,26 @@ class FrameRef(MutableMapping): ...@@ -134,22 +145,26 @@ class FrameRef(MutableMapping):
shp = F.shape(col) shp = F.shape(col)
if self.is_span_whole_column(): if self.is_span_whole_column():
if self.num_columns == 0: if self.num_columns == 0:
self._index = slice(0, shp[0]) self._index_data = slice(0, shp[0])
self._clear_cache() self._clear_cache()
assert shp[0] == self.num_rows assert shp[0] == self.num_rows
self._frame[name] = col self._frame[name] = col
else: else:
colctx = F.get_context(col)
if name in self._frame: if name in self._frame:
fcol = self._frame[name] fcol = self._frame[name]
else: else:
fcol = F.zeros((self._frame.num_rows,) + shp[1:]) fcol = F.zeros((self._frame.num_rows,) + shp[1:])
newfcol = F.scatter_row(fcol, self.index_tensor(), col) fcol = F.to_context(fcol, colctx)
idx = self.index().totensor(colctx)
newfcol = F.scatter_row(fcol, idx, col)
self._frame[name] = newfcol self._frame[name] = newfcol
def update_rows(self, query, other): def update_rows(self, query, other):
rowids = self._getrowid(query) rowids = self._getrowid(query)
for key, col in other.items(): for key, col in other.items():
self._frame[key] = F.scatter_row(self._frame[key], rowids, col) idx = rowids.totensor(F.get_context(self._frame[key]))
self._frame[key] = F.scatter_row(self._frame[key], idx, col)
def __delitem__(self, key): def __delitem__(self, key):
if isinstance(key, str): if isinstance(key, str):
...@@ -161,10 +176,10 @@ class FrameRef(MutableMapping): ...@@ -161,10 +176,10 @@ class FrameRef(MutableMapping):
def delete_rows(self, query): def delete_rows(self, query):
query = F.asnumpy(query) query = F.asnumpy(query)
if isinstance(self._index, slice): if isinstance(self._index_data, slice):
self._index = list(range(self._index.start, self._index.stop)) self._index_data = list(range(self._index_data.start, self._index_data.stop))
arr = np.array(self._index, dtype=np.int32) arr = np.array(self._index_data, dtype=np.int32)
self._index = list(np.delete(arr, query)) self._index_data = list(np.delete(arr, query))
self._clear_cache() self._clear_cache()
def append(self, other): def append(self, other):
...@@ -174,16 +189,16 @@ class FrameRef(MutableMapping): ...@@ -174,16 +189,16 @@ class FrameRef(MutableMapping):
self._frame.append(other) self._frame.append(other)
# update index # update index
if span_whole: if span_whole:
self._index = slice(0, self._frame.num_rows) self._index_data = slice(0, self._frame.num_rows)
else: elif contiguous:
new_idx = list(range(self._index.start, self._index.stop)) new_idx = list(range(self._index_data.start, self._index_data.stop))
new_idx += list(range(old_nrows, self._frame.num_rows)) new_idx += list(range(old_nrows, self._frame.num_rows))
self._index = new_idx self._index_data = new_idx
self._clear_cache() self._clear_cache()
def clear(self): def clear(self):
self._frame.clear() self._frame.clear()
self._index = slice(0, 0) self._index_data = slice(0, 0)
self._clear_cache() self._clear_cache()
def __iter__(self): def __iter__(self):
...@@ -194,26 +209,73 @@ class FrameRef(MutableMapping): ...@@ -194,26 +209,73 @@ class FrameRef(MutableMapping):
def is_contiguous(self): def is_contiguous(self):
# NOTE: this check could have false negative # NOTE: this check could have false negative
return isinstance(self._index, slice) return isinstance(self._index_data, slice)
def is_span_whole_column(self): def is_span_whole_column(self):
return self.is_contiguous() and self.num_rows == self._frame.num_rows return self.is_contiguous() and self.num_rows == self._frame.num_rows
def _getrowid(self, query): def _getrowid(self, query):
if isinstance(self._index, slice): if self.is_contiguous():
# shortcut for identical mapping # shortcut for identical mapping
return query return query
else: else:
return F.gather_row(self.index_tensor(), query) idxtensor = self.index().totensor()
return utils.toindex(F.gather_row(idxtensor, query.totensor()))
def index_tensor(self): def index(self):
# TODO(minjie): context if self._index is None:
if self._index_tensor is None:
if self.is_contiguous(): if self.is_contiguous():
self._index_tensor = F.arange(self._index.stop, dtype=F.int64) self._index = utils.toindex(
F.arange(self._index_data.stop, dtype=F.int64))
else: else:
self._index_tensor = F.tensor(self._index, dtype=F.int64) self._index = utils.toindex(self._index_data)
return self._index_tensor return self._index
def _clear_cache(self): def _clear_cache(self):
self._index_tensor = None self._index_tensor = None
def merge_frames(frames, indices, max_index, reduce_func):
"""Merge a list of frames.
The result frame contains `max_index` number of rows. For each frame in
the given list, its row is merged as follows:
merged[indices[i][row]] += frames[i][row]
Parameters
----------
frames : iterator of dgl.frame.FrameRef
A list of frames to be merged.
indices : iterator of dgl.utils.Index
The indices of the frame rows.
reduce_func : str
The reduce function (only 'sum' is supported currently)
Returns
-------
merged : FrameRef
The merged frame.
"""
assert reduce_func == 'sum'
assert len(frames) > 0
schemes = frames[0].schemes
# create an adj to merge
# row index is equal to the concatenation of all the indices.
row = sum([idx.tolist() for idx in indices], [])
col = list(range(len(row)))
n = max_index
m = len(row)
row = F.unsqueeze(F.tensor(row, dtype=F.int64), 0)
col = F.unsqueeze(F.tensor(col, dtype=F.int64), 0)
idx = F.pack([row, col])
dat = F.ones((m,))
adjmat = F.sparse_tensor(idx, dat, [n, m])
ctx_adjmat = utils.CtxCachedObject(lambda ctx: F.to_context(adjmat, ctx))
merged = {}
for key in schemes:
# the rhs of the spmv is the concatenation of all the frame columns
feats = F.pack([fr[key] for fr in frames])
merged_feats = F.spmm(ctx_adjmat.get(F.get_context(feats)), feats)
merged[key] = merged_feats
merged = FrameRef(Frame(merged))
return merged
...@@ -12,7 +12,7 @@ from dgl.backend import Tensor ...@@ -12,7 +12,7 @@ from dgl.backend import Tensor
import dgl.builtin as builtin import dgl.builtin as builtin
from dgl.cached_graph import CachedGraph, create_cached_graph from dgl.cached_graph import CachedGraph, create_cached_graph
import dgl.context as context import dgl.context as context
from dgl.frame import FrameRef from dgl.frame import FrameRef, merge_frames
from dgl.nx_adapt import nx_init from dgl.nx_adapt import nx_init
import dgl.scheduler as scheduler import dgl.scheduler as scheduler
import dgl.utils as utils import dgl.utils as utils
...@@ -62,12 +62,11 @@ class DGLGraph(DiGraph): ...@@ -62,12 +62,11 @@ class DGLGraph(DiGraph):
self._reduce_func = None self._reduce_func = None
self._update_func = None self._update_func = None
self._edge_func = None self._edge_func = None
self._context = context.cpu()
def get_n_attr_list(self): def node_attr_schemes(self):
return self._node_frame.schemes return self._node_frame.schemes
def get_e_attr_list(self): def edge_attr_schemes(self):
return self._edge_frame.schemes return self._edge_frame.schemes
def set_n_repr(self, hu, u=ALL): def set_n_repr(self, hu, u=ALL):
...@@ -92,7 +91,7 @@ class DGLGraph(DiGraph): ...@@ -92,7 +91,7 @@ class DGLGraph(DiGraph):
if is_all(u): if is_all(u):
num_nodes = self.number_of_nodes() num_nodes = self.number_of_nodes()
else: else:
u = utils.convert_to_id_tensor(u, self.context) u = utils.toindex(u)
num_nodes = len(u) num_nodes = len(u)
if isinstance(hu, dict): if isinstance(hu, dict):
for key, val in hu.items(): for key, val in hu.items():
...@@ -108,10 +107,9 @@ class DGLGraph(DiGraph): ...@@ -108,10 +107,9 @@ class DGLGraph(DiGraph):
self._node_frame[__REPR__] = hu self._node_frame[__REPR__] = hu
else: else:
if isinstance(hu, dict): if isinstance(hu, dict):
for key, val in hu.items(): self._node_frame[u] = hu
self._node_frame[key] = F.scatter_row(self._node_frame[key], u, val)
else: else:
self._node_frame[__REPR__] = F.scatter_row(self._node_frame[__REPR__], u, hu) self._node_frame[u] = {__REPR__ : hu}
def get_n_repr(self, u=ALL): def get_n_repr(self, u=ALL):
"""Get node(s) representation. """Get node(s) representation.
...@@ -127,9 +125,9 @@ class DGLGraph(DiGraph): ...@@ -127,9 +125,9 @@ class DGLGraph(DiGraph):
else: else:
return dict(self._node_frame) return dict(self._node_frame)
else: else:
u = utils.convert_to_id_tensor(u, self.context) u = utils.toindex(u)
if len(self._node_frame) == 1 and __REPR__ in self._node_frame: if len(self._node_frame) == 1 and __REPR__ in self._node_frame:
return self._node_frame[__REPR__][u] return self._node_frame.select_rows(u)[__REPR__]
else: else:
return self._node_frame.select_rows(u) return self._node_frame.select_rows(u)
...@@ -168,10 +166,10 @@ class DGLGraph(DiGraph): ...@@ -168,10 +166,10 @@ class DGLGraph(DiGraph):
v_is_all = is_all(v) v_is_all = is_all(v)
assert u_is_all == v_is_all assert u_is_all == v_is_all
if u_is_all: if u_is_all:
num_edges = self.number_of_edges() num_edges = self.cached_graph.num_edges()
else: else:
u = utils.convert_to_id_tensor(u, self.context) u = utils.toindex(u)
v = utils.convert_to_id_tensor(v, self.context) v = utils.toindex(v)
num_edges = max(len(u), len(v)) num_edges = max(len(u), len(v))
if isinstance(h_uv, dict): if isinstance(h_uv, dict):
for key, val in h_uv.items(): for key, val in h_uv.items():
...@@ -188,10 +186,9 @@ class DGLGraph(DiGraph): ...@@ -188,10 +186,9 @@ class DGLGraph(DiGraph):
else: else:
eid = self.cached_graph.get_edge_id(u, v) eid = self.cached_graph.get_edge_id(u, v)
if isinstance(h_uv, dict): if isinstance(h_uv, dict):
for key, val in h_uv.items(): self._edge_frame[eid] = h_uv
self._edge_frame[key] = F.scatter_row(self._edge_frame[key], eid, val)
else: else:
self._edge_frame[__REPR__] = F.scatter_row(self._edge_frame[__REPR__], eid, h_uv) self._edge_frame[eid] = {__REPR__ : h_uv}
def set_e_repr_by_id(self, h_uv, eid=ALL): def set_e_repr_by_id(self, h_uv, eid=ALL):
"""Set edge(s) representation by edge id. """Set edge(s) representation by edge id.
...@@ -205,9 +202,9 @@ class DGLGraph(DiGraph): ...@@ -205,9 +202,9 @@ class DGLGraph(DiGraph):
""" """
# sanity check # sanity check
if is_all(eid): if is_all(eid):
num_edges = self.number_of_edges() num_edges = self.cached_graph.num_edges()
else: else:
eid = utils.convert_to_id_tensor(eid, self.context) eid = utils.toindex(eid)
num_edges = len(eid) num_edges = len(eid)
if isinstance(h_uv, dict): if isinstance(h_uv, dict):
for key, val in h_uv.items(): for key, val in h_uv.items():
...@@ -223,10 +220,9 @@ class DGLGraph(DiGraph): ...@@ -223,10 +220,9 @@ class DGLGraph(DiGraph):
self._edge_frame[__REPR__] = h_uv self._edge_frame[__REPR__] = h_uv
else: else:
if isinstance(h_uv, dict): if isinstance(h_uv, dict):
for key, val in h_uv.items(): self._edge_frame[eid] = h_uv
self._edge_frame[key] = F.scatter_row(self._edge_frame[key], eid, val)
else: else:
self._edge_frame[__REPR__] = F.scatter_row(self._edge_frame[__REPR__], eid, h_uv) self._edge_frame[eid] = {__REPR__ : h_uv}
def get_e_repr(self, u=ALL, v=ALL): def get_e_repr(self, u=ALL, v=ALL):
"""Get node(s) representation. """Get node(s) representation.
...@@ -247,11 +243,11 @@ class DGLGraph(DiGraph): ...@@ -247,11 +243,11 @@ class DGLGraph(DiGraph):
else: else:
return dict(self._edge_frame) return dict(self._edge_frame)
else: else:
u = utils.convert_to_id_tensor(u, self.context) u = utils.toindex(u)
v = utils.convert_to_id_tensor(v, self.context) v = utils.toindex(v)
eid = self.cached_graph.get_edge_id(u, v) eid = self.cached_graph.get_edge_id(u, v)
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__][eid] return self._edge_frame.select_rows(eid)[__REPR__]
else: else:
return self._edge_frame.select_rows(eid) return self._edge_frame.select_rows(eid)
...@@ -279,27 +275,12 @@ class DGLGraph(DiGraph): ...@@ -279,27 +275,12 @@ class DGLGraph(DiGraph):
else: else:
return dict(self._edge_frame) return dict(self._edge_frame)
else: else:
eid = utils.convert_to_id_tensor(eid, self.context) eid = utils.toindex(eid)
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__][eid] return self._edge_frame.select_rows(eid)[__REPR__]
else: else:
return self._edge_frame.select_rows(eid) return self._edge_frame.select_rows(eid)
def set_device(self, ctx):
"""Set device context for this graph.
Parameters
----------
ctx : dgl.context.Context
The device context.
"""
self._context = ctx
@property
def context(self):
"""Get the device context of this graph."""
return self._context
def register_message_func(self, def register_message_func(self,
message_func, message_func,
batchable=False): batchable=False):
...@@ -356,27 +337,6 @@ class DGLGraph(DiGraph): ...@@ -356,27 +337,6 @@ class DGLGraph(DiGraph):
""" """
self._update_func = (update_func, batchable) self._update_func = (update_func, batchable)
def readout(self,
readout_func,
nodes=ALL,
edges=ALL):
"""Trigger the readout function on the specified nodes/edges.
Parameters
----------
readout_func : callable
Readout function.
nodes : str, node, container or tensor
The nodes to get reprs from.
edges : str, pair of nodes, pair of containers or pair of tensors
The edges to get reprs from.
"""
nodes = self._nodes_or_all(nodes)
edges = self._edges_or_all(edges)
nstates = [self.nodes[n] for n in nodes]
estates = [self.edges[e] for e in edges]
return readout_func(nstates, estates)
def sendto(self, u, v, message_func=None, batchable=False): def sendto(self, u, v, message_func=None, batchable=False):
"""Trigger the message function on edge u->v """Trigger the message function on edge u->v
...@@ -413,6 +373,9 @@ class DGLGraph(DiGraph): ...@@ -413,6 +373,9 @@ class DGLGraph(DiGraph):
f_msg = _get_message_func(message_func) f_msg = _get_message_func(message_func)
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
u, v = self.cached_graph.edges() u, v = self.cached_graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v): for uu, vv in utils.edge_iter(u, v):
ret = f_msg(_get_repr(self.nodes[uu]), ret = f_msg(_get_repr(self.nodes[uu]),
_get_repr(self.edges[uu, vv])) _get_repr(self.edges[uu, vv]))
...@@ -428,8 +391,8 @@ class DGLGraph(DiGraph): ...@@ -428,8 +391,8 @@ class DGLGraph(DiGraph):
edge_reprs = self.get_e_repr() edge_reprs = self.get_e_repr()
msgs = message_func(src_reprs, edge_reprs) msgs = message_func(src_reprs, edge_reprs)
else: else:
u = utils.convert_to_id_tensor(u) u = utils.toindex(u)
v = utils.convert_to_id_tensor(v) v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v) u, v = utils.edge_broadcasting(u, v)
eid = self.cached_graph.get_edge_id(u, v) eid = self.cached_graph.get_edge_id(u, v)
self.msg_graph.add_edges(u, v) self.msg_graph.add_edges(u, v)
...@@ -475,6 +438,9 @@ class DGLGraph(DiGraph): ...@@ -475,6 +438,9 @@ class DGLGraph(DiGraph):
def _nonbatch_update_edge(self, u, v, edge_func): def _nonbatch_update_edge(self, u, v, edge_func):
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
u, v = self.cached_graph.edges() u, v = self.cached_graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v): for uu, vv in utils.edge_iter(u, v):
ret = edge_func(_get_repr(self.nodes[uu]), ret = edge_func(_get_repr(self.nodes[uu]),
_get_repr(self.nodes[vv]), _get_repr(self.nodes[vv]),
...@@ -491,8 +457,8 @@ class DGLGraph(DiGraph): ...@@ -491,8 +457,8 @@ class DGLGraph(DiGraph):
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs) new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs)
self.set_e_repr(new_edge_reprs) self.set_e_repr(new_edge_reprs)
else: else:
u = utils.convert_to_id_tensor(u) u = utils.toindex(u)
v = utils.convert_to_id_tensor(v) v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v) u, v = utils.edge_broadcasting(u, v)
eid = self.cached_graph.get_edge_id(u, v) eid = self.cached_graph.get_edge_id(u, v)
# call the UDF # call the UDF
...@@ -559,6 +525,8 @@ class DGLGraph(DiGraph): ...@@ -559,6 +525,8 @@ class DGLGraph(DiGraph):
f_update = update_func f_update = update_func
if is_all(u): if is_all(u):
u = list(range(0, self.number_of_nodes())) u = list(range(0, self.number_of_nodes()))
else:
u = utils.toindex(u)
for i, uu in enumerate(utils.node_iter(u)): for i, uu in enumerate(utils.node_iter(u)):
# reduce phase # reduce phase
msgs_batch = [self.edges[vv, uu].pop(__MSG__) msgs_batch = [self.edges[vv, uu].pop(__MSG__)
...@@ -586,14 +554,16 @@ class DGLGraph(DiGraph): ...@@ -586,14 +554,16 @@ class DGLGraph(DiGraph):
new_ns = f_update(reordered_ns, all_reduced_msgs) new_ns = f_update(reordered_ns, all_reduced_msgs)
if is_all(v): if is_all(v):
# First do reorder and then replace the whole column. # First do reorder and then replace the whole column.
_, indices = F.sort(reordered_v) _, indices = F.sort(reordered_v.totensor())
# TODO(minjie): manually convert ids to context. indices = utils.toindex(indices)
indices = F.to_context(indices, self.context) # TODO(minjie): following code should be included in Frame somehow.
if isinstance(new_ns, dict): if isinstance(new_ns, dict):
for key, val in new_ns.items(): for key, val in new_ns.items():
self._node_frame[key] = F.gather_row(val, indices) idx = indices.totensor(F.get_context(val))
self._node_frame[key] = F.gather_row(val, idx)
else: else:
self._node_frame[__REPR__] = F.gather_row(new_ns, indices) idx = indices.totensor(F.get_context(new_ns))
self._node_frame[__REPR__] = F.gather_row(new_ns, idx)
else: else:
# Use setter to do reorder. # Use setter to do reorder.
self.set_n_repr(new_ns, reordered_v) self.set_n_repr(new_ns, reordered_v)
...@@ -605,9 +575,14 @@ class DGLGraph(DiGraph): ...@@ -605,9 +575,14 @@ class DGLGraph(DiGraph):
if is_all(v): if is_all(v):
v = list(range(self.number_of_nodes())) v = list(range(self.number_of_nodes()))
# freeze message graph
self.msg_graph.freeze()
# sanity checks # sanity checks
v = utils.convert_to_id_tensor(v) v = utils.toindex(v)
f_reduce = _get_reduce_func(reduce_func) f_reduce = _get_reduce_func(reduce_func)
# degree bucketing # degree bucketing
degrees, v_buckets = scheduler.degree_bucketing(self.msg_graph, v) degrees, v_buckets = scheduler.degree_bucketing(self.msg_graph, v)
reduced_msgs = [] reduced_msgs = []
...@@ -617,8 +592,6 @@ class DGLGraph(DiGraph): ...@@ -617,8 +592,6 @@ class DGLGraph(DiGraph):
bkt_len = len(v_bkt) bkt_len = len(v_bkt)
uu, vv = self.msg_graph.in_edges(v_bkt) uu, vv = self.msg_graph.in_edges(v_bkt)
in_msg_ids = self.msg_graph.get_edge_id(uu, vv) in_msg_ids = self.msg_graph.get_edge_id(uu, vv)
# TODO(minjie): manually convert ids to context.
in_msg_ids = F.to_context(in_msg_ids, self.context)
in_msgs = self._msg_frame.select_rows(in_msg_ids) in_msgs = self._msg_frame.select_rows(in_msg_ids)
# Reshape the column tensor to (B, Deg, ...). # Reshape the column tensor to (B, Deg, ...).
def _reshape_fn(msg): def _reshape_fn(msg):
...@@ -641,10 +614,14 @@ class DGLGraph(DiGraph): ...@@ -641,10 +614,14 @@ class DGLGraph(DiGraph):
self.clear_messages() self.clear_messages()
# Read the node states in the degree-bucketing order. # Read the node states in the degree-bucketing order.
reordered_v = F.pack(v_buckets) reordered_v = utils.toindex(F.pack(
[v_bkt.totensor() for v_bkt in v_buckets]))
# Pack all reduced msgs together # Pack all reduced msgs together
if isinstance(reduced_msgs[0], dict): if isinstance(reduced_msgs[0], dict):
all_reduced_msgs = {key : F.pack(val) for key, val in reduced_msgs.items()} keys = reduced_msgs[0].keys()
all_reduced_msgs = {
key : F.pack([msg[key] for msg in reduced_msgs])
for key in keys}
else: else:
all_reduced_msgs = F.pack(reduced_msgs) all_reduced_msgs = F.pack(reduced_msgs)
...@@ -697,6 +674,9 @@ class DGLGraph(DiGraph): ...@@ -697,6 +674,9 @@ class DGLGraph(DiGraph):
update_func): update_func):
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
u, v = self.cached_graph.edges() u, v = self.cached_graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
self._nonbatch_sendto(u, v, message_func) self._nonbatch_sendto(u, v, message_func)
dst = set() dst = set()
for uu, vv in utils.edge_iter(u, v): for uu, vv in utils.edge_iter(u, v):
...@@ -713,12 +693,14 @@ class DGLGraph(DiGraph): ...@@ -713,12 +693,14 @@ class DGLGraph(DiGraph):
self.update_all(message_func, reduce_func, update_func, True) self.update_all(message_func, reduce_func, update_func, True)
elif message_func == 'from_src' and reduce_func == 'sum': elif message_func == 'from_src' and reduce_func == 'sum':
# TODO(minjie): check the validity of edges u->v # TODO(minjie): check the validity of edges u->v
u = utils.convert_to_id_tensor(u) u = utils.toindex(u)
v = utils.convert_to_id_tensor(v) v = utils.toindex(v)
# TODO(minjie): broadcasting is optional for many-one input. # TODO(minjie): broadcasting is optional for many-one input.
u, v = utils.edge_broadcasting(u, v) u, v = utils.edge_broadcasting(u, v)
# relabel destination nodes. # relabel destination nodes.
new2old, old2new = utils.build_relabel_map(v) new2old, old2new = utils.build_relabel_map(v)
u = u.totensor()
v = v.totensor()
# TODO(minjie): should not directly use [] # TODO(minjie): should not directly use []
new_v = old2new[v] new_v = old2new[v]
# create adj mat # create adj mat
...@@ -726,8 +708,8 @@ class DGLGraph(DiGraph): ...@@ -726,8 +708,8 @@ class DGLGraph(DiGraph):
dat = F.ones((len(u),)) dat = F.ones((len(u),))
n = self.number_of_nodes() n = self.number_of_nodes()
m = len(new2old) m = len(new2old)
# TODO(minjie): context
adjmat = F.sparse_tensor(idx, dat, [m, n]) adjmat = F.sparse_tensor(idx, dat, [m, n])
adjmat = F.to_context(adjmat, self.context)
# TODO(minjie): use lazy dict for reduced_msgs # TODO(minjie): use lazy dict for reduced_msgs
reduced_msgs = {} reduced_msgs = {}
for key in self._node_frame.schemes: for key in self._node_frame.schemes:
...@@ -739,10 +721,10 @@ class DGLGraph(DiGraph): ...@@ -739,10 +721,10 @@ class DGLGraph(DiGraph):
new_node_repr = update_func(node_repr, reduced_msgs) new_node_repr = update_func(node_repr, reduced_msgs)
self.set_n_repr(new_node_repr, new2old) self.set_n_repr(new_node_repr, new2old)
else: else:
u = utils.convert_to_id_tensor(u, self.context) u = utils.toindex(u)
v = utils.convert_to_id_tensor(v, self.context) v = utils.toindex(v)
self._batch_sendto(u, v, message_func) self._batch_sendto(u, v, message_func)
unique_v = F.unique(v) unique_v = F.unique(v.totensor())
self._batch_recv(unique_v, reduce_func, update_func) self._batch_recv(unique_v, reduce_func, update_func)
def update_to(self, def update_to(self,
...@@ -776,15 +758,17 @@ class DGLGraph(DiGraph): ...@@ -776,15 +758,17 @@ class DGLGraph(DiGraph):
assert reduce_func is not None assert reduce_func is not None
assert update_func is not None assert update_func is not None
if batchable: if batchable:
v = utils.toindex(v)
uu, vv = self.cached_graph.in_edges(v) uu, vv = self.cached_graph.in_edges(v)
self.update_by_edge(uu, vv, message_func, self._batch_update_by_edge(uu, vv, message_func,
reduce_func, update_func, batchable) reduce_func, update_func)
else: else:
v = utils.toindex(v)
for vv in utils.node_iter(v): for vv in utils.node_iter(v):
assert vv in self.nodes assert vv in self.nodes
uu = list(self.pred[vv]) uu = list(self.pred[vv])
self.sendto(uu, vv, message_func, batchable) self._nonbatch_sendto(uu, vv, message_func)
self.recv(vv, reduce_func, update_func, batchable) self._nonbatch_recv(vv, reduce_func, update_func)
def update_from(self, def update_from(self,
u, u,
...@@ -817,15 +801,17 @@ class DGLGraph(DiGraph): ...@@ -817,15 +801,17 @@ class DGLGraph(DiGraph):
assert reduce_func is not None assert reduce_func is not None
assert update_func is not None assert update_func is not None
if batchable: if batchable:
u = utils.toindex(u)
uu, vv = self.cached_graph.out_edges(u) uu, vv = self.cached_graph.out_edges(u)
self.update_by_edge(uu, vv, message_func, self._batch_update_by_edge(uu, vv, message_func,
reduce_func, update_func, batchable) reduce_func, update_func)
else: else:
u = utils.toindex(u)
for uu in utils.node_iter(u): for uu in utils.node_iter(u):
assert uu in self.nodes assert uu in self.nodes
for v in self.succ[uu]: for v in self.succ[uu]:
self.update_by_edge(uu, v, self._nonbatch_update_by_edge(uu, v,
message_func, reduce_func, update_func, batchable) message_func, reduce_func, update_func)
def update_all(self, def update_all(self,
message_func=None, message_func=None,
...@@ -857,10 +843,10 @@ class DGLGraph(DiGraph): ...@@ -857,10 +843,10 @@ class DGLGraph(DiGraph):
if batchable: if batchable:
if message_func == 'from_src' and reduce_func == 'sum': if message_func == 'from_src' and reduce_func == 'sum':
# TODO(minjie): use lazy dict for reduced_msgs # TODO(minjie): use lazy dict for reduced_msgs
adjmat = self.cached_graph.adjmat(self.context)
reduced_msgs = {} reduced_msgs = {}
for key in self._node_frame.schemes: for key in self._node_frame.schemes:
col = self._node_frame[key] col = self._node_frame[key]
adjmat = self.cached_graph.adjmat(F.get_context(col))
reduced_msgs[key] = F.spmm(adjmat, col) reduced_msgs[key] = F.spmm(adjmat, col)
if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs: if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs:
reduced_msgs = reduced_msgs[__REPR__] reduced_msgs = reduced_msgs[__REPR__]
...@@ -930,23 +916,52 @@ class DGLGraph(DiGraph): ...@@ -930,23 +916,52 @@ class DGLGraph(DiGraph):
Returns Returns
------- -------
G : DGLGraph G : DGLSubGraph
The subgraph. The subgraph.
""" """
return dgl.DGLSubGraph(self, nodes) return dgl.DGLSubGraph(self, nodes)
def copy_from(self, graph): def merge(self, subgraphs, reduce_func='sum'):
"""Copy node/edge features from the given graph. """Merge subgraph features back to this parent graph.
All old features will be removed.
Parameters Parameters
---------- ----------
graph : DGLGraph subgraphs : iterator of DGLSubGraph
The graph to copy from. The subgraphs to be merged.
reduce_func : str
The reduce function (only 'sum' is supported currently)
""" """
# TODO # sanity check: all the subgraphs and the parent graph
pass # should have the same node/edge feature schemes.
# merge node features
to_merge = []
for sg in subgraphs:
if len(sg.node_attr_schemes()) == 0:
continue
if sg.node_attr_schemes() != self.node_attr_schemes():
raise RuntimeError('Subgraph and parent graph do not '
'have the same node attribute schemes.')
to_merge.append(sg)
self._node_frame = merge_frames(
[sg._node_frame for sg in to_merge],
[sg._parent_nid for sg in to_merge],
self._node_frame.num_rows,
reduce_func)
# merge edge features
to_merge.clear()
for sg in subgraphs:
if len(sg.edge_attr_schemes()) == 0:
continue
if sg.edge_attr_schemes() != self.edge_attr_schemes():
raise RuntimeError('Subgraph and parent graph do not '
'have the same edge attribute schemes.')
to_merge.append(sg)
self._edge_frame = merge_frames(
[sg._edge_frame for sg in to_merge],
[sg._parent_eid for sg in to_merge],
self._edge_frame.num_rows,
reduce_func)
def draw(self): def draw(self):
"""Plot the graph using dot.""" """Plot the graph using dot."""
...@@ -996,14 +1011,10 @@ class DGLGraph(DiGraph): ...@@ -996,14 +1011,10 @@ class DGLGraph(DiGraph):
eid : tensor eid : tensor
The tensor contains edge id(s). The tensor contains edge id(s).
""" """
u = utils.toindex(u)
v = utils.toindex(v)
return self.cached_graph.get_edge_id(u, v) return self.cached_graph.get_edge_id(u, v)
def _nodes_or_all(self, nodes):
return self.nodes() if nodes == ALL else nodes
def _edges_or_all(self, edges):
return self.edges() if edges == ALL else edges
def _add_node_callback(self, node): def _add_node_callback(self, node):
#print('New node:', node) #print('New node:', node)
self._cached_graph = None self._cached_graph = None
......
"""Schedule policies for graph computation.""" """Schedule policies for graph computation."""
from __future__ import absolute_import from __future__ import absolute_import
import dgl.backend as F
import numpy as np import numpy as np
import dgl.backend as F
import dgl.utils as utils
def degree_bucketing(cached_graph, v): def degree_bucketing(cached_graph, v):
degrees = F.asnumpy(cached_graph.in_degrees(v)) """Create degree bucketing scheduling policy.
Parameters
----------
cached_graph : dgl.cached_graph.CachedGraph
the graph
v : dgl.utils.Index
the nodes to gather messages
Returns
-------
unique_degrees : list of int
list of unique degrees
v_bkt : list of dgl.utils.Index
list of node id buckets; nodes belong to the same bucket have
the same degree
"""
degrees = F.asnumpy(cached_graph.in_degrees(v).totensor())
unique_degrees = list(np.unique(degrees)) unique_degrees = list(np.unique(degrees))
v_np = np.array(v.tolist())
v_bkt = [] v_bkt = []
for deg in unique_degrees: for deg in unique_degrees:
idx = np.where(degrees == deg) idx = np.where(degrees == deg)
v_bkt.append(v[idx]) v_bkt.append(utils.Index(v_np[idx]))
return unique_degrees, v_bkt return unique_degrees, v_bkt
...@@ -13,46 +13,30 @@ class DGLSubGraph(DGLGraph): ...@@ -13,46 +13,30 @@ class DGLSubGraph(DGLGraph):
def __init__(self, def __init__(self,
parent, parent,
nodes): nodes):
# create subgraph and relabel super(DGLSubGraph, self).__init__()
nx_sg = nx.DiGraph.subgraph(parent, nodes) # relabel nodes
# node id
# TODO(minjie): context
nid = F.tensor(nodes, dtype=F.int64)
# edge id
# TODO(minjie): slow, context
u, v = zip(*nx_sg.edges)
u = list(u)
v = list(v)
eid = parent.cached_graph.get_edge_id(u, v)
# relabel
self._node_mapping = utils.build_relabel_dict(nodes) self._node_mapping = utils.build_relabel_dict(nodes)
nx_sg = nx.relabel.relabel_nodes(nx_sg, self._node_mapping) self._parent_nid = utils.toindex(nodes)
eids = []
# create subgraph
for eid, (u, v) in enumerate(parent.edge_list):
if u in self._node_mapping and v in self._node_mapping:
self.add_edge(self._node_mapping[u],
self._node_mapping[v])
eids.append(eid)
self._parent_eid = utils.toindex(eids)
def copy_from(self, parent):
"""Copy node/edge features from the parent graph.
All old features will be removed.
# init Parameters
self._edge_list = [] ----------
nx_init(self, parent : DGLGraph
self._add_node_callback, The parent graph to copy from.
self._add_edge_callback, """
self._del_node_callback, if parent._node_frame.num_rows != 0:
self._del_edge_callback, self._node_frame = FrameRef(Frame(parent._node_frame[self._parent_nid]))
nx_sg, if parent._edge_frame.num_rows != 0:
**parent.graph) self._edge_frame = FrameRef(Frame(parent._edge_frame[self._parent_eid]))
# cached graph and storage
self._cached_graph = None
if parent._node_frame.num_rows == 0:
self._node_frame = FrameRef()
else:
self._node_frame = FrameRef(Frame(parent._node_frame[nid]))
if parent._edge_frame.num_rows == 0:
self._edge_frame = FrameRef()
else:
self._edge_frame = FrameRef(Frame(parent._edge_frame[eid]))
# other class members
self._msg_graph = None
self._msg_frame = FrameRef()
self._message_func = parent._message_func
self._reduce_func = parent._reduce_func
self._update_func = parent._update_func
self._edge_func = parent._edge_func
self._context = parent._context
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
from __future__ import absolute_import from __future__ import absolute_import
from collections import Mapping from collections import Mapping
from functools import wraps
import numpy as np
import dgl.backend as F import dgl.backend as F
from dgl.backend import Tensor, SparseTensor from dgl.backend import Tensor, SparseTensor
...@@ -11,18 +14,77 @@ def is_id_tensor(u): ...@@ -11,18 +14,77 @@ def is_id_tensor(u):
def is_id_container(u): def is_id_container(u):
"""Return whether the input is a supported id container.""" """Return whether the input is a supported id container."""
return isinstance(u, list) return (getattr(u, '__iter__', None) is not None
and getattr(u, '__len__', None) is not None)
class Index(object):
"""Index class that can be easily converted to list/tensor."""
def __init__(self, data):
self._list_data = None
self._tensor_data = None
self._ctx_data = dict()
self._dispatch(data)
def _dispatch(self, data):
if is_id_tensor(data):
self._tensor_data = data
elif is_id_container(data):
self._list_data = data
else:
try:
self._list_data = [int(data)]
except:
raise TypeError('Error index data: %s' % str(x))
def tolist(self):
if self._list_data is None:
self._list_data = list(F.asnumpy(self._tensor_data))
return self._list_data
def totensor(self, ctx=None):
if self._tensor_data is None:
self._tensor_data = F.tensor(self._list_data, dtype=F.int64)
if ctx is None:
return self._tensor_data
if ctx not in self._ctx_data:
self._ctx_data[ctx] = F.to_context(self._tensor_data, ctx)
return self._ctx_data[ctx]
def __iter__(self):
return iter(self.tolist())
def __len__(self):
if self._list_data is not None:
return len(self._list_data)
else:
return len(self._tensor_data)
def __getitem__(self, i):
return self.tolist()[i]
def toindex(x):
return x if isinstance(x, Index) else Index(x)
def node_iter(n): def node_iter(n):
"""Return an iterator that loops over the given nodes.""" """Return an iterator that loops over the given nodes.
n = convert_to_id_container(n)
for nn in n: Parameters
yield nn ----------
n : iterable
The node ids.
"""
return iter(n)
def edge_iter(u, v): def edge_iter(u, v):
"""Return an iterator that loops over the given edges.""" """Return an iterator that loops over the given edges.
u = convert_to_id_container(u)
v = convert_to_id_container(v) Parameters
----------
u : iterable
The src ids.
v : iterable
The dst ids.
"""
if len(u) == len(v): if len(u) == len(v):
# many-many # many-many
for uu, vv in zip(u, v): for uu, vv in zip(u, v):
...@@ -38,8 +100,33 @@ def edge_iter(u, v): ...@@ -38,8 +100,33 @@ def edge_iter(u, v):
else: else:
raise ValueError('Error edges:', u, v) raise ValueError('Error edges:', u, v)
def edge_broadcasting(u, v):
"""Convert one-many and many-one edges to many-many.
Parameters
----------
u : Index
The src id(s)
v : Index
The dst id(s)
Returns
-------
uu : Index
The src id(s) after broadcasting
vv : Index
The dst id(s) after broadcasting
"""
if len(u) != len(v) and len(u) == 1:
u = toindex(F.broadcast_to(u.totensor(), v.totensor()))
elif len(u) != len(v) and len(v) == 1:
v = toindex(F.broadcast_to(v.totensor(), u.totensor()))
else:
assert len(u) == len(v)
return u, v
'''
def convert_to_id_container(x): def convert_to_id_container(x):
"""Convert the input to id container."""
if is_id_container(x): if is_id_container(x):
return x return x
elif is_id_tensor(x): elif is_id_tensor(x):
...@@ -52,7 +139,6 @@ def convert_to_id_container(x): ...@@ -52,7 +139,6 @@ def convert_to_id_container(x):
return None return None
def convert_to_id_tensor(x, ctx=None): def convert_to_id_tensor(x, ctx=None):
"""Convert the input to id tensor."""
if is_id_container(x): if is_id_container(x):
ret = F.tensor(x, dtype=F.int64) ret = F.tensor(x, dtype=F.int64)
elif is_id_tensor(x): elif is_id_tensor(x):
...@@ -64,6 +150,7 @@ def convert_to_id_tensor(x, ctx=None): ...@@ -64,6 +150,7 @@ def convert_to_id_tensor(x, ctx=None):
raise TypeError('Error node: %s' % str(x)) raise TypeError('Error node: %s' % str(x))
ret = F.to_context(ret, ctx) ret = F.to_context(ret, ctx)
return ret return ret
'''
class LazyDict(Mapping): class LazyDict(Mapping):
"""A readonly dictionary that does not materialize the storage.""" """A readonly dictionary that does not materialize the storage."""
...@@ -110,7 +197,7 @@ def build_relabel_map(x): ...@@ -110,7 +197,7 @@ def build_relabel_map(x):
Parameters Parameters
---------- ----------
x : int, tensor or container x : Index
The input ids. The input ids.
Returns Returns
...@@ -122,7 +209,7 @@ def build_relabel_map(x): ...@@ -122,7 +209,7 @@ def build_relabel_map(x):
One can use advanced indexing to convert an old id tensor to a One can use advanced indexing to convert an old id tensor to a
new id tensor: new_id = old_to_new[old_id] new id tensor: new_id = old_to_new[old_id]
""" """
x = convert_to_id_tensor(x) x = x.totensor()
unique_x, _ = F.sort(F.unique(x)) unique_x, _ = F.sort(F.unique(x))
map_len = int(F.max(unique_x)) + 1 map_len = int(F.max(unique_x)) + 1
old_to_new = F.zeros(map_len, dtype=F.int64) old_to_new = F.zeros(map_len, dtype=F.int64)
...@@ -150,12 +237,55 @@ def build_relabel_dict(x): ...@@ -150,12 +237,55 @@ def build_relabel_dict(x):
relabel_dict[v] = i relabel_dict[v] = i
return relabel_dict return relabel_dict
def edge_broadcasting(u, v): class CtxCachedObject(object):
"""Convert one-many and many-one edges to many-many.""" """A wrapper to cache object generated by different context.
if len(u) != len(v) and len(u) == 1:
u = F.broadcast_to(u, v) Note: such wrapper may incur significant overhead if the wrapped object is very light.
elif len(u) != len(v) and len(v) == 1:
v = F.broadcast_to(v, u) Parameters
else: ----------
assert len(u) == len(v) generator : callable
return u, v A callable function that can create the object given ctx as the only argument.
"""
def __init__(self, generator):
self._generator = generator
self._ctx_dict = {}
def get(self, ctx):
if not ctx in self._ctx_dict:
self._ctx_dict[ctx] = self._generator(ctx)
return self._ctx_dict[ctx]
def ctx_cached_member(func):
"""Convenient class member function wrapper to cache the function result.
The wrapped function must only have two arguments: `self` and `ctx`. The former is the
class object and the later is the context. It will check whether the class object is
freezed (by checking the `_freeze` member). If yes, it caches the function result in
the field prefixed by '_CACHED_' before the function name.
"""
cache_name = '_CACHED_' + func.__name__
@wraps(func)
def wrapper(self, ctx):
if self._freeze:
# cache
if getattr(self, cache_name, None) is None:
bind_func = lambda _ctx : func(self, _ctx)
setattr(self, cache_name, CtxCachedObject(bind_func))
return getattr(self, cache_name).get(ctx)
else:
return func(self, ctx)
return wrapper
def cached_member(func):
cache_name = '_CACHED_' + func.__name__
@wraps(func)
def wrapper(self):
if self._freeze:
# cache
if getattr(self, cache_name, None) is None:
setattr(self, cache_name, func(self))
return getattr(self, cache_name)
else:
return func(self)
return wrapper
...@@ -26,6 +26,17 @@ def update_func(node, accum): ...@@ -26,6 +26,17 @@ def update_func(node, accum):
assert node['h'].shape == accum.shape assert node['h'].shape == accum.shape
return {'h' : node['h'] + accum} return {'h' : node['h'] + accum}
def reduce_dict_func(node, msgs):
msgs = msgs['m']
reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3
assert msgs.shape[2] == D
return {'m' : th.sum(msgs, 1)}
def update_dict_func(node, accum):
assert node['h'].shape == accum['m'].shape
return {'h' : node['h'] + accum['m']}
def generate_graph(grad=False): def generate_graph(grad=False):
g = DGLGraph() g = DGLGraph()
for i in range(10): for i in range(10):
...@@ -149,7 +160,8 @@ def test_batch_send(): ...@@ -149,7 +160,8 @@ def test_batch_send():
v = th.tensor([9]) v = th.tensor([9])
g.sendto(u, v) g.sendto(u, v)
def test_batch_recv(): def test_batch_recv1():
# basic recv test
g = generate_graph() g = generate_graph()
g.register_message_func(message_func, batchable=True) g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True) g.register_reduce_func(reduce_func, batchable=True)
...@@ -162,6 +174,20 @@ def test_batch_recv(): ...@@ -162,6 +174,20 @@ def test_batch_recv():
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)}) assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
def test_batch_recv2():
# recv test with dict type reduce message
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_dict_func, batchable=True)
g.register_update_func(update_dict_func, batchable=True)
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear()
g.sendto(u, v)
g.recv(th.unique(v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
def test_update_routines(): def test_update_routines():
g = generate_graph() g = generate_graph()
g.register_message_func(message_func, batchable=True) g.register_message_func(message_func, batchable=True)
...@@ -210,6 +236,7 @@ if __name__ == '__main__': ...@@ -210,6 +236,7 @@ if __name__ == '__main__':
test_batch_setter_getter() test_batch_setter_getter()
test_batch_setter_autograd() test_batch_setter_autograd()
test_batch_send() test_batch_send()
test_batch_recv() test_batch_recv1()
test_batch_recv2()
test_update_routines() test_update_routines()
#test_delete() #test_delete()
...@@ -3,6 +3,7 @@ import numpy as np ...@@ -3,6 +3,7 @@ import numpy as np
import networkx as nx import networkx as nx
from dgl import DGLGraph from dgl import DGLGraph
from dgl.cached_graph import * from dgl.cached_graph import *
from dgl.utils import Index
def check_eq(a, b): def check_eq(a, b):
assert a.shape == b.shape assert a.shape == b.shape
...@@ -15,22 +16,18 @@ def test_basics(): ...@@ -15,22 +16,18 @@ def test_basics():
g.add_edge(1, 3) g.add_edge(1, 3)
g.add_edge(2, 4) g.add_edge(2, 4)
g.add_edge(2, 5) g.add_edge(2, 5)
g.add_edge(0, 2)
cg = create_cached_graph(g) cg = create_cached_graph(g)
u = th.tensor([0, 1, 1, 2, 2]) u = Index(th.tensor([0, 0, 1, 1, 2, 2]))
v = th.tensor([1, 2, 3, 4, 5]) v = Index(th.tensor([1, 2, 2, 3, 4, 5]))
check_eq(cg.get_edge_id(u, v), th.tensor([0, 1, 2, 3, 4])) check_eq(cg.get_edge_id(u, v).totensor(), th.tensor([0, 5, 1, 2, 3, 4]))
cg.add_edges(0, 2) query = Index(th.tensor([1, 2]))
assert cg.get_edge_id(0, 2) == 5
query = th.tensor([1, 2])
s, d = cg.in_edges(query) s, d = cg.in_edges(query)
check_eq(s, th.tensor([0, 0, 1])) check_eq(s.totensor(), th.tensor([0, 0, 1]))
check_eq(d, th.tensor([1, 2, 2])) check_eq(d.totensor(), th.tensor([1, 2, 2]))
s, d = cg.out_edges(query) s, d = cg.out_edges(query)
check_eq(s, th.tensor([1, 1, 2, 2])) check_eq(s.totensor(), th.tensor([1, 1, 2, 2]))
check_eq(d, th.tensor([2, 3, 4, 5])) check_eq(d.totensor(), th.tensor([2, 3, 4, 5]))
print(cg._graph.get_adjacency())
print(cg._graph.get_adjacency(eids=True))
if __name__ == '__main__': if __name__ == '__main__':
test_basics() test_basics()
...@@ -2,6 +2,7 @@ import torch as th ...@@ -2,6 +2,7 @@ import torch as th
from torch.autograd import Variable from torch.autograd import Variable
import numpy as np import numpy as np
from dgl.frame import Frame, FrameRef from dgl.frame import Frame, FrameRef
from dgl.utils import Index
N = 10 N = 10
D = 5 D = 5
...@@ -112,7 +113,7 @@ def test_append2(): ...@@ -112,7 +113,7 @@ def test_append2():
assert not f.is_span_whole_column() assert not f.is_span_whole_column()
assert f.num_rows == 3 * N assert f.num_rows == 3 * N
new_idx = list(range(N)) + list(range(2*N, 4*N)) new_idx = list(range(N)) + list(range(2*N, 4*N))
assert check_eq(f.index_tensor(), th.tensor(new_idx)) assert check_eq(f.index().totensor(), th.tensor(new_idx))
assert data.num_rows == 4 * N assert data.num_rows == 4 * N
def test_row1(): def test_row1():
...@@ -122,20 +123,20 @@ def test_row1(): ...@@ -122,20 +123,20 @@ def test_row1():
# getter # getter
# test non-duplicate keys # test non-duplicate keys
rowid = th.tensor([0, 2]) rowid = Index(th.tensor([0, 2]))
rows = f[rowid] rows = f[rowid]
for k, v in rows.items(): for k, v in rows.items():
assert v.shape == (len(rowid), D) assert v.shape == (len(rowid), D)
assert check_eq(v, data[k][rowid]) assert check_eq(v, data[k][rowid])
# test duplicate keys # test duplicate keys
rowid = th.tensor([8, 2, 2, 1]) rowid = Index(th.tensor([8, 2, 2, 1]))
rows = f[rowid] rows = f[rowid]
for k, v in rows.items(): for k, v in rows.items():
assert v.shape == (len(rowid), D) assert v.shape == (len(rowid), D)
assert check_eq(v, data[k][rowid]) assert check_eq(v, data[k][rowid])
# setter # setter
rowid = th.tensor([0, 2, 4]) rowid = Index(th.tensor([0, 2, 4]))
vals = {'a1' : th.zeros((len(rowid), D)), vals = {'a1' : th.zeros((len(rowid), D)),
'a2' : th.zeros((len(rowid), D)), 'a2' : th.zeros((len(rowid), D)),
'a3' : th.zeros((len(rowid), D)), 'a3' : th.zeros((len(rowid), D)),
...@@ -152,13 +153,13 @@ def test_row2(): ...@@ -152,13 +153,13 @@ def test_row2():
# getter # getter
c1 = f['a1'] c1 = f['a1']
# test non-duplicate keys # test non-duplicate keys
rowid = th.tensor([0, 2]) rowid = Index(th.tensor([0, 2]))
rows = f[rowid] rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D))) rows['a1'].backward(th.ones((len(rowid), D)))
assert check_eq(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.])) assert check_eq(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]))
c1.grad.data.zero_() c1.grad.data.zero_()
# test duplicate keys # test duplicate keys
rowid = th.tensor([8, 2, 2, 1]) rowid = Index(th.tensor([8, 2, 2, 1]))
rows = f[rowid] rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D))) rows['a1'].backward(th.ones((len(rowid), D)))
assert check_eq(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.])) assert check_eq(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.]))
...@@ -166,7 +167,7 @@ def test_row2(): ...@@ -166,7 +167,7 @@ def test_row2():
# setter # setter
c1 = f['a1'] c1 = f['a1']
rowid = th.tensor([0, 2, 4]) rowid = Index(th.tensor([0, 2, 4]))
vals = {'a1' : Variable(th.zeros((len(rowid), D)), requires_grad=True), vals = {'a1' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
'a2' : Variable(th.zeros((len(rowid), D)), requires_grad=True), 'a2' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
'a3' : Variable(th.zeros((len(rowid), D)), requires_grad=True), 'a3' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
...@@ -210,14 +211,14 @@ def test_sharing(): ...@@ -210,14 +211,14 @@ def test_sharing():
f2_a1 = f2['a1'] f2_a1 = f2['a1']
# test write # test write
# update own ref should not been seen by the other. # update own ref should not been seen by the other.
f1[th.tensor([0, 1])] = { f1[Index(th.tensor([0, 1]))] = {
'a1' : th.zeros([2, D]), 'a1' : th.zeros([2, D]),
'a2' : th.zeros([2, D]), 'a2' : th.zeros([2, D]),
'a3' : th.zeros([2, D]), 'a3' : th.zeros([2, D]),
} }
assert check_eq(f2['a1'], f2_a1) assert check_eq(f2['a1'], f2_a1)
# update shared space should been seen by the other. # update shared space should been seen by the other.
f1[th.tensor([2, 3])] = { f1[Index(th.tensor([2, 3]))] = {
'a1' : th.ones([2, D]), 'a1' : th.ones([2, D]),
'a2' : th.ones([2, D]), 'a2' : th.ones([2, D]),
'a3' : th.ones([2, D]), 'a3' : th.ones([2, D]),
......
...@@ -24,35 +24,71 @@ def generate_graph(grad=False): ...@@ -24,35 +24,71 @@ def generate_graph(grad=False):
g.set_e_repr({'l' : ecol}) g.set_e_repr({'l' : ecol})
return g return g
def test_subgraph(): def test_basics():
g = generate_graph() g = generate_graph()
h = g.get_n_repr()['h'] h = g.get_n_repr()['h']
l = g.get_e_repr()['l'] l = g.get_e_repr()['l']
sg = g.subgraph([0, 2, 3, 6, 7, 9]) nid = [0, 2, 3, 6, 7, 9]
eid = [2, 3, 4, 5, 10, 11, 12, 13, 16]
sg = g.subgraph(nid)
# the subgraph is empty initially
assert len(sg.get_n_repr()) == 0
assert len(sg.get_e_repr()) == 0
# the data is copied after explict copy from
sg.copy_from(g)
assert len(sg.get_n_repr()) == 1
assert len(sg.get_e_repr()) == 1
sh = sg.get_n_repr()['h'] sh = sg.get_n_repr()['h']
check_eq(h[th.tensor([0, 2, 3, 6, 7, 9])], sh) assert check_eq(h[nid], sh)
''' '''
s, d, eid s, d, eid
0, 1, 0 0, 1, 0
1, 9, 1 1, 9, 1
0, 2, 2 0, 2, 2 1
2, 9, 3 2, 9, 3 1
0, 3, 4 0, 3, 4 1
3, 9, 5 3, 9, 5 1
0, 4, 6 0, 4, 6
4, 9, 7 4, 9, 7
0, 5, 8 0, 5, 8
5, 9, 9 5, 9, 9 3
0, 6, 10 0, 6, 10 1
6, 9, 11 6, 9, 11 1 3
0, 7, 12 0, 7, 12 1
7, 9, 13 7, 9, 13 1 3
0, 8, 14 0, 8, 14
8, 9, 15 8, 9, 15 3
9, 0, 16 9, 0, 16 1
''' '''
eid = th.tensor([2, 3, 4, 5, 10, 11, 12, 13, 16]) assert check_eq(l[eid], sg.get_e_repr()['l'])
check_eq(l[eid], sg.get_e_repr()['l']) # update the node/edge features on the subgraph should NOT
# reflect to the parent graph.
sg.set_n_repr({'h' : th.zeros((6, D))})
assert check_eq(h, g.get_n_repr()['h'])
def test_merge():
g = generate_graph()
g.set_n_repr({'h' : th.zeros((10, D))})
g.set_e_repr({'l' : th.zeros((17, D))})
# subgraphs
sg1 = g.subgraph([0, 2, 3, 6, 7, 9])
sg1.set_n_repr({'h' : th.ones((6, D))})
sg1.set_e_repr({'l' : th.ones((9, D))})
sg2 = g.subgraph([0, 2, 3, 4])
sg2.set_n_repr({'h' : th.ones((4, D)) * 2})
sg3 = g.subgraph([5, 6, 7, 8, 9])
sg3.set_e_repr({'l' : th.ones((4, D)) * 3})
g.merge([sg1, sg2, sg3])
h = g.get_n_repr()['h'][:,0]
l = g.get_e_repr()['l'][:,0]
assert check_eq(h, th.tensor([3., 0., 3., 3., 2., 0., 1., 1., 0., 1.]))
assert check_eq(l,
th.tensor([0., 0., 1., 1., 1., 1., 0., 0., 0., 3., 1., 4., 1., 4., 0., 3., 1.]))
if __name__ == '__main__': if __name__ == '__main__':
test_subgraph() test_basics()
test_merge()
...@@ -6,6 +6,12 @@ def message_func(src, edge): ...@@ -6,6 +6,12 @@ def message_func(src, edge):
def update_func(node, accum): def update_func(node, accum):
return {'h' : node['h'] + accum} return {'h' : node['h'] + accum}
def message_dict_func(src, edge):
return {'m' : src['h']}
def update_dict_func(node, accum):
return {'h' : node['h'] + accum['m']}
def generate_graph(): def generate_graph():
g = DGLGraph() g = DGLGraph()
for i in range(10): for i in range(10):
...@@ -23,12 +29,18 @@ def check(g, h): ...@@ -23,12 +29,18 @@ def check(g, h):
h = [str(x) for x in h] h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h)) assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def test_sendrecv(): def register1(g):
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func) g.register_message_func(message_func)
g.register_update_func(update_func) g.register_update_func(update_func)
g.register_reduce_func('sum') g.register_reduce_func('sum')
def register2(g):
g.register_message_func(message_dict_func)
g.register_update_func(update_dict_func)
g.register_reduce_func('sum')
def _test_sendrecv(g):
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(0, 1) g.sendto(0, 1)
g.recv(1) g.recv(1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10]) check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
...@@ -37,12 +49,8 @@ def test_sendrecv(): ...@@ -37,12 +49,8 @@ def test_sendrecv():
g.recv(9) g.recv(9)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23]) check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23])
def test_multi_sendrecv(): def _test_multi_sendrecv(g):
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_update_func(update_func)
g.register_reduce_func('sum')
# one-many # one-many
g.sendto(0, [1, 2, 3]) g.sendto(0, [1, 2, 3])
g.recv([1, 2, 3]) g.recv([1, 2, 3])
...@@ -56,12 +64,8 @@ def test_multi_sendrecv(): ...@@ -56,12 +64,8 @@ def test_multi_sendrecv():
g.recv([4, 5, 9]) g.recv([4, 5, 9])
check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45]) check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45])
def test_update_routines(): def _test_update_routines(g):
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_update_func(update_func)
g.register_reduce_func('sum')
g.update_by_edge(0, 1) g.update_by_edge(0, 1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10]) check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.update_to(9) g.update_to(9)
...@@ -71,6 +75,30 @@ def test_update_routines(): ...@@ -71,6 +75,30 @@ def test_update_routines():
g.update_all() g.update_all()
check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108]) check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108])
def test_sendrecv():
g = generate_graph()
register1(g)
_test_sendrecv(g)
g = generate_graph()
register2(g)
_test_sendrecv(g)
def test_multi_sendrecv():
g = generate_graph()
register1(g)
_test_multi_sendrecv(g)
g = generate_graph()
register2(g)
_test_multi_sendrecv(g)
def test_update_routines():
g = generate_graph()
register1(g)
_test_update_routines(g)
g = generate_graph()
register2(g)
_test_update_routines(g)
if __name__ == '__main__': if __name__ == '__main__':
test_sendrecv() test_sendrecv()
test_multi_sendrecv() test_multi_sendrecv()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment