Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
14af8402
Unverified
Commit
14af8402
authored
May 24, 2019
by
Da Zheng
Committed by
GitHub
May 24, 2019
Browse files
[Perf] lazily create msg_index. (#563)
* lazily create msg_index. * update test.
parent
de54891f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
21 deletions
+29
-21
python/dgl/graph.py
python/dgl/graph.py
+14
-6
python/dgl/runtime/scheduler.py
python/dgl/runtime/scheduler.py
+4
-4
tests/compute/test_multi_send_recv.py
tests/compute/test_multi_send_recv.py
+11
-11
No files found.
python/dgl/graph.py
View file @
14af8402
...
...
@@ -910,7 +910,7 @@ class DGLGraph(DGLBaseGraph):
self
.
_edge_frame
=
edge_frame
# message indicator:
# if self._msg_index[eid] == 1, then edge eid has message
self
.
_msg_index
=
utils
.
zero_index
(
size
=
self
.
number_of_edges
())
self
.
_msg_index
=
None
# message frame
self
.
_msg_frame
=
FrameRef
(
Frame
(
num_rows
=
self
.
number_of_edges
()))
# set initializer for message frame
...
...
@@ -921,6 +921,14 @@ class DGLGraph(DGLBaseGraph):
self
.
_apply_node_func
=
None
self
.
_apply_edge_func
=
None
def
_get_msg_index
(
self
):
if
self
.
_msg_index
is
None
:
self
.
_msg_index
=
utils
.
zero_index
(
size
=
self
.
number_of_edges
())
return
self
.
_msg_index
def
_set_msg_index
(
self
,
index
):
self
.
_msg_index
=
index
def
add_nodes
(
self
,
num
,
data
=
None
):
"""Add multiple new nodes.
...
...
@@ -1026,7 +1034,8 @@ class DGLGraph(DGLBaseGraph):
else
:
self
.
_edge_frame
.
append
(
data
)
# resize msg_index and msg_frame
self
.
_msg_index
=
self
.
_msg_index
.
append_zeros
(
1
)
if
self
.
_msg_index
is
not
None
:
self
.
_msg_index
=
self
.
_msg_index
.
append_zeros
(
1
)
self
.
_msg_frame
.
add_rows
(
1
)
def
add_edges
(
self
,
u
,
v
,
data
=
None
):
...
...
@@ -1086,7 +1095,8 @@ class DGLGraph(DGLBaseGraph):
else
:
self
.
_edge_frame
.
append
(
data
)
# initialize feature placeholder for messages
self
.
_msg_index
=
self
.
_msg_index
.
append_zeros
(
num
)
if
self
.
_msg_index
is
not
None
:
self
.
_msg_index
=
self
.
_msg_index
.
append_zeros
(
num
)
self
.
_msg_frame
.
add_rows
(
num
)
def
clear
(
self
):
...
...
@@ -1111,7 +1121,7 @@ class DGLGraph(DGLBaseGraph):
self
.
_graph
.
clear
()
self
.
_node_frame
.
clear
()
self
.
_edge_frame
.
clear
()
self
.
_msg_index
=
utils
.
zero_index
(
0
)
self
.
_msg_index
=
None
self
.
_msg_frame
.
clear
()
def
clear_cache
(
self
):
...
...
@@ -1218,7 +1228,6 @@ class DGLGraph(DGLBaseGraph):
self
.
_graph
.
from_networkx
(
nx_graph
)
self
.
_node_frame
.
add_rows
(
self
.
number_of_nodes
())
self
.
_edge_frame
.
add_rows
(
self
.
number_of_edges
())
self
.
_msg_index
=
utils
.
zero_index
(
self
.
number_of_edges
())
self
.
_msg_frame
.
add_rows
(
self
.
number_of_edges
())
# copy attributes
...
...
@@ -1285,7 +1294,6 @@ class DGLGraph(DGLBaseGraph):
self
.
_graph
.
from_scipy_sparse_matrix
(
spmat
)
self
.
_node_frame
.
add_rows
(
self
.
number_of_nodes
())
self
.
_edge_frame
.
add_rows
(
self
.
number_of_edges
())
self
.
_msg_index
=
utils
.
zero_index
(
self
.
number_of_edges
())
self
.
_msg_frame
.
add_rows
(
self
.
number_of_edges
())
def
node_attr_schemes
(
self
):
...
...
python/dgl/runtime/scheduler.py
View file @
14af8402
...
...
@@ -56,7 +56,7 @@ def schedule_send(graph, u, v, eid, message_func):
msg
=
_gen_send
(
graph
,
var_nf
,
var_nf
,
var_ef
,
var_u
,
var_v
,
var_eid
,
message_func
)
ir
.
WRITE_ROW_
(
var_mf
,
var_eid
,
msg
)
# set message indicator to 1
graph
.
_msg_index
=
graph
.
_msg_index
.
set_items
(
eid
,
1
)
graph
.
_
set_
msg_index
(
graph
.
_
get_
msg_index
()
.
set_items
(
eid
,
1
)
)
def
schedule_recv
(
graph
,
recv_nodes
,
...
...
@@ -80,7 +80,7 @@ def schedule_recv(graph,
"""
src
,
dst
,
eid
=
graph
.
_graph
.
in_edges
(
recv_nodes
)
if
len
(
eid
)
>
0
:
nonzero_idx
=
graph
.
_msg_index
.
get_items
(
eid
).
nonzero
()
nonzero_idx
=
graph
.
_
get_
msg_index
()
.
get_items
(
eid
).
nonzero
()
eid
=
eid
.
get_items
(
nonzero_idx
)
src
=
src
.
get_items
(
nonzero_idx
)
dst
=
dst
.
get_items
(
nonzero_idx
)
...
...
@@ -107,8 +107,8 @@ def schedule_recv(graph,
else
:
ir
.
WRITE_ROW_
(
var_nf
,
var_recv_nodes
,
final_feat
)
# set message indicator to 0
graph
.
_msg_index
=
graph
.
_msg_index
.
set_items
(
eid
,
0
)
if
not
graph
.
_msg_index
.
has_nonzero
():
graph
.
_
set_
msg_index
(
graph
.
_
get_
msg_index
()
.
set_items
(
eid
,
0
)
)
if
not
graph
.
_
get_
msg_index
()
.
has_nonzero
():
ir
.
CLEAR_FRAME_
(
var
.
FEAT_DICT
(
graph
.
_msg_frame
,
name
=
'mf'
))
def
schedule_snr
(
graph
,
...
...
tests/compute/test_multi_send_recv.py
View file @
14af8402
...
...
@@ -64,7 +64,7 @@ def test_multi_send():
eid
=
g
.
edge_ids
([
0
,
0
,
0
,
0
,
0
,
1
,
2
,
3
,
4
,
5
],
[
1
,
2
,
3
,
4
,
5
,
9
,
9
,
9
,
9
,
9
])
expected
[
eid
]
=
1
assert
F
.
array_equal
(
g
.
_msg_index
.
tousertensor
(),
expected
)
assert
F
.
array_equal
(
g
.
_
get_
msg_index
()
.
tousertensor
(),
expected
)
def
test_multi_recv
():
# basic recv test
...
...
@@ -80,20 +80,20 @@ def test_multi_recv():
g
.
send
((
u
,
v
))
eid
=
g
.
edge_ids
(
u
,
v
)
expected
[
eid
]
=
1
assert
F
.
array_equal
(
g
.
_msg_index
.
tousertensor
(),
expected
)
assert
F
.
array_equal
(
g
.
_
get_
msg_index
()
.
tousertensor
(),
expected
)
g
.
recv
(
v
)
expected
[
eid
]
=
0
assert
F
.
array_equal
(
g
.
_msg_index
.
tousertensor
(),
expected
)
assert
F
.
array_equal
(
g
.
_
get_
msg_index
()
.
tousertensor
(),
expected
)
u
=
[
0
]
v
=
[
1
,
2
,
3
]
g
.
send
((
u
,
v
))
eid
=
g
.
edge_ids
(
u
,
v
)
expected
[
eid
]
=
1
assert
F
.
array_equal
(
g
.
_msg_index
.
tousertensor
(),
expected
)
assert
F
.
array_equal
(
g
.
_
get_
msg_index
()
.
tousertensor
(),
expected
)
g
.
recv
(
v
)
expected
[
eid
]
=
0
assert
F
.
array_equal
(
g
.
_msg_index
.
tousertensor
(),
expected
)
assert
F
.
array_equal
(
g
.
_
get_
msg_index
()
.
tousertensor
(),
expected
)
h1
=
g
.
ndata
[
'h'
]
...
...
@@ -104,19 +104,19 @@ def test_multi_recv():
g
.
send
((
u
,
v
))
eid
=
g
.
edge_ids
(
u
,
v
)
expected
[
eid
]
=
1
assert
F
.
array_equal
(
g
.
_msg_index
.
tousertensor
(),
expected
)
assert
F
.
array_equal
(
g
.
_
get_
msg_index
()
.
tousertensor
(),
expected
)
u
=
[
4
,
5
,
6
]
v
=
[
9
]
g
.
recv
(
v
)
eid
=
g
.
edge_ids
(
u
,
v
)
expected
[
eid
]
=
0
assert
F
.
array_equal
(
g
.
_msg_index
.
tousertensor
(),
expected
)
assert
F
.
array_equal
(
g
.
_
get_
msg_index
()
.
tousertensor
(),
expected
)
u
=
[
0
]
v
=
[
1
,
2
,
3
]
g
.
recv
(
v
)
eid
=
g
.
edge_ids
(
u
,
v
)
expected
[
eid
]
=
0
assert
F
.
array_equal
(
g
.
_msg_index
.
tousertensor
(),
expected
)
assert
F
.
array_equal
(
g
.
_
get_
msg_index
()
.
tousertensor
(),
expected
)
h2
=
g
.
ndata
[
'h'
]
assert
F
.
allclose
(
h1
,
h2
)
...
...
@@ -250,7 +250,7 @@ def test_dynamic_addition():
'h2'
:
F
.
randn
((
2
,
D
))})
g
.
send
()
expected
=
F
.
ones
((
g
.
number_of_edges
(),),
dtype
=
F
.
int64
)
assert
F
.
array_equal
(
g
.
_msg_index
.
tousertensor
(),
expected
)
assert
F
.
array_equal
(
g
.
_
get_
msg_index
()
.
tousertensor
(),
expected
)
# add more edges
g
.
add_edges
([
0
,
2
],
[
2
,
0
],
{
'h1'
:
F
.
randn
((
2
,
D
))})
...
...
@@ -281,10 +281,10 @@ def test_recv_no_send():
g
.
send
((
1
,
2
),
message_func
)
expected
=
F
.
zeros
((
2
,),
dtype
=
F
.
int64
)
expected
[
1
]
=
1
assert
F
.
array_equal
(
g
.
_msg_index
.
tousertensor
(),
expected
)
assert
F
.
array_equal
(
g
.
_
get_
msg_index
()
.
tousertensor
(),
expected
)
g
.
recv
(
2
,
reduce_func
)
expected
[
1
]
=
0
assert
F
.
array_equal
(
g
.
_msg_index
.
tousertensor
(),
expected
)
assert
F
.
array_equal
(
g
.
_
get_
msg_index
()
.
tousertensor
(),
expected
)
def
test_send_recv_after_conversion
():
# test send and recv after converting from a graph with edges
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment