Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
7d04c8c9
Commit
7d04c8c9
authored
Oct 03, 2018
by
Minjie Wang
Browse files
remove nonbatchable mode
parent
3a3e5d48
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
122 additions
and
644 deletions
+122
-644
python/dgl/graph.py
python/dgl/graph.py
+56
-190
tests/pytorch/test_batching.py
tests/pytorch/test_batching.py
+11
-11
tests/pytorch/test_batching_anonymous.py
tests/pytorch/test_batching_anonymous.py
+5
-5
tests/pytorch/test_function.py
tests/pytorch/test_function.py
+26
-26
tests/pytorch/test_graph_batch.py
tests/pytorch/test_graph_batch.py
+4
-4
tests/pytorch/test_specialization.py
tests/pytorch/test_specialization.py
+20
-20
tests/test_anonymous_repr.py
tests/test_anonymous_repr.py
+0
-62
tests/test_basics.py
tests/test_basics.py
+0
-111
tests/test_basics2.py
tests/test_basics2.py
+0
-74
tests/test_function.py
tests/test_function.py
+0
-141
No files found.
python/dgl/graph.py
View file @
7d04c8c9
...
@@ -50,11 +50,11 @@ class DGLGraph(object):
...
@@ -50,11 +50,11 @@ class DGLGraph(object):
self
.
_msg_frame
=
FrameRef
()
self
.
_msg_frame
=
FrameRef
()
self
.
reset_messages
()
self
.
reset_messages
()
# registered functions
# registered functions
self
.
_message_func
=
(
None
,
None
)
self
.
_message_func
=
None
self
.
_reduce_func
=
(
None
,
None
)
self
.
_reduce_func
=
None
self
.
_edge_func
=
(
None
,
None
)
self
.
_edge_func
=
None
self
.
_apply_node_func
=
(
None
,
None
)
self
.
_apply_node_func
=
None
self
.
_apply_edge_func
=
(
None
,
None
)
self
.
_apply_edge_func
=
None
def
add_nodes
(
self
,
num
,
reprs
=
None
):
def
add_nodes
(
self
,
num
,
reprs
=
None
):
"""Add nodes.
"""Add nodes.
...
@@ -710,77 +710,57 @@ class DGLGraph(object):
...
@@ -710,77 +710,57 @@ class DGLGraph(object):
else
:
else
:
return
self
.
_edge_frame
.
select_rows
(
eid
)
return
self
.
_edge_frame
.
select_rows
(
eid
)
def
register_edge_func
(
self
,
def
register_edge_func
(
self
,
edge_func
):
edge_func
,
batchable
=
False
):
"""Register global edge update function.
"""Register global edge update function.
Parameters
Parameters
----------
----------
edge_func : callable
edge_func : callable
Message function on the edge.
Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
"""
"""
self
.
_edge_func
=
(
edge_func
,
batchable
)
self
.
_edge_func
=
edge_func
def
register_message_func
(
self
,
def
register_message_func
(
self
,
message_func
):
message_func
,
batchable
=
False
):
"""Register global message function.
"""Register global message function.
Parameters
Parameters
----------
----------
message_func : callable
message_func : callable
Message function on the edge.
Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
"""
"""
self
.
_message_func
=
(
message_func
,
batchable
)
self
.
_message_func
=
message_func
def
register_reduce_func
(
self
,
def
register_reduce_func
(
self
,
reduce_func
):
reduce_func
,
batchable
=
False
):
"""Register global message reduce function.
"""Register global message reduce function.
Parameters
Parameters
----------
----------
reduce_func : str or callable
reduce_func : str or callable
Reduce function on incoming edges.
Reduce function on incoming edges.
batchable : bool
Whether the provided reduce function allows batch computing.
"""
"""
self
.
_reduce_func
=
(
reduce_func
,
batchable
)
self
.
_reduce_func
=
reduce_func
def
register_apply_node_func
(
self
,
def
register_apply_node_func
(
self
,
apply_node_func
):
apply_node_func
,
batchable
=
False
):
"""Register global node apply function.
"""Register global node apply function.
Parameters
Parameters
----------
----------
apply_node_func : callable
apply_node_func : callable
Apply function on the node.
Apply function on the node.
batchable : bool
Whether the provided function allows batch computing.
"""
"""
self
.
_apply_node_func
=
(
apply_node_func
,
batchable
)
self
.
_apply_node_func
=
apply_node_func
def
register_apply_edge_func
(
self
,
def
register_apply_edge_func
(
self
,
apply_edge_func
):
apply_edge_func
,
batchable
=
False
):
"""Register global edge apply function.
"""Register global edge apply function.
Parameters
Parameters
----------
----------
apply_edge_func : callable
apply_edge_func : callable
Apply function on the edge.
Apply function on the edge.
batchable : bool
Whether the provided function allows batch computing.
"""
"""
self
.
_apply_edge_func
=
(
apply_edge_func
,
batchable
)
self
.
_apply_edge_func
=
apply_edge_func
def
apply_nodes
(
self
,
v
,
apply_node_func
=
"default"
,
batchable
=
False
):
def
apply_nodes
(
self
,
v
,
apply_node_func
=
"default"
):
"""Apply the function on node representations.
"""Apply the function on node representations.
Parameters
Parameters
...
@@ -789,27 +769,16 @@ class DGLGraph(object):
...
@@ -789,27 +769,16 @@ class DGLGraph(object):
The node id(s).
The node id(s).
apply_node_func : callable
apply_node_func : callable
The apply node function.
The apply node function.
batchable : bool
Whether the provided function allows batch computing.
"""
"""
if
apply_node_func
==
"default"
:
if
apply_node_func
==
"default"
:
apply_node_func
,
batchable
=
self
.
_apply_node_func
apply_node_func
=
self
.
_apply_node_func
if
not
apply_node_func
:
if
not
apply_node_func
:
# Skip none function call.
# Skip none function call.
return
return
if
batchable
:
new_repr
=
apply_node_func
(
self
.
get_n_repr
(
v
))
new_repr
=
apply_node_func
(
self
.
get_n_repr
(
v
))
self
.
set_n_repr
(
new_repr
,
v
)
self
.
set_n_repr
(
new_repr
,
v
)
else
:
raise
RuntimeError
(
'Disabled'
)
if
is_all
(
v
):
v
=
self
.
nodes
()
v
=
utils
.
toindex
(
v
)
for
vv
in
utils
.
node_iter
(
v
):
ret
=
apply_node_func
(
_get_repr
(
self
.
nodes
[
vv
]))
_set_repr
(
self
.
nodes
[
vv
],
ret
)
def
apply_edges
(
self
,
u
,
v
,
apply_edge_func
=
"default"
,
batchable
=
False
):
def
apply_edges
(
self
,
u
,
v
,
apply_edge_func
=
"default"
):
"""Apply the function on edge representations.
"""Apply the function on edge representations.
Parameters
Parameters
...
@@ -820,27 +789,16 @@ class DGLGraph(object):
...
@@ -820,27 +789,16 @@ class DGLGraph(object):
The dst node id(s).
The dst node id(s).
apply_edge_func : callable
apply_edge_func : callable
The apply edge function.
The apply edge function.
batchable : bool
Whether the provided function allows batch computing.
"""
"""
if
apply_edge_func
==
"default"
:
if
apply_edge_func
==
"default"
:
apply_edge_func
,
batchable
=
self
.
_apply_edge_func
apply_edge_func
=
self
.
_apply_edge_func
if
not
apply_edge_func
:
if
not
apply_edge_func
:
# Skip none function call.
# Skip none function call.
return
return
if
batchable
:
new_repr
=
apply_edge_func
(
self
.
get_e_repr
(
u
,
v
))
new_repr
=
apply_edge_func
(
self
.
get_e_repr
(
u
,
v
))
self
.
set_e_repr
(
new_repr
,
u
,
v
)
self
.
set_e_repr
(
new_repr
,
u
,
v
)
else
:
if
is_all
(
u
)
==
is_all
(
v
):
u
,
v
=
zip
(
*
self
.
edges
)
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
ret
=
apply_edge_func
(
_get_repr
(
self
.
edges
[
uu
,
vv
]))
_set_repr
(
self
.
edges
[
uu
,
vv
],
ret
)
def
send
(
self
,
u
,
v
,
message_func
=
"default"
,
batchable
=
False
):
def
send
(
self
,
u
,
v
,
message_func
=
"default"
):
"""Trigger the message function on edge u->v
"""Trigger the message function on edge u->v
The message function should be compatible with following signature:
The message function should be compatible with following signature:
...
@@ -861,30 +819,13 @@ class DGLGraph(object):
...
@@ -861,30 +819,13 @@ class DGLGraph(object):
The destination node(s).
The destination node(s).
message_func : callable
message_func : callable
The message function.
The message function.
batchable : bool
Whether the function allows batched computation.
"""
"""
if
message_func
==
"default"
:
if
message_func
==
"default"
:
message_func
,
batchable
=
self
.
_message_func
message_func
=
self
.
_message_func
assert
message_func
is
not
None
assert
message_func
is
not
None
if
isinstance
(
message_func
,
(
tuple
,
list
)):
if
isinstance
(
message_func
,
(
tuple
,
list
)):
message_func
=
BundledMessageFunction
(
message_func
)
message_func
=
BundledMessageFunction
(
message_func
)
if
batchable
:
self
.
_batch_send
(
u
,
v
,
message_func
)
self
.
_batch_send
(
u
,
v
,
message_func
)
else
:
self
.
_nonbatch_send
(
u
,
v
,
message_func
)
def
_nonbatch_send
(
self
,
u
,
v
,
message_func
):
raise
RuntimeError
(
'Disabled'
)
if
is_all
(
u
)
and
is_all
(
v
):
u
,
v
=
self
.
cached_graph
.
edges
()
else
:
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
ret
=
message_func
(
_get_repr
(
self
.
nodes
[
uu
]),
_get_repr
(
self
.
edges
[
uu
,
vv
]))
self
.
edges
[
uu
,
vv
][
__MSG__
]
=
ret
def
_batch_send
(
self
,
u
,
v
,
message_func
):
def
_batch_send
(
self
,
u
,
v
,
message_func
):
if
is_all
(
u
)
and
is_all
(
v
):
if
is_all
(
u
)
and
is_all
(
v
):
...
@@ -908,7 +849,7 @@ class DGLGraph(object):
...
@@ -908,7 +849,7 @@ class DGLGraph(object):
else
:
else
:
self
.
_msg_frame
.
append
({
__MSG__
:
msgs
})
self
.
_msg_frame
.
append
({
__MSG__
:
msgs
})
def
update_edge
(
self
,
u
=
ALL
,
v
=
ALL
,
edge_func
=
"default"
,
batchable
=
False
):
def
update_edge
(
self
,
u
=
ALL
,
v
=
ALL
,
edge_func
=
"default"
):
"""Update representation on edge u->v
"""Update representation on edge u->v
The edge function should be compatible with following signature:
The edge function should be compatible with following signature:
...
@@ -927,29 +868,11 @@ class DGLGraph(object):
...
@@ -927,29 +868,11 @@ class DGLGraph(object):
The destination node(s).
The destination node(s).
edge_func : callable
edge_func : callable
The update function.
The update function.
batchable : bool
Whether the function allows batched computation.
"""
"""
if
edge_func
==
"default"
:
if
edge_func
==
"default"
:
edge_func
,
batchable
=
self
.
_edge_func
edge_func
=
self
.
_edge_func
assert
edge_func
is
not
None
assert
edge_func
is
not
None
if
batchable
:
self
.
_batch_update_edge
(
u
,
v
,
edge_func
)
self
.
_batch_update_edge
(
u
,
v
,
edge_func
)
else
:
self
.
_nonbatch_update_edge
(
u
,
v
,
edge_func
)
def
_nonbatch_update_edge
(
self
,
u
,
v
,
edge_func
):
raise
RuntimeError
(
'Disabled'
)
if
is_all
(
u
)
and
is_all
(
v
):
u
,
v
=
self
.
cached_graph
.
edges
()
else
:
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
ret
=
edge_func
(
_get_repr
(
self
.
nodes
[
uu
]),
_get_repr
(
self
.
nodes
[
vv
]),
_get_repr
(
self
.
edges
[
uu
,
vv
]))
_set_repr
(
self
.
edges
[
uu
,
vv
],
ret
)
def
_batch_update_edge
(
self
,
u
,
v
,
edge_func
):
def
_batch_update_edge
(
self
,
u
,
v
,
edge_func
):
if
is_all
(
u
)
and
is_all
(
v
):
if
is_all
(
u
)
and
is_all
(
v
):
...
@@ -975,8 +898,7 @@ class DGLGraph(object):
...
@@ -975,8 +898,7 @@ class DGLGraph(object):
def
recv
(
self
,
def
recv
(
self
,
u
,
u
,
reduce_func
=
"default"
,
reduce_func
=
"default"
,
apply_node_func
=
"default"
,
apply_node_func
=
"default"
):
batchable
=
False
):
"""Receive and reduce in-coming messages and update representation on node u.
"""Receive and reduce in-coming messages and update representation on node u.
It computes the new node state using the messages sent from the predecessors
It computes the new node state using the messages sent from the predecessors
...
@@ -1006,34 +928,15 @@ class DGLGraph(object):
...
@@ -1006,34 +928,15 @@ class DGLGraph(object):
The reduce function.
The reduce function.
apply_node_func : callable, optional
apply_node_func : callable, optional
The update function.
The update function.
batchable : bool, optional
Whether the reduce and update function allows batched computation.
"""
"""
if
reduce_func
==
"default"
:
if
reduce_func
==
"default"
:
reduce_func
,
batchable
=
self
.
_reduce_func
reduce_func
=
self
.
_reduce_func
assert
reduce_func
is
not
None
assert
reduce_func
is
not
None
if
isinstance
(
reduce_func
,
(
list
,
tuple
)):
if
isinstance
(
reduce_func
,
(
list
,
tuple
)):
reduce_func
=
BundledReduceFunction
(
reduce_func
)
reduce_func
=
BundledReduceFunction
(
reduce_func
)
if
batchable
:
self
.
_batch_recv
(
u
,
reduce_func
)
self
.
_batch_recv
(
u
,
reduce_func
)
else
:
self
.
_nonbatch_recv
(
u
,
reduce_func
)
# optional apply nodes
# optional apply nodes
self
.
apply_nodes
(
u
,
apply_node_func
,
batchable
)
self
.
apply_nodes
(
u
,
apply_node_func
)
def
_nonbatch_recv
(
self
,
u
,
reduce_func
):
raise
RuntimeError
(
'Disabled'
)
if
is_all
(
u
):
u
=
list
(
range
(
0
,
self
.
number_of_nodes
()))
else
:
u
=
utils
.
toindex
(
u
)
for
i
,
uu
in
enumerate
(
utils
.
node_iter
(
u
)):
# reduce phase
msgs_batch
=
[
self
.
edges
[
vv
,
uu
].
pop
(
__MSG__
)
for
vv
in
self
.
pred
[
uu
]
if
__MSG__
in
self
.
edges
[
vv
,
uu
]]
if
len
(
msgs_batch
)
!=
0
:
new_repr
=
reduce_func
(
_get_repr
(
self
.
nodes
[
uu
]),
msgs_batch
)
_set_repr
(
self
.
nodes
[
uu
],
new_repr
)
def
_batch_recv
(
self
,
v
,
reduce_func
):
def
_batch_recv
(
self
,
v
,
reduce_func
):
if
self
.
_msg_frame
.
num_rows
==
0
:
if
self
.
_msg_frame
.
num_rows
==
0
:
...
@@ -1105,8 +1008,7 @@ class DGLGraph(object):
...
@@ -1105,8 +1008,7 @@ class DGLGraph(object):
u
,
v
,
u
,
v
,
message_func
=
"default"
,
message_func
=
"default"
,
reduce_func
=
"default"
,
reduce_func
=
"default"
,
apply_node_func
=
"default"
,
apply_node_func
=
"default"
):
batchable
=
False
):
"""Trigger the message function on u->v and update v.
"""Trigger the message function on u->v and update v.
Parameters
Parameters
...
@@ -1121,8 +1023,6 @@ class DGLGraph(object):
...
@@ -1121,8 +1023,6 @@ class DGLGraph(object):
The reduce function.
The reduce function.
apply_node_func : callable, optional
apply_node_func : callable, optional
The update function.
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
"""
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
v
=
utils
.
toindex
(
v
)
...
@@ -1132,34 +1032,28 @@ class DGLGraph(object):
...
@@ -1132,34 +1032,28 @@ class DGLGraph(object):
return
return
unique_v
=
utils
.
toindex
(
F
.
unique
(
v
.
tousertensor
()))
unique_v
=
utils
.
toindex
(
F
.
unique
(
v
.
tousertensor
()))
# TODO(minjie): better way to figure out `batchable` flag
if
message_func
==
"default"
:
if
message_func
==
"default"
:
message_func
,
batchable
=
self
.
_message_func
message_func
=
self
.
_message_func
if
reduce_func
==
"default"
:
if
reduce_func
==
"default"
:
reduce_func
,
_
=
self
.
_reduce_func
reduce_func
=
self
.
_reduce_func
assert
message_func
is
not
None
assert
message_func
is
not
None
assert
reduce_func
is
not
None
assert
reduce_func
is
not
None
if
batchable
:
executor
=
scheduler
.
get_executor
(
executor
=
scheduler
.
get_executor
(
'send_and_recv'
,
self
,
src
=
u
,
dst
=
v
,
'send_and_recv'
,
self
,
src
=
u
,
dst
=
v
,
message_func
=
message_func
,
reduce_func
=
reduce_func
)
message_func
=
message_func
,
reduce_func
=
reduce_func
)
else
:
executor
=
None
if
executor
:
if
executor
:
executor
.
run
()
executor
.
run
()
else
:
else
:
self
.
send
(
u
,
v
,
message_func
,
batchable
=
batchable
)
self
.
send
(
u
,
v
,
message_func
)
self
.
recv
(
unique_v
,
reduce_func
,
None
,
batchable
=
batchable
)
self
.
recv
(
unique_v
,
reduce_func
,
None
)
self
.
apply_nodes
(
unique_v
,
apply_node_func
,
batchable
=
batchable
)
self
.
apply_nodes
(
unique_v
,
apply_node_func
)
def
pull
(
self
,
def
pull
(
self
,
v
,
v
,
message_func
=
"default"
,
message_func
=
"default"
,
reduce_func
=
"default"
,
reduce_func
=
"default"
,
apply_node_func
=
"default"
,
apply_node_func
=
"default"
):
batchable
=
False
):
"""Pull messages from the node's predecessors and then update it.
"""Pull messages from the node's predecessors and then update it.
Parameters
Parameters
...
@@ -1172,24 +1066,20 @@ class DGLGraph(object):
...
@@ -1172,24 +1066,20 @@ class DGLGraph(object):
The reduce function.
The reduce function.
apply_node_func : callable, optional
apply_node_func : callable, optional
The update function.
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
"""
v
=
utils
.
toindex
(
v
)
v
=
utils
.
toindex
(
v
)
if
len
(
v
)
==
0
:
if
len
(
v
)
==
0
:
return
return
uu
,
vv
,
_
=
self
.
_graph
.
in_edges
(
v
)
uu
,
vv
,
_
=
self
.
_graph
.
in_edges
(
v
)
self
.
send_and_recv
(
uu
,
vv
,
message_func
,
reduce_func
,
self
.
send_and_recv
(
uu
,
vv
,
message_func
,
reduce_func
,
apply_node_func
=
None
)
apply_node_func
=
None
,
batchable
=
batchable
)
unique_v
=
F
.
unique
(
v
.
tousertensor
())
unique_v
=
F
.
unique
(
v
.
tousertensor
())
self
.
apply_nodes
(
unique_v
,
apply_node_func
,
batchable
=
batchable
)
self
.
apply_nodes
(
unique_v
,
apply_node_func
)
def
push
(
self
,
def
push
(
self
,
u
,
u
,
message_func
=
"default"
,
message_func
=
"default"
,
reduce_func
=
"default"
,
reduce_func
=
"default"
,
apply_node_func
=
"default"
,
apply_node_func
=
"default"
):
batchable
=
False
):
"""Send message from the node to its successors and update them.
"""Send message from the node to its successors and update them.
Parameters
Parameters
...
@@ -1202,21 +1092,18 @@ class DGLGraph(object):
...
@@ -1202,21 +1092,18 @@ class DGLGraph(object):
The reduce function.
The reduce function.
apply_node_func : callable
apply_node_func : callable
The update function.
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
"""
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
if
len
(
u
)
==
0
:
if
len
(
u
)
==
0
:
return
return
uu
,
vv
,
_
=
self
.
_graph
.
out_edges
(
u
)
uu
,
vv
,
_
=
self
.
_graph
.
out_edges
(
u
)
self
.
send_and_recv
(
uu
,
vv
,
message_func
,
self
.
send_and_recv
(
uu
,
vv
,
message_func
,
reduce_func
,
apply_node_func
,
batchable
=
batchable
)
reduce_func
,
apply_node_func
)
def
update_all
(
self
,
def
update_all
(
self
,
message_func
=
"default"
,
message_func
=
"default"
,
reduce_func
=
"default"
,
reduce_func
=
"default"
,
apply_node_func
=
"default"
,
apply_node_func
=
"default"
):
batchable
=
False
):
"""Send messages through all the edges and update all nodes.
"""Send messages through all the edges and update all nodes.
Parameters
Parameters
...
@@ -1227,35 +1114,28 @@ class DGLGraph(object):
...
@@ -1227,35 +1114,28 @@ class DGLGraph(object):
The reduce function.
The reduce function.
apply_node_func : callable, optional
apply_node_func : callable, optional
The update function.
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
"""
if
message_func
==
"default"
:
if
message_func
==
"default"
:
message_func
,
batchable
=
self
.
_message_func
message_func
=
self
.
_message_func
if
reduce_func
==
"default"
:
if
reduce_func
==
"default"
:
reduce_func
,
_
=
self
.
_reduce_func
reduce_func
=
self
.
_reduce_func
assert
message_func
is
not
None
assert
message_func
is
not
None
assert
reduce_func
is
not
None
assert
reduce_func
is
not
None
if
batchable
:
executor
=
scheduler
.
get_executor
(
executor
=
scheduler
.
get_executor
(
"update_all"
,
self
,
message_func
=
message_func
,
reduce_func
=
reduce_func
)
"update_all"
,
self
,
message_func
=
message_func
,
reduce_func
=
reduce_func
)
else
:
executor
=
None
if
executor
:
if
executor
:
executor
.
run
()
executor
.
run
()
else
:
else
:
self
.
send
(
ALL
,
ALL
,
message_func
,
batchable
=
batchable
)
self
.
send
(
ALL
,
ALL
,
message_func
)
self
.
recv
(
ALL
,
reduce_func
,
None
,
batchable
=
batchable
)
self
.
recv
(
ALL
,
reduce_func
,
None
)
self
.
apply_nodes
(
ALL
,
apply_node_func
,
batchable
=
batchable
)
self
.
apply_nodes
(
ALL
,
apply_node_func
)
def
propagate
(
self
,
def
propagate
(
self
,
iterator
=
'bfs'
,
iterator
=
'bfs'
,
message_func
=
"default"
,
message_func
=
"default"
,
reduce_func
=
"default"
,
reduce_func
=
"default"
,
apply_node_func
=
"default"
,
apply_node_func
=
"default"
,
batchable
=
False
,
**
kwargs
):
**
kwargs
):
"""Propagate messages and update nodes using iterator.
"""Propagate messages and update nodes using iterator.
...
@@ -1274,8 +1154,6 @@ class DGLGraph(object):
...
@@ -1274,8 +1154,6 @@ class DGLGraph(object):
The reduce function.
The reduce function.
apply_node_func : str or callable
apply_node_func : str or callable
The update function.
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
iterator : str or generator of steps.
iterator : str or generator of steps.
The iterator of the graph.
The iterator of the graph.
kwargs : keyword arguments, optional
kwargs : keyword arguments, optional
...
@@ -1288,7 +1166,7 @@ class DGLGraph(object):
...
@@ -1288,7 +1166,7 @@ class DGLGraph(object):
# NOTE: the iteration can return multiple edges at each step.
# NOTE: the iteration can return multiple edges at each step.
for
u
,
v
in
iterator
:
for
u
,
v
in
iterator
:
self
.
send_and_recv
(
u
,
v
,
self
.
send_and_recv
(
u
,
v
,
message_func
,
reduce_func
,
apply_node_func
,
batchable
)
message_func
,
reduce_func
,
apply_node_func
)
def
subgraph
(
self
,
nodes
):
def
subgraph
(
self
,
nodes
):
"""Generate the subgraph among the given nodes.
"""Generate the subgraph among the given nodes.
...
@@ -1350,15 +1228,3 @@ class DGLGraph(object):
...
@@ -1350,15 +1228,3 @@ class DGLGraph(object):
[
sg
.
_parent_eid
for
sg
in
to_merge
],
[
sg
.
_parent_eid
for
sg
in
to_merge
],
self
.
_edge_frame
.
num_rows
,
self
.
_edge_frame
.
num_rows
,
reduce_func
)
reduce_func
)
def
_get_repr
(
attr_dict
):
if
len
(
attr_dict
)
==
1
and
__REPR__
in
attr_dict
:
return
attr_dict
[
__REPR__
]
else
:
return
attr_dict
def
_set_repr
(
attr_dict
,
attr
):
if
utils
.
is_dict_like
(
attr
):
attr_dict
.
update
(
attr
)
else
:
attr_dict
[
__REPR__
]
=
attr
tests/pytorch/test_batching.py
View file @
7d04c8c9
...
@@ -133,7 +133,7 @@ def test_batch_send():
...
@@ -133,7 +133,7 @@ def test_batch_send():
def
_fmsg
(
src
,
edge
):
def
_fmsg
(
src
,
edge
):
assert
src
[
'h'
].
shape
==
(
5
,
D
)
assert
src
[
'h'
].
shape
==
(
5
,
D
)
return
{
'm'
:
src
[
'h'
]}
return
{
'm'
:
src
[
'h'
]}
g
.
register_message_func
(
_fmsg
,
batchable
=
True
)
g
.
register_message_func
(
_fmsg
)
# many-many send
# many-many send
u
=
th
.
tensor
([
0
,
0
,
0
,
0
,
0
])
u
=
th
.
tensor
([
0
,
0
,
0
,
0
,
0
])
v
=
th
.
tensor
([
1
,
2
,
3
,
4
,
5
])
v
=
th
.
tensor
([
1
,
2
,
3
,
4
,
5
])
...
@@ -150,9 +150,9 @@ def test_batch_send():
...
@@ -150,9 +150,9 @@ def test_batch_send():
def
test_batch_recv
():
def
test_batch_recv
():
# basic recv test
# basic recv test
g
=
generate_graph
()
g
=
generate_graph
()
g
.
register_message_func
(
message_func
,
batchable
=
True
)
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
,
batchable
=
True
)
g
.
register_reduce_func
(
reduce_func
)
g
.
register_apply_node_func
(
apply_node_func
,
batchable
=
True
)
g
.
register_apply_node_func
(
apply_node_func
)
u
=
th
.
tensor
([
0
,
0
,
0
,
4
,
5
,
6
])
u
=
th
.
tensor
([
0
,
0
,
0
,
4
,
5
,
6
])
v
=
th
.
tensor
([
1
,
2
,
3
,
9
,
9
,
9
])
v
=
th
.
tensor
([
1
,
2
,
3
,
9
,
9
,
9
])
reduce_msg_shapes
.
clear
()
reduce_msg_shapes
.
clear
()
...
@@ -163,9 +163,9 @@ def test_batch_recv():
...
@@ -163,9 +163,9 @@ def test_batch_recv():
def
test_update_routines
():
def
test_update_routines
():
g
=
generate_graph
()
g
=
generate_graph
()
g
.
register_message_func
(
message_func
,
batchable
=
True
)
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
,
batchable
=
True
)
g
.
register_reduce_func
(
reduce_func
)
g
.
register_apply_node_func
(
apply_node_func
,
batchable
=
True
)
g
.
register_apply_node_func
(
apply_node_func
)
# send_and_recv
# send_and_recv
reduce_msg_shapes
.
clear
()
reduce_msg_shapes
.
clear
()
...
@@ -209,7 +209,7 @@ def test_reduce_0deg():
...
@@ -209,7 +209,7 @@ def test_reduce_0deg():
return
node
+
msgs
.
sum
(
1
)
return
node
+
msgs
.
sum
(
1
)
old_repr
=
th
.
randn
(
5
,
5
)
old_repr
=
th
.
randn
(
5
,
5
)
g
.
set_n_repr
(
old_repr
)
g
.
set_n_repr
(
old_repr
)
g
.
update_all
(
_message
,
_reduce
,
batchable
=
True
)
g
.
update_all
(
_message
,
_reduce
)
new_repr
=
g
.
get_n_repr
()
new_repr
=
g
.
get_n_repr
()
assert
th
.
allclose
(
new_repr
[
1
:],
old_repr
[
1
:])
assert
th
.
allclose
(
new_repr
[
1
:],
old_repr
[
1
:])
...
@@ -227,17 +227,17 @@ def test_pull_0deg():
...
@@ -227,17 +227,17 @@ def test_pull_0deg():
old_repr
=
th
.
randn
(
2
,
5
)
old_repr
=
th
.
randn
(
2
,
5
)
g
.
set_n_repr
(
old_repr
)
g
.
set_n_repr
(
old_repr
)
g
.
pull
(
0
,
_message
,
_reduce
,
batchable
=
True
)
g
.
pull
(
0
,
_message
,
_reduce
)
new_repr
=
g
.
get_n_repr
()
new_repr
=
g
.
get_n_repr
()
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
1
])
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
1
])
g
.
pull
(
1
,
_message
,
_reduce
,
batchable
=
True
)
g
.
pull
(
1
,
_message
,
_reduce
)
new_repr
=
g
.
get_n_repr
()
new_repr
=
g
.
get_n_repr
()
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
0
])
old_repr
=
th
.
randn
(
2
,
5
)
old_repr
=
th
.
randn
(
2
,
5
)
g
.
set_n_repr
(
old_repr
)
g
.
set_n_repr
(
old_repr
)
g
.
pull
([
0
,
1
],
_message
,
_reduce
,
batchable
=
True
)
g
.
pull
([
0
,
1
],
_message
,
_reduce
)
new_repr
=
g
.
get_n_repr
()
new_repr
=
g
.
get_n_repr
()
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
0
])
...
...
tests/pytorch/test_batching_anonymous.py
View file @
7d04c8c9
...
@@ -129,7 +129,7 @@ def test_batch_send():
...
@@ -129,7 +129,7 @@ def test_batch_send():
def
_fmsg
(
hu
,
edge
):
def
_fmsg
(
hu
,
edge
):
assert
hu
.
shape
==
(
5
,
D
)
assert
hu
.
shape
==
(
5
,
D
)
return
hu
return
hu
g
.
register_message_func
(
_fmsg
,
batchable
=
True
)
g
.
register_message_func
(
_fmsg
)
# many-many send
# many-many send
u
=
th
.
tensor
([
0
,
0
,
0
,
0
,
0
])
u
=
th
.
tensor
([
0
,
0
,
0
,
0
,
0
])
v
=
th
.
tensor
([
1
,
2
,
3
,
4
,
5
])
v
=
th
.
tensor
([
1
,
2
,
3
,
4
,
5
])
...
@@ -145,8 +145,8 @@ def test_batch_send():
...
@@ -145,8 +145,8 @@ def test_batch_send():
def
test_batch_recv
():
def
test_batch_recv
():
g
=
generate_graph
()
g
=
generate_graph
()
g
.
register_message_func
(
message_func
,
batchable
=
True
)
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
,
batchable
=
True
)
g
.
register_reduce_func
(
reduce_func
)
u
=
th
.
tensor
([
0
,
0
,
0
,
4
,
5
,
6
])
u
=
th
.
tensor
([
0
,
0
,
0
,
4
,
5
,
6
])
v
=
th
.
tensor
([
1
,
2
,
3
,
9
,
9
,
9
])
v
=
th
.
tensor
([
1
,
2
,
3
,
9
,
9
,
9
])
reduce_msg_shapes
.
clear
()
reduce_msg_shapes
.
clear
()
...
@@ -157,8 +157,8 @@ def test_batch_recv():
...
@@ -157,8 +157,8 @@ def test_batch_recv():
def
test_update_routines
():
def
test_update_routines
():
g
=
generate_graph
()
g
=
generate_graph
()
g
.
register_message_func
(
message_func
,
batchable
=
True
)
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
,
batchable
=
True
)
g
.
register_reduce_func
(
reduce_func
)
# send_and_recv
# send_and_recv
reduce_msg_shapes
.
clear
()
reduce_msg_shapes
.
clear
()
...
...
tests/pytorch/test_function.py
View file @
7d04c8c9
...
@@ -51,32 +51,32 @@ def reducer_none(node, msgs):
...
@@ -51,32 +51,32 @@ def reducer_none(node, msgs):
def
test_copy_src
():
def
test_copy_src
():
# copy_src with both fields
# copy_src with both fields
g
=
generate_graph
()
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_src
(
src
=
'h'
,
out
=
'm'
)
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
copy_src
(
src
=
'h'
,
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy_src with only src field; the out field should use anonymous repr
# copy_src with only src field; the out field should use anonymous repr
g
=
generate_graph
()
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_src
(
src
=
'h'
)
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
copy_src
(
src
=
'h'
))
g
.
register_reduce_func
(
reducer_out
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy_src with no src field; should use anonymous repr
# copy_src with no src field; should use anonymous repr
g
=
generate_graph1
()
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_src
(
out
=
'm'
)
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
copy_src
(
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy src with no fields;
# copy src with no fields;
g
=
generate_graph1
()
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_src
()
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
copy_src
())
g
.
register_reduce_func
(
reducer_out
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
...
@@ -84,32 +84,32 @@ def test_copy_src():
...
@@ -84,32 +84,32 @@ def test_copy_src():
def
test_copy_edge
():
def
test_copy_edge
():
# copy_edge with both fields
# copy_edge with both fields
g
=
generate_graph
()
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_edge
(
edge
=
'h'
,
out
=
'm'
)
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
copy_edge
(
edge
=
'h'
,
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy_edge with only edge field; the out field should use anonymous repr
# copy_edge with only edge field; the out field should use anonymous repr
g
=
generate_graph
()
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_edge
(
edge
=
'h'
)
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
copy_edge
(
edge
=
'h'
))
g
.
register_reduce_func
(
reducer_out
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy_edge with no edge field; should use anonymous repr
# copy_edge with no edge field; should use anonymous repr
g
=
generate_graph1
()
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_edge
(
out
=
'm'
)
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
copy_edge
(
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy edge with no fields;
# copy edge with no fields;
g
=
generate_graph1
()
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_edge
()
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
copy_edge
())
g
.
register_reduce_func
(
reducer_out
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
...
@@ -117,36 +117,36 @@ def test_copy_edge():
...
@@ -117,36 +117,36 @@ def test_copy_edge():
def
test_src_mul_edge
():
def
test_src_mul_edge
():
# src_mul_edge with all fields
# src_mul_edge with all fields
g
=
generate_graph
()
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
src_mul_edge
(
src
=
'h'
,
edge
=
'h'
,
out
=
'm'
)
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
src_mul_edge
(
src
=
'h'
,
edge
=
'h'
,
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
g
=
generate_graph
()
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
src_mul_edge
(
src
=
'h'
,
edge
=
'h'
)
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
src_mul_edge
(
src
=
'h'
,
edge
=
'h'
))
g
.
register_reduce_func
(
reducer_out
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
g
=
generate_graph1
()
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
(
out
=
'm'
)
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
src_mul_edge
(
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
g
=
generate_graph1
()
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
()
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
src_mul_edge
())
g
.
register_reduce_func
(
reducer_out
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
g
=
generate_graph1
()
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
()
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
src_mul_edge
())
g
.
register_reduce_func
(
reducer_none
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_none
)
g
.
update_all
()
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
(),
assert
th
.
allclose
(
g
.
get_n_repr
(),
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
...
...
tests/pytorch/test_graph_batch.py
View file @
7d04c8c9
...
@@ -71,8 +71,8 @@ def test_batch_sendrecv():
...
@@ -71,8 +71,8 @@ def test_batch_sendrecv():
t2
=
tree2
()
t2
=
tree2
()
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
.
register_message_func
(
lambda
src
,
edge
:
src
,
batchable
=
True
)
bg
.
register_message_func
(
lambda
src
,
edge
:
src
)
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
torch
.
sum
(
msgs
,
1
)
,
batchable
=
True
)
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
torch
.
sum
(
msgs
,
1
))
e1
=
[(
3
,
1
),
(
4
,
1
)]
e1
=
[(
3
,
1
),
(
4
,
1
)]
e2
=
[(
2
,
4
),
(
0
,
4
)]
e2
=
[(
2
,
4
),
(
0
,
4
)]
...
@@ -94,8 +94,8 @@ def test_batch_propagate():
...
@@ -94,8 +94,8 @@ def test_batch_propagate():
t2
=
tree2
()
t2
=
tree2
()
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
.
register_message_func
(
lambda
src
,
edge
:
src
,
batchable
=
True
)
bg
.
register_message_func
(
lambda
src
,
edge
:
src
)
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
torch
.
sum
(
msgs
,
1
)
,
batchable
=
True
)
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
torch
.
sum
(
msgs
,
1
))
# get leaves.
# get leaves.
order
=
[]
order
=
[]
...
...
tests/pytorch/test_specialization.py
View file @
7d04c8c9
...
@@ -38,23 +38,23 @@ def test_update_all():
...
@@ -38,23 +38,23 @@ def test_update_all():
g
=
generate_graph
()
g
=
generate_graph
()
# update all
# update all
v1
=
g
.
get_n_repr
()[
fld
]
v1
=
g
.
get_n_repr
()[
fld
]
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
fld
),
apply_func
,
batchable
=
True
)
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
fld
),
apply_func
)
v2
=
g
.
get_n_repr
()[
fld
]
v2
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
message_func
,
reduce_func
,
apply_func
,
batchable
=
True
)
g
.
update_all
(
message_func
,
reduce_func
,
apply_func
)
v3
=
g
.
get_n_repr
()[
fld
]
v3
=
g
.
get_n_repr
()[
fld
]
assert
th
.
allclose
(
v2
,
v3
)
assert
th
.
allclose
(
v2
,
v3
)
# update all with edge weights
# update all with edge weights
v1
=
g
.
get_n_repr
()[
fld
]
v1
=
g
.
get_n_repr
()[
fld
]
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
),
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
),
fn
.
sum
(
out
=
fld
),
apply_func
,
batchable
=
True
)
fn
.
sum
(
out
=
fld
),
apply_func
)
v2
=
g
.
get_n_repr
()[
fld
]
v2
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
),
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
),
fn
.
sum
(
out
=
fld
),
apply_func
,
batchable
=
True
)
fn
.
sum
(
out
=
fld
),
apply_func
)
v3
=
g
.
get_n_repr
()[
fld
]
v3
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
message_func_edge
,
reduce_func
,
apply_func
,
batchable
=
True
)
g
.
update_all
(
message_func_edge
,
reduce_func
,
apply_func
)
v4
=
g
.
get_n_repr
()[
fld
]
v4
=
g
.
get_n_repr
()[
fld
]
assert
th
.
allclose
(
v2
,
v3
)
assert
th
.
allclose
(
v2
,
v3
)
assert
th
.
allclose
(
v3
,
v4
)
assert
th
.
allclose
(
v3
,
v4
)
...
@@ -85,25 +85,25 @@ def test_send_and_recv():
...
@@ -85,25 +85,25 @@ def test_send_and_recv():
# send and recv
# send and recv
v1
=
g
.
get_n_repr
()[
fld
]
v1
=
g
.
get_n_repr
()[
fld
]
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
fld
),
apply_func
,
batchable
=
True
)
fn
.
sum
(
out
=
fld
),
apply_func
)
v2
=
g
.
get_n_repr
()[
fld
]
v2
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
set_n_repr
({
fld
:
v1
})
g
.
send_and_recv
(
u
,
v
,
message_func
,
g
.
send_and_recv
(
u
,
v
,
message_func
,
reduce_func
,
apply_func
,
batchable
=
True
)
reduce_func
,
apply_func
)
v3
=
g
.
get_n_repr
()[
fld
]
v3
=
g
.
get_n_repr
()[
fld
]
assert
th
.
allclose
(
v2
,
v3
)
assert
th
.
allclose
(
v2
,
v3
)
# send and recv with edge weights
# send and recv with edge weights
v1
=
g
.
get_n_repr
()[
fld
]
v1
=
g
.
get_n_repr
()[
fld
]
g
.
send_and_recv
(
u
,
v
,
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
),
g
.
send_and_recv
(
u
,
v
,
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
),
fn
.
sum
(
out
=
fld
),
apply_func
,
batchable
=
True
)
fn
.
sum
(
out
=
fld
),
apply_func
)
v2
=
g
.
get_n_repr
()[
fld
]
v2
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
set_n_repr
({
fld
:
v1
})
g
.
send_and_recv
(
u
,
v
,
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
),
g
.
send_and_recv
(
u
,
v
,
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
),
fn
.
sum
(
out
=
fld
),
apply_func
,
batchable
=
True
)
fn
.
sum
(
out
=
fld
),
apply_func
)
v3
=
g
.
get_n_repr
()[
fld
]
v3
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
set_n_repr
({
fld
:
v1
})
g
.
send_and_recv
(
u
,
v
,
message_func_edge
,
g
.
send_and_recv
(
u
,
v
,
message_func_edge
,
reduce_func
,
apply_func
,
batchable
=
True
)
reduce_func
,
apply_func
)
v4
=
g
.
get_n_repr
()[
fld
]
v4
=
g
.
get_n_repr
()[
fld
]
assert
th
.
allclose
(
v2
,
v3
)
assert
th
.
allclose
(
v2
,
v3
)
assert
th
.
allclose
(
v3
,
v4
)
assert
th
.
allclose
(
v3
,
v4
)
...
@@ -127,18 +127,18 @@ def test_update_all_multi_fn():
...
@@ -127,18 +127,18 @@ def test_update_all_multi_fn():
# update all, mix of builtin and UDF
# update all, mix of builtin and UDF
g
.
update_all
([
fn
.
copy_src
(
src
=
fld
,
out
=
'm1'
),
message_func
],
g
.
update_all
([
fn
.
copy_src
(
src
=
fld
,
out
=
'm1'
),
message_func
],
[
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v1'
),
reduce_func
],
[
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v1'
),
reduce_func
],
None
,
batchable
=
True
)
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v2
)
# run builtin with single message and reduce
# run builtin with single message and reduce
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
'v1'
),
None
,
batchable
=
True
)
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
'v1'
),
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v1
=
g
.
get_n_repr
()[
'v1'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v2
)
# 1 message, 2 reduces, using anonymous repr
# 1 message, 2 reduces, using anonymous repr
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
[
fn
.
sum
(
out
=
'v2'
),
fn
.
sum
(
out
=
'v3'
)],
None
,
batchable
=
True
)
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
[
fn
.
sum
(
out
=
'v2'
),
fn
.
sum
(
out
=
'v3'
)],
None
)
v2
=
g
.
get_n_repr
()[
'v2'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v2
)
...
@@ -147,7 +147,7 @@ def test_update_all_multi_fn():
...
@@ -147,7 +147,7 @@ def test_update_all_multi_fn():
# update all with edge weights, 2 message, 3 reduces
# update all with edge weights, 2 message, 3 reduces
g
.
update_all
([
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm1'
),
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm2'
)],
g
.
update_all
([
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm1'
),
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm2'
)],
[
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v1'
),
fn
.
sum
(
msgs
=
'm2'
,
out
=
'v2'
),
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v3'
)],
[
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v1'
),
fn
.
sum
(
msgs
=
'm2'
,
out
=
'v2'
),
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v3'
)],
None
,
batchable
=
True
)
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
...
@@ -155,7 +155,7 @@ def test_update_all_multi_fn():
...
@@ -155,7 +155,7 @@ def test_update_all_multi_fn():
assert
th
.
allclose
(
v1
,
v3
)
assert
th
.
allclose
(
v1
,
v3
)
# run UDF with single message and reduce
# run UDF with single message and reduce
g
.
update_all
(
message_func_edge
,
reduce_func
,
None
,
batchable
=
True
)
g
.
update_all
(
message_func_edge
,
reduce_func
,
None
)
v2
=
g
.
get_n_repr
()[
'v2'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v2
)
...
@@ -179,19 +179,19 @@ def test_send_and_recv_multi_fn():
...
@@ -179,19 +179,19 @@ def test_send_and_recv_multi_fn():
g
.
send_and_recv
(
u
,
v
,
g
.
send_and_recv
(
u
,
v
,
[
fn
.
copy_src
(
src
=
fld
,
out
=
'm1'
),
message_func
],
[
fn
.
copy_src
(
src
=
fld
,
out
=
'm1'
),
message_func
],
[
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v1'
),
reduce_func
],
[
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v1'
),
reduce_func
],
None
,
batchable
=
True
)
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v2
)
# run builtin with single message and reduce
# run builtin with single message and reduce
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
'v1'
),
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
'v1'
),
None
,
batchable
=
True
)
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v1
=
g
.
get_n_repr
()[
'v1'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v2
)
# 1 message, 2 reduces, using anonymous repr
# 1 message, 2 reduces, using anonymous repr
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
[
fn
.
sum
(
out
=
'v2'
),
fn
.
sum
(
out
=
'v3'
)],
None
,
batchable
=
True
)
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
[
fn
.
sum
(
out
=
'v2'
),
fn
.
sum
(
out
=
'v3'
)],
None
)
v2
=
g
.
get_n_repr
()[
'v2'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v2
)
...
@@ -201,7 +201,7 @@ def test_send_and_recv_multi_fn():
...
@@ -201,7 +201,7 @@ def test_send_and_recv_multi_fn():
g
.
send_and_recv
(
u
,
v
,
g
.
send_and_recv
(
u
,
v
,
[
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm1'
),
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm2'
)],
[
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm1'
),
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm2'
)],
[
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v1'
),
fn
.
sum
(
msgs
=
'm2'
,
out
=
'v2'
),
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v3'
)],
[
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v1'
),
fn
.
sum
(
msgs
=
'm2'
,
out
=
'v2'
),
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v3'
)],
None
,
batchable
=
True
)
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
...
@@ -210,7 +210,7 @@ def test_send_and_recv_multi_fn():
...
@@ -210,7 +210,7 @@ def test_send_and_recv_multi_fn():
# run UDF with single message and reduce
# run UDF with single message and reduce
g
.
send_and_recv
(
u
,
v
,
message_func_edge
,
g
.
send_and_recv
(
u
,
v
,
message_func_edge
,
reduce_func
,
None
,
batchable
=
True
)
reduce_func
,
None
)
v2
=
g
.
get_n_repr
()[
'v2'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v2
)
...
...
tests/test_anonymous_repr.py
deleted
100644 → 0
View file @
3a3e5d48
from
dgl
import
DGLGraph
from
dgl.graph
import
__REPR__
def
message_func
(
hu
,
e_uv
):
return
hu
+
e_uv
def
reduce_func
(
h
,
msgs
):
return
h
+
sum
(
msgs
)
def
generate_graph
():
g
=
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
,
__REPR__
=
i
+
1
)
# 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
,
__REPR__
=
1
)
g
.
add_edge
(
i
,
9
,
__REPR__
=
1
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
)
return
g
def
check
(
g
,
h
):
nh
=
[
str
(
g
.
nodes
[
i
][
__REPR__
])
for
i
in
range
(
10
)]
h
=
[
str
(
x
)
for
x
in
h
]
assert
nh
==
h
,
"nh=[%s], h=[%s]"
%
(
' '
.
join
(
nh
),
' '
.
join
(
h
))
def
test_sendrecv
():
g
=
generate_graph
()
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
)
g
.
send
(
0
,
1
)
g
.
recv
(
1
)
check
(
g
,
[
1
,
4
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
send
(
5
,
9
)
g
.
send
(
6
,
9
)
g
.
recv
(
9
)
check
(
g
,
[
1
,
4
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
25
])
def
message_func_hybrid
(
src
,
edge
):
return
src
[
__REPR__
]
+
edge
def
reduce_func_hybrid
(
node
,
msgs
):
return
node
[
__REPR__
]
+
sum
(
msgs
)
def
test_hybridrepr
():
g
=
generate_graph
()
for
i
in
range
(
10
):
g
.
nodes
[
i
][
'id'
]
=
-
i
g
.
register_message_func
(
message_func_hybrid
)
g
.
register_reduce_func
(
reduce_func_hybrid
)
g
.
send
(
0
,
1
)
g
.
recv
(
1
)
check
(
g
,
[
1
,
4
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
send
(
5
,
9
)
g
.
send
(
6
,
9
)
g
.
recv
(
9
)
check
(
g
,
[
1
,
4
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
25
])
if
__name__
==
'__main__'
:
test_sendrecv
()
test_hybridrepr
()
tests/test_basics.py
deleted
100644 → 0
View file @
3a3e5d48
from
dgl.graph
import
DGLGraph
def
message_func
(
src
,
edge
):
return
src
[
'h'
]
def
reduce_func
(
node
,
msgs
):
return
{
'm'
:
sum
(
msgs
)}
def
apply_func
(
node
):
return
{
'h'
:
node
[
'h'
]
+
node
[
'm'
]}
def
message_dict_func
(
src
,
edge
):
return
{
'm'
:
src
[
'h'
]}
def
reduce_dict_func
(
node
,
msgs
):
return
{
'm'
:
sum
([
msg
[
'm'
]
for
msg
in
msgs
])}
def
apply_dict_func
(
node
):
return
{
'h'
:
node
[
'h'
]
+
node
[
'm'
]}
def
generate_graph
():
g
=
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
,
h
=
i
+
1
)
# 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
i
,
9
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
)
return
g
def
check
(
g
,
h
):
nh
=
[
str
(
g
.
nodes
[
i
][
'h'
])
for
i
in
range
(
10
)]
h
=
[
str
(
x
)
for
x
in
h
]
assert
nh
==
h
,
"nh=[%s], h=[%s]"
%
(
' '
.
join
(
nh
),
' '
.
join
(
h
))
def
register1
(
g
):
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
)
g
.
register_apply_node_func
(
apply_func
)
def
register2
(
g
):
g
.
register_message_func
(
message_dict_func
)
g
.
register_reduce_func
(
reduce_dict_func
)
g
.
register_apply_node_func
(
apply_dict_func
)
def
_test_sendrecv
(
g
):
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
send
(
0
,
1
)
g
.
recv
(
1
)
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
send
(
5
,
9
)
g
.
send
(
6
,
9
)
g
.
recv
(
9
)
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
23
])
def
_test_multi_sendrecv
(
g
):
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
# one-many
g
.
send
(
0
,
[
1
,
2
,
3
])
g
.
recv
([
1
,
2
,
3
])
check
(
g
,
[
1
,
3
,
4
,
5
,
5
,
6
,
7
,
8
,
9
,
10
])
# many-one
g
.
send
([
6
,
7
,
8
],
9
)
g
.
recv
(
9
)
check
(
g
,
[
1
,
3
,
4
,
5
,
5
,
6
,
7
,
8
,
9
,
34
])
# many-many
g
.
send
([
0
,
0
,
4
,
5
],
[
4
,
5
,
9
,
9
])
g
.
recv
([
4
,
5
,
9
])
check
(
g
,
[
1
,
3
,
4
,
5
,
6
,
7
,
7
,
8
,
9
,
45
])
def
_test_update_routines
(
g
):
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
send_and_recv
(
0
,
1
)
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
pull
(
9
)
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
55
])
g
.
push
(
0
)
check
(
g
,
[
1
,
4
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
55
])
g
.
update_all
()
check
(
g
,
[
56
,
5
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
108
])
def
test_sendrecv
():
g
=
generate_graph
()
register1
(
g
)
_test_sendrecv
(
g
)
g
=
generate_graph
()
register2
(
g
)
_test_sendrecv
(
g
)
def
test_multi_sendrecv
():
g
=
generate_graph
()
register1
(
g
)
_test_multi_sendrecv
(
g
)
g
=
generate_graph
()
register2
(
g
)
_test_multi_sendrecv
(
g
)
def
test_update_routines
():
g
=
generate_graph
()
register1
(
g
)
_test_update_routines
(
g
)
g
=
generate_graph
()
register2
(
g
)
_test_update_routines
(
g
)
if
__name__
==
'__main__'
:
test_sendrecv
()
test_multi_sendrecv
()
test_update_routines
()
tests/test_basics2.py
deleted
100644 → 0
View file @
3a3e5d48
from
dgl
import
DGLGraph
from
dgl.graph
import
__REPR__
def
message_func
(
hu
,
e_uv
):
return
hu
def
message_not_called
(
hu
,
e_uv
):
assert
False
return
hu
def
reduce_not_called
(
h
,
msgs
):
assert
False
return
0
def
reduce_func
(
h
,
msgs
):
return
h
+
sum
(
msgs
)
def
check
(
g
,
h
):
nh
=
[
str
(
g
.
nodes
[
i
][
__REPR__
])
for
i
in
range
(
10
)]
h
=
[
str
(
x
)
for
x
in
h
]
assert
nh
==
h
,
"nh=[%s], h=[%s]"
%
(
' '
.
join
(
nh
),
' '
.
join
(
h
))
def
generate_graph
():
g
=
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
,
__REPR__
=
i
+
1
)
# 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
i
,
9
)
return
g
def
test_no_msg_recv
():
g
=
generate_graph
()
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_not_called
)
g
.
register_reduce_func
(
reduce_not_called
)
g
.
register_apply_node_func
(
lambda
h
:
h
+
1
)
for
i
in
range
(
10
):
g
.
recv
(
i
)
check
(
g
,
[
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
])
def
test_double_recv
():
g
=
generate_graph
()
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
)
g
.
send
(
1
,
9
)
g
.
send
(
2
,
9
)
g
.
recv
(
9
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
15
])
g
.
register_reduce_func
(
reduce_not_called
)
g
.
recv
(
9
)
def
test_pull_0deg
():
g
=
DGLGraph
()
g
.
add_node
(
0
,
h
=
2
)
g
.
add_node
(
1
,
h
=
1
)
g
.
add_edge
(
0
,
1
)
def
_message
(
src
,
edge
):
assert
False
return
src
def
_reduce
(
node
,
msgs
):
assert
False
return
node
def
_update
(
node
):
return
{
'h'
:
node
[
'h'
]
*
2
}
g
.
pull
(
0
,
_message
,
_reduce
,
_update
)
assert
g
.
nodes
[
0
][
'h'
]
==
4
if
__name__
==
'__main__'
:
test_no_msg_recv
()
test_double_recv
()
test_pull_0deg
()
tests/test_function.py
deleted
100644 → 0
View file @
3a3e5d48
import
dgl
import
dgl.function
as
fn
from
dgl.graph
import
__REPR__
def
generate_graph
():
g
=
dgl
.
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
,
h
=
i
+
1
)
# 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
,
h
=
1
)
g
.
add_edge
(
i
,
9
,
h
=
i
+
1
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
,
h
=
10
)
return
g
def
check
(
g
,
h
,
fld
):
nh
=
[
str
(
g
.
nodes
[
i
][
fld
])
for
i
in
range
(
10
)]
h
=
[
str
(
x
)
for
x
in
h
]
assert
nh
==
h
,
"nh=[%s], h=[%s]"
%
(
' '
.
join
(
nh
),
' '
.
join
(
h
))
def
generate_graph1
():
"""graph with anonymous repr"""
g
=
dgl
.
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
,
__REPR__
=
i
+
1
)
# 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
,
__REPR__
=
1
)
g
.
add_edge
(
i
,
9
,
__REPR__
=
i
+
1
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
,
__REPR__
=
10
)
return
g
def
test_copy_src
():
# copy_src with both fields
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_src
(
src
=
'h'
,
out
=
'm'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
msgs
=
'm'
,
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
'h'
)
g
.
update_all
()
check
(
g
,
[
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
44
],
'h'
)
# copy_src with only src field; the out field should use anonymous repr
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_src
(
src
=
'h'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
'h'
)
g
.
update_all
()
check
(
g
,
[
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
44
],
'h'
)
# copy_src with no src field; should use anonymous repr
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_src
(
out
=
'm'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
msgs
=
'm'
,
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
__REPR__
)
g
.
update_all
()
check
(
g
,
[
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
44
],
'h'
)
# copy src with no fields;
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_src
(),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
__REPR__
)
g
.
update_all
()
check
(
g
,
[
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
44
],
'h'
)
def
test_copy_edge
():
# copy_edge with both fields
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_edge
(
edge
=
'h'
,
out
=
'm'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
msgs
=
'm'
,
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
'h'
)
g
.
update_all
()
check
(
g
,
[
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
44
],
'h'
)
# copy_edge with only edge field; the out field should use anonymous repr
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_edge
(
edge
=
'h'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
'h'
)
g
.
update_all
()
check
(
g
,
[
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
44
],
'h'
)
# copy_edge with no edge field; should use anonymous repr
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_edge
(
out
=
'm'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
msgs
=
'm'
,
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
__REPR__
)
g
.
update_all
()
check
(
g
,
[
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
44
],
'h'
)
# copy edge with no fields;
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_edge
(),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
__REPR__
)
g
.
update_all
()
check
(
g
,
[
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
44
],
'h'
)
def
test_src_mul_edge
():
# src_mul_edge with all fields
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
src_mul_edge
(
src
=
'h'
,
edge
=
'h'
,
out
=
'm'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
msgs
=
'm'
,
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
'h'
)
g
.
update_all
()
check
(
g
,
[
100
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
284
],
'h'
)
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
src_mul_edge
(
src
=
'h'
,
edge
=
'h'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
'h'
)
g
.
update_all
()
check
(
g
,
[
100
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
284
],
'h'
)
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
(
out
=
'm'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
msgs
=
'm'
,
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
__REPR__
)
g
.
update_all
()
check
(
g
,
[
100
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
284
],
'h'
)
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
(),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
__REPR__
)
g
.
update_all
()
check
(
g
,
[
100
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
284
],
'h'
)
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
(),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
__REPR__
)
g
.
update_all
()
check
(
g
,
[
100
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
284
],
__REPR__
)
if
__name__
==
'__main__'
:
test_copy_src
()
test_copy_edge
()
test_src_mul_edge
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment