Unverified Commit 61fa3c6c authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Builtin function and API changes (#53)

* WIP: API renaming

* API rewrite and node function refactor

* builtin functions

* builtin functions tested

* fix test

* send and recv spmv test

* WIP: fix examples

* Fix examples using new APIs
parent 8c71f3f8
......@@ -4,8 +4,12 @@ from __future__ import absolute_import
import numpy as np
import dgl.backend as F
import dgl.function.message as fmsg
import dgl.function.reducer as fred
import dgl.utils as utils
__all__ = ["degree_bucketing", "get_executor"]
def degree_bucketing(cached_graph, v):
"""Create degree bucketing scheduling policy.
......@@ -33,3 +37,181 @@ def degree_bucketing(cached_graph, v):
v_bkt.append(utils.Index(v_np[idx]))
#print('degree-bucketing:', unique_degrees, [len(b) for b in v_bkt])
return unique_degrees, v_bkt
class Executor(object):
def run(self, graph):
raise NotImplementedError
class UpdateAllSPMVExecutor(Executor):
def __init__(self, graph, src_field, dst_field, edge_field, use_adj):
self.graph = graph
self.src_field = src_field
self.dst_field = dst_field
self.edge_field = edge_field
self.use_adj = use_adj
def run(self):
g = self.graph
if self.src_field is None:
srccol = g.get_n_repr()
else:
srccol = g.get_n_repr()[self.src_field]
ctx = F.get_context(srccol)
if self.use_adj:
adjmat = g.cached_graph.adjmat().get(ctx)
else:
if self.edge_field is None:
dat = g.get_e_repr()
else:
dat = g.get_e_repr()[self.edge_field]
dat = F.squeeze(dat)
# TODO(minjie): should not directly use _indices
idx = g.cached_graph.adjmat().get(ctx)._indices()
n = g.number_of_nodes()
adjmat = F.sparse_tensor(idx, dat, [n, n])
# spmm
if len(F.shape(srccol)) == 1:
srccol = F.unsqueeze(srccol, 1)
dstcol = F.spmm(adjmat, srccol)
dstcol = F.squeeze(dstcol)
else:
dstcol = F.spmm(adjmat, srccol)
if self.dst_field is None:
g.set_n_repr(dstcol)
else:
g.set_n_repr({self.dst_field : dstcol})
class SendRecvSPMVExecutor(Executor):
def __init__(self, graph, src, dst, src_field, dst_field, edge_field, use_edge_dat):
self.graph = graph
self.src = src
self.dst = dst
self.src_field = src_field
self.dst_field = dst_field
self.edge_field = edge_field
self.use_edge_dat = use_edge_dat
def run(self):
# get src col
g = self.graph
if self.src_field is None:
srccol = g.get_n_repr()
else:
srccol = g.get_n_repr()[self.src_field]
ctx = F.get_context(srccol)
# build adjmat
# build adjmat dat
u, v = utils.edge_broadcasting(self.src, self.dst)
if self.use_edge_dat:
if self.edge_field is None:
dat = g.get_e_repr(u, v)
else:
dat = g.get_e_repr(u, v)[self.edge_field]
dat = F.squeeze(dat)
else:
dat = F.ones((len(u),))
# build adjmat index
new2old, old2new = utils.build_relabel_map(v)
u = u.totensor()
v = v.totensor()
# TODO(minjie): should not directly use []
new_v = old2new[v]
idx = F.pack([F.unsqueeze(new_v, 0), F.unsqueeze(u, 0)])
n = g.number_of_nodes()
m = len(new2old)
adjmat = F.sparse_tensor(idx, dat, [m, n])
adjmat = F.to_context(adjmat, ctx)
# spmm
if len(F.shape(srccol)) == 1:
srccol = F.unsqueeze(srccol, 1)
dstcol = F.spmm(adjmat, srccol)
dstcol = F.squeeze(dstcol)
else:
dstcol = F.spmm(adjmat, srccol)
if self.dst_field is None:
g.set_n_repr(dstcol, new2old)
else:
g.set_n_repr({self.dst_field : dstcol}, new2old)
def _is_spmv_supported_node_feat(g, field):
if field is None:
feat = g.get_n_repr()
else:
feat = g.get_n_repr()[field]
shape = F.shape(feat)
return (len(shape) == 1 or len(shape) == 2)
def _is_spmv_supported_edge_feat(g, field):
# check shape, only scalar edge feature can be optimized at the moment.
if field is None:
feat = g.get_e_repr()
else:
feat = g.get_e_repr()[field]
shape = F.shape(feat)
return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1)
def _create_update_all_exec(graph, **kwargs):
mfunc = kwargs.pop('message_func')
rfunc = kwargs.pop('reduce_func')
if (isinstance(mfunc, fmsg.CopySrcMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)):
# TODO(minjie): more sanity check on field names
return UpdateAllSPMVExecutor(graph,
src_field=mfunc.src_field,
dst_field=rfunc.out_field,
edge_field=None,
use_adj=True)
elif (isinstance(mfunc, fmsg.SrcMulEdgeMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)
and _is_spmv_supported_edge_feat(graph, mfunc.edge_field)):
return UpdateAllSPMVExecutor(graph,
src_field=mfunc.src_field,
dst_field=rfunc.out_field,
edge_field=mfunc.edge_field,
use_adj=False)
elif (isinstance(mfunc, fmsg.CopyEdgeMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)):
return None
else:
return None
def _create_send_and_recv_exec(graph, **kwargs):
src = kwargs.pop('src')
dst = kwargs.pop('dst')
mfunc = kwargs.pop('message_func')
rfunc = kwargs.pop('reduce_func')
if (isinstance(mfunc, fmsg.CopySrcMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)):
# TODO(minjie): more sanity check on field names
return SendRecvSPMVExecutor(graph,
src=src,
dst=dst,
src_field=mfunc.src_field,
dst_field=rfunc.out_field,
edge_field=None,
use_edge_dat=False)
elif (isinstance(mfunc, fmsg.SrcMulEdgeMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)
and _is_spmv_supported_edge_feat(graph, mfunc.edge_field)):
return SendRecvSPMVExecutor(graph,
src=src,
dst=dst,
src_field=mfunc.src_field,
dst_field=rfunc.out_field,
edge_field=mfunc.edge_field,
use_edge_dat=True)
else:
return None
def get_executor(call_type, graph, **kwargs):
if call_type == "update_all":
return _create_update_all_exec(graph, **kwargs)
elif call_type == "send_and_recv":
return _create_send_and_recv_exec(graph, **kwargs)
else:
return None
......@@ -303,3 +303,19 @@ def pack2(a, b):
return {k: F.pack([a[k], b[k]]) for k in a}
else:
return F.pack([a, b])
def reorder(dict_like, index):
"""Reorder each column in the dict according to the index.
Parameters
----------
dict_like : dict of tensors
The dict to be reordered.
index : dgl.utils.Index
The reorder index.
"""
new_dict = {}
for key, val in dict_like.items():
idx_ctx = index.totensor(F.get_context(val))
new_dict[key] = F.gather_row(val, idx_ctx)
return new_dict
......@@ -16,26 +16,14 @@ def message_func(src, edge):
return {'m' : src['h']}
def reduce_func(node, msgs):
msgs = msgs['m']
reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3
assert msgs.shape[2] == D
return th.sum(msgs, 1)
def update_func(node, accum):
assert node['h'].shape == accum.shape
return {'h' : node['h'] + accum}
def reduce_dict_func(node, msgs):
msgs = msgs['m']
reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3
assert msgs.shape[2] == D
return {'m' : th.sum(msgs, 1)}
def update_dict_func(node, accum):
assert node['h'].shape == accum['m'].shape
return {'h' : node['h'] + accum['m']}
def apply_node_func(node):
return {'h' : node['h'] + node['m']}
def generate_graph(grad=False):
g = DGLGraph()
......@@ -147,43 +135,29 @@ def test_batch_send():
assert src['h'].shape == (5, D)
return {'m' : src['h']}
g.register_message_func(_fmsg, batchable=True)
# many-many sendto
# many-many send
u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5])
g.sendto(u, v)
# one-many sendto
g.send(u, v)
# one-many send
u = th.tensor([0])
v = th.tensor([1, 2, 3, 4, 5])
g.sendto(u, v)
# many-one sendto
g.send(u, v)
# many-one send
u = th.tensor([1, 2, 3, 4, 5])
v = th.tensor([9])
g.sendto(u, v)
g.send(u, v)
def test_batch_recv1():
def test_batch_recv():
# basic recv test
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_update_func(update_func, batchable=True)
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear()
g.sendto(u, v)
g.recv(th.unique(v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
def test_batch_recv2():
# recv test with dict type reduce message
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_dict_func, batchable=True)
g.register_update_func(update_dict_func, batchable=True)
g.register_apply_node_func(apply_node_func, batchable=True)
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear()
g.sendto(u, v)
g.send(u, v)
g.recv(th.unique(v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
......@@ -192,27 +166,27 @@ def test_update_routines():
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_update_func(update_func, batchable=True)
g.register_apply_node_func(apply_node_func, batchable=True)
# update_by_edge
# send_and_recv
reduce_msg_shapes.clear()
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
g.update_by_edge(u, v)
g.send_and_recv(u, v)
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
# update_to
# pull
v = th.tensor([1, 2, 3, 9])
reduce_msg_shapes.clear()
g.update_to(v)
g.pull(v)
assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
reduce_msg_shapes.clear()
# update_from
# push
v = th.tensor([0, 1, 2, 3])
reduce_msg_shapes.clear()
g.update_from(v)
g.push(v)
assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
reduce_msg_shapes.clear()
......@@ -233,24 +207,16 @@ def test_reduce_0deg():
return src
def _reduce(node, msgs):
assert msgs is not None
return msgs.sum(1)
def _update(node, accum):
if node.shape[0] == 4:
assert accum is None
return node
else:
assert accum is not None
return node + accum
return node + msgs.sum(1)
old_repr = th.randn(5, 5)
g.set_n_repr(old_repr)
g.update_all(_message, _reduce, _update, True)
g.update_all(_message, _reduce, batchable=True)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[1:], old_repr[1:])
assert th.allclose(new_repr[0], old_repr.sum(0))
def test_update_to_0deg():
def test_pull_0deg():
g = DGLGraph()
g.add_nodes_from([0, 1])
g.add_edge(0, 1)
......@@ -259,24 +225,22 @@ def test_update_to_0deg():
def _reduce(node, msgs):
assert msgs is not None
return msgs.sum(1)
def _update(node, accum):
return node * 2 if accum is None else accum
old_repr = th.randn(2, 5)
g.set_n_repr(old_repr)
g.update_to(0, _message, _reduce, _update, True)
g.pull(0, _message, _reduce, batchable=True)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[0], old_repr[0] * 2)
assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[1])
g.update_to(1, _message, _reduce, _update, True)
g.pull(1, _message, _reduce, batchable=True)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[1], old_repr[0] * 2)
assert th.allclose(new_repr[1], old_repr[0])
old_repr = th.randn(2, 5)
g.set_n_repr(old_repr)
g.update_to([0, 1], _message, _reduce, _update, True)
g.pull([0, 1], _message, _reduce, batchable=True)
new_repr = g.get_n_repr()
assert th.allclose(new_repr[0], old_repr[0] * 2)
assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[0])
def _test_delete():
......@@ -293,9 +257,8 @@ if __name__ == '__main__':
test_batch_setter_getter()
test_batch_setter_autograd()
test_batch_send()
test_batch_recv1()
test_batch_recv2()
test_batch_recv()
test_update_routines()
test_reduce_0deg()
test_update_to_0deg()
test_pull_0deg()
#test_delete()
......@@ -19,11 +19,7 @@ def reduce_func(hv, msgs):
reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3
assert msgs.shape[2] == D
return th.sum(msgs, 1)
def update_func(hv, accum):
assert hv.shape == accum.shape
return hv + accum
return hv + th.sum(msgs, 1)
def generate_graph(grad=False):
g = DGLGraph()
......@@ -135,28 +131,27 @@ def test_batch_send():
assert hu.shape == (5, D)
return hu
g.register_message_func(_fmsg, batchable=True)
# many-many sendto
# many-many send
u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5])
g.sendto(u, v)
# one-many sendto
g.send(u, v)
# one-many send
u = th.tensor([0])
v = th.tensor([1, 2, 3, 4, 5])
g.sendto(u, v)
# many-one sendto
g.send(u, v)
# many-one send
u = th.tensor([1, 2, 3, 4, 5])
v = th.tensor([9])
g.sendto(u, v)
g.send(u, v)
def test_batch_recv():
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_update_func(update_func, batchable=True)
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear()
g.sendto(u, v)
g.send(u, v)
g.recv(th.unique(v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
......@@ -165,27 +160,26 @@ def test_update_routines():
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_update_func(update_func, batchable=True)
# update_by_edge
# send_and_recv
reduce_msg_shapes.clear()
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
g.update_by_edge(u, v)
g.send_and_recv(u, v)
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
# update_to
# pull
v = th.tensor([1, 2, 3, 9])
reduce_msg_shapes.clear()
g.update_to(v)
g.pull(v)
assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
reduce_msg_shapes.clear()
# update_from
# push
v = th.tensor([0, 1, 2, 3])
reduce_msg_shapes.clear()
g.update_from(v)
g.push(v)
assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
reduce_msg_shapes.clear()
......
......@@ -145,6 +145,11 @@ def test_row1():
for k, v in f[rowid].items():
assert check_eq(v, th.zeros((len(rowid), D)))
# setting rows with new column should automatically add a new column
vals['a4'] = th.ones((len(rowid), D))
f[rowid] = vals
assert len(f) == 4
def test_row2():
# test row getter/setter autograd compatibility
data = create_test_data(grad=True)
......
import torch as th
import dgl
import dgl.function as fn
from dgl.graph import __REPR__
def generate_graph():
g = dgl.DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
h = th.arange(1, 11)
g.set_n_repr({'h': h})
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
h = th.tensor([1., 2., 1., 3., 1., 4., 1., 5., 1., 6.,\
1., 7., 1., 8., 1., 9., 10.])
g.set_e_repr({'h' : h})
return g
def generate_graph1():
"""graph with anonymous repr"""
g = dgl.DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
h = th.arange(1, 11)
g.set_n_repr(h)
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
h = th.tensor([1., 2., 1., 3., 1., 4., 1., 5., 1., 6.,\
1., 7., 1., 8., 1., 9., 10.])
g.set_e_repr(h)
return g
def reducer_msg(node, msgs):
return th.sum(msgs['m'], 1)
def reducer_out(node, msgs):
return {'h' : th.sum(msgs, 1)}
def reducer_both(node, msgs):
return {'h' : th.sum(msgs['m'], 1)}
def reducer_none(node, msgs):
return th.sum(msgs, 1)
def test_copy_src():
# copy_src with both fields
g = generate_graph()
g.register_message_func(fn.copy_src(src='h', out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_src with only src field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_src(src='h'), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_src with no src field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_src(out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy src with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_src(), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
def test_copy_edge():
# copy_edge with both fields
g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h', out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_edge with only edge field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h'), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_edge with no edge field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_edge(out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy edge with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_edge(), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
def test_src_mul_edge():
# src_mul_edge with all fields
g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h'), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(out='m'), batchable=True)
g.register_reduce_func(reducer_both, batchable=True)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=True)
g.register_reduce_func(reducer_out, batchable=True)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=True)
g.register_reduce_func(reducer_none, batchable=True)
g.update_all()
assert th.allclose(g.get_n_repr(),
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
if __name__ == '__main__':
test_copy_src()
test_copy_edge()
test_src_mul_edge()
......@@ -73,7 +73,6 @@ def test_batch_sendrecv():
bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src, batchable=True)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1), batchable=True)
bg.register_update_func(lambda node, accum: accum, batchable=True)
e1 = [(3, 1), (4, 1)]
e2 = [(2, 4), (0, 4)]
......@@ -82,7 +81,7 @@ def test_batch_sendrecv():
u = np.concatenate((u1, u2)).tolist()
v = np.concatenate((v1, v2)).tolist()
bg.sendto(u, v)
bg.send(u, v)
bg.recv(v)
dgl.unbatch(bg)
......@@ -97,7 +96,6 @@ def test_batch_propagate():
bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src, batchable=True)
bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1), batchable=True)
bg.register_update_func(lambda node, accum: accum, batchable=True)
# get leaves.
order = []
......
import torch as th
import numpy as np
from dgl.graph import DGLGraph
import dgl
import dgl.function as fn
D = 5
def check_eq(a, b):
if not np.allclose(a.numpy(), b.numpy()):
print(a, b)
def message_func(hu, edge):
return hu
def reduce_func(hv, msgs):
return th.sum(msgs, 1)
def update_func(hv, accum):
assert hv.shape == accum.shape
return hv + accum
def generate_graph():
g = DGLGraph()
g = dgl.DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
......@@ -28,30 +15,104 @@ def generate_graph():
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
col = th.randn(10, D)
g.set_n_repr(col)
g.set_n_repr({'f1' : th.randn(10,), 'f2' : th.randn(10, D)})
weights = th.randn(17,)
g.set_e_repr({'e1': weights, 'e2': th.unsqueeze(weights, 1)})
return g
def test_spmv_specialize():
g = generate_graph()
# update all
v1 = g.get_n_repr()
g.update_all('from_src', 'sum', update_func, batchable=True)
v2 = g.get_n_repr()
g.set_n_repr(v1)
g.update_all(message_func, reduce_func, update_func, batchable=True)
v3 = g.get_n_repr()
check_eq(v2, v3)
# partial update
def test_update_all():
def _test(fld):
def message_func(hu, edge):
return hu[fld]
def message_func_edge(hu, edge):
if len(hu[fld].shape) == 1:
return hu[fld] * edge['e1']
else:
return hu[fld] * edge['e2']
def reduce_func(hv, msgs):
return {fld : th.sum(msgs, 1)}
def apply_func(hu):
return {fld : 2 * hu[fld]}
g = generate_graph()
# update all
v1 = g.get_n_repr()[fld]
g.update_all(fn.copy_src(src=fld), fn.sum(out=fld), apply_func, batchable=True)
v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.update_all(message_func, reduce_func, apply_func, batchable=True)
v3 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
# update all with edge weights
v1 = g.get_n_repr()[fld]
g.update_all(fn.src_mul_edge(src=fld, edge='e1'),
fn.sum(out=fld), apply_func, batchable=True)
v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.update_all(fn.src_mul_edge(src=fld, edge='e2'),
fn.sum(out=fld), apply_func, batchable=True)
v3 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.update_all(message_func_edge, reduce_func, apply_func, batchable=True)
v4 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
assert th.allclose(v3, v4)
# test 1d node features
_test('f1')
# test 2d node features
_test('f2')
def test_send_and_recv():
u = th.tensor([0, 0, 0, 3, 4, 9])
v = th.tensor([1, 2, 3, 9, 9, 0])
v1 = g.get_n_repr()
g.update_by_edge(u, v, 'from_src', 'sum', update_func, batchable=True)
v2 = g.get_n_repr()
g.set_n_repr(v1)
g.update_by_edge(u, v, message_func, reduce_func, update_func, batchable=True)
v3 = g.get_n_repr()
check_eq(v2, v3)
def _test(fld):
def message_func(hu, edge):
return hu[fld]
def message_func_edge(hu, edge):
if len(hu[fld].shape) == 1:
return hu[fld] * edge['e1']
else:
return hu[fld] * edge['e2']
def reduce_func(hv, msgs):
return {fld : th.sum(msgs, 1)}
def apply_func(hu):
return {fld : 2 * hu[fld]}
g = generate_graph()
# send and recv
v1 = g.get_n_repr()[fld]
g.send_and_recv(u, v, fn.copy_src(src=fld),
fn.sum(out=fld), apply_func, batchable=True)
v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.send_and_recv(u, v, message_func,
reduce_func, apply_func, batchable=True)
v3 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
# send and recv with edge weights
v1 = g.get_n_repr()[fld]
g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e1'),
fn.sum(out=fld), apply_func, batchable=True)
v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e2'),
fn.sum(out=fld), apply_func, batchable=True)
v3 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1})
g.send_and_recv(u, v, message_func_edge,
reduce_func, apply_func, batchable=True)
v4 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
assert th.allclose(v3, v4)
# test 1d node features
_test('f1')
# test 2d node features
_test('f2')
if __name__ == '__main__':
test_spmv_specialize()
#test_update_all()
test_send_and_recv()
......@@ -4,8 +4,8 @@ from dgl.graph import __REPR__
def message_func(hu, e_uv):
return hu + e_uv
def update_func(h, accum):
return h + accum
def reduce_func(h, msgs):
return h + sum(msgs)
def generate_graph():
g = DGLGraph()
......@@ -28,34 +28,32 @@ def test_sendrecv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_update_func(update_func)
g.register_reduce_func('sum')
g.sendto(0, 1)
g.register_reduce_func(reduce_func)
g.send(0, 1)
g.recv(1)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(5, 9)
g.sendto(6, 9)
g.send(5, 9)
g.send(6, 9)
g.recv(9)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25])
def message_func_hybrid(src, edge):
return src[__REPR__] + edge
def update_func_hybrid(node, accum):
return node[__REPR__] + accum
def reduce_func_hybrid(node, msgs):
return node[__REPR__] + sum(msgs)
def test_hybridrepr():
g = generate_graph()
for i in range(10):
g.nodes[i]['id'] = -i
g.register_message_func(message_func_hybrid)
g.register_update_func(update_func_hybrid)
g.register_reduce_func('sum')
g.sendto(0, 1)
g.register_reduce_func(reduce_func_hybrid)
g.send(0, 1)
g.recv(1)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(5, 9)
g.sendto(6, 9)
g.send(5, 9)
g.send(6, 9)
g.recv(9)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25])
......
......@@ -3,14 +3,20 @@ from dgl.graph import DGLGraph
def message_func(src, edge):
return src['h']
def update_func(node, accum):
return {'h' : node['h'] + accum}
def reduce_func(node, msgs):
return {'m' : sum(msgs)}
def apply_func(node):
return {'h' : node['h'] + node['m']}
def message_dict_func(src, edge):
return {'m' : src['h']}
def update_dict_func(node, accum):
return {'h' : node['h'] + accum['m']}
def reduce_dict_func(node, msgs):
return {'m' : sum([msg['m'] for msg in msgs])}
def apply_dict_func(node):
return {'h' : node['h'] + node['m']}
def generate_graph():
g = DGLGraph()
......@@ -31,66 +37,50 @@ def check(g, h):
def register1(g):
g.register_message_func(message_func)
g.register_update_func(update_func)
g.register_reduce_func('sum')
g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_func)
def register2(g):
g.register_message_func(message_dict_func)
g.register_update_func(update_dict_func)
g.register_reduce_func('sum')
g.register_reduce_func(reduce_dict_func)
g.register_apply_node_func(apply_dict_func)
def _test_sendrecv(g):
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(0, 1)
g.send(0, 1)
g.recv(1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(5, 9)
g.sendto(6, 9)
g.send(5, 9)
g.send(6, 9)
g.recv(9)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23])
def _test_multi_sendrecv(g):
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
# one-many
g.sendto(0, [1, 2, 3])
g.send(0, [1, 2, 3])
g.recv([1, 2, 3])
check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 10])
# many-one
g.sendto([6, 7, 8], 9)
g.send([6, 7, 8], 9)
g.recv(9)
check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 34])
# many-many
g.sendto([0, 0, 4, 5], [4, 5, 9, 9])
g.send([0, 0, 4, 5], [4, 5, 9, 9])
g.recv([4, 5, 9])
check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45])
def _test_update_routines(g):
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.update_by_edge(0, 1)
g.send_and_recv(0, 1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.update_to(9)
g.pull(9)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 55])
g.update_from(0)
g.push(0)
check(g, [1, 4, 4, 5, 6, 7, 8, 9, 10, 55])
g.update_all()
check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108])
def _test_update_to_0deg():
g = DGLGraph()
g.add_node(0, h=2)
g.add_node(1, h=1)
g.add_edge(0, 1)
def _message(src, edge):
return src
def _reduce(node, msgs):
assert msgs is not None
return msgs.sum(1)
def _update(node, accum):
assert accum is None
return {'h': node['h'] * 2}
g.update_to(0, _message, _reduce, _update)
assert g.nodes[0]['h'] == 4
def test_sendrecv():
g = generate_graph()
register1(g)
......@@ -115,8 +105,6 @@ def test_update_routines():
register2(g)
_test_update_routines(g)
_test_update_to_0deg()
if __name__ == '__main__':
test_sendrecv()
test_multi_sendrecv()
......
......@@ -12,13 +12,8 @@ def reduce_not_called(h, msgs):
assert False
return 0
def update_no_msg(h, accum):
assert accum is None
return h + 1
def update_func(h, accum):
assert accum is not None
return h + accum
def reduce_func(h, msgs):
return h + sum(msgs)
def check(g, h):
nh = [str(g.nodes[i][__REPR__]) for i in range(10)]
......@@ -35,12 +30,12 @@ def generate_graph():
g.add_edge(i, 9)
return g
def test_no_msg_update():
def test_no_msg_recv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_not_called)
g.register_reduce_func(reduce_not_called)
g.register_update_func(update_no_msg)
g.register_apply_node_func(lambda h : h + 1)
for i in range(10):
g.recv(i)
check(g, [2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
......@@ -49,28 +44,31 @@ def test_double_recv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_reduce_func('sum')
g.register_update_func(update_func)
g.sendto(1, 9)
g.sendto(2, 9)
g.register_reduce_func(reduce_func)
g.send(1, 9)
g.send(2, 9)
g.recv(9)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 15])
try:
# The second recv should have a None message
g.recv(9)
except:
return
assert False
def test_recv_no_pred():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_not_called)
g.register_reduce_func(reduce_not_called)
g.register_update_func(update_no_msg)
g.recv(0)
g.recv(9)
def test_pull_0deg():
g = DGLGraph()
g.add_node(0, h=2)
g.add_node(1, h=1)
g.add_edge(0, 1)
def _message(src, edge):
assert False
return src
def _reduce(node, msgs):
assert False
return node
def _update(node):
return {'h': node['h'] * 2}
g.pull(0, _message, _reduce, _update)
assert g.nodes[0]['h'] == 4
if __name__ == '__main__':
test_no_msg_update()
test_no_msg_recv()
test_double_recv()
test_recv_no_pred()
test_pull_0deg()
import dgl
import dgl.function as fn
from dgl.graph import __REPR__
def generate_graph():
g = dgl.DGLGraph()
for i in range(10):
g.add_node(i, h=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i, h=1)
g.add_edge(i, 9, h=i+1)
# add a back flow from 9 to 0
g.add_edge(9, 0, h=10)
return g
def check(g, h, fld):
nh = [str(g.nodes[i][fld]) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def generate_graph1():
"""graph with anonymous repr"""
g = dgl.DGLGraph()
for i in range(10):
g.add_node(i, __REPR__=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i, __REPR__=1)
g.add_edge(i, 9, __REPR__=i+1)
# add a back flow from 9 to 0
g.add_edge(9, 0, __REPR__=10)
return g
def test_copy_src():
# copy_src with both fields
g = generate_graph()
g.register_message_func(fn.copy_src(src='h', out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy_src with only src field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_src(src='h'), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy_src with no src field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_src(out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy src with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_src(), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
def test_copy_edge():
# copy_edge with both fields
g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h', out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy_edge with only edge field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h'), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy_edge with no edge field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_edge(out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
# copy edge with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_edge(), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [10, 1, 1, 1, 1, 1, 1, 1, 1, 44], 'h')
def test_src_mul_edge():
# src_mul_edge with all fields
g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h')
g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h'), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 'h')
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h')
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(out='m'), batchable=False)
g.register_reduce_func(fn.sum(msgs='m', out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h')
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=False)
g.register_reduce_func(fn.sum(out='h'), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], 'h')
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(), batchable=False)
g.register_reduce_func(fn.sum(), batchable=False)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], __REPR__)
g.update_all()
check(g, [100, 1, 1, 1, 1, 1, 1, 1, 1, 284], __REPR__)
if __name__ == '__main__':
test_copy_src()
test_copy_edge()
test_src_mul_edge()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment