"vscode:/vscode.git/clone" did not exist on "b83e38a6417bd41e4ce5edd81a5a9696abc9441d"
Unverified Commit b1eeb934 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Bugfix][Runtime] Zero degree behaviors (#177)

* fix recv nodes are all 0deg; fix hybriddict does not through keyerror properly

* fallback to apply_nodes when all nodes are 0deg; WIP on pull spmv 0deg

* new 0deg behavior

* new 0deg behavior

* update mx utest for pull-0deg

* fix mx

* fix mx

* get rid of unnecessary sort-n-unique
parent 4dfe7547
......@@ -852,8 +852,9 @@ def frame_like(other, num_rows):
other._warn_and_set_initializer()
newf._default_initializer = other._default_initializer
# set per-col initializer
for key in other.keys():
newf.set_initializer(other.get_initializer(key), key)
# TODO(minjie): hack; cannot rely on keys as the _initializers
# now supports non-exist columns.
newf._initializers = other._initializers
return newf
def merge_frames(frames, indices, max_index, reduce_func):
......
......@@ -1094,7 +1094,6 @@ class DGLGraph(object):
"""
if message_func == "default":
message_func = self._message_func
assert message_func is not None
if is_all(edges):
eid = ALL
......@@ -1252,7 +1251,7 @@ class DGLGraph(object):
if len(v) == 0:
return
with ir.prog() as prog:
scheduler.schedule_pull(graph=self, v=v,
scheduler.schedule_pull(graph=self, pull_nodes=v,
message_func=message_func, reduce_func=reduce_func,
apply_func=apply_node_func)
Runtime.run(prog)
......
......@@ -70,18 +70,18 @@ def gen_degree_bucketing_schedule(
# save for merge
idx_list.append(vb)
fd_list.append(fdvb)
if zero_deg_nodes is not None:
# NOTE: there must be at least one non-zero-deg node; otherwise,
# degree bucketing should not be called.
var_0deg = var.IDX(zero_deg_nodes)
zero_feat = ir.NEW_DICT(var_out, var_0deg, fd_list[0])
idx_list.append(var_0deg)
fd_list.append(zero_feat)
# merge buckets according to the ascending order of the node ids.
all_idx = F.cat([idx.data.tousertensor() for idx in idx_list], dim=0)
sorted_idx, order = F.sort_1d(all_idx)
var_sorted_idx = var.IDX(utils.toindex(sorted_idx))
_, order = F.sort_1d(all_idx)
var_order = var.IDX(utils.toindex(order))
reduced_feat = ir.MERGE_ROW(var_order, fd_list)
if zero_deg_nodes is not None:
# If has zero degrees, scatter the result back to the frame. As
# a result, the features for zero degree nodes will be initialized
# correctly.
ir.WRITE_ROW_(var_out, var_sorted_idx, reduced_feat)
else:
ir.WRITE_DICT_(var_out, reduced_feat)
def _degree_bucketing_schedule(mids, dsts, v):
......
......@@ -2,6 +2,7 @@ from __future__ import absolute_import
from abc import abstractmethod
from ...base import DGLError
from ... import backend as F
from ...frame import FrameRef, Frame
from ... import utils
......@@ -22,6 +23,7 @@ class OpCode(object):
READ_ROW = 6
MERGE_ROW = 7
UPDATE_DICT = 8
NEW_DICT = 9
# mutable op (no return)
# remember the name is suffixed with "_"
WRITE_ = 21
......@@ -391,6 +393,49 @@ def UPDATE_DICT(fd1, fd2, ret=None):
get_current_prog().issue(reg['executor_cls'](fd1, fd2, ret))
return ret
class NewDictExecutor(Executor):
def __init__(self, fd_init, idx, fd_scheme, ret):
self.fd_init = fd_init # the feat dict to borrow initializer
self.idx = idx # the index to look for number or rows
self.fd_scheme = fd_scheme # the feat dict to look for column scheme
self.ret = ret # the result
def opcode(self):
return OpCode.NEW_DICT
def arg_vars(self):
return [self.fd_init, self.idx, self.fd_scheme]
def ret_var(self):
return self.ret
def run(self):
fd_init_data = self.fd_init.data
idx_data = self.idx.data
fd_scheme_data = self.fd_scheme.data
schemes = fd_scheme_data.schemes
ret_dict = {}
for key, sch in schemes.items():
initializer = fd_init_data.get_initializer(key)
ctx = F.context(fd_scheme_data[key])
shape = (len(idx_data),) + sch.shape
# FIXME: the last argument here can only be idx; range
# is meaningless. Need to rethink the signature.
ret_dict[key] = initializer(shape, sch.dtype, ctx, idx_data)
self.ret.data = FrameRef(Frame(ret_dict))
IR_REGISTRY[OpCode.NEW_DICT] = {
'name' : 'NEW_DICT',
'args_type' : [VarType.FEAT_DICT, VarType.IDX, VarType.FEAT_DICT],
'ret_type' : VarType.FEAT_DICT,
'executor_cls' : NewDictExecutor,
}
def NEW_DICT(fd_init, idx, fd_scheme, ret=None):
reg = IR_REGISTRY[OpCode.NEW_DICT]
ret = var.new(reg['ret_type']) if ret is None else ret
get_current_prog().issue(reg['executor_cls'](fd_init, idx, fd_scheme, ret))
return ret
class Write_Executor(Executor):
def __init__(self, fd, row, col, val):
self.fd = fd
......
......@@ -41,6 +41,8 @@ def schedule_send(graph, u, v, eid, message_func):
message_func: callable or list of callable
The message function
"""
# TODO(minjie): support builtin message func
message_func = _standardize_func_usage(message_func, 'message')
# vars
nf = var.FEAT_DICT(graph._node_frame)
ef = var.FEAT_DICT(graph._edge_frame)
......@@ -66,67 +68,22 @@ def schedule_recv(graph, recv_nodes, reduce_func, apply_func):
apply_func: callable
The apply node function
"""
nf = var.FEAT_DICT(graph._node_frame, name='nf')
src, dst, mid = graph._msg_graph.in_edges(recv_nodes)
if len(mid) == 0:
# All recv nodes are 0-degree nodes; downgrade to apply nodes.
if apply_func is not None:
schedule_apply_nodes(graph, recv_nodes, apply_func)
else:
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
# sort and unique the argument
recv_nodes, _ = F.sort_1d(F.unique(recv_nodes.tousertensor()))
recv_nodes = utils.toindex(recv_nodes)
reduced_feat = _gen_reduce(graph, reduce_func, recv_nodes)
var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes')
if apply_func:
# To avoid writing reduced features back to node frame and reading
# it again for apply phase. Instead, we first read the the node
# features and "merge" it with the reduced features.
v_nf = ir.READ_ROW(nf, var_recv_nodes)
v_nf = ir.UPDATE_DICT(v_nf, reduced_feat)
def _afunc_wrapper(node_data):
nb = NodeBatch(graph, recv_nodes, node_data)
return apply_func(nb)
afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf)
final_feat = ir.UPDATE_DICT(reduced_feat, applied_feat)
else:
final_feat = reduced_feat
ir.WRITE_ROW_(nf, var_recv_nodes, final_feat)
def _gen_reduce(graph, reduce_func, recv_nodes):
"""
graph : DGLGraph
reduce_func : callable
recv_nodes : utils.Index
"""
call_type = "recv"
_, dst, mid = graph._msg_graph.in_edges(recv_nodes)
rfunc = _standardize_func_usage(reduce_func)
rfunc_is_list = utils.is_iterable(rfunc)
# Create a tmp frame to hold the feature data.
# The frame has the same size and schemes of the
# node frame.
# TODO(minjie): should replace this with an IR call to make the program stateless.
tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(recv_nodes)))
# vars
msg = var.FEAT_DICT(graph._msg_frame, 'msg')
nf = var.FEAT_DICT(graph._node_frame, 'nf')
out = var.FEAT_DICT(data=tmpframe)
if rfunc_is_list:
# UDF message + builtin reducer
# analyze e2v spmv
spmv_rfunc, rfunc = spmv.analyze_e2v_spmv(graph, rfunc)
inc = spmv.build_inc_matrix(call_type, graph, mid, dst)
spmv.gen_e2v_spmv_schedule(inc, spmv_rfunc, msg, out)
if len(rfunc) == 0:
# All mfunc and rfunc has been processed.
return out
# convert the remaining rfunc to UDFs
rfunc = BundledFunction(rfunc)
# gen degree bucketing schedule for UDF recv
db.gen_degree_bucketing_schedule(graph, rfunc, mid, dst,
recv_nodes, nf, msg, out)
return out
# reduce
reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, mid), recv_nodes)
# apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
def schedule_snr(graph,
edge_tuples,
......@@ -147,114 +104,9 @@ def schedule_snr(graph,
reduced_feat = _gen_send_reduce(call_type, graph,
message_func, reduce_func, (var_u, var_v, var_eid), recv_nodes)
# generate apply schedule
if apply_func:
# To avoid writing reduced features back to node frame and reading
# it again for apply phase. Instead, we first read the the node
# features and "merge" it with the reduced features.
v_nf = ir.READ_ROW(var_nf, var_recv_nodes)
v_nf = ir.UPDATE_DICT(v_nf, reduced_feat)
def _afunc_wrapper(node_data):
nb = NodeBatch(graph, recv_nodes, node_data)
return apply_func(nb)
afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf)
final_feat = ir.UPDATE_DICT(reduced_feat, applied_feat)
else:
final_feat = reduced_feat
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
def _gen_send_reduce(
call_type,
graph,
message_func,
reduce_func,
edge_tuples,
recv_nodes):
"""Generate send and reduce schedule.
This guarantees that the returned reduced features are batched
in the *unique-ascending* order of the edge destination node ids.
call_type : str
graph : DGLGraph
message_func : callable, list of builtins
reduce_func : callable, list of builtins
edge_tuples : (u, v, eid) tuple of var.Var
recv_nodes : utils.index
"""
# arg vars
var_u, var_v, var_eid = edge_tuples
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef')
# format the input functions
mfunc = _standardize_func_usage(message_func)
rfunc = _standardize_func_usage(reduce_func)
mfunc_is_list = utils.is_iterable(mfunc)
rfunc_is_list = utils.is_iterable(rfunc)
# Create a tmp frame to hold the feature data.
# The frame has the same size and schemes of the
# node frame.
# TODO(minjie): should replace this with an IR call to make the program stateless.
tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(recv_nodes)))
var_out = var.FEAT_DICT(data=tmpframe)
if mfunc_is_list and rfunc_is_list:
# builtin message + builtin reducer
# analyze v2v spmv
spmv_pairs, mfunc, rfunc = spmv.analyze_v2v_spmv(graph, mfunc, rfunc)
adj = spmv.build_adj_matrix(call_type, graph, var_u.data, var_v.data)
spmv.gen_v2v_spmv_schedule(adj, spmv_pairs, var_nf, var_ef, var_eid, var_out)
if len(mfunc) == 0:
# All mfunc and rfunc have been converted to v2v spmv.
return var_out
if mfunc_is_list:
# Two cases:
# - mfunc is builtin while rfunc is UDF.
# - mfunc and rfunc are both builtin but some combinations
# fall through from the v2v spmv analysis.
# In both cases, convert the mfunc to UDF.
mfunc = BundledFunction(mfunc)
# generate UDF send schedule
var_mf = _gen_send(graph, var_nf, var_ef, var_u, var_v, var_eid, mfunc)
if rfunc_is_list:
# UDF message + builtin reducer
# analyze e2v spmv
spmv_rfunc, rfunc = spmv.analyze_e2v_spmv(graph, rfunc)
inc = spmv.build_inc_matrix(call_type, graph, var_eid.data, var_v.data)
spmv.gen_e2v_spmv_schedule(inc, spmv_rfunc, var_mf, var_out)
if len(rfunc) == 0:
# All mfunc and rfunc has been processed.
return var_out
# convert the remaining rfunc to UDFs
rfunc = BundledFunction(rfunc)
# gen degree bucketing schedule for UDF recv
mid = utils.toindex(slice(0, len(var_v.data))) # message id is from 0~|dst|
db.gen_degree_bucketing_schedule(graph, rfunc,
mid, var_v.data, recv_nodes,
var_nf, var_mf, var_out)
return var_out
def _gen_send(graph, nf, ef, u, v, eid, mfunc):
fdsrc = ir.READ_ROW(nf, u)
fddst = ir.READ_ROW(nf, v)
fdedge = ir.READ_ROW(ef, eid)
def _mfunc_wrapper(src_data, edge_data, dst_data):
eb = EdgeBatch(graph, (u.data, v.data, eid.data),
src_data, edge_data, dst_data)
return mfunc(eb)
_mfunc_wrapper = var.FUNC(_mfunc_wrapper)
msg = ir.EDGE_UDF(_mfunc_wrapper, fdsrc, fdedge, fddst)
return msg
def schedule_update_all(graph, message_func, reduce_func, apply_func):
"""get send and recv schedule
......@@ -269,6 +121,12 @@ def schedule_update_all(graph, message_func, reduce_func, apply_func):
apply_func: callable
The apply node function
"""
if graph.number_of_edges() == 0:
# All the nodes are zero degree; downgrade to apply nodes
if apply_func is not None:
nodes = utils.toindex(slice(0, graph.number_of_nodes()))
schedule_apply_nodes(graph, nodes, apply_func)
else:
call_type = 'update_all'
src, dst, _ = graph._graph.edges()
eid = utils.toindex(slice(0, graph.number_of_edges())) # shortcut for ALL
......@@ -283,20 +141,7 @@ def schedule_update_all(graph, message_func, reduce_func, apply_func):
reduced_feat = _gen_send_reduce(call_type, graph,
message_func, reduce_func, (var_src, var_dst, var_eid), recv_nodes)
# generate optional apply
if apply_func:
# To avoid writing reduced features back to node frame and reading
# it again for apply phase. Instead, we first read the the node
# features and "merge" it with the reduced features.
v_nf = ir.READ_ROW(var_nf, var_recv_nodes)
v_nf = ir.UPDATE_DICT(v_nf, reduced_feat)
def _afunc_wrapper(node_data):
nb = NodeBatch(graph, recv_nodes, node_data)
return apply_func(nb)
afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf)
final_feat = ir.UPDATE_DICT(reduced_feat, applied_feat)
else:
final_feat = reduced_feat
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
ir.WRITE_DICT_(var_nf, final_feat)
def schedule_apply_nodes(graph, v, apply_func):
......@@ -378,25 +223,21 @@ def schedule_push(graph, u, message_func, reduce_func, apply_func):
The reduce function
apply_func: callable
The apply node function
Returns
-------
A list of executors for DGL Runtime
"""
# FIXME: for now, use send_and_recv to implement push
u, v, eid = graph._graph.out_edges(u)
if len(eid) == 0:
return []
# All the pushing nodes have no out edges. No computation is scheduled.
return
schedule_snr(graph, (u, v, eid), message_func, reduce_func, apply_func)
def schedule_pull(graph, v, message_func, reduce_func, apply_func):
def schedule_pull(graph, pull_nodes, message_func, reduce_func, apply_func):
"""get pull schedule
Parameters
----------
graph: DGLGraph
The DGLGraph to use
v : utils.Index
pull_nodes : utils.Index
Destination nodes for pull
message_func: callable or list of callable
The message function
......@@ -404,16 +245,31 @@ def schedule_pull(graph, v, message_func, reduce_func, apply_func):
The reduce function
apply_func: callable
The apply node function
Returns
-------
A list of executors for DGL Runtime
"""
# FIXME: for now, use send_and_recv to implement pull
u, v, eid = graph._graph.in_edges(v)
# TODO(minjie): `in_edges` can be omitted if message and reduce func pairs
# can be specialized to SPMV. This needs support for creating adjmat
# directly from dst node frontier.
u, v, eid = graph._graph.in_edges(pull_nodes)
if len(eid) == 0:
return []
schedule_snr(graph, (u, v, eid), message_func, reduce_func, apply_func)
# All the nodes are 0deg; downgrades to apply.
if apply_func is not None:
schedule_apply_nodes(graph, pull_nodes, apply_func)
else:
call_type = 'send_and_recv'
pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor()))
pull_nodes = utils.toindex(pull_nodes)
# create vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_pull_nodes = var.IDX(pull_nodes, name='pull_nodes')
var_u = var.IDX(u)
var_v = var.IDX(v)
var_eid = var.IDX(eid)
# generate send and reduce schedule
reduced_feat = _gen_send_reduce(call_type, graph,
message_func, reduce_func, (var_u, var_v, var_eid), pull_nodes)
# generate optional apply
final_feat = _apply_with_accum(graph, var_pull_nodes, var_nf, reduced_feat, apply_func)
ir.WRITE_ROW_(var_nf, var_pull_nodes, final_feat)
def _check_builtin_func_list(func_list):
"""Check whether func_list only contains builtin functions."""
......@@ -422,7 +278,7 @@ def _check_builtin_func_list(func_list):
raise DGLError("If specify multiple message/reduce functions, \
all of them must be builtin")
def _standardize_func_usage(func):
def _standardize_func_usage(func, func_name):
"""Standardize usages of message and reduce functions
Message or reduce funtion can be:
1. a UDF
......@@ -437,14 +293,183 @@ def _standardize_func_usage(func):
"""
if utils.is_iterable(func):
# rfunc is a list of builtin
# func is a list of builtin
_check_builtin_func_list(func)
return func
elif isinstance(func, BuiltinFunction):
# func is one builtin-in
return [func]
else:
# rfunc is one UDF
# func is one UDF
if not callable(func):
raise DGLError('User-defined %s function must be callable.'
' Got: %s' % (func_name, str(func)))
return func
def _apply_with_accum(graph, var_nodes, var_nf, var_accum, apply_func):
"""Apply with accumulated features.
Paramters
---------
var_nodes : var.IDX
The nodes.
var_nf : var.FEAT_DICT
The node features.
var_accum : var.FEAT_DICT
The accumulated features.
apply_func : callable, None
The apply function.
"""
if apply_func:
# To avoid writing reduced features back to node frame and reading
# it again for apply phase. Instead, we first read the the node
# features and "merge" it with the reduced features.
v_nf = ir.READ_ROW(var_nf, var_nodes)
v_nf = ir.UPDATE_DICT(v_nf, var_accum)
def _afunc_wrapper(node_data):
nb = NodeBatch(graph, var_nodes.data, node_data)
return apply_func(nb)
afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf)
final_feat = ir.UPDATE_DICT(var_accum, applied_feat)
else:
final_feat = var_accum
return final_feat
def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
"""
graph : DGLGraph
reduce_func : callable
edge_tuples : tuple of utils.Index
recv_nodes : utils.Index
"""
call_type = "recv"
_, dst, mid = edge_tuples
rfunc = _standardize_func_usage(reduce_func, 'reduce')
rfunc_is_list = utils.is_iterable(rfunc)
# Create a tmp frame to hold the feature data.
# The frame has the same size and schemes of the
# node frame.
# TODO(minjie): should replace this with an IR call to make the program stateless.
tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(recv_nodes)))
# vars
msg = var.FEAT_DICT(graph._msg_frame, 'msg')
nf = var.FEAT_DICT(graph._node_frame, 'nf')
out = var.FEAT_DICT(data=tmpframe)
if rfunc_is_list:
# UDF message + builtin reducer
# analyze e2v spmv
spmv_rfunc, rfunc = spmv.analyze_e2v_spmv(graph, rfunc)
# FIXME: refactor this when fixing the multi-recv bug
mat = spmv.build_incmat_by_eid(graph._msg_frame.num_rows, mid, dst, recv_nodes)
inc = utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx))
spmv.gen_e2v_spmv_schedule(inc, spmv_rfunc, msg, out)
if len(rfunc) == 0:
# All mfunc and rfunc has been processed.
return out
# convert the remaining rfunc to UDFs
rfunc = BundledFunction(rfunc)
# gen degree bucketing schedule for UDF recv
db.gen_degree_bucketing_schedule(graph, rfunc, mid, dst,
recv_nodes, nf, msg, out)
return out
def _gen_send_reduce(
call_type,
graph,
message_func,
reduce_func,
edge_tuples,
recv_nodes):
"""Generate send and reduce schedule.
This guarantees that the returned reduced features are batched
in the *unique-ascending* order of the edge destination node ids.
call_type : str
graph : DGLGraph
message_func : callable, list of builtins
reduce_func : callable, list of builtins
edge_tuples : (u, v, eid) tuple of var.Var
recv_nodes : utils.index
"""
# arg vars
var_u, var_v, var_eid = edge_tuples
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef')
# format the input functions
mfunc = _standardize_func_usage(message_func, 'message')
rfunc = _standardize_func_usage(reduce_func, 'reduce')
mfunc_is_list = utils.is_iterable(mfunc)
rfunc_is_list = utils.is_iterable(rfunc)
# Create a tmp frame to hold the feature data.
# The frame has the same size and schemes of the
# node frame.
# TODO(minjie): should replace this with an IR call to make the program stateless.
tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(recv_nodes)))
var_out = var.FEAT_DICT(data=tmpframe)
if mfunc_is_list and rfunc_is_list:
# builtin message + builtin reducer
# analyze v2v spmv
spmv_pairs, mfunc, rfunc = spmv.analyze_v2v_spmv(graph, mfunc, rfunc)
adj = spmv.build_adj_matrix(call_type, graph,
(var_u.data, var_v.data), recv_nodes)
spmv.gen_v2v_spmv_schedule(adj, spmv_pairs, var_nf, var_ef, var_eid, var_out)
if len(mfunc) == 0:
# All mfunc and rfunc have been converted to v2v spmv.
return var_out
if mfunc_is_list:
# Two cases:
# - mfunc is builtin while rfunc is UDF.
# - mfunc and rfunc are both builtin but some combinations
# fall through from the v2v spmv analysis.
# In both cases, convert the mfunc to UDF.
mfunc = BundledFunction(mfunc)
# generate UDF send schedule
var_mf = _gen_send(graph, var_nf, var_ef, var_u, var_v, var_eid, mfunc)
if rfunc_is_list:
# UDF message + builtin reducer
# analyze e2v spmv
spmv_rfunc, rfunc = spmv.analyze_e2v_spmv(graph, rfunc)
inc = spmv.build_inc_matrix(call_type, graph, var_v.data, recv_nodes)
spmv.gen_e2v_spmv_schedule(inc, spmv_rfunc, var_mf, var_out)
if len(rfunc) == 0:
# All mfunc and rfunc has been processed.
return var_out
# convert the remaining rfunc to UDFs
rfunc = BundledFunction(rfunc)
# gen degree bucketing schedule for UDF recv
mid = utils.toindex(slice(0, len(var_v.data))) # message id is from 0~|dst|
db.gen_degree_bucketing_schedule(graph, rfunc,
mid, var_v.data, recv_nodes,
var_nf, var_mf, var_out)
return var_out
def _gen_send(graph, nf, ef, u, v, eid, mfunc):
fdsrc = ir.READ_ROW(nf, u)
fddst = ir.READ_ROW(nf, v)
fdedge = ir.READ_ROW(ef, eid)
def _mfunc_wrapper(src_data, edge_data, dst_data):
eb = EdgeBatch(graph, (u.data, v.data, eid.data),
src_data, edge_data, dst_data)
return mfunc(eb)
_mfunc_wrapper = var.FUNC(_mfunc_wrapper)
msg = ir.EDGE_UDF(_mfunc_wrapper, fdsrc, fdedge, fddst)
return msg
_init_api("dgl.runtime.scheduler")
......@@ -43,7 +43,7 @@ def analyze_v2v_spmv(graph, mfunc, rfunc):
raise DGLError('Reduce function requires message field "%s",'
' but no message function generates it.' % mfld)
mfn = fld2mfunc[mfld]
# FIXME: should pre-compile a look up table
# TODO(minjie): should pre-compile a look up table
if mfn.is_spmv_supported(graph) and rfn.is_spmv_supported():
spmv_pairs.append((mfn, rfn))
else:
......@@ -122,27 +122,40 @@ def gen_e2v_spmv_schedule(inc, spmv_rfunc, mf, out):
ftdst = ir.SPMV(inc_var, ftmsg)
ir.WRITE_COL_(out, var.STR(rfn.out_field), ftdst)
def build_adj_matrix(call_type, graph, u, v):
"""
def build_adj_matrix(call_type, graph, edges, reduce_nodes):
"""Build adjacency matrix.
Parameters
----------
call_type : str
Can be 'update_all', 'send_and_recv'
graph : DGLGraph
u : utils.Index
v : utils.Index
The graph
edges : tuple of utils.Index
(u, v)
reduce_nodes : utils.Index
The nodes to reduce messages, which will be target dimension
of the adjmat. The nodes include unique(v) and zero-degree-nodes.
Returns
-------
utils.CtxCachedObject
Get be used to get adjacency matrix on the provided ctx.
"""
if call_type == "update_all":
# full graph case
return utils.CtxCachedObject(lambda ctx : graph.adjacency_matrix(ctx=ctx))
elif call_type == "send_and_recv":
# edgeset case
mat = build_adj_matrix_uv(graph, u, v)
mat = build_adj_matrix_uv(graph, edges, reduce_nodes)
return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx))
else:
raise DGLError('Invalid call type:', call_type)
def build_adj_matrix_index_uv(graph, u, v):
def build_adj_matrix_index_uv(graph, edges, reduce_nodes):
"""Build adj matrix index and shape using the given (u, v) edges.
The matrix is of shape (len(unique(v)), n), where n is the number of nodes
The matrix is of shape (len(reduce_nodes), n), where n is the number of nodes
in the graph. Therefore, when doing SPMV, the src node data
should be all the node features.
......@@ -154,10 +167,11 @@ def build_adj_matrix_index_uv(graph, u, v):
---------
graph : DGLGraph
The graph
u : utils.Index
Src nodes.
v : utils.Index
Dst nodes.
edges : tuple of utils.Index
(u, v)
reduce_nodes : utils.Index
The nodes to reduce messages, which will be target dimension
of the adjmat. The nodes include unique(v) and zero-degree-nodes.
Returns
-------
......@@ -166,21 +180,23 @@ def build_adj_matrix_index_uv(graph, u, v):
tupe of int
The dense shape.
"""
new2old, old2new = utils.build_relabel_map(v)
# TODO(minjie): add node frontier for this
new2old, old2new = utils.build_relabel_map(reduce_nodes, sorted=True)
u, v = edges
u = u.tousertensor()
v = v.tousertensor()
new_v = old2new[v] # FIXME(minjie): no use []
n = graph.number_of_nodes()
m = len(new2old)
m = len(reduce_nodes)
row = F.unsqueeze(new_v, 0)
col = F.unsqueeze(u, 0)
idx = F.cat([row, col], dim=0)
return ('coo', idx), (m, n)
def build_adj_matrix_uv(graph, u, v):
"""Build adj matrix using the given (u, v) edges.
def build_adj_matrix_uv(graph, edges, reduce_nodes):
"""Build adj matrix using the given (u, v) edges and target nodes.
The matrix is of shape (len(v), n), where n is the number of nodes
The matrix is of shape (len(reduce_nodes), n), where n is the number of nodes
in the graph. Therefore, when doing SPMV, the src node data
should be all the node features.
......@@ -188,64 +204,109 @@ def build_adj_matrix_uv(graph, u, v):
---------
graph : DGLGraph
The graph
u : utils.Index
Src nodes.
v : utils.Index
Dst nodes.
edges : tuple of utils.Index
(u, v)
reduce_nodes : utils.Index
The nodes to reduce messages, which will be target dimension
of the adjmat. The nodes include unique(v) and zero-degree-nodes.
Returns
-------
Sparse matrix
The adjacency matrix on CPU
"""
sp_idx, shape = build_adj_matrix_index_uv(graph, u, v)
sp_idx, shape = build_adj_matrix_index_uv(graph, edges, reduce_nodes)
u, v = edges
nnz = len(u)
# FIXME(minjie): data type
dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu())
mat = F.sparse_matrix(dat, sp_idx, shape)
return mat
def build_inc_matrix(call_type, graph, eid, v):
"""
def build_inc_matrix(call_type, graph, dst, reduce_nodes):
"""Build incidence matrix.
Parameters
----------
call_type : str
Can be 'update_all', 'send_and_recv'.
graph : DGLGraph
eid : utils.Index
v : utils.Index
The graph.
dst : utils.Index
The destination nodes of the edges.
reduce_nodes : utils.Index
The nodes to reduce messages, which will be target dimension
of the incmat. The nodes include unique(dst) and zero-degree-nodes.
Returns
-------
utils.CtxCachedObject
Get be used to get incidence matrix on the provided ctx.
"""
if call_type == "update_all":
# full graph case
return utils.CtxCachedObject(lambda ctx : graph.incidence_matrix(type='in', ctx=ctx))
elif call_type == "send_and_recv":
# edgeset case
mat = build_inc_matrix_v(v)
return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx))
elif call_type == "recv":
# dst nodeset case
mat = build_inc_matrix_eid(eid, v)
mat = build_incmat_by_dst(dst, reduce_nodes)
return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx))
else:
raise DGLError('Invalid call type:', call_type)
def build_inc_matrix_eid(eid, v):
"""A spmat of shape (n, m), where n=len(unique(v)), m=len(eid).
def build_incmat_by_eid(m, eid, dst, reduce_nodes):
"""Build incidence matrix using edge id and edge dst nodes.
The incidence matrix is of shape (n, m), where n=len(reduce_nodes).
The nnz is equal to len(eid).
Invariant: len(eid) == len(v)
Invariant: len(eid) == len(dst)
The dst nodes will be sorted in the *unique-ascending* order of
their ids. This is compatible with other reduce scheduler such as
degree-bucketing scheduler.
Examples
--------
Total of seven edges. Three edges point to node 1 (eid=0,1,2);
two point to node 3 (eid=3,4); two point to node 4 (eid=5,6).
Only five edges should be included in the result incmat (eid=1,2,3,5,6).
There are five nodes in the final target dimension (0~4),
where node 0 and 2 are two 0-deg nodes.
>>> m = 7
>>> eid = [1, 2, 3, 5, 6]
>>> dst = [1, 1, 3, 4, 4]
>>> reduce_nodes = [0, 1, 2, 3, 4]
>>> build_incmat_by_eid(m, eid, dst, reduce_nodes)
tensor([[0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 1]], shape=(5, 7))
Paramters
---------
m : int
The source dimension size of the incidence matrix.
eid : utils.Index
v : utils.Index
The edge ids. All ids must be in range [0, m).
dst : utils.Index
The edge destination nodes. len(eid) == len(dst).
reduce_nodes : utils.Index
The nodes to reduce messages, which will be target dimension
of the incmat. The nodes include unique(dst) and zero-degree-nodes.
Returns
-------
Sparse matrix
The incidence matrix on CPU
"""
# relabel v to range(0, len(unique(v)))
new2old, old2new = utils.build_relabel_map(v)
v = v.tousertensor()
new2old, old2new = utils.build_relabel_map(reduce_nodes, sorted=True)
dst = dst.tousertensor()
eid = eid.tousertensor()
new_v = old2new[v] # FIXME(minjie): no use []
# relabel edges dsts
new_v = old2new[dst] # FIXME(minjie): no use []
# create sparse index tensor
m = len(eid)
n = len(new2old)
n = len(reduce_nodes)
row = F.unsqueeze(new_v, 0)
col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0)
......@@ -254,10 +315,38 @@ def build_inc_matrix_eid(eid, v):
dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu())
return F.sparse_matrix(dat, ('coo', idx), (n, m))
def build_inc_matrix_v(v):
"""A spmat of shape (n, m), where n=len(unique(v)), m=len(v).
def build_incmat_by_dst(dst, reduce_nodes):
"""Build incidence matrix using only edge destinations.
v : utils.Index
The incidence matrix is of shape (n, m), where n=len(reduce_nodes), m=len(dst).
The nnz is equal to len(dst).
Examples
--------
Five edges. Two edges point to node 1; one points to node 3;
two point to node 4. There are five nodes in the final
target dimension (0~4), where node 0 and 2 are two 0-deg nodes.
>>> dst = [1, 1, 3, 4, 4]
>>> reduce_nodes = [0, 1, 2, 3, 4]
>>> build_incmat_by_dst(dst, reduced_nodes)
tensor([[0, 0, 0, 0, 0],
[1, 1, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 1]], shape=(5, 5))
Parameters
----------
dst : utils.Index
The edge destinations.
reduce_nodes : utils.Index
The nodes to reduce messages, which will be target dimension
of the incmat. The nodes include unique(dst) and zero-degree-nodes.
Returns
-------
Sparse matrix
The incidence matrix on CPU
"""
eid = utils.toindex(F.arange(0, len(v)))
return build_inc_matrix_eid(eid, v)
eid = utils.toindex(F.arange(0, len(dst)))
return build_incmat_by_eid(len(eid), eid, dst, reduce_nodes)
......@@ -168,6 +168,7 @@ class HybridDict(Mapping):
for d in self._dict_like_list:
if key in d:
return d[key]
raise KeyError(key)
def __contains__(self, key):
return key in self.keys()
......@@ -198,7 +199,7 @@ class ReadOnlyDict(Mapping):
def __len__(self):
return len(self._dict_like)
def build_relabel_map(x):
def build_relabel_map(x, sorted=False):
"""Relabel the input ids to continuous ids that starts from zero.
Ids are assigned new ids according to their ascending order.
......@@ -218,6 +219,8 @@ def build_relabel_map(x):
----------
x : Index
The input ids.
sorted : bool, default=False
Whether the input has already been unique and sorted.
Returns
-------
......@@ -229,7 +232,10 @@ def build_relabel_map(x):
new id tensor: new_id = old_to_new[old_id]
"""
x = x.tousertensor()
if not sorted:
unique_x, _ = F.sort_1d(F.unique(x))
else:
unique_x = x
map_len = int(F.max(unique_x, dim=0)) + 1
old_to_new = F.zeros(map_len, dtype=F.int64, ctx=F.cpu())
F.scatter_row_inplace(old_to_new, unique_x, F.arange(0, len(unique_x)))
......
......@@ -282,27 +282,40 @@ def check_pull_0deg(readonly):
return {'m' : edges.src['h']}
def _reduce(nodes):
return {'h' : nodes.mailbox['m'].sum(1)}
def _apply(nodes):
return {'h' : nodes.data['h'] * 2}
def _init2(shape, dtype, ctx, ids):
return 2 + mx.nd.zeros(shape, dtype=dtype, ctx=ctx)
g.set_n_initializer(_init2, 'h')
old_repr = mx.nd.random.normal(shape=(2, 5))
g.set_n_repr({'h' : old_repr})
g.pull(0, _message, _reduce)
# test#1: pull only 0-deg node
g.ndata['h'] = old_repr
g.pull(0, _message, _reduce, _apply)
new_repr = g.ndata['h']
# TODO(minjie): this is not the intended behavior. Pull node#0
# should reset node#0 to the initial value. The bug is because
# current pull is implemented using send_and_recv. Since there
# is no edge to node#0 so the send_and_recv is skipped. Fix this
# behavior when optimizing the pull scheduler.
assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
# 0deg check: equal to apply_nodes
assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy() * 2)
# non-0deg check: untouched
assert np.allclose(new_repr[1].asnumpy(), old_repr[1].asnumpy())
g.pull(1, _message, _reduce)
new_repr = g.ndata['h']
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())
old_repr = mx.nd.random.normal(shape=(2, 5))
g.set_n_repr({'h' : old_repr})
g.pull([0, 1], _message, _reduce)
# test#2: pull only non-deg node
g.ndata['h'] = old_repr
g.pull(1, _message, _reduce, _apply)
new_repr = g.ndata['h']
# 0deg check: untouched
assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())
# non-0deg check: recved node0 and got applied
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy() * 2)
# test#3: pull only both nodes
g.ndata['h'] = old_repr
g.pull([0, 1], _message, _reduce, _apply)
new_repr = g.ndata['h']
# 0deg check: init and applied
t = mx.nd.zeros(shape=(2,5)) + 4
assert np.allclose(new_repr[0].asnumpy(), t.asnumpy())
# non-0deg check: recv node0 and applied
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy() * 2)
def test_pull_0deg():
check_pull_0deg(True)
......
......@@ -235,7 +235,87 @@ def test_update_routines():
assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
reduce_msg_shapes.clear()
def test_reduce_0deg():
def test_recv_0deg():
# test recv with 0deg nodes;
g = DGLGraph()
g.add_nodes(2)
g.add_edge(0, 1)
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
def _apply(nodes):
return {'h' : nodes.data['h'] * 2}
def _init2(shape, dtype, ctx, ids):
return 2 + th.zeros(shape, dtype=dtype, device=ctx)
g.register_message_func(_message)
g.register_reduce_func(_reduce)
g.register_apply_node_func(_apply)
g.set_n_initializer(_init2, 'h')
# test#1: recv both 0deg and non-0deg nodes
old = th.randn((2, 5))
g.ndata['h'] = old
g.send((0, 1))
g.recv([0, 1])
new = g.ndata.pop('h')
# 0deg check: initialized with the func and got applied
assert U.allclose(new[0], th.full((5,), 4))
# non-0deg check
assert U.allclose(new[1], th.sum(old, 0) * 2)
# test#2: recv only 0deg node is equal to apply
old = th.randn((2, 5))
g.ndata['h'] = old
g.send((0, 1))
g.recv(0)
new = g.ndata.pop('h')
# 0deg check: equal to apply_nodes
assert U.allclose(new[0], 2 * old[0])
# non-0deg check: untouched
assert U.allclose(new[1], old[1])
def test_recv_0deg_newfld():
# test recv with 0deg nodes; the reducer also creates a new field
g = DGLGraph()
g.add_nodes(2)
g.add_edge(0, 1)
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
return {'h1' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
def _apply(nodes):
return {'h1' : nodes.data['h1'] * 2}
def _init2(shape, dtype, ctx, ids):
return 2 + th.zeros(shape, dtype=dtype, device=ctx)
g.register_message_func(_message)
g.register_reduce_func(_reduce)
g.register_apply_node_func(_apply)
# test#1: recv both 0deg and non-0deg nodes
old = th.randn((2, 5))
g.set_n_initializer(_init2, 'h1')
g.ndata['h'] = old
g.send((0, 1))
g.recv([0, 1])
new = g.ndata.pop('h1')
# 0deg check: initialized with the func and got applied
assert U.allclose(new[0], th.full((5,), 4))
# non-0deg check
assert U.allclose(new[1], th.sum(old, 0) * 2)
# test#2: recv only 0deg node
old = th.randn((2, 5))
g.ndata['h'] = old
g.ndata['h1'] = th.full((2, 5), -1) # this is necessary
g.send((0, 1))
g.recv(0)
new = g.ndata.pop('h1')
# 0deg check: fallback to apply
assert U.allclose(new[0], th.full((5,), -2))
# non-0deg check: not changed
assert U.allclose(new[1], th.full((5,), -1))
def test_update_all_0deg():
# test#1
g = DGLGraph()
g.add_nodes(5)
g.add_edge(1, 0)
......@@ -246,18 +326,30 @@ def test_reduce_0deg():
return {'m' : edges.src['h']}
def _reduce(nodes):
return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
def _apply(nodes):
return {'h' : nodes.data['h'] * 2}
def _init2(shape, dtype, ctx, ids):
return 2 + th.zeros(shape, dtype=dtype, device=ctx)
g.set_n_initializer(_init2, 'h')
old_repr = th.randn(5, 5)
g.ndata['h'] = old_repr
g.update_all(_message, _reduce)
g.update_all(_message, _reduce, _apply)
new_repr = g.ndata['h']
# the first row of the new_repr should be the sum of all the node
# features; while the 0-deg nodes should be initialized by the
# initializer.
assert U.allclose(new_repr[1:], 2+th.zeros((4,5)))
assert U.allclose(new_repr[0], old_repr.sum(0))
# initializer and applied with UDF.
assert U.allclose(new_repr[1:], 2*(2+th.zeros((4,5))))
assert U.allclose(new_repr[0], 2 * old_repr.sum(0))
# test#2: graph with no edge
g = DGLGraph()
g.add_nodes(5)
g.set_n_initializer(_init2, 'h')
g.ndata['h'] = old_repr
g.update_all(_message, _reduce, _apply)
new_repr = g.ndata['h']
# should fallback to apply
assert U.allclose(new_repr, 2*old_repr)
def test_pull_0deg():
g = DGLGraph()
......@@ -266,25 +358,34 @@ def test_pull_0deg():
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
return {'h' : nodes.mailbox['m'].sum(1)}
old_repr = th.randn(2, 5)
g.ndata['h'] = old_repr
g.pull(0, _message, _reduce)
new_repr = g.ndata['h']
assert U.allclose(new_repr[0], old_repr[0])
assert U.allclose(new_repr[1], old_repr[1])
g.pull(1, _message, _reduce)
new_repr = g.ndata['h']
assert U.allclose(new_repr[1], old_repr[0])
old_repr = th.randn(2, 5)
g.ndata['h'] = old_repr
g.pull([0, 1], _message, _reduce)
new_repr = g.ndata['h']
assert U.allclose(new_repr[0], old_repr[0])
assert U.allclose(new_repr[1], old_repr[0])
return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
def _apply(nodes):
return {'h' : nodes.data['h'] * 2}
def _init2(shape, dtype, ctx, ids):
return 2 + th.zeros(shape, dtype=dtype, device=ctx)
g.register_message_func(_message)
g.register_reduce_func(_reduce)
g.register_apply_node_func(_apply)
g.set_n_initializer(_init2, 'h')
# test#1: pull both 0deg and non-0deg nodes
old = th.randn((2, 5))
g.ndata['h'] = old
g.pull([0, 1])
new = g.ndata.pop('h')
# 0deg check: initialized with the func and got applied
assert U.allclose(new[0], th.full((5,), 4))
# non-0deg check
assert U.allclose(new[1], th.sum(old, 0) * 2)
# test#2: pull only 0deg node
old = th.randn((2, 5))
g.ndata['h'] = old
g.pull(0)
new = g.ndata.pop('h')
# 0deg check: fallback to apply
assert U.allclose(new[0], 2*old[0])
# non-0deg check: not touched
assert U.allclose(new[1], old[1])
def _disabled_test_send_twice():
# TODO(minjie): please re-enable this unittest after the send code problem is fixed.
......@@ -419,7 +520,9 @@ if __name__ == '__main__':
test_apply_nodes()
test_apply_edges()
test_update_routines()
test_reduce_0deg()
test_recv_0deg()
test_recv_0deg_newfld()
test_update_all_0deg()
test_pull_0deg()
test_send_multigraph()
test_dynamic_addition()
......@@ -110,6 +110,52 @@ def test_v2v_snr():
# test 2d node features
_test('f2')
def test_v2v_pull():
nodes = th.tensor([1, 2, 3, 9])
def _test(fld):
def message_func(edges):
return {'m' : edges.src[fld]}
def message_func_edge(edges):
if len(edges.src[fld].shape) == 1:
return {'m' : edges.src[fld] * edges.data['e1']}
else:
return {'m' : edges.src[fld] * edges.data['e2']}
def reduce_func(nodes):
return {fld : th.sum(nodes.mailbox['m'], 1)}
def apply_func(nodes):
return {fld : 2 * nodes.data[fld]}
g = generate_graph()
# send and recv
v1 = g.ndata[fld]
g.pull(nodes, fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out=fld), apply_func)
v2 = g.ndata[fld]
g.ndata[fld] = v1
g.pull(nodes, message_func, reduce_func, apply_func)
v3 = g.ndata[fld]
assert U.allclose(v2, v3)
# send and recv with edge weights
v1 = g.ndata[fld]
g.pull(nodes, fn.src_mul_edge(src=fld, edge='e1', out='m'),
fn.sum(msg='m', out=fld), apply_func)
v2 = g.ndata[fld]
g.ndata[fld] = v1
g.pull(nodes, fn.src_mul_edge(src=fld, edge='e2', out='m'),
fn.sum(msg='m', out=fld), apply_func)
v3 = g.ndata[fld]
g.ndata[fld] = v1
g.pull(nodes, message_func_edge, reduce_func, apply_func)
v4 = g.ndata[fld]
assert U.allclose(v2, v3)
assert U.allclose(v3, v4)
# test 1d node features
_test('f1')
# test 2d node features
_test('f2')
def test_v2v_update_all_multi_fn():
def message_func(edges):
return {'m2': edges.src['f2']}
......@@ -311,7 +357,7 @@ def test_e2v_recv_multi_fn():
# test 2d node features
_test('f2')
def test_multi_fn_fallback():
def test_update_all_multi_fallback():
# create a graph with zero in degree nodes
g = dgl.DGLGraph()
g.add_nodes(10)
......@@ -383,12 +429,98 @@ def test_multi_fn_fallback():
assert U.allclose(o2, g.ndata.pop('o2'))
assert U.allclose(o3, g.ndata.pop('o3'))
def test_pull_multi_fallback():
# create a graph with zero in degree nodes
g = dgl.DGLGraph()
g.add_nodes(10)
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
g.ndata['h'] = th.randn(10, D)
g.edata['w1'] = th.randn(16,)
g.edata['w2'] = th.randn(16, D)
def _mfunc_hxw1(edges):
return {'m1' : edges.src['h'] * th.unsqueeze(edges.data['w1'], 1)}
def _mfunc_hxw2(edges):
return {'m2' : edges.src['h'] * edges.data['w2']}
def _rfunc_m1(nodes):
return {'o1' : th.sum(nodes.mailbox['m1'], 1)}
def _rfunc_m2(nodes):
return {'o2' : th.sum(nodes.mailbox['m2'], 1)}
def _rfunc_m1max(nodes):
return {'o3' : th.max(nodes.mailbox['m1'], 1)[0]}
def _afunc(nodes):
ret = {}
for k, v in nodes.data.items():
if k.startswith('o'):
ret[k] = 2 * v
return ret
# nodes to pull
def _pull_nodes(nodes):
# compute ground truth
g.pull(nodes, _mfunc_hxw1, _rfunc_m1, _afunc)
o1 = g.ndata.pop('o1')
g.pull(nodes, _mfunc_hxw2, _rfunc_m2, _afunc)
o2 = g.ndata.pop('o2')
g.pull(nodes, _mfunc_hxw1, _rfunc_m1max, _afunc)
o3 = g.ndata.pop('o3')
# v2v spmv
g.pull(nodes, fn.src_mul_edge(src='h', edge='w1', out='m1'),
fn.sum(msg='m1', out='o1'),
_afunc)
assert U.allclose(o1, g.ndata.pop('o1'))
# v2v fallback to e2v
g.pull(nodes, fn.src_mul_edge(src='h', edge='w2', out='m2'),
fn.sum(msg='m2', out='o2'),
_afunc)
assert U.allclose(o2, g.ndata.pop('o2'))
# v2v fallback to degree bucketing
g.pull(nodes, fn.src_mul_edge(src='h', edge='w1', out='m1'),
fn.max(msg='m1', out='o3'),
_afunc)
assert U.allclose(o3, g.ndata.pop('o3'))
# multi builtins, both v2v spmv
g.pull(nodes,
[fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w1', out='m2')],
[fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')],
_afunc)
assert U.allclose(o1, g.ndata.pop('o1'))
assert U.allclose(o1, g.ndata.pop('o2'))
# multi builtins, one v2v spmv, one fallback to e2v
g.pull(nodes,
[fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w2', out='m2')],
[fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')],
_afunc)
assert U.allclose(o1, g.ndata.pop('o1'))
assert U.allclose(o2, g.ndata.pop('o2'))
# multi builtins, one v2v spmv, one fallback to e2v, one fallback to degree-bucketing
g.pull(nodes,
[fn.src_mul_edge(src='h', edge='w1', out='m1'),
fn.src_mul_edge(src='h', edge='w2', out='m2'),
fn.src_mul_edge(src='h', edge='w1', out='m3')],
[fn.sum(msg='m1', out='o1'),
fn.sum(msg='m2', out='o2'),
fn.max(msg='m3', out='o3')],
_afunc)
assert U.allclose(o1, g.ndata.pop('o1'))
assert U.allclose(o2, g.ndata.pop('o2'))
assert U.allclose(o3, g.ndata.pop('o3'))
# test#1: non-0deg nodes
nodes = [1, 2, 9]
_pull_nodes(nodes)
# test#2: 0deg nodes + non-0deg nodes
nodes = [0, 1, 2, 9]
_pull_nodes(nodes)
if __name__ == '__main__':
test_v2v_update_all()
test_v2v_snr()
test_v2v_pull()
test_v2v_update_all_multi_fn()
test_v2v_snr_multi_fn()
test_e2v_update_all_multi_fn()
test_e2v_snr_multi_fn()
test_e2v_recv_multi_fn()
test_multi_fn_fallback()
test_update_all_multi_fallback()
test_pull_multi_fallback()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment