Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f310e586
Commit
f310e586
authored
Jun 14, 2018
by
Minjie Wang
Browse files
Add test/example; fix bug in graph.py
parent
15a2c22c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
118 additions
and
6 deletions
+118
-6
examples/pagerank.py
examples/pagerank.py
+34
-0
python/dgl/graph.py
python/dgl/graph.py
+9
-6
tests/test_basics.py
tests/test_basics.py
+75
-0
No files found.
examples/pagerank.py
0 → 100644
View file @
f310e586
from
__future__
import
division
import
networkx
as
nx
from
dgl.graph
import
DGLGraph
DAMP
=
0.85
N
=
100
K
=
10
def
message_func
(
src
,
dst
,
edge
):
return
src
[
'pv'
]
/
src
[
'deg'
]
def
update_func
(
node
,
msgs
):
pv
=
(
1
-
DAMP
)
/
N
+
DAMP
*
sum
(
msgs
)
return
{
'pv'
:
pv
}
def
compute_pagerank
(
g
):
g
=
DGLGraph
(
g
)
print
(
g
.
number_of_edges
(),
g
.
number_of_nodes
())
g
.
register_message_func
(
message_func
)
g
.
register_update_func
(
update_func
)
# init pv value
for
n
in
g
.
nodes
():
g
.
node
[
n
][
'pv'
]
=
1
/
N
g
.
node
[
n
][
'deg'
]
=
g
.
out_degree
(
n
)
# pagerank
for
k
in
range
(
K
):
g
.
update_all
()
return
[
g
.
node
[
n
][
'pv'
]
for
n
in
g
.
nodes
()]
if
__name__
==
'__main__'
:
g
=
nx
.
erdos_renyi_graph
(
N
,
0.05
)
pv
=
compute_pagerank
(
g
)
print
(
pv
)
python/dgl/graph.py
View file @
f310e586
"""Base graph class specialized for neural networks on graphs.
"""
from
collections
import
defaultdict
import
networkx
as
nx
from
networkx.classes.digraph
import
DiGraph
...
...
@@ -170,7 +171,7 @@ class DGLGraph(DiGraph):
"""
nodes
=
self
.
_nodes_or_all
(
nodes
)
edges
=
self
.
_nodes_or_all
(
nodes
)
assert
self
.
readout_func
is
not
None
,
assert
self
.
readout_func
is
not
None
,
\
"Readout function is not registered."
# TODO(minjie): tensorize following loop.
nstates
=
[
self
.
nodes
[
n
]
for
n
in
nodes
]
...
...
@@ -190,7 +191,7 @@ class DGLGraph(DiGraph):
# TODO(minjie): tensorize the loop.
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
f_msg
=
self
.
edges
[
uu
,
vv
].
get
(
__MFUNC__
,
self
.
m_func
)
assert
f_msg
is
not
None
,
assert
f_msg
is
not
None
,
\
"message function not registered for edge (%s->%s)"
%
(
uu
,
vv
)
m
=
f_msg
(
self
.
nodes
[
uu
],
self
.
nodes
[
vv
],
self
.
edges
[
uu
,
vv
])
self
.
edges
[
uu
,
vv
][
__MSG__
]
=
m
...
...
@@ -224,9 +225,9 @@ class DGLGraph(DiGraph):
# TODO(minjie): tensorize the message batching
m
=
[
self
.
edges
[
vv
,
uu
][
__MSG__
]
for
vv
in
v
]
f_update
=
self
.
nodes
[
uu
].
get
(
__UFUNC__
,
self
.
u_func
)
assert
f_update
is
not
None
,
assert
f_update
is
not
None
,
\
"Update function not registered for node %s"
%
uu
self
.
node
s
[
uu
]
=
f_update
(
self
.
nodes
[
uu
],
m
)
self
.
node
[
uu
]
.
update
(
f_update
(
self
.
nodes
[
uu
],
m
)
)
def
update_by_edge
(
self
,
u
,
v
):
"""Trigger the message function on u->v and update v.
...
...
@@ -283,9 +284,9 @@ class DGLGraph(DiGraph):
u
=
[
uu
for
uu
,
_
in
self
.
edges
]
v
=
[
vv
for
_
,
vv
in
self
.
edges
]
self
.
sendto
(
u
,
v
)
self
.
recvfrom
(
v
)
self
.
recvfrom
(
list
(
self
.
nodes
())
)
def
propagate
(
self
,
iterator
=
'bfs'
):
def
propagate
(
self
,
iterator
=
'bfs'
,
**
kwargs
):
"""Propagate messages and update nodes using iterator.
A convenient function for passing messages and updating
...
...
@@ -299,6 +300,8 @@ class DGLGraph(DiGraph):
----------
iterator : str or generator of steps.
The iterator of the graph.
kwargs : keyword arguments, optional
Arguments for pre-defined iterators.
"""
if
isinstance
(
iterator
,
str
):
# TODO Call pre-defined routine to unroll the computation.
...
...
tests/test_basics.py
0 → 100644
View file @
f310e586
from
dgl.graph
import
DGLGraph
def
message_func
(
src
,
dst
,
edge
):
return
src
[
'h'
]
def
update_func
(
node
,
msgs
):
m
=
sum
(
msgs
)
return
{
'h'
:
node
[
'h'
]
+
m
}
def
generate_graph
():
g
=
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
,
h
=
i
+
1
)
# 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
i
,
9
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
)
return
g
def
check
(
g
,
h
):
nh
=
[
str
(
g
.
nodes
[
i
][
'h'
])
for
i
in
range
(
10
)]
h
=
[
str
(
x
)
for
x
in
h
]
assert
nh
==
h
,
"nh=[%s], h=[%s]"
%
(
' '
.
join
(
nh
),
' '
.
join
(
h
))
def
test_sendrecv
():
g
=
generate_graph
()
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_func
)
g
.
register_update_func
(
update_func
)
g
.
sendto
(
0
,
1
)
g
.
recvfrom
(
1
,
[
0
])
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
sendto
(
5
,
9
)
g
.
sendto
(
6
,
9
)
g
.
recvfrom
(
9
,
[
5
,
6
])
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
23
])
def
test_multi_sendrecv
():
g
=
generate_graph
()
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_func
)
g
.
register_update_func
(
update_func
)
# one-many
g
.
sendto
(
0
,
[
1
,
2
,
3
])
g
.
recvfrom
([
1
,
2
,
3
],
[[
0
],
[
0
],
[
0
]])
check
(
g
,
[
1
,
3
,
4
,
5
,
5
,
6
,
7
,
8
,
9
,
10
])
# many-one
g
.
sendto
([
6
,
7
,
8
],
9
)
g
.
recvfrom
(
9
,
[
6
,
7
,
8
])
check
(
g
,
[
1
,
3
,
4
,
5
,
5
,
6
,
7
,
8
,
9
,
34
])
# many-many
g
.
sendto
([
0
,
0
,
4
,
5
],
[
4
,
5
,
9
,
9
])
g
.
recvfrom
([
4
,
5
,
9
],
[[
0
],
[
0
],
[
4
,
5
]])
check
(
g
,
[
1
,
3
,
4
,
5
,
6
,
7
,
7
,
8
,
9
,
45
])
def
test_update_routines
():
g
=
generate_graph
()
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_func
)
g
.
register_update_func
(
update_func
)
g
.
update_by_edge
(
0
,
1
)
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
update_to
(
9
)
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
55
])
g
.
update_from
(
0
)
check
(
g
,
[
1
,
4
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
55
])
g
.
update_all
()
check
(
g
,
[
56
,
5
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
108
])
if
__name__
==
'__main__'
:
test_sendrecv
()
test_multi_sendrecv
()
test_update_routines
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment