Unverified Commit b1eeb934 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Bugfix][Runtime] Zero degree behaviors (#177)

* fix recv nodes are all 0deg; fix hybriddict does not through keyerror properly

* fallback to apply_nodes when all nodes are 0deg; WIP on pull spmv 0deg

* new 0deg behavior

* new 0deg behavior

* update mx utest for pull-0deg

* fix mx

* fix mx

* get rid of unnecessary sort-n-unique
parent 4dfe7547
...@@ -852,8 +852,9 @@ def frame_like(other, num_rows): ...@@ -852,8 +852,9 @@ def frame_like(other, num_rows):
other._warn_and_set_initializer() other._warn_and_set_initializer()
newf._default_initializer = other._default_initializer newf._default_initializer = other._default_initializer
# set per-col initializer # set per-col initializer
for key in other.keys(): # TODO(minjie): hack; cannot rely on keys as the _initializers
newf.set_initializer(other.get_initializer(key), key) # now supports non-exist columns.
newf._initializers = other._initializers
return newf return newf
def merge_frames(frames, indices, max_index, reduce_func): def merge_frames(frames, indices, max_index, reduce_func):
......
...@@ -1094,7 +1094,6 @@ class DGLGraph(object): ...@@ -1094,7 +1094,6 @@ class DGLGraph(object):
""" """
if message_func == "default": if message_func == "default":
message_func = self._message_func message_func = self._message_func
assert message_func is not None
if is_all(edges): if is_all(edges):
eid = ALL eid = ALL
...@@ -1252,7 +1251,7 @@ class DGLGraph(object): ...@@ -1252,7 +1251,7 @@ class DGLGraph(object):
if len(v) == 0: if len(v) == 0:
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_pull(graph=self, v=v, scheduler.schedule_pull(graph=self, pull_nodes=v,
message_func=message_func, reduce_func=reduce_func, message_func=message_func, reduce_func=reduce_func,
apply_func=apply_node_func) apply_func=apply_node_func)
Runtime.run(prog) Runtime.run(prog)
......
...@@ -70,18 +70,18 @@ def gen_degree_bucketing_schedule( ...@@ -70,18 +70,18 @@ def gen_degree_bucketing_schedule(
# save for merge # save for merge
idx_list.append(vb) idx_list.append(vb)
fd_list.append(fdvb) fd_list.append(fdvb)
if zero_deg_nodes is not None:
# NOTE: there must be at least one non-zero-deg node; otherwise,
# degree bucketing should not be called.
var_0deg = var.IDX(zero_deg_nodes)
zero_feat = ir.NEW_DICT(var_out, var_0deg, fd_list[0])
idx_list.append(var_0deg)
fd_list.append(zero_feat)
# merge buckets according to the ascending order of the node ids. # merge buckets according to the ascending order of the node ids.
all_idx = F.cat([idx.data.tousertensor() for idx in idx_list], dim=0) all_idx = F.cat([idx.data.tousertensor() for idx in idx_list], dim=0)
sorted_idx, order = F.sort_1d(all_idx) _, order = F.sort_1d(all_idx)
var_sorted_idx = var.IDX(utils.toindex(sorted_idx))
var_order = var.IDX(utils.toindex(order)) var_order = var.IDX(utils.toindex(order))
reduced_feat = ir.MERGE_ROW(var_order, fd_list) reduced_feat = ir.MERGE_ROW(var_order, fd_list)
if zero_deg_nodes is not None:
# If has zero degrees, scatter the result back to the frame. As
# a result, the features for zero degree nodes will be initialized
# correctly.
ir.WRITE_ROW_(var_out, var_sorted_idx, reduced_feat)
else:
ir.WRITE_DICT_(var_out, reduced_feat) ir.WRITE_DICT_(var_out, reduced_feat)
def _degree_bucketing_schedule(mids, dsts, v): def _degree_bucketing_schedule(mids, dsts, v):
......
...@@ -2,6 +2,7 @@ from __future__ import absolute_import ...@@ -2,6 +2,7 @@ from __future__ import absolute_import
from abc import abstractmethod from abc import abstractmethod
from ...base import DGLError
from ... import backend as F from ... import backend as F
from ...frame import FrameRef, Frame from ...frame import FrameRef, Frame
from ... import utils from ... import utils
...@@ -22,6 +23,7 @@ class OpCode(object): ...@@ -22,6 +23,7 @@ class OpCode(object):
READ_ROW = 6 READ_ROW = 6
MERGE_ROW = 7 MERGE_ROW = 7
UPDATE_DICT = 8 UPDATE_DICT = 8
NEW_DICT = 9
# mutable op (no return) # mutable op (no return)
# remember the name is suffixed with "_" # remember the name is suffixed with "_"
WRITE_ = 21 WRITE_ = 21
...@@ -391,6 +393,49 @@ def UPDATE_DICT(fd1, fd2, ret=None): ...@@ -391,6 +393,49 @@ def UPDATE_DICT(fd1, fd2, ret=None):
get_current_prog().issue(reg['executor_cls'](fd1, fd2, ret)) get_current_prog().issue(reg['executor_cls'](fd1, fd2, ret))
return ret return ret
class NewDictExecutor(Executor):
def __init__(self, fd_init, idx, fd_scheme, ret):
self.fd_init = fd_init # the feat dict to borrow initializer
self.idx = idx # the index to look for number or rows
self.fd_scheme = fd_scheme # the feat dict to look for column scheme
self.ret = ret # the result
def opcode(self):
return OpCode.NEW_DICT
def arg_vars(self):
return [self.fd_init, self.idx, self.fd_scheme]
def ret_var(self):
return self.ret
def run(self):
fd_init_data = self.fd_init.data
idx_data = self.idx.data
fd_scheme_data = self.fd_scheme.data
schemes = fd_scheme_data.schemes
ret_dict = {}
for key, sch in schemes.items():
initializer = fd_init_data.get_initializer(key)
ctx = F.context(fd_scheme_data[key])
shape = (len(idx_data),) + sch.shape
# FIXME: the last argument here can only be idx; range
# is meaningless. Need to rethink the signature.
ret_dict[key] = initializer(shape, sch.dtype, ctx, idx_data)
self.ret.data = FrameRef(Frame(ret_dict))
IR_REGISTRY[OpCode.NEW_DICT] = {
'name' : 'NEW_DICT',
'args_type' : [VarType.FEAT_DICT, VarType.IDX, VarType.FEAT_DICT],
'ret_type' : VarType.FEAT_DICT,
'executor_cls' : NewDictExecutor,
}
def NEW_DICT(fd_init, idx, fd_scheme, ret=None):
reg = IR_REGISTRY[OpCode.NEW_DICT]
ret = var.new(reg['ret_type']) if ret is None else ret
get_current_prog().issue(reg['executor_cls'](fd_init, idx, fd_scheme, ret))
return ret
class Write_Executor(Executor): class Write_Executor(Executor):
def __init__(self, fd, row, col, val): def __init__(self, fd, row, col, val):
self.fd = fd self.fd = fd
......
...@@ -41,6 +41,8 @@ def schedule_send(graph, u, v, eid, message_func): ...@@ -41,6 +41,8 @@ def schedule_send(graph, u, v, eid, message_func):
message_func: callable or list of callable message_func: callable or list of callable
The message function The message function
""" """
# TODO(minjie): support builtin message func
message_func = _standardize_func_usage(message_func, 'message')
# vars # vars
nf = var.FEAT_DICT(graph._node_frame) nf = var.FEAT_DICT(graph._node_frame)
ef = var.FEAT_DICT(graph._edge_frame) ef = var.FEAT_DICT(graph._edge_frame)
...@@ -66,67 +68,22 @@ def schedule_recv(graph, recv_nodes, reduce_func, apply_func): ...@@ -66,67 +68,22 @@ def schedule_recv(graph, recv_nodes, reduce_func, apply_func):
apply_func: callable apply_func: callable
The apply node function The apply node function
""" """
nf = var.FEAT_DICT(graph._node_frame, name='nf') src, dst, mid = graph._msg_graph.in_edges(recv_nodes)
if len(mid) == 0:
# All recv nodes are 0-degree nodes; downgrade to apply nodes.
if apply_func is not None:
schedule_apply_nodes(graph, recv_nodes, apply_func)
else:
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
# sort and unique the argument # sort and unique the argument
recv_nodes, _ = F.sort_1d(F.unique(recv_nodes.tousertensor())) recv_nodes, _ = F.sort_1d(F.unique(recv_nodes.tousertensor()))
recv_nodes = utils.toindex(recv_nodes) recv_nodes = utils.toindex(recv_nodes)
reduced_feat = _gen_reduce(graph, reduce_func, recv_nodes)
var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes') var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes')
if apply_func: # reduce
# To avoid writing reduced features back to node frame and reading reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, mid), recv_nodes)
# it again for apply phase. Instead, we first read the the node # apply
# features and "merge" it with the reduced features. final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
v_nf = ir.READ_ROW(nf, var_recv_nodes) ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
v_nf = ir.UPDATE_DICT(v_nf, reduced_feat)
def _afunc_wrapper(node_data):
nb = NodeBatch(graph, recv_nodes, node_data)
return apply_func(nb)
afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf)
final_feat = ir.UPDATE_DICT(reduced_feat, applied_feat)
else:
final_feat = reduced_feat
ir.WRITE_ROW_(nf, var_recv_nodes, final_feat)
def _gen_reduce(graph, reduce_func, recv_nodes):
"""
graph : DGLGraph
reduce_func : callable
recv_nodes : utils.Index
"""
call_type = "recv"
_, dst, mid = graph._msg_graph.in_edges(recv_nodes)
rfunc = _standardize_func_usage(reduce_func)
rfunc_is_list = utils.is_iterable(rfunc)
# Create a tmp frame to hold the feature data.
# The frame has the same size and schemes of the
# node frame.
# TODO(minjie): should replace this with an IR call to make the program stateless.
tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(recv_nodes)))
# vars
msg = var.FEAT_DICT(graph._msg_frame, 'msg')
nf = var.FEAT_DICT(graph._node_frame, 'nf')
out = var.FEAT_DICT(data=tmpframe)
if rfunc_is_list:
# UDF message + builtin reducer
# analyze e2v spmv
spmv_rfunc, rfunc = spmv.analyze_e2v_spmv(graph, rfunc)
inc = spmv.build_inc_matrix(call_type, graph, mid, dst)
spmv.gen_e2v_spmv_schedule(inc, spmv_rfunc, msg, out)
if len(rfunc) == 0:
# All mfunc and rfunc has been processed.
return out
# convert the remaining rfunc to UDFs
rfunc = BundledFunction(rfunc)
# gen degree bucketing schedule for UDF recv
db.gen_degree_bucketing_schedule(graph, rfunc, mid, dst,
recv_nodes, nf, msg, out)
return out
def schedule_snr(graph, def schedule_snr(graph,
edge_tuples, edge_tuples,
...@@ -147,114 +104,9 @@ def schedule_snr(graph, ...@@ -147,114 +104,9 @@ def schedule_snr(graph,
reduced_feat = _gen_send_reduce(call_type, graph, reduced_feat = _gen_send_reduce(call_type, graph,
message_func, reduce_func, (var_u, var_v, var_eid), recv_nodes) message_func, reduce_func, (var_u, var_v, var_eid), recv_nodes)
# generate apply schedule # generate apply schedule
if apply_func: final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
# To avoid writing reduced features back to node frame and reading
# it again for apply phase. Instead, we first read the the node
# features and "merge" it with the reduced features.
v_nf = ir.READ_ROW(var_nf, var_recv_nodes)
v_nf = ir.UPDATE_DICT(v_nf, reduced_feat)
def _afunc_wrapper(node_data):
nb = NodeBatch(graph, recv_nodes, node_data)
return apply_func(nb)
afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf)
final_feat = ir.UPDATE_DICT(reduced_feat, applied_feat)
else:
final_feat = reduced_feat
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
def _gen_send_reduce(
call_type,
graph,
message_func,
reduce_func,
edge_tuples,
recv_nodes):
"""Generate send and reduce schedule.
This guarantees that the returned reduced features are batched
in the *unique-ascending* order of the edge destination node ids.
call_type : str
graph : DGLGraph
message_func : callable, list of builtins
reduce_func : callable, list of builtins
edge_tuples : (u, v, eid) tuple of var.Var
recv_nodes : utils.index
"""
# arg vars
var_u, var_v, var_eid = edge_tuples
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef')
# format the input functions
mfunc = _standardize_func_usage(message_func)
rfunc = _standardize_func_usage(reduce_func)
mfunc_is_list = utils.is_iterable(mfunc)
rfunc_is_list = utils.is_iterable(rfunc)
# Create a tmp frame to hold the feature data.
# The frame has the same size and schemes of the
# node frame.
# TODO(minjie): should replace this with an IR call to make the program stateless.
tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(recv_nodes)))
var_out = var.FEAT_DICT(data=tmpframe)
if mfunc_is_list and rfunc_is_list:
# builtin message + builtin reducer
# analyze v2v spmv
spmv_pairs, mfunc, rfunc = spmv.analyze_v2v_spmv(graph, mfunc, rfunc)
adj = spmv.build_adj_matrix(call_type, graph, var_u.data, var_v.data)
spmv.gen_v2v_spmv_schedule(adj, spmv_pairs, var_nf, var_ef, var_eid, var_out)
if len(mfunc) == 0:
# All mfunc and rfunc have been converted to v2v spmv.
return var_out
if mfunc_is_list:
# Two cases:
# - mfunc is builtin while rfunc is UDF.
# - mfunc and rfunc are both builtin but some combinations
# fall through from the v2v spmv analysis.
# In both cases, convert the mfunc to UDF.
mfunc = BundledFunction(mfunc)
# generate UDF send schedule
var_mf = _gen_send(graph, var_nf, var_ef, var_u, var_v, var_eid, mfunc)
if rfunc_is_list:
# UDF message + builtin reducer
# analyze e2v spmv
spmv_rfunc, rfunc = spmv.analyze_e2v_spmv(graph, rfunc)
inc = spmv.build_inc_matrix(call_type, graph, var_eid.data, var_v.data)
spmv.gen_e2v_spmv_schedule(inc, spmv_rfunc, var_mf, var_out)
if len(rfunc) == 0:
# All mfunc and rfunc has been processed.
return var_out
# convert the remaining rfunc to UDFs
rfunc = BundledFunction(rfunc)
# gen degree bucketing schedule for UDF recv
mid = utils.toindex(slice(0, len(var_v.data))) # message id is from 0~|dst|
db.gen_degree_bucketing_schedule(graph, rfunc,
mid, var_v.data, recv_nodes,
var_nf, var_mf, var_out)
return var_out
def _gen_send(graph, nf, ef, u, v, eid, mfunc):
fdsrc = ir.READ_ROW(nf, u)
fddst = ir.READ_ROW(nf, v)
fdedge = ir.READ_ROW(ef, eid)
def _mfunc_wrapper(src_data, edge_data, dst_data):
eb = EdgeBatch(graph, (u.data, v.data, eid.data),
src_data, edge_data, dst_data)
return mfunc(eb)
_mfunc_wrapper = var.FUNC(_mfunc_wrapper)
msg = ir.EDGE_UDF(_mfunc_wrapper, fdsrc, fdedge, fddst)
return msg
def schedule_update_all(graph, message_func, reduce_func, apply_func): def schedule_update_all(graph, message_func, reduce_func, apply_func):
"""get send and recv schedule """get send and recv schedule
...@@ -269,6 +121,12 @@ def schedule_update_all(graph, message_func, reduce_func, apply_func): ...@@ -269,6 +121,12 @@ def schedule_update_all(graph, message_func, reduce_func, apply_func):
apply_func: callable apply_func: callable
The apply node function The apply node function
""" """
if graph.number_of_edges() == 0:
# All the nodes are zero degree; downgrade to apply nodes
if apply_func is not None:
nodes = utils.toindex(slice(0, graph.number_of_nodes()))
schedule_apply_nodes(graph, nodes, apply_func)
else:
call_type = 'update_all' call_type = 'update_all'
src, dst, _ = graph._graph.edges() src, dst, _ = graph._graph.edges()
eid = utils.toindex(slice(0, graph.number_of_edges())) # shortcut for ALL eid = utils.toindex(slice(0, graph.number_of_edges())) # shortcut for ALL
...@@ -283,20 +141,7 @@ def schedule_update_all(graph, message_func, reduce_func, apply_func): ...@@ -283,20 +141,7 @@ def schedule_update_all(graph, message_func, reduce_func, apply_func):
reduced_feat = _gen_send_reduce(call_type, graph, reduced_feat = _gen_send_reduce(call_type, graph,
message_func, reduce_func, (var_src, var_dst, var_eid), recv_nodes) message_func, reduce_func, (var_src, var_dst, var_eid), recv_nodes)
# generate optional apply # generate optional apply
if apply_func: final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
# To avoid writing reduced features back to node frame and reading
# it again for apply phase. Instead, we first read the the node
# features and "merge" it with the reduced features.
v_nf = ir.READ_ROW(var_nf, var_recv_nodes)
v_nf = ir.UPDATE_DICT(v_nf, reduced_feat)
def _afunc_wrapper(node_data):
nb = NodeBatch(graph, recv_nodes, node_data)
return apply_func(nb)
afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf)
final_feat = ir.UPDATE_DICT(reduced_feat, applied_feat)
else:
final_feat = reduced_feat
ir.WRITE_DICT_(var_nf, final_feat) ir.WRITE_DICT_(var_nf, final_feat)
def schedule_apply_nodes(graph, v, apply_func): def schedule_apply_nodes(graph, v, apply_func):
...@@ -378,25 +223,21 @@ def schedule_push(graph, u, message_func, reduce_func, apply_func): ...@@ -378,25 +223,21 @@ def schedule_push(graph, u, message_func, reduce_func, apply_func):
The reduce function The reduce function
apply_func: callable apply_func: callable
The apply node function The apply node function
Returns
-------
A list of executors for DGL Runtime
""" """
# FIXME: for now, use send_and_recv to implement push
u, v, eid = graph._graph.out_edges(u) u, v, eid = graph._graph.out_edges(u)
if len(eid) == 0: if len(eid) == 0:
return [] # All the pushing nodes have no out edges. No computation is scheduled.
return
schedule_snr(graph, (u, v, eid), message_func, reduce_func, apply_func) schedule_snr(graph, (u, v, eid), message_func, reduce_func, apply_func)
def schedule_pull(graph, v, message_func, reduce_func, apply_func): def schedule_pull(graph, pull_nodes, message_func, reduce_func, apply_func):
"""get pull schedule """get pull schedule
Parameters Parameters
---------- ----------
graph: DGLGraph graph: DGLGraph
The DGLGraph to use The DGLGraph to use
v : utils.Index pull_nodes : utils.Index
Destination nodes for pull Destination nodes for pull
message_func: callable or list of callable message_func: callable or list of callable
The message function The message function
...@@ -404,16 +245,31 @@ def schedule_pull(graph, v, message_func, reduce_func, apply_func): ...@@ -404,16 +245,31 @@ def schedule_pull(graph, v, message_func, reduce_func, apply_func):
The reduce function The reduce function
apply_func: callable apply_func: callable
The apply node function The apply node function
Returns
-------
A list of executors for DGL Runtime
""" """
# FIXME: for now, use send_and_recv to implement pull # TODO(minjie): `in_edges` can be omitted if message and reduce func pairs
u, v, eid = graph._graph.in_edges(v) # can be specialized to SPMV. This needs support for creating adjmat
# directly from dst node frontier.
u, v, eid = graph._graph.in_edges(pull_nodes)
if len(eid) == 0: if len(eid) == 0:
return [] # All the nodes are 0deg; downgrades to apply.
schedule_snr(graph, (u, v, eid), message_func, reduce_func, apply_func) if apply_func is not None:
schedule_apply_nodes(graph, pull_nodes, apply_func)
else:
call_type = 'send_and_recv'
pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor()))
pull_nodes = utils.toindex(pull_nodes)
# create vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_pull_nodes = var.IDX(pull_nodes, name='pull_nodes')
var_u = var.IDX(u)
var_v = var.IDX(v)
var_eid = var.IDX(eid)
# generate send and reduce schedule
reduced_feat = _gen_send_reduce(call_type, graph,
message_func, reduce_func, (var_u, var_v, var_eid), pull_nodes)
# generate optional apply
final_feat = _apply_with_accum(graph, var_pull_nodes, var_nf, reduced_feat, apply_func)
ir.WRITE_ROW_(var_nf, var_pull_nodes, final_feat)
def _check_builtin_func_list(func_list): def _check_builtin_func_list(func_list):
"""Check whether func_list only contains builtin functions.""" """Check whether func_list only contains builtin functions."""
...@@ -422,7 +278,7 @@ def _check_builtin_func_list(func_list): ...@@ -422,7 +278,7 @@ def _check_builtin_func_list(func_list):
raise DGLError("If specify multiple message/reduce functions, \ raise DGLError("If specify multiple message/reduce functions, \
all of them must be builtin") all of them must be builtin")
def _standardize_func_usage(func): def _standardize_func_usage(func, func_name):
"""Standardize usages of message and reduce functions """Standardize usages of message and reduce functions
Message or reduce funtion can be: Message or reduce funtion can be:
1. a UDF 1. a UDF
...@@ -437,14 +293,183 @@ def _standardize_func_usage(func): ...@@ -437,14 +293,183 @@ def _standardize_func_usage(func):
""" """
if utils.is_iterable(func): if utils.is_iterable(func):
# rfunc is a list of builtin # func is a list of builtin
_check_builtin_func_list(func) _check_builtin_func_list(func)
return func return func
elif isinstance(func, BuiltinFunction): elif isinstance(func, BuiltinFunction):
# func is one builtin-in # func is one builtin-in
return [func] return [func]
else: else:
# rfunc is one UDF # func is one UDF
if not callable(func):
raise DGLError('User-defined %s function must be callable.'
' Got: %s' % (func_name, str(func)))
return func return func
def _apply_with_accum(graph, var_nodes, var_nf, var_accum, apply_func):
"""Apply with accumulated features.
Paramters
---------
var_nodes : var.IDX
The nodes.
var_nf : var.FEAT_DICT
The node features.
var_accum : var.FEAT_DICT
The accumulated features.
apply_func : callable, None
The apply function.
"""
if apply_func:
# To avoid writing reduced features back to node frame and reading
# it again for apply phase. Instead, we first read the the node
# features and "merge" it with the reduced features.
v_nf = ir.READ_ROW(var_nf, var_nodes)
v_nf = ir.UPDATE_DICT(v_nf, var_accum)
def _afunc_wrapper(node_data):
nb = NodeBatch(graph, var_nodes.data, node_data)
return apply_func(nb)
afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf)
final_feat = ir.UPDATE_DICT(var_accum, applied_feat)
else:
final_feat = var_accum
return final_feat
def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
"""
graph : DGLGraph
reduce_func : callable
edge_tuples : tuple of utils.Index
recv_nodes : utils.Index
"""
call_type = "recv"
_, dst, mid = edge_tuples
rfunc = _standardize_func_usage(reduce_func, 'reduce')
rfunc_is_list = utils.is_iterable(rfunc)
# Create a tmp frame to hold the feature data.
# The frame has the same size and schemes of the
# node frame.
# TODO(minjie): should replace this with an IR call to make the program stateless.
tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(recv_nodes)))
# vars
msg = var.FEAT_DICT(graph._msg_frame, 'msg')
nf = var.FEAT_DICT(graph._node_frame, 'nf')
out = var.FEAT_DICT(data=tmpframe)
if rfunc_is_list:
# UDF message + builtin reducer
# analyze e2v spmv
spmv_rfunc, rfunc = spmv.analyze_e2v_spmv(graph, rfunc)
# FIXME: refactor this when fixing the multi-recv bug
mat = spmv.build_incmat_by_eid(graph._msg_frame.num_rows, mid, dst, recv_nodes)
inc = utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx))
spmv.gen_e2v_spmv_schedule(inc, spmv_rfunc, msg, out)
if len(rfunc) == 0:
# All mfunc and rfunc has been processed.
return out
# convert the remaining rfunc to UDFs
rfunc = BundledFunction(rfunc)
# gen degree bucketing schedule for UDF recv
db.gen_degree_bucketing_schedule(graph, rfunc, mid, dst,
recv_nodes, nf, msg, out)
return out
def _gen_send_reduce(
call_type,
graph,
message_func,
reduce_func,
edge_tuples,
recv_nodes):
"""Generate send and reduce schedule.
This guarantees that the returned reduced features are batched
in the *unique-ascending* order of the edge destination node ids.
call_type : str
graph : DGLGraph
message_func : callable, list of builtins
reduce_func : callable, list of builtins
edge_tuples : (u, v, eid) tuple of var.Var
recv_nodes : utils.index
"""
# arg vars
var_u, var_v, var_eid = edge_tuples
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef')
# format the input functions
mfunc = _standardize_func_usage(message_func, 'message')
rfunc = _standardize_func_usage(reduce_func, 'reduce')
mfunc_is_list = utils.is_iterable(mfunc)
rfunc_is_list = utils.is_iterable(rfunc)
# Create a tmp frame to hold the feature data.
# The frame has the same size and schemes of the
# node frame.
# TODO(minjie): should replace this with an IR call to make the program stateless.
tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(recv_nodes)))
var_out = var.FEAT_DICT(data=tmpframe)
if mfunc_is_list and rfunc_is_list:
# builtin message + builtin reducer
# analyze v2v spmv
spmv_pairs, mfunc, rfunc = spmv.analyze_v2v_spmv(graph, mfunc, rfunc)
adj = spmv.build_adj_matrix(call_type, graph,
(var_u.data, var_v.data), recv_nodes)
spmv.gen_v2v_spmv_schedule(adj, spmv_pairs, var_nf, var_ef, var_eid, var_out)
if len(mfunc) == 0:
# All mfunc and rfunc have been converted to v2v spmv.
return var_out
if mfunc_is_list:
# Two cases:
# - mfunc is builtin while rfunc is UDF.
# - mfunc and rfunc are both builtin but some combinations
# fall through from the v2v spmv analysis.
# In both cases, convert the mfunc to UDF.
mfunc = BundledFunction(mfunc)
# generate UDF send schedule
var_mf = _gen_send(graph, var_nf, var_ef, var_u, var_v, var_eid, mfunc)
if rfunc_is_list:
# UDF message + builtin reducer
# analyze e2v spmv
spmv_rfunc, rfunc = spmv.analyze_e2v_spmv(graph, rfunc)
inc = spmv.build_inc_matrix(call_type, graph, var_v.data, recv_nodes)
spmv.gen_e2v_spmv_schedule(inc, spmv_rfunc, var_mf, var_out)
if len(rfunc) == 0:
# All mfunc and rfunc has been processed.
return var_out
# convert the remaining rfunc to UDFs
rfunc = BundledFunction(rfunc)
# gen degree bucketing schedule for UDF recv
mid = utils.toindex(slice(0, len(var_v.data))) # message id is from 0~|dst|
db.gen_degree_bucketing_schedule(graph, rfunc,
mid, var_v.data, recv_nodes,
var_nf, var_mf, var_out)
return var_out
def _gen_send(graph, nf, ef, u, v, eid, mfunc):
fdsrc = ir.READ_ROW(nf, u)
fddst = ir.READ_ROW(nf, v)
fdedge = ir.READ_ROW(ef, eid)
def _mfunc_wrapper(src_data, edge_data, dst_data):
eb = EdgeBatch(graph, (u.data, v.data, eid.data),
src_data, edge_data, dst_data)
return mfunc(eb)
_mfunc_wrapper = var.FUNC(_mfunc_wrapper)
msg = ir.EDGE_UDF(_mfunc_wrapper, fdsrc, fdedge, fddst)
return msg
_init_api("dgl.runtime.scheduler") _init_api("dgl.runtime.scheduler")
...@@ -43,7 +43,7 @@ def analyze_v2v_spmv(graph, mfunc, rfunc): ...@@ -43,7 +43,7 @@ def analyze_v2v_spmv(graph, mfunc, rfunc):
raise DGLError('Reduce function requires message field "%s",' raise DGLError('Reduce function requires message field "%s",'
' but no message function generates it.' % mfld) ' but no message function generates it.' % mfld)
mfn = fld2mfunc[mfld] mfn = fld2mfunc[mfld]
# FIXME: should pre-compile a look up table # TODO(minjie): should pre-compile a look up table
if mfn.is_spmv_supported(graph) and rfn.is_spmv_supported(): if mfn.is_spmv_supported(graph) and rfn.is_spmv_supported():
spmv_pairs.append((mfn, rfn)) spmv_pairs.append((mfn, rfn))
else: else:
...@@ -122,27 +122,40 @@ def gen_e2v_spmv_schedule(inc, spmv_rfunc, mf, out): ...@@ -122,27 +122,40 @@ def gen_e2v_spmv_schedule(inc, spmv_rfunc, mf, out):
ftdst = ir.SPMV(inc_var, ftmsg) ftdst = ir.SPMV(inc_var, ftmsg)
ir.WRITE_COL_(out, var.STR(rfn.out_field), ftdst) ir.WRITE_COL_(out, var.STR(rfn.out_field), ftdst)
def build_adj_matrix(call_type, graph, u, v): def build_adj_matrix(call_type, graph, edges, reduce_nodes):
""" """Build adjacency matrix.
Parameters
----------
call_type : str call_type : str
Can be 'update_all', 'send_and_recv'
graph : DGLGraph graph : DGLGraph
u : utils.Index The graph
v : utils.Index edges : tuple of utils.Index
(u, v)
reduce_nodes : utils.Index
The nodes to reduce messages, which will be target dimension
of the adjmat. The nodes include unique(v) and zero-degree-nodes.
Returns
-------
utils.CtxCachedObject
Get be used to get adjacency matrix on the provided ctx.
""" """
if call_type == "update_all": if call_type == "update_all":
# full graph case # full graph case
return utils.CtxCachedObject(lambda ctx : graph.adjacency_matrix(ctx=ctx)) return utils.CtxCachedObject(lambda ctx : graph.adjacency_matrix(ctx=ctx))
elif call_type == "send_and_recv": elif call_type == "send_and_recv":
# edgeset case # edgeset case
mat = build_adj_matrix_uv(graph, u, v) mat = build_adj_matrix_uv(graph, edges, reduce_nodes)
return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx)) return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx))
else: else:
raise DGLError('Invalid call type:', call_type) raise DGLError('Invalid call type:', call_type)
def build_adj_matrix_index_uv(graph, u, v): def build_adj_matrix_index_uv(graph, edges, reduce_nodes):
"""Build adj matrix index and shape using the given (u, v) edges. """Build adj matrix index and shape using the given (u, v) edges.
The matrix is of shape (len(unique(v)), n), where n is the number of nodes The matrix is of shape (len(reduce_nodes), n), where n is the number of nodes
in the graph. Therefore, when doing SPMV, the src node data in the graph. Therefore, when doing SPMV, the src node data
should be all the node features. should be all the node features.
...@@ -154,10 +167,11 @@ def build_adj_matrix_index_uv(graph, u, v): ...@@ -154,10 +167,11 @@ def build_adj_matrix_index_uv(graph, u, v):
--------- ---------
graph : DGLGraph graph : DGLGraph
The graph The graph
u : utils.Index edges : tuple of utils.Index
Src nodes. (u, v)
v : utils.Index reduce_nodes : utils.Index
Dst nodes. The nodes to reduce messages, which will be target dimension
of the adjmat. The nodes include unique(v) and zero-degree-nodes.
Returns Returns
------- -------
...@@ -166,21 +180,23 @@ def build_adj_matrix_index_uv(graph, u, v): ...@@ -166,21 +180,23 @@ def build_adj_matrix_index_uv(graph, u, v):
tupe of int tupe of int
The dense shape. The dense shape.
""" """
new2old, old2new = utils.build_relabel_map(v) # TODO(minjie): add node frontier for this
new2old, old2new = utils.build_relabel_map(reduce_nodes, sorted=True)
u, v = edges
u = u.tousertensor() u = u.tousertensor()
v = v.tousertensor() v = v.tousertensor()
new_v = old2new[v] # FIXME(minjie): no use [] new_v = old2new[v] # FIXME(minjie): no use []
n = graph.number_of_nodes() n = graph.number_of_nodes()
m = len(new2old) m = len(reduce_nodes)
row = F.unsqueeze(new_v, 0) row = F.unsqueeze(new_v, 0)
col = F.unsqueeze(u, 0) col = F.unsqueeze(u, 0)
idx = F.cat([row, col], dim=0) idx = F.cat([row, col], dim=0)
return ('coo', idx), (m, n) return ('coo', idx), (m, n)
def build_adj_matrix_uv(graph, u, v): def build_adj_matrix_uv(graph, edges, reduce_nodes):
"""Build adj matrix using the given (u, v) edges. """Build adj matrix using the given (u, v) edges and target nodes.
The matrix is of shape (len(v), n), where n is the number of nodes The matrix is of shape (len(reduce_nodes), n), where n is the number of nodes
in the graph. Therefore, when doing SPMV, the src node data in the graph. Therefore, when doing SPMV, the src node data
should be all the node features. should be all the node features.
...@@ -188,64 +204,109 @@ def build_adj_matrix_uv(graph, u, v): ...@@ -188,64 +204,109 @@ def build_adj_matrix_uv(graph, u, v):
--------- ---------
graph : DGLGraph graph : DGLGraph
The graph The graph
u : utils.Index edges : tuple of utils.Index
Src nodes. (u, v)
v : utils.Index reduce_nodes : utils.Index
Dst nodes. The nodes to reduce messages, which will be target dimension
of the adjmat. The nodes include unique(v) and zero-degree-nodes.
Returns Returns
------- -------
Sparse matrix Sparse matrix
The adjacency matrix on CPU The adjacency matrix on CPU
""" """
sp_idx, shape = build_adj_matrix_index_uv(graph, u, v) sp_idx, shape = build_adj_matrix_index_uv(graph, edges, reduce_nodes)
u, v = edges
nnz = len(u) nnz = len(u)
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu()) dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu())
mat = F.sparse_matrix(dat, sp_idx, shape) mat = F.sparse_matrix(dat, sp_idx, shape)
return mat return mat
def build_inc_matrix(call_type, graph, eid, v): def build_inc_matrix(call_type, graph, dst, reduce_nodes):
""" """Build incidence matrix.
Parameters
----------
call_type : str call_type : str
Can be 'update_all', 'send_and_recv'.
graph : DGLGraph graph : DGLGraph
eid : utils.Index The graph.
v : utils.Index dst : utils.Index
The destination nodes of the edges.
reduce_nodes : utils.Index
The nodes to reduce messages, which will be target dimension
of the incmat. The nodes include unique(dst) and zero-degree-nodes.
Returns
-------
utils.CtxCachedObject
Get be used to get incidence matrix on the provided ctx.
""" """
if call_type == "update_all": if call_type == "update_all":
# full graph case # full graph case
return utils.CtxCachedObject(lambda ctx : graph.incidence_matrix(type='in', ctx=ctx)) return utils.CtxCachedObject(lambda ctx : graph.incidence_matrix(type='in', ctx=ctx))
elif call_type == "send_and_recv": elif call_type == "send_and_recv":
# edgeset case # edgeset case
mat = build_inc_matrix_v(v) mat = build_incmat_by_dst(dst, reduce_nodes)
return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx))
elif call_type == "recv":
# dst nodeset case
mat = build_inc_matrix_eid(eid, v)
return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx)) return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx))
else: else:
raise DGLError('Invalid call type:', call_type) raise DGLError('Invalid call type:', call_type)
def build_inc_matrix_eid(eid, v): def build_incmat_by_eid(m, eid, dst, reduce_nodes):
"""A spmat of shape (n, m), where n=len(unique(v)), m=len(eid). """Build incidence matrix using edge id and edge dst nodes.
The incidence matrix is of shape (n, m), where n=len(reduce_nodes).
The nnz is equal to len(eid).
Invariant: len(eid) == len(v) Invariant: len(eid) == len(dst)
The dst nodes will be sorted in the *unique-ascending* order of The dst nodes will be sorted in the *unique-ascending* order of
their ids. This is compatible with other reduce scheduler such as their ids. This is compatible with other reduce scheduler such as
degree-bucketing scheduler. degree-bucketing scheduler.
Examples
--------
Total of seven edges. Three edges point to node 1 (eid=0,1,2);
two point to node 3 (eid=3,4); two point to node 4 (eid=5,6).
Only five edges should be included in the result incmat (eid=1,2,3,5,6).
There are five nodes in the final target dimension (0~4),
where node 0 and 2 are two 0-deg nodes.
>>> m = 7
>>> eid = [1, 2, 3, 5, 6]
>>> dst = [1, 1, 3, 4, 4]
>>> reduce_nodes = [0, 1, 2, 3, 4]
>>> build_incmat_by_eid(m, eid, dst, reduce_nodes)
tensor([[0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 1]], shape=(5, 7))
Paramters
---------
m : int
The source dimension size of the incidence matrix.
eid : utils.Index eid : utils.Index
v : utils.Index The edge ids. All ids must be in range [0, m).
dst : utils.Index
The edge destination nodes. len(eid) == len(dst).
reduce_nodes : utils.Index
The nodes to reduce messages, which will be target dimension
of the incmat. The nodes include unique(dst) and zero-degree-nodes.
Returns
-------
Sparse matrix
The incidence matrix on CPU
""" """
# relabel v to range(0, len(unique(v))) new2old, old2new = utils.build_relabel_map(reduce_nodes, sorted=True)
new2old, old2new = utils.build_relabel_map(v) dst = dst.tousertensor()
v = v.tousertensor()
eid = eid.tousertensor() eid = eid.tousertensor()
new_v = old2new[v] # FIXME(minjie): no use [] # relabel edges dsts
new_v = old2new[dst] # FIXME(minjie): no use []
# create sparse index tensor # create sparse index tensor
m = len(eid) n = len(reduce_nodes)
n = len(new2old)
row = F.unsqueeze(new_v, 0) row = F.unsqueeze(new_v, 0)
col = F.unsqueeze(eid, 0) col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0) idx = F.cat([row, col], dim=0)
...@@ -254,10 +315,38 @@ def build_inc_matrix_eid(eid, v): ...@@ -254,10 +315,38 @@ def build_inc_matrix_eid(eid, v):
dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu()) dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu())
return F.sparse_matrix(dat, ('coo', idx), (n, m)) return F.sparse_matrix(dat, ('coo', idx), (n, m))
def build_inc_matrix_v(v): def build_incmat_by_dst(dst, reduce_nodes):
"""A spmat of shape (n, m), where n=len(unique(v)), m=len(v). """Build incidence matrix using only edge destinations.
v : utils.Index The incidence matrix is of shape (n, m), where n=len(reduce_nodes), m=len(dst).
The nnz is equal to len(dst).
Examples
--------
Five edges. Two edges point to node 1; one points to node 3;
two point to node 4. There are five nodes in the final
target dimension (0~4), where node 0 and 2 are two 0-deg nodes.
>>> dst = [1, 1, 3, 4, 4]
>>> reduce_nodes = [0, 1, 2, 3, 4]
>>> build_incmat_by_dst(dst, reduced_nodes)
tensor([[0, 0, 0, 0, 0],
[1, 1, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 1]], shape=(5, 5))
Parameters
----------
dst : utils.Index
The edge destinations.
reduce_nodes : utils.Index
The nodes to reduce messages, which will be target dimension
of the incmat. The nodes include unique(dst) and zero-degree-nodes.
Returns
-------
Sparse matrix
The incidence matrix on CPU
""" """
eid = utils.toindex(F.arange(0, len(v))) eid = utils.toindex(F.arange(0, len(dst)))
return build_inc_matrix_eid(eid, v) return build_incmat_by_eid(len(eid), eid, dst, reduce_nodes)
...@@ -168,6 +168,7 @@ class HybridDict(Mapping): ...@@ -168,6 +168,7 @@ class HybridDict(Mapping):
for d in self._dict_like_list: for d in self._dict_like_list:
if key in d: if key in d:
return d[key] return d[key]
raise KeyError(key)
def __contains__(self, key): def __contains__(self, key):
return key in self.keys() return key in self.keys()
...@@ -198,7 +199,7 @@ class ReadOnlyDict(Mapping): ...@@ -198,7 +199,7 @@ class ReadOnlyDict(Mapping):
def __len__(self): def __len__(self):
return len(self._dict_like) return len(self._dict_like)
def build_relabel_map(x): def build_relabel_map(x, sorted=False):
"""Relabel the input ids to continuous ids that starts from zero. """Relabel the input ids to continuous ids that starts from zero.
Ids are assigned new ids according to their ascending order. Ids are assigned new ids according to their ascending order.
...@@ -218,6 +219,8 @@ def build_relabel_map(x): ...@@ -218,6 +219,8 @@ def build_relabel_map(x):
---------- ----------
x : Index x : Index
The input ids. The input ids.
sorted : bool, default=False
Whether the input has already been unique and sorted.
Returns Returns
------- -------
...@@ -229,7 +232,10 @@ def build_relabel_map(x): ...@@ -229,7 +232,10 @@ def build_relabel_map(x):
new id tensor: new_id = old_to_new[old_id] new id tensor: new_id = old_to_new[old_id]
""" """
x = x.tousertensor() x = x.tousertensor()
if not sorted:
unique_x, _ = F.sort_1d(F.unique(x)) unique_x, _ = F.sort_1d(F.unique(x))
else:
unique_x = x
map_len = int(F.max(unique_x, dim=0)) + 1 map_len = int(F.max(unique_x, dim=0)) + 1
old_to_new = F.zeros(map_len, dtype=F.int64, ctx=F.cpu()) old_to_new = F.zeros(map_len, dtype=F.int64, ctx=F.cpu())
F.scatter_row_inplace(old_to_new, unique_x, F.arange(0, len(unique_x))) F.scatter_row_inplace(old_to_new, unique_x, F.arange(0, len(unique_x)))
......
...@@ -282,27 +282,40 @@ def check_pull_0deg(readonly): ...@@ -282,27 +282,40 @@ def check_pull_0deg(readonly):
return {'m' : edges.src['h']} return {'m' : edges.src['h']}
def _reduce(nodes): def _reduce(nodes):
return {'h' : nodes.mailbox['m'].sum(1)} return {'h' : nodes.mailbox['m'].sum(1)}
def _apply(nodes):
return {'h' : nodes.data['h'] * 2}
def _init2(shape, dtype, ctx, ids):
return 2 + mx.nd.zeros(shape, dtype=dtype, ctx=ctx)
g.set_n_initializer(_init2, 'h')
old_repr = mx.nd.random.normal(shape=(2, 5)) old_repr = mx.nd.random.normal(shape=(2, 5))
g.set_n_repr({'h' : old_repr})
g.pull(0, _message, _reduce) # test#1: pull only 0-deg node
g.ndata['h'] = old_repr
g.pull(0, _message, _reduce, _apply)
new_repr = g.ndata['h'] new_repr = g.ndata['h']
# TODO(minjie): this is not the intended behavior. Pull node#0 # 0deg check: equal to apply_nodes
# should reset node#0 to the initial value. The bug is because assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy() * 2)
# current pull is implemented using send_and_recv. Since there # non-0deg check: untouched
# is no edge to node#0 so the send_and_recv is skipped. Fix this
# behavior when optimizing the pull scheduler.
assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
assert np.allclose(new_repr[1].asnumpy(), old_repr[1].asnumpy()) assert np.allclose(new_repr[1].asnumpy(), old_repr[1].asnumpy())
g.pull(1, _message, _reduce)
new_repr = g.ndata['h']
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())
old_repr = mx.nd.random.normal(shape=(2, 5)) # test#2: pull only non-deg node
g.set_n_repr({'h' : old_repr}) g.ndata['h'] = old_repr
g.pull([0, 1], _message, _reduce) g.pull(1, _message, _reduce, _apply)
new_repr = g.ndata['h'] new_repr = g.ndata['h']
# 0deg check: untouched
assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy()) assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy()) # non-0deg check: recved node0 and got applied
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy() * 2)
# test#3: pull only both nodes
g.ndata['h'] = old_repr
g.pull([0, 1], _message, _reduce, _apply)
new_repr = g.ndata['h']
# 0deg check: init and applied
t = mx.nd.zeros(shape=(2,5)) + 4
assert np.allclose(new_repr[0].asnumpy(), t.asnumpy())
# non-0deg check: recv node0 and applied
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy() * 2)
def test_pull_0deg(): def test_pull_0deg():
check_pull_0deg(True) check_pull_0deg(True)
......
...@@ -235,7 +235,87 @@ def test_update_routines(): ...@@ -235,7 +235,87 @@ def test_update_routines():
assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)}) assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
def test_reduce_0deg(): def test_recv_0deg():
# test recv with 0deg nodes;
g = DGLGraph()
g.add_nodes(2)
g.add_edge(0, 1)
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
def _apply(nodes):
return {'h' : nodes.data['h'] * 2}
def _init2(shape, dtype, ctx, ids):
return 2 + th.zeros(shape, dtype=dtype, device=ctx)
g.register_message_func(_message)
g.register_reduce_func(_reduce)
g.register_apply_node_func(_apply)
g.set_n_initializer(_init2, 'h')
# test#1: recv both 0deg and non-0deg nodes
old = th.randn((2, 5))
g.ndata['h'] = old
g.send((0, 1))
g.recv([0, 1])
new = g.ndata.pop('h')
# 0deg check: initialized with the func and got applied
assert U.allclose(new[0], th.full((5,), 4))
# non-0deg check
assert U.allclose(new[1], th.sum(old, 0) * 2)
# test#2: recv only 0deg node is equal to apply
old = th.randn((2, 5))
g.ndata['h'] = old
g.send((0, 1))
g.recv(0)
new = g.ndata.pop('h')
# 0deg check: equal to apply_nodes
assert U.allclose(new[0], 2 * old[0])
# non-0deg check: untouched
assert U.allclose(new[1], old[1])
def test_recv_0deg_newfld():
# test recv with 0deg nodes; the reducer also creates a new field
g = DGLGraph()
g.add_nodes(2)
g.add_edge(0, 1)
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
return {'h1' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
def _apply(nodes):
return {'h1' : nodes.data['h1'] * 2}
def _init2(shape, dtype, ctx, ids):
return 2 + th.zeros(shape, dtype=dtype, device=ctx)
g.register_message_func(_message)
g.register_reduce_func(_reduce)
g.register_apply_node_func(_apply)
# test#1: recv both 0deg and non-0deg nodes
old = th.randn((2, 5))
g.set_n_initializer(_init2, 'h1')
g.ndata['h'] = old
g.send((0, 1))
g.recv([0, 1])
new = g.ndata.pop('h1')
# 0deg check: initialized with the func and got applied
assert U.allclose(new[0], th.full((5,), 4))
# non-0deg check
assert U.allclose(new[1], th.sum(old, 0) * 2)
# test#2: recv only 0deg node
old = th.randn((2, 5))
g.ndata['h'] = old
g.ndata['h1'] = th.full((2, 5), -1) # this is necessary
g.send((0, 1))
g.recv(0)
new = g.ndata.pop('h1')
# 0deg check: fallback to apply
assert U.allclose(new[0], th.full((5,), -2))
# non-0deg check: not changed
assert U.allclose(new[1], th.full((5,), -1))
def test_update_all_0deg():
# test#1
g = DGLGraph() g = DGLGraph()
g.add_nodes(5) g.add_nodes(5)
g.add_edge(1, 0) g.add_edge(1, 0)
...@@ -246,18 +326,30 @@ def test_reduce_0deg(): ...@@ -246,18 +326,30 @@ def test_reduce_0deg():
return {'m' : edges.src['h']} return {'m' : edges.src['h']}
def _reduce(nodes): def _reduce(nodes):
return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)} return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
def _apply(nodes):
return {'h' : nodes.data['h'] * 2}
def _init2(shape, dtype, ctx, ids): def _init2(shape, dtype, ctx, ids):
return 2 + th.zeros(shape, dtype=dtype, device=ctx) return 2 + th.zeros(shape, dtype=dtype, device=ctx)
g.set_n_initializer(_init2, 'h') g.set_n_initializer(_init2, 'h')
old_repr = th.randn(5, 5) old_repr = th.randn(5, 5)
g.ndata['h'] = old_repr g.ndata['h'] = old_repr
g.update_all(_message, _reduce) g.update_all(_message, _reduce, _apply)
new_repr = g.ndata['h'] new_repr = g.ndata['h']
# the first row of the new_repr should be the sum of all the node # the first row of the new_repr should be the sum of all the node
# features; while the 0-deg nodes should be initialized by the # features; while the 0-deg nodes should be initialized by the
# initializer. # initializer and applied with UDF.
assert U.allclose(new_repr[1:], 2+th.zeros((4,5))) assert U.allclose(new_repr[1:], 2*(2+th.zeros((4,5))))
assert U.allclose(new_repr[0], old_repr.sum(0)) assert U.allclose(new_repr[0], 2 * old_repr.sum(0))
# test#2: graph with no edge
g = DGLGraph()
g.add_nodes(5)
g.set_n_initializer(_init2, 'h')
g.ndata['h'] = old_repr
g.update_all(_message, _reduce, _apply)
new_repr = g.ndata['h']
# should fallback to apply
assert U.allclose(new_repr, 2*old_repr)
def test_pull_0deg(): def test_pull_0deg():
g = DGLGraph() g = DGLGraph()
...@@ -266,25 +358,34 @@ def test_pull_0deg(): ...@@ -266,25 +358,34 @@ def test_pull_0deg():
def _message(edges): def _message(edges):
return {'m' : edges.src['h']} return {'m' : edges.src['h']}
def _reduce(nodes): def _reduce(nodes):
return {'h' : nodes.mailbox['m'].sum(1)} return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
old_repr = th.randn(2, 5) def _apply(nodes):
g.ndata['h'] = old_repr return {'h' : nodes.data['h'] * 2}
def _init2(shape, dtype, ctx, ids):
g.pull(0, _message, _reduce) return 2 + th.zeros(shape, dtype=dtype, device=ctx)
new_repr = g.ndata['h'] g.register_message_func(_message)
assert U.allclose(new_repr[0], old_repr[0]) g.register_reduce_func(_reduce)
assert U.allclose(new_repr[1], old_repr[1]) g.register_apply_node_func(_apply)
g.set_n_initializer(_init2, 'h')
g.pull(1, _message, _reduce) # test#1: pull both 0deg and non-0deg nodes
new_repr = g.ndata['h'] old = th.randn((2, 5))
assert U.allclose(new_repr[1], old_repr[0]) g.ndata['h'] = old
g.pull([0, 1])
old_repr = th.randn(2, 5) new = g.ndata.pop('h')
g.ndata['h'] = old_repr # 0deg check: initialized with the func and got applied
g.pull([0, 1], _message, _reduce) assert U.allclose(new[0], th.full((5,), 4))
new_repr = g.ndata['h'] # non-0deg check
assert U.allclose(new_repr[0], old_repr[0]) assert U.allclose(new[1], th.sum(old, 0) * 2)
assert U.allclose(new_repr[1], old_repr[0])
# test#2: pull only 0deg node
old = th.randn((2, 5))
g.ndata['h'] = old
g.pull(0)
new = g.ndata.pop('h')
# 0deg check: fallback to apply
assert U.allclose(new[0], 2*old[0])
# non-0deg check: not touched
assert U.allclose(new[1], old[1])
def _disabled_test_send_twice(): def _disabled_test_send_twice():
# TODO(minjie): please re-enable this unittest after the send code problem is fixed. # TODO(minjie): please re-enable this unittest after the send code problem is fixed.
...@@ -419,7 +520,9 @@ if __name__ == '__main__': ...@@ -419,7 +520,9 @@ if __name__ == '__main__':
test_apply_nodes() test_apply_nodes()
test_apply_edges() test_apply_edges()
test_update_routines() test_update_routines()
test_reduce_0deg() test_recv_0deg()
test_recv_0deg_newfld()
test_update_all_0deg()
test_pull_0deg() test_pull_0deg()
test_send_multigraph() test_send_multigraph()
test_dynamic_addition() test_dynamic_addition()
...@@ -110,6 +110,52 @@ def test_v2v_snr(): ...@@ -110,6 +110,52 @@ def test_v2v_snr():
# test 2d node features # test 2d node features
_test('f2') _test('f2')
def test_v2v_pull():
nodes = th.tensor([1, 2, 3, 9])
def _test(fld):
def message_func(edges):
return {'m' : edges.src[fld]}
def message_func_edge(edges):
if len(edges.src[fld].shape) == 1:
return {'m' : edges.src[fld] * edges.data['e1']}
else:
return {'m' : edges.src[fld] * edges.data['e2']}
def reduce_func(nodes):
return {fld : th.sum(nodes.mailbox['m'], 1)}
def apply_func(nodes):
return {fld : 2 * nodes.data[fld]}
g = generate_graph()
# send and recv
v1 = g.ndata[fld]
g.pull(nodes, fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out=fld), apply_func)
v2 = g.ndata[fld]
g.ndata[fld] = v1
g.pull(nodes, message_func, reduce_func, apply_func)
v3 = g.ndata[fld]
assert U.allclose(v2, v3)
# send and recv with edge weights
v1 = g.ndata[fld]
g.pull(nodes, fn.src_mul_edge(src=fld, edge='e1', out='m'),
fn.sum(msg='m', out=fld), apply_func)
v2 = g.ndata[fld]
g.ndata[fld] = v1
g.pull(nodes, fn.src_mul_edge(src=fld, edge='e2', out='m'),
fn.sum(msg='m', out=fld), apply_func)
v3 = g.ndata[fld]
g.ndata[fld] = v1
g.pull(nodes, message_func_edge, reduce_func, apply_func)
v4 = g.ndata[fld]
assert U.allclose(v2, v3)
assert U.allclose(v3, v4)
# test 1d node features
_test('f1')
# test 2d node features
_test('f2')
def test_v2v_update_all_multi_fn(): def test_v2v_update_all_multi_fn():
def message_func(edges): def message_func(edges):
return {'m2': edges.src['f2']} return {'m2': edges.src['f2']}
...@@ -311,7 +357,7 @@ def test_e2v_recv_multi_fn(): ...@@ -311,7 +357,7 @@ def test_e2v_recv_multi_fn():
# test 2d node features # test 2d node features
_test('f2') _test('f2')
def test_multi_fn_fallback(): def test_update_all_multi_fallback():
# create a graph with zero in degree nodes # create a graph with zero in degree nodes
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.add_nodes(10) g.add_nodes(10)
...@@ -383,12 +429,98 @@ def test_multi_fn_fallback(): ...@@ -383,12 +429,98 @@ def test_multi_fn_fallback():
assert U.allclose(o2, g.ndata.pop('o2')) assert U.allclose(o2, g.ndata.pop('o2'))
assert U.allclose(o3, g.ndata.pop('o3')) assert U.allclose(o3, g.ndata.pop('o3'))
def test_pull_multi_fallback():
# create a graph with zero in degree nodes
g = dgl.DGLGraph()
g.add_nodes(10)
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
g.ndata['h'] = th.randn(10, D)
g.edata['w1'] = th.randn(16,)
g.edata['w2'] = th.randn(16, D)
def _mfunc_hxw1(edges):
return {'m1' : edges.src['h'] * th.unsqueeze(edges.data['w1'], 1)}
def _mfunc_hxw2(edges):
return {'m2' : edges.src['h'] * edges.data['w2']}
def _rfunc_m1(nodes):
return {'o1' : th.sum(nodes.mailbox['m1'], 1)}
def _rfunc_m2(nodes):
return {'o2' : th.sum(nodes.mailbox['m2'], 1)}
def _rfunc_m1max(nodes):
return {'o3' : th.max(nodes.mailbox['m1'], 1)[0]}
def _afunc(nodes):
ret = {}
for k, v in nodes.data.items():
if k.startswith('o'):
ret[k] = 2 * v
return ret
# nodes to pull
def _pull_nodes(nodes):
# compute ground truth
g.pull(nodes, _mfunc_hxw1, _rfunc_m1, _afunc)
o1 = g.ndata.pop('o1')
g.pull(nodes, _mfunc_hxw2, _rfunc_m2, _afunc)
o2 = g.ndata.pop('o2')
g.pull(nodes, _mfunc_hxw1, _rfunc_m1max, _afunc)
o3 = g.ndata.pop('o3')
# v2v spmv
g.pull(nodes, fn.src_mul_edge(src='h', edge='w1', out='m1'),
fn.sum(msg='m1', out='o1'),
_afunc)
assert U.allclose(o1, g.ndata.pop('o1'))
# v2v fallback to e2v
g.pull(nodes, fn.src_mul_edge(src='h', edge='w2', out='m2'),
fn.sum(msg='m2', out='o2'),
_afunc)
assert U.allclose(o2, g.ndata.pop('o2'))
# v2v fallback to degree bucketing
g.pull(nodes, fn.src_mul_edge(src='h', edge='w1', out='m1'),
fn.max(msg='m1', out='o3'),
_afunc)
assert U.allclose(o3, g.ndata.pop('o3'))
# multi builtins, both v2v spmv
g.pull(nodes,
[fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w1', out='m2')],
[fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')],
_afunc)
assert U.allclose(o1, g.ndata.pop('o1'))
assert U.allclose(o1, g.ndata.pop('o2'))
# multi builtins, one v2v spmv, one fallback to e2v
g.pull(nodes,
[fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w2', out='m2')],
[fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')],
_afunc)
assert U.allclose(o1, g.ndata.pop('o1'))
assert U.allclose(o2, g.ndata.pop('o2'))
# multi builtins, one v2v spmv, one fallback to e2v, one fallback to degree-bucketing
g.pull(nodes,
[fn.src_mul_edge(src='h', edge='w1', out='m1'),
fn.src_mul_edge(src='h', edge='w2', out='m2'),
fn.src_mul_edge(src='h', edge='w1', out='m3')],
[fn.sum(msg='m1', out='o1'),
fn.sum(msg='m2', out='o2'),
fn.max(msg='m3', out='o3')],
_afunc)
assert U.allclose(o1, g.ndata.pop('o1'))
assert U.allclose(o2, g.ndata.pop('o2'))
assert U.allclose(o3, g.ndata.pop('o3'))
# test#1: non-0deg nodes
nodes = [1, 2, 9]
_pull_nodes(nodes)
# test#2: 0deg nodes + non-0deg nodes
nodes = [0, 1, 2, 9]
_pull_nodes(nodes)
if __name__ == '__main__': if __name__ == '__main__':
test_v2v_update_all() test_v2v_update_all()
test_v2v_snr() test_v2v_snr()
test_v2v_pull()
test_v2v_update_all_multi_fn() test_v2v_update_all_multi_fn()
test_v2v_snr_multi_fn() test_v2v_snr_multi_fn()
test_e2v_update_all_multi_fn() test_e2v_update_all_multi_fn()
test_e2v_snr_multi_fn() test_e2v_snr_multi_fn()
test_e2v_recv_multi_fn() test_e2v_recv_multi_fn()
test_multi_fn_fallback() test_update_all_multi_fallback()
test_pull_multi_fallback()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment