Unverified Commit 52ed09a3 authored by Lingfan Yu's avatar Lingfan Yu Committed by GitHub
Browse files

[Bug] Fix inplace update (#221)

* inplace write row op and executor

* update scheduler and graph to use inplace write

* fix

* fix bug

* test case for inplace

* fix bugs for inplace apply node/edge

* fix comments

* th.allclose -> U.allclose
parent dd26ff10
......@@ -31,7 +31,7 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
def __reduce__(self):
state = (self.shape, F.reverse_data_type_dict[self.dtype])
return self._reconstruct_scheme, state
@classmethod
def _reconstruct_scheme(cls, shape, dtype_str):
......@@ -219,7 +219,7 @@ class Frame(MutableMapping):
callable
The initializer
"""
return self._initializers.get(column, self._default_initializer)
return self._initializers.get(column, self._default_initializer)
def set_initializer(self, initializer, column=None):
"""Set the initializer for empty values, for a given column or all future
......@@ -287,7 +287,7 @@ class Frame(MutableMapping):
def __delitem__(self, name):
"""Delete the whole column.
Parameters
----------
name : str
......@@ -435,7 +435,7 @@ class FrameRef(MutableMapping):
@property
def schemes(self):
"""Return the frame schemes.
Returns
-------
dict of str to Scheme
......@@ -541,7 +541,7 @@ class FrameRef(MutableMapping):
If the provided key is an index or a slice, the corresponding rows will be selected.
The returned rows are saved in a lazy dictionary so only the real selection happens
when the explicit column name is provided.
Examples (using pytorch)
------------------------
>>> # create a frame of two columns and five rows
......@@ -550,7 +550,7 @@ class FrameRef(MutableMapping):
>>> # select the row 1 and 2, the returned `rows` is a lazy dictionary.
>>> rows = fr[Index([1, 2])]
>>> rows['c1'] # only select rows for 'c1' column; 'c2' column is not sliced.
Parameters
----------
key : str or utils.Index or slice
......@@ -611,6 +611,9 @@ class FrameRef(MutableMapping):
return utils.LazyDict(lambda key: self._frame[key][rows], keys=self.keys())
def __setitem__(self, key, val):
self.set_item_inplace(key, val, inplace=False)
def set_item_inplace(self, key, val, inplace):
"""Update the data in the frame.
If the provided key is string, the corresponding column data will be updated.
......@@ -629,9 +632,11 @@ class FrameRef(MutableMapping):
The key.
val : Tensor or dict of tensors
The value.
inplace: bool
If True, update will be done in place
"""
if isinstance(key, str):
self.update_column(key, val, inplace=False)
self.update_column(key, val, inplace=inplace)
elif isinstance(key, slice) and key == slice(0, self.num_rows):
# shortcut for updating all the rows
return self.update(val)
......@@ -639,7 +644,7 @@ class FrameRef(MutableMapping):
# shortcut for selecting all the rows
return self.update(val)
else:
self.update_rows(key, val, inplace=False)
self.update_rows(key, val, inplace=inplace)
def update_column(self, name, data, inplace):
"""Update the column.
......
......@@ -1163,8 +1163,8 @@ class DGLGraph(object):
and (D1, D2, ...) be the shape of the node representation tensor. The
length of the given node ids must match B (i.e, len(u) == B).
All update will be done out-placely to work with autograd unless the inplace
flag is true.
All update will be done out of place to work with autograd unless the
inplace flag is true.
Parameters
----------
......@@ -1173,7 +1173,7 @@ class DGLGraph(object):
u : node, container or tensor
The node(s).
inplace : bool
True if the update is done inplacely
If True, update will be done in place, but autograd will break.
"""
# sanity check
if not utils.is_dict_like(hu):
......@@ -1241,8 +1241,8 @@ class DGLGraph(object):
is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
and (D1, D2, ...) be the shape of the edge representation tensor.
All update will be done out-placely to work with autograd unless the inplace
flag is true.
All update will be done out of place to work with autograd unless the
inplace flag is true.
Parameters
----------
......@@ -1252,7 +1252,7 @@ class DGLGraph(object):
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
inplace : bool
True if the update is done inplacely
If True, update will be done in place, but autograd will break.
"""
# parse argument
if is_all(edges):
......@@ -1390,6 +1390,8 @@ class DGLGraph(object):
The UDF applied on the node features.
v : int, iterable of int, tensor, optional
The node id(s).
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
if func == "default":
func = self._apply_node_func
......@@ -1398,10 +1400,13 @@ class DGLGraph(object):
else:
v = utils.toindex(v)
with ir.prog() as prog:
scheduler.schedule_apply_nodes(graph=self, v=v, apply_func=func)
scheduler.schedule_apply_nodes(graph=self,
v=v,
apply_func=func,
inplace=inplace)
Runtime.run(prog)
def apply_edges(self, func="default", edges=ALL):
def apply_edges(self, func="default", edges=ALL, inplace=False):
"""Apply the function on the edge features.
Parameters
......@@ -1411,6 +1416,8 @@ class DGLGraph(object):
edges : edges, optional
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
Notes
-----
......@@ -1422,7 +1429,8 @@ class DGLGraph(object):
assert func is not None
if is_all(edges):
u, v, eid = self._graph.edges()
u, v, _ = self._graph.edges()
eid = utils.toindex(slice(0, self.number_of_edges()))
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
......@@ -1434,8 +1442,12 @@ class DGLGraph(object):
u, v, _ = self._graph.find_edges(eid)
with ir.prog() as prog:
scheduler.schedule_apply_edges(graph=self, u=u, v=v,
eid=eid, apply_func=func)
scheduler.schedule_apply_edges(graph=self,
u=u,
v=v,
eid=eid,
apply_func=func,
inplace=inplace)
Runtime.run(prog)
def send(self, edges, message_func="default"):
......@@ -1481,7 +1493,8 @@ class DGLGraph(object):
def recv(self,
v,
reduce_func="default",
apply_node_func="default"):
apply_node_func="default",
inplace=False):
"""Receive and reduce in-coming messages and update representation on node v.
TODO(minjie): document on zero-in-degree case
......@@ -1496,6 +1509,8 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable, optional
The update function.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
if reduce_func == "default":
reduce_func = self._reduce_func
......@@ -1518,8 +1533,11 @@ class DGLGraph(object):
return
with ir.prog() as prog:
scheduler.schedule_recv(graph=self, recv_nodes=v,
reduce_func=reduce_func, apply_func=apply_node_func)
scheduler.schedule_recv(graph=self,
recv_nodes=v,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace)
Runtime.run(prog)
# FIXME(minjie): multi send bug
......@@ -1529,7 +1547,8 @@ class DGLGraph(object):
edges,
message_func="default",
reduce_func="default",
apply_node_func="default"):
apply_node_func="default",
inplace=False):
"""Send messages along edges and receive them on the targets.
Parameters
......@@ -1546,6 +1565,8 @@ class DGLGraph(object):
apply_node_func : callable, optional
The update function. Registered function will be used if not
specified.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
Notes
-----
......@@ -1577,15 +1598,20 @@ class DGLGraph(object):
return
with ir.prog() as prog:
scheduler.schedule_snr(self, (u, v, eid),
message_func, reduce_func, apply_node_func)
scheduler.schedule_snr(graph=self,
edge_tuples=(u, v, eid),
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace)
Runtime.run(prog)
def pull(self,
v,
message_func="default",
reduce_func="default",
apply_node_func="default"):
apply_node_func="default",
inplace=False):
"""Pull messages from the node's predecessors and then update it.
Parameters
......@@ -1598,6 +1624,8 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable, optional
The update function.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
if message_func == "default":
message_func = self._message_func
......@@ -1613,16 +1641,20 @@ class DGLGraph(object):
if len(v) == 0:
return
with ir.prog() as prog:
scheduler.schedule_pull(graph=self, pull_nodes=v,
message_func=message_func, reduce_func=reduce_func,
apply_func=apply_node_func)
scheduler.schedule_pull(graph=self,
pull_nodes=v,
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace)
Runtime.run(prog)
def push(self,
u,
message_func="default",
reduce_func="default",
apply_node_func="default"):
apply_node_func="default",
inplace=False):
"""Send message from the node to its successors and update them.
Parameters
......@@ -1635,6 +1667,8 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable
The update function.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
if message_func == "default":
message_func = self._message_func
......@@ -1650,9 +1684,12 @@ class DGLGraph(object):
if len(u) == 0:
return
with ir.prog() as prog:
scheduler.schedule_push(graph=self, u=u,
message_func=message_func, reduce_func=reduce_func,
apply_func=apply_node_func)
scheduler.schedule_push(graph=self,
u=u,
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace)
Runtime.run(prog)
def update_all(self,
......@@ -1680,8 +1717,10 @@ class DGLGraph(object):
assert reduce_func is not None
with ir.prog() as prog:
scheduler.schedule_update_all(graph=self, message_func=message_func,
reduce_func=reduce_func, apply_func=apply_node_func)
scheduler.schedule_update_all(graph=self,
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func)
Runtime.run(prog)
def prop_nodes(self,
......
......@@ -33,6 +33,7 @@ class OpCode(object):
WRITE_ROW_ = 23
WRITE_DICT_ = 24
APPEND_ROW_ = 25
WRITE_ROW_INPLACE_ = 26
class Executor(object):
@abstractmethod
......@@ -553,6 +554,38 @@ def WRITE_ROW_(fd, row, val):
reg = IR_REGISTRY[OpCode.WRITE_ROW_]
get_current_prog().issue(reg['executor_cls'](fd, row, val))
class WriteRowInplace_Executor(Executor):
def __init__(self, fd, row, val):
self.fd = fd
self.row = row
self.val = val
def opcode(self):
return OpCode.WRITE_ROW_INPLACE_
def arg_vars(self):
return [self.fd, self.row, self.val]
def ret_var(self):
return None
def run(self):
fd_data = self.fd.data # feature dict
row_data = self.row.data # idx
val_data = self.val.data
fd_data.set_item_inplace(row_data, val_data, inplace=True)
IR_REGISTRY[OpCode.WRITE_ROW_INPLACE_] = {
'name' : 'WRITE_ROW_INPLACE_',
'args_type' : [VarType.FEAT_DICT, VarType.IDX, VarType.FEAT_DICT],
'ret_type' : None,
'executor_cls' : WriteRowInplace_Executor,
}
def WRITE_ROW_INPLACE_(fd, row, val):
reg = IR_REGISTRY[OpCode.WRITE_ROW_INPLACE_]
get_current_prog().issue(reg['executor_cls'](fd, row, val))
class WriteDict_Executor(Executor):
def __init__(self, fd1, fd2):
self.fd1 = fd1
......
......@@ -54,25 +54,31 @@ def schedule_send(graph, u, v, eid, message_func):
# TODO: handle duplicate messages
ir.APPEND_ROW_(mf, msg)
def schedule_recv(graph, recv_nodes, reduce_func, apply_func):
def schedule_recv(graph,
recv_nodes,
reduce_func,
apply_func,
inplace):
"""Schedule recv.
Parameters
----------
graph: DGLGraph
The DGLGraph to use
v : utils.Index
recv_nodes: utils.Index
Nodes to recv.
reduce_func: callable or list of callable
The reduce function
apply_func: callable
The apply node function
inplace: bool
If True, the update will be done in place
"""
src, dst, mid = graph._msg_graph.in_edges(recv_nodes)
if len(mid) == 0:
# All recv nodes are 0-degree nodes; downgrade to apply nodes.
if apply_func is not None:
schedule_apply_nodes(graph, recv_nodes, apply_func)
schedule_apply_nodes(graph, recv_nodes, apply_func, inplace)
else:
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
# sort and unique the argument
......@@ -83,13 +89,35 @@ def schedule_recv(graph, recv_nodes, reduce_func, apply_func):
reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, mid), recv_nodes)
# apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_recv_nodes, final_feat)
else:
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
def schedule_snr(graph,
edge_tuples,
message_func,
reduce_func,
apply_func):
apply_func,
inplace):
"""Schedule send_and_recv.
Parameters
----------
graph: DGLGraph
The DGLGraph to use
edge_tuple: tuple
A tuple of (src ids, dst ids, edge ids) representing edges to perform
send_and_recv
message_func: callable or list of callable
The message function
reduce_func: callable or list of callable
The reduce function
apply_func: callable
The apply node function
inplace: bool
If True, the update will be done in place
"""
call_type = 'send_and_recv'
u, v, eid = edge_tuples
recv_nodes, _ = F.sort_1d(F.unique(v.tousertensor()))
......@@ -110,9 +138,15 @@ def schedule_snr(graph,
uv_getter, adj_creator, inc_creator)
# generate apply schedule
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_recv_nodes, final_feat)
else:
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
def schedule_update_all(graph, message_func, reduce_func, apply_func):
def schedule_update_all(graph,
message_func,
reduce_func,
apply_func):
"""get send and recv schedule
Parameters
......@@ -130,7 +164,7 @@ def schedule_update_all(graph, message_func, reduce_func, apply_func):
# All the nodes are zero degree; downgrade to apply nodes
if apply_func is not None:
nodes = utils.toindex(slice(0, graph.number_of_nodes()))
schedule_apply_nodes(graph, nodes, apply_func)
schedule_apply_nodes(graph, nodes, apply_func, inplace=False)
else:
call_type = 'update_all'
eid = utils.toindex(slice(0, graph.number_of_edges())) # shortcut for ALL
......@@ -153,7 +187,10 @@ def schedule_update_all(graph, message_func, reduce_func, apply_func):
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat, apply_func)
ir.WRITE_DICT_(var_nf, final_feat)
def schedule_apply_nodes(graph, v, apply_func):
def schedule_apply_nodes(graph,
v,
apply_func,
inplace):
"""get apply nodes schedule
Parameters
......@@ -164,6 +201,8 @@ def schedule_apply_nodes(graph, v, apply_func):
Nodes to apply
apply_func: callable
The apply node function
inplace: bool
If True, the update will be done in place
Returns
-------
......@@ -177,9 +216,15 @@ def schedule_apply_nodes(graph, v, apply_func):
return apply_func(nb)
afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf)
ir.WRITE_ROW_(var_nf, var_v, applied_feat)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_v, applied_feat)
else:
ir.WRITE_ROW_(var_nf, var_v, applied_feat)
def schedule_apply_edges(graph, u, v, eid, apply_func):
def schedule_apply_edges(graph,
u, v, eid,
apply_func,
inplace):
"""get apply edges schedule
Parameters
......@@ -194,6 +239,8 @@ def schedule_apply_edges(graph, u, v, eid, apply_func):
Ids of sending edges
apply_func: callable
The apply edge function
inplace: bool
If True, the update will be done in place
Returns
-------
......@@ -215,9 +262,17 @@ def schedule_apply_edges(graph, u, v, eid, apply_func):
return apply_func(eb)
_efunc = var.FUNC(_efunc_wrapper)
new_fdedge = ir.EDGE_UDF(_efunc, fdsrc, fdedge, fddst)
ir.WRITE_ROW_(var_ef, var_eid, new_fdedge)
def schedule_push(graph, u, message_func, reduce_func, apply_func):
if inplace:
ir.WRITE_ROW_INPLACE_(var_ef, var_eid, new_fdedge)
else:
ir.WRITE_ROW_(var_ef, var_eid, new_fdedge)
def schedule_push(graph,
u,
message_func,
reduce_func,
apply_func,
inplace):
"""get push schedule
Parameters
......@@ -232,14 +287,22 @@ def schedule_push(graph, u, message_func, reduce_func, apply_func):
The reduce function
apply_func: callable
The apply node function
inplace: bool
If True, the update will be done in place
"""
u, v, eid = graph._graph.out_edges(u)
if len(eid) == 0:
# All the pushing nodes have no out edges. No computation is scheduled.
return
schedule_snr(graph, (u, v, eid), message_func, reduce_func, apply_func)
def schedule_pull(graph, pull_nodes, message_func, reduce_func, apply_func):
schedule_snr(graph, (u, v, eid),
message_func, reduce_func, apply_func, inplace)
def schedule_pull(graph,
pull_nodes,
message_func,
reduce_func,
apply_func,
inplace):
"""get pull schedule
Parameters
......@@ -254,6 +317,8 @@ def schedule_pull(graph, pull_nodes, message_func, reduce_func, apply_func):
The reduce function
apply_func: callable
The apply node function
inplace: bool
If True, the update will be done in place
"""
# TODO(minjie): `in_edges` can be omitted if message and reduce func pairs
# can be specialized to SPMV. This needs support for creating adjmat
......@@ -262,7 +327,7 @@ def schedule_pull(graph, pull_nodes, message_func, reduce_func, apply_func):
if len(eid) == 0:
# All the nodes are 0deg; downgrades to apply.
if apply_func is not None:
schedule_apply_nodes(graph, pull_nodes, apply_func)
schedule_apply_nodes(graph, pull_nodes, apply_func, inplace)
else:
call_type = 'send_and_recv'
pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor()))
......@@ -283,7 +348,10 @@ def schedule_pull(graph, pull_nodes, message_func, reduce_func, apply_func):
uv_getter, adj_creator, inc_creator)
# generate optional apply
final_feat = _apply_with_accum(graph, var_pull_nodes, var_nf, reduced_feat, apply_func)
ir.WRITE_ROW_(var_nf, var_pull_nodes, final_feat)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_pull_nodes, final_feat)
else:
ir.WRITE_ROW_(var_nf, var_pull_nodes, final_feat)
def _check_builtin_func_list(func_list):
"""Check whether func_list only contains builtin functions."""
......@@ -370,7 +438,7 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
# vars
msg = var.FEAT_DICT(graph._msg_frame, 'msg')
nf = var.FEAT_DICT(graph._node_frame, 'nf')
out = var.FEAT_DICT(data=tmpframe)
out = var.FEAT_DICT(data=tmpframe)
if rfunc_is_list:
# UDF message + builtin reducer
......@@ -469,7 +537,7 @@ def _gen_send_reduce(
# fall through from the v2v spmv analysis.
# In both cases, convert the mfunc to UDF.
mfunc = BundledFunction(mfunc)
# generate UDF send schedule
var_u, var_v = uv_getter()
var_mf = _gen_send(graph, var_nf, var_ef, var_u, var_v, var_eid, mfunc)
......
import torch as th
import numpy as np
import scipy.sparse as sp
import dgl
import dgl.function as fn
import utils as U
D = 5
def generate_graph():
g = dgl.DGLGraph()
g.add_nodes(10)
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
g.ndata['f'] = th.randn(10, D)
g.edata['e'] = th.randn(17, D)
return g
def test_inplace_recv():
u = th.tensor([0, 0, 0, 3, 4, 9])
v = th.tensor([1, 2, 3, 9, 9, 0])
def message_func(edges):
return {'m' : edges.src['f'] + edges.dst['f']}
def reduce_func(nodes):
return {'f' : th.sum(nodes.mailbox['m'], 1)}
def apply_func(nodes):
return {'f' : 2 * nodes.data['f']}
def _test(apply_func):
g = generate_graph()
f = g.ndata['f']
# one out place run to get result
g.send((u, v), message_func)
g.recv([0,1,2,3,9], reduce_func, apply_func)
result = g.get_n_repr()['f']
# inplace deg bucket run
v1 = f.clone()
g.ndata['f'] = v1
g.send((u, v), message_func)
g.recv([0,1,2,3,9], reduce_func, apply_func, inplace=True)
r1 = g.get_n_repr()['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# inplace e2v
v1 = f.clone()
g.ndata['f'] = v1
g.send((u, v), message_func)
g.recv([0,1,2,3,9], fn.sum(msg='m', out='f'), apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# test send_and_recv with apply_func
_test(apply_func)
# test send_and_recv without apply_func
_test(None)
def test_inplace_snr():
u = th.tensor([0, 0, 0, 3, 4, 9])
v = th.tensor([1, 2, 3, 9, 9, 0])
def message_func(edges):
return {'m' : edges.src['f']}
def reduce_func(nodes):
return {'f' : th.sum(nodes.mailbox['m'], 1)}
def apply_func(nodes):
return {'f' : 2 * nodes.data['f']}
def _test(apply_func):
g = generate_graph()
f = g.ndata['f']
# an out place run to get result
g.send_and_recv((u, v), fn.copy_src(src='f', out='m'),
fn.sum(msg='m', out='f'), apply_func)
result = g.ndata['f']
# inplace deg bucket
v1 = f.clone()
g.ndata['f'] = v1
g.send_and_recv((u, v), message_func, reduce_func, apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# inplace v2v spmv
v1 = f.clone()
g.ndata['f'] = v1
g.send_and_recv((u, v), fn.copy_src(src='f', out='m'),
fn.sum(msg='m', out='f'), apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# inplace e2v spmv
v1 = f.clone()
g.ndata['f'] = v1
g.send_and_recv((u, v), message_func,
fn.sum(msg='m', out='f'), apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# test send_and_recv with apply_func
_test(apply_func)
# test send_and_recv without apply_func
_test(None)
def test_inplace_push():
nodes = th.tensor([0, 3, 4, 9])
def message_func(edges):
return {'m' : edges.src['f']}
def reduce_func(nodes):
return {'f' : th.sum(nodes.mailbox['m'], 1)}
def apply_func(nodes):
return {'f' : 2 * nodes.data['f']}
def _test(apply_func):
g = generate_graph()
f = g.ndata['f']
# an out place run to get result
g.push(nodes,
fn.copy_src(src='f', out='m'), fn.sum(msg='m', out='f'), apply_func)
result = g.ndata['f']
# inplace deg bucket
v1 = f.clone()
g.ndata['f'] = v1
g.push(nodes, message_func, reduce_func, apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# inplace v2v spmv
v1 = f.clone()
g.ndata['f'] = v1
g.push(nodes, fn.copy_src(src='f', out='m'),
fn.sum(msg='m', out='f'), apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# inplace e2v spmv
v1 = f.clone()
g.ndata['f'] = v1
g.push(nodes,
message_func, fn.sum(msg='m', out='f'), apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# test send_and_recv with apply_func
_test(apply_func)
# test send_and_recv without apply_func
_test(None)
def test_inplace_pull():
nodes = th.tensor([1, 2, 3, 9])
def message_func(edges):
return {'m' : edges.src['f']}
def reduce_func(nodes):
return {'f' : th.sum(nodes.mailbox['m'], 1)}
def apply_func(nodes):
return {'f' : 2 * nodes.data['f']}
def _test(apply_func):
g = generate_graph()
f = g.ndata['f']
# an out place run to get result
g.pull(nodes,
fn.copy_src(src='f', out='m'), fn.sum(msg='m', out='f'), apply_func)
result = g.ndata['f']
# inplace deg bucket
v1 = f.clone()
g.ndata['f'] = v1
g.pull(nodes, message_func, reduce_func, apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# inplace v2v spmv
v1 = f.clone()
g.ndata['f'] = v1
g.pull(nodes, fn.copy_src(src='f', out='m'),
fn.sum(msg='m', out='f'), apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# inplace e2v spmv
v1 = f.clone()
g.ndata['f'] = v1
g.pull(nodes,
message_func, fn.sum(msg='m', out='f'), apply_func, inplace=True)
r1 = g.ndata['f']
# check result
assert U.allclose(r1, result)
# check inplace
assert U.allclose(v1, r1)
# test send_and_recv with apply_func
_test(apply_func)
# test send_and_recv without apply_func
_test(None)
def test_inplace_apply():
def apply_node_func(nodes):
return {'f': nodes.data['f'] * 2}
def apply_edge_func(edges):
return {'e': edges.data['e'] * 2}
g = generate_graph()
nodes = [1, 2, 3, 9]
nf = g.ndata['f']
# out place run
g.apply_nodes(apply_node_func, nodes)
new_nf = g.ndata['f']
# in place run
g.ndata['f'] = nf
g.apply_nodes(apply_node_func, nodes, inplace=True)
# check results correct and in place
assert U.allclose(nf, new_nf)
# test apply all nodes, should not be done in place
g.ndata['f'] = nf
g.apply_nodes(apply_node_func, inplace=True)
assert U.allclose(nf, g.ndata['f']) == False
edges = [3, 5, 7, 10]
ef = g.edata['e']
# out place run
g.apply_edges(apply_edge_func, edges)
new_ef = g.edata['e']
# in place run
g.edata['e'] = ef
g.apply_edges(apply_edge_func, edges, inplace=True)
g.edata['e'] = ef
assert U.allclose(ef, new_ef)
# test apply all edges, should not be done in place
g.edata['e'] == ef
g.apply_edges(apply_edge_func, inplace=True)
assert U.allclose(ef, g.edata['e']) == False
if __name__ == '__main__':
test_inplace_recv()
test_inplace_snr()
test_inplace_push()
test_inplace_pull()
test_inplace_apply()
......@@ -274,7 +274,7 @@ def test_e2v_update_all_multi_fn():
apply_func_2)
v3 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
assert U.allclose(v2, v3)
# test 1d node features
_test('f1')
......@@ -312,7 +312,7 @@ def test_e2v_snr_multi_fn():
apply_func_2)
v3 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
assert U.allclose(v2, v3)
# test 1d node features
_test('f1')
......@@ -352,7 +352,7 @@ def test_e2v_recv_multi_fn():
apply_func_2)
v3 = g.get_n_repr()[fld]
assert th.allclose(v2, v3)
assert U.allclose(v2, v3)
# test 1d node features
_test('f1')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment