Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
916d375b
Commit
916d375b
authored
Sep 19, 2018
by
Minjie Wang
Browse files
Merge branch 'master' into cpp
parents
a1038eb1
9b0a01db
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
453 additions
and
145 deletions
+453
-145
python/dgl/function/message.py
python/dgl/function/message.py
+56
-5
python/dgl/function/reducer.py
python/dgl/function/reducer.py
+35
-15
python/dgl/graph.py
python/dgl/graph.py
+7
-1
python/dgl/scheduler.py
python/dgl/scheduler.py
+250
-123
tests/pytorch/test_specialization.py
tests/pytorch/test_specialization.py
+105
-1
No files found.
python/dgl/function/message.py
View file @
916d375b
...
@@ -2,9 +2,11 @@
...
@@ -2,9 +2,11 @@
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
operator
import
operator
import
dgl.backend
as
F
__all__
=
[
"MessageFunction"
,
"src_mul_edge"
,
"copy_src"
,
"copy_edge"
]
__all__
=
[
"MessageFunction"
,
"src_mul_edge"
,
"copy_src"
,
"copy_edge"
]
class
MessageFunction
(
object
):
class
MessageFunction
(
object
):
def
__call__
(
self
,
src
,
edge
):
def
__call__
(
self
,
src
,
edge
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -12,10 +14,28 @@ class MessageFunction(object):
...
@@ -12,10 +14,28 @@ class MessageFunction(object):
def
name
(
self
):
def
name
(
self
):
raise
NotImplementedError
raise
NotImplementedError
def
is_spmv_supported
(
self
,
g
):
raise
NotImplementedError
class
BundledMessageFunction
(
MessageFunction
):
class
BundledMessageFunction
(
MessageFunction
):
def
__init__
(
self
,
fn_list
):
def
__init__
(
self
,
fn_list
):
if
not
isinstance
(
fn_list
,
(
list
,
tuple
)):
fn_list
=
[
fn_list
]
else
:
# sanity check on out field
for
fn
in
fn_list
:
# cannot perform check for udf
if
isinstance
(
fn
,
MessageFunction
)
and
fn
.
out_field
is
None
:
raise
RuntimeError
(
"Not specifying out field for multiple message is ambiguous"
)
self
.
fn_list
=
fn_list
self
.
fn_list
=
fn_list
def
is_spmv_supported
(
self
,
g
):
for
fn
in
self
.
fn_list
:
if
not
isinstance
(
fn
,
MessageFunction
)
or
not
fn
.
is_spmv_supported
(
g
):
return
False
return
True
def
__call__
(
self
,
src
,
edge
):
def
__call__
(
self
,
src
,
edge
):
ret
=
None
ret
=
None
for
fn
in
self
.
fn_list
:
for
fn
in
self
.
fn_list
:
...
@@ -24,16 +44,34 @@ class BundledMessageFunction(MessageFunction):
...
@@ -24,16 +44,34 @@ class BundledMessageFunction(MessageFunction):
ret
=
msg
ret
=
msg
else
:
else
:
try
:
try
:
# ret and msg must be dict
ret
.
update
(
msg
)
ret
.
update
(
msg
)
except
e
:
except
:
raise
RuntimeError
(
"Failed to merge results of two builtin"
raise
RuntimeError
(
"Must specify out field for multiple message"
)
" message functions. Please specify out_field"
" for the builtin message function."
)
return
ret
return
ret
def
name
(
self
):
def
name
(
self
):
return
"bundled"
return
"bundled"
def
_is_spmv_supported_node_feat
(
g
,
field
):
if
field
is
None
:
feat
=
g
.
get_n_repr
()
else
:
feat
=
g
.
get_n_repr
()[
field
]
shape
=
F
.
shape
(
feat
)
return
len
(
shape
)
==
1
or
len
(
shape
)
==
2
def
_is_spmv_supported_edge_feat
(
g
,
field
):
# check shape, only scalar edge feature can be optimized at the moment
if
field
is
None
:
feat
=
g
.
get_e_repr
()
else
:
feat
=
g
.
get_e_repr
()[
field
]
shape
=
F
.
shape
(
feat
)
return
len
(
shape
)
==
1
or
(
len
(
shape
)
==
2
and
shape
[
1
]
==
1
)
class
SrcMulEdgeMessageFunction
(
MessageFunction
):
class
SrcMulEdgeMessageFunction
(
MessageFunction
):
def
__init__
(
self
,
mul_op
,
src_field
=
None
,
edge_field
=
None
,
out_field
=
None
):
def
__init__
(
self
,
mul_op
,
src_field
=
None
,
edge_field
=
None
,
out_field
=
None
):
self
.
mul_op
=
mul_op
self
.
mul_op
=
mul_op
...
@@ -41,6 +79,10 @@ class SrcMulEdgeMessageFunction(MessageFunction):
...
@@ -41,6 +79,10 @@ class SrcMulEdgeMessageFunction(MessageFunction):
self
.
edge_field
=
edge_field
self
.
edge_field
=
edge_field
self
.
out_field
=
out_field
self
.
out_field
=
out_field
def
is_spmv_supported
(
self
,
g
):
return
_is_spmv_supported_node_feat
(
g
,
self
.
src_field
)
\
and
_is_spmv_supported_edge_feat
(
g
,
self
.
edge_field
)
def
__call__
(
self
,
src
,
edge
):
def
__call__
(
self
,
src
,
edge
):
if
self
.
src_field
is
not
None
:
if
self
.
src_field
is
not
None
:
src
=
src
[
self
.
src_field
]
src
=
src
[
self
.
src_field
]
...
@@ -60,6 +102,9 @@ class CopySrcMessageFunction(MessageFunction):
...
@@ -60,6 +102,9 @@ class CopySrcMessageFunction(MessageFunction):
self
.
src_field
=
src_field
self
.
src_field
=
src_field
self
.
out_field
=
out_field
self
.
out_field
=
out_field
def
is_spmv_supported
(
self
,
g
):
return
_is_spmv_supported_node_feat
(
g
,
self
.
src_field
)
def
__call__
(
self
,
src
,
edge
):
def
__call__
(
self
,
src
,
edge
):
if
self
.
src_field
is
not
None
:
if
self
.
src_field
is
not
None
:
ret
=
src
[
self
.
src_field
]
ret
=
src
[
self
.
src_field
]
...
@@ -78,6 +123,11 @@ class CopyEdgeMessageFunction(MessageFunction):
...
@@ -78,6 +123,11 @@ class CopyEdgeMessageFunction(MessageFunction):
self
.
edge_field
=
edge_field
self
.
edge_field
=
edge_field
self
.
out_field
=
out_field
self
.
out_field
=
out_field
def
is_spmv_supported
(
self
,
g
):
# TODO: support this with g-spmv
return
False
# return _is_spmv_supported_edge_feat(g, self.edge_field)
def
__call__
(
self
,
src
,
edge
):
def
__call__
(
self
,
src
,
edge
):
if
self
.
edge_field
is
not
None
:
if
self
.
edge_field
is
not
None
:
ret
=
edge
[
self
.
edge_field
]
ret
=
edge
[
self
.
edge_field
]
...
@@ -91,6 +141,7 @@ class CopyEdgeMessageFunction(MessageFunction):
...
@@ -91,6 +141,7 @@ class CopyEdgeMessageFunction(MessageFunction):
def
name
(
self
):
def
name
(
self
):
return
"copy_edge"
return
"copy_edge"
def
src_mul_edge
(
src
=
None
,
edge
=
None
,
out
=
None
):
def
src_mul_edge
(
src
=
None
,
edge
=
None
,
out
=
None
):
"""TODO(minjie): docstring """
"""TODO(minjie): docstring """
return
SrcMulEdgeMessageFunction
(
operator
.
mul
,
src
,
edge
,
out
)
return
SrcMulEdgeMessageFunction
(
operator
.
mul
,
src
,
edge
,
out
)
...
...
python/dgl/function/reducer.py
View file @
916d375b
...
@@ -12,10 +12,26 @@ class ReduceFunction(object):
...
@@ -12,10 +12,26 @@ class ReduceFunction(object):
def
name
(
self
):
def
name
(
self
):
raise
NotImplementedError
raise
NotImplementedError
def
is_spmv_supported
(
self
):
raise
NotImplementedError
class
BundledReduceFunction
(
ReduceFunction
):
class
BundledReduceFunction
(
ReduceFunction
):
def
__init__
(
self
,
fn_list
):
def
__init__
(
self
,
fn_list
):
if
not
isinstance
(
fn_list
,
(
list
,
tuple
)):
fn_list
=
[
fn_list
]
else
:
# sanity check on out field
for
fn
in
fn_list
:
if
isinstance
(
fn
,
ReduceFunction
)
and
fn
.
out_field
is
None
:
raise
RuntimeError
(
"Not specifying out field for multiple reduce is ambiguous"
)
self
.
fn_list
=
fn_list
self
.
fn_list
=
fn_list
def
is_spmv_supported
(
self
):
for
fn
in
self
.
fn_list
:
if
not
isinstance
(
fn
,
ReduceFunction
)
or
not
fn
.
is_spmv_supported
():
return
False
return
True
def
__call__
(
self
,
node
,
msgs
):
def
__call__
(
self
,
node
,
msgs
):
ret
=
None
ret
=
None
for
fn
in
self
.
fn_list
:
for
fn
in
self
.
fn_list
:
...
@@ -24,46 +40,50 @@ class BundledReduceFunction(ReduceFunction):
...
@@ -24,46 +40,50 @@ class BundledReduceFunction(ReduceFunction):
ret
=
rpr
ret
=
rpr
else
:
else
:
try
:
try
:
# ret and rpr must be dict
ret
.
update
(
rpr
)
ret
.
update
(
rpr
)
except
e
:
except
:
raise
RuntimeError
(
"Failed to merge results of two builtin"
raise
RuntimeError
(
"Must specify out field for multiple reudce"
)
" reduce functions. Please specify out_field"
" for the builtin reduce function."
)
return
ret
return
ret
def
name
(
self
):
def
name
(
self
):
return
"bundled"
return
"bundled"
class
SumReducerFunction
(
ReduceFunction
):
class
ReducerFunctionTemplate
(
ReduceFunction
):
def
__init__
(
self
,
batch_sum_op
,
nonbatch_sum_op
,
msg_field
=
None
,
out_field
=
None
):
def
__init__
(
self
,
name
,
batch_op
,
nonbatch_op
,
msg_field
=
None
,
out_field
=
None
):
self
.
batch_sum_op
=
batch_sum_op
self
.
name
=
name
self
.
nonbatch_sum_op
=
nonbatch_sum_op
self
.
batch_op
=
batch_op
self
.
nonbatch_op
=
nonbatch_op
self
.
msg_field
=
msg_field
self
.
msg_field
=
msg_field
self
.
out_field
=
out_field
self
.
out_field
=
out_field
def
is_spmv_supported
(
self
):
# TODO: support max
return
self
.
name
==
"sum"
def
__call__
(
self
,
node
,
msgs
):
def
__call__
(
self
,
node
,
msgs
):
if
isinstance
(
msgs
,
list
):
if
isinstance
(
msgs
,
list
):
if
self
.
msg_field
is
None
:
if
self
.
msg_field
is
None
:
ret
=
self
.
nonbatch_
sum_
op
(
msgs
)
ret
=
self
.
nonbatch_op
(
msgs
)
else
:
else
:
ret
=
self
.
nonbatch_
sum_
op
([
msg
[
self
.
msg_field
]
for
msg
in
msgs
])
ret
=
self
.
nonbatch_op
([
msg
[
self
.
msg_field
]
for
msg
in
msgs
])
else
:
else
:
if
self
.
msg_field
is
None
:
if
self
.
msg_field
is
None
:
ret
=
self
.
batch_
sum_
op
(
msgs
,
1
)
ret
=
self
.
batch_op
(
msgs
,
1
)
else
:
else
:
ret
=
self
.
batch_
sum_
op
(
msgs
[
self
.
msg_field
],
1
)
ret
=
self
.
batch_op
(
msgs
[
self
.
msg_field
],
1
)
if
self
.
out_field
is
None
:
if
self
.
out_field
is
None
:
return
ret
return
ret
else
:
else
:
return
{
self
.
out_field
:
ret
}
return
{
self
.
out_field
:
ret
}
def
name
(
self
):
def
name
(
self
):
return
"sum"
return
self
.
name
_python_sum
=
sum
_python_sum
=
sum
def
sum
(
msgs
=
None
,
out
=
None
):
def
sum
(
msgs
=
None
,
out
=
None
):
return
Sum
ReducerFunction
(
F
.
sum
,
_python_sum
,
msgs
,
out
)
return
ReducerFunction
Template
(
"sum"
,
F
.
sum
,
_python_sum
,
msgs
,
out
)
_python_max
=
max
_python_max
=
max
def
max
(
msgs
=
None
,
out
=
None
):
def
max
(
msgs
=
None
,
out
=
None
):
return
Sum
ReducerFunction
(
F
.
max
,
_python_max
,
msgs
,
out
)
return
ReducerFunction
Template
(
"max"
,
F
.
max
,
_python_max
,
msgs
,
out
)
python/dgl/graph.py
View file @
916d375b
...
@@ -12,6 +12,8 @@ from .graph_index import GraphIndex
...
@@ -12,6 +12,8 @@ from .graph_index import GraphIndex
from
.frame
import
FrameRef
,
merge_frames
from
.frame
import
FrameRef
,
merge_frames
from
.
import
scheduler
from
.
import
scheduler
from
.
import
utils
from
.
import
utils
from
.function.message
import
BundledMessageFunction
from
.function.reducer
import
BundledReduceFunction
class
DGLGraph
(
object
):
class
DGLGraph
(
object
):
"""Base graph class specialized for neural networks on graphs.
"""Base graph class specialized for neural networks on graphs.
...
@@ -431,6 +433,8 @@ class DGLGraph(object):
...
@@ -431,6 +433,8 @@ class DGLGraph(object):
if
message_func
==
"default"
:
if
message_func
==
"default"
:
message_func
,
batchable
=
self
.
_message_func
message_func
,
batchable
=
self
.
_message_func
assert
message_func
is
not
None
assert
message_func
is
not
None
if
isinstance
(
message_func
,
(
tuple
,
list
)):
message_func
=
BundledMessageFunction
(
message_func
)
if
batchable
:
if
batchable
:
self
.
_batch_send
(
u
,
v
,
message_func
)
self
.
_batch_send
(
u
,
v
,
message_func
)
else
:
else
:
...
@@ -470,7 +474,7 @@ class DGLGraph(object):
...
@@ -470,7 +474,7 @@ class DGLGraph(object):
else
:
else
:
self
.
_msg_frame
.
append
({
__MSG__
:
msgs
})
self
.
_msg_frame
.
append
({
__MSG__
:
msgs
})
def
update_edge
(
self
,
u
,
v
,
edge_func
=
"default"
,
batchable
=
False
):
def
update_edge
(
self
,
u
=
ALL
,
v
=
ALL
,
edge_func
=
"default"
,
batchable
=
False
):
"""Update representation on edge u->v
"""Update representation on edge u->v
The edge function should be compatible with following signature:
The edge function should be compatible with following signature:
...
@@ -573,6 +577,8 @@ class DGLGraph(object):
...
@@ -573,6 +577,8 @@ class DGLGraph(object):
if
reduce_func
==
"default"
:
if
reduce_func
==
"default"
:
reduce_func
,
batchable
=
self
.
_reduce_func
reduce_func
,
batchable
=
self
.
_reduce_func
assert
reduce_func
is
not
None
assert
reduce_func
is
not
None
if
isinstance
(
reduce_func
,
(
list
,
tuple
)):
reduce_func
=
BundledReduceFunction
(
reduce_func
)
if
batchable
:
if
batchable
:
self
.
_batch_recv
(
u
,
reduce_func
)
self
.
_batch_recv
(
u
,
reduce_func
)
else
:
else
:
...
...
python/dgl/scheduler.py
View file @
916d375b
...
@@ -3,6 +3,7 @@ from __future__ import absolute_import
...
@@ -3,6 +3,7 @@ from __future__ import absolute_import
import
numpy
as
np
import
numpy
as
np
from
.base
import
ALL
from
.
import
backend
as
F
from
.
import
backend
as
F
from
.function
import
message
as
fmsg
from
.function
import
message
as
fmsg
from
.function
import
reducer
as
fred
from
.function
import
reducer
as
fred
...
@@ -38,37 +39,32 @@ def degree_bucketing(cached_graph, v):
...
@@ -38,37 +39,32 @@ def degree_bucketing(cached_graph, v):
#print('degree-bucketing:', unique_degrees, [len(b) for b in v_bkt])
#print('degree-bucketing:', unique_degrees, [len(b) for b in v_bkt])
return
unique_degrees
,
v_bkt
return
unique_degrees
,
v_bkt
class
Executor
(
object
):
class
Executor
(
object
):
def
run
(
self
,
graph
):
def
run
(
self
):
raise
NotImplementedError
raise
NotImplementedError
class
UpdateAllSPMVExecu
tor
(
Executor
):
class
SPMVOpera
tor
(
Executor
):
def
__init__
(
self
,
graph
,
src_field
,
dst
_field
,
edge
_field
,
use_
adj
):
def
__init__
(
self
,
src_field
,
edge
_field
,
dst
_field
,
use_
edge_feat
,
self
.
graph
=
graph
node_repr
,
adj_build_fn
):
self
.
src_field
=
src_field
self
.
src_field
=
src_field
self
.
dst_field
=
dst_field
self
.
edge_field
=
edge_field
self
.
edge_field
=
edge_field
self
.
use_adj
=
use_adj
self
.
dst_field
=
dst_field
self
.
use_edge_feat
=
use_edge_feat
self
.
node_repr
=
node_repr
self
.
adj_build_fn
=
adj_build_fn
def
run
(
self
):
def
run
(
self
):
g
=
self
.
graph
# get src col
if
self
.
src_field
is
None
:
if
self
.
src_field
is
None
:
srccol
=
g
.
get_n
_repr
()
srccol
=
self
.
node
_repr
else
:
else
:
srccol
=
g
.
get_n
_repr
()
[
self
.
src_field
]
srccol
=
self
.
node
_repr
[
self
.
src_field
]
ctx
=
F
.
get_context
(
srccol
)
ctx
=
F
.
get_context
(
srccol
)
if
self
.
use_adj
:
adjmat
=
g
.
cached_graph
.
adjmat
().
get
(
ctx
)
# build adjmat
else
:
adjmat
=
self
.
adj_build_fn
(
self
.
edge_field
,
ctx
,
self
.
use_edge_feat
)
if
self
.
edge_field
is
None
:
dat
=
g
.
get_e_repr
()
else
:
dat
=
g
.
get_e_repr
()[
self
.
edge_field
]
dat
=
F
.
squeeze
(
dat
)
# TODO(minjie): should not directly use _indices
idx
=
g
.
cached_graph
.
adjmat
().
get
(
ctx
).
_indices
()
n
=
g
.
number_of_nodes
()
adjmat
=
F
.
sparse_tensor
(
idx
,
dat
,
[
n
,
n
])
# spmm
# spmm
if
len
(
F
.
shape
(
srccol
))
==
1
:
if
len
(
F
.
shape
(
srccol
))
==
1
:
srccol
=
F
.
unsqueeze
(
srccol
,
1
)
srccol
=
F
.
unsqueeze
(
srccol
,
1
)
...
@@ -77,104 +73,249 @@ class UpdateAllSPMVExecutor(Executor):
...
@@ -77,104 +73,249 @@ class UpdateAllSPMVExecutor(Executor):
else
:
else
:
dstcol
=
F
.
spmm
(
adjmat
,
srccol
)
dstcol
=
F
.
spmm
(
adjmat
,
srccol
)
if
self
.
dst_field
is
None
:
if
self
.
dst_field
is
None
:
g
.
set_n_repr
(
dstcol
)
return
dstcol
else
:
else
:
g
.
set_n_repr
(
{
self
.
dst_field
:
dstcol
}
)
return
{
self
.
dst_field
:
dstcol
}
class
SendRecvSPMVExecutor
(
Executor
):
def
__init__
(
self
,
graph
,
src
,
dst
,
src_field
,
dst_field
,
edge_field
,
use_edge_dat
):
self
.
graph
=
graph
self
.
src
=
src
self
.
dst
=
dst
self
.
src_field
=
src_field
self
.
dst_field
=
dst_field
self
.
edge_field
=
edge_field
self
.
use_edge_dat
=
use_edge_dat
def
run
(
self
):
class
BasicExecutor
(
Executor
):
# get src col
def
__init__
(
self
,
graph
,
mfunc
,
rfunc
):
g
=
self
.
graph
self
.
g
=
graph
if
self
.
src_field
is
None
:
self
.
exe
=
self
.
_build_exec
(
mfunc
,
rfunc
)
srccol
=
g
.
get_n_repr
()
@
property
def
node_repr
(
self
):
raise
NotImplementedError
@
property
def
edge_repr
(
self
):
raise
NotImplementedError
@
property
def
graph_mapping
(
self
):
raise
NotImplementedError
def
_build_exec
(
self
,
mfunc
,
rfunc
):
if
isinstance
(
mfunc
,
fmsg
.
CopySrcMessageFunction
):
exe
=
SPMVOperator
(
src_field
=
mfunc
.
src_field
,
edge_field
=
None
,
dst_field
=
rfunc
.
out_field
,
use_edge_feat
=
False
,
node_repr
=
self
.
node_repr
,
adj_build_fn
=
self
.
_adj_build_fn
)
elif
isinstance
(
mfunc
,
fmsg
.
SrcMulEdgeMessageFunction
):
exe
=
SPMVOperator
(
src_field
=
mfunc
.
src_field
,
edge_field
=
mfunc
.
edge_field
,
dst_field
=
rfunc
.
out_field
,
use_edge_feat
=
True
,
node_repr
=
self
.
node_repr
,
adj_build_fn
=
self
.
_adj_build_fn
)
else
:
else
:
srccol
=
g
.
get_n_repr
()[
self
.
src_field
]
raise
NotImplementedError
(
"message func type {}"
.
format
(
type
(
mfunc
)))
ctx
=
F
.
get_context
(
srccol
)
return
exe
# build adjmat
def
run
(
self
):
# build adjmat dat
attr
=
self
.
exe
.
run
()
u
,
v
=
utils
.
edge_broadcasting
(
self
.
src
,
self
.
dst
)
self
.
g
.
set_n_repr
(
attr
,
self
.
graph_mapping
)
if
self
.
use_edge_dat
:
if
self
.
edge_field
is
None
:
dat
=
g
.
get_e_repr
(
u
,
v
)
class
UpdateAllExecutor
(
BasicExecutor
):
def
__init__
(
self
,
graph
,
mfunc
,
rfunc
):
self
.
_init_state
()
super
(
UpdateAllExecutor
,
self
).
__init__
(
graph
,
mfunc
,
rfunc
)
def
_init_state
(
self
):
self
.
_node_repr
=
None
self
.
_edge_repr
=
None
self
.
_graph_idx
=
None
self
.
_graph_shape
=
None
self
.
_graph_mapping
=
None
@
property
def
graph_idx
(
self
):
if
self
.
_graph_idx
is
None
:
self
.
_graph_idx
=
self
.
g
.
cached_graph
.
adjmat
()
return
self
.
_graph_idx
@
property
def
graph_shape
(
self
):
if
self
.
_graph_shape
is
None
:
n
=
self
.
g
.
number_of_nodes
()
self
.
_graph_shape
=
[
n
,
n
]
return
self
.
_graph_shape
@
property
def
graph_mapping
(
self
):
return
ALL
@
property
def
node_repr
(
self
):
if
self
.
_node_repr
is
None
:
self
.
_node_repr
=
self
.
g
.
get_n_repr
()
return
self
.
_node_repr
@
property
def
edge_repr
(
self
):
if
self
.
_edge_repr
is
None
:
self
.
_edge_repr
=
self
.
g
.
get_e_repr
()
return
self
.
_edge_repr
def
_adj_build_fn
(
self
,
edge_field
,
ctx
,
use_edge_feat
):
if
use_edge_feat
:
if
edge_field
is
None
:
dat
=
self
.
edge_repr
else
:
else
:
dat
=
g
.
get_e_repr
(
u
,
v
)[
self
.
edge_field
]
dat
=
self
.
edge_repr
[
edge_field
]
dat
=
F
.
squeeze
(
dat
)
dat
=
F
.
squeeze
(
dat
)
# TODO(minjie): should not directly use _indices
idx
=
self
.
graph_idx
.
get
(
ctx
).
_indices
()
adjmat
=
F
.
sparse_tensor
(
idx
,
dat
,
self
.
graph_shape
)
else
:
else
:
dat
=
F
.
ones
((
len
(
u
),))
adjmat
=
self
.
graph_idx
.
get
(
ctx
)
# build adjmat index
return
adjmat
new2old
,
old2new
=
utils
.
build_relabel_map
(
v
)
u
=
u
.
totensor
()
v
=
v
.
totensor
()
class
SendRecvExecutor
(
BasicExecutor
):
def
__init__
(
self
,
graph
,
src
,
dst
,
mfunc
,
rfunc
):
self
.
_init_state
(
src
,
dst
)
super
(
SendRecvExecutor
,
self
).
__init__
(
graph
,
mfunc
,
rfunc
)
def
_init_state
(
self
,
src
,
dst
):
self
.
u
,
self
.
v
=
utils
.
edge_broadcasting
(
src
,
dst
)
self
.
_node_repr
=
None
self
.
_edge_repr
=
None
self
.
_graph_idx
=
None
self
.
_graph_shape
=
None
self
.
_graph_mapping
=
None
@
property
def
graph_idx
(
self
):
if
self
.
_graph_idx
is
None
:
self
.
_build_adjmat
()
return
self
.
_graph_idx
@
property
def
graph_shape
(
self
):
if
self
.
_graph_shape
is
None
:
self
.
_build_adjmat
()
return
self
.
_graph_shape
@
property
def
graph_mapping
(
self
):
if
self
.
_graph_mapping
is
None
:
self
.
_build_adjmat
()
return
self
.
_graph_mapping
@
property
def
node_repr
(
self
):
if
self
.
_node_repr
is
None
:
self
.
_node_repr
=
self
.
g
.
get_n_repr
()
return
self
.
_node_repr
@
property
def
edge_repr
(
self
):
if
self
.
_edge_repr
is
None
:
self
.
_edge_repr
=
self
.
g
.
get_e_repr
(
self
.
u
,
self
.
v
)
return
self
.
_edge_repr
def
_build_adjmat
(
self
):
# handle graph index
new2old
,
old2new
=
utils
.
build_relabel_map
(
self
.
v
)
u
=
self
.
u
.
totensor
()
v
=
self
.
v
.
totensor
()
# TODO(minjie): should not directly use []
# TODO(minjie): should not directly use []
new_v
=
old2new
[
v
]
new_v
=
old2new
[
v
]
idx
=
F
.
pack
([
F
.
unsqueeze
(
new_v
,
0
),
F
.
unsqueeze
(
u
,
0
)])
n
=
self
.
g
.
number_of_nodes
()
n
=
g
.
number_of_nodes
()
m
=
len
(
new2old
)
m
=
len
(
new2old
)
adjmat
=
F
.
sparse_tensor
(
idx
,
dat
,
[
m
,
n
])
self
.
_graph_idx
=
F
.
pack
([
F
.
unsqueeze
(
new_v
,
0
),
F
.
unsqueeze
(
u
,
0
)])
adjmat
=
F
.
to_context
(
adjmat
,
ctx
)
self
.
_graph_shape
=
[
m
,
n
]
# spmm
self
.
_graph_mapping
=
new2old
if
len
(
F
.
shape
(
srccol
))
==
1
:
srccol
=
F
.
unsqueeze
(
srccol
,
1
)
def
_adj_build_fn
(
self
,
edge_field
,
ctx
,
use_edge_feat
):
dstcol
=
F
.
spmm
(
adjmat
,
srccol
)
if
use_edge_feat
:
dstcol
=
F
.
squeeze
(
dstcol
)
if
edge_field
is
None
:
dat
=
self
.
edge_repr
else
:
else
:
dstcol
=
F
.
spmm
(
adjmat
,
srccol
)
dat
=
self
.
edge_repr
[
edge_field
]
if
self
.
dst_field
is
None
:
dat
=
F
.
squeeze
(
dat
)
g
.
set_n_repr
(
dstcol
,
new2old
)
else
:
else
:
g
.
set_n_repr
({
self
.
dst_field
:
dstcol
},
new2old
)
dat
=
F
.
ones
((
len
(
self
.
u
),
))
adjmat
=
F
.
sparse_tensor
(
self
.
graph_idx
,
dat
,
self
.
graph_shape
)
return
F
.
to_context
(
adjmat
,
ctx
)
def
_is_spmv_supported_node_feat
(
g
,
field
):
class
BundledExecutor
(
BasicExecutor
):
if
field
is
None
:
"""
feat
=
g
.
get_n_repr
()
Base class for Bundled execution
All shared structure like graph index should be cached in this class or its subclass
BundledUpdateAllExecutor and BundledSendRecvExecutor should subclass BundledExecutor
"""
def
__init__
(
self
,
graph
,
mfunc
,
rfunc
):
self
.
g
=
graph
func_pairs
=
self
.
_match_message_with_reduce
(
mfunc
,
rfunc
)
# create all executors
self
.
executors
=
self
.
_build_executors
(
func_pairs
)
def
_build_executors
(
self
,
func_pairs
):
executors
=
[]
for
mfunc
,
rfunc
in
func_pairs
:
exe
=
self
.
_build_exec
(
mfunc
,
rfunc
)
executors
.
append
(
exe
)
return
executors
def
_match_message_with_reduce
(
self
,
mfunc
,
rfunc
):
out2mfunc
=
{
fn
.
out_field
:
fn
for
fn
in
mfunc
.
fn_list
}
func_pairs
=
[]
for
rfn
in
rfunc
.
fn_list
:
mfn
=
out2mfunc
.
get
(
rfn
.
msg_field
,
None
)
# field check
assert
mfn
is
not
None
,
\
"cannot find message func for reduce func in-field {}"
.
format
(
rfn
.
msg_field
)
func_pairs
.
append
((
mfn
,
rfn
))
return
func_pairs
def
run
(
self
):
attr
=
None
for
exe
in
self
.
executors
:
res
=
exe
.
run
()
if
attr
is
None
:
attr
=
res
else
:
else
:
feat
=
g
.
get_n_repr
()[
field
]
# attr and res must be dict
shape
=
F
.
shape
(
feat
)
attr
.
update
(
res
)
return
(
len
(
shape
)
==
1
or
len
(
shape
)
==
2
)
self
.
g
.
set_n_repr
(
attr
,
self
.
graph_mapping
)
def
_is_spmv_supported_edge_feat
(
g
,
field
):
# check shape, only scalar edge feature can be optimized at the moment.
class
BundledUpdateAllExecutor
(
BundledExecutor
,
UpdateAllExecutor
):
if
field
is
None
:
def
__init__
(
self
,
graph
,
mfunc
,
rfunc
):
feat
=
g
.
get_e_repr
()
self
.
_init_state
()
BundledExecutor
.
__init__
(
self
,
graph
,
mfunc
,
rfunc
)
class
BundledSendRecvExecutor
(
BundledExecutor
,
SendRecvExecutor
):
def
__init__
(
self
,
graph
,
src
,
dst
,
mfunc
,
rfunc
):
self
.
_init_state
(
src
,
dst
)
BundledExecutor
.
__init__
(
self
,
graph
,
mfunc
,
rfunc
)
def
_is_spmv_supported
(
fn
,
graph
=
None
):
if
isinstance
(
fn
,
fmsg
.
MessageFunction
):
return
fn
.
is_spmv_supported
(
graph
)
elif
isinstance
(
fn
,
fred
.
ReduceFunction
):
return
fn
.
is_spmv_supported
()
else
:
else
:
feat
=
g
.
get_e_repr
()[
field
]
return
False
shape
=
F
.
shape
(
feat
)
return
len
(
shape
)
==
1
or
(
len
(
shape
)
==
2
and
shape
[
1
]
==
1
)
def
_create_update_all_exec
(
graph
,
**
kwargs
):
def
_create_update_all_exec
(
graph
,
**
kwargs
):
mfunc
=
kwargs
.
pop
(
'message_func'
)
mfunc
=
kwargs
.
pop
(
'message_func'
)
rfunc
=
kwargs
.
pop
(
'reduce_func'
)
rfunc
=
kwargs
.
pop
(
'reduce_func'
)
if
(
isinstance
(
mfunc
,
fmsg
.
CopySrcMessageFunction
)
if
isinstance
(
mfunc
,
(
list
,
tuple
))
or
isinstance
(
rfunc
,
(
list
,
tuple
)):
and
isinstance
(
rfunc
,
fred
.
SumReducerFunction
)
mfunc
=
fmsg
.
BundledMessageFunction
(
mfunc
)
and
_is_spmv_supported_node_feat
(
graph
,
mfunc
.
src_field
)):
rfunc
=
fred
.
BundledReduceFunction
(
rfunc
)
# TODO(minjie): more sanity check on field names
exec_cls
=
BundledUpdateAllExecutor
return
UpdateAllSPMVExecutor
(
graph
,
else
:
src_field
=
mfunc
.
src_field
,
exec_cls
=
UpdateAllExecutor
dst_field
=
rfunc
.
out_field
,
if
_is_spmv_supported
(
mfunc
,
graph
)
and
_is_spmv_supported
(
rfunc
):
edge_field
=
None
,
return
exec_cls
(
graph
,
mfunc
=
mfunc
,
rfunc
=
rfunc
)
use_adj
=
True
)
elif
(
isinstance
(
mfunc
,
fmsg
.
SrcMulEdgeMessageFunction
)
and
isinstance
(
rfunc
,
fred
.
SumReducerFunction
)
and
_is_spmv_supported_node_feat
(
graph
,
mfunc
.
src_field
)
and
_is_spmv_supported_edge_feat
(
graph
,
mfunc
.
edge_field
)):
return
UpdateAllSPMVExecutor
(
graph
,
src_field
=
mfunc
.
src_field
,
dst_field
=
rfunc
.
out_field
,
edge_field
=
mfunc
.
edge_field
,
use_adj
=
False
)
elif
(
isinstance
(
mfunc
,
fmsg
.
CopyEdgeMessageFunction
)
and
isinstance
(
rfunc
,
fred
.
SumReducerFunction
)):
return
None
else
:
else
:
return
None
return
None
...
@@ -183,28 +324,14 @@ def _create_send_and_recv_exec(graph, **kwargs):
...
@@ -183,28 +324,14 @@ def _create_send_and_recv_exec(graph, **kwargs):
dst
=
kwargs
.
pop
(
'dst'
)
dst
=
kwargs
.
pop
(
'dst'
)
mfunc
=
kwargs
.
pop
(
'message_func'
)
mfunc
=
kwargs
.
pop
(
'message_func'
)
rfunc
=
kwargs
.
pop
(
'reduce_func'
)
rfunc
=
kwargs
.
pop
(
'reduce_func'
)
if
(
isinstance
(
mfunc
,
fmsg
.
CopySrcMessageFunction
)
if
isinstance
(
mfunc
,
(
list
,
tuple
))
or
isinstance
(
rfunc
,
(
list
,
tuple
)):
and
isinstance
(
rfunc
,
fred
.
SumReducerFunction
)
mfunc
=
fmsg
.
BundledMessageFunction
(
mfunc
)
and
_is_spmv_supported_node_feat
(
graph
,
mfunc
.
src_field
)):
rfunc
=
fred
.
BundledReduceFunction
(
rfunc
)
# TODO(minjie): more sanity check on field names
exec_cls
=
BundledSendRecvExecutor
return
SendRecvSPMVExecutor
(
graph
,
else
:
src
=
src
,
exec_cls
=
SendRecvExecutor
dst
=
dst
,
if
_is_spmv_supported
(
mfunc
,
graph
)
and
_is_spmv_supported
(
rfunc
):
src_field
=
mfunc
.
src_field
,
return
exec_cls
(
graph
,
src
=
src
,
dst
=
dst
,
mfunc
=
mfunc
,
rfunc
=
rfunc
)
dst_field
=
rfunc
.
out_field
,
edge_field
=
None
,
use_edge_dat
=
False
)
elif
(
isinstance
(
mfunc
,
fmsg
.
SrcMulEdgeMessageFunction
)
and
isinstance
(
rfunc
,
fred
.
SumReducerFunction
)
and
_is_spmv_supported_node_feat
(
graph
,
mfunc
.
src_field
)
and
_is_spmv_supported_edge_feat
(
graph
,
mfunc
.
edge_field
)):
return
SendRecvSPMVExecutor
(
graph
,
src
=
src
,
dst
=
dst
,
src_field
=
mfunc
.
src_field
,
dst_field
=
rfunc
.
out_field
,
edge_field
=
mfunc
.
edge_field
,
use_edge_dat
=
True
)
else
:
else
:
return
None
return
None
...
...
tests/pytorch/test_specialization.py
View file @
916d375b
...
@@ -113,6 +113,110 @@ def test_send_and_recv():
...
@@ -113,6 +113,110 @@ def test_send_and_recv():
# test 2d node features
# test 2d node features
_test
(
'f2'
)
_test
(
'f2'
)
def
test_update_all_multi_fn
():
def
message_func
(
hu
,
edge
):
return
{
'm2'
:
hu
[
'f2'
]}
def
message_func_edge
(
hu
,
edge
):
return
{
'm2'
:
hu
[
'f2'
]
*
edge
[
'e2'
]}
def
reduce_func
(
hv
,
msgs
):
return
{
'v2'
:
th
.
sum
(
msgs
[
'm2'
],
1
)}
g
=
generate_graph
()
fld
=
'f2'
# update all, mix of builtin and UDF
g
.
update_all
([
fn
.
copy_src
(
src
=
fld
,
out
=
'm1'
),
message_func
],
[
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v1'
),
reduce_func
],
None
,
batchable
=
True
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
assert
th
.
allclose
(
v1
,
v2
)
# run builtin with single message and reduce
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
'v1'
),
None
,
batchable
=
True
)
v1
=
g
.
get_n_repr
()[
'v1'
]
assert
th
.
allclose
(
v1
,
v2
)
# 1 message, 2 reduces, using anonymous repr
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
[
fn
.
sum
(
out
=
'v2'
),
fn
.
sum
(
out
=
'v3'
)],
None
,
batchable
=
True
)
v2
=
g
.
get_n_repr
()[
'v2'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v3
)
# update all with edge weights, 2 message, 3 reduces
g
.
update_all
([
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm1'
),
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm2'
)],
[
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v1'
),
fn
.
sum
(
msgs
=
'm2'
,
out
=
'v2'
),
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v3'
)],
None
,
batchable
=
True
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v3
)
# run UDF with single message and reduce
g
.
update_all
(
message_func_edge
,
reduce_func
,
None
,
batchable
=
True
)
v2
=
g
.
get_n_repr
()[
'v2'
]
assert
th
.
allclose
(
v1
,
v2
)
def
test_send_and_recv_multi_fn
():
u
=
th
.
tensor
([
0
,
0
,
0
,
3
,
4
,
9
])
v
=
th
.
tensor
([
1
,
2
,
3
,
9
,
9
,
0
])
def
message_func
(
hu
,
edge
):
return
{
'm2'
:
hu
[
'f2'
]}
def
message_func_edge
(
hu
,
edge
):
return
{
'm2'
:
hu
[
'f2'
]
*
edge
[
'e2'
]}
def
reduce_func
(
hv
,
msgs
):
return
{
'v2'
:
th
.
sum
(
msgs
[
'm2'
],
1
)}
g
=
generate_graph
()
fld
=
'f2'
# send and recv, mix of builtin and UDF
g
.
send_and_recv
(
u
,
v
,
[
fn
.
copy_src
(
src
=
fld
,
out
=
'm1'
),
message_func
],
[
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v1'
),
reduce_func
],
None
,
batchable
=
True
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
assert
th
.
allclose
(
v1
,
v2
)
# run builtin with single message and reduce
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
'v1'
),
None
,
batchable
=
True
)
v1
=
g
.
get_n_repr
()[
'v1'
]
assert
th
.
allclose
(
v1
,
v2
)
# 1 message, 2 reduces, using anonymous repr
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
[
fn
.
sum
(
out
=
'v2'
),
fn
.
sum
(
out
=
'v3'
)],
None
,
batchable
=
True
)
v2
=
g
.
get_n_repr
()[
'v2'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v3
)
# send and recv with edge weights, 2 message, 3 reduces
g
.
send_and_recv
(
u
,
v
,
[
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm1'
),
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm2'
)],
[
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v1'
),
fn
.
sum
(
msgs
=
'm2'
,
out
=
'v2'
),
fn
.
sum
(
msgs
=
'm1'
,
out
=
'v3'
)],
None
,
batchable
=
True
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v3
)
# run UDF with single message and reduce
g
.
send_and_recv
(
u
,
v
,
message_func_edge
,
reduce_func
,
None
,
batchable
=
True
)
v2
=
g
.
get_n_repr
()[
'v2'
]
assert
th
.
allclose
(
v1
,
v2
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
#
test_update_all()
test_update_all
()
test_send_and_recv
()
test_send_and_recv
()
test_update_all_multi_fn
()
test_send_and_recv_multi_fn
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment