Unverified Commit 9b0a01db authored by Lingfan Yu's avatar Lingfan Yu Committed by GitHub
Browse files

Support list of msg or reduce func for update_all and send_and_recv (#58)

* support multiple fields in spmv

* finishing SPMV executor

* non-spmv case

* refactor code to give single mfunc rfunc a shortcut

* two test cases to test multiple msg/red func

* catching cases where msg uses anonymous field

* default ALL for update edge

* more corner case test

* print failed test

* delete print

* fix builtin max reducer
parent 61fa3c6c
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
from __future__ import absolute_import from __future__ import absolute_import
import operator import operator
import dgl.backend as F
__all__ = ["MessageFunction", "src_mul_edge", "copy_src", "copy_edge"] __all__ = ["MessageFunction", "src_mul_edge", "copy_src", "copy_edge"]
class MessageFunction(object): class MessageFunction(object):
def __call__(self, src, edge): def __call__(self, src, edge):
raise NotImplementedError raise NotImplementedError
...@@ -12,10 +14,28 @@ class MessageFunction(object): ...@@ -12,10 +14,28 @@ class MessageFunction(object):
def name(self): def name(self):
raise NotImplementedError raise NotImplementedError
def is_spmv_supported(self, g):
raise NotImplementedError
class BundledMessageFunction(MessageFunction): class BundledMessageFunction(MessageFunction):
def __init__(self, fn_list): def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list]
else:
# sanity check on out field
for fn in fn_list:
# cannot perform check for udf
if isinstance(fn, MessageFunction) and fn.out_field is None:
raise RuntimeError("Not specifying out field for multiple message is ambiguous")
self.fn_list = fn_list self.fn_list = fn_list
def is_spmv_supported(self, g):
for fn in self.fn_list:
if not isinstance(fn, MessageFunction) or not fn.is_spmv_supported(g):
return False
return True
def __call__(self, src, edge): def __call__(self, src, edge):
ret = None ret = None
for fn in self.fn_list: for fn in self.fn_list:
...@@ -24,16 +44,34 @@ class BundledMessageFunction(MessageFunction): ...@@ -24,16 +44,34 @@ class BundledMessageFunction(MessageFunction):
ret = msg ret = msg
else: else:
try: try:
# ret and msg must be dict
ret.update(msg) ret.update(msg)
except e: except:
raise RuntimeError("Failed to merge results of two builtin" raise RuntimeError("Must specify out field for multiple message")
" message functions. Please specify out_field"
" for the builtin message function.")
return ret return ret
def name(self): def name(self):
return "bundled" return "bundled"
def _is_spmv_supported_node_feat(g, field):
if field is None:
feat = g.get_n_repr()
else:
feat = g.get_n_repr()[field]
shape = F.shape(feat)
return len(shape) == 1 or len(shape) == 2
def _is_spmv_supported_edge_feat(g, field):
# check shape, only scalar edge feature can be optimized at the moment
if field is None:
feat = g.get_e_repr()
else:
feat = g.get_e_repr()[field]
shape = F.shape(feat)
return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1)
class SrcMulEdgeMessageFunction(MessageFunction): class SrcMulEdgeMessageFunction(MessageFunction):
def __init__(self, mul_op, src_field=None, edge_field=None, out_field=None): def __init__(self, mul_op, src_field=None, edge_field=None, out_field=None):
self.mul_op = mul_op self.mul_op = mul_op
...@@ -41,6 +79,10 @@ class SrcMulEdgeMessageFunction(MessageFunction): ...@@ -41,6 +79,10 @@ class SrcMulEdgeMessageFunction(MessageFunction):
self.edge_field = edge_field self.edge_field = edge_field
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self, g):
return _is_spmv_supported_node_feat(g, self.src_field) \
and _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge): def __call__(self, src, edge):
if self.src_field is not None: if self.src_field is not None:
src = src[self.src_field] src = src[self.src_field]
...@@ -60,6 +102,9 @@ class CopySrcMessageFunction(MessageFunction): ...@@ -60,6 +102,9 @@ class CopySrcMessageFunction(MessageFunction):
self.src_field = src_field self.src_field = src_field
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self, g):
return _is_spmv_supported_node_feat(g, self.src_field)
def __call__(self, src, edge): def __call__(self, src, edge):
if self.src_field is not None: if self.src_field is not None:
ret = src[self.src_field] ret = src[self.src_field]
...@@ -78,6 +123,11 @@ class CopyEdgeMessageFunction(MessageFunction): ...@@ -78,6 +123,11 @@ class CopyEdgeMessageFunction(MessageFunction):
self.edge_field = edge_field self.edge_field = edge_field
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self, g):
# TODO: support this with g-spmv
return False
# return _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge): def __call__(self, src, edge):
if self.edge_field is not None: if self.edge_field is not None:
ret = edge[self.edge_field] ret = edge[self.edge_field]
...@@ -90,7 +140,8 @@ class CopyEdgeMessageFunction(MessageFunction): ...@@ -90,7 +140,8 @@ class CopyEdgeMessageFunction(MessageFunction):
def name(self): def name(self):
return "copy_edge" return "copy_edge"
def src_mul_edge(src=None, edge=None, out=None): def src_mul_edge(src=None, edge=None, out=None):
"""TODO(minjie): docstring """ """TODO(minjie): docstring """
return SrcMulEdgeMessageFunction(operator.mul, src, edge, out) return SrcMulEdgeMessageFunction(operator.mul, src, edge, out)
......
...@@ -12,10 +12,26 @@ class ReduceFunction(object): ...@@ -12,10 +12,26 @@ class ReduceFunction(object):
def name(self): def name(self):
raise NotImplementedError raise NotImplementedError
def is_spmv_supported(self):
raise NotImplementedError
class BundledReduceFunction(ReduceFunction): class BundledReduceFunction(ReduceFunction):
def __init__(self, fn_list): def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list]
else:
# sanity check on out field
for fn in fn_list:
if isinstance(fn, ReduceFunction) and fn.out_field is None:
raise RuntimeError("Not specifying out field for multiple reduce is ambiguous")
self.fn_list = fn_list self.fn_list = fn_list
def is_spmv_supported(self):
for fn in self.fn_list:
if not isinstance(fn, ReduceFunction) or not fn.is_spmv_supported():
return False
return True
def __call__(self, node, msgs): def __call__(self, node, msgs):
ret = None ret = None
for fn in self.fn_list: for fn in self.fn_list:
...@@ -24,46 +40,50 @@ class BundledReduceFunction(ReduceFunction): ...@@ -24,46 +40,50 @@ class BundledReduceFunction(ReduceFunction):
ret = rpr ret = rpr
else: else:
try: try:
# ret and rpr must be dict
ret.update(rpr) ret.update(rpr)
except e: except:
raise RuntimeError("Failed to merge results of two builtin" raise RuntimeError("Must specify out field for multiple reudce")
" reduce functions. Please specify out_field"
" for the builtin reduce function.")
return ret return ret
def name(self): def name(self):
return "bundled" return "bundled"
class SumReducerFunction(ReduceFunction): class ReducerFunctionTemplate(ReduceFunction):
def __init__(self, batch_sum_op, nonbatch_sum_op, msg_field=None, out_field=None): def __init__(self, name, batch_op, nonbatch_op, msg_field=None, out_field=None):
self.batch_sum_op = batch_sum_op self.name = name
self.nonbatch_sum_op = nonbatch_sum_op self.batch_op = batch_op
self.nonbatch_op = nonbatch_op
self.msg_field = msg_field self.msg_field = msg_field
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self):
# TODO: support max
return self.name == "sum"
def __call__(self, node, msgs): def __call__(self, node, msgs):
if isinstance(msgs, list): if isinstance(msgs, list):
if self.msg_field is None: if self.msg_field is None:
ret = self.nonbatch_sum_op(msgs) ret = self.nonbatch_op(msgs)
else: else:
ret = self.nonbatch_sum_op([msg[self.msg_field] for msg in msgs]) ret = self.nonbatch_op([msg[self.msg_field] for msg in msgs])
else: else:
if self.msg_field is None: if self.msg_field is None:
ret = self.batch_sum_op(msgs, 1) ret = self.batch_op(msgs, 1)
else: else:
ret = self.batch_sum_op(msgs[self.msg_field], 1) ret = self.batch_op(msgs[self.msg_field], 1)
if self.out_field is None: if self.out_field is None:
return ret return ret
else: else:
return {self.out_field : ret} return {self.out_field : ret}
def name(self): def name(self):
return "sum" return self.name
_python_sum = sum _python_sum = sum
def sum(msgs=None, out=None): def sum(msgs=None, out=None):
return SumReducerFunction(F.sum, _python_sum, msgs, out) return ReducerFunctionTemplate("sum", F.sum, _python_sum, msgs, out)
_python_max = max _python_max = max
def max(msgs=None, out=None): def max(msgs=None, out=None):
return SumReducerFunction(F.max, _python_max, msgs, out) return ReducerFunctionTemplate("max", F.max, _python_max, msgs, out)
...@@ -15,6 +15,8 @@ from dgl.frame import FrameRef, merge_frames ...@@ -15,6 +15,8 @@ from dgl.frame import FrameRef, merge_frames
from dgl.nx_adapt import nx_init from dgl.nx_adapt import nx_init
import dgl.scheduler as scheduler import dgl.scheduler as scheduler
import dgl.utils as utils import dgl.utils as utils
from dgl.function.message import BundledMessageFunction
from dgl.function.reducer import BundledReduceFunction
class DGLGraph(DiGraph): class DGLGraph(DiGraph):
"""Base graph class specialized for neural networks on graphs. """Base graph class specialized for neural networks on graphs.
...@@ -434,6 +436,8 @@ class DGLGraph(DiGraph): ...@@ -434,6 +436,8 @@ class DGLGraph(DiGraph):
if message_func == "default": if message_func == "default":
message_func, batchable = self._message_func message_func, batchable = self._message_func
assert message_func is not None assert message_func is not None
if isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func)
if batchable: if batchable:
self._batch_send(u, v, message_func) self._batch_send(u, v, message_func)
else: else:
...@@ -473,7 +477,7 @@ class DGLGraph(DiGraph): ...@@ -473,7 +477,7 @@ class DGLGraph(DiGraph):
else: else:
self._msg_frame.append({__MSG__ : msgs}) self._msg_frame.append({__MSG__ : msgs})
def update_edge(self, u, v, edge_func="default", batchable=False): def update_edge(self, u=ALL, v=ALL, edge_func="default", batchable=False):
"""Update representation on edge u->v """Update representation on edge u->v
The edge function should be compatible with following signature: The edge function should be compatible with following signature:
...@@ -576,6 +580,8 @@ class DGLGraph(DiGraph): ...@@ -576,6 +580,8 @@ class DGLGraph(DiGraph):
if reduce_func == "default": if reduce_func == "default":
reduce_func, batchable = self._reduce_func reduce_func, batchable = self._reduce_func
assert reduce_func is not None assert reduce_func is not None
if isinstance(reduce_func, (list, tuple)):
reduce_func = BundledReduceFunction(reduce_func)
if batchable: if batchable:
self._batch_recv(u, reduce_func) self._batch_recv(u, reduce_func)
else: else:
......
...@@ -7,6 +7,7 @@ import dgl.backend as F ...@@ -7,6 +7,7 @@ import dgl.backend as F
import dgl.function.message as fmsg import dgl.function.message as fmsg
import dgl.function.reducer as fred import dgl.function.reducer as fred
import dgl.utils as utils import dgl.utils as utils
from dgl.base import ALL
__all__ = ["degree_bucketing", "get_executor"] __all__ = ["degree_bucketing", "get_executor"]
...@@ -38,37 +39,32 @@ def degree_bucketing(cached_graph, v): ...@@ -38,37 +39,32 @@ def degree_bucketing(cached_graph, v):
#print('degree-bucketing:', unique_degrees, [len(b) for b in v_bkt]) #print('degree-bucketing:', unique_degrees, [len(b) for b in v_bkt])
return unique_degrees, v_bkt return unique_degrees, v_bkt
class Executor(object): class Executor(object):
def run(self, graph): def run(self):
raise NotImplementedError raise NotImplementedError
class UpdateAllSPMVExecutor(Executor): class SPMVOperator(Executor):
def __init__(self, graph, src_field, dst_field, edge_field, use_adj): def __init__(self, src_field, edge_field, dst_field, use_edge_feat,
self.graph = graph node_repr, adj_build_fn):
self.src_field = src_field self.src_field = src_field
self.dst_field = dst_field
self.edge_field = edge_field self.edge_field = edge_field
self.use_adj = use_adj self.dst_field = dst_field
self.use_edge_feat = use_edge_feat
self.node_repr = node_repr
self.adj_build_fn = adj_build_fn
def run(self): def run(self):
g = self.graph # get src col
if self.src_field is None: if self.src_field is None:
srccol = g.get_n_repr() srccol = self.node_repr
else: else:
srccol = g.get_n_repr()[self.src_field] srccol = self.node_repr[self.src_field]
ctx = F.get_context(srccol) ctx = F.get_context(srccol)
if self.use_adj:
adjmat = g.cached_graph.adjmat().get(ctx) # build adjmat
else: adjmat = self.adj_build_fn(self.edge_field, ctx, self.use_edge_feat)
if self.edge_field is None:
dat = g.get_e_repr()
else:
dat = g.get_e_repr()[self.edge_field]
dat = F.squeeze(dat)
# TODO(minjie): should not directly use _indices
idx = g.cached_graph.adjmat().get(ctx)._indices()
n = g.number_of_nodes()
adjmat = F.sparse_tensor(idx, dat, [n, n])
# spmm # spmm
if len(F.shape(srccol)) == 1: if len(F.shape(srccol)) == 1:
srccol = F.unsqueeze(srccol, 1) srccol = F.unsqueeze(srccol, 1)
...@@ -77,104 +73,249 @@ class UpdateAllSPMVExecutor(Executor): ...@@ -77,104 +73,249 @@ class UpdateAllSPMVExecutor(Executor):
else: else:
dstcol = F.spmm(adjmat, srccol) dstcol = F.spmm(adjmat, srccol)
if self.dst_field is None: if self.dst_field is None:
g.set_n_repr(dstcol) return dstcol
else: else:
g.set_n_repr({self.dst_field : dstcol}) return {self.dst_field : dstcol}
class SendRecvSPMVExecutor(Executor):
def __init__(self, graph, src, dst, src_field, dst_field, edge_field, use_edge_dat):
self.graph = graph
self.src = src
self.dst = dst
self.src_field = src_field
self.dst_field = dst_field
self.edge_field = edge_field
self.use_edge_dat = use_edge_dat
def run(self): class BasicExecutor(Executor):
# get src col def __init__(self, graph, mfunc, rfunc):
g = self.graph self.g = graph
if self.src_field is None: self.exe = self._build_exec(mfunc, rfunc)
srccol = g.get_n_repr()
@property
def node_repr(self):
raise NotImplementedError
@property
def edge_repr(self):
raise NotImplementedError
@property
def graph_mapping(self):
raise NotImplementedError
def _build_exec(self, mfunc, rfunc):
if isinstance(mfunc, fmsg.CopySrcMessageFunction):
exe = SPMVOperator(src_field=mfunc.src_field,
edge_field=None,
dst_field=rfunc.out_field,
use_edge_feat=False,
node_repr=self.node_repr,
adj_build_fn=self._adj_build_fn)
elif isinstance(mfunc, fmsg.SrcMulEdgeMessageFunction):
exe = SPMVOperator(src_field=mfunc.src_field,
edge_field=mfunc.edge_field,
dst_field=rfunc.out_field,
use_edge_feat=True,
node_repr=self.node_repr,
adj_build_fn=self._adj_build_fn)
else: else:
srccol = g.get_n_repr()[self.src_field] raise NotImplementedError("message func type {}".format(type(mfunc)))
ctx = F.get_context(srccol) return exe
# build adjmat def run(self):
# build adjmat dat attr = self.exe.run()
u, v = utils.edge_broadcasting(self.src, self.dst) self.g.set_n_repr(attr, self.graph_mapping)
if self.use_edge_dat:
if self.edge_field is None:
dat = g.get_e_repr(u, v) class UpdateAllExecutor(BasicExecutor):
def __init__(self, graph, mfunc, rfunc):
self._init_state()
super(UpdateAllExecutor, self).__init__(graph, mfunc, rfunc)
def _init_state(self):
self._node_repr = None
self._edge_repr = None
self._graph_idx = None
self._graph_shape = None
self._graph_mapping = None
@property
def graph_idx(self):
if self._graph_idx is None:
self._graph_idx = self.g.cached_graph.adjmat()
return self._graph_idx
@property
def graph_shape(self):
if self._graph_shape is None:
n = self.g.number_of_nodes()
self._graph_shape = [n, n]
return self._graph_shape
@property
def graph_mapping(self):
return ALL
@property
def node_repr(self):
if self._node_repr is None:
self._node_repr = self.g.get_n_repr()
return self._node_repr
@property
def edge_repr(self):
if self._edge_repr is None:
self._edge_repr = self.g.get_e_repr()
return self._edge_repr
def _adj_build_fn(self, edge_field, ctx, use_edge_feat):
if use_edge_feat:
if edge_field is None:
dat = self.edge_repr
else: else:
dat = g.get_e_repr(u, v)[self.edge_field] dat = self.edge_repr[edge_field]
dat = F.squeeze(dat) dat = F.squeeze(dat)
# TODO(minjie): should not directly use _indices
idx = self.graph_idx.get(ctx)._indices()
adjmat = F.sparse_tensor(idx, dat, self.graph_shape)
else: else:
dat = F.ones((len(u),)) adjmat = self.graph_idx.get(ctx)
# build adjmat index return adjmat
new2old, old2new = utils.build_relabel_map(v)
u = u.totensor()
v = v.totensor() class SendRecvExecutor(BasicExecutor):
def __init__(self, graph, src, dst, mfunc, rfunc):
self._init_state(src, dst)
super(SendRecvExecutor, self).__init__(graph, mfunc, rfunc)
def _init_state(self, src, dst):
self.u, self.v = utils.edge_broadcasting(src, dst)
self._node_repr = None
self._edge_repr = None
self._graph_idx = None
self._graph_shape = None
self._graph_mapping = None
@property
def graph_idx(self):
if self._graph_idx is None:
self._build_adjmat()
return self._graph_idx
@property
def graph_shape(self):
if self._graph_shape is None:
self._build_adjmat()
return self._graph_shape
@property
def graph_mapping(self):
if self._graph_mapping is None:
self._build_adjmat()
return self._graph_mapping
@property
def node_repr(self):
if self._node_repr is None:
self._node_repr = self.g.get_n_repr()
return self._node_repr
@property
def edge_repr(self):
if self._edge_repr is None:
self._edge_repr = self.g.get_e_repr(self.u, self.v)
return self._edge_repr
def _build_adjmat(self):
# handle graph index
new2old, old2new = utils.build_relabel_map(self.v)
u = self.u.totensor()
v = self.v.totensor()
# TODO(minjie): should not directly use [] # TODO(minjie): should not directly use []
new_v = old2new[v] new_v = old2new[v]
idx = F.pack([F.unsqueeze(new_v, 0), F.unsqueeze(u, 0)]) n = self.g.number_of_nodes()
n = g.number_of_nodes()
m = len(new2old) m = len(new2old)
adjmat = F.sparse_tensor(idx, dat, [m, n]) self._graph_idx = F.pack([F.unsqueeze(new_v, 0), F.unsqueeze(u, 0)])
adjmat = F.to_context(adjmat, ctx) self._graph_shape = [m, n]
# spmm self._graph_mapping = new2old
if len(F.shape(srccol)) == 1:
srccol = F.unsqueeze(srccol, 1) def _adj_build_fn(self, edge_field, ctx, use_edge_feat):
dstcol = F.spmm(adjmat, srccol) if use_edge_feat:
dstcol = F.squeeze(dstcol) if edge_field is None:
else: dat = self.edge_repr
dstcol = F.spmm(adjmat, srccol) else:
if self.dst_field is None: dat = self.edge_repr[edge_field]
g.set_n_repr(dstcol, new2old) dat = F.squeeze(dat)
else: else:
g.set_n_repr({self.dst_field : dstcol}, new2old) dat = F.ones((len(self.u), ))
adjmat = F.sparse_tensor(self.graph_idx, dat, self.graph_shape)
return F.to_context(adjmat, ctx)
def _is_spmv_supported_node_feat(g, field):
if field is None: class BundledExecutor(BasicExecutor):
feat = g.get_n_repr() """
else: Base class for Bundled execution
feat = g.get_n_repr()[field] All shared structure like graph index should be cached in this class or its subclass
shape = F.shape(feat) BundledUpdateAllExecutor and BundledSendRecvExecutor should subclass BundledExecutor
return (len(shape) == 1 or len(shape) == 2) """
def __init__(self, graph, mfunc, rfunc):
def _is_spmv_supported_edge_feat(g, field): self.g = graph
# check shape, only scalar edge feature can be optimized at the moment. func_pairs = self._match_message_with_reduce(mfunc, rfunc)
if field is None: # create all executors
feat = g.get_e_repr() self.executors = self._build_executors(func_pairs)
def _build_executors(self, func_pairs):
executors = []
for mfunc, rfunc in func_pairs:
exe = self._build_exec(mfunc, rfunc)
executors.append(exe)
return executors
def _match_message_with_reduce(self, mfunc, rfunc):
out2mfunc = {fn.out_field: fn for fn in mfunc.fn_list}
func_pairs = []
for rfn in rfunc.fn_list:
mfn = out2mfunc.get(rfn.msg_field, None)
# field check
assert mfn is not None, \
"cannot find message func for reduce func in-field {}".format(rfn.msg_field)
func_pairs.append((mfn, rfn))
return func_pairs
def run(self):
attr = None
for exe in self.executors:
res = exe.run()
if attr is None:
attr = res
else:
# attr and res must be dict
attr.update(res)
self.g.set_n_repr(attr, self.graph_mapping)
class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor):
def __init__(self, graph, mfunc, rfunc):
self._init_state()
BundledExecutor.__init__(self, graph, mfunc, rfunc)
class BundledSendRecvExecutor(BundledExecutor, SendRecvExecutor):
def __init__(self, graph, src, dst, mfunc, rfunc):
self._init_state(src, dst)
BundledExecutor.__init__(self, graph, mfunc, rfunc)
def _is_spmv_supported(fn, graph=None):
if isinstance(fn, fmsg.MessageFunction):
return fn.is_spmv_supported(graph)
elif isinstance(fn, fred.ReduceFunction):
return fn.is_spmv_supported()
else: else:
feat = g.get_e_repr()[field] return False
shape = F.shape(feat)
return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1)
def _create_update_all_exec(graph, **kwargs): def _create_update_all_exec(graph, **kwargs):
mfunc = kwargs.pop('message_func') mfunc = kwargs.pop('message_func')
rfunc = kwargs.pop('reduce_func') rfunc = kwargs.pop('reduce_func')
if (isinstance(mfunc, fmsg.CopySrcMessageFunction) if isinstance(mfunc, (list, tuple)) or isinstance(rfunc, (list, tuple)):
and isinstance(rfunc, fred.SumReducerFunction) mfunc = fmsg.BundledMessageFunction(mfunc)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)): rfunc = fred.BundledReduceFunction(rfunc)
# TODO(minjie): more sanity check on field names exec_cls = BundledUpdateAllExecutor
return UpdateAllSPMVExecutor(graph, else:
src_field=mfunc.src_field, exec_cls = UpdateAllExecutor
dst_field=rfunc.out_field, if _is_spmv_supported(mfunc, graph) and _is_spmv_supported(rfunc):
edge_field=None, return exec_cls(graph, mfunc=mfunc, rfunc=rfunc)
use_adj=True)
elif (isinstance(mfunc, fmsg.SrcMulEdgeMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)
and _is_spmv_supported_edge_feat(graph, mfunc.edge_field)):
return UpdateAllSPMVExecutor(graph,
src_field=mfunc.src_field,
dst_field=rfunc.out_field,
edge_field=mfunc.edge_field,
use_adj=False)
elif (isinstance(mfunc, fmsg.CopyEdgeMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)):
return None
else: else:
return None return None
...@@ -183,28 +324,14 @@ def _create_send_and_recv_exec(graph, **kwargs): ...@@ -183,28 +324,14 @@ def _create_send_and_recv_exec(graph, **kwargs):
dst = kwargs.pop('dst') dst = kwargs.pop('dst')
mfunc = kwargs.pop('message_func') mfunc = kwargs.pop('message_func')
rfunc = kwargs.pop('reduce_func') rfunc = kwargs.pop('reduce_func')
if (isinstance(mfunc, fmsg.CopySrcMessageFunction) if isinstance(mfunc, (list, tuple)) or isinstance(rfunc, (list, tuple)):
and isinstance(rfunc, fred.SumReducerFunction) mfunc = fmsg.BundledMessageFunction(mfunc)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)): rfunc = fred.BundledReduceFunction(rfunc)
# TODO(minjie): more sanity check on field names exec_cls = BundledSendRecvExecutor
return SendRecvSPMVExecutor(graph, else:
src=src, exec_cls = SendRecvExecutor
dst=dst, if _is_spmv_supported(mfunc, graph) and _is_spmv_supported(rfunc):
src_field=mfunc.src_field, return exec_cls(graph, src=src, dst=dst, mfunc=mfunc, rfunc=rfunc)
dst_field=rfunc.out_field,
edge_field=None,
use_edge_dat=False)
elif (isinstance(mfunc, fmsg.SrcMulEdgeMessageFunction)
and isinstance(rfunc, fred.SumReducerFunction)
and _is_spmv_supported_node_feat(graph, mfunc.src_field)
and _is_spmv_supported_edge_feat(graph, mfunc.edge_field)):
return SendRecvSPMVExecutor(graph,
src=src,
dst=dst,
src_field=mfunc.src_field,
dst_field=rfunc.out_field,
edge_field=mfunc.edge_field,
use_edge_dat=True)
else: else:
return None return None
......
...@@ -113,6 +113,110 @@ def test_send_and_recv(): ...@@ -113,6 +113,110 @@ def test_send_and_recv():
# test 2d node features # test 2d node features
_test('f2') _test('f2')
def test_update_all_multi_fn():
def message_func(hu, edge):
return {'m2': hu['f2']}
def message_func_edge(hu, edge):
return {'m2': hu['f2'] * edge['e2']}
def reduce_func(hv, msgs):
return {'v2': th.sum(msgs['m2'], 1)}
g = generate_graph()
fld = 'f2'
# update all, mix of builtin and UDF
g.update_all([fn.copy_src(src=fld, out='m1'), message_func],
[fn.sum(msgs='m1', out='v1'), reduce_func],
None, batchable=True)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)
# run builtin with single message and reduce
g.update_all(fn.copy_src(src=fld), fn.sum(out='v1'), None, batchable=True)
v1 = g.get_n_repr()['v1']
assert th.allclose(v1, v2)
# 1 message, 2 reduces, using anonymous repr
g.update_all(fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None, batchable=True)
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2)
assert th.allclose(v1, v3)
# update all with edge weights, 2 message, 3 reduces
g.update_all([fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
[fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')],
None, batchable=True)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2)
assert th.allclose(v1, v3)
# run UDF with single message and reduce
g.update_all(message_func_edge, reduce_func, None, batchable=True)
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)
def test_send_and_recv_multi_fn():
u = th.tensor([0, 0, 0, 3, 4, 9])
v = th.tensor([1, 2, 3, 9, 9, 0])
def message_func(hu, edge):
return {'m2': hu['f2']}
def message_func_edge(hu, edge):
return {'m2': hu['f2'] * edge['e2']}
def reduce_func(hv, msgs):
return {'v2' : th.sum(msgs['m2'], 1)}
g = generate_graph()
fld = 'f2'
# send and recv, mix of builtin and UDF
g.send_and_recv(u, v,
[fn.copy_src(src=fld, out='m1'), message_func],
[fn.sum(msgs='m1', out='v1'), reduce_func],
None, batchable=True)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)
# run builtin with single message and reduce
g.send_and_recv(u, v, fn.copy_src(src=fld), fn.sum(out='v1'),
None, batchable=True)
v1 = g.get_n_repr()['v1']
assert th.allclose(v1, v2)
# 1 message, 2 reduces, using anonymous repr
g.send_and_recv(u, v, fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None, batchable=True)
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2)
assert th.allclose(v1, v3)
# send and recv with edge weights, 2 message, 3 reduces
g.send_and_recv(u, v,
[fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
[fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')],
None, batchable=True)
v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2)
assert th.allclose(v1, v3)
# run UDF with single message and reduce
g.send_and_recv(u, v, message_func_edge,
reduce_func, None, batchable=True)
v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2)
if __name__ == '__main__': if __name__ == '__main__':
#test_update_all() test_update_all()
test_send_and_recv() test_send_and_recv()
test_update_all_multi_fn()
test_send_and_recv_multi_fn()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment