Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
174c1d55
Commit
174c1d55
authored
Jun 19, 2018
by
zzhang-cn
Browse files
add edge repr func; modify mgcn.py
parent
eb507e4f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
6 deletions
+74
-6
examples/pytorch/mgcn.py
examples/pytorch/mgcn.py
+4
-2
python/dgl/graph.py
python/dgl/graph.py
+70
-4
No files found.
examples/pytorch/mgcn.py
View file @
174c1d55
...
@@ -53,6 +53,7 @@ class EdgeUpdateModule(nn.Module):
...
@@ -53,6 +53,7 @@ class EdgeUpdateModule(nn.Module):
new_he
=
self
.
net1
(
src
[
'hv'
])
+
self
.
net2
(
dst
[
'hv'
])
+
self
.
net3
(
edge
[
'he'
])
new_he
=
self
.
net1
(
src
[
'hv'
])
+
self
.
net2
(
dst
[
'hv'
])
+
self
.
net3
(
edge
[
'he'
])
return
{
'he'
:
new_he
}
return
{
'he'
:
new_he
}
# TODO: we don't need this one anymore
class
EdgeModule
(
nn
.
Module
):
class
EdgeModule
(
nn
.
Module
):
def
__init__
(
self
,
he_dims
):
def
__init__
(
self
,
he_dims
):
# use a flag to trigger either message module or edge update module.
# use a flag to trigger either message module or edge update module.
...
@@ -70,7 +71,8 @@ class EdgeModule(nn.Module):
...
@@ -70,7 +71,8 @@ class EdgeModule(nn.Module):
def
train
(
g
):
def
train
(
g
):
# TODO(minjie): finish the complete training algorithm.
# TODO(minjie): finish the complete training algorithm.
g
=
dgl
.
DGLGraph
(
g
)
g
=
dgl
.
DGLGraph
(
g
)
g
.
register_message_func
(
EdgeModule
())
g
.
register_message_func
(
MessageModule
())
g
.
register_edge_func
(
EdgeUpdateModule
())
g
.
register_update_func
(
NodeUpdateModule
())
g
.
register_update_func
(
NodeUpdateModule
())
# TODO(minjie): init hv and he
# TODO(minjie): init hv and he
num_iter
=
10
num_iter
=
10
...
@@ -78,4 +80,4 @@ def train(g):
...
@@ -78,4 +80,4 @@ def train(g):
# The first call triggers message function and update all the nodes.
# The first call triggers message function and update all the nodes.
g
.
update_all
()
g
.
update_all
()
# The second sendall updates all the edge features.
# The second sendall updates all the edge features.
g
.
send_all
()
#
g.send_all()
python/dgl/graph.py
View file @
174c1d55
...
@@ -10,8 +10,10 @@ from dgl.backend import Tensor
...
@@ -10,8 +10,10 @@ from dgl.backend import Tensor
import
dgl.utils
as
utils
import
dgl.utils
as
utils
__MSG__
=
"__msg__"
__MSG__
=
"__msg__"
__REPR__
=
"__repr__"
__E_REPR__
=
"__e_repr__"
__N_REPR__
=
"__n_repr__"
__MFUNC__
=
"__mfunc__"
__MFUNC__
=
"__mfunc__"
__EFUNC__
=
"__efunc__"
__UFUNC__
=
"__ufunc__"
__UFUNC__
=
"__ufunc__"
class
DGLGraph
(
DiGraph
):
class
DGLGraph
(
DiGraph
):
...
@@ -30,6 +32,7 @@ class DGLGraph(DiGraph):
...
@@ -30,6 +32,7 @@ class DGLGraph(DiGraph):
super
(
DGLGraph
,
self
).
__init__
(
graph_data
,
**
attr
)
super
(
DGLGraph
,
self
).
__init__
(
graph_data
,
**
attr
)
self
.
m_func
=
None
self
.
m_func
=
None
self
.
u_func
=
None
self
.
u_func
=
None
self
.
e_func
=
None
self
.
readout_func
=
None
self
.
readout_func
=
None
def
init_reprs
(
self
,
h_init
=
None
):
def
init_reprs
(
self
,
h_init
=
None
):
...
@@ -38,14 +41,14 @@ class DGLGraph(DiGraph):
...
@@ -38,14 +41,14 @@ class DGLGraph(DiGraph):
for
n
in
self
.
nodes
:
for
n
in
self
.
nodes
:
self
.
set_repr
(
n
,
h_init
)
self
.
set_repr
(
n
,
h_init
)
def
set_repr
(
self
,
u
,
h_u
,
name
=
__REPR__
):
def
set_repr
(
self
,
u
,
h_u
,
name
=
__
N_
REPR__
):
print
(
"[DEPRECATED]: please directly set node attrs "
print
(
"[DEPRECATED]: please directly set node attrs "
"(e.g. g.nodes[node]['x'] = val)."
)
"(e.g. g.nodes[node]['x'] = val)."
)
assert
u
in
self
.
nodes
assert
u
in
self
.
nodes
kwarg
=
{
name
:
h_u
}
kwarg
=
{
name
:
h_u
}
self
.
add_node
(
u
,
**
kwarg
)
self
.
add_node
(
u
,
**
kwarg
)
def
get_repr
(
self
,
u
,
name
=
__REPR__
):
def
get_repr
(
self
,
u
,
name
=
__
N_
REPR__
):
print
(
"[DEPRECATED]: please directly get node attrs "
print
(
"[DEPRECATED]: please directly get node attrs "
"(e.g. g.nodes[node]['x'])."
)
"(e.g. g.nodes[node]['x'])."
)
assert
u
in
self
.
nodes
assert
u
in
self
.
nodes
...
@@ -58,7 +61,7 @@ class DGLGraph(DiGraph):
...
@@ -58,7 +61,7 @@ class DGLGraph(DiGraph):
(node_reprs, node_reprs, edge_reprs) -> edge_reprs
(node_reprs, node_reprs, edge_reprs) -> edge_reprs
It computes the
new edge
representation
s (the same concept
a
s
message
s)
It computes the representation
of
a message
using the representations of the source node, target node and the edge
using the representations of the source node, target node and the edge
itself. All node_reprs and edge_reprs are dictionaries.
itself. All node_reprs and edge_reprs are dictionaries.
...
@@ -93,6 +96,48 @@ class DGLGraph(DiGraph):
...
@@ -93,6 +96,48 @@ class DGLGraph(DiGraph):
for
e
in
edges
:
for
e
in
edges
:
self
.
edges
[
e
][
__MFUNC__
]
=
message_func
self
.
edges
[
e
][
__MFUNC__
]
=
message_func
def
register_edge_func
(
self
,
edge_func
,
edges
=
'all'
,
batchable
=
False
):
"""Register computation on edges.
The edge function should be compatible with following signature:
(node_reprs, node_reprs, edge_reprs) -> edge_reprs
It computes the new edge representations (the same concept as messages)
using the representations of the source node, target node and the edge
itself. All node_reprs and edge_reprs are dictionaries.
Parameters
----------
edge_func : callable
Message function on the edge.
edges : str, pair of nodes, pair of containers, pair of tensors
The edges for which the message function is registered. Default is
registering for all the edges. Registering for multiple edges is
supported.
batchable : bool
Whether the provided message function allows batch computing.
Examples
--------
Register for all edges.
>>> g.register_edge_func(efunc)
Register for a specific edge.
>>> g.register_edge_func(efunc, (u, v))
Register for multiple edges.
>>> u = [u1, u2, u3, ...]
>>> v = [v1, v2, v3, ...]
>>> g.register_edge_func(mfunc, (u, v))
"""
if
edges
==
'all'
:
self
.
e_func
=
edge_func
else
:
for
e
in
edges
:
self
.
edges
[
e
][
__EFUNC__
]
=
edge_func
def
register_update_func
(
self
,
update_func
,
nodes
=
'all'
,
batchable
=
False
):
def
register_update_func
(
self
,
update_func
,
nodes
=
'all'
,
batchable
=
False
):
"""Register computation on nodes.
"""Register computation on nodes.
...
@@ -198,6 +243,24 @@ class DGLGraph(DiGraph):
...
@@ -198,6 +243,24 @@ class DGLGraph(DiGraph):
m
=
f_msg
(
self
.
nodes
[
uu
],
self
.
nodes
[
vv
],
self
.
edges
[
uu
,
vv
])
m
=
f_msg
(
self
.
nodes
[
uu
],
self
.
nodes
[
vv
],
self
.
edges
[
uu
,
vv
])
self
.
edges
[
uu
,
vv
][
__MSG__
]
=
m
self
.
edges
[
uu
,
vv
][
__MSG__
]
=
m
def
update_edge
(
self
,
u
,
v
):
"""Update representation on edge u->v
Parameters
----------
u : node, container or tensor
The source node(s).
v : node, container or tensor
The destination node(s).
"""
# TODO(minjie): tensorize the loop.
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
f_edge
=
self
.
edges
[
uu
,
vv
].
get
(
__EFUNC__
,
self
.
m_func
)
assert
f_edge
is
not
None
,
\
"edge function not registered for edge (%s->%s)"
%
(
uu
,
vv
)
m
=
f_edge
(
self
.
nodes
[
uu
],
self
.
nodes
[
vv
],
self
.
edges
[
uu
,
vv
])
self
.
edges
[
uu
,
vv
][
__E_REPR__
]
=
m
def
recvfrom
(
self
,
u
,
preds
=
None
):
def
recvfrom
(
self
,
u
,
preds
=
None
):
"""Trigger the update function on node u.
"""Trigger the update function on node u.
...
@@ -292,6 +355,9 @@ class DGLGraph(DiGraph):
...
@@ -292,6 +355,9 @@ class DGLGraph(DiGraph):
v
=
[
vv
for
_
,
vv
in
self
.
edges
]
v
=
[
vv
for
_
,
vv
in
self
.
edges
]
self
.
sendto
(
u
,
v
)
self
.
sendto
(
u
,
v
)
self
.
recvfrom
(
list
(
self
.
nodes
()))
self
.
recvfrom
(
list
(
self
.
nodes
()))
# TODO(zz): this is a hack
if
self
.
e_func
:
self
.
update_edge
(
u
,
v
)
def
propagate
(
self
,
iterator
=
'bfs'
,
**
kwargs
):
def
propagate
(
self
,
iterator
=
'bfs'
,
**
kwargs
):
"""Propagate messages and update nodes using iterator.
"""Propagate messages and update nodes using iterator.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment