Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
174c1d55
Commit
174c1d55
authored
Jun 19, 2018
by
zzhang-cn
Browse files
add edge repr func; modify mgcn.py
parent
eb507e4f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
6 deletions
+74
-6
examples/pytorch/mgcn.py
examples/pytorch/mgcn.py
+4
-2
python/dgl/graph.py
python/dgl/graph.py
+70
-4
No files found.
examples/pytorch/mgcn.py
View file @
174c1d55
...
...
@@ -53,6 +53,7 @@ class EdgeUpdateModule(nn.Module):
new_he
=
self
.
net1
(
src
[
'hv'
])
+
self
.
net2
(
dst
[
'hv'
])
+
self
.
net3
(
edge
[
'he'
])
return
{
'he'
:
new_he
}
# TODO: we don't need this one anymore
class
EdgeModule
(
nn
.
Module
):
def
__init__
(
self
,
he_dims
):
# use a flag to trigger either message module or edge update module.
...
...
@@ -70,7 +71,8 @@ class EdgeModule(nn.Module):
def
train
(
g
):
# TODO(minjie): finish the complete training algorithm.
g
=
dgl
.
DGLGraph
(
g
)
g
.
register_message_func
(
EdgeModule
())
g
.
register_message_func
(
MessageModule
())
g
.
register_edge_func
(
EdgeUpdateModule
())
g
.
register_update_func
(
NodeUpdateModule
())
# TODO(minjie): init hv and he
num_iter
=
10
...
...
@@ -78,4 +80,4 @@ def train(g):
# The first call triggers message function and update all the nodes.
g
.
update_all
()
# The second sendall updates all the edge features.
g
.
send_all
()
#
g.send_all()
python/dgl/graph.py
View file @
174c1d55
...
...
@@ -10,8 +10,10 @@ from dgl.backend import Tensor
import
dgl.utils
as
utils
__MSG__
=
"__msg__"
__REPR__
=
"__repr__"
__E_REPR__
=
"__e_repr__"
__N_REPR__
=
"__n_repr__"
__MFUNC__
=
"__mfunc__"
__EFUNC__
=
"__efunc__"
__UFUNC__
=
"__ufunc__"
class
DGLGraph
(
DiGraph
):
...
...
@@ -30,6 +32,7 @@ class DGLGraph(DiGraph):
super
(
DGLGraph
,
self
).
__init__
(
graph_data
,
**
attr
)
self
.
m_func
=
None
self
.
u_func
=
None
self
.
e_func
=
None
self
.
readout_func
=
None
def
init_reprs
(
self
,
h_init
=
None
):
...
...
@@ -38,14 +41,14 @@ class DGLGraph(DiGraph):
for
n
in
self
.
nodes
:
self
.
set_repr
(
n
,
h_init
)
def
set_repr
(
self
,
u
,
h_u
,
name
=
__REPR__
):
def
set_repr
(
self
,
u
,
h_u
,
name
=
__
N_
REPR__
):
print
(
"[DEPRECATED]: please directly set node attrs "
"(e.g. g.nodes[node]['x'] = val)."
)
assert
u
in
self
.
nodes
kwarg
=
{
name
:
h_u
}
self
.
add_node
(
u
,
**
kwarg
)
def
get_repr
(
self
,
u
,
name
=
__REPR__
):
def
get_repr
(
self
,
u
,
name
=
__
N_
REPR__
):
print
(
"[DEPRECATED]: please directly get node attrs "
"(e.g. g.nodes[node]['x'])."
)
assert
u
in
self
.
nodes
...
...
@@ -58,7 +61,7 @@ class DGLGraph(DiGraph):
(node_reprs, node_reprs, edge_reprs) -> edge_reprs
It computes the
new edge
representation
s (the same concept
a
s
message
s)
It computes the representation
of
a message
using the representations of the source node, target node and the edge
itself. All node_reprs and edge_reprs are dictionaries.
...
...
@@ -93,6 +96,48 @@ class DGLGraph(DiGraph):
for
e
in
edges
:
self
.
edges
[
e
][
__MFUNC__
]
=
message_func
def
register_edge_func
(
self
,
edge_func
,
edges
=
'all'
,
batchable
=
False
):
"""Register computation on edges.
The edge function should be compatible with following signature:
(node_reprs, node_reprs, edge_reprs) -> edge_reprs
It computes the new edge representations (the same concept as messages)
using the representations of the source node, target node and the edge
itself. All node_reprs and edge_reprs are dictionaries.
Parameters
----------
edge_func : callable
Message function on the edge.
edges : str, pair of nodes, pair of containers, pair of tensors
The edges for which the message function is registered. Default is
registering for all the edges. Registering for multiple edges is
supported.
batchable : bool
Whether the provided message function allows batch computing.
Examples
--------
Register for all edges.
>>> g.register_edge_func(efunc)
Register for a specific edge.
>>> g.register_edge_func(efunc, (u, v))
Register for multiple edges.
>>> u = [u1, u2, u3, ...]
>>> v = [v1, v2, v3, ...]
>>> g.register_edge_func(mfunc, (u, v))
"""
if
edges
==
'all'
:
self
.
e_func
=
edge_func
else
:
for
e
in
edges
:
self
.
edges
[
e
][
__EFUNC__
]
=
edge_func
def
register_update_func
(
self
,
update_func
,
nodes
=
'all'
,
batchable
=
False
):
"""Register computation on nodes.
...
...
@@ -198,6 +243,24 @@ class DGLGraph(DiGraph):
m
=
f_msg
(
self
.
nodes
[
uu
],
self
.
nodes
[
vv
],
self
.
edges
[
uu
,
vv
])
self
.
edges
[
uu
,
vv
][
__MSG__
]
=
m
def
update_edge
(
self
,
u
,
v
):
"""Update representation on edge u->v
Parameters
----------
u : node, container or tensor
The source node(s).
v : node, container or tensor
The destination node(s).
"""
# TODO(minjie): tensorize the loop.
for
uu
,
vv
in
utils
.
edge_iter
(
u
,
v
):
f_edge
=
self
.
edges
[
uu
,
vv
].
get
(
__EFUNC__
,
self
.
m_func
)
assert
f_edge
is
not
None
,
\
"edge function not registered for edge (%s->%s)"
%
(
uu
,
vv
)
m
=
f_edge
(
self
.
nodes
[
uu
],
self
.
nodes
[
vv
],
self
.
edges
[
uu
,
vv
])
self
.
edges
[
uu
,
vv
][
__E_REPR__
]
=
m
def
recvfrom
(
self
,
u
,
preds
=
None
):
"""Trigger the update function on node u.
...
...
@@ -292,6 +355,9 @@ class DGLGraph(DiGraph):
v
=
[
vv
for
_
,
vv
in
self
.
edges
]
self
.
sendto
(
u
,
v
)
self
.
recvfrom
(
list
(
self
.
nodes
()))
# TODO(zz): this is a hack
if
self
.
e_func
:
self
.
update_edge
(
u
,
v
)
def
propagate
(
self
,
iterator
=
'bfs'
,
**
kwargs
):
"""Propagate messages and update nodes using iterator.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment