Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b355d1ed
"src/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "08b60eb1628ef91a29a14de33b046f8f19808531"
Unverified
Commit
b355d1ed
authored
Nov 04, 2018
by
Minjie Wang
Committed by
GitHub
Nov 04, 2018
Browse files
[API] Apply nodes & apply edges (#117)
parent
1a2b306f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
90 additions
and
90 deletions
+90
-90
examples/pytorch/gcn/gcn.py
examples/pytorch/gcn/gcn.py
+1
-1
python/dgl/graph.py
python/dgl/graph.py
+84
-84
tests/pytorch/test_basics.py
tests/pytorch/test_basics.py
+5
-5
No files found.
examples/pytorch/gcn/gcn.py
View file @
b355d1ed
...
@@ -68,7 +68,7 @@ class GCN(nn.Module):
...
@@ -68,7 +68,7 @@ class GCN(nn.Module):
self
.
g
.
apply_nodes
(
apply_node_func
=
self
.
g
.
apply_nodes
(
apply_node_func
=
lambda
nodes
:
{
'h'
:
self
.
dropout
(
nodes
.
data
[
'h'
])})
lambda
nodes
:
{
'h'
:
self
.
dropout
(
nodes
.
data
[
'h'
])})
self
.
g
.
update_all
(
gcn_msg
,
gcn_reduce
,
layer
)
self
.
g
.
update_all
(
gcn_msg
,
gcn_reduce
,
layer
)
return
self
.
g
.
pop_n_repr
(
'h'
)
return
self
.
g
.
ndata
.
pop
(
'h'
)
def
main
(
args
):
def
main
(
args
):
# load and preprocess dataset
# load and preprocess dataset
...
...
python/dgl/graph.py
View file @
b355d1ed
...
@@ -58,8 +58,8 @@ class DGLGraph(object):
...
@@ -58,8 +58,8 @@ class DGLGraph(object):
# registered functions
# registered functions
self
.
_message_func
=
None
self
.
_message_func
=
None
self
.
_reduce_func
=
None
self
.
_reduce_func
=
None
self
.
_edge_func
=
None
self
.
_apply_node_func
=
None
self
.
_apply_node_func
=
None
self
.
_apply_edge_func
=
None
def
add_nodes
(
self
,
num
,
reprs
=
None
):
def
add_nodes
(
self
,
num
,
reprs
=
None
):
"""Add nodes.
"""Add nodes.
...
@@ -815,110 +815,79 @@ class DGLGraph(object):
...
@@ -815,110 +815,79 @@ class DGLGraph(object):
"""
"""
return
self
.
_edge_frame
.
pop
(
key
)
return
self
.
_edge_frame
.
pop
(
key
)
def
register_edge_func
(
self
,
edge_func
):
def
register_message_func
(
self
,
func
):
"""Register global edge update function.
Parameters
----------
edge_func : callable
Message function on the edge.
"""
self
.
_edge_func
=
edge_func
def
register_message_func
(
self
,
message_func
):
"""Register global message function.
"""Register global message function.
Parameters
Parameters
----------
----------
message_
func : callable
func : callable
Message function on the edge.
Message function on the edge.
"""
"""
self
.
_message_func
=
message_
func
self
.
_message_func
=
func
def
register_reduce_func
(
self
,
reduce_
func
):
def
register_reduce_func
(
self
,
func
):
"""Register global message reduce function.
"""Register global message reduce function.
Parameters
Parameters
----------
----------
reduce_
func : str or callable
func : str or callable
Reduce function on incoming edges.
Reduce function on incoming edges.
"""
"""
self
.
_reduce_func
=
reduce_
func
self
.
_reduce_func
=
func
def
register_apply_node_func
(
self
,
apply_node_
func
):
def
register_apply_node_func
(
self
,
func
):
"""Register global node apply function.
"""Register global node apply function.
Parameters
Parameters
----------
----------
apply_node_
func : callable
func : callable
Apply function on the node.
Apply function on the node.
"""
"""
self
.
_apply_node_func
=
apply_node_
func
self
.
_apply_node_func
=
func
def
apply_nodes
(
self
,
v
=
ALL
,
apply_node_func
=
"default"
):
def
register_apply_edge_func
(
self
,
func
):
"""Apply the function on node representations.
"""Register global edge apply function.
Applying a None function will be ignored.
Parameters
Parameters
----------
----------
v : int, iterable of int, tensor, optional
edge_func : callable
The node id(s).
Apply function on the edge.
apply_node_func : callable
The apply node function.
"""
"""
self
.
_apply_
nodes
(
v
,
apply_node_
func
)
self
.
_apply_
edge_func
=
func
def
_apply_nodes
(
self
,
v
,
apply_node_func
=
"default"
,
reduce_accum
=
None
):
def
apply_nodes
(
self
,
func
=
"default"
,
v
=
ALL
):
"""Internal apply nodes
"""Apply the function on the node features.
Applying a None function will be ignored.
Parameters
Parameters
----------
----------
reduce_accum: dict-like
func : callable, optional
The output of reduce func
The UDF applied on the node features.
v : int, iterable of int, tensor, optional
The node id(s).
"""
"""
if
apply_node_func
==
"default"
:
self
.
_internal_apply_nodes
(
v
,
func
)
apply_node_func
=
self
.
_apply_node_func
if
not
apply_node_func
:
def
apply_edges
(
self
,
func
=
"default"
,
edges
=
ALL
):
# Skip none function call.
"""Apply the function on the edge features.
if
reduce_accum
is
not
None
:
# write reduce result back
self
.
set_n_repr
(
reduce_accum
,
v
)
return
# take out current node repr
curr_repr
=
self
.
get_n_repr
(
v
)
if
reduce_accum
is
not
None
:
# merge current node_repr with reduce output
curr_repr
=
utils
.
HybridDict
(
reduce_accum
,
curr_repr
)
nb
=
NodeBatch
(
self
,
v
,
curr_repr
)
new_repr
=
apply_node_func
(
nb
)
if
reduce_accum
is
not
None
:
# merge new node_repr with reduce output
reduce_accum
.
update
(
new_repr
)
new_repr
=
reduce_accum
self
.
set_n_repr
(
new_repr
,
v
)
def
send
(
self
,
edges
=
ALL
,
message_func
=
"default"
):
"""Send messages along the given edges.
Parameters
Parameters
----------
----------
func : callable, optional
The UDF applied on the edge features.
edges : edges, optional
edges : edges, optional
Edges can be a pair of endpoint nodes (u, v), or a
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
tensor of edge ids. The default value is all the edges.
message_func : callable
The message function.
Notes
Notes
-----
-----
On multigraphs, if u and v are specified, then
the messages will be sent
On multigraphs, if u and v are specified, then
all the edges
along all edges between u and v
.
between u and v will be updated
.
"""
"""
if
message_func
==
"default"
:
if
func
==
"default"
:
message_func
=
self
.
_message_func
func
=
self
.
_apply_edge_func
assert
message_func
is
not
None
assert
func
is
not
None
if
isinstance
(
message_func
,
(
tuple
,
list
)):
message_func
=
BundledMessageFunction
(
message_func
)
if
is_all
(
edges
):
if
is_all
(
edges
):
eid
=
ALL
eid
=
ALL
...
@@ -938,29 +907,29 @@ class DGLGraph(object):
...
@@ -938,29 +907,29 @@ class DGLGraph(object):
dst_data
=
self
.
get_n_repr
(
v
)
dst_data
=
self
.
get_n_repr
(
v
)
eb
=
EdgeBatch
(
self
,
(
u
,
v
,
eid
),
eb
=
EdgeBatch
(
self
,
(
u
,
v
,
eid
),
src_data
,
edge_data
,
dst_data
)
src_data
,
edge_data
,
dst_data
)
msgs
=
message_func
(
eb
)
self
.
set_e_repr
(
func
(
eb
),
eid
)
self
.
_msg_graph
.
add_edges
(
u
,
v
)
self
.
_msg_frame
.
append
(
msgs
)
def
update_edges
(
self
,
edges
=
ALL
,
ed
ge_func
=
"default"
):
def
send
(
self
,
edges
,
messa
ge_func
=
"default"
):
"""
Update featur
es on the given edges.
"""
Send messag
es
al
on
g
the given edges.
Parameters
Parameters
----------
----------
edges : edges, optional
edges : edges, optional
Edges can be a pair of endpoint nodes (u, v), or a
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids.
The default value is all the edges.
tensor of edge ids.
ed
ge_func : callable
messa
ge_func : callable
The
updat
e function.
The
messag
e function.
Notes
Notes
-----
-----
On multigraphs, if u and v are specified, then
all the edges
On multigraphs, if u and v are specified, then
the messages will be sent
between u and v will be updated
.
along all edges between u and v
.
"""
"""
if
edge_func
==
"default"
:
if
message_func
==
"default"
:
edge_func
=
self
.
_edge_func
message_func
=
self
.
_message_func
assert
edge_func
is
not
None
assert
message_func
is
not
None
if
isinstance
(
message_func
,
(
tuple
,
list
)):
message_func
=
BundledMessageFunction
(
message_func
)
if
is_all
(
edges
):
if
is_all
(
edges
):
eid
=
ALL
eid
=
ALL
...
@@ -980,7 +949,9 @@ class DGLGraph(object):
...
@@ -980,7 +949,9 @@ class DGLGraph(object):
dst_data
=
self
.
get_n_repr
(
v
)
dst_data
=
self
.
get_n_repr
(
v
)
eb
=
EdgeBatch
(
self
,
(
u
,
v
,
eid
),
eb
=
EdgeBatch
(
self
,
(
u
,
v
,
eid
),
src_data
,
edge_data
,
dst_data
)
src_data
,
edge_data
,
dst_data
)
self
.
set_e_repr
(
edge_func
(
eb
),
eid
)
msgs
=
message_func
(
eb
)
self
.
_msg_graph
.
add_edges
(
u
,
v
)
self
.
_msg_frame
.
append
(
msgs
)
def
recv
(
self
,
def
recv
(
self
,
u
,
u
,
...
@@ -1008,7 +979,7 @@ class DGLGraph(object):
...
@@ -1008,7 +979,7 @@ class DGLGraph(object):
reduce_func
=
BundledReduceFunction
(
reduce_func
)
reduce_func
=
BundledReduceFunction
(
reduce_func
)
self
.
_batch_recv
(
u
,
reduce_func
)
self
.
_batch_recv
(
u
,
reduce_func
)
# optional apply nodes
# optional apply nodes
self
.
apply_nodes
(
u
,
apply_node_func
)
self
.
apply_nodes
(
apply_node_func
,
u
)
def
_batch_recv
(
self
,
v
,
reduce_func
):
def
_batch_recv
(
self
,
v
,
reduce_func
):
if
self
.
_msg_frame
.
num_rows
==
0
:
if
self
.
_msg_frame
.
num_rows
==
0
:
...
@@ -1153,7 +1124,7 @@ class DGLGraph(object):
...
@@ -1153,7 +1124,7 @@ class DGLGraph(object):
accum
=
executor
.
run
()
accum
=
executor
.
run
()
unique_v
=
executor
.
recv_nodes
unique_v
=
executor
.
recv_nodes
self
.
_apply_nodes
(
unique_v
,
apply_node_func
,
reduce_accum
=
accum
)
self
.
_
internal_
apply_nodes
(
unique_v
,
apply_node_func
,
reduce_accum
=
accum
)
def
pull
(
self
,
def
pull
(
self
,
v
,
v
,
...
@@ -1179,7 +1150,7 @@ class DGLGraph(object):
...
@@ -1179,7 +1150,7 @@ class DGLGraph(object):
uu
,
vv
,
_
=
self
.
_graph
.
in_edges
(
v
)
uu
,
vv
,
_
=
self
.
_graph
.
in_edges
(
v
)
self
.
send_and_recv
((
uu
,
vv
),
message_func
,
reduce_func
,
apply_node_func
=
None
)
self
.
send_and_recv
((
uu
,
vv
),
message_func
,
reduce_func
,
apply_node_func
=
None
)
unique_v
=
F
.
unique
(
v
.
tousertensor
())
unique_v
=
F
.
unique
(
v
.
tousertensor
())
self
.
apply_nodes
(
unique_v
,
apply_node_func
)
self
.
apply_nodes
(
apply_node_func
,
unique_v
)
def
push
(
self
,
def
push
(
self
,
u
,
u
,
...
@@ -1232,7 +1203,7 @@ class DGLGraph(object):
...
@@ -1232,7 +1203,7 @@ class DGLGraph(object):
"update_all"
,
self
,
message_func
=
message_func
,
reduce_func
=
reduce_func
)
"update_all"
,
self
,
message_func
=
message_func
,
reduce_func
=
reduce_func
)
if
executor
:
if
executor
:
new_reprs
=
executor
.
run
()
new_reprs
=
executor
.
run
()
self
.
_apply_nodes
(
ALL
,
apply_node_func
,
reduce_accum
=
new_reprs
)
self
.
_
internal_
apply_nodes
(
ALL
,
apply_node_func
,
reduce_accum
=
new_reprs
)
else
:
else
:
self
.
send
(
ALL
,
message_func
)
self
.
send
(
ALL
,
message_func
)
self
.
recv
(
ALL
,
reduce_func
,
apply_node_func
)
self
.
recv
(
ALL
,
reduce_func
,
apply_node_func
)
...
@@ -1474,3 +1445,32 @@ class DGLGraph(object):
...
@@ -1474,3 +1445,32 @@ class DGLGraph(object):
else
:
else
:
edges
=
F
.
Tensor
(
edges
)
edges
=
F
.
Tensor
(
edges
)
return
edges
[
e_mask
]
return
edges
[
e_mask
]
def
_internal_apply_nodes
(
self
,
v
,
apply_node_func
=
"default"
,
reduce_accum
=
None
):
"""Internal apply nodes
Parameters
----------
reduce_accum: dict-like
The output of reduce func
"""
if
apply_node_func
==
"default"
:
apply_node_func
=
self
.
_apply_node_func
if
not
apply_node_func
:
# Skip none function call.
if
reduce_accum
is
not
None
:
# write reduce result back
self
.
set_n_repr
(
reduce_accum
,
v
)
return
# take out current node repr
curr_repr
=
self
.
get_n_repr
(
v
)
if
reduce_accum
is
not
None
:
# merge current node_repr with reduce output
curr_repr
=
utils
.
HybridDict
(
reduce_accum
,
curr_repr
)
nb
=
NodeBatch
(
self
,
v
,
curr_repr
)
new_repr
=
apply_node_func
(
nb
)
if
reduce_accum
is
not
None
:
# merge new node_repr with reduce output
reduce_accum
.
update
(
new_repr
)
new_repr
=
reduce_accum
self
.
set_n_repr
(
new_repr
,
v
)
tests/pytorch/test_basics.py
View file @
b355d1ed
...
@@ -168,17 +168,17 @@ def test_batch_recv():
...
@@ -168,17 +168,17 @@ def test_batch_recv():
assert
(
reduce_msg_shapes
==
{(
1
,
3
,
D
),
(
3
,
1
,
D
)})
assert
(
reduce_msg_shapes
==
{(
1
,
3
,
D
),
(
3
,
1
,
D
)})
reduce_msg_shapes
.
clear
()
reduce_msg_shapes
.
clear
()
def
test_
update
_edges
():
def
test_
apply
_edges
():
def
_upd
(
edges
):
def
_upd
(
edges
):
return
{
'w'
:
edges
.
data
[
'w'
]
*
2
}
return
{
'w'
:
edges
.
data
[
'w'
]
*
2
}
g
=
generate_graph
()
g
=
generate_graph
()
g
.
register_edge_func
(
_upd
)
g
.
register_
apply_
edge_func
(
_upd
)
old
=
g
.
edata
[
'w'
]
old
=
g
.
edata
[
'w'
]
g
.
update
_edges
()
g
.
apply
_edges
()
assert
th
.
allclose
(
old
*
2
,
g
.
edata
[
'w'
])
assert
th
.
allclose
(
old
*
2
,
g
.
edata
[
'w'
])
u
=
th
.
tensor
([
0
,
0
,
0
,
4
,
5
,
6
])
u
=
th
.
tensor
([
0
,
0
,
0
,
4
,
5
,
6
])
v
=
th
.
tensor
([
1
,
2
,
3
,
9
,
9
,
9
])
v
=
th
.
tensor
([
1
,
2
,
3
,
9
,
9
,
9
])
g
.
update
_edges
(
(
u
,
v
),
lambda
edges
:
{
'w'
:
edges
.
data
[
'w'
]
*
0.
})
g
.
apply
_edges
(
lambda
edges
:
{
'w'
:
edges
.
data
[
'w'
]
*
0.
}
,
(
u
,
v
)
)
eid
=
g
.
edge_ids
(
u
,
v
)
eid
=
g
.
edge_ids
(
u
,
v
)
assert
th
.
allclose
(
g
.
edata
[
'w'
][
eid
],
th
.
zeros
((
6
,
D
)))
assert
th
.
allclose
(
g
.
edata
[
'w'
][
eid
],
th
.
zeros
((
6
,
D
)))
...
@@ -392,7 +392,7 @@ if __name__ == '__main__':
...
@@ -392,7 +392,7 @@ if __name__ == '__main__':
test_batch_setter_autograd
()
test_batch_setter_autograd
()
test_batch_send
()
test_batch_send
()
test_batch_recv
()
test_batch_recv
()
test_
update
_edges
()
test_
apply
_edges
()
test_update_routines
()
test_update_routines
()
test_reduce_0deg
()
test_reduce_0deg
()
test_pull_0deg
()
test_pull_0deg
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment