Unverified Commit 14af8402 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Perf] lazily create msg_index. (#563)

* lazily create msg_index.

* update test.
parent de54891f
...@@ -910,7 +910,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -910,7 +910,7 @@ class DGLGraph(DGLBaseGraph):
self._edge_frame = edge_frame self._edge_frame = edge_frame
# message indicator: # message indicator:
# if self._msg_index[eid] == 1, then edge eid has message # if self._msg_index[eid] == 1, then edge eid has message
self._msg_index = utils.zero_index(size=self.number_of_edges()) self._msg_index = None
# message frame # message frame
self._msg_frame = FrameRef(Frame(num_rows=self.number_of_edges())) self._msg_frame = FrameRef(Frame(num_rows=self.number_of_edges()))
# set initializer for message frame # set initializer for message frame
...@@ -921,6 +921,14 @@ class DGLGraph(DGLBaseGraph): ...@@ -921,6 +921,14 @@ class DGLGraph(DGLBaseGraph):
self._apply_node_func = None self._apply_node_func = None
self._apply_edge_func = None self._apply_edge_func = None
def _get_msg_index(self):
if self._msg_index is None:
self._msg_index = utils.zero_index(size=self.number_of_edges())
return self._msg_index
def _set_msg_index(self, index):
self._msg_index = index
def add_nodes(self, num, data=None): def add_nodes(self, num, data=None):
"""Add multiple new nodes. """Add multiple new nodes.
...@@ -1026,6 +1034,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -1026,6 +1034,7 @@ class DGLGraph(DGLBaseGraph):
else: else:
self._edge_frame.append(data) self._edge_frame.append(data)
# resize msg_index and msg_frame # resize msg_index and msg_frame
if self._msg_index is not None:
self._msg_index = self._msg_index.append_zeros(1) self._msg_index = self._msg_index.append_zeros(1)
self._msg_frame.add_rows(1) self._msg_frame.add_rows(1)
...@@ -1086,6 +1095,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -1086,6 +1095,7 @@ class DGLGraph(DGLBaseGraph):
else: else:
self._edge_frame.append(data) self._edge_frame.append(data)
# initialize feature placeholder for messages # initialize feature placeholder for messages
if self._msg_index is not None:
self._msg_index = self._msg_index.append_zeros(num) self._msg_index = self._msg_index.append_zeros(num)
self._msg_frame.add_rows(num) self._msg_frame.add_rows(num)
...@@ -1111,7 +1121,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -1111,7 +1121,7 @@ class DGLGraph(DGLBaseGraph):
self._graph.clear() self._graph.clear()
self._node_frame.clear() self._node_frame.clear()
self._edge_frame.clear() self._edge_frame.clear()
self._msg_index = utils.zero_index(0) self._msg_index = None
self._msg_frame.clear() self._msg_frame.clear()
def clear_cache(self): def clear_cache(self):
...@@ -1218,7 +1228,6 @@ class DGLGraph(DGLBaseGraph): ...@@ -1218,7 +1228,6 @@ class DGLGraph(DGLBaseGraph):
self._graph.from_networkx(nx_graph) self._graph.from_networkx(nx_graph)
self._node_frame.add_rows(self.number_of_nodes()) self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges()) self._edge_frame.add_rows(self.number_of_edges())
self._msg_index = utils.zero_index(self.number_of_edges())
self._msg_frame.add_rows(self.number_of_edges()) self._msg_frame.add_rows(self.number_of_edges())
# copy attributes # copy attributes
...@@ -1285,7 +1294,6 @@ class DGLGraph(DGLBaseGraph): ...@@ -1285,7 +1294,6 @@ class DGLGraph(DGLBaseGraph):
self._graph.from_scipy_sparse_matrix(spmat) self._graph.from_scipy_sparse_matrix(spmat)
self._node_frame.add_rows(self.number_of_nodes()) self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges()) self._edge_frame.add_rows(self.number_of_edges())
self._msg_index = utils.zero_index(self.number_of_edges())
self._msg_frame.add_rows(self.number_of_edges()) self._msg_frame.add_rows(self.number_of_edges())
def node_attr_schemes(self): def node_attr_schemes(self):
......
...@@ -56,7 +56,7 @@ def schedule_send(graph, u, v, eid, message_func): ...@@ -56,7 +56,7 @@ def schedule_send(graph, u, v, eid, message_func):
msg = _gen_send(graph, var_nf, var_nf, var_ef, var_u, var_v, var_eid, message_func) msg = _gen_send(graph, var_nf, var_nf, var_ef, var_u, var_v, var_eid, message_func)
ir.WRITE_ROW_(var_mf, var_eid, msg) ir.WRITE_ROW_(var_mf, var_eid, msg)
# set message indicator to 1 # set message indicator to 1
graph._msg_index = graph._msg_index.set_items(eid, 1) graph._set_msg_index(graph._get_msg_index().set_items(eid, 1))
def schedule_recv(graph, def schedule_recv(graph,
recv_nodes, recv_nodes,
...@@ -80,7 +80,7 @@ def schedule_recv(graph, ...@@ -80,7 +80,7 @@ def schedule_recv(graph,
""" """
src, dst, eid = graph._graph.in_edges(recv_nodes) src, dst, eid = graph._graph.in_edges(recv_nodes)
if len(eid) > 0: if len(eid) > 0:
nonzero_idx = graph._msg_index.get_items(eid).nonzero() nonzero_idx = graph._get_msg_index().get_items(eid).nonzero()
eid = eid.get_items(nonzero_idx) eid = eid.get_items(nonzero_idx)
src = src.get_items(nonzero_idx) src = src.get_items(nonzero_idx)
dst = dst.get_items(nonzero_idx) dst = dst.get_items(nonzero_idx)
...@@ -107,8 +107,8 @@ def schedule_recv(graph, ...@@ -107,8 +107,8 @@ def schedule_recv(graph,
else: else:
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat) ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
# set message indicator to 0 # set message indicator to 0
graph._msg_index = graph._msg_index.set_items(eid, 0) graph._set_msg_index(graph._get_msg_index().set_items(eid, 0))
if not graph._msg_index.has_nonzero(): if not graph._get_msg_index().has_nonzero():
ir.CLEAR_FRAME_(var.FEAT_DICT(graph._msg_frame, name='mf')) ir.CLEAR_FRAME_(var.FEAT_DICT(graph._msg_frame, name='mf'))
def schedule_snr(graph, def schedule_snr(graph,
......
...@@ -64,7 +64,7 @@ def test_multi_send(): ...@@ -64,7 +64,7 @@ def test_multi_send():
eid = g.edge_ids([0, 0, 0, 0, 0, 1, 2, 3, 4, 5], eid = g.edge_ids([0, 0, 0, 0, 0, 1, 2, 3, 4, 5],
[1, 2, 3, 4, 5, 9, 9, 9, 9, 9]) [1, 2, 3, 4, 5, 9, 9, 9, 9, 9])
expected[eid] = 1 expected[eid] = 1
assert F.array_equal(g._msg_index.tousertensor(), expected) assert F.array_equal(g._get_msg_index().tousertensor(), expected)
def test_multi_recv(): def test_multi_recv():
# basic recv test # basic recv test
...@@ -80,20 +80,20 @@ def test_multi_recv(): ...@@ -80,20 +80,20 @@ def test_multi_recv():
g.send((u, v)) g.send((u, v))
eid = g.edge_ids(u, v) eid = g.edge_ids(u, v)
expected[eid] = 1 expected[eid] = 1
assert F.array_equal(g._msg_index.tousertensor(), expected) assert F.array_equal(g._get_msg_index().tousertensor(), expected)
g.recv(v) g.recv(v)
expected[eid] = 0 expected[eid] = 0
assert F.array_equal(g._msg_index.tousertensor(), expected) assert F.array_equal(g._get_msg_index().tousertensor(), expected)
u = [0] u = [0]
v = [1, 2, 3] v = [1, 2, 3]
g.send((u, v)) g.send((u, v))
eid = g.edge_ids(u, v) eid = g.edge_ids(u, v)
expected[eid] = 1 expected[eid] = 1
assert F.array_equal(g._msg_index.tousertensor(), expected) assert F.array_equal(g._get_msg_index().tousertensor(), expected)
g.recv(v) g.recv(v)
expected[eid] = 0 expected[eid] = 0
assert F.array_equal(g._msg_index.tousertensor(), expected) assert F.array_equal(g._get_msg_index().tousertensor(), expected)
h1 = g.ndata['h'] h1 = g.ndata['h']
...@@ -104,19 +104,19 @@ def test_multi_recv(): ...@@ -104,19 +104,19 @@ def test_multi_recv():
g.send((u, v)) g.send((u, v))
eid = g.edge_ids(u, v) eid = g.edge_ids(u, v)
expected[eid] = 1 expected[eid] = 1
assert F.array_equal(g._msg_index.tousertensor(), expected) assert F.array_equal(g._get_msg_index().tousertensor(), expected)
u = [4, 5, 6] u = [4, 5, 6]
v = [9] v = [9]
g.recv(v) g.recv(v)
eid = g.edge_ids(u, v) eid = g.edge_ids(u, v)
expected[eid] = 0 expected[eid] = 0
assert F.array_equal(g._msg_index.tousertensor(), expected) assert F.array_equal(g._get_msg_index().tousertensor(), expected)
u = [0] u = [0]
v = [1, 2, 3] v = [1, 2, 3]
g.recv(v) g.recv(v)
eid = g.edge_ids(u, v) eid = g.edge_ids(u, v)
expected[eid] = 0 expected[eid] = 0
assert F.array_equal(g._msg_index.tousertensor(), expected) assert F.array_equal(g._get_msg_index().tousertensor(), expected)
h2 = g.ndata['h'] h2 = g.ndata['h']
assert F.allclose(h1, h2) assert F.allclose(h1, h2)
...@@ -250,7 +250,7 @@ def test_dynamic_addition(): ...@@ -250,7 +250,7 @@ def test_dynamic_addition():
'h2': F.randn((2, D))}) 'h2': F.randn((2, D))})
g.send() g.send()
expected = F.ones((g.number_of_edges(),), dtype=F.int64) expected = F.ones((g.number_of_edges(),), dtype=F.int64)
assert F.array_equal(g._msg_index.tousertensor(), expected) assert F.array_equal(g._get_msg_index().tousertensor(), expected)
# add more edges # add more edges
g.add_edges([0, 2], [2, 0], {'h1': F.randn((2, D))}) g.add_edges([0, 2], [2, 0], {'h1': F.randn((2, D))})
...@@ -281,10 +281,10 @@ def test_recv_no_send(): ...@@ -281,10 +281,10 @@ def test_recv_no_send():
g.send((1, 2), message_func) g.send((1, 2), message_func)
expected = F.zeros((2,), dtype=F.int64) expected = F.zeros((2,), dtype=F.int64)
expected[1] = 1 expected[1] = 1
assert F.array_equal(g._msg_index.tousertensor(), expected) assert F.array_equal(g._get_msg_index().tousertensor(), expected)
g.recv(2, reduce_func) g.recv(2, reduce_func)
expected[1] = 0 expected[1] = 0
assert F.array_equal(g._msg_index.tousertensor(), expected) assert F.array_equal(g._get_msg_index().tousertensor(), expected)
def test_send_recv_after_conversion(): def test_send_recv_after_conversion():
# test send and recv after converting from a graph with edges # test send and recv after converting from a graph with edges
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment