Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9219349a
Unverified
Commit
9219349a
authored
Jul 10, 2018
by
Minjie Wang
Committed by
GitHub
Jul 10, 2018
Browse files
Use edge update to impl sendto; fix examples with missing reduce func (#22)
parent
68fb5f7e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
84 additions
and
51 deletions
+84
-51
examples/pagerank.py
examples/pagerank.py
+3
-2
python/dgl/graph.py
python/dgl/graph.py
+76
-46
tests/test_basics.py
tests/test_basics.py
+5
-3
No files found.
examples/pagerank.py
View file @
9219349a
...
@@ -10,8 +10,8 @@ K = 10
...
@@ -10,8 +10,8 @@ K = 10
def
message_func
(
src
,
dst
,
edge
):
def
message_func
(
src
,
dst
,
edge
):
return
src
[
'pv'
]
/
src
[
'deg'
]
return
src
[
'pv'
]
/
src
[
'deg'
]
def
update_func
(
node
,
msgs
):
def
update_func
(
node
,
accum
):
pv
=
(
1
-
DAMP
)
/
N
+
DAMP
*
sum
(
msgs
)
pv
=
(
1
-
DAMP
)
/
N
+
DAMP
*
accum
return
{
'pv'
:
pv
}
return
{
'pv'
:
pv
}
def
compute_pagerank
(
g
):
def
compute_pagerank
(
g
):
...
@@ -19,6 +19,7 @@ def compute_pagerank(g):
...
@@ -19,6 +19,7 @@ def compute_pagerank(g):
print
(
g
.
number_of_edges
(),
g
.
number_of_nodes
())
print
(
g
.
number_of_edges
(),
g
.
number_of_nodes
())
g
.
register_message_func
(
message_func
)
g
.
register_message_func
(
message_func
)
g
.
register_update_func
(
update_func
)
g
.
register_update_func
(
update_func
)
g
.
register_reduce_func
(
'sum'
)
# init pv value
# init pv value
for
n
in
g
.
nodes
():
for
n
in
g
.
nodes
():
g
.
node
[
n
][
'pv'
]
=
1
/
N
g
.
node
[
n
][
'pv'
]
=
1
/
N
...
...
python/dgl/graph.py
View file @
9219349a
...
@@ -16,6 +16,7 @@ __MFUNC__ = "__mfunc__"
...
@@ -16,6 +16,7 @@ __MFUNC__ = "__mfunc__"
__EFUNC__
=
"__efunc__"
__EFUNC__
=
"__efunc__"
__UFUNC__
=
"__ufunc__"
__UFUNC__
=
"__ufunc__"
__RFUNC__
=
"__rfunc__"
__RFUNC__
=
"__rfunc__"
__READOUT__
=
"__readout__"
class
DGLGraph
(
DiGraph
):
class
DGLGraph
(
DiGraph
):
"""Base graph class specialized for neural networks on graphs.
"""Base graph class specialized for neural networks on graphs.
...
@@ -31,10 +32,7 @@ class DGLGraph(DiGraph):
...
@@ -31,10 +32,7 @@ class DGLGraph(DiGraph):
"""
"""
def
__init__
(
self
,
graph_data
=
None
,
**
attr
):
def
__init__
(
self
,
graph_data
=
None
,
**
attr
):
super
(
DGLGraph
,
self
).
__init__
(
graph_data
,
**
attr
)
super
(
DGLGraph
,
self
).
__init__
(
graph_data
,
**
attr
)
self
.
m_func
=
None
self
.
_glb_func
=
{}
self
.
u_func
=
None
self
.
e_func
=
None
self
.
readout_func
=
None
def
init_reprs
(
self
,
h_init
=
None
):
def
init_reprs
(
self
,
h_init
=
None
):
print
(
"[DEPRECATED]: please directly set node attrs "
print
(
"[DEPRECATED]: please directly set node attrs "
...
@@ -55,12 +53,16 @@ class DGLGraph(DiGraph):
...
@@ -55,12 +53,16 @@ class DGLGraph(DiGraph):
assert
u
in
self
.
nodes
assert
u
in
self
.
nodes
return
self
.
nodes
[
u
][
name
]
return
self
.
nodes
[
u
][
name
]
def
register_message_func
(
self
,
message_func
,
edges
=
'all'
,
batchable
=
False
):
def
register_message_func
(
self
,
message_func
,
edges
=
'all'
,
batchable
=
False
,
name
=
__MFUNC__
):
"""Register computation on edges.
"""Register computation on edges.
The message function should be compatible with following signature:
The message function should be compatible with following signature:
(node_reprs, node_reprs, edge_reprs) ->
edge_reprs
(node_reprs, node_reprs, edge_reprs) ->
msg
It computes the representation of a message
It computes the representation of a message
using the representations of the source node, target node and the edge
using the representations of the source node, target node and the edge
...
@@ -76,6 +78,8 @@ class DGLGraph(DiGraph):
...
@@ -76,6 +78,8 @@ class DGLGraph(DiGraph):
supported.
supported.
batchable : bool
batchable : bool
Whether the provided message function allows batch computing.
Whether the provided message function allows batch computing.
name : str
The name of the function.
Examples
Examples
--------
--------
...
@@ -91,13 +95,15 @@ class DGLGraph(DiGraph):
...
@@ -91,13 +95,15 @@ class DGLGraph(DiGraph):
>>> v = [v1, v2, v3, ...]
>>> v = [v1, v2, v3, ...]
>>> g.register_message_func(mfunc, (u, v))
>>> g.register_message_func(mfunc, (u, v))
"""
"""
if
edges
==
'all'
:
def
_msg_edge_func
(
u
,
v
,
e_uv
):
self
.
m_func
=
message_func
return
{
__MSG__
:
message_func
(
u
,
v
,
e_uv
)}
else
:
self
.
register_edge_func
(
_msg_edge_func
,
edges
,
batchable
,
name
)
for
e
in
edges
:
self
.
edges
[
e
][
__MFUNC__
]
=
message_func
def
register_edge_func
(
self
,
edge_func
,
def
register_edge_func
(
self
,
edge_func
,
edges
=
'all'
,
batchable
=
False
):
edges
=
'all'
,
batchable
=
False
,
name
=
__EFUNC__
):
"""Register computation on edges.
"""Register computation on edges.
The edge function should be compatible with following signature:
The edge function should be compatible with following signature:
...
@@ -118,6 +124,8 @@ class DGLGraph(DiGraph):
...
@@ -118,6 +124,8 @@ class DGLGraph(DiGraph):
supported.
supported.
batchable : bool
batchable : bool
Whether the provided message function allows batch computing.
Whether the provided message function allows batch computing.
name : str
The name of the function.
Examples
Examples
--------
--------
...
@@ -134,12 +142,16 @@ class DGLGraph(DiGraph):
...
@@ -134,12 +142,16 @@ class DGLGraph(DiGraph):
>>> g.register_edge_func(mfunc, (u, v))
>>> g.register_edge_func(mfunc, (u, v))
"""
"""
if
edges
==
'all'
:
if
edges
==
'all'
:
self
.
e
_func
=
edge_func
self
.
_glb
_func
[
name
]
=
edge_func
else
:
else
:
for
e
in
edges
:
for
e
in
edges
:
self
.
edges
[
e
][
__EFUNC__
]
=
edge_func
self
.
edges
[
e
][
name
]
=
edge_func
def
register_reduce_func
(
self
,
reduce_func
,
nodes
=
'all'
,
batchable
=
False
):
def
register_reduce_func
(
self
,
reduce_func
,
nodes
=
'all'
,
batchable
=
False
,
name
=
__RFUNC__
):
"""Register message reduce function on incoming edges.
"""Register message reduce function on incoming edges.
The reduce function should be compatible with following signature:
The reduce function should be compatible with following signature:
...
@@ -163,6 +175,8 @@ class DGLGraph(DiGraph):
...
@@ -163,6 +175,8 @@ class DGLGraph(DiGraph):
supported.
supported.
batchable : bool
batchable : bool
Whether the provided reduce function allows batch computing.
Whether the provided reduce function allows batch computing.
name : str
The name of the function.
Examples
Examples
--------
--------
...
@@ -187,12 +201,16 @@ class DGLGraph(DiGraph):
...
@@ -187,12 +201,16 @@ class DGLGraph(DiGraph):
raise
NotImplementedError
(
raise
NotImplementedError
(
"Built-in function %s not implemented"
%
reduce_func
)
"Built-in function %s not implemented"
%
reduce_func
)
if
nodes
==
'all'
:
if
nodes
==
'all'
:
self
.
r
_func
=
reduce_func
self
.
_glb
_func
[
name
]
=
reduce_func
else
:
else
:
for
n
in
nodes
:
for
n
in
nodes
:
self
.
nodes
[
n
][
__RFUNC__
]
=
reduce_func
self
.
nodes
[
n
][
name
]
=
reduce_func
def
register_update_func
(
self
,
update_func
,
nodes
=
'all'
,
batchable
=
False
):
def
register_update_func
(
self
,
update_func
,
nodes
=
'all'
,
batchable
=
False
,
name
=
__UFUNC__
):
"""Register computation on nodes.
"""Register computation on nodes.
The update function should be compatible with following signature:
The update function should be compatible with following signature:
...
@@ -213,6 +231,8 @@ class DGLGraph(DiGraph):
...
@@ -213,6 +231,8 @@ class DGLGraph(DiGraph):
supported.
supported.
batchable : bool
batchable : bool
Whether the provided update function allows batch computing.
Whether the provided update function allows batch computing.
name : str
The name of the function.
Examples
Examples
--------
--------
...
@@ -228,12 +248,12 @@ class DGLGraph(DiGraph):
...
@@ -228,12 +248,12 @@ class DGLGraph(DiGraph):
>>> g.register_update_func(ufunc, u)
>>> g.register_update_func(ufunc, u)
"""
"""
if
nodes
==
'all'
:
if
nodes
==
'all'
:
self
.
u
_func
=
update_func
self
.
_glb
_func
[
name
]
=
update_func
else
:
else
:
for
n
in
nodes
:
for
n
in
nodes
:
self
.
nodes
[
n
][
__UFUNC__
]
=
update_func
self
.
nodes
[
n
][
name
]
=
update_func
def
register_readout_func
(
self
,
readout_func
):
def
register_readout_func
(
self
,
readout_func
,
name
=
__READOUT__
):
"""Register computation on the whole graph.
"""Register computation on the whole graph.
The readout_func should be compatible with following signature:
The readout_func should be compatible with following signature:
...
@@ -251,14 +271,20 @@ class DGLGraph(DiGraph):
...
@@ -251,14 +271,20 @@ class DGLGraph(DiGraph):
----------
----------
readout_func : callable
readout_func : callable
The readout function.
The readout function.
name : str
The name of the function.
See Also
See Also
--------
--------
readout
readout
"""
"""
self
.
readout_func
=
readout_func
self
.
_glb_func
[
name
]
=
readout_func
def
readout
(
self
,
nodes
=
'all'
,
edges
=
'all'
,
**
kwargs
):
def
readout
(
self
,
nodes
=
'all'
,
edges
=
'all'
,
name
=
__READOUT__
,
**
kwargs
):
"""Trigger the readout function on the specified nodes/edges.
"""Trigger the readout function on the specified nodes/edges.
Parameters
Parameters
...
@@ -267,19 +293,21 @@ class DGLGraph(DiGraph):
...
@@ -267,19 +293,21 @@ class DGLGraph(DiGraph):
The nodes to get reprs from.
The nodes to get reprs from.
edges : str, pair of nodes, pair of containers or pair of tensors
edges : str, pair of nodes, pair of containers or pair of tensors
The edges to get reprs from.
The edges to get reprs from.
name : str
The name of the function.
kwargs : keyword arguments, optional
kwargs : keyword arguments, optional
Arguments for the readout function.
Arguments for the readout function.
"""
"""
nodes
=
self
.
_nodes_or_all
(
nodes
)
nodes
=
self
.
_nodes_or_all
(
nodes
)
edges
=
self
.
_edges_or_all
(
edges
)
edges
=
self
.
_edges_or_all
(
edges
)
assert
self
.
readout_func
is
not
None
,
\
assert
name
in
self
.
_glb_func
,
\
"Readout function
i
s not registered."
"Readout function
\"
%s
\"
ha
s not
been
registered."
%
name
# TODO(minjie): tensorize following loop.
# TODO(minjie): tensorize following loop.
nstates
=
[
self
.
nodes
[
n
]
for
n
in
nodes
]
nstates
=
[
self
.
nodes
[
n
]
for
n
in
nodes
]
estates
=
[
self
.
edges
[
e
]
for
e
in
edges
]
estates
=
[
self
.
edges
[
e
]
for
e
in
edges
]
return
self
.
readout_func
(
nstates
,
estates
,
**
kwargs
)
return
self
.
_glb_func
[
name
]
(
nstates
,
estates
,
**
kwargs
)
def
sendto
(
self
,
u
,
v
):
def
sendto
(
self
,
u
,
v
,
name
=
__MFUNC__
):
"""Trigger the message function on edge u->v
"""Trigger the message function on edge u->v
Parameters
Parameters
...
@@ -288,16 +316,12 @@ class DGLGraph(DiGraph):
...
@@ -288,16 +316,12 @@ class DGLGraph(DiGraph):
The source node(s).
The source node(s).
v : node, container or tensor
v : node, container or tensor
The destination node(s).
The destination node(s).
name : str
The name of the function.
"""
"""
# TODO(minjie): tensorize the loop.
self
.
update_edge
(
u
,
v
,
name
)
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
f_msg
=
self
.
edges
[
uu
,
vv
].
get
(
__MFUNC__
,
self
.
m_func
)
assert
f_msg
is
not
None
,
\
"message function not registered for edge (%s->%s)"
%
(
uu
,
vv
)
m
=
f_msg
(
self
.
nodes
[
uu
],
self
.
nodes
[
vv
],
self
.
edges
[
uu
,
vv
])
self
.
edges
[
uu
,
vv
][
__MSG__
]
=
m
def
update_edge
(
self
,
u
,
v
):
def
update_edge
(
self
,
u
,
v
,
name
=
__EFUNC__
):
"""Update representation on edge u->v
"""Update representation on edge u->v
Parameters
Parameters
...
@@ -306,16 +330,19 @@ class DGLGraph(DiGraph):
...
@@ -306,16 +330,19 @@ class DGLGraph(DiGraph):
The source node(s).
The source node(s).
v : node, container or tensor
v : node, container or tensor
The destination node(s).
The destination node(s).
name : str
The name of the function.
"""
"""
# TODO(minjie): tensorize the loop.
# TODO(minjie): tensorize the loop.
efunc
=
self
.
_glb_func
.
get
(
name
)
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
f_edge
=
self
.
edges
[
uu
,
vv
].
get
(
__EFUNC__
,
self
.
m_
func
)
f_edge
=
self
.
edges
[
uu
,
vv
].
get
(
name
,
e
func
)
assert
f_edge
is
not
None
,
\
assert
f_edge
is
not
None
,
\
"edge function not registered for edge (%s->%s)"
%
(
uu
,
vv
)
"edge function
\"
%s
\"
not registered for edge (%s->%s)"
%
(
name
,
uu
,
vv
)
m
=
f_edge
(
self
.
nodes
[
uu
],
self
.
nodes
[
vv
],
self
.
edges
[
uu
,
vv
])
m
=
f_edge
(
self
.
nodes
[
uu
],
self
.
nodes
[
vv
],
self
.
edges
[
uu
,
vv
])
self
.
edges
[
uu
,
vv
]
[
__E_REPR__
]
=
m
self
.
edges
[
uu
,
vv
]
.
update
(
m
)
def
recvfrom
(
self
,
u
,
preds
=
None
):
def
recvfrom
(
self
,
u
,
preds
=
None
,
rname
=
__RFUNC__
,
uname
=
__UFUNC__
):
"""Trigger the update function on node u.
"""Trigger the update function on node u.
It computes the new node state using the messages and edge
It computes the new node state using the messages and edge
...
@@ -330,9 +357,15 @@ class DGLGraph(DiGraph):
...
@@ -330,9 +357,15 @@ class DGLGraph(DiGraph):
preds : container
preds : container
Nodes with pre-computed messages to u. Default is all
Nodes with pre-computed messages to u. Default is all
the predecessors.
the predecessors.
rname : str
The name of reduce function.
uname : str
The name of update function.
"""
"""
u_is_container
=
isinstance
(
u
,
list
)
u_is_container
=
isinstance
(
u
,
list
)
u_is_tensor
=
isinstance
(
u
,
Tensor
)
u_is_tensor
=
isinstance
(
u
,
Tensor
)
rfunc
=
self
.
_glb_func
.
get
(
rname
)
ufunc
=
self
.
_glb_func
.
get
(
uname
)
# TODO(minjie): tensorize the loop.
# TODO(minjie): tensorize the loop.
for
i
,
uu
in
enumerate
(
utils
.
node_iter
(
u
)):
for
i
,
uu
in
enumerate
(
utils
.
node_iter
(
u
)):
if
preds
is
None
:
if
preds
is
None
:
...
@@ -342,12 +375,12 @@ class DGLGraph(DiGraph):
...
@@ -342,12 +375,12 @@ class DGLGraph(DiGraph):
else
:
else
:
v
=
preds
v
=
preds
# TODO(minjie): tensorize the message batching
# TODO(minjie): tensorize the message batching
m
=
[
self
.
edges
[
vv
,
uu
][
__MSG__
]
for
vv
in
v
]
f_reduce
=
self
.
nodes
[
uu
].
get
(
rname
,
rfunc
)
f_reduce
=
self
.
nodes
[
uu
].
get
(
__RFUNC__
,
self
.
r_func
)
assert
f_reduce
is
not
None
,
\
assert
f_reduce
is
not
None
,
\
"Reduce function not registered for node %s"
%
uu
"Reduce function not registered for node %s"
%
uu
m
=
[
self
.
edges
[
vv
,
uu
][
__MSG__
]
for
vv
in
v
]
msgs_reduced_repr
=
f_reduce
(
m
)
msgs_reduced_repr
=
f_reduce
(
m
)
f_update
=
self
.
nodes
[
uu
].
get
(
__UFUNC__
,
self
.
u_
func
)
f_update
=
self
.
nodes
[
uu
].
get
(
uname
,
u
func
)
assert
f_update
is
not
None
,
\
assert
f_update
is
not
None
,
\
"Update function not registered for node %s"
%
uu
"Update function not registered for node %s"
%
uu
self
.
node
[
uu
].
update
(
f_update
(
self
.
nodes
[
uu
],
msgs_reduced_repr
))
self
.
node
[
uu
].
update
(
f_update
(
self
.
nodes
[
uu
],
msgs_reduced_repr
))
...
@@ -413,9 +446,6 @@ class DGLGraph(DiGraph):
...
@@ -413,9 +446,6 @@ class DGLGraph(DiGraph):
v
=
[
vv
for
_
,
vv
in
self
.
edges
]
v
=
[
vv
for
_
,
vv
in
self
.
edges
]
self
.
sendto
(
u
,
v
)
self
.
sendto
(
u
,
v
)
self
.
recvfrom
(
list
(
self
.
nodes
()))
self
.
recvfrom
(
list
(
self
.
nodes
()))
# TODO(zz): this is a hack
if
self
.
e_func
:
self
.
update_edge
(
u
,
v
)
def
propagate
(
self
,
iterator
=
'bfs'
,
**
kwargs
):
def
propagate
(
self
,
iterator
=
'bfs'
,
**
kwargs
):
"""Propagate messages and update nodes using iterator.
"""Propagate messages and update nodes using iterator.
...
...
tests/test_basics.py
View file @
9219349a
...
@@ -3,9 +3,8 @@ from dgl.graph import DGLGraph
...
@@ -3,9 +3,8 @@ from dgl.graph import DGLGraph
def
message_func
(
src
,
dst
,
edge
):
def
message_func
(
src
,
dst
,
edge
):
return
src
[
'h'
]
return
src
[
'h'
]
def
update_func
(
node
,
msgs
):
def
update_func
(
node
,
accum
):
m
=
sum
(
msgs
)
return
{
'h'
:
node
[
'h'
]
+
accum
}
return
{
'h'
:
node
[
'h'
]
+
m
}
def
generate_graph
():
def
generate_graph
():
g
=
DGLGraph
()
g
=
DGLGraph
()
...
@@ -29,6 +28,7 @@ def test_sendrecv():
...
@@ -29,6 +28,7 @@ def test_sendrecv():
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_func
)
g
.
register_message_func
(
message_func
)
g
.
register_update_func
(
update_func
)
g
.
register_update_func
(
update_func
)
g
.
register_reduce_func
(
'sum'
)
g
.
sendto
(
0
,
1
)
g
.
sendto
(
0
,
1
)
g
.
recvfrom
(
1
,
[
0
])
g
.
recvfrom
(
1
,
[
0
])
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
...
@@ -42,6 +42,7 @@ def test_multi_sendrecv():
...
@@ -42,6 +42,7 @@ def test_multi_sendrecv():
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_func
)
g
.
register_message_func
(
message_func
)
g
.
register_update_func
(
update_func
)
g
.
register_update_func
(
update_func
)
g
.
register_reduce_func
(
'sum'
)
# one-many
# one-many
g
.
sendto
(
0
,
[
1
,
2
,
3
])
g
.
sendto
(
0
,
[
1
,
2
,
3
])
g
.
recvfrom
([
1
,
2
,
3
],
[[
0
],
[
0
],
[
0
]])
g
.
recvfrom
([
1
,
2
,
3
],
[[
0
],
[
0
],
[
0
]])
...
@@ -60,6 +61,7 @@ def test_update_routines():
...
@@ -60,6 +61,7 @@ def test_update_routines():
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_func
)
g
.
register_message_func
(
message_func
)
g
.
register_update_func
(
update_func
)
g
.
register_update_func
(
update_func
)
g
.
register_reduce_func
(
'sum'
)
g
.
update_by_edge
(
0
,
1
)
g
.
update_by_edge
(
0
,
1
)
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
update_to
(
9
)
g
.
update_to
(
9
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment