Unverified Commit 52ed09a3 authored by Lingfan Yu's avatar Lingfan Yu Committed by GitHub
Browse files

[Bug] Fix inplace update (#221)

* inplace write row op and executor

* update scheduler and graph to use inplace write

* fix

* fix bug

* test case for inplace

* fix bugs for inplace apply node/edge

* fix comments

* th.allclose -> U.allclose
parent dd26ff10
...@@ -611,6 +611,9 @@ class FrameRef(MutableMapping): ...@@ -611,6 +611,9 @@ class FrameRef(MutableMapping):
return utils.LazyDict(lambda key: self._frame[key][rows], keys=self.keys()) return utils.LazyDict(lambda key: self._frame[key][rows], keys=self.keys())
def __setitem__(self, key, val): def __setitem__(self, key, val):
self.set_item_inplace(key, val, inplace=False)
def set_item_inplace(self, key, val, inplace):
"""Update the data in the frame. """Update the data in the frame.
If the provided key is string, the corresponding column data will be updated. If the provided key is string, the corresponding column data will be updated.
...@@ -629,9 +632,11 @@ class FrameRef(MutableMapping): ...@@ -629,9 +632,11 @@ class FrameRef(MutableMapping):
The key. The key.
val : Tensor or dict of tensors val : Tensor or dict of tensors
The value. The value.
inplace: bool
If True, update will be done in place
""" """
if isinstance(key, str): if isinstance(key, str):
self.update_column(key, val, inplace=False) self.update_column(key, val, inplace=inplace)
elif isinstance(key, slice) and key == slice(0, self.num_rows): elif isinstance(key, slice) and key == slice(0, self.num_rows):
# shortcut for updating all the rows # shortcut for updating all the rows
return self.update(val) return self.update(val)
...@@ -639,7 +644,7 @@ class FrameRef(MutableMapping): ...@@ -639,7 +644,7 @@ class FrameRef(MutableMapping):
# shortcut for selecting all the rows # shortcut for selecting all the rows
return self.update(val) return self.update(val)
else: else:
self.update_rows(key, val, inplace=False) self.update_rows(key, val, inplace=inplace)
def update_column(self, name, data, inplace): def update_column(self, name, data, inplace):
"""Update the column. """Update the column.
......
...@@ -1163,8 +1163,8 @@ class DGLGraph(object): ...@@ -1163,8 +1163,8 @@ class DGLGraph(object):
and (D1, D2, ...) be the shape of the node representation tensor. The and (D1, D2, ...) be the shape of the node representation tensor. The
length of the given node ids must match B (i.e, len(u) == B). length of the given node ids must match B (i.e, len(u) == B).
All update will be done out-placely to work with autograd unless the inplace All update will be done out of place to work with autograd unless the
flag is true. inplace flag is true.
Parameters Parameters
---------- ----------
...@@ -1173,7 +1173,7 @@ class DGLGraph(object): ...@@ -1173,7 +1173,7 @@ class DGLGraph(object):
u : node, container or tensor u : node, container or tensor
The node(s). The node(s).
inplace : bool inplace : bool
True if the update is done inplacely If True, update will be done in place, but autograd will break.
""" """
# sanity check # sanity check
if not utils.is_dict_like(hu): if not utils.is_dict_like(hu):
...@@ -1241,8 +1241,8 @@ class DGLGraph(object): ...@@ -1241,8 +1241,8 @@ class DGLGraph(object):
is of shape (B, D1, D2, ...), where B is the number of edges to be updated, is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
and (D1, D2, ...) be the shape of the edge representation tensor. and (D1, D2, ...) be the shape of the edge representation tensor.
All update will be done out-placely to work with autograd unless the inplace All update will be done out of place to work with autograd unless the
flag is true. inplace flag is true.
Parameters Parameters
---------- ----------
...@@ -1252,7 +1252,7 @@ class DGLGraph(object): ...@@ -1252,7 +1252,7 @@ class DGLGraph(object):
Edges can be a pair of endpoint nodes (u, v), or a Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges. tensor of edge ids. The default value is all the edges.
inplace : bool inplace : bool
True if the update is done inplacely If True, update will be done in place, but autograd will break.
""" """
# parse argument # parse argument
if is_all(edges): if is_all(edges):
...@@ -1390,6 +1390,8 @@ class DGLGraph(object): ...@@ -1390,6 +1390,8 @@ class DGLGraph(object):
The UDF applied on the node features. The UDF applied on the node features.
v : int, iterable of int, tensor, optional v : int, iterable of int, tensor, optional
The node id(s). The node id(s).
inplace: bool, optional
If True, update will be done in place, but autograd will break.
""" """
if func == "default": if func == "default":
func = self._apply_node_func func = self._apply_node_func
...@@ -1398,10 +1400,13 @@ class DGLGraph(object): ...@@ -1398,10 +1400,13 @@ class DGLGraph(object):
else: else:
v = utils.toindex(v) v = utils.toindex(v)
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_apply_nodes(graph=self, v=v, apply_func=func) scheduler.schedule_apply_nodes(graph=self,
v=v,
apply_func=func,
inplace=inplace)
Runtime.run(prog) Runtime.run(prog)
def apply_edges(self, func="default", edges=ALL): def apply_edges(self, func="default", edges=ALL, inplace=False):
"""Apply the function on the edge features. """Apply the function on the edge features.
Parameters Parameters
...@@ -1411,6 +1416,8 @@ class DGLGraph(object): ...@@ -1411,6 +1416,8 @@ class DGLGraph(object):
edges : edges, optional edges : edges, optional
Edges can be a pair of endpoint nodes (u, v), or a Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges. tensor of edge ids. The default value is all the edges.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
Notes Notes
----- -----
...@@ -1422,7 +1429,8 @@ class DGLGraph(object): ...@@ -1422,7 +1429,8 @@ class DGLGraph(object):
assert func is not None assert func is not None
if is_all(edges): if is_all(edges):
u, v, eid = self._graph.edges() u, v, _ = self._graph.edges()
eid = utils.toindex(slice(0, self.number_of_edges()))
elif isinstance(edges, tuple): elif isinstance(edges, tuple):
u, v = edges u, v = edges
u = utils.toindex(u) u = utils.toindex(u)
...@@ -1434,8 +1442,12 @@ class DGLGraph(object): ...@@ -1434,8 +1442,12 @@ class DGLGraph(object):
u, v, _ = self._graph.find_edges(eid) u, v, _ = self._graph.find_edges(eid)
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_apply_edges(graph=self, u=u, v=v, scheduler.schedule_apply_edges(graph=self,
eid=eid, apply_func=func) u=u,
v=v,
eid=eid,
apply_func=func,
inplace=inplace)
Runtime.run(prog) Runtime.run(prog)
def send(self, edges, message_func="default"): def send(self, edges, message_func="default"):
...@@ -1481,7 +1493,8 @@ class DGLGraph(object): ...@@ -1481,7 +1493,8 @@ class DGLGraph(object):
def recv(self, def recv(self,
v, v,
reduce_func="default", reduce_func="default",
apply_node_func="default"): apply_node_func="default",
inplace=False):
"""Receive and reduce in-coming messages and update representation on node v. """Receive and reduce in-coming messages and update representation on node v.
TODO(minjie): document on zero-in-degree case TODO(minjie): document on zero-in-degree case
...@@ -1496,6 +1509,8 @@ class DGLGraph(object): ...@@ -1496,6 +1509,8 @@ class DGLGraph(object):
The reduce function. The reduce function.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
""" """
if reduce_func == "default": if reduce_func == "default":
reduce_func = self._reduce_func reduce_func = self._reduce_func
...@@ -1518,8 +1533,11 @@ class DGLGraph(object): ...@@ -1518,8 +1533,11 @@ class DGLGraph(object):
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_recv(graph=self, recv_nodes=v, scheduler.schedule_recv(graph=self,
reduce_func=reduce_func, apply_func=apply_node_func) recv_nodes=v,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace)
Runtime.run(prog) Runtime.run(prog)
# FIXME(minjie): multi send bug # FIXME(minjie): multi send bug
...@@ -1529,7 +1547,8 @@ class DGLGraph(object): ...@@ -1529,7 +1547,8 @@ class DGLGraph(object):
edges, edges,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default"): apply_node_func="default",
inplace=False):
"""Send messages along edges and receive them on the targets. """Send messages along edges and receive them on the targets.
Parameters Parameters
...@@ -1546,6 +1565,8 @@ class DGLGraph(object): ...@@ -1546,6 +1565,8 @@ class DGLGraph(object):
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. Registered function will be used if not The update function. Registered function will be used if not
specified. specified.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
Notes Notes
----- -----
...@@ -1577,15 +1598,20 @@ class DGLGraph(object): ...@@ -1577,15 +1598,20 @@ class DGLGraph(object):
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_snr(self, (u, v, eid), scheduler.schedule_snr(graph=self,
message_func, reduce_func, apply_node_func) edge_tuples=(u, v, eid),
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace)
Runtime.run(prog) Runtime.run(prog)
def pull(self, def pull(self,
v, v,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default"): apply_node_func="default",
inplace=False):
"""Pull messages from the node's predecessors and then update it. """Pull messages from the node's predecessors and then update it.
Parameters Parameters
...@@ -1598,6 +1624,8 @@ class DGLGraph(object): ...@@ -1598,6 +1624,8 @@ class DGLGraph(object):
The reduce function. The reduce function.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
""" """
if message_func == "default": if message_func == "default":
message_func = self._message_func message_func = self._message_func
...@@ -1613,16 +1641,20 @@ class DGLGraph(object): ...@@ -1613,16 +1641,20 @@ class DGLGraph(object):
if len(v) == 0: if len(v) == 0:
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_pull(graph=self, pull_nodes=v, scheduler.schedule_pull(graph=self,
message_func=message_func, reduce_func=reduce_func, pull_nodes=v,
apply_func=apply_node_func) message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace)
Runtime.run(prog) Runtime.run(prog)
def push(self, def push(self,
u, u,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default"): apply_node_func="default",
inplace=False):
"""Send message from the node to its successors and update them. """Send message from the node to its successors and update them.
Parameters Parameters
...@@ -1635,6 +1667,8 @@ class DGLGraph(object): ...@@ -1635,6 +1667,8 @@ class DGLGraph(object):
The reduce function. The reduce function.
apply_node_func : callable apply_node_func : callable
The update function. The update function.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
""" """
if message_func == "default": if message_func == "default":
message_func = self._message_func message_func = self._message_func
...@@ -1650,9 +1684,12 @@ class DGLGraph(object): ...@@ -1650,9 +1684,12 @@ class DGLGraph(object):
if len(u) == 0: if len(u) == 0:
return return
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_push(graph=self, u=u, scheduler.schedule_push(graph=self,
message_func=message_func, reduce_func=reduce_func, u=u,
apply_func=apply_node_func) message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace)
Runtime.run(prog) Runtime.run(prog)
def update_all(self, def update_all(self,
...@@ -1680,8 +1717,10 @@ class DGLGraph(object): ...@@ -1680,8 +1717,10 @@ class DGLGraph(object):
assert reduce_func is not None assert reduce_func is not None
with ir.prog() as prog: with ir.prog() as prog:
scheduler.schedule_update_all(graph=self, message_func=message_func, scheduler.schedule_update_all(graph=self,
reduce_func=reduce_func, apply_func=apply_node_func) message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func)
Runtime.run(prog) Runtime.run(prog)
def prop_nodes(self, def prop_nodes(self,
......
...@@ -33,6 +33,7 @@ class OpCode(object): ...@@ -33,6 +33,7 @@ class OpCode(object):
WRITE_ROW_ = 23 WRITE_ROW_ = 23
WRITE_DICT_ = 24 WRITE_DICT_ = 24
APPEND_ROW_ = 25 APPEND_ROW_ = 25
WRITE_ROW_INPLACE_ = 26
class Executor(object): class Executor(object):
@abstractmethod @abstractmethod
...@@ -553,6 +554,38 @@ def WRITE_ROW_(fd, row, val): ...@@ -553,6 +554,38 @@ def WRITE_ROW_(fd, row, val):
reg = IR_REGISTRY[OpCode.WRITE_ROW_] reg = IR_REGISTRY[OpCode.WRITE_ROW_]
get_current_prog().issue(reg['executor_cls'](fd, row, val)) get_current_prog().issue(reg['executor_cls'](fd, row, val))
class WriteRowInplace_Executor(Executor):
def __init__(self, fd, row, val):
self.fd = fd
self.row = row
self.val = val
def opcode(self):
return OpCode.WRITE_ROW_INPLACE_
def arg_vars(self):
return [self.fd, self.row, self.val]
def ret_var(self):
return None
def run(self):
fd_data = self.fd.data # feature dict
row_data = self.row.data # idx
val_data = self.val.data
fd_data.set_item_inplace(row_data, val_data, inplace=True)
IR_REGISTRY[OpCode.WRITE_ROW_INPLACE_] = {
'name' : 'WRITE_ROW_INPLACE_',
'args_type' : [VarType.FEAT_DICT, VarType.IDX, VarType.FEAT_DICT],
'ret_type' : None,
'executor_cls' : WriteRowInplace_Executor,
}
def WRITE_ROW_INPLACE_(fd, row, val):
reg = IR_REGISTRY[OpCode.WRITE_ROW_INPLACE_]
get_current_prog().issue(reg['executor_cls'](fd, row, val))
class WriteDict_Executor(Executor): class WriteDict_Executor(Executor):
def __init__(self, fd1, fd2): def __init__(self, fd1, fd2):
self.fd1 = fd1 self.fd1 = fd1
......
...@@ -54,25 +54,31 @@ def schedule_send(graph, u, v, eid, message_func): ...@@ -54,25 +54,31 @@ def schedule_send(graph, u, v, eid, message_func):
# TODO: handle duplicate messages # TODO: handle duplicate messages
ir.APPEND_ROW_(mf, msg) ir.APPEND_ROW_(mf, msg)
def schedule_recv(graph, recv_nodes, reduce_func, apply_func): def schedule_recv(graph,
recv_nodes,
reduce_func,
apply_func,
inplace):
"""Schedule recv. """Schedule recv.
Parameters Parameters
---------- ----------
graph: DGLGraph graph: DGLGraph
The DGLGraph to use The DGLGraph to use
v : utils.Index recv_nodes: utils.Index
Nodes to recv. Nodes to recv.
reduce_func: callable or list of callable reduce_func: callable or list of callable
The reduce function The reduce function
apply_func: callable apply_func: callable
The apply node function The apply node function
inplace: bool
If True, the update will be done in place
""" """
src, dst, mid = graph._msg_graph.in_edges(recv_nodes) src, dst, mid = graph._msg_graph.in_edges(recv_nodes)
if len(mid) == 0: if len(mid) == 0:
# All recv nodes are 0-degree nodes; downgrade to apply nodes. # All recv nodes are 0-degree nodes; downgrade to apply nodes.
if apply_func is not None: if apply_func is not None:
schedule_apply_nodes(graph, recv_nodes, apply_func) schedule_apply_nodes(graph, recv_nodes, apply_func, inplace)
else: else:
var_nf = var.FEAT_DICT(graph._node_frame, name='nf') var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
# sort and unique the argument # sort and unique the argument
...@@ -83,13 +89,35 @@ def schedule_recv(graph, recv_nodes, reduce_func, apply_func): ...@@ -83,13 +89,35 @@ def schedule_recv(graph, recv_nodes, reduce_func, apply_func):
reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, mid), recv_nodes) reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, mid), recv_nodes)
# apply # apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func) final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_recv_nodes, final_feat)
else:
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
def schedule_snr(graph, def schedule_snr(graph,
edge_tuples, edge_tuples,
message_func, message_func,
reduce_func, reduce_func,
apply_func): apply_func,
inplace):
"""Schedule send_and_recv.
Parameters
----------
graph: DGLGraph
The DGLGraph to use
edge_tuple: tuple
A tuple of (src ids, dst ids, edge ids) representing edges to perform
send_and_recv
message_func: callable or list of callable
The message function
reduce_func: callable or list of callable
The reduce function
apply_func: callable
The apply node function
inplace: bool
If True, the update will be done in place
"""
call_type = 'send_and_recv' call_type = 'send_and_recv'
u, v, eid = edge_tuples u, v, eid = edge_tuples
recv_nodes, _ = F.sort_1d(F.unique(v.tousertensor())) recv_nodes, _ = F.sort_1d(F.unique(v.tousertensor()))
...@@ -110,9 +138,15 @@ def schedule_snr(graph, ...@@ -110,9 +138,15 @@ def schedule_snr(graph,
uv_getter, adj_creator, inc_creator) uv_getter, adj_creator, inc_creator)
# generate apply schedule # generate apply schedule
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func) final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_recv_nodes, final_feat)
else:
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
def schedule_update_all(graph, message_func, reduce_func, apply_func): def schedule_update_all(graph,
message_func,
reduce_func,
apply_func):
"""get send and recv schedule """get send and recv schedule
Parameters Parameters
...@@ -130,7 +164,7 @@ def schedule_update_all(graph, message_func, reduce_func, apply_func): ...@@ -130,7 +164,7 @@ def schedule_update_all(graph, message_func, reduce_func, apply_func):
# All the nodes are zero degree; downgrade to apply nodes # All the nodes are zero degree; downgrade to apply nodes
if apply_func is not None: if apply_func is not None:
nodes = utils.toindex(slice(0, graph.number_of_nodes())) nodes = utils.toindex(slice(0, graph.number_of_nodes()))
schedule_apply_nodes(graph, nodes, apply_func) schedule_apply_nodes(graph, nodes, apply_func, inplace=False)
else: else:
call_type = 'update_all' call_type = 'update_all'
eid = utils.toindex(slice(0, graph.number_of_edges())) # shortcut for ALL eid = utils.toindex(slice(0, graph.number_of_edges())) # shortcut for ALL
...@@ -153,7 +187,10 @@ def schedule_update_all(graph, message_func, reduce_func, apply_func): ...@@ -153,7 +187,10 @@ def schedule_update_all(graph, message_func, reduce_func, apply_func):
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func) final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
ir.WRITE_DICT_(var_nf, final_feat) ir.WRITE_DICT_(var_nf, final_feat)
def schedule_apply_nodes(graph, v, apply_func): def schedule_apply_nodes(graph,
v,
apply_func,
inplace):
"""get apply nodes schedule """get apply nodes schedule
Parameters Parameters
...@@ -164,6 +201,8 @@ def schedule_apply_nodes(graph, v, apply_func): ...@@ -164,6 +201,8 @@ def schedule_apply_nodes(graph, v, apply_func):
Nodes to apply Nodes to apply
apply_func: callable apply_func: callable
The apply node function The apply node function
inplace: bool
If True, the update will be done in place
Returns Returns
------- -------
...@@ -177,9 +216,15 @@ def schedule_apply_nodes(graph, v, apply_func): ...@@ -177,9 +216,15 @@ def schedule_apply_nodes(graph, v, apply_func):
return apply_func(nb) return apply_func(nb)
afunc = var.FUNC(_afunc_wrapper) afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf) applied_feat = ir.NODE_UDF(afunc, v_nf)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_v, applied_feat)
else:
ir.WRITE_ROW_(var_nf, var_v, applied_feat) ir.WRITE_ROW_(var_nf, var_v, applied_feat)
def schedule_apply_edges(graph, u, v, eid, apply_func): def schedule_apply_edges(graph,
u, v, eid,
apply_func,
inplace):
"""get apply edges schedule """get apply edges schedule
Parameters Parameters
...@@ -194,6 +239,8 @@ def schedule_apply_edges(graph, u, v, eid, apply_func): ...@@ -194,6 +239,8 @@ def schedule_apply_edges(graph, u, v, eid, apply_func):
Ids of sending edges Ids of sending edges
apply_func: callable apply_func: callable
The apply edge function The apply edge function
inplace: bool
If True, the update will be done in place
Returns Returns
------- -------
...@@ -215,9 +262,17 @@ def schedule_apply_edges(graph, u, v, eid, apply_func): ...@@ -215,9 +262,17 @@ def schedule_apply_edges(graph, u, v, eid, apply_func):
return apply_func(eb) return apply_func(eb)
_efunc = var.FUNC(_efunc_wrapper) _efunc = var.FUNC(_efunc_wrapper)
new_fdedge = ir.EDGE_UDF(_efunc, fdsrc, fdedge, fddst) new_fdedge = ir.EDGE_UDF(_efunc, fdsrc, fdedge, fddst)
if inplace:
ir.WRITE_ROW_INPLACE_(var_ef, var_eid, new_fdedge)
else:
ir.WRITE_ROW_(var_ef, var_eid, new_fdedge) ir.WRITE_ROW_(var_ef, var_eid, new_fdedge)
def schedule_push(graph, u, message_func, reduce_func, apply_func): def schedule_push(graph,
u,
message_func,
reduce_func,
apply_func,
inplace):
"""get push schedule """get push schedule
Parameters Parameters
...@@ -232,14 +287,22 @@ def schedule_push(graph, u, message_func, reduce_func, apply_func): ...@@ -232,14 +287,22 @@ def schedule_push(graph, u, message_func, reduce_func, apply_func):
The reduce function The reduce function
apply_func: callable apply_func: callable
The apply node function The apply node function
inplace: bool
If True, the update will be done in place
""" """
u, v, eid = graph._graph.out_edges(u) u, v, eid = graph._graph.out_edges(u)
if len(eid) == 0: if len(eid) == 0:
# All the pushing nodes have no out edges. No computation is scheduled. # All the pushing nodes have no out edges. No computation is scheduled.
return return
schedule_snr(graph, (u, v, eid), message_func, reduce_func, apply_func) schedule_snr(graph, (u, v, eid),
message_func, reduce_func, apply_func, inplace)
def schedule_pull(graph, pull_nodes, message_func, reduce_func, apply_func): def schedule_pull(graph,
pull_nodes,
message_func,
reduce_func,
apply_func,
inplace):
"""get pull schedule """get pull schedule
Parameters Parameters
...@@ -254,6 +317,8 @@ def schedule_pull(graph, pull_nodes, message_func, reduce_func, apply_func): ...@@ -254,6 +317,8 @@ def schedule_pull(graph, pull_nodes, message_func, reduce_func, apply_func):
The reduce function The reduce function
apply_func: callable apply_func: callable
The apply node function The apply node function
inplace: bool
If True, the update will be done in place
""" """
# TODO(minjie): `in_edges` can be omitted if message and reduce func pairs # TODO(minjie): `in_edges` can be omitted if message and reduce func pairs
# can be specialized to SPMV. This needs support for creating adjmat # can be specialized to SPMV. This needs support for creating adjmat
...@@ -262,7 +327,7 @@ def schedule_pull(graph, pull_nodes, message_func, reduce_func, apply_func): ...@@ -262,7 +327,7 @@ def schedule_pull(graph, pull_nodes, message_func, reduce_func, apply_func):
if len(eid) == 0: if len(eid) == 0:
# All the nodes are 0deg; downgrades to apply. # All the nodes are 0deg; downgrades to apply.
if apply_func is not None: if apply_func is not None:
schedule_apply_nodes(graph, pull_nodes, apply_func) schedule_apply_nodes(graph, pull_nodes, apply_func, inplace)
else: else:
call_type = 'send_and_recv' call_type = 'send_and_recv'
pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor())) pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor()))
...@@ -283,6 +348,9 @@ def schedule_pull(graph, pull_nodes, message_func, reduce_func, apply_func): ...@@ -283,6 +348,9 @@ def schedule_pull(graph, pull_nodes, message_func, reduce_func, apply_func):
uv_getter, adj_creator, inc_creator) uv_getter, adj_creator, inc_creator)
# generate optional apply # generate optional apply
final_feat = _apply_with_accum(graph, var_pull_nodes, var_nf, reduced_feat, apply_func) final_feat = _apply_with_accum(graph, var_pull_nodes, var_nf, reduced_feat, apply_func)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_pull_nodes, final_feat)
else:
ir.WRITE_ROW_(var_nf, var_pull_nodes, final_feat) ir.WRITE_ROW_(var_nf, var_pull_nodes, final_feat)
def _check_builtin_func_list(func_list): def _check_builtin_func_list(func_list):
......
import torch as th
import numpy as np
import scipy.sparse as sp
import dgl
import dgl.function as fn
import utils as U
D = 5
def generate_graph():
g = dgl.DGLGraph()
g.add_nodes(10)
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
g.ndata['f'] = th.randn(10, D)
g.edata['e'] = th.randn(17, D)
return g
def test_inplace_recv():
u = th.tensor([0, 0, 0, 3, 4, 9])
v = th.tensor([1, 2, 3, 9, 9, 0])
def message_func(edges):
return {'m' : edges.src['f'] + edges.dst['f']}
def reduce_func(nodes):
return {'f' : th.sum(nodes.mailbox['m'], 1)}
def apply_func(nodes):
return {'f' : 2 * nodes.data['f']}
def _test(apply_func):
g = generate_graph()
f = g.ndata['f']
# one out place run to get result
g.send((u, v), message_func)
g.recv([0,1,2,3,9], reduce_func, apply_func)
result = g.get_n_repr()['f']
# inplace deg bucket run
v1 = f.clone()
g.ndata['f'] = v1
g.send((u, v), message_func)
g.recv([0,1,2,3,9], reduce_func, apply_func, inplace=True)
r1 = g.get_n_repr()['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# inplace e2v
v1 = f.clone()
g.ndata['f'] = v1
g.send((u, v), message_func)
g.recv([0,1,2,3,9], fn.sum(msg='m', out='f'), apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# test send_and_recv with apply_func
_test(apply_func)
# test send_and_recv without apply_func
_test(None)
def test_inplace_snr():
u = th.tensor([0, 0, 0, 3, 4, 9])
v = th.tensor([1, 2, 3, 9, 9, 0])
def message_func(edges):
return {'m' : edges.src['f']}
def reduce_func(nodes):
return {'f' : th.sum(nodes.mailbox['m'], 1)}
def apply_func(nodes):
return {'f' : 2 * nodes.data['f']}
def _test(apply_func):
g = generate_graph()
f = g.ndata['f']
# an out place run to get result
g.send_and_recv((u, v), fn.copy_src(src='f', out='m'),
fn.sum(msg='m', out='f'), apply_func)
result = g.ndata['f']
# inplace deg bucket
v1 = f.clone()
g.ndata['f'] = v1
g.send_and_recv((u, v), message_func, reduce_func, apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# inplace v2v spmv
v1 = f.clone()
g.ndata['f'] = v1
g.send_and_recv((u, v), fn.copy_src(src='f', out='m'),
fn.sum(msg='m', out='f'), apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# inplace e2v spmv
v1 = f.clone()
g.ndata['f'] = v1
g.send_and_recv((u, v), message_func,
fn.sum(msg='m', out='f'), apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# test send_and_recv with apply_func
_test(apply_func)
# test send_and_recv without apply_func
_test(None)
def test_inplace_push():
nodes = th.tensor([0, 3, 4, 9])
def message_func(edges):
return {'m' : edges.src['f']}
def reduce_func(nodes):
return {'f' : th.sum(nodes.mailbox['m'], 1)}
def apply_func(nodes):
return {'f' : 2 * nodes.data['f']}
def _test(apply_func):
g = generate_graph()
f = g.ndata['f']
# an out place run to get result
g.push(nodes,
fn.copy_src(src='f', out='m'), fn.sum(msg='m', out='f'), apply_func)
result = g.ndata['f']
# inplace deg bucket
v1 = f.clone()
g.ndata['f'] = v1
g.push(nodes, message_func, reduce_func, apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# inplace v2v spmv
v1 = f.clone()
g.ndata['f'] = v1
g.push(nodes, fn.copy_src(src='f', out='m'),
fn.sum(msg='m', out='f'), apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# inplace e2v spmv
v1 = f.clone()
g.ndata['f'] = v1
g.push(nodes,
message_func, fn.sum(msg='m', out='f'), apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# test send_and_recv with apply_func
_test(apply_func)
# test send_and_recv without apply_func
_test(None)
def test_inplace_pull():
nodes = th.tensor([1, 2, 3, 9])
def message_func(edges):
return {'m' : edges.src['f']}
def reduce_func(nodes):
return {'f' : th.sum(nodes.mailbox['m'], 1)}
def apply_func(nodes):
return {'f' : 2 * nodes.data['f']}
def _test(apply_func):
g = generate_graph()
f = g.ndata['f']
# an out place run to get result
g.pull(nodes,
fn.copy_src(src='f', out='m'), fn.sum(msg='m', out='f'), apply_func)
result = g.ndata['f']
# inplace deg bucket
v1 = f.clone()
g.ndata['f'] = v1
g.pull(nodes, message_func, reduce_func, apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# inplace v2v spmv
v1 = f.clone()
g.ndata['f'] = v1
g.pull(nodes, fn.copy_src(src='f', out='m'),
fn.sum(msg='m', out='f'), apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# inplace e2v spmv
v1 = f.clone()
g.ndata['f'] = v1
g.pull(nodes,
message_func, fn.sum(msg='m', out='f'), apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# test send_and_recv with apply_func
_test(apply_func)
# test send_and_recv without apply_func
_test(None)
def test_inplace_apply():
def apply_node_func(nodes):
return {'f': nodes.data['f'] * 2}
def apply_edge_func(edges):
return {'e': edges.data['e'] * 2}
g = generate_graph()
nodes = [1, 2, 3, 9]
nf = g.ndata['f']
# out place run
g.apply_nodes(apply_node_func, nodes)
new_nf = g.ndata['f']
# in place run
g.ndata['f'] = nf
g.apply_nodes(apply_node_func, nodes, inplace=True)
# check results correct and in place
assert U.allclose(nf, new_nf)
# test apply all nodes, should not be done in place
g.ndata['f'] = nf
g.apply_nodes(apply_node_func, inplace=True)
assert U.allclose(nf, g.ndata['f']) == False
edges = [3, 5, 7, 10]
ef = g.edata['e']
# out place run
g.apply_edges(apply_edge_func, edges)
new_ef = g.edata['e']
# in place run
g.edata['e'] = ef
g.apply_edges(apply_edge_func, edges, inplace=True)
g.edata['e'] = ef
assert U.allclose(ef, new_ef)
# test apply all edges, should not be done in place
g.edata['e'] == ef
g.apply_edges(apply_edge_func, inplace=True)
assert U.allclose(ef, g.edata['e']) == False
if __name__ == '__main__':
test_inplace_recv()
test_inplace_snr()
test_inplace_push()
test_inplace_pull()
test_inplace_apply()
...@@ -274,7 +274,7 @@ def test_e2v_update_all_multi_fn(): ...@@ -274,7 +274,7 @@ def test_e2v_update_all_multi_fn():
apply_func_2) apply_func_2)
v3 = g.get_n_repr()[fld] v3 = g.get_n_repr()[fld]
assert th.allclose(v2, v3) assert U.allclose(v2, v3)
# test 1d node features # test 1d node features
_test('f1') _test('f1')
...@@ -312,7 +312,7 @@ def test_e2v_snr_multi_fn(): ...@@ -312,7 +312,7 @@ def test_e2v_snr_multi_fn():
apply_func_2) apply_func_2)
v3 = g.get_n_repr()[fld] v3 = g.get_n_repr()[fld]
assert th.allclose(v2, v3) assert U.allclose(v2, v3)
# test 1d node features # test 1d node features
_test('f1') _test('f1')
...@@ -352,7 +352,7 @@ def test_e2v_recv_multi_fn(): ...@@ -352,7 +352,7 @@ def test_e2v_recv_multi_fn():
apply_func_2) apply_func_2)
v3 = g.get_n_repr()[fld] v3 = g.get_n_repr()[fld]
assert th.allclose(v2, v3) assert U.allclose(v2, v3)
# test 1d node features # test 1d node features
_test('f1') _test('f1')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment