"tests/vscode:/vscode.git/clone" did not exist on "5412a3341fa5d0211629ee87899015f98a62e0cc"
Commit 916d375b authored by Minjie Wang's avatar Minjie Wang
Browse files

Merge branch 'master' into cpp

parents a1038eb1 9b0a01db
......@@ -2,9 +2,11 @@
from __future__ import absolute_import
import operator
import dgl.backend as F
__all__ = ["MessageFunction", "src_mul_edge", "copy_src", "copy_edge"]
class MessageFunction(object):
def __call__(self, src, edge):
raise NotImplementedError
......@@ -12,10 +14,28 @@ class MessageFunction(object):
def name(self):
raise NotImplementedError
def is_spmv_supported(self, g):
raise NotImplementedError
class BundledMessageFunction(MessageFunction):
def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list]
else:
# sanity check on out field
for fn in fn_list:
# cannot perform check for udf
if isinstance(fn, MessageFunction) and fn.out_field is None:
raise RuntimeError("Not specifying out field for multiple message is ambiguous")
self.fn_list = fn_list
def is_spmv_supported(self, g):
for fn in self.fn_list:
if not isinstance(fn, MessageFunction) or not fn.is_spmv_supported(g):
return False
return True
def __call__(self, src, edge):
ret = None
for fn in self.fn_list:
......@@ -24,16 +44,34 @@ class BundledMessageFunction(MessageFunction):
ret = msg
else:
try:
# ret and msg must be dict
ret.update(msg)
except e:
raise RuntimeError("Failed to merge results of two builtin"
" message functions. Please specify out_field"
" for the builtin message function.")
except:
raise RuntimeError("Must specify out field for multiple message")
return ret
def name(self):
return "bundled"
def _is_spmv_supported_node_feat(g, field):
if field is None:
feat = g.get_n_repr()
else:
feat = g.get_n_repr()[field]
shape = F.shape(feat)
return len(shape) == 1 or len(shape) == 2
def _is_spmv_supported_edge_feat(g, field):
# check shape, only scalar edge feature can be optimized at the moment
if field is None:
feat = g.get_e_repr()
else:
feat = g.get_e_repr()[field]
shape = F.shape(feat)
return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1)
class SrcMulEdgeMessageFunction(MessageFunction):
def __init__(self, mul_op, src_field=None, edge_field=None, out_field=None):
self.mul_op = mul_op
......@@ -41,6 +79,10 @@ class SrcMulEdgeMessageFunction(MessageFunction):
self.edge_field = edge_field
self.out_field = out_field
def is_spmv_supported(self, g):
return _is_spmv_supported_node_feat(g, self.src_field) \
and _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge):
if self.src_field is not None:
src = src[self.src_field]
......@@ -60,6 +102,9 @@ class CopySrcMessageFunction(MessageFunction):
self.src_field = src_field
self.out_field = out_field
def is_spmv_supported(self, g):
return _is_spmv_supported_node_feat(g, self.src_field)
def __call__(self, src, edge):
if self.src_field is not None:
ret = src[self.src_field]
......@@ -78,6 +123,11 @@ class CopyEdgeMessageFunction(MessageFunction):
self.edge_field = edge_field
self.out_field = out_field
def is_spmv_supported(self, g):
# TODO: support this with g-spmv
return False
# return _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge):
if self.edge_field is not None:
ret = edge[self.edge_field]
......@@ -91,6 +141,7 @@ class CopyEdgeMessageFunction(MessageFunction):
def name(self):
return "copy_edge"
def src_mul_edge(src=None, edge=None, out=None):
"""TODO(minjie): docstring """
return SrcMulEdgeMessageFunction(operator.mul, src, edge, out)
......
......@@ -12,10 +12,26 @@ class ReduceFunction(object):
def name(self):
raise NotImplementedError
def is_spmv_supported(self):
raise NotImplementedError
class BundledReduceFunction(ReduceFunction):
def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list]
else:
# sanity check on out field
for fn in fn_list:
if isinstance(fn, ReduceFunction) and fn.out_field is None:
raise RuntimeError("Not specifying out field for multiple reduce is ambiguous")
self.fn_list = fn_list
def is_spmv_supported(self):
for fn in self.fn_list:
if not isinstance(fn, ReduceFunction) or not fn.is_spmv_supported():
return False
return True
def __call__(self, node, msgs):
ret = None
for fn in self.fn_list:
......@@ -24,46 +40,50 @@ class BundledReduceFunction(ReduceFunction):
ret = rpr
else:
try:
# ret and rpr must be dict
ret.update(rpr)
except e:
raise RuntimeError("Failed to merge results of two builtin"
" reduce functions. Please specify out_field"
" for the builtin reduce function.")
except:
raise RuntimeError("Must specify out field for multiple reudce")
return ret
def name(self):
return "bundled"
class SumReducerFunction(ReduceFunction):
def __init__(self, batch_sum_op, nonbatch_sum_op, msg_field=None, out_field=None):
self.batch_sum_op = batch_sum_op
self.nonbatch_sum_op = nonbatch_sum_op
class ReducerFunctionTemplate(ReduceFunction):
def __init__(self, name, batch_op, nonbatch_op, msg_field=None, out_field=None):
self.name = name
self.batch_op = batch_op
self.nonbatch_op = nonbatch_op
self.msg_field = msg_field
self.out_field = out_field
def is_spmv_supported(self):
# TODO: support max
return self.name == "sum"
def __call__(self, node, msgs):
if isinstance(msgs, list):
if self.msg_field is None:
ret = self.nonbatch_sum_op(msgs)
ret = self.nonbatch_op(msgs)
else:
ret = self.nonbatch_sum_op([msg[self.msg_field] for msg in msgs])
ret = self.nonbatch_op([msg[self.msg_field] for msg in msgs])
else:
if self.msg_field is None:
ret = self.batch_sum_op(msgs, 1)
ret = self.batch_op(msgs, 1)
else:
ret = self.batch_sum_op(msgs[self.msg_field], 1)
ret = self.batch_op(msgs[self.msg_field], 1)
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self):
return "sum"
return self.name
_python_sum = sum
def sum(msgs=None, out=None):
return SumReducerFunction(F.sum, _python_sum, msgs, out)
return ReducerFunctionTemplate("sum", F.sum, _python_sum, msgs, out)
_python_max = max
def max(msgs=None, out=None):
return SumReducerFunction(F.max, _python_max, msgs, out)
return ReducerFunctionTemplate("max", F.max, _python_max, msgs, out)
......@@ -12,6 +12,8 @@ from .graph_index import GraphIndex
from .frame import FrameRef, merge_frames
from . import scheduler
from . import utils
from .function.message import BundledMessageFunction
from .function.reducer import BundledReduceFunction
class DGLGraph(object):
"""Base graph class specialized for neural networks on graphs.
......@@ -431,6 +433,8 @@ class DGLGraph(object):
if message_func == "default":
message_func, batchable = self._message_func
assert message_func is not None
if isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func)
if batchable:
self._batch_send(u, v, message_func)
else:
......@@ -470,7 +474,7 @@ class DGLGraph(object):
else:
self._msg_frame.append({__MSG__ : msgs})
def update_edge(self, u, v, edge_func="default", batchable=False):
def update_edge(self, u=ALL, v=ALL, edge_func="default", batchable=False):
"""Update representation on edge u->v
The edge function should be compatible with following signature:
......@@ -573,6 +577,8 @@ class DGLGraph(object):
if reduce_func == "default":
reduce_func, batchable = self._reduce_func
assert reduce_func is not None
if isinstance(reduce_func, (list, tuple)):
reduce_func = BundledReduceFunction(reduce_func)
if batchable:
self._batch_recv(u, reduce_func)
else:
......
......@@ -3,6 +3,7 @@ from __future__ import absolute_import
import numpy as np
from .base import ALL
from . import backend as F
from .function import message as fmsg
from .function import reducer as fred
......@@ -38,37 +39,32 @@ def degree_bucketing(cached_graph, v):
#print('degree-bucketing:', unique_degrees, [len(b) for b in v_bkt])
return unique_degrees, v_bkt
class Executor(object):
def run(self, graph):
def run(self):
raise NotImplementedError
class UpdateAllSPMVExecutor(Executor):
def __init__(self, graph, src_field, dst_field, edge_field, use_adj):
self.graph = graph
class SPMVOperator(Executor):
def __init__(self, src_field, edge_field, dst_field, use_edge_feat,
node_repr, adj_build_fn):
self.src_field = src_field
self.dst_field = dst_field
self.edge_field = edge_field
self.use_adj = use_adj
self.dst_field = dst_field
self.use_edge_feat = use_edge_feat
self.node_repr = node_repr
self.adj_build_fn = adj_build_fn
def run(self):
g = self.graph
# get src col
if self.src_field is None:
srccol = g.get_n_repr()
srccol = self.node_repr
else:
srccol = g.get_n_repr()[self.src_field]
srccol = self.node_repr[self.src_field]
ctx = F.get_context(srccol)
if self.use_adj:
adjmat = g.cached_graph.adjmat().get(ctx)
else:
if self.edge_field is None:
dat = g.get_e_repr()
else:
dat = g.get_e_repr()[self.edge_field]
dat = F.squeeze(dat)
# TODO(minjie): should not directly use _indices
idx = g.cached_graph.adjmat().get(ctx)._indices()
n = g.number_of_nodes()
adjmat = F.sparse_tensor(idx, dat, [n, n])
# build adjmat
adjmat = self.adj_build_fn(self.edge_field, ctx, self.use_edge_feat)
# spmm
if len(F.shape(srccol)) == 1:
srccol = F.unsqueeze(srccol, 1)
......@@ -77,104 +73,249 @@ class UpdateAllSPMVExecutor(Executor):
else:
dstcol = F.spmm(adjmat, srccol)
if self.dst_field is None:
g.set_n_repr(dstcol)
return dstcol
else:
g.set_n_repr({self.dst_field : dstcol})
return {self.dst_field : dstcol}
class SendRecvSPMVExecutor(Executor):
def __init__(self, graph, src, dst, src_field, dst_field, edge_field, use_edge_dat):
self.graph = graph
self.src = src
self.dst = dst
self.src_field = src_field
self.dst_field = dst_field
self.edge_field = edge_field
self.use_edge_dat = use_edge_dat
def run(self):
# get src col
g = self.graph
if self.src_field is None:
srccol = g.get_n_repr()
class BasicExecutor(Executor):
def __init__(self, graph, mfunc, rfunc):
self.g = graph
self.exe = self._build_exec(mfunc, rfunc)
@property
def node_repr(self):
raise NotImplementedError
@property
def edge_repr(self):
raise NotImplementedError
@property
def graph_mapping(self):
raise NotImplementedError
def _build_exec(self, mfunc, rfunc):
if isinstance(mfunc, fmsg.CopySrcMessageFunction):
exe = SPMVOperator(src_field=mfunc.src_field,
edge_field=None,
dst_field=rfunc.out_field,
use_edge_feat=False,
node_repr=self.node_repr,
adj_build_fn=self._adj_build_fn)
elif isinstance(mfunc, fmsg.SrcMulEdgeMessageFunction):
exe = SPMVOperator(src_field=mfunc.src_field,
edge_field=mfunc.edge_field,
dst_field=rfunc.out_field,
use_edge_feat=True,
node_repr=self.node_repr,
adj_build_fn=self._adj_build_fn)
else:
srccol = g.get_n_repr()[self.src_field]
ctx = F.get_context(srccol)
raise NotImplementedError("message func type {}".format(type(mfunc)))
return exe
# build adjmat
# build adjmat dat
u, v = utils.edge_broadcasting(self.src, self.dst)
if self.use_edge_dat:
if self.edge_field is None:
dat = g.get_e_repr(u, v)
def run(self):
attr = self.exe.run()
self.g.set_n_repr(attr, self.graph_mapping)
class UpdateAllExecutor(BasicExecutor):
def __init__(self, graph, mfunc, rfunc):
self._init_state()
super(UpdateAllExecutor, self).__init__(graph, mfunc, rfunc)
def _init_state(self):
self._node_repr = None
self._edge_repr = None
self._graph_idx = None
self._graph_shape = None
self._graph_mapping = None
@property
def graph_idx(self):
if self._graph_idx is None:
self._graph_idx = self.g.cached_graph.adjmat()
return self._graph_idx
@property
def graph_shape(self):
if self._graph_shape is None:
n = self.g.number_of_nodes()
self._graph_shape = [n, n]
return self._graph_shape
@property
def graph_mapping(self):
return ALL
@property
def node_repr(self):
if self._node_repr is None:
self._node_repr = self.g.get_n_repr()
return self._node_repr
@property
def edge_repr(self):
if self._edge_repr is None:
self._edge_repr = self.g.get_e_repr()
return self._edge_repr
def _adj_build_fn(self, edge_field, ctx, use_edge_feat):
if use_edge_feat:
if edge_field is None:
dat = self.edge_repr
else:
dat = g.get_e_repr(u, v)[self.edge_field]
dat = self.edge_repr[edge_field]
dat = F.squeeze(dat)
# TODO(minjie): should not directly use _indices
idx = self.graph_idx.get(ctx)._indices()
adjmat = F.sparse_tensor(idx, dat, self.graph_shape)
else:
dat = F.ones((len(u),))
# build adjmat index
new2old, old2new = utils.build_relabel_map(v)
u = u.totensor()
v = v.totensor()
adjmat = self.graph_idx.get(ctx)
return adjmat
class SendRecvExecutor(BasicExecutor):
def __init__(self, graph, src, dst, mfunc, rfunc):
self._init_state(src, dst)
super(SendRecvExecutor, self).__init__(graph, mfunc, rfunc)
def _init_state(self, src, dst):
self.u, self.v = utils.edge_broadcasting(src, dst)
self._node_repr = None
self._edge_repr = None
self._graph_idx = None
self._graph_shape = None
self._graph_mapping = None
@property
def graph_idx(self):
if self._graph_idx is None:
self._build_adjmat()
return self._graph_idx
@property
def graph_shape(self):
if self._graph_shape is None:
self._build_adjmat()
return self._graph_shape
@property
def graph_mapping(self):
if self._graph_mapping is None:
self._build_adjmat()
return self._graph_mapping
@property
def node_repr(self):
if self._node_repr is None:
self._node_repr = self.g.get_n_repr()
return self._node_repr
@property
def edge_repr(self):
if self._edge_repr is None:
self._edge_repr = self.g.get_e_repr(self.u, self.v)
return self._edge_repr
def _build_adjmat(self):
# handle graph index
new2old, old2new = utils.build_relabel_map(self.v)
u = self.u.totensor()
v = self.v.totensor()
# TODO(minjie): should not directly use []
new_v = old2new[v]
idx = F.pack([F.unsqueeze(new_v, 0), F.unsqueeze(u, 0)])
n = g.number_of_nodes()
n = self.g.number_of_nodes()
m = len(new2old)
adjmat = F.sparse_tensor(idx, dat, [m, n])
adjmat = F.to_context(adjmat, ctx)
# spmm
if len(F.shape(srccol)) == 1:
srccol = F.unsqueeze(srccol, 1)
dstcol = F.spmm(adjmat, srccol)
dstcol = F.squeeze(dstcol)
self._graph_idx = F.pack([F.unsqueeze(new_v, 0), F.unsqueeze(u, 0)])
self._graph_shape = [m, n]
self._graph_mapping = new2old
def _adj_build_fn(self, edge_field, ctx, use_edge_feat):
if use_edge_feat:
if edge_field is None:
dat = self.edge_repr
else:
dstcol = F.spmm(adjmat, srccol)
if self.dst_field is None:
g.set_n_repr(dstcol, new2old)
dat = self.edge_repr[edge_field]
dat = F.squeeze(dat)
else:
g.set_n_repr({self.dst_field : dstcol}, new2old)
dat = F.ones((len(self.u), ))
adjmat = F.sparse_tensor(self.graph_idx, dat, self.graph_shape)
return F.to_context(adjmat, ctx)
def _is_spmv_supported_node_feat(g, field):
if field is None:
feat = g.get_n_repr()
class BundledExecutor(BasicExecutor):
"""
Base class for Bundled execution
All shared structure like graph index should be cached in this class or its subclass
BundledUpdateAllExecutor and BundledSendRecvExecutor should subclass BundledExecutor
"""
def __init__(self, graph, mfunc, rfunc):
self.g = graph
func_pairs = self._match_message_with_reduce(mfunc, rfunc)
# create all executors
self.executors = self._build_executors(func_pairs)
def _build_executors(self, func_pairs):
executors = []
for mfunc, rfunc in func_pairs:
exe = self._build_exec(mfunc, rfunc)
executors.append(exe)
return executors
def _match_message_with_reduce(self, mfunc, rfunc):
out2mfunc = {fn.out_field: fn for fn in mfunc.fn_list}
func_pairs = []
for rfn in rfunc.fn_list:
mfn = out2mfunc.get(rfn.msg_field, None)
# field check
assert mfn is not None, \
"cannot find message func for reduce func in-field {}".format(rfn.msg_field)
func_pairs.append((mfn, rfn))
return func_pairs
def run(self):
attr = None
for exe in self.executors:
res = exe.run()
if attr is None:
attr = res
else:
feat = g.get_n_repr()[field]
shape = F.shape(feat)
return (len(shape) == 1 or len(shape) == 2)
def _is_spmv_supported_edge_feat(g, field):
# check shape, only scalar edge feature can be optimized at the moment.
if field is None:
feat = g.get_e_repr()
# attr and res must be dict
attr.update(res)
self.g.set_n_repr(attr, self.graph_mapping)
class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor):
def __init__(self, graph, mfunc, rfunc):
self._init_state()
BundledExecutor.__init__(self, graph, mfunc, rfunc)
class BundledSendRecvExecutor(BundledExecutor, SendRecvExecutor):
def __init__(self, graph, src, dst, mfunc, rfunc):
self._init_state(src, dst)
BundledExecutor.__init__(self, graph, mfunc, rfunc)
def _is_spmv_supported(fn, graph=None):
if isinstance(fn, fmsg.MessageFunction):
return fn.is_spmv_supported(graph)
elif isinstance(fn, fred.ReduceFunction):
return fn.is_spmv_supported()
else:
feat = g.get_e_repr()[field]
shape = F.shape(feat)
return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1)
return False
def _create_update_all_exec(graph, **kwargs):
mfunc = kwargs.pop('message_func')
rfunc = kwargs.pop('reduce_func')
if (isinstance(mfunc, fmsg.CopySrcMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)):
# TODO(minjie): more sanity check on field names
return UpdateAllSPMVExecutor(graph,
src_field=mfunc.src_field,
dst_field=rfunc.out_field,
edge_field=None,
use_adj=True)
elif (isinstance(mfunc, fmsg.SrcMulEdgeMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)
and _is_spmv_supported_edge_feat(graph, mfunc.edge_field)):
return UpdateAllSPMVExecutor(graph,
src_field=mfunc.src_field,
dst_field=rfunc.out_field,
edge_field=mfunc.edge_field,
use_adj=False)
elif (isinstance(mfunc, fmsg.CopyEdgeMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)):
return None
if isinstance(mfunc, (list, tuple)) or isinstance(rfunc, (list, tuple)):
mfunc = fmsg.BundledMessageFunction(mfunc)
rfunc = fred.BundledReduceFunction(rfunc)
exec_cls = BundledUpdateAllExecutor
else:
exec_cls = UpdateAllExecutor
if _is_spmv_supported(mfunc, graph) and _is_spmv_supported(rfunc):
return exec_cls(graph, mfunc=mfunc, rfunc=rfunc)
else:
return None
......@@ -183,28 +324,14 @@ def _create_send_and_recv_exec(graph, **kwargs):
dst = kwargs.pop('dst')
mfunc = kwargs.pop('message_func')
rfunc = kwargs.pop('reduce_func')
if (isinstance(mfunc, fmsg.CopySrcMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)):
# TODO(minjie): more sanity check on field names
return SendRecvSPMVExecutor(graph,
src=src,
dst=dst,
src_field=mfunc.src_field,
dst_field=rfunc.out_field,
edge_field=None,
use_edge_dat=False)
elif (isinstance(mfunc, fmsg.SrcMulEdgeMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)
and _is_spmv_supported_edge_feat(graph, mfunc.edge_field)):
return SendRecvSPMVExecutor(graph,
src=src,
dst=dst,
src_field=mfunc.src_field,
dst_field=rfunc.out_field,
edge_field=mfunc.edge_field,
use_edge_dat=True)
if isinstance(mfunc, (list, tuple)) or isinstance(rfunc, (list, tuple)):
mfunc = fmsg.BundledMessageFunction(mfunc)
rfunc = fred.BundledReduceFunction(rfunc)
exec_cls = BundledSendRecvExecutor
else:
exec_cls = SendRecvExecutor
if _is_spmv_supported(mfunc, graph) and _is_spmv_supported(rfunc):
return exec_cls(graph, src=src, dst=dst, mfunc=mfunc, rfunc=rfunc)
else:
return None
......
......@@ -113,6 +113,110 @@ def test_send_and_recv():
# test 2d node features
_test('f2')
def test_update_all_multi_fn():
def message_func(hu, edge):
return {'m2': hu['f2']}
def message_func_edge(hu, edge):
return {'m2': hu['f2'] * edge['e2']}
def reduce_func(hv, msgs):
return {'v2': th.sum(msgs['m2'], 1)}
g = generate_graph()
fld = 'f2'
# update all, mix of builtin and UDF
g.update_all([fn.copy_src(src=fld, out='m1'), message_func],
[fn.sum(msgs='m1', out='v1'), reduce_func],
None, batchable=True)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)
# run builtin with single message and reduce
g.update_all(fn.copy_src(src=fld), fn.sum(out='v1'), None, batchable=True)
v1 = g.get_n_repr()['v1']
assert th.allclose(v1, v2)
# 1 message, 2 reduces, using anonymous repr
g.update_all(fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None, batchable=True)
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2)
assert th.allclose(v1, v3)
# update all with edge weights, 2 message, 3 reduces
g.update_all([fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
[fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')],
None, batchable=True)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2)
assert th.allclose(v1, v3)
# run UDF with single message and reduce
g.update_all(message_func_edge, reduce_func, None, batchable=True)
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)
def test_send_and_recv_multi_fn():
u = th.tensor([0, 0, 0, 3, 4, 9])
v = th.tensor([1, 2, 3, 9, 9, 0])
def message_func(hu, edge):
return {'m2': hu['f2']}
def message_func_edge(hu, edge):
return {'m2': hu['f2'] * edge['e2']}
def reduce_func(hv, msgs):
return {'v2' : th.sum(msgs['m2'], 1)}
g = generate_graph()
fld = 'f2'
# send and recv, mix of builtin and UDF
g.send_and_recv(u, v,
[fn.copy_src(src=fld, out='m1'), message_func],
[fn.sum(msgs='m1', out='v1'), reduce_func],
None, batchable=True)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)
# run builtin with single message and reduce
g.send_and_recv(u, v, fn.copy_src(src=fld), fn.sum(out='v1'),
None, batchable=True)
v1 = g.get_n_repr()['v1']
assert th.allclose(v1, v2)
# 1 message, 2 reduces, using anonymous repr
g.send_and_recv(u, v, fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None, batchable=True)
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2)
assert th.allclose(v1, v3)
# send and recv with edge weights, 2 message, 3 reduces
g.send_and_recv(u, v,
[fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
[fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')],
None, batchable=True)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2)
assert th.allclose(v1, v3)
# run UDF with single message and reduce
g.send_and_recv(u, v, message_func_edge,
reduce_func, None, batchable=True)
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)
if __name__ == '__main__':
#test_update_all()
test_update_all()
test_send_and_recv()
test_update_all_multi_fn()
test_send_and_recv_multi_fn()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment