Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
7d04c8c9
Commit
7d04c8c9
authored
Oct 03, 2018
by
Minjie Wang
Browse files
remove nonbatchable mode
parent
3a3e5d48
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
122 additions
and
644 deletions
+122
-644
python/dgl/graph.py
python/dgl/graph.py
+56
-190
tests/pytorch/test_batching.py
tests/pytorch/test_batching.py
+11
-11
tests/pytorch/test_batching_anonymous.py
tests/pytorch/test_batching_anonymous.py
+5
-5
tests/pytorch/test_function.py
tests/pytorch/test_function.py
+26
-26
tests/pytorch/test_graph_batch.py
tests/pytorch/test_graph_batch.py
+4
-4
tests/pytorch/test_specialization.py
tests/pytorch/test_specialization.py
+20
-20
tests/test_anonymous_repr.py
tests/test_anonymous_repr.py
+0
-62
tests/test_basics.py
tests/test_basics.py
+0
-111
tests/test_basics2.py
tests/test_basics2.py
+0
-74
tests/test_function.py
tests/test_function.py
+0
-141
No files found.
python/dgl/graph.py
View file @
7d04c8c9
...
...
@@ -50,11 +50,11 @@ class DGLGraph(object):
self
.
_msg_frame
=
FrameRef
()
self
.
reset_messages
()
# registered functions
self
.
_message_func
=
(
None
,
None
)
self
.
_reduce_func
=
(
None
,
None
)
self
.
_edge_func
=
(
None
,
None
)
self
.
_apply_node_func
=
(
None
,
None
)
self
.
_apply_edge_func
=
(
None
,
None
)
self
.
_message_func
=
None
self
.
_reduce_func
=
None
self
.
_edge_func
=
None
self
.
_apply_node_func
=
None
self
.
_apply_edge_func
=
None
def
add_nodes
(
self
,
num
,
reprs
=
None
):
"""Add nodes.
...
...
@@ -710,77 +710,57 @@ class DGLGraph(object):
else
:
return
self
.
_edge_frame
.
select_rows
(
eid
)
def
register_edge_func
(
self
,
edge_func
,
batchable
=
False
):
def
register_edge_func
(
self
,
edge_func
):
"""Register global edge update function.
Parameters
----------
edge_func : callable
Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
"""
self
.
_edge_func
=
(
edge_func
,
batchable
)
self
.
_edge_func
=
edge_func
def
register_message_func
(
self
,
message_func
,
batchable
=
False
):
def
register_message_func
(
self
,
message_func
):
"""Register global message function.
Parameters
----------
message_func : callable
Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
"""
self
.
_message_func
=
(
message_func
,
batchable
)
self
.
_message_func
=
message_func
def
register_reduce_func
(
self
,
reduce_func
,
batchable
=
False
):
def
register_reduce_func
(
self
,
reduce_func
):
"""Register global message reduce function.
Parameters
----------
reduce_func : str or callable
Reduce function on incoming edges.
batchable : bool
Whether the provided reduce function allows batch computing.
"""
self
.
_reduce_func
=
(
reduce_func
,
batchable
)
self
.
_reduce_func
=
reduce_func
def
register_apply_node_func
(
self
,
apply_node_func
,
batchable
=
False
):
def
register_apply_node_func
(
self
,
apply_node_func
):
"""Register global node apply function.
Parameters
----------
apply_node_func : callable
Apply function on the node.
batchable : bool
Whether the provided function allows batch computing.
"""
self
.
_apply_node_func
=
(
apply_node_func
,
batchable
)
self
.
_apply_node_func
=
apply_node_func
def
register_apply_edge_func
(
self
,
apply_edge_func
,
batchable
=
False
):
def
register_apply_edge_func
(
self
,
apply_edge_func
):
"""Register global edge apply function.
Parameters
----------
apply_edge_func : callable
Apply function on the edge.
batchable : bool
Whether the provided function allows batch computing.
"""
self
.
_apply_edge_func
=
(
apply_edge_func
,
batchable
)
self
.
_apply_edge_func
=
apply_edge_func
def
apply_nodes
(
self
,
v
,
apply_node_func
=
"default"
,
batchable
=
False
):
def
apply_nodes
(
self
,
v
,
apply_node_func
=
"default"
):
"""Apply the function on node representations.
Parameters
...
...
@@ -789,27 +769,16 @@ class DGLGraph(object):
The node id(s).
apply_node_func : callable
The apply node function.
batchable : bool
Whether the provided function allows batch computing.
"""
if
apply_node_func
==
"default"
:
apply_node_func
,
batchable
=
self
.
_apply_node_func
apply_node_func
=
self
.
_apply_node_func
if
not
apply_node_func
:
# Skip none function call.
return
if
batchable
:
new_repr
=
apply_node_func
(
self
.
get_n_repr
(
v
))
self
.
set_n_repr
(
new_repr
,
v
)
else
:
raise
RuntimeError
(
'Disabled'
)
if
is_all
(
v
):
v
=
self
.
nodes
()
v
=
utils
.
toindex
(
v
)
for
vv
in
utils
.
node_iter
(
v
):
ret
=
apply_node_func
(
_get_repr
(
self
.
nodes
[
vv
]))
_set_repr
(
self
.
nodes
[
vv
],
ret
)
def
apply_edges
(
self
,
u
,
v
,
apply_edge_func
=
"default"
,
batchable
=
False
):
def
apply_edges
(
self
,
u
,
v
,
apply_edge_func
=
"default"
):
"""Apply the function on edge representations.
Parameters
...
...
@@ -820,27 +789,16 @@ class DGLGraph(object):
The dst node id(s).
apply_edge_func : callable
The apply edge function.
batchable : bool
Whether the provided function allows batch computing.
"""
if
apply_edge_func
==
"default"
:
apply_edge_func
,
batchable
=
self
.
_apply_edge_func
apply_edge_func
=
self
.
_apply_edge_func
if
not
apply_edge_func
:
# Skip none function call.
return
if
batchable
:
new_repr
=
apply_edge_func
(
self
.
get_e_repr
(
u
,
v
))
self
.
set_e_repr
(
new_repr
,
u
,
v
)
else
:
if
is_all
(
u
)
==
is_all
(
v
):
u
,
v
=
zip
(
*
self
.
edges
)
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
ret
=
apply_edge_func
(
_get_repr
(
self
.
edges
[
uu
,
vv
]))
_set_repr
(
self
.
edges
[
uu
,
vv
],
ret
)
def
send
(
self
,
u
,
v
,
message_func
=
"default"
,
batchable
=
False
):
def
send
(
self
,
u
,
v
,
message_func
=
"default"
):
"""Trigger the message function on edge u->v
The message function should be compatible with following signature:
...
...
@@ -861,30 +819,13 @@ class DGLGraph(object):
The destination node(s).
message_func : callable
The message function.
batchable : bool
Whether the function allows batched computation.
"""
if
message_func
==
"default"
:
message_func
,
batchable
=
self
.
_message_func
message_func
=
self
.
_message_func
assert
message_func
is
not
None
if
isinstance
(
message_func
,
(
tuple
,
list
)):
message_func
=
BundledMessageFunction
(
message_func
)
if
batchable
:
self
.
_batch_send
(
u
,
v
,
message_func
)
else
:
self
.
_nonbatch_send
(
u
,
v
,
message_func
)
def
_nonbatch_send
(
self
,
u
,
v
,
message_func
):
raise
RuntimeError
(
'Disabled'
)
if
is_all
(
u
)
and
is_all
(
v
):
u
,
v
=
self
.
cached_graph
.
edges
()
else
:
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
ret
=
message_func
(
_get_repr
(
self
.
nodes
[
uu
]),
_get_repr
(
self
.
edges
[
uu
,
vv
]))
self
.
edges
[
uu
,
vv
][
__MSG__
]
=
ret
def
_batch_send
(
self
,
u
,
v
,
message_func
):
if
is_all
(
u
)
and
is_all
(
v
):
...
...
@@ -908,7 +849,7 @@ class DGLGraph(object):
else
:
self
.
_msg_frame
.
append
({
__MSG__
:
msgs
})
def
update_edge
(
self
,
u
=
ALL
,
v
=
ALL
,
edge_func
=
"default"
,
batchable
=
False
):
def
update_edge
(
self
,
u
=
ALL
,
v
=
ALL
,
edge_func
=
"default"
):
"""Update representation on edge u->v
The edge function should be compatible with following signature:
...
...
@@ -927,29 +868,11 @@ class DGLGraph(object):
The destination node(s).
edge_func : callable
The update function.
batchable : bool
Whether the function allows batched computation.
"""
if
edge_func
==
"default"
:
edge_func
,
batchable
=
self
.
_edge_func
edge_func
=
self
.
_edge_func
assert
edge_func
is
not
None
if
batchable
:
self
.
_batch_update_edge
(
u
,
v
,
edge_func
)
else
:
self
.
_nonbatch_update_edge
(
u
,
v
,
edge_func
)
def
_nonbatch_update_edge
(
self
,
u
,
v
,
edge_func
):
raise
RuntimeError
(
'Disabled'
)
if
is_all
(
u
)
and
is_all
(
v
):
u
,
v
=
self
.
cached_graph
.
edges
()
else
:
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
ret
=
edge_func
(
_get_repr
(
self
.
nodes
[
uu
]),
_get_repr
(
self
.
nodes
[
vv
]),
_get_repr
(
self
.
edges
[
uu
,
vv
]))
_set_repr
(
self
.
edges
[
uu
,
vv
],
ret
)
def
_batch_update_edge
(
self
,
u
,
v
,
edge_func
):
if
is_all
(
u
)
and
is_all
(
v
):
...
...
@@ -975,8 +898,7 @@ class DGLGraph(object):
def
recv
(
self
,
u
,
reduce_func
=
"default"
,
apply_node_func
=
"default"
,
batchable
=
False
):
apply_node_func
=
"default"
):
"""Receive and reduce in-coming messages and update representation on node u.
It computes the new node state using the messages sent from the predecessors
...
...
@@ -1006,34 +928,15 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable, optional
The update function.
batchable : bool, optional
Whether the reduce and update function allows batched computation.
"""
if
reduce_func
==
"default"
:
reduce_func
,
batchable
=
self
.
_reduce_func
reduce_func
=
self
.
_reduce_func
assert
reduce_func
is
not
None
if
isinstance
(
reduce_func
,
(
list
,
tuple
)):
reduce_func
=
BundledReduceFunction
(
reduce_func
)
if
batchable
:
self
.
_batch_recv
(
u
,
reduce_func
)
else
:
self
.
_nonbatch_recv
(
u
,
reduce_func
)
# optional apply nodes
self
.
apply_nodes
(
u
,
apply_node_func
,
batchable
)
def
_nonbatch_recv
(
self
,
u
,
reduce_func
):
raise
RuntimeError
(
'Disabled'
)
if
is_all
(
u
):
u
=
list
(
range
(
0
,
self
.
number_of_nodes
()))
else
:
u
=
utils
.
toindex
(
u
)
for
i
,
uu
in
enumerate
(
utils
.
node_iter
(
u
)):
# reduce phase
msgs_batch
=
[
self
.
edges
[
vv
,
uu
].
pop
(
__MSG__
)
for
vv
in
self
.
pred
[
uu
]
if
__MSG__
in
self
.
edges
[
vv
,
uu
]]
if
len
(
msgs_batch
)
!=
0
:
new_repr
=
reduce_func
(
_get_repr
(
self
.
nodes
[
uu
]),
msgs_batch
)
_set_repr
(
self
.
nodes
[
uu
],
new_repr
)
self
.
apply_nodes
(
u
,
apply_node_func
)
def
_batch_recv
(
self
,
v
,
reduce_func
):
if
self
.
_msg_frame
.
num_rows
==
0
:
...
...
@@ -1105,8 +1008,7 @@ class DGLGraph(object):
u
,
v
,
message_func
=
"default"
,
reduce_func
=
"default"
,
apply_node_func
=
"default"
,
batchable
=
False
):
apply_node_func
=
"default"
):
"""Trigger the message function on u->v and update v.
Parameters
...
...
@@ -1121,8 +1023,6 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable, optional
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
...
...
@@ -1132,34 +1032,28 @@ class DGLGraph(object):
return
unique_v
=
utils
.
toindex
(
F
.
unique
(
v
.
tousertensor
()))
# TODO(minjie): better way to figure out `batchable` flag
if
message_func
==
"default"
:
message_func
,
batchable
=
self
.
_message_func
message_func
=
self
.
_message_func
if
reduce_func
==
"default"
:
reduce_func
,
_
=
self
.
_reduce_func
reduce_func
=
self
.
_reduce_func
assert
message_func
is
not
None
assert
reduce_func
is
not
None
if
batchable
:
executor
=
scheduler
.
get_executor
(
'send_and_recv'
,
self
,
src
=
u
,
dst
=
v
,
message_func
=
message_func
,
reduce_func
=
reduce_func
)
else
:
executor
=
None
if
executor
:
executor
.
run
()
else
:
self
.
send
(
u
,
v
,
message_func
,
batchable
=
batchable
)
self
.
recv
(
unique_v
,
reduce_func
,
None
,
batchable
=
batchable
)
self
.
apply_nodes
(
unique_v
,
apply_node_func
,
batchable
=
batchable
)
self
.
send
(
u
,
v
,
message_func
)
self
.
recv
(
unique_v
,
reduce_func
,
None
)
self
.
apply_nodes
(
unique_v
,
apply_node_func
)
def
pull
(
self
,
v
,
message_func
=
"default"
,
reduce_func
=
"default"
,
apply_node_func
=
"default"
,
batchable
=
False
):
apply_node_func
=
"default"
):
"""Pull messages from the node's predecessors and then update it.
Parameters
...
...
@@ -1172,24 +1066,20 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable, optional
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
v
=
utils
.
toindex
(
v
)
if
len
(
v
)
==
0
:
return
uu
,
vv
,
_
=
self
.
_graph
.
in_edges
(
v
)
self
.
send_and_recv
(
uu
,
vv
,
message_func
,
reduce_func
,
apply_node_func
=
None
,
batchable
=
batchable
)
self
.
send_and_recv
(
uu
,
vv
,
message_func
,
reduce_func
,
apply_node_func
=
None
)
unique_v
=
F
.
unique
(
v
.
tousertensor
())
self
.
apply_nodes
(
unique_v
,
apply_node_func
,
batchable
=
batchable
)
self
.
apply_nodes
(
unique_v
,
apply_node_func
)
def
push
(
self
,
u
,
message_func
=
"default"
,
reduce_func
=
"default"
,
apply_node_func
=
"default"
,
batchable
=
False
):
apply_node_func
=
"default"
):
"""Send message from the node to its successors and update them.
Parameters
...
...
@@ -1202,21 +1092,18 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
u
=
utils
.
toindex
(
u
)
if
len
(
u
)
==
0
:
return
uu
,
vv
,
_
=
self
.
_graph
.
out_edges
(
u
)
self
.
send_and_recv
(
uu
,
vv
,
message_func
,
reduce_func
,
apply_node_func
,
batchable
=
batchable
)
reduce_func
,
apply_node_func
)
def
update_all
(
self
,
message_func
=
"default"
,
reduce_func
=
"default"
,
apply_node_func
=
"default"
,
batchable
=
False
):
apply_node_func
=
"default"
):
"""Send messages through all the edges and update all nodes.
Parameters
...
...
@@ -1227,35 +1114,28 @@ class DGLGraph(object):
The reduce function.
apply_node_func : callable, optional
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
if
message_func
==
"default"
:
message_func
,
batchable
=
self
.
_message_func
message_func
=
self
.
_message_func
if
reduce_func
==
"default"
:
reduce_func
,
_
=
self
.
_reduce_func
reduce_func
=
self
.
_reduce_func
assert
message_func
is
not
None
assert
reduce_func
is
not
None
if
batchable
:
executor
=
scheduler
.
get_executor
(
"update_all"
,
self
,
message_func
=
message_func
,
reduce_func
=
reduce_func
)
else
:
executor
=
None
if
executor
:
executor
.
run
()
else
:
self
.
send
(
ALL
,
ALL
,
message_func
,
batchable
=
batchable
)
self
.
recv
(
ALL
,
reduce_func
,
None
,
batchable
=
batchable
)
self
.
apply_nodes
(
ALL
,
apply_node_func
,
batchable
=
batchable
)
self
.
send
(
ALL
,
ALL
,
message_func
)
self
.
recv
(
ALL
,
reduce_func
,
None
)
self
.
apply_nodes
(
ALL
,
apply_node_func
)
def
propagate
(
self
,
iterator
=
'bfs'
,
message_func
=
"default"
,
reduce_func
=
"default"
,
apply_node_func
=
"default"
,
batchable
=
False
,
**
kwargs
):
"""Propagate messages and update nodes using iterator.
...
...
@@ -1274,8 +1154,6 @@ class DGLGraph(object):
The reduce function.
apply_node_func : str or callable
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
iterator : str or generator of steps.
The iterator of the graph.
kwargs : keyword arguments, optional
...
...
@@ -1288,7 +1166,7 @@ class DGLGraph(object):
# NOTE: the iteration can return multiple edges at each step.
for
u
,
v
in
iterator
:
self
.
send_and_recv
(
u
,
v
,
message_func
,
reduce_func
,
apply_node_func
,
batchable
)
message_func
,
reduce_func
,
apply_node_func
)
def
subgraph
(
self
,
nodes
):
"""Generate the subgraph among the given nodes.
...
...
@@ -1350,15 +1228,3 @@ class DGLGraph(object):
[
sg
.
_parent_eid
for
sg
in
to_merge
],
self
.
_edge_frame
.
num_rows
,
reduce_func
)
def
_get_repr
(
attr_dict
):
if
len
(
attr_dict
)
==
1
and
__REPR__
in
attr_dict
:
return
attr_dict
[
__REPR__
]
else
:
return
attr_dict
def
_set_repr
(
attr_dict
,
attr
):
if
utils
.
is_dict_like
(
attr
):
attr_dict
.
update
(
attr
)
else
:
attr_dict
[
__REPR__
]
=
attr
tests/pytorch/test_batching.py
View file @
7d04c8c9
...
...
@@ -133,7 +133,7 @@ def test_batch_send():
def
_fmsg
(
src
,
edge
):
assert
src
[
'h'
].
shape
==
(
5
,
D
)
return
{
'm'
:
src
[
'h'
]}
g
.
register_message_func
(
_fmsg
,
batchable
=
True
)
g
.
register_message_func
(
_fmsg
)
# many-many send
u
=
th
.
tensor
([
0
,
0
,
0
,
0
,
0
])
v
=
th
.
tensor
([
1
,
2
,
3
,
4
,
5
])
...
...
@@ -150,9 +150,9 @@ def test_batch_send():
def
test_batch_recv
():
# basic recv test
g
=
generate_graph
()
g
.
register_message_func
(
message_func
,
batchable
=
True
)
g
.
register_reduce_func
(
reduce_func
,
batchable
=
True
)
g
.
register_apply_node_func
(
apply_node_func
,
batchable
=
True
)
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
)
g
.
register_apply_node_func
(
apply_node_func
)
u
=
th
.
tensor
([
0
,
0
,
0
,
4
,
5
,
6
])
v
=
th
.
tensor
([
1
,
2
,
3
,
9
,
9
,
9
])
reduce_msg_shapes
.
clear
()
...
...
@@ -163,9 +163,9 @@ def test_batch_recv():
def
test_update_routines
():
g
=
generate_graph
()
g
.
register_message_func
(
message_func
,
batchable
=
True
)
g
.
register_reduce_func
(
reduce_func
,
batchable
=
True
)
g
.
register_apply_node_func
(
apply_node_func
,
batchable
=
True
)
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
)
g
.
register_apply_node_func
(
apply_node_func
)
# send_and_recv
reduce_msg_shapes
.
clear
()
...
...
@@ -209,7 +209,7 @@ def test_reduce_0deg():
return
node
+
msgs
.
sum
(
1
)
old_repr
=
th
.
randn
(
5
,
5
)
g
.
set_n_repr
(
old_repr
)
g
.
update_all
(
_message
,
_reduce
,
batchable
=
True
)
g
.
update_all
(
_message
,
_reduce
)
new_repr
=
g
.
get_n_repr
()
assert
th
.
allclose
(
new_repr
[
1
:],
old_repr
[
1
:])
...
...
@@ -227,17 +227,17 @@ def test_pull_0deg():
old_repr
=
th
.
randn
(
2
,
5
)
g
.
set_n_repr
(
old_repr
)
g
.
pull
(
0
,
_message
,
_reduce
,
batchable
=
True
)
g
.
pull
(
0
,
_message
,
_reduce
)
new_repr
=
g
.
get_n_repr
()
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
1
])
g
.
pull
(
1
,
_message
,
_reduce
,
batchable
=
True
)
g
.
pull
(
1
,
_message
,
_reduce
)
new_repr
=
g
.
get_n_repr
()
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
0
])
old_repr
=
th
.
randn
(
2
,
5
)
g
.
set_n_repr
(
old_repr
)
g
.
pull
([
0
,
1
],
_message
,
_reduce
,
batchable
=
True
)
g
.
pull
([
0
,
1
],
_message
,
_reduce
)
new_repr
=
g
.
get_n_repr
()
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
0
])
...
...
tests/pytorch/test_batching_anonymous.py
View file @
7d04c8c9
...
...
@@ -129,7 +129,7 @@ def test_batch_send():
def
_fmsg
(
hu
,
edge
):
assert
hu
.
shape
==
(
5
,
D
)
return
hu
g
.
register_message_func
(
_fmsg
,
batchable
=
True
)
g
.
register_message_func
(
_fmsg
)
# many-many send
u
=
th
.
tensor
([
0
,
0
,
0
,
0
,
0
])
v
=
th
.
tensor
([
1
,
2
,
3
,
4
,
5
])
...
...
@@ -145,8 +145,8 @@ def test_batch_send():
def
test_batch_recv
():
g
=
generate_graph
()
g
.
register_message_func
(
message_func
,
batchable
=
True
)
g
.
register_reduce_func
(
reduce_func
,
batchable
=
True
)
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
)
u
=
th
.
tensor
([
0
,
0
,
0
,
4
,
5
,
6
])
v
=
th
.
tensor
([
1
,
2
,
3
,
9
,
9
,
9
])
reduce_msg_shapes
.
clear
()
...
...
@@ -157,8 +157,8 @@ def test_batch_recv():
def
test_update_routines
():
g
=
generate_graph
()
g
.
register_message_func
(
message_func
,
batchable
=
True
)
g
.
register_reduce_func
(
reduce_func
,
batchable
=
True
)
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
)
# send_and_recv
reduce_msg_shapes
.
clear
()
...
...
tests/pytorch/test_function.py
View file @
7d04c8c9
...
...
@@ -51,32 +51,32 @@ def reducer_none(node, msgs):
def
test_copy_src
():
# copy_src with both fields
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_src
(
src
=
'h'
,
out
=
'm'
)
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_both
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
copy_src
(
src
=
'h'
,
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy_src with only src field; the out field should use anonymous repr
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_src
(
src
=
'h'
)
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_out
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
copy_src
(
src
=
'h'
))
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy_src with no src field; should use anonymous repr
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_src
(
out
=
'm'
)
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_both
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
copy_src
(
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy src with no fields;
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_src
()
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_out
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
copy_src
())
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
...
...
@@ -84,32 +84,32 @@ def test_copy_src():
def
test_copy_edge
():
# copy_edge with both fields
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_edge
(
edge
=
'h'
,
out
=
'm'
)
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_both
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
copy_edge
(
edge
=
'h'
,
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy_edge with only edge field; the out field should use anonymous repr
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_edge
(
edge
=
'h'
)
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_out
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
copy_edge
(
edge
=
'h'
))
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy_edge with no edge field; should use anonymous repr
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_edge
(
out
=
'm'
)
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_both
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
copy_edge
(
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy edge with no fields;
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_edge
()
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_out
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
copy_edge
())
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
...
...
@@ -117,36 +117,36 @@ def test_copy_edge():
def
test_src_mul_edge
():
# src_mul_edge with all fields
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
src_mul_edge
(
src
=
'h'
,
edge
=
'h'
,
out
=
'm'
)
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_both
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
src_mul_edge
(
src
=
'h'
,
edge
=
'h'
,
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
src_mul_edge
(
src
=
'h'
,
edge
=
'h'
)
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_out
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
src_mul_edge
(
src
=
'h'
,
edge
=
'h'
))
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
(
out
=
'm'
)
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_both
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
src_mul_edge
(
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
()
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_out
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
src_mul_edge
())
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
()
,
batchable
=
True
)
g
.
register_reduce_func
(
reducer_none
,
batchable
=
True
)
g
.
register_message_func
(
fn
.
src_mul_edge
())
g
.
register_reduce_func
(
reducer_none
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
(),
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
...
...
tests/pytorch/test_graph_batch.py
View file @
7d04c8c9
...
...
@@ -71,8 +71,8 @@ def test_batch_sendrecv():
t2
=
tree2
()
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
.
register_message_func
(
lambda
src
,
edge
:
src
,
batchable
=
True
)
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
torch
.
sum
(
msgs
,
1
)
,
batchable
=
True
)
bg
.
register_message_func
(
lambda
src
,
edge
:
src
)
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
torch
.
sum
(
msgs
,
1
))
e1
=
[(
3
,
1
),
(
4
,
1
)]
e2
=
[(
2
,
4
),
(
0
,
4
)]
...
...
@@ -94,8 +94,8 @@ def test_batch_propagate():
t2
=
tree2
()
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
.
register_message_func
(
lambda
src
,
edge
:
src
,
batchable
=
True
)
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
torch
.
sum
(
msgs
,
1
)
,
batchable
=
True
)
bg
.
register_message_func
(
lambda
src
,
edge
:
src
)
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
torch
.
sum
(
msgs
,
1
))
# get leaves.
order
=
[]
...
...
tests/pytorch/test_specialization.py
View file @
7d04c8c9
...
...
@@ -38,23 +38,23 @@ def test_update_all():
g
=
generate_graph
()
# update all
v1
=
g
.
get_n_repr
()[
fld
]
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
fld
),
apply_func
,
batchable
=
True
)
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
fld
),
apply_func
)
v2
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
message_func
,
reduce_func
,
apply_func
,
batchable
=
True
)
g
.
update_all
(
message_func
,
reduce_func
,
apply_func
)
v3
=
g
.
get_n_repr
()[
fld
]
assert
th
.
allclose
(
v2
,
v3
)
# update all with edge weights
v1
=
g
.
get_n_repr
()[
fld
]
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
),
fn
.
sum
(
out
=
fld
),
apply_func
,
batchable
=
True
)
fn
.
sum
(
out
=
fld
),
apply_func
)
v2
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
),
fn
.
sum
(
out
=
fld
),
apply_func
,
batchable
=
True
)
fn
.
sum
(
out
=
fld
),
apply_func
)
v3
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
message_func_edge
,
reduce_func
,
apply_func
,
batchable
=
True
)
g
.
update_all
(
message_func_edge
,
reduce_func
,
apply_func
)
v4
=
g
.
get_n_repr
()[
fld
]
assert
th
.
allclose
(
v2
,
v3
)
assert
th
.
allclose
(
v3
,
v4
)
...
...
@@ -85,25 +85,25 @@ def test_send_and_recv():
# send and recv
v1
=
g
.
get_n_repr
()[
fld
]
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
fld
),
apply_func
,
batchable
=
True
)
fn
.
sum
(
out
=
fld
),
apply_func
)
v2
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
send_and_recv
(
u
,
v
,
message_func
,
reduce_func
,
apply_func
,
batchable
=
True
)
reduce_func
,
apply_func
)
v3
=
g
.
get_n_repr
()[
fld
]
assert
th
.
allclose
(
v2
,
v3
)
# send and recv with edge weights
v1
=
g
.
get_n_repr
()[
fld
]
g
.
send_and_recv
(
u
,
v
,
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
),
fn
.
sum
(
out
=
fld
),
apply_func
,
batchable
=
True
)
fn
.
sum
(
out
=
fld
),
apply_func
)
v2
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
send_and_recv
(
u
,
v
,
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
),
fn
.
sum
(
out
=
fld
),
apply_func
,
batchable
=
True
)
fn
.
sum
(
out
=
fld
),
apply_func
)
v3
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
send_and_recv
(
u
,
v
,
message_func_edge
,
reduce_func
,
apply_func
,
batchable
=
True
)
reduce_func
,
apply_func
)
v4
=
g
.
get_n_repr
()[
fld
]
assert
th
.
allclose
(
v2
,
v3
)
assert
th
.
allclose
(
v3
,
v4
)
...
...
@@ -127,18 +127,18 @@ def test_update_all_multi_fn():
# update all, mix of builtin and UDF
g
.
update_all
([
fn
.
copy_src
(
src
=
fld
,
out
=
'm1'
),
message_func
],
[
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v1'
),
reduce_func
],
None
,
batchable
=
True
)
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
assert
th
.
allclose
(
v1
,
v2
)
# run builtin with single message and reduce
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
'v1'
),
None
,
batchable
=
True
)
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
'v1'
),
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
assert
th
.
allclose
(
v1
,
v2
)
# 1 message, 2 reduces, using anonymous repr
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
[
fn
.
sum
(
out
=
'v2'
),
fn
.
sum
(
out
=
'v3'
)],
None
,
batchable
=
True
)
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
[
fn
.
sum
(
out
=
'v2'
),
fn
.
sum
(
out
=
'v3'
)],
None
)
v2
=
g
.
get_n_repr
()[
'v2'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
assert
th
.
allclose
(
v1
,
v2
)
...
...
@@ -147,7 +147,7 @@ def test_update_all_multi_fn():
# update all with edge weights, 2 message, 3 reduces
g
.
update_all
([
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm1'
),
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm2'
)],
[
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v1'
),
fn
.
sum
(
msgs
=
'm2'
,
out
=
'v2'
),
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v3'
)],
None
,
batchable
=
True
)
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
...
...
@@ -155,7 +155,7 @@ def test_update_all_multi_fn():
assert
th
.
allclose
(
v1
,
v3
)
# run UDF with single message and reduce
g
.
update_all
(
message_func_edge
,
reduce_func
,
None
,
batchable
=
True
)
g
.
update_all
(
message_func_edge
,
reduce_func
,
None
)
v2
=
g
.
get_n_repr
()[
'v2'
]
assert
th
.
allclose
(
v1
,
v2
)
...
...
@@ -179,19 +179,19 @@ def test_send_and_recv_multi_fn():
g
.
send_and_recv
(
u
,
v
,
[
fn
.
copy_src
(
src
=
fld
,
out
=
'm1'
),
message_func
],
[
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v1'
),
reduce_func
],
None
,
batchable
=
True
)
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
assert
th
.
allclose
(
v1
,
v2
)
# run builtin with single message and reduce
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
'v1'
),
None
,
batchable
=
True
)
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
assert
th
.
allclose
(
v1
,
v2
)
# 1 message, 2 reduces, using anonymous repr
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
[
fn
.
sum
(
out
=
'v2'
),
fn
.
sum
(
out
=
'v3'
)],
None
,
batchable
=
True
)
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
[
fn
.
sum
(
out
=
'v2'
),
fn
.
sum
(
out
=
'v3'
)],
None
)
v2
=
g
.
get_n_repr
()[
'v2'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
assert
th
.
allclose
(
v1
,
v2
)
...
...
@@ -201,7 +201,7 @@ def test_send_and_recv_multi_fn():
g
.
send_and_recv
(
u
,
v
,
[
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm1'
),
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm2'
)],
[
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v1'
),
fn
.
sum
(
msgs
=
'm2'
,
out
=
'v2'
),
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v3'
)],
None
,
batchable
=
True
)
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
...
...
@@ -210,7 +210,7 @@ def test_send_and_recv_multi_fn():
# run UDF with single message and reduce
g
.
send_and_recv
(
u
,
v
,
message_func_edge
,
reduce_func
,
None
,
batchable
=
True
)
reduce_func
,
None
)
v2
=
g
.
get_n_repr
()[
'v2'
]
assert
th
.
allclose
(
v1
,
v2
)
...
...
tests/test_anonymous_repr.py
deleted
100644 → 0
View file @
3a3e5d48
from
dgl
import
DGLGraph
from
dgl.graph
import
__REPR__
def
message_func
(
hu
,
e_uv
):
return
hu
+
e_uv
def
reduce_func
(
h
,
msgs
):
return
h
+
sum
(
msgs
)
def
generate_graph
():
g
=
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
,
__REPR__
=
i
+
1
)
# 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
,
__REPR__
=
1
)
g
.
add_edge
(
i
,
9
,
__REPR__
=
1
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
)
return
g
def
check
(
g
,
h
):
nh
=
[
str
(
g
.
nodes
[
i
][
__REPR__
])
for
i
in
range
(
10
)]
h
=
[
str
(
x
)
for
x
in
h
]
assert
nh
==
h
,
"nh=[%s], h=[%s]"
%
(
' '
.
join
(
nh
),
' '
.
join
(
h
))
def
test_sendrecv
():
g
=
generate_graph
()
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
)
g
.
send
(
0
,
1
)
g
.
recv
(
1
)
check
(
g
,
[
1
,
4
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
send
(
5
,
9
)
g
.
send
(
6
,
9
)
g
.
recv
(
9
)
check
(
g
,
[
1
,
4
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
25
])
def
message_func_hybrid
(
src
,
edge
):
return
src
[
__REPR__
]
+
edge
def
reduce_func_hybrid
(
node
,
msgs
):
return
node
[
__REPR__
]
+
sum
(
msgs
)
def
test_hybridrepr
():
g
=
generate_graph
()
for
i
in
range
(
10
):
g
.
nodes
[
i
][
'id'
]
=
-
i
g
.
register_message_func
(
message_func_hybrid
)
g
.
register_reduce_func
(
reduce_func_hybrid
)
g
.
send
(
0
,
1
)
g
.
recv
(
1
)
check
(
g
,
[
1
,
4
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
send
(
5
,
9
)
g
.
send
(
6
,
9
)
g
.
recv
(
9
)
check
(
g
,
[
1
,
4
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
25
])
if
__name__
==
'__main__'
:
test_sendrecv
()
test_hybridrepr
()
tests/test_basics.py
deleted
100644 → 0
View file @
3a3e5d48
from
dgl.graph
import
DGLGraph
def
message_func
(
src
,
edge
):
return
src
[
'h'
]
def
reduce_func
(
node
,
msgs
):
return
{
'm'
:
sum
(
msgs
)}
def
apply_func
(
node
):
return
{
'h'
:
node
[
'h'
]
+
node
[
'm'
]}
def
message_dict_func
(
src
,
edge
):
return
{
'm'
:
src
[
'h'
]}
def
reduce_dict_func
(
node
,
msgs
):
return
{
'm'
:
sum
([
msg
[
'm'
]
for
msg
in
msgs
])}
def
apply_dict_func
(
node
):
return
{
'h'
:
node
[
'h'
]
+
node
[
'm'
]}
def
generate_graph
():
g
=
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
,
h
=
i
+
1
)
# 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
i
,
9
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
)
return
g
def
check
(
g
,
h
):
nh
=
[
str
(
g
.
nodes
[
i
][
'h'
])
for
i
in
range
(
10
)]
h
=
[
str
(
x
)
for
x
in
h
]
assert
nh
==
h
,
"nh=[%s], h=[%s]"
%
(
' '
.
join
(
nh
),
' '
.
join
(
h
))
def
register1
(
g
):
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
)
g
.
register_apply_node_func
(
apply_func
)
def
register2
(
g
):
g
.
register_message_func
(
message_dict_func
)
g
.
register_reduce_func
(
reduce_dict_func
)
g
.
register_apply_node_func
(
apply_dict_func
)
def
_test_sendrecv
(
g
):
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
send
(
0
,
1
)
g
.
recv
(
1
)
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
send
(
5
,
9
)
g
.
send
(
6
,
9
)
g
.
recv
(
9
)
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
23
])
def
_test_multi_sendrecv
(
g
):
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
# one-many
g
.
send
(
0
,
[
1
,
2
,
3
])
g
.
recv
([
1
,
2
,
3
])
check
(
g
,
[
1
,
3
,
4
,
5
,
5
,
6
,
7
,
8
,
9
,
10
])
# many-one
g
.
send
([
6
,
7
,
8
],
9
)
g
.
recv
(
9
)
check
(
g
,
[
1
,
3
,
4
,
5
,
5
,
6
,
7
,
8
,
9
,
34
])
# many-many
g
.
send
([
0
,
0
,
4
,
5
],
[
4
,
5
,
9
,
9
])
g
.
recv
([
4
,
5
,
9
])
check
(
g
,
[
1
,
3
,
4
,
5
,
6
,
7
,
7
,
8
,
9
,
45
])
def
_test_update_routines
(
g
):
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
send_and_recv
(
0
,
1
)
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
pull
(
9
)
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
55
])
g
.
push
(
0
)
check
(
g
,
[
1
,
4
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
55
])
g
.
update_all
()
check
(
g
,
[
56
,
5
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
108
])
def
test_sendrecv
():
g
=
generate_graph
()
register1
(
g
)
_test_sendrecv
(
g
)
g
=
generate_graph
()
register2
(
g
)
_test_sendrecv
(
g
)
def
test_multi_sendrecv
():
g
=
generate_graph
()
register1
(
g
)
_test_multi_sendrecv
(
g
)
g
=
generate_graph
()
register2
(
g
)
_test_multi_sendrecv
(
g
)
def
test_update_routines
():
g
=
generate_graph
()
register1
(
g
)
_test_update_routines
(
g
)
g
=
generate_graph
()
register2
(
g
)
_test_update_routines
(
g
)
if
__name__
==
'__main__'
:
test_sendrecv
()
test_multi_sendrecv
()
test_update_routines
()
tests/test_basics2.py
deleted
100644 → 0
View file @
3a3e5d48
from
dgl
import
DGLGraph
from
dgl.graph
import
__REPR__
def
message_func
(
hu
,
e_uv
):
return
hu
def
message_not_called
(
hu
,
e_uv
):
assert
False
return
hu
def
reduce_not_called
(
h
,
msgs
):
assert
False
return
0
def
reduce_func
(
h
,
msgs
):
return
h
+
sum
(
msgs
)
def
check
(
g
,
h
):
nh
=
[
str
(
g
.
nodes
[
i
][
__REPR__
])
for
i
in
range
(
10
)]
h
=
[
str
(
x
)
for
x
in
h
]
assert
nh
==
h
,
"nh=[%s], h=[%s]"
%
(
' '
.
join
(
nh
),
' '
.
join
(
h
))
def
generate_graph
():
g
=
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
,
__REPR__
=
i
+
1
)
# 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
i
,
9
)
return
g
def
test_no_msg_recv
():
g
=
generate_graph
()
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_not_called
)
g
.
register_reduce_func
(
reduce_not_called
)
g
.
register_apply_node_func
(
lambda
h
:
h
+
1
)
for
i
in
range
(
10
):
g
.
recv
(
i
)
check
(
g
,
[
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
])
def
test_double_recv
():
g
=
generate_graph
()
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
)
g
.
send
(
1
,
9
)
g
.
send
(
2
,
9
)
g
.
recv
(
9
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
15
])
g
.
register_reduce_func
(
reduce_not_called
)
g
.
recv
(
9
)
def
test_pull_0deg
():
g
=
DGLGraph
()
g
.
add_node
(
0
,
h
=
2
)
g
.
add_node
(
1
,
h
=
1
)
g
.
add_edge
(
0
,
1
)
def
_message
(
src
,
edge
):
assert
False
return
src
def
_reduce
(
node
,
msgs
):
assert
False
return
node
def
_update
(
node
):
return
{
'h'
:
node
[
'h'
]
*
2
}
g
.
pull
(
0
,
_message
,
_reduce
,
_update
)
assert
g
.
nodes
[
0
][
'h'
]
==
4
if
__name__
==
'__main__'
:
test_no_msg_recv
()
test_double_recv
()
test_pull_0deg
()
tests/test_function.py
deleted
100644 → 0
View file @
3a3e5d48
import
dgl
import
dgl.function
as
fn
from
dgl.graph
import
__REPR__
def
generate_graph
():
g
=
dgl
.
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
,
h
=
i
+
1
)
# 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
,
h
=
1
)
g
.
add_edge
(
i
,
9
,
h
=
i
+
1
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
,
h
=
10
)
return
g
def
check
(
g
,
h
,
fld
):
nh
=
[
str
(
g
.
nodes
[
i
][
fld
])
for
i
in
range
(
10
)]
h
=
[
str
(
x
)
for
x
in
h
]
assert
nh
==
h
,
"nh=[%s], h=[%s]"
%
(
' '
.
join
(
nh
),
' '
.
join
(
h
))
def
generate_graph1
():
"""graph with anonymous repr"""
g
=
dgl
.
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
,
__REPR__
=
i
+
1
)
# 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
,
__REPR__
=
1
)
g
.
add_edge
(
i
,
9
,
__REPR__
=
i
+
1
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
,
__REPR__
=
10
)
return
g
def
test_copy_src
():
# copy_src with both fields
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_src
(
src
=
'h'
,
out
=
'm'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
msgs
=
'm'
,
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
'h'
)
g
.
update_all
()
check
(
g
,
[
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
44
],
'h'
)
# copy_src with only src field; the out field should use anonymous repr
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_src
(
src
=
'h'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
'h'
)
g
.
update_all
()
check
(
g
,
[
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
44
],
'h'
)
# copy_src with no src field; should use anonymous repr
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_src
(
out
=
'm'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
msgs
=
'm'
,
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
__REPR__
)
g
.
update_all
()
check
(
g
,
[
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
44
],
'h'
)
# copy src with no fields;
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_src
(),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
__REPR__
)
g
.
update_all
()
check
(
g
,
[
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
44
],
'h'
)
def
test_copy_edge
():
# copy_edge with both fields
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_edge
(
edge
=
'h'
,
out
=
'm'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
msgs
=
'm'
,
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
'h'
)
g
.
update_all
()
check
(
g
,
[
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
44
],
'h'
)
# copy_edge with only edge field; the out field should use anonymous repr
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_edge
(
edge
=
'h'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
'h'
)
g
.
update_all
()
check
(
g
,
[
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
44
],
'h'
)
# copy_edge with no edge field; should use anonymous repr
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_edge
(
out
=
'm'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
msgs
=
'm'
,
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
__REPR__
)
g
.
update_all
()
check
(
g
,
[
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
44
],
'h'
)
# copy edge with no fields;
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_edge
(),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
__REPR__
)
g
.
update_all
()
check
(
g
,
[
10
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
44
],
'h'
)
def
test_src_mul_edge
():
# src_mul_edge with all fields
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
src_mul_edge
(
src
=
'h'
,
edge
=
'h'
,
out
=
'm'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
msgs
=
'm'
,
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
'h'
)
g
.
update_all
()
check
(
g
,
[
100
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
284
],
'h'
)
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
src_mul_edge
(
src
=
'h'
,
edge
=
'h'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
'h'
)
g
.
update_all
()
check
(
g
,
[
100
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
284
],
'h'
)
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
(
out
=
'm'
),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
msgs
=
'm'
,
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
__REPR__
)
g
.
update_all
()
check
(
g
,
[
100
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
284
],
'h'
)
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
(),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(
out
=
'h'
),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
__REPR__
)
g
.
update_all
()
check
(
g
,
[
100
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
284
],
'h'
)
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
(),
batchable
=
False
)
g
.
register_reduce_func
(
fn
.
sum
(),
batchable
=
False
)
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
__REPR__
)
g
.
update_all
()
check
(
g
,
[
100
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
284
],
__REPR__
)
if
__name__
==
'__main__'
:
test_copy_src
()
test_copy_edge
()
test_src_mul_edge
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment