"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "454418d2a6da6ebd5ad85e9d4b1c09ea69531ed7"
Unverified Commit 14af8402 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Perf] lazily create msg_index. (#563)

* lazily create msg_index.

* update test.
parent de54891f
......@@ -910,7 +910,7 @@ class DGLGraph(DGLBaseGraph):
self._edge_frame = edge_frame
# message indicator:
# if self._msg_index[eid] == 1, then edge eid has message
self._msg_index = utils.zero_index(size=self.number_of_edges())
self._msg_index = None
# message frame
self._msg_frame = FrameRef(Frame(num_rows=self.number_of_edges()))
# set initializer for message frame
......@@ -921,6 +921,14 @@ class DGLGraph(DGLBaseGraph):
self._apply_node_func = None
self._apply_edge_func = None
def _get_msg_index(self):
if self._msg_index is None:
self._msg_index = utils.zero_index(size=self.number_of_edges())
return self._msg_index
def _set_msg_index(self, index):
self._msg_index = index
def add_nodes(self, num, data=None):
"""Add multiple new nodes.
......@@ -1026,7 +1034,8 @@ class DGLGraph(DGLBaseGraph):
else:
self._edge_frame.append(data)
# resize msg_index and msg_frame
self._msg_index = self._msg_index.append_zeros(1)
if self._msg_index is not None:
self._msg_index = self._msg_index.append_zeros(1)
self._msg_frame.add_rows(1)
def add_edges(self, u, v, data=None):
......@@ -1086,7 +1095,8 @@ class DGLGraph(DGLBaseGraph):
else:
self._edge_frame.append(data)
# initialize feature placeholder for messages
self._msg_index = self._msg_index.append_zeros(num)
if self._msg_index is not None:
self._msg_index = self._msg_index.append_zeros(num)
self._msg_frame.add_rows(num)
def clear(self):
......@@ -1111,7 +1121,7 @@ class DGLGraph(DGLBaseGraph):
self._graph.clear()
self._node_frame.clear()
self._edge_frame.clear()
self._msg_index = utils.zero_index(0)
self._msg_index = None
self._msg_frame.clear()
def clear_cache(self):
......@@ -1218,7 +1228,6 @@ class DGLGraph(DGLBaseGraph):
self._graph.from_networkx(nx_graph)
self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges())
self._msg_index = utils.zero_index(self.number_of_edges())
self._msg_frame.add_rows(self.number_of_edges())
# copy attributes
......@@ -1285,7 +1294,6 @@ class DGLGraph(DGLBaseGraph):
self._graph.from_scipy_sparse_matrix(spmat)
self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges())
self._msg_index = utils.zero_index(self.number_of_edges())
self._msg_frame.add_rows(self.number_of_edges())
def node_attr_schemes(self):
......
......@@ -56,7 +56,7 @@ def schedule_send(graph, u, v, eid, message_func):
msg = _gen_send(graph, var_nf, var_nf, var_ef, var_u, var_v, var_eid, message_func)
ir.WRITE_ROW_(var_mf, var_eid, msg)
# set message indicator to 1
graph._msg_index = graph._msg_index.set_items(eid, 1)
graph._set_msg_index(graph._get_msg_index().set_items(eid, 1))
def schedule_recv(graph,
recv_nodes,
......@@ -80,7 +80,7 @@ def schedule_recv(graph,
"""
src, dst, eid = graph._graph.in_edges(recv_nodes)
if len(eid) > 0:
nonzero_idx = graph._msg_index.get_items(eid).nonzero()
nonzero_idx = graph._get_msg_index().get_items(eid).nonzero()
eid = eid.get_items(nonzero_idx)
src = src.get_items(nonzero_idx)
dst = dst.get_items(nonzero_idx)
......@@ -107,8 +107,8 @@ def schedule_recv(graph,
else:
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
# set message indicator to 0
graph._msg_index = graph._msg_index.set_items(eid, 0)
if not graph._msg_index.has_nonzero():
graph._set_msg_index(graph._get_msg_index().set_items(eid, 0))
if not graph._get_msg_index().has_nonzero():
ir.CLEAR_FRAME_(var.FEAT_DICT(graph._msg_frame, name='mf'))
def schedule_snr(graph,
......
......@@ -64,7 +64,7 @@ def test_multi_send():
eid = g.edge_ids([0, 0, 0, 0, 0, 1, 2, 3, 4, 5],
[1, 2, 3, 4, 5, 9, 9, 9, 9, 9])
expected[eid] = 1
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)
def test_multi_recv():
# basic recv test
......@@ -80,20 +80,20 @@ def test_multi_recv():
g.send((u, v))
eid = g.edge_ids(u, v)
expected[eid] = 1
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)
g.recv(v)
expected[eid] = 0
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)
u = [0]
v = [1, 2, 3]
g.send((u, v))
eid = g.edge_ids(u, v)
expected[eid] = 1
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)
g.recv(v)
expected[eid] = 0
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)
h1 = g.ndata['h']
......@@ -104,19 +104,19 @@ def test_multi_recv():
g.send((u, v))
eid = g.edge_ids(u, v)
expected[eid] = 1
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)
u = [4, 5, 6]
v = [9]
g.recv(v)
eid = g.edge_ids(u, v)
expected[eid] = 0
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)
u = [0]
v = [1, 2, 3]
g.recv(v)
eid = g.edge_ids(u, v)
expected[eid] = 0
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)
h2 = g.ndata['h']
assert F.allclose(h1, h2)
......@@ -250,7 +250,7 @@ def test_dynamic_addition():
'h2': F.randn((2, D))})
g.send()
expected = F.ones((g.number_of_edges(),), dtype=F.int64)
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)
# add more edges
g.add_edges([0, 2], [2, 0], {'h1': F.randn((2, D))})
......@@ -281,10 +281,10 @@ def test_recv_no_send():
g.send((1, 2), message_func)
expected = F.zeros((2,), dtype=F.int64)
expected[1] = 1
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)
g.recv(2, reduce_func)
expected[1] = 0
assert F.array_equal(g._msg_index.tousertensor(), expected)
assert F.array_equal(g._get_msg_index().tousertensor(), expected)
def test_send_recv_after_conversion():
# test send and recv after converting from a graph with edges
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment