"docs/source/en/api/models/auto_model.md" did not exist on "c6ae9b7df65bbeee93e49fc11e6044f9ce8b56e1"
Unverified Commit 6105e441 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Many fix and updates (#47)

* subgraph copy from

* WIP

* cached members

* Change all usage of id tensor to the new Index object; remove set device in DGLGraph;

* subgraph merge API tested

* add dict type reduced msg test
parent 3721822e
......@@ -154,8 +154,6 @@ def main(args):
# create GCN model
g = DGLGraph(data.graph)
if cuda:
g.set_device(dgl.gpu(args.gpu))
# create model
model = GAT(g,
......
......@@ -85,8 +85,6 @@ def main(args):
# create GCN model
g = DGLGraph(data.graph)
if cuda:
g.set_device(dgl.gpu(args.gpu))
model = GCN(g,
in_feats,
args.n_hidden,
......
......@@ -79,8 +79,6 @@ def main(args):
# create GCN model
g = DGLGraph(data.graph)
if cuda:
g.set_device(dgl.gpu(args.gpu))
model = GCN(g,
in_feats,
args.n_hidden,
......
......@@ -213,7 +213,6 @@ def main(args):
count, label, node_list, mask, active, label1, label1_tensor = ground_truth[0]
label, node_list, mask, label1_tensor = move2cuda((label, node_list, mask, label1_tensor))
ground_truth[0] = (count, label, node_list, mask, active, label1, label1_tensor)
ground_truth[1][0].set_device(dgl.gpu(args.gpu))
optimizer.zero_grad()
# create new empty graphs
......
......@@ -2,6 +2,7 @@ from __future__ import absolute_import
import torch as th
import scipy.sparse
import dgl.context as context
# Tensor types
Tensor = th.Tensor
......@@ -73,3 +74,9 @@ def to_context(x, ctx):
return x.cpu()
else:
raise RuntimeError('Invalid context', ctx)
def get_context(x):
if x.device.type == 'cpu':
return context.cpu()
else:
return context.gpu(x.device.index)
......@@ -87,7 +87,7 @@ def unbatch(graph_batch):
num_graphs = len(graph_list)
# split and set node attrs
attrs = [{} for _ in range(num_graphs)] # node attr dict for each graph
for key in graph_batch.get_n_attr_list():
for key in graph_batch.node_attr_schemes():
vals = F.unpack(graph_batch.pop_n_repr(key), graph_batch.num_nodes)
for attr, val in zip(attrs, vals):
attr[key] = val
......@@ -96,7 +96,7 @@ def unbatch(graph_batch):
# split and set edge attrs
attrs = [{} for _ in range(num_graphs)] # edge attr dict for each graph
for key in graph_batch.get_e_attr_list():
for key in graph_batch.edge_attr_schemes():
vals = F.unpack(graph_batch.pop_e_repr(key), graph_batch.num_edges)
for attr, val in zip(attrs, vals):
attr[key] = val
......
......@@ -8,6 +8,9 @@ def message_from_src(src, edge):
def reduce_sum(node, msgs):
if isinstance(msgs, list):
if isinstance(msgs[0], dict):
return {k : sum(m[k] for m in msgs) for k in msgs[0].keys()}
else:
return sum(msgs)
else:
return F.sum(msgs, 1)
......
......@@ -14,15 +14,21 @@ import dgl.utils as utils
class CachedGraph:
def __init__(self):
self._graph = igraph.Graph(directed=True)
self._adjmat = None # cached adjacency matrix
self._freeze = False
def add_nodes(self, num_nodes):
if self._freeze:
raise RuntimeError('Freezed cached graph cannot be mutated.')
self._graph.add_vertices(num_nodes)
def add_edge(self, u, v):
if self._freeze:
raise RuntimeError('Freezed cached graph cannot be mutated.')
self._graph.add_edge(u, v)
def add_edges(self, u, v):
if self._freeze:
raise RuntimeError('Freezed cached graph cannot be mutated.')
# The edge will be assigned ids equal to the order.
uvs = list(utils.edge_iter(u, v))
self._graph.add_edges(uvs)
......@@ -30,7 +36,7 @@ class CachedGraph:
def get_edge_id(self, u, v):
uvs = list(utils.edge_iter(u, v))
eids = self._graph.get_eids(uvs)
return utils.convert_to_id_tensor(eids)
return utils.toindex(eids)
def in_edges(self, v):
src = []
......@@ -39,8 +45,8 @@ class CachedGraph:
uu = self._graph.predecessors(vv)
src += uu
dst += [vv] * len(uu)
src = utils.convert_to_id_tensor(src)
dst = utils.convert_to_id_tensor(dst)
src = utils.toindex(src)
dst = utils.toindex(dst)
return src, dst
def out_edges(self, u):
......@@ -50,44 +56,51 @@ class CachedGraph:
vv = self._graph.successors(uu)
src += [uu] * len(vv)
dst += vv
src = utils.convert_to_id_tensor(src)
dst = utils.convert_to_id_tensor(dst)
src = utils.toindex(src)
dst = utils.toindex(dst)
return src, dst
def in_degrees(self, v):
degs = self._graph.indegree(list(v))
return utils.toindex(degs)
def num_edges(self):
return self._graph.ecount()
@utils.cached_member
def edges(self):
elist = self._graph.get_edgelist()
src = [u for u, _ in elist]
dst = [v for _, v in elist]
src = utils.convert_to_id_tensor(src)
dst = utils.convert_to_id_tensor(dst)
src = utils.toindex(src)
dst = utils.toindex(dst)
return src, dst
def in_degrees(self, v):
degs = self._graph.indegree(list(v))
return utils.convert_to_id_tensor(degs)
@utils.ctx_cached_member
def adjmat(self, ctx):
"""Return a sparse adjacency matrix.
The row dimension represents the dst nodes; the column dimension
represents the src nodes.
"""
if self._adjmat is None:
elist = self._graph.get_edgelist()
src = [u for u, _ in elist]
dst = [v for _, v in elist]
src = F.unsqueeze(utils.convert_to_id_tensor(src), 0)
dst = F.unsqueeze(utils.convert_to_id_tensor(dst), 0)
src = F.tensor([u for u, _ in elist], dtype=F.int64)
dst = F.tensor([v for _, v in elist], dtype=F.int64)
src = F.unsqueeze(src, 0)
dst = F.unsqueeze(dst, 0)
idx = F.pack([dst, src])
n = self._graph.vcount()
dat = F.ones((len(elist),))
self._adjmat = F.sparse_tensor(idx, dat, [n, n])
# TODO(minjie): manually convert adjmat to context
self._adjmat = F.to_context(self._adjmat, ctx)
return self._adjmat
mat = F.sparse_tensor(idx, dat, [n, n])
mat = F.to_context(mat, ctx)
return mat
def freeze(self):
self._freeze = True
def create_cached_graph(dglgraph):
cg = CachedGraph()
cg.add_nodes(dglgraph.number_of_nodes())
cg._graph.add_edges(dglgraph.edge_list)
cg.freeze()
return cg
......@@ -8,6 +8,12 @@ class Context(object):
def __str__(self):
return '{}:{}'.format(self.device, self.device_id)
def __eq__(self, other):
return self.device == other.device and self.device_id == other.device_id
def __hash__(self):
return hash((self.device, self.device_id))
def gpu(gpuid):
return Context('gpu', gpuid)
......
......@@ -6,7 +6,7 @@ import numpy as np
import dgl.backend as F
from dgl.backend import Tensor
from dgl.utils import LazyDict
import dgl.utils as utils
class Frame(MutableMapping):
def __init__(self, data=None):
......@@ -77,15 +77,24 @@ class Frame(MutableMapping):
return self.num_columns
class FrameRef(MutableMapping):
"""Frame reference
Parameters
----------
frame : dgl.frame.Frame
The underlying frame.
index : iterable of int
The rows that are referenced in the underlying frame.
"""
def __init__(self, frame=None, index=None):
self._frame = frame if frame is not None else Frame()
if index is None:
self._index = slice(0, self._frame.num_rows)
self._index_data = slice(0, self._frame.num_rows)
else:
# check no duplicate index
# check no duplication
assert len(index) == len(np.unique(index))
self._index = index
self._index_tensor = None
self._index_data = index
self._index = None
@property
def schemes(self):
......@@ -97,10 +106,10 @@ class FrameRef(MutableMapping):
@property
def num_rows(self):
if isinstance(self._index, slice):
return self._index.stop
if isinstance(self._index_data, slice):
return self._index_data.stop
else:
return len(self._index)
return len(self._index_data)
def __contains__(self, key):
return key in self._frame
......@@ -114,15 +123,17 @@ class FrameRef(MutableMapping):
def select_rows(self, query):
rowids = self._getrowid(query)
def _lazy_select(key):
return F.gather_row(self._frame[key], rowids)
return LazyDict(_lazy_select, keys=self.schemes)
idx = rowids.totensor(F.get_context(self._frame[key]))
return F.gather_row(self._frame[key], idx)
return utils.LazyDict(_lazy_select, keys=self.schemes)
def get_column(self, name):
col = self._frame[name]
if self.is_span_whole_column():
return col
else:
return F.gather_row(col, self.index_tensor())
idx = self.index().totensor(F.get_context(col))
return F.gather_row(col, idx)
def __setitem__(self, key, val):
if isinstance(key, str):
......@@ -134,22 +145,26 @@ class FrameRef(MutableMapping):
shp = F.shape(col)
if self.is_span_whole_column():
if self.num_columns == 0:
self._index = slice(0, shp[0])
self._index_data = slice(0, shp[0])
self._clear_cache()
assert shp[0] == self.num_rows
self._frame[name] = col
else:
colctx = F.get_context(col)
if name in self._frame:
fcol = self._frame[name]
else:
fcol = F.zeros((self._frame.num_rows,) + shp[1:])
newfcol = F.scatter_row(fcol, self.index_tensor(), col)
fcol = F.to_context(fcol, colctx)
idx = self.index().totensor(colctx)
newfcol = F.scatter_row(fcol, idx, col)
self._frame[name] = newfcol
def update_rows(self, query, other):
rowids = self._getrowid(query)
for key, col in other.items():
self._frame[key] = F.scatter_row(self._frame[key], rowids, col)
idx = rowids.totensor(F.get_context(self._frame[key]))
self._frame[key] = F.scatter_row(self._frame[key], idx, col)
def __delitem__(self, key):
if isinstance(key, str):
......@@ -161,10 +176,10 @@ class FrameRef(MutableMapping):
def delete_rows(self, query):
query = F.asnumpy(query)
if isinstance(self._index, slice):
self._index = list(range(self._index.start, self._index.stop))
arr = np.array(self._index, dtype=np.int32)
self._index = list(np.delete(arr, query))
if isinstance(self._index_data, slice):
self._index_data = list(range(self._index_data.start, self._index_data.stop))
arr = np.array(self._index_data, dtype=np.int32)
self._index_data = list(np.delete(arr, query))
self._clear_cache()
def append(self, other):
......@@ -174,16 +189,16 @@ class FrameRef(MutableMapping):
self._frame.append(other)
# update index
if span_whole:
self._index = slice(0, self._frame.num_rows)
else:
new_idx = list(range(self._index.start, self._index.stop))
self._index_data = slice(0, self._frame.num_rows)
elif contiguous:
new_idx = list(range(self._index_data.start, self._index_data.stop))
new_idx += list(range(old_nrows, self._frame.num_rows))
self._index = new_idx
self._index_data = new_idx
self._clear_cache()
def clear(self):
self._frame.clear()
self._index = slice(0, 0)
self._index_data = slice(0, 0)
self._clear_cache()
def __iter__(self):
......@@ -194,26 +209,73 @@ class FrameRef(MutableMapping):
def is_contiguous(self):
# NOTE: this check could have false negative
return isinstance(self._index, slice)
return isinstance(self._index_data, slice)
def is_span_whole_column(self):
return self.is_contiguous() and self.num_rows == self._frame.num_rows
def _getrowid(self, query):
if isinstance(self._index, slice):
if self.is_contiguous():
# shortcut for identical mapping
return query
else:
return F.gather_row(self.index_tensor(), query)
idxtensor = self.index().totensor()
return utils.toindex(F.gather_row(idxtensor, query.totensor()))
def index_tensor(self):
# TODO(minjie): context
if self._index_tensor is None:
def index(self):
if self._index is None:
if self.is_contiguous():
self._index_tensor = F.arange(self._index.stop, dtype=F.int64)
self._index = utils.toindex(
F.arange(self._index_data.stop, dtype=F.int64))
else:
self._index_tensor = F.tensor(self._index, dtype=F.int64)
return self._index_tensor
self._index = utils.toindex(self._index_data)
return self._index
def _clear_cache(self):
self._index_tensor = None
def merge_frames(frames, indices, max_index, reduce_func):
"""Merge a list of frames.
The result frame contains `max_index` number of rows. For each frame in
the given list, its row is merged as follows:
merged[indices[i][row]] += frames[i][row]
Parameters
----------
frames : iterator of dgl.frame.FrameRef
A list of frames to be merged.
indices : iterator of dgl.utils.Index
The indices of the frame rows.
reduce_func : str
The reduce function (only 'sum' is supported currently)
Returns
-------
merged : FrameRef
The merged frame.
"""
assert reduce_func == 'sum'
assert len(frames) > 0
schemes = frames[0].schemes
# create an adj to merge
# row index is equal to the concatenation of all the indices.
row = sum([idx.tolist() for idx in indices], [])
col = list(range(len(row)))
n = max_index
m = len(row)
row = F.unsqueeze(F.tensor(row, dtype=F.int64), 0)
col = F.unsqueeze(F.tensor(col, dtype=F.int64), 0)
idx = F.pack([row, col])
dat = F.ones((m,))
adjmat = F.sparse_tensor(idx, dat, [n, m])
ctx_adjmat = utils.CtxCachedObject(lambda ctx: F.to_context(adjmat, ctx))
merged = {}
for key in schemes:
# the rhs of the spmv is the concatenation of all the frame columns
feats = F.pack([fr[key] for fr in frames])
merged_feats = F.spmm(ctx_adjmat.get(F.get_context(feats)), feats)
merged[key] = merged_feats
merged = FrameRef(Frame(merged))
return merged
......@@ -12,7 +12,7 @@ from dgl.backend import Tensor
import dgl.builtin as builtin
from dgl.cached_graph import CachedGraph, create_cached_graph
import dgl.context as context
from dgl.frame import FrameRef
from dgl.frame import FrameRef, merge_frames
from dgl.nx_adapt import nx_init
import dgl.scheduler as scheduler
import dgl.utils as utils
......@@ -62,12 +62,11 @@ class DGLGraph(DiGraph):
self._reduce_func = None
self._update_func = None
self._edge_func = None
self._context = context.cpu()
def get_n_attr_list(self):
def node_attr_schemes(self):
return self._node_frame.schemes
def get_e_attr_list(self):
def edge_attr_schemes(self):
return self._edge_frame.schemes
def set_n_repr(self, hu, u=ALL):
......@@ -92,7 +91,7 @@ class DGLGraph(DiGraph):
if is_all(u):
num_nodes = self.number_of_nodes()
else:
u = utils.convert_to_id_tensor(u, self.context)
u = utils.toindex(u)
num_nodes = len(u)
if isinstance(hu, dict):
for key, val in hu.items():
......@@ -108,10 +107,9 @@ class DGLGraph(DiGraph):
self._node_frame[__REPR__] = hu
else:
if isinstance(hu, dict):
for key, val in hu.items():
self._node_frame[key] = F.scatter_row(self._node_frame[key], u, val)
self._node_frame[u] = hu
else:
self._node_frame[__REPR__] = F.scatter_row(self._node_frame[__REPR__], u, hu)
self._node_frame[u] = {__REPR__ : hu}
def get_n_repr(self, u=ALL):
"""Get node(s) representation.
......@@ -127,9 +125,9 @@ class DGLGraph(DiGraph):
else:
return dict(self._node_frame)
else:
u = utils.convert_to_id_tensor(u, self.context)
u = utils.toindex(u)
if len(self._node_frame) == 1 and __REPR__ in self._node_frame:
return self._node_frame[__REPR__][u]
return self._node_frame.select_rows(u)[__REPR__]
else:
return self._node_frame.select_rows(u)
......@@ -168,10 +166,10 @@ class DGLGraph(DiGraph):
v_is_all = is_all(v)
assert u_is_all == v_is_all
if u_is_all:
num_edges = self.number_of_edges()
num_edges = self.cached_graph.num_edges()
else:
u = utils.convert_to_id_tensor(u, self.context)
v = utils.convert_to_id_tensor(v, self.context)
u = utils.toindex(u)
v = utils.toindex(v)
num_edges = max(len(u), len(v))
if isinstance(h_uv, dict):
for key, val in h_uv.items():
......@@ -188,10 +186,9 @@ class DGLGraph(DiGraph):
else:
eid = self.cached_graph.get_edge_id(u, v)
if isinstance(h_uv, dict):
for key, val in h_uv.items():
self._edge_frame[key] = F.scatter_row(self._edge_frame[key], eid, val)
self._edge_frame[eid] = h_uv
else:
self._edge_frame[__REPR__] = F.scatter_row(self._edge_frame[__REPR__], eid, h_uv)
self._edge_frame[eid] = {__REPR__ : h_uv}
def set_e_repr_by_id(self, h_uv, eid=ALL):
"""Set edge(s) representation by edge id.
......@@ -205,9 +202,9 @@ class DGLGraph(DiGraph):
"""
# sanity check
if is_all(eid):
num_edges = self.number_of_edges()
num_edges = self.cached_graph.num_edges()
else:
eid = utils.convert_to_id_tensor(eid, self.context)
eid = utils.toindex(eid)
num_edges = len(eid)
if isinstance(h_uv, dict):
for key, val in h_uv.items():
......@@ -223,10 +220,9 @@ class DGLGraph(DiGraph):
self._edge_frame[__REPR__] = h_uv
else:
if isinstance(h_uv, dict):
for key, val in h_uv.items():
self._edge_frame[key] = F.scatter_row(self._edge_frame[key], eid, val)
self._edge_frame[eid] = h_uv
else:
self._edge_frame[__REPR__] = F.scatter_row(self._edge_frame[__REPR__], eid, h_uv)
self._edge_frame[eid] = {__REPR__ : h_uv}
def get_e_repr(self, u=ALL, v=ALL):
"""Get node(s) representation.
......@@ -247,11 +243,11 @@ class DGLGraph(DiGraph):
else:
return dict(self._edge_frame)
else:
u = utils.convert_to_id_tensor(u, self.context)
v = utils.convert_to_id_tensor(v, self.context)
u = utils.toindex(u)
v = utils.toindex(v)
eid = self.cached_graph.get_edge_id(u, v)
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__][eid]
return self._edge_frame.select_rows(eid)[__REPR__]
else:
return self._edge_frame.select_rows(eid)
......@@ -279,27 +275,12 @@ class DGLGraph(DiGraph):
else:
return dict(self._edge_frame)
else:
eid = utils.convert_to_id_tensor(eid, self.context)
eid = utils.toindex(eid)
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__][eid]
return self._edge_frame.select_rows(eid)[__REPR__]
else:
return self._edge_frame.select_rows(eid)
def set_device(self, ctx):
"""Set device context for this graph.
Parameters
----------
ctx : dgl.context.Context
The device context.
"""
self._context = ctx
@property
def context(self):
"""Get the device context of this graph."""
return self._context
def register_message_func(self,
message_func,
batchable=False):
......@@ -356,27 +337,6 @@ class DGLGraph(DiGraph):
"""
self._update_func = (update_func, batchable)
def readout(self,
readout_func,
nodes=ALL,
edges=ALL):
"""Trigger the readout function on the specified nodes/edges.
Parameters
----------
readout_func : callable
Readout function.
nodes : str, node, container or tensor
The nodes to get reprs from.
edges : str, pair of nodes, pair of containers or pair of tensors
The edges to get reprs from.
"""
nodes = self._nodes_or_all(nodes)
edges = self._edges_or_all(edges)
nstates = [self.nodes[n] for n in nodes]
estates = [self.edges[e] for e in edges]
return readout_func(nstates, estates)
def sendto(self, u, v, message_func=None, batchable=False):
"""Trigger the message function on edge u->v
......@@ -413,6 +373,9 @@ class DGLGraph(DiGraph):
f_msg = _get_message_func(message_func)
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = f_msg(_get_repr(self.nodes[uu]),
_get_repr(self.edges[uu, vv]))
......@@ -428,8 +391,8 @@ class DGLGraph(DiGraph):
edge_reprs = self.get_e_repr()
msgs = message_func(src_reprs, edge_reprs)
else:
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
u = utils.toindex(u)
v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v)
eid = self.cached_graph.get_edge_id(u, v)
self.msg_graph.add_edges(u, v)
......@@ -475,6 +438,9 @@ class DGLGraph(DiGraph):
def _nonbatch_update_edge(self, u, v, edge_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = edge_func(_get_repr(self.nodes[uu]),
_get_repr(self.nodes[vv]),
......@@ -491,8 +457,8 @@ class DGLGraph(DiGraph):
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs)
self.set_e_repr(new_edge_reprs)
else:
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
u = utils.toindex(u)
v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v)
eid = self.cached_graph.get_edge_id(u, v)
# call the UDF
......@@ -559,6 +525,8 @@ class DGLGraph(DiGraph):
f_update = update_func
if is_all(u):
u = list(range(0, self.number_of_nodes()))
else:
u = utils.toindex(u)
for i, uu in enumerate(utils.node_iter(u)):
# reduce phase
msgs_batch = [self.edges[vv, uu].pop(__MSG__)
......@@ -586,14 +554,16 @@ class DGLGraph(DiGraph):
new_ns = f_update(reordered_ns, all_reduced_msgs)
if is_all(v):
# First do reorder and then replace the whole column.
_, indices = F.sort(reordered_v)
# TODO(minjie): manually convert ids to context.
indices = F.to_context(indices, self.context)
_, indices = F.sort(reordered_v.totensor())
indices = utils.toindex(indices)
# TODO(minjie): following code should be included in Frame somehow.
if isinstance(new_ns, dict):
for key, val in new_ns.items():
self._node_frame[key] = F.gather_row(val, indices)
idx = indices.totensor(F.get_context(val))
self._node_frame[key] = F.gather_row(val, idx)
else:
self._node_frame[__REPR__] = F.gather_row(new_ns, indices)
idx = indices.totensor(F.get_context(new_ns))
self._node_frame[__REPR__] = F.gather_row(new_ns, idx)
else:
# Use setter to do reorder.
self.set_n_repr(new_ns, reordered_v)
......@@ -605,9 +575,14 @@ class DGLGraph(DiGraph):
if is_all(v):
v = list(range(self.number_of_nodes()))
# freeze message graph
self.msg_graph.freeze()
# sanity checks
v = utils.convert_to_id_tensor(v)
v = utils.toindex(v)
f_reduce = _get_reduce_func(reduce_func)
# degree bucketing
degrees, v_buckets = scheduler.degree_bucketing(self.msg_graph, v)
reduced_msgs = []
......@@ -617,8 +592,6 @@ class DGLGraph(DiGraph):
bkt_len = len(v_bkt)
uu, vv = self.msg_graph.in_edges(v_bkt)
in_msg_ids = self.msg_graph.get_edge_id(uu, vv)
# TODO(minjie): manually convert ids to context.
in_msg_ids = F.to_context(in_msg_ids, self.context)
in_msgs = self._msg_frame.select_rows(in_msg_ids)
# Reshape the column tensor to (B, Deg, ...).
def _reshape_fn(msg):
......@@ -641,10 +614,14 @@ class DGLGraph(DiGraph):
self.clear_messages()
# Read the node states in the degree-bucketing order.
reordered_v = F.pack(v_buckets)
reordered_v = utils.toindex(F.pack(
[v_bkt.totensor() for v_bkt in v_buckets]))
# Pack all reduced msgs together
if isinstance(reduced_msgs[0], dict):
all_reduced_msgs = {key : F.pack(val) for key, val in reduced_msgs.items()}
keys = reduced_msgs[0].keys()
all_reduced_msgs = {
key : F.pack([msg[key] for msg in reduced_msgs])
for key in keys}
else:
all_reduced_msgs = F.pack(reduced_msgs)
......@@ -697,6 +674,9 @@ class DGLGraph(DiGraph):
update_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
self._nonbatch_sendto(u, v, message_func)
dst = set()
for uu, vv in utils.edge_iter(u, v):
......@@ -713,12 +693,14 @@ class DGLGraph(DiGraph):
self.update_all(message_func, reduce_func, update_func, True)
elif message_func == 'from_src' and reduce_func == 'sum':
# TODO(minjie): check the validity of edges u->v
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
u = utils.toindex(u)
v = utils.toindex(v)
# TODO(minjie): broadcasting is optional for many-one input.
u, v = utils.edge_broadcasting(u, v)
# relabel destination nodes.
new2old, old2new = utils.build_relabel_map(v)
u = u.totensor()
v = v.totensor()
# TODO(minjie): should not directly use []
new_v = old2new[v]
# create adj mat
......@@ -726,8 +708,8 @@ class DGLGraph(DiGraph):
dat = F.ones((len(u),))
n = self.number_of_nodes()
m = len(new2old)
# TODO(minjie): context
adjmat = F.sparse_tensor(idx, dat, [m, n])
adjmat = F.to_context(adjmat, self.context)
# TODO(minjie): use lazy dict for reduced_msgs
reduced_msgs = {}
for key in self._node_frame.schemes:
......@@ -739,10 +721,10 @@ class DGLGraph(DiGraph):
new_node_repr = update_func(node_repr, reduced_msgs)
self.set_n_repr(new_node_repr, new2old)
else:
u = utils.convert_to_id_tensor(u, self.context)
v = utils.convert_to_id_tensor(v, self.context)
u = utils.toindex(u)
v = utils.toindex(v)
self._batch_sendto(u, v, message_func)
unique_v = F.unique(v)
unique_v = F.unique(v.totensor())
self._batch_recv(unique_v, reduce_func, update_func)
def update_to(self,
......@@ -776,15 +758,17 @@ class DGLGraph(DiGraph):
assert reduce_func is not None
assert update_func is not None
if batchable:
v = utils.toindex(v)
uu, vv = self.cached_graph.in_edges(v)
self.update_by_edge(uu, vv, message_func,
reduce_func, update_func, batchable)
self._batch_update_by_edge(uu, vv, message_func,
reduce_func, update_func)
else:
v = utils.toindex(v)
for vv in utils.node_iter(v):
assert vv in self.nodes
uu = list(self.pred[vv])
self.sendto(uu, vv, message_func, batchable)
self.recv(vv, reduce_func, update_func, batchable)
self._nonbatch_sendto(uu, vv, message_func)
self._nonbatch_recv(vv, reduce_func, update_func)
def update_from(self,
u,
......@@ -817,15 +801,17 @@ class DGLGraph(DiGraph):
assert reduce_func is not None
assert update_func is not None
if batchable:
u = utils.toindex(u)
uu, vv = self.cached_graph.out_edges(u)
self.update_by_edge(uu, vv, message_func,
reduce_func, update_func, batchable)
self._batch_update_by_edge(uu, vv, message_func,
reduce_func, update_func)
else:
u = utils.toindex(u)
for uu in utils.node_iter(u):
assert uu in self.nodes
for v in self.succ[uu]:
self.update_by_edge(uu, v,
message_func, reduce_func, update_func, batchable)
self._nonbatch_update_by_edge(uu, v,
message_func, reduce_func, update_func)
def update_all(self,
message_func=None,
......@@ -857,10 +843,10 @@ class DGLGraph(DiGraph):
if batchable:
if message_func == 'from_src' and reduce_func == 'sum':
# TODO(minjie): use lazy dict for reduced_msgs
adjmat = self.cached_graph.adjmat(self.context)
reduced_msgs = {}
for key in self._node_frame.schemes:
col = self._node_frame[key]
adjmat = self.cached_graph.adjmat(F.get_context(col))
reduced_msgs[key] = F.spmm(adjmat, col)
if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs:
reduced_msgs = reduced_msgs[__REPR__]
......@@ -930,23 +916,52 @@ class DGLGraph(DiGraph):
Returns
-------
G : DGLGraph
G : DGLSubGraph
The subgraph.
"""
return dgl.DGLSubGraph(self, nodes)
def copy_from(self, graph):
"""Copy node/edge features from the given graph.
All old features will be removed.
def merge(self, subgraphs, reduce_func='sum'):
"""Merge subgraph features back to this parent graph.
Parameters
----------
graph : DGLGraph
The graph to copy from.
subgraphs : iterator of DGLSubGraph
The subgraphs to be merged.
reduce_func : str
The reduce function (only 'sum' is supported currently)
"""
# TODO
pass
# sanity check: all the subgraphs and the parent graph
# should have the same node/edge feature schemes.
# merge node features
to_merge = []
for sg in subgraphs:
if len(sg.node_attr_schemes()) == 0:
continue
if sg.node_attr_schemes() != self.node_attr_schemes():
raise RuntimeError('Subgraph and parent graph do not '
'have the same node attribute schemes.')
to_merge.append(sg)
self._node_frame = merge_frames(
[sg._node_frame for sg in to_merge],
[sg._parent_nid for sg in to_merge],
self._node_frame.num_rows,
reduce_func)
# merge edge features
to_merge.clear()
for sg in subgraphs:
if len(sg.edge_attr_schemes()) == 0:
continue
if sg.edge_attr_schemes() != self.edge_attr_schemes():
raise RuntimeError('Subgraph and parent graph do not '
'have the same edge attribute schemes.')
to_merge.append(sg)
self._edge_frame = merge_frames(
[sg._edge_frame for sg in to_merge],
[sg._parent_eid for sg in to_merge],
self._edge_frame.num_rows,
reduce_func)
def draw(self):
"""Plot the graph using dot."""
......@@ -996,14 +1011,10 @@ class DGLGraph(DiGraph):
eid : tensor
The tensor contains edge id(s).
"""
u = utils.toindex(u)
v = utils.toindex(v)
return self.cached_graph.get_edge_id(u, v)
def _nodes_or_all(self, nodes):
return self.nodes() if nodes == ALL else nodes
def _edges_or_all(self, edges):
return self.edges() if edges == ALL else edges
def _add_node_callback(self, node):
#print('New node:', node)
self._cached_graph = None
......
"""Schedule policies for graph computation."""
from __future__ import absolute_import
import dgl.backend as F
import numpy as np
import dgl.backend as F
import dgl.utils as utils
def degree_bucketing(cached_graph, v):
degrees = F.asnumpy(cached_graph.in_degrees(v))
"""Create degree bucketing scheduling policy.
Parameters
----------
cached_graph : dgl.cached_graph.CachedGraph
the graph
v : dgl.utils.Index
the nodes to gather messages
Returns
-------
unique_degrees : list of int
list of unique degrees
v_bkt : list of dgl.utils.Index
list of node id buckets; nodes belong to the same bucket have
the same degree
"""
degrees = F.asnumpy(cached_graph.in_degrees(v).totensor())
unique_degrees = list(np.unique(degrees))
v_np = np.array(v.tolist())
v_bkt = []
for deg in unique_degrees:
idx = np.where(degrees == deg)
v_bkt.append(v[idx])
v_bkt.append(utils.Index(v_np[idx]))
return unique_degrees, v_bkt
......@@ -13,46 +13,30 @@ class DGLSubGraph(DGLGraph):
def __init__(self,
parent,
nodes):
# create subgraph and relabel
nx_sg = nx.DiGraph.subgraph(parent, nodes)
# node id
# TODO(minjie): context
nid = F.tensor(nodes, dtype=F.int64)
# edge id
# TODO(minjie): slow, context
u, v = zip(*nx_sg.edges)
u = list(u)
v = list(v)
eid = parent.cached_graph.get_edge_id(u, v)
# relabel
super(DGLSubGraph, self).__init__()
# relabel nodes
self._node_mapping = utils.build_relabel_dict(nodes)
nx_sg = nx.relabel.relabel_nodes(nx_sg, self._node_mapping)
self._parent_nid = utils.toindex(nodes)
eids = []
# create subgraph
for eid, (u, v) in enumerate(parent.edge_list):
if u in self._node_mapping and v in self._node_mapping:
self.add_edge(self._node_mapping[u],
self._node_mapping[v])
eids.append(eid)
self._parent_eid = utils.toindex(eids)
def copy_from(self, parent):
"""Copy node/edge features from the parent graph.
All old features will be removed.
# init
self._edge_list = []
nx_init(self,
self._add_node_callback,
self._add_edge_callback,
self._del_node_callback,
self._del_edge_callback,
nx_sg,
**parent.graph)
# cached graph and storage
self._cached_graph = None
if parent._node_frame.num_rows == 0:
self._node_frame = FrameRef()
else:
self._node_frame = FrameRef(Frame(parent._node_frame[nid]))
if parent._edge_frame.num_rows == 0:
self._edge_frame = FrameRef()
else:
self._edge_frame = FrameRef(Frame(parent._edge_frame[eid]))
# other class members
self._msg_graph = None
self._msg_frame = FrameRef()
self._message_func = parent._message_func
self._reduce_func = parent._reduce_func
self._update_func = parent._update_func
self._edge_func = parent._edge_func
self._context = parent._context
Parameters
----------
parent : DGLGraph
The parent graph to copy from.
"""
if parent._node_frame.num_rows != 0:
self._node_frame = FrameRef(Frame(parent._node_frame[self._parent_nid]))
if parent._edge_frame.num_rows != 0:
self._edge_frame = FrameRef(Frame(parent._edge_frame[self._parent_eid]))
......@@ -2,6 +2,9 @@
from __future__ import absolute_import
from collections import Mapping
from functools import wraps
import numpy as np
import dgl.backend as F
from dgl.backend import Tensor, SparseTensor
......@@ -11,18 +14,77 @@ def is_id_tensor(u):
def is_id_container(u):
"""Return whether the input is a supported id container."""
return isinstance(u, list)
return (getattr(u, '__iter__', None) is not None
and getattr(u, '__len__', None) is not None)
class Index(object):
"""Index class that can be easily converted to list/tensor."""
def __init__(self, data):
self._list_data = None
self._tensor_data = None
self._ctx_data = dict()
self._dispatch(data)
def _dispatch(self, data):
if is_id_tensor(data):
self._tensor_data = data
elif is_id_container(data):
self._list_data = data
else:
try:
self._list_data = [int(data)]
except:
raise TypeError('Error index data: %s' % str(x))
def tolist(self):
if self._list_data is None:
self._list_data = list(F.asnumpy(self._tensor_data))
return self._list_data
def totensor(self, ctx=None):
if self._tensor_data is None:
self._tensor_data = F.tensor(self._list_data, dtype=F.int64)
if ctx is None:
return self._tensor_data
if ctx not in self._ctx_data:
self._ctx_data[ctx] = F.to_context(self._tensor_data, ctx)
return self._ctx_data[ctx]
def __iter__(self):
return iter(self.tolist())
def __len__(self):
if self._list_data is not None:
return len(self._list_data)
else:
return len(self._tensor_data)
def __getitem__(self, i):
return self.tolist()[i]
def toindex(x):
return x if isinstance(x, Index) else Index(x)
def node_iter(n):
"""Return an iterator that loops over the given nodes."""
n = convert_to_id_container(n)
for nn in n:
yield nn
"""Return an iterator that loops over the given nodes.
Parameters
----------
n : iterable
The node ids.
"""
return iter(n)
def edge_iter(u, v):
"""Return an iterator that loops over the given edges."""
u = convert_to_id_container(u)
v = convert_to_id_container(v)
"""Return an iterator that loops over the given edges.
Parameters
----------
u : iterable
The src ids.
v : iterable
The dst ids.
"""
if len(u) == len(v):
# many-many
for uu, vv in zip(u, v):
......@@ -38,8 +100,33 @@ def edge_iter(u, v):
else:
raise ValueError('Error edges:', u, v)
def edge_broadcasting(u, v):
"""Convert one-many and many-one edges to many-many.
Parameters
----------
u : Index
The src id(s)
v : Index
The dst id(s)
Returns
-------
uu : Index
The src id(s) after broadcasting
vv : Index
The dst id(s) after broadcasting
"""
if len(u) != len(v) and len(u) == 1:
u = toindex(F.broadcast_to(u.totensor(), v.totensor()))
elif len(u) != len(v) and len(v) == 1:
v = toindex(F.broadcast_to(v.totensor(), u.totensor()))
else:
assert len(u) == len(v)
return u, v
'''
def convert_to_id_container(x):
"""Convert the input to id container."""
if is_id_container(x):
return x
elif is_id_tensor(x):
......@@ -52,7 +139,6 @@ def convert_to_id_container(x):
return None
def convert_to_id_tensor(x, ctx=None):
"""Convert the input to id tensor."""
if is_id_container(x):
ret = F.tensor(x, dtype=F.int64)
elif is_id_tensor(x):
......@@ -64,6 +150,7 @@ def convert_to_id_tensor(x, ctx=None):
raise TypeError('Error node: %s' % str(x))
ret = F.to_context(ret, ctx)
return ret
'''
class LazyDict(Mapping):
"""A readonly dictionary that does not materialize the storage."""
......@@ -110,7 +197,7 @@ def build_relabel_map(x):
Parameters
----------
x : int, tensor or container
x : Index
The input ids.
Returns
......@@ -122,7 +209,7 @@ def build_relabel_map(x):
One can use advanced indexing to convert an old id tensor to a
new id tensor: new_id = old_to_new[old_id]
"""
x = convert_to_id_tensor(x)
x = x.totensor()
unique_x, _ = F.sort(F.unique(x))
map_len = int(F.max(unique_x)) + 1
old_to_new = F.zeros(map_len, dtype=F.int64)
......@@ -150,12 +237,55 @@ def build_relabel_dict(x):
relabel_dict[v] = i
return relabel_dict
def edge_broadcasting(u, v):
"""Convert one-many and many-one edges to many-many."""
if len(u) != len(v) and len(u) == 1:
u = F.broadcast_to(u, v)
elif len(u) != len(v) and len(v) == 1:
v = F.broadcast_to(v, u)
class CtxCachedObject(object):
"""A wrapper to cache object generated by different context.
Note: such wrapper may incur significant overhead if the wrapped object is very light.
Parameters
----------
generator : callable
A callable function that can create the object given ctx as the only argument.
"""
def __init__(self, generator):
self._generator = generator
self._ctx_dict = {}
def get(self, ctx):
if not ctx in self._ctx_dict:
self._ctx_dict[ctx] = self._generator(ctx)
return self._ctx_dict[ctx]
def ctx_cached_member(func):
"""Convenient class member function wrapper to cache the function result.
The wrapped function must only have two arguments: `self` and `ctx`. The former is the
class object and the later is the context. It will check whether the class object is
freezed (by checking the `_freeze` member). If yes, it caches the function result in
the field prefixed by '_CACHED_' before the function name.
"""
cache_name = '_CACHED_' + func.__name__
@wraps(func)
def wrapper(self, ctx):
if self._freeze:
# cache
if getattr(self, cache_name, None) is None:
bind_func = lambda _ctx : func(self, _ctx)
setattr(self, cache_name, CtxCachedObject(bind_func))
return getattr(self, cache_name).get(ctx)
else:
assert len(u) == len(v)
return u, v
return func(self, ctx)
return wrapper
def cached_member(func):
cache_name = '_CACHED_' + func.__name__
@wraps(func)
def wrapper(self):
if self._freeze:
# cache
if getattr(self, cache_name, None) is None:
setattr(self, cache_name, func(self))
return getattr(self, cache_name)
else:
return func(self)
return wrapper
......@@ -26,6 +26,17 @@ def update_func(node, accum):
assert node['h'].shape == accum.shape
return {'h' : node['h'] + accum}
def reduce_dict_func(node, msgs):
msgs = msgs['m']
reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3
assert msgs.shape[2] == D
return {'m' : th.sum(msgs, 1)}
def update_dict_func(node, accum):
assert node['h'].shape == accum['m'].shape
return {'h' : node['h'] + accum['m']}
def generate_graph(grad=False):
g = DGLGraph()
for i in range(10):
......@@ -149,7 +160,8 @@ def test_batch_send():
v = th.tensor([9])
g.sendto(u, v)
def test_batch_recv():
def test_batch_recv1():
# basic recv test
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
......@@ -162,6 +174,20 @@ def test_batch_recv():
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
def test_batch_recv2():
# recv test with dict type reduce message
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_dict_func, batchable=True)
g.register_update_func(update_dict_func, batchable=True)
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear()
g.sendto(u, v)
g.recv(th.unique(v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
def test_update_routines():
g = generate_graph()
g.register_message_func(message_func, batchable=True)
......@@ -210,6 +236,7 @@ if __name__ == '__main__':
test_batch_setter_getter()
test_batch_setter_autograd()
test_batch_send()
test_batch_recv()
test_batch_recv1()
test_batch_recv2()
test_update_routines()
#test_delete()
......@@ -3,6 +3,7 @@ import numpy as np
import networkx as nx
from dgl import DGLGraph
from dgl.cached_graph import *
from dgl.utils import Index
def check_eq(a, b):
assert a.shape == b.shape
......@@ -15,22 +16,18 @@ def test_basics():
g.add_edge(1, 3)
g.add_edge(2, 4)
g.add_edge(2, 5)
g.add_edge(0, 2)
cg = create_cached_graph(g)
u = th.tensor([0, 1, 1, 2, 2])
v = th.tensor([1, 2, 3, 4, 5])
check_eq(cg.get_edge_id(u, v), th.tensor([0, 1, 2, 3, 4]))
cg.add_edges(0, 2)
assert cg.get_edge_id(0, 2) == 5
query = th.tensor([1, 2])
u = Index(th.tensor([0, 0, 1, 1, 2, 2]))
v = Index(th.tensor([1, 2, 2, 3, 4, 5]))
check_eq(cg.get_edge_id(u, v).totensor(), th.tensor([0, 5, 1, 2, 3, 4]))
query = Index(th.tensor([1, 2]))
s, d = cg.in_edges(query)
check_eq(s, th.tensor([0, 0, 1]))
check_eq(d, th.tensor([1, 2, 2]))
check_eq(s.totensor(), th.tensor([0, 0, 1]))
check_eq(d.totensor(), th.tensor([1, 2, 2]))
s, d = cg.out_edges(query)
check_eq(s, th.tensor([1, 1, 2, 2]))
check_eq(d, th.tensor([2, 3, 4, 5]))
print(cg._graph.get_adjacency())
print(cg._graph.get_adjacency(eids=True))
check_eq(s.totensor(), th.tensor([1, 1, 2, 2]))
check_eq(d.totensor(), th.tensor([2, 3, 4, 5]))
if __name__ == '__main__':
test_basics()
......@@ -2,6 +2,7 @@ import torch as th
from torch.autograd import Variable
import numpy as np
from dgl.frame import Frame, FrameRef
from dgl.utils import Index
N = 10
D = 5
......@@ -112,7 +113,7 @@ def test_append2():
assert not f.is_span_whole_column()
assert f.num_rows == 3 * N
new_idx = list(range(N)) + list(range(2*N, 4*N))
assert check_eq(f.index_tensor(), th.tensor(new_idx))
assert check_eq(f.index().totensor(), th.tensor(new_idx))
assert data.num_rows == 4 * N
def test_row1():
......@@ -122,20 +123,20 @@ def test_row1():
# getter
# test non-duplicate keys
rowid = th.tensor([0, 2])
rowid = Index(th.tensor([0, 2]))
rows = f[rowid]
for k, v in rows.items():
assert v.shape == (len(rowid), D)
assert check_eq(v, data[k][rowid])
# test duplicate keys
rowid = th.tensor([8, 2, 2, 1])
rowid = Index(th.tensor([8, 2, 2, 1]))
rows = f[rowid]
for k, v in rows.items():
assert v.shape == (len(rowid), D)
assert check_eq(v, data[k][rowid])
# setter
rowid = th.tensor([0, 2, 4])
rowid = Index(th.tensor([0, 2, 4]))
vals = {'a1' : th.zeros((len(rowid), D)),
'a2' : th.zeros((len(rowid), D)),
'a3' : th.zeros((len(rowid), D)),
......@@ -152,13 +153,13 @@ def test_row2():
# getter
c1 = f['a1']
# test non-duplicate keys
rowid = th.tensor([0, 2])
rowid = Index(th.tensor([0, 2]))
rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D)))
assert check_eq(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]))
c1.grad.data.zero_()
# test duplicate keys
rowid = th.tensor([8, 2, 2, 1])
rowid = Index(th.tensor([8, 2, 2, 1]))
rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D)))
assert check_eq(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.]))
......@@ -166,7 +167,7 @@ def test_row2():
# setter
c1 = f['a1']
rowid = th.tensor([0, 2, 4])
rowid = Index(th.tensor([0, 2, 4]))
vals = {'a1' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
'a2' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
'a3' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
......@@ -210,14 +211,14 @@ def test_sharing():
f2_a1 = f2['a1']
# test write
# update own ref should not been seen by the other.
f1[th.tensor([0, 1])] = {
f1[Index(th.tensor([0, 1]))] = {
'a1' : th.zeros([2, D]),
'a2' : th.zeros([2, D]),
'a3' : th.zeros([2, D]),
}
assert check_eq(f2['a1'], f2_a1)
# update shared space should been seen by the other.
f1[th.tensor([2, 3])] = {
f1[Index(th.tensor([2, 3]))] = {
'a1' : th.ones([2, D]),
'a2' : th.ones([2, D]),
'a3' : th.ones([2, D]),
......
......@@ -24,35 +24,71 @@ def generate_graph(grad=False):
g.set_e_repr({'l' : ecol})
return g
def test_subgraph():
def test_basics():
g = generate_graph()
h = g.get_n_repr()['h']
l = g.get_e_repr()['l']
sg = g.subgraph([0, 2, 3, 6, 7, 9])
nid = [0, 2, 3, 6, 7, 9]
eid = [2, 3, 4, 5, 10, 11, 12, 13, 16]
sg = g.subgraph(nid)
# the subgraph is empty initially
assert len(sg.get_n_repr()) == 0
assert len(sg.get_e_repr()) == 0
# the data is copied after explict copy from
sg.copy_from(g)
assert len(sg.get_n_repr()) == 1
assert len(sg.get_e_repr()) == 1
sh = sg.get_n_repr()['h']
check_eq(h[th.tensor([0, 2, 3, 6, 7, 9])], sh)
assert check_eq(h[nid], sh)
'''
s, d, eid
0, 1, 0
1, 9, 1
0, 2, 2
2, 9, 3
0, 3, 4
3, 9, 5
0, 2, 2 1
2, 9, 3 1
0, 3, 4 1
3, 9, 5 1
0, 4, 6
4, 9, 7
0, 5, 8
5, 9, 9
0, 6, 10
6, 9, 11
0, 7, 12
7, 9, 13
5, 9, 9 3
0, 6, 10 1
6, 9, 11 1 3
0, 7, 12 1
7, 9, 13 1 3
0, 8, 14
8, 9, 15
9, 0, 16
8, 9, 15 3
9, 0, 16 1
'''
eid = th.tensor([2, 3, 4, 5, 10, 11, 12, 13, 16])
check_eq(l[eid], sg.get_e_repr()['l'])
assert check_eq(l[eid], sg.get_e_repr()['l'])
# update the node/edge features on the subgraph should NOT
# reflect to the parent graph.
sg.set_n_repr({'h' : th.zeros((6, D))})
assert check_eq(h, g.get_n_repr()['h'])
def test_merge():
g = generate_graph()
g.set_n_repr({'h' : th.zeros((10, D))})
g.set_e_repr({'l' : th.zeros((17, D))})
# subgraphs
sg1 = g.subgraph([0, 2, 3, 6, 7, 9])
sg1.set_n_repr({'h' : th.ones((6, D))})
sg1.set_e_repr({'l' : th.ones((9, D))})
sg2 = g.subgraph([0, 2, 3, 4])
sg2.set_n_repr({'h' : th.ones((4, D)) * 2})
sg3 = g.subgraph([5, 6, 7, 8, 9])
sg3.set_e_repr({'l' : th.ones((4, D)) * 3})
g.merge([sg1, sg2, sg3])
h = g.get_n_repr()['h'][:,0]
l = g.get_e_repr()['l'][:,0]
assert check_eq(h, th.tensor([3., 0., 3., 3., 2., 0., 1., 1., 0., 1.]))
assert check_eq(l,
th.tensor([0., 0., 1., 1., 1., 1., 0., 0., 0., 3., 1., 4., 1., 4., 0., 3., 1.]))
if __name__ == '__main__':
test_subgraph()
test_basics()
test_merge()
......@@ -6,6 +6,12 @@ def message_func(src, edge):
def update_func(node, accum):
return {'h' : node['h'] + accum}
def message_dict_func(src, edge):
return {'m' : src['h']}
def update_dict_func(node, accum):
return {'h' : node['h'] + accum['m']}
def generate_graph():
g = DGLGraph()
for i in range(10):
......@@ -23,12 +29,18 @@ def check(g, h):
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def test_sendrecv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
def register1(g):
g.register_message_func(message_func)
g.register_update_func(update_func)
g.register_reduce_func('sum')
def register2(g):
g.register_message_func(message_dict_func)
g.register_update_func(update_dict_func)
g.register_reduce_func('sum')
def _test_sendrecv(g):
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(0, 1)
g.recv(1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
......@@ -37,12 +49,8 @@ def test_sendrecv():
g.recv(9)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23])
def test_multi_sendrecv():
g = generate_graph()
def _test_multi_sendrecv(g):
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_update_func(update_func)
g.register_reduce_func('sum')
# one-many
g.sendto(0, [1, 2, 3])
g.recv([1, 2, 3])
......@@ -56,12 +64,8 @@ def test_multi_sendrecv():
g.recv([4, 5, 9])
check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45])
def test_update_routines():
g = generate_graph()
def _test_update_routines(g):
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_update_func(update_func)
g.register_reduce_func('sum')
g.update_by_edge(0, 1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.update_to(9)
......@@ -71,6 +75,30 @@ def test_update_routines():
g.update_all()
check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108])
def test_sendrecv():
g = generate_graph()
register1(g)
_test_sendrecv(g)
g = generate_graph()
register2(g)
_test_sendrecv(g)
def test_multi_sendrecv():
g = generate_graph()
register1(g)
_test_multi_sendrecv(g)
g = generate_graph()
register2(g)
_test_multi_sendrecv(g)
def test_update_routines():
g = generate_graph()
register1(g)
_test_update_routines(g)
g = generate_graph()
register2(g)
_test_update_routines(g)
if __name__ == '__main__':
test_sendrecv()
test_multi_sendrecv()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment