Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9219349a
Unverified
Commit
9219349a
authored
Jul 10, 2018
by
Minjie Wang
Committed by
GitHub
Jul 10, 2018
Browse files
Use edge update to impl sendto; fix examples with missing reduce func (#22)
parent
68fb5f7e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
84 additions
and
51 deletions
+84
-51
examples/pagerank.py
examples/pagerank.py
+3
-2
python/dgl/graph.py
python/dgl/graph.py
+76
-46
tests/test_basics.py
tests/test_basics.py
+5
-3
No files found.
examples/pagerank.py
View file @
9219349a
...
...
@@ -10,8 +10,8 @@ K = 10
def
message_func
(
src
,
dst
,
edge
):
return
src
[
'pv'
]
/
src
[
'deg'
]
def
update_func
(
node
,
msgs
):
pv
=
(
1
-
DAMP
)
/
N
+
DAMP
*
sum
(
msgs
)
def
update_func
(
node
,
accum
):
pv
=
(
1
-
DAMP
)
/
N
+
DAMP
*
accum
return
{
'pv'
:
pv
}
def
compute_pagerank
(
g
):
...
...
@@ -19,6 +19,7 @@ def compute_pagerank(g):
print
(
g
.
number_of_edges
(),
g
.
number_of_nodes
())
g
.
register_message_func
(
message_func
)
g
.
register_update_func
(
update_func
)
g
.
register_reduce_func
(
'sum'
)
# init pv value
for
n
in
g
.
nodes
():
g
.
node
[
n
][
'pv'
]
=
1
/
N
...
...
python/dgl/graph.py
View file @
9219349a
...
...
@@ -16,6 +16,7 @@ __MFUNC__ = "__mfunc__"
__EFUNC__
=
"__efunc__"
__UFUNC__
=
"__ufunc__"
__RFUNC__
=
"__rfunc__"
__READOUT__
=
"__readout__"
class
DGLGraph
(
DiGraph
):
"""Base graph class specialized for neural networks on graphs.
...
...
@@ -31,10 +32,7 @@ class DGLGraph(DiGraph):
"""
def
__init__
(
self
,
graph_data
=
None
,
**
attr
):
super
(
DGLGraph
,
self
).
__init__
(
graph_data
,
**
attr
)
self
.
m_func
=
None
self
.
u_func
=
None
self
.
e_func
=
None
self
.
readout_func
=
None
self
.
_glb_func
=
{}
def
init_reprs
(
self
,
h_init
=
None
):
print
(
"[DEPRECATED]: please directly set node attrs "
...
...
@@ -55,12 +53,16 @@ class DGLGraph(DiGraph):
assert
u
in
self
.
nodes
return
self
.
nodes
[
u
][
name
]
def
register_message_func
(
self
,
message_func
,
edges
=
'all'
,
batchable
=
False
):
def
register_message_func
(
self
,
message_func
,
edges
=
'all'
,
batchable
=
False
,
name
=
__MFUNC__
):
"""Register computation on edges.
The message function should be compatible with following signature:
(node_reprs, node_reprs, edge_reprs) ->
edge_reprs
(node_reprs, node_reprs, edge_reprs) ->
msg
It computes the representation of a message
using the representations of the source node, target node and the edge
...
...
@@ -76,6 +78,8 @@ class DGLGraph(DiGraph):
supported.
batchable : bool
Whether the provided message function allows batch computing.
name : str
The name of the function.
Examples
--------
...
...
@@ -91,13 +95,15 @@ class DGLGraph(DiGraph):
>>> v = [v1, v2, v3, ...]
>>> g.register_message_func(mfunc, (u, v))
"""
if
edges
==
'all'
:
self
.
m_func
=
message_func
else
:
for
e
in
edges
:
self
.
edges
[
e
][
__MFUNC__
]
=
message_func
def
register_edge_func
(
self
,
edge_func
,
edges
=
'all'
,
batchable
=
False
):
def
_msg_edge_func
(
u
,
v
,
e_uv
):
return
{
__MSG__
:
message_func
(
u
,
v
,
e_uv
)}
self
.
register_edge_func
(
_msg_edge_func
,
edges
,
batchable
,
name
)
def
register_edge_func
(
self
,
edge_func
,
edges
=
'all'
,
batchable
=
False
,
name
=
__EFUNC__
):
"""Register computation on edges.
The edge function should be compatible with following signature:
...
...
@@ -118,6 +124,8 @@ class DGLGraph(DiGraph):
supported.
batchable : bool
Whether the provided message function allows batch computing.
name : str
The name of the function.
Examples
--------
...
...
@@ -134,12 +142,16 @@ class DGLGraph(DiGraph):
>>> g.register_edge_func(mfunc, (u, v))
"""
if
edges
==
'all'
:
self
.
e
_func
=
edge_func
self
.
_glb
_func
[
name
]
=
edge_func
else
:
for
e
in
edges
:
self
.
edges
[
e
][
__EFUNC__
]
=
edge_func
self
.
edges
[
e
][
name
]
=
edge_func
def
register_reduce_func
(
self
,
reduce_func
,
nodes
=
'all'
,
batchable
=
False
):
def
register_reduce_func
(
self
,
reduce_func
,
nodes
=
'all'
,
batchable
=
False
,
name
=
__RFUNC__
):
"""Register message reduce function on incoming edges.
The reduce function should be compatible with following signature:
...
...
@@ -163,6 +175,8 @@ class DGLGraph(DiGraph):
supported.
batchable : bool
Whether the provided reduce function allows batch computing.
name : str
The name of the function.
Examples
--------
...
...
@@ -187,12 +201,16 @@ class DGLGraph(DiGraph):
raise
NotImplementedError
(
"Built-in function %s not implemented"
%
reduce_func
)
if
nodes
==
'all'
:
self
.
r
_func
=
reduce_func
self
.
_glb
_func
[
name
]
=
reduce_func
else
:
for
n
in
nodes
:
self
.
nodes
[
n
][
__RFUNC__
]
=
reduce_func
self
.
nodes
[
n
][
name
]
=
reduce_func
def
register_update_func
(
self
,
update_func
,
nodes
=
'all'
,
batchable
=
False
):
def
register_update_func
(
self
,
update_func
,
nodes
=
'all'
,
batchable
=
False
,
name
=
__UFUNC__
):
"""Register computation on nodes.
The update function should be compatible with following signature:
...
...
@@ -213,6 +231,8 @@ class DGLGraph(DiGraph):
supported.
batchable : bool
Whether the provided update function allows batch computing.
name : str
The name of the function.
Examples
--------
...
...
@@ -228,12 +248,12 @@ class DGLGraph(DiGraph):
>>> g.register_update_func(ufunc, u)
"""
if
nodes
==
'all'
:
self
.
u
_func
=
update_func
self
.
_glb
_func
[
name
]
=
update_func
else
:
for
n
in
nodes
:
self
.
nodes
[
n
][
__UFUNC__
]
=
update_func
self
.
nodes
[
n
][
name
]
=
update_func
def
register_readout_func
(
self
,
readout_func
):
def
register_readout_func
(
self
,
readout_func
,
name
=
__READOUT__
):
"""Register computation on the whole graph.
The readout_func should be compatible with following signature:
...
...
@@ -251,14 +271,20 @@ class DGLGraph(DiGraph):
----------
readout_func : callable
The readout function.
name : str
The name of the function.
See Also
--------
readout
"""
self
.
readout_func
=
readout_func
self
.
_glb_func
[
name
]
=
readout_func
def
readout
(
self
,
nodes
=
'all'
,
edges
=
'all'
,
**
kwargs
):
def
readout
(
self
,
nodes
=
'all'
,
edges
=
'all'
,
name
=
__READOUT__
,
**
kwargs
):
"""Trigger the readout function on the specified nodes/edges.
Parameters
...
...
@@ -267,19 +293,21 @@ class DGLGraph(DiGraph):
The nodes to get reprs from.
edges : str, pair of nodes, pair of containers or pair of tensors
The edges to get reprs from.
name : str
The name of the function.
kwargs : keyword arguments, optional
Arguments for the readout function.
"""
nodes
=
self
.
_nodes_or_all
(
nodes
)
edges
=
self
.
_edges_or_all
(
edges
)
assert
self
.
readout_func
is
not
None
,
\
"Readout function
i
s not registered."
assert
name
in
self
.
_glb_func
,
\
"Readout function
\"
%s
\"
ha
s not
been
registered."
%
name
# TODO(minjie): tensorize following loop.
nstates
=
[
self
.
nodes
[
n
]
for
n
in
nodes
]
estates
=
[
self
.
edges
[
e
]
for
e
in
edges
]
return
self
.
readout_func
(
nstates
,
estates
,
**
kwargs
)
return
self
.
_glb_func
[
name
]
(
nstates
,
estates
,
**
kwargs
)
def
sendto
(
self
,
u
,
v
):
def
sendto
(
self
,
u
,
v
,
name
=
__MFUNC__
):
"""Trigger the message function on edge u->v
Parameters
...
...
@@ -288,16 +316,12 @@ class DGLGraph(DiGraph):
The source node(s).
v : node, container or tensor
The destination node(s).
name : str
The name of the function.
"""
# TODO(minjie): tensorize the loop.
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
f_msg
=
self
.
edges
[
uu
,
vv
].
get
(
__MFUNC__
,
self
.
m_func
)
assert
f_msg
is
not
None
,
\
"message function not registered for edge (%s->%s)"
%
(
uu
,
vv
)
m
=
f_msg
(
self
.
nodes
[
uu
],
self
.
nodes
[
vv
],
self
.
edges
[
uu
,
vv
])
self
.
edges
[
uu
,
vv
][
__MSG__
]
=
m
self
.
update_edge
(
u
,
v
,
name
)
def
update_edge
(
self
,
u
,
v
):
def
update_edge
(
self
,
u
,
v
,
name
=
__EFUNC__
):
"""Update representation on edge u->v
Parameters
...
...
@@ -306,16 +330,19 @@ class DGLGraph(DiGraph):
The source node(s).
v : node, container or tensor
The destination node(s).
name : str
The name of the function.
"""
# TODO(minjie): tensorize the loop.
efunc
=
self
.
_glb_func
.
get
(
name
)
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
f_edge
=
self
.
edges
[
uu
,
vv
].
get
(
__EFUNC__
,
self
.
m_
func
)
f_edge
=
self
.
edges
[
uu
,
vv
].
get
(
name
,
e
func
)
assert
f_edge
is
not
None
,
\
"edge function not registered for edge (%s->%s)"
%
(
uu
,
vv
)
"edge function
\"
%s
\"
not registered for edge (%s->%s)"
%
(
name
,
uu
,
vv
)
m
=
f_edge
(
self
.
nodes
[
uu
],
self
.
nodes
[
vv
],
self
.
edges
[
uu
,
vv
])
self
.
edges
[
uu
,
vv
]
[
__E_REPR__
]
=
m
self
.
edges
[
uu
,
vv
]
.
update
(
m
)
def
recvfrom
(
self
,
u
,
preds
=
None
):
def
recvfrom
(
self
,
u
,
preds
=
None
,
rname
=
__RFUNC__
,
uname
=
__UFUNC__
):
"""Trigger the update function on node u.
It computes the new node state using the messages and edge
...
...
@@ -330,9 +357,15 @@ class DGLGraph(DiGraph):
preds : container
Nodes with pre-computed messages to u. Default is all
the predecessors.
rname : str
The name of reduce function.
uname : str
The name of update function.
"""
u_is_container
=
isinstance
(
u
,
list
)
u_is_tensor
=
isinstance
(
u
,
Tensor
)
rfunc
=
self
.
_glb_func
.
get
(
rname
)
ufunc
=
self
.
_glb_func
.
get
(
uname
)
# TODO(minjie): tensorize the loop.
for
i
,
uu
in
enumerate
(
utils
.
node_iter
(
u
)):
if
preds
is
None
:
...
...
@@ -342,12 +375,12 @@ class DGLGraph(DiGraph):
else
:
v
=
preds
# TODO(minjie): tensorize the message batching
m
=
[
self
.
edges
[
vv
,
uu
][
__MSG__
]
for
vv
in
v
]
f_reduce
=
self
.
nodes
[
uu
].
get
(
__RFUNC__
,
self
.
r_func
)
f_reduce
=
self
.
nodes
[
uu
].
get
(
rname
,
rfunc
)
assert
f_reduce
is
not
None
,
\
"Reduce function not registered for node %s"
%
uu
m
=
[
self
.
edges
[
vv
,
uu
][
__MSG__
]
for
vv
in
v
]
msgs_reduced_repr
=
f_reduce
(
m
)
f_update
=
self
.
nodes
[
uu
].
get
(
__UFUNC__
,
self
.
u_
func
)
f_update
=
self
.
nodes
[
uu
].
get
(
uname
,
u
func
)
assert
f_update
is
not
None
,
\
"Update function not registered for node %s"
%
uu
self
.
node
[
uu
].
update
(
f_update
(
self
.
nodes
[
uu
],
msgs_reduced_repr
))
...
...
@@ -413,9 +446,6 @@ class DGLGraph(DiGraph):
v
=
[
vv
for
_
,
vv
in
self
.
edges
]
self
.
sendto
(
u
,
v
)
self
.
recvfrom
(
list
(
self
.
nodes
()))
# TODO(zz): this is a hack
if
self
.
e_func
:
self
.
update_edge
(
u
,
v
)
def
propagate
(
self
,
iterator
=
'bfs'
,
**
kwargs
):
"""Propagate messages and update nodes using iterator.
...
...
tests/test_basics.py
View file @
9219349a
...
...
@@ -3,9 +3,8 @@ from dgl.graph import DGLGraph
def
message_func
(
src
,
dst
,
edge
):
return
src
[
'h'
]
def
update_func
(
node
,
msgs
):
m
=
sum
(
msgs
)
return
{
'h'
:
node
[
'h'
]
+
m
}
def
update_func
(
node
,
accum
):
return
{
'h'
:
node
[
'h'
]
+
accum
}
def
generate_graph
():
g
=
DGLGraph
()
...
...
@@ -29,6 +28,7 @@ def test_sendrecv():
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_func
)
g
.
register_update_func
(
update_func
)
g
.
register_reduce_func
(
'sum'
)
g
.
sendto
(
0
,
1
)
g
.
recvfrom
(
1
,
[
0
])
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
...
...
@@ -42,6 +42,7 @@ def test_multi_sendrecv():
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_func
)
g
.
register_update_func
(
update_func
)
g
.
register_reduce_func
(
'sum'
)
# one-many
g
.
sendto
(
0
,
[
1
,
2
,
3
])
g
.
recvfrom
([
1
,
2
,
3
],
[[
0
],
[
0
],
[
0
]])
...
...
@@ -60,6 +61,7 @@ def test_update_routines():
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_func
)
g
.
register_update_func
(
update_func
)
g
.
register_reduce_func
(
'sum'
)
g
.
update_by_edge
(
0
,
1
)
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
update_to
(
9
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment