Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f310e586
Commit
f310e586
authored
Jun 14, 2018
by
Minjie Wang
Browse files
Add test/example; fix bug in graph.py
parent
15a2c22c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
118 additions
and
6 deletions
+118
-6
examples/pagerank.py
examples/pagerank.py
+34
-0
python/dgl/graph.py
python/dgl/graph.py
+9
-6
tests/test_basics.py
tests/test_basics.py
+75
-0
No files found.
examples/pagerank.py
0 → 100644
View file @
f310e586
from
__future__
import
division
import
networkx
as
nx
from
dgl.graph
import
DGLGraph
DAMP
=
0.85
N
=
100
K
=
10
def
message_func
(
src
,
dst
,
edge
):
return
src
[
'pv'
]
/
src
[
'deg'
]
def
update_func
(
node
,
msgs
):
pv
=
(
1
-
DAMP
)
/
N
+
DAMP
*
sum
(
msgs
)
return
{
'pv'
:
pv
}
def
compute_pagerank
(
g
):
g
=
DGLGraph
(
g
)
print
(
g
.
number_of_edges
(),
g
.
number_of_nodes
())
g
.
register_message_func
(
message_func
)
g
.
register_update_func
(
update_func
)
# init pv value
for
n
in
g
.
nodes
():
g
.
node
[
n
][
'pv'
]
=
1
/
N
g
.
node
[
n
][
'deg'
]
=
g
.
out_degree
(
n
)
# pagerank
for
k
in
range
(
K
):
g
.
update_all
()
return
[
g
.
node
[
n
][
'pv'
]
for
n
in
g
.
nodes
()]
if
__name__
==
'__main__'
:
g
=
nx
.
erdos_renyi_graph
(
N
,
0.05
)
pv
=
compute_pagerank
(
g
)
print
(
pv
)
python/dgl/graph.py
View file @
f310e586
"""Base graph class specialized for neural networks on graphs.
"""Base graph class specialized for neural networks on graphs.
"""
"""
from
collections
import
defaultdict
import
networkx
as
nx
import
networkx
as
nx
from
networkx.classes.digraph
import
DiGraph
from
networkx.classes.digraph
import
DiGraph
...
@@ -170,7 +171,7 @@ class DGLGraph(DiGraph):
...
@@ -170,7 +171,7 @@ class DGLGraph(DiGraph):
"""
"""
nodes
=
self
.
_nodes_or_all
(
nodes
)
nodes
=
self
.
_nodes_or_all
(
nodes
)
edges
=
self
.
_nodes_or_all
(
nodes
)
edges
=
self
.
_nodes_or_all
(
nodes
)
assert
self
.
readout_func
is
not
None
,
assert
self
.
readout_func
is
not
None
,
\
"Readout function is not registered."
"Readout function is not registered."
# TODO(minjie): tensorize following loop.
# TODO(minjie): tensorize following loop.
nstates
=
[
self
.
nodes
[
n
]
for
n
in
nodes
]
nstates
=
[
self
.
nodes
[
n
]
for
n
in
nodes
]
...
@@ -190,7 +191,7 @@ class DGLGraph(DiGraph):
...
@@ -190,7 +191,7 @@ class DGLGraph(DiGraph):
# TODO(minjie): tensorize the loop.
# TODO(minjie): tensorize the loop.
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
f_msg
=
self
.
edges
[
uu
,
vv
].
get
(
__MFUNC__
,
self
.
m_func
)
f_msg
=
self
.
edges
[
uu
,
vv
].
get
(
__MFUNC__
,
self
.
m_func
)
assert
f_msg
is
not
None
,
assert
f_msg
is
not
None
,
\
"message function not registered for edge (%s->%s)"
%
(
uu
,
vv
)
"message function not registered for edge (%s->%s)"
%
(
uu
,
vv
)
m
=
f_msg
(
self
.
nodes
[
uu
],
self
.
nodes
[
vv
],
self
.
edges
[
uu
,
vv
])
m
=
f_msg
(
self
.
nodes
[
uu
],
self
.
nodes
[
vv
],
self
.
edges
[
uu
,
vv
])
self
.
edges
[
uu
,
vv
][
__MSG__
]
=
m
self
.
edges
[
uu
,
vv
][
__MSG__
]
=
m
...
@@ -224,9 +225,9 @@ class DGLGraph(DiGraph):
...
@@ -224,9 +225,9 @@ class DGLGraph(DiGraph):
# TODO(minjie): tensorize the message batching
# TODO(minjie): tensorize the message batching
m
=
[
self
.
edges
[
vv
,
uu
][
__MSG__
]
for
vv
in
v
]
m
=
[
self
.
edges
[
vv
,
uu
][
__MSG__
]
for
vv
in
v
]
f_update
=
self
.
nodes
[
uu
].
get
(
__UFUNC__
,
self
.
u_func
)
f_update
=
self
.
nodes
[
uu
].
get
(
__UFUNC__
,
self
.
u_func
)
assert
f_update
is
not
None
,
assert
f_update
is
not
None
,
\
"Update function not registered for node %s"
%
uu
"Update function not registered for node %s"
%
uu
self
.
node
s
[
uu
]
=
f_update
(
self
.
nodes
[
uu
],
m
)
self
.
node
[
uu
]
.
update
(
f_update
(
self
.
nodes
[
uu
],
m
)
)
def
update_by_edge
(
self
,
u
,
v
):
def
update_by_edge
(
self
,
u
,
v
):
"""Trigger the message function on u->v and update v.
"""Trigger the message function on u->v and update v.
...
@@ -283,9 +284,9 @@ class DGLGraph(DiGraph):
...
@@ -283,9 +284,9 @@ class DGLGraph(DiGraph):
u
=
[
uu
for
uu
,
_
in
self
.
edges
]
u
=
[
uu
for
uu
,
_
in
self
.
edges
]
v
=
[
vv
for
_
,
vv
in
self
.
edges
]
v
=
[
vv
for
_
,
vv
in
self
.
edges
]
self
.
sendto
(
u
,
v
)
self
.
sendto
(
u
,
v
)
self
.
recvfrom
(
v
)
self
.
recvfrom
(
list
(
self
.
nodes
())
)
def
propagate
(
self
,
iterator
=
'bfs'
):
def
propagate
(
self
,
iterator
=
'bfs'
,
**
kwargs
):
"""Propagate messages and update nodes using iterator.
"""Propagate messages and update nodes using iterator.
A convenient function for passing messages and updating
A convenient function for passing messages and updating
...
@@ -299,6 +300,8 @@ class DGLGraph(DiGraph):
...
@@ -299,6 +300,8 @@ class DGLGraph(DiGraph):
----------
----------
iterator : str or generator of steps.
iterator : str or generator of steps.
The iterator of the graph.
The iterator of the graph.
kwargs : keyword arguments, optional
Arguments for pre-defined iterators.
"""
"""
if
isinstance
(
iterator
,
str
):
if
isinstance
(
iterator
,
str
):
# TODO Call pre-defined routine to unroll the computation.
# TODO Call pre-defined routine to unroll the computation.
...
...
tests/test_basics.py
0 → 100644
View file @
f310e586
from
dgl.graph
import
DGLGraph
def
message_func
(
src
,
dst
,
edge
):
return
src
[
'h'
]
def
update_func
(
node
,
msgs
):
m
=
sum
(
msgs
)
return
{
'h'
:
node
[
'h'
]
+
m
}
def
generate_graph
():
g
=
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_node
(
i
,
h
=
i
+
1
)
# 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
i
,
9
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
)
return
g
def
check
(
g
,
h
):
nh
=
[
str
(
g
.
nodes
[
i
][
'h'
])
for
i
in
range
(
10
)]
h
=
[
str
(
x
)
for
x
in
h
]
assert
nh
==
h
,
"nh=[%s], h=[%s]"
%
(
' '
.
join
(
nh
),
' '
.
join
(
h
))
def
test_sendrecv
():
g
=
generate_graph
()
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_func
)
g
.
register_update_func
(
update_func
)
g
.
sendto
(
0
,
1
)
g
.
recvfrom
(
1
,
[
0
])
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
sendto
(
5
,
9
)
g
.
sendto
(
6
,
9
)
g
.
recvfrom
(
9
,
[
5
,
6
])
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
23
])
def
test_multi_sendrecv
():
g
=
generate_graph
()
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_func
)
g
.
register_update_func
(
update_func
)
# one-many
g
.
sendto
(
0
,
[
1
,
2
,
3
])
g
.
recvfrom
([
1
,
2
,
3
],
[[
0
],
[
0
],
[
0
]])
check
(
g
,
[
1
,
3
,
4
,
5
,
5
,
6
,
7
,
8
,
9
,
10
])
# many-one
g
.
sendto
([
6
,
7
,
8
],
9
)
g
.
recvfrom
(
9
,
[
6
,
7
,
8
])
check
(
g
,
[
1
,
3
,
4
,
5
,
5
,
6
,
7
,
8
,
9
,
34
])
# many-many
g
.
sendto
([
0
,
0
,
4
,
5
],
[
4
,
5
,
9
,
9
])
g
.
recvfrom
([
4
,
5
,
9
],
[[
0
],
[
0
],
[
4
,
5
]])
check
(
g
,
[
1
,
3
,
4
,
5
,
6
,
7
,
7
,
8
,
9
,
45
])
def
test_update_routines
():
g
=
generate_graph
()
check
(
g
,
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
register_message_func
(
message_func
)
g
.
register_update_func
(
update_func
)
g
.
update_by_edge
(
0
,
1
)
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
])
g
.
update_to
(
9
)
check
(
g
,
[
1
,
3
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
55
])
g
.
update_from
(
0
)
check
(
g
,
[
1
,
4
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
55
])
g
.
update_all
()
check
(
g
,
[
56
,
5
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
108
])
if
__name__
==
'__main__'
:
test_sendrecv
()
test_multi_sendrecv
()
test_update_routines
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment