Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3e76bcc0
Commit
3e76bcc0
authored
Oct 18, 2018
by
Minjie Wang
Browse files
remove anonymous repr
parent
fb6be9fb
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
276 additions
and
600 deletions
+276
-600
examples/pytorch/gcn/gcn_spmv.py
examples/pytorch/gcn/gcn_spmv.py
+1
-1
python/dgl/__init__.py
python/dgl/__init__.py
+1
-1
python/dgl/base.py
python/dgl/base.py
+0
-3
python/dgl/function/message.py
python/dgl/function/message.py
+59
-46
python/dgl/function/reducer.py
python/dgl/function/reducer.py
+37
-35
python/dgl/graph.py
python/dgl/graph.py
+74
-104
python/dgl/scheduler.py
python/dgl/scheduler.py
+22
-40
tests/pytorch/test_basics.py
tests/pytorch/test_basics.py
+13
-14
tests/pytorch/test_basics_anonymous.py
tests/pytorch/test_basics_anonymous.py
+0
-198
tests/pytorch/test_batched_graph.py
tests/pytorch/test_batched_graph.py
+30
-30
tests/pytorch/test_function.py
tests/pytorch/test_function.py
+0
-86
tests/pytorch/test_line_graph.py
tests/pytorch/test_line_graph.py
+5
-9
tests/pytorch/test_specialization.py
tests/pytorch/test_specialization.py
+34
-33
No files found.
examples/pytorch/gcn/gcn_spmv.py
View file @
3e76bcc0
...
@@ -56,7 +56,7 @@ class GCN(nn.Module):
...
@@ -56,7 +56,7 @@ class GCN(nn.Module):
g
.
apply_nodes
(
apply_node_func
=
g
.
apply_nodes
(
apply_node_func
=
lambda
node
:
F
.
dropout
(
node
[
'h'
],
p
=
self
.
dropout
))
lambda
node
:
F
.
dropout
(
node
[
'h'
],
p
=
self
.
dropout
))
self
.
g
.
update_all
(
fn
.
copy_src
(
src
=
'h'
,
out
=
'm'
),
self
.
g
.
update_all
(
fn
.
copy_src
(
src
=
'h'
,
out
=
'm'
),
fn
.
sum
(
msg
s
=
'm'
,
out
=
'h'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
),
layer
)
layer
)
return
self
.
g
.
pop_n_repr
(
'h'
)
return
self
.
g
.
pop_n_repr
(
'h'
)
...
...
python/dgl/__init__.py
View file @
3e76bcc0
...
@@ -11,5 +11,5 @@ from ._ffi.base import DGLError, __version__
...
@@ -11,5 +11,5 @@ from ._ffi.base import DGLError, __version__
from
.base
import
ALL
from
.base
import
ALL
from
.batched_graph
import
*
from
.batched_graph
import
*
from
.generator
import
*
from
.generator
import
*
from
.graph
import
DGLGraph
,
__MSG__
,
__REPR__
from
.graph
import
DGLGraph
from
.subgraph
import
DGLSubGraph
from
.subgraph
import
DGLSubGraph
python/dgl/base.py
View file @
3e76bcc0
...
@@ -11,7 +11,4 @@ ALL = "__ALL__"
...
@@ -11,7 +11,4 @@ ALL = "__ALL__"
def
is_all
(
arg
):
def
is_all
(
arg
):
return
isinstance
(
arg
,
str
)
and
arg
==
ALL
return
isinstance
(
arg
,
str
)
and
arg
==
ALL
__MSG__
=
"__MSG__"
__REPR__
=
"__REPR__"
dgl_warning
=
warnings
.
warn
dgl_warning
=
warnings
.
warn
python/dgl/function/message.py
View file @
3e76bcc0
...
@@ -4,17 +4,25 @@ from __future__ import absolute_import
...
@@ -4,17 +4,25 @@ from __future__ import absolute_import
import
operator
import
operator
import
dgl.backend
as
F
import
dgl.backend
as
F
__all__
=
[
"MessageFunction"
,
"src_mul_edge"
,
"copy_src"
,
"copy_edge"
]
__all__
=
[
"src_mul_edge"
,
"copy_src"
,
"copy_edge"
]
class
MessageFunction
(
object
):
class
MessageFunction
(
object
):
"""Base builtin message function class."""
def
__call__
(
self
,
src
,
edge
):
def
__call__
(
self
,
src
,
edge
):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise
NotImplementedError
raise
NotImplementedError
def
name
(
self
):
def
name
(
self
):
"""Return the name of this builtin function."""
raise
NotImplementedError
raise
NotImplementedError
def
is_spmv_supported
(
self
,
g
):
def
is_spmv_supported
(
self
,
g
):
"""Return whether the SPMV optimization is supported."""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -22,12 +30,6 @@ class BundledMessageFunction(MessageFunction):
...
@@ -22,12 +30,6 @@ class BundledMessageFunction(MessageFunction):
def
__init__
(
self
,
fn_list
):
def
__init__
(
self
,
fn_list
):
if
not
isinstance
(
fn_list
,
(
list
,
tuple
)):
if
not
isinstance
(
fn_list
,
(
list
,
tuple
)):
fn_list
=
[
fn_list
]
fn_list
=
[
fn_list
]
else
:
# sanity check on out field
for
fn
in
fn_list
:
# cannot perform check for udf
if
isinstance
(
fn
,
MessageFunction
)
and
fn
.
out_field
is
None
:
raise
RuntimeError
(
"Not specifying out field for multiple message is ambiguous"
)
self
.
fn_list
=
fn_list
self
.
fn_list
=
fn_list
def
is_spmv_supported
(
self
,
g
):
def
is_spmv_supported
(
self
,
g
):
...
@@ -43,11 +45,8 @@ class BundledMessageFunction(MessageFunction):
...
@@ -43,11 +45,8 @@ class BundledMessageFunction(MessageFunction):
if
ret
is
None
:
if
ret
is
None
:
ret
=
msg
ret
=
msg
else
:
else
:
try
:
# ret and msg must be dict
# ret and msg must be dict
ret
.
update
(
msg
)
ret
.
update
(
msg
)
except
:
raise
RuntimeError
(
"Must specify out field for multiple message"
)
return
ret
return
ret
def
name
(
self
):
def
name
(
self
):
...
@@ -55,25 +54,26 @@ class BundledMessageFunction(MessageFunction):
...
@@ -55,25 +54,26 @@ class BundledMessageFunction(MessageFunction):
def
_is_spmv_supported_node_feat
(
g
,
field
):
def
_is_spmv_supported_node_feat
(
g
,
field
):
if
field
is
None
:
"""Return whether the node feature shape supports SPMV optimization.
feat
=
g
.
get_n_repr
()
else
:
Only scalar and vector features are supported currently.
feat
=
g
.
get_n_repr
()[
field
]
"""
feat
=
g
.
get_n_repr
()[
field
]
shape
=
F
.
shape
(
feat
)
shape
=
F
.
shape
(
feat
)
return
len
(
shape
)
==
1
or
len
(
shape
)
==
2
return
len
(
shape
)
==
1
or
len
(
shape
)
==
2
def
_is_spmv_supported_edge_feat
(
g
,
field
):
def
_is_spmv_supported_edge_feat
(
g
,
field
):
# check shape, only scalar edge feature can be optimized at the moment
"""Return whether the edge feature shape supports SPMV optimization.
if
field
is
None
:
feat
=
g
.
get_e_repr
()
Only scalar feature is supported currently.
else
:
"""
feat
=
g
.
get_e_repr
()[
field
]
feat
=
g
.
get_e_repr
()[
field
]
shape
=
F
.
shape
(
feat
)
shape
=
F
.
shape
(
feat
)
return
len
(
shape
)
==
1
or
(
len
(
shape
)
==
2
and
shape
[
1
]
==
1
)
return
len
(
shape
)
==
1
or
(
len
(
shape
)
==
2
and
shape
[
1
]
==
1
)
class
SrcMulEdgeMessageFunction
(
MessageFunction
):
class
SrcMulEdgeMessageFunction
(
MessageFunction
):
def
__init__
(
self
,
mul_op
,
src_field
=
None
,
edge_field
=
None
,
out_field
=
None
):
def
__init__
(
self
,
mul_op
,
src_field
,
edge_field
,
out_field
):
self
.
mul_op
=
mul_op
self
.
mul_op
=
mul_op
self
.
src_field
=
src_field
self
.
src_field
=
src_field
self
.
edge_field
=
edge_field
self
.
edge_field
=
edge_field
...
@@ -84,21 +84,14 @@ class SrcMulEdgeMessageFunction(MessageFunction):
...
@@ -84,21 +84,14 @@ class SrcMulEdgeMessageFunction(MessageFunction):
and
_is_spmv_supported_edge_feat
(
g
,
self
.
edge_field
)
and
_is_spmv_supported_edge_feat
(
g
,
self
.
edge_field
)
def
__call__
(
self
,
src
,
edge
):
def
__call__
(
self
,
src
,
edge
):
if
self
.
src_field
is
not
None
:
ret
=
self
.
mul_op
(
src
[
self
.
src_field
],
edge
[
self
.
edge_field
])
src
=
src
[
self
.
src_field
]
return
{
self
.
out_field
:
ret
}
if
self
.
edge_field
is
not
None
:
edge
=
edge
[
self
.
edge_field
]
ret
=
self
.
mul_op
(
src
,
edge
)
if
self
.
out_field
is
None
:
return
ret
else
:
return
{
self
.
out_field
:
ret
}
def
name
(
self
):
def
name
(
self
):
return
"src_mul_edge"
return
"src_mul_edge"
class
CopySrcMessageFunction
(
MessageFunction
):
class
CopySrcMessageFunction
(
MessageFunction
):
def
__init__
(
self
,
src_field
=
None
,
out_field
=
None
):
def
__init__
(
self
,
src_field
,
out_field
):
self
.
src_field
=
src_field
self
.
src_field
=
src_field
self
.
out_field
=
out_field
self
.
out_field
=
out_field
...
@@ -106,14 +99,7 @@ class CopySrcMessageFunction(MessageFunction):
...
@@ -106,14 +99,7 @@ class CopySrcMessageFunction(MessageFunction):
return
_is_spmv_supported_node_feat
(
g
,
self
.
src_field
)
return
_is_spmv_supported_node_feat
(
g
,
self
.
src_field
)
def
__call__
(
self
,
src
,
edge
):
def
__call__
(
self
,
src
,
edge
):
if
self
.
src_field
is
not
None
:
return
{
self
.
out_field
:
src
[
self
.
src_field
]}
ret
=
src
[
self
.
src_field
]
else
:
ret
=
src
if
self
.
out_field
is
None
:
return
ret
else
:
return
{
self
.
out_field
:
ret
}
def
name
(
self
):
def
name
(
self
):
return
"copy_src"
return
"copy_src"
...
@@ -142,14 +128,41 @@ class CopyEdgeMessageFunction(MessageFunction):
...
@@ -142,14 +128,41 @@ class CopyEdgeMessageFunction(MessageFunction):
return
"copy_edge"
return
"copy_edge"
def
src_mul_edge
(
src
=
None
,
edge
=
None
,
out
=
None
):
def
src_mul_edge
(
src
,
edge
,
out
):
"""TODO(minjie): docstring """
"""Builtin message function that computes message by multiplying source node features
with edge features.
Parameters
----------
src : str
The source feature name.
edge : str
The edge feature name.
out : str
The output message name.
"""
return
SrcMulEdgeMessageFunction
(
operator
.
mul
,
src
,
edge
,
out
)
return
SrcMulEdgeMessageFunction
(
operator
.
mul
,
src
,
edge
,
out
)
def
copy_src
(
src
=
None
,
out
=
None
):
def
copy_src
(
src
,
out
):
"""TODO(minjie): docstring """
"""Builtin message function that computes message using source node feature.
Parameters
----------
src : str
The source feature name.
out : str
The output message name.
"""
return
CopySrcMessageFunction
(
src
,
out
)
return
CopySrcMessageFunction
(
src
,
out
)
def
copy_edge
(
edge
=
None
,
out
=
None
):
def
copy_edge
(
edge
,
out
):
"""TODO(minjie): docstring """
"""Builtin message function that computes message using edge feature.
Parameters
----------
edge : str
The edge feature name.
out : str
The output message name.
"""
return
CopyEdgeMessageFunction
(
edge
,
out
)
return
CopyEdgeMessageFunction
(
edge
,
out
)
python/dgl/function/reducer.py
View file @
3e76bcc0
...
@@ -3,27 +3,30 @@ from __future__ import absolute_import
...
@@ -3,27 +3,30 @@ from __future__ import absolute_import
from
..
import
backend
as
F
from
..
import
backend
as
F
__all__
=
[
"ReduceFunction"
,
"sum"
,
"max"
]
__all__
=
[
"sum"
,
"max"
]
class
ReduceFunction
(
object
):
class
ReduceFunction
(
object
):
"""Base builtin reduce function class."""
def
__call__
(
self
,
node
,
msgs
):
def
__call__
(
self
,
node
,
msgs
):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise
NotImplementedError
raise
NotImplementedError
def
name
(
self
):
def
name
(
self
):
"""Return the name of this builtin function."""
raise
NotImplementedError
raise
NotImplementedError
def
is_spmv_supported
(
self
):
def
is_spmv_supported
(
self
):
"""Return whether the SPMV optimization is supported."""
raise
NotImplementedError
raise
NotImplementedError
class
BundledReduceFunction
(
ReduceFunction
):
class
BundledReduceFunction
(
ReduceFunction
):
def
__init__
(
self
,
fn_list
):
def
__init__
(
self
,
fn_list
):
if
not
isinstance
(
fn_list
,
(
list
,
tuple
)):
if
not
isinstance
(
fn_list
,
(
list
,
tuple
)):
fn_list
=
[
fn_list
]
fn_list
=
[
fn_list
]
else
:
# sanity check on out field
for
fn
in
fn_list
:
if
isinstance
(
fn
,
ReduceFunction
)
and
fn
.
out_field
is
None
:
raise
RuntimeError
(
"Not specifying out field for multiple reduce is ambiguous"
)
self
.
fn_list
=
fn_list
self
.
fn_list
=
fn_list
def
is_spmv_supported
(
self
):
def
is_spmv_supported
(
self
):
...
@@ -39,51 +42,50 @@ class BundledReduceFunction(ReduceFunction):
...
@@ -39,51 +42,50 @@ class BundledReduceFunction(ReduceFunction):
if
ret
is
None
:
if
ret
is
None
:
ret
=
rpr
ret
=
rpr
else
:
else
:
try
:
# ret and rpr must be dict
# ret and rpr must be dict
ret
.
update
(
rpr
)
ret
.
update
(
rpr
)
except
:
raise
RuntimeError
(
"Must specify out field for multiple reudce"
)
return
ret
return
ret
def
name
(
self
):
def
name
(
self
):
return
"bundled"
return
"bundled"
class
ReducerFunctionTemplate
(
ReduceFunction
):
class
ReducerFunctionTemplate
(
ReduceFunction
):
def
__init__
(
self
,
name
,
batch_op
,
nonbatch_
op
,
msg_field
=
None
,
out_field
=
None
):
def
__init__
(
self
,
name
,
op
,
msg_field
,
out_field
):
self
.
name
=
name
self
.
name
=
name
self
.
batch_op
=
batch_op
self
.
op
=
op
self
.
nonbatch_op
=
nonbatch_op
self
.
msg_field
=
msg_field
self
.
msg_field
=
msg_field
self
.
out_field
=
out_field
self
.
out_field
=
out_field
def
is_spmv_supported
(
self
):
def
is_spmv_supported
(
self
):
#
TODO: support max
#
NOTE: only sum is supported right now.
return
self
.
name
==
"sum"
return
self
.
name
==
"sum"
def
__call__
(
self
,
node
,
msgs
):
def
__call__
(
self
,
node
,
msgs
):
if
isinstance
(
msgs
,
list
):
return
{
self
.
out_field
:
self
.
op
(
msgs
[
self
.
msg_field
],
1
)}
if
self
.
msg_field
is
None
:
ret
=
self
.
nonbatch_op
(
msgs
)
else
:
ret
=
self
.
nonbatch_op
([
msg
[
self
.
msg_field
]
for
msg
in
msgs
])
else
:
if
self
.
msg_field
is
None
:
ret
=
self
.
batch_op
(
msgs
,
1
)
else
:
ret
=
self
.
batch_op
(
msgs
[
self
.
msg_field
],
1
)
if
self
.
out_field
is
None
:
return
ret
else
:
return
{
self
.
out_field
:
ret
}
def
name
(
self
):
def
name
(
self
):
return
self
.
name
return
self
.
name
_python_sum
=
sum
def
sum
(
msg
,
out
):
def
sum
(
msgs
=
None
,
out
=
None
):
"""Builtin reduce function that aggregates messages by sum.
return
ReducerFunctionTemplate
(
"sum"
,
F
.
sum
,
_python_sum
,
msgs
,
out
)
Parameters
----------
msg : str
The message name.
out : str
The output node feature name.
"""
return
ReducerFunctionTemplate
(
"sum"
,
F
.
sum
,
msg
,
out
)
def
max
(
msg
,
out
):
"""Builtin reduce function that aggregates messages by max.
_python_max
=
max
Parameters
def
max
(
msgs
=
None
,
out
=
None
):
----------
return
ReducerFunctionTemplate
(
"max"
,
F
.
max
,
_python_max
,
msgs
,
out
)
msg : str
The message name.
out : str
The output node feature name.
"""
return
ReducerFunctionTemplate
(
"max"
,
F
.
max
,
msg
,
out
)
python/dgl/graph.py
View file @
3e76bcc0
...
@@ -6,7 +6,7 @@ import networkx as nx
...
@@ -6,7 +6,7 @@ import networkx as nx
import
numpy
as
np
import
numpy
as
np
import
dgl
import
dgl
from
.base
import
ALL
,
is_all
,
__MSG__
,
__REPR__
from
.base
import
ALL
,
is_all
,
DGLError
,
dgl_warning
from
.
import
backend
as
F
from
.
import
backend
as
F
from
.backend
import
Tensor
from
.backend
import
Tensor
from
.frame
import
FrameRef
,
merge_frames
from
.frame
import
FrameRef
,
merge_frames
...
@@ -22,7 +22,6 @@ class DGLGraph(object):
...
@@ -22,7 +22,6 @@ class DGLGraph(object):
"""Base graph class specialized for neural networks on graphs.
"""Base graph class specialized for neural networks on graphs.
TODO(minjie): document of batching semantics
TODO(minjie): document of batching semantics
TODO(minjie): document of __REPR__ semantics
Parameters
Parameters
----------
----------
...
@@ -448,7 +447,9 @@ class DGLGraph(object):
...
@@ -448,7 +447,9 @@ class DGLGraph(object):
The nx graph
The nx graph
"""
"""
nx_graph
=
self
.
_graph
.
to_networkx
()
nx_graph
=
self
.
_graph
.
to_networkx
()
#TODO: attributes
#TODO(minjie): attributes
dgl_warning
(
'to_networkx currently does not support converting'
' node/edge features automatically.'
)
return
nx_graph
return
nx_graph
def
from_networkx
(
self
,
nx_graph
,
node_attrs
=
None
,
edge_attrs
=
None
):
def
from_networkx
(
self
,
nx_graph
,
node_attrs
=
None
,
edge_attrs
=
None
):
...
@@ -550,20 +551,17 @@ class DGLGraph(object):
...
@@ -550,20 +551,17 @@ class DGLGraph(object):
def
set_n_repr
(
self
,
hu
,
u
=
ALL
,
inplace
=
False
):
def
set_n_repr
(
self
,
hu
,
u
=
ALL
,
inplace
=
False
):
"""Set node(s) representation.
"""Set node(s) representation.
To set multiple node representations at once, pass `u` with a tensor or
`hu` is a dictionary from the feature name to feature tensor. Each tensor
a supported container of node ids. In this case, `hu` must be a tensor
is of shape (B, D1, D2, ...), where B is the number of nodes to be updated,
of shape (B, D1, D2, ...), where B is the number of the nodes and
and (D1, D2, ...) be the shape of the node representation tensor. The
(D1, D2, ...) is the shape of the node representation tensor.
length of the given node ids must match B (i.e, len(u) == B).
Dictionary type is also supported for `hu`. In this case, each item
will be treated as separate attribute of the nodes.
All update will be done out-placely to work with autograd unless the inplace
All update will be done out-placely to work with autograd unless the inplace
flag is true.
flag is true.
Parameters
Parameters
----------
----------
hu :
tensor or
dict of tensor
hu : dict of tensor
Node representation.
Node representation.
u : node, container or tensor
u : node, container or tensor
The node(s).
The node(s).
...
@@ -571,32 +569,31 @@ class DGLGraph(object):
...
@@ -571,32 +569,31 @@ class DGLGraph(object):
True if the update is done inplacely
True if the update is done inplacely
"""
"""
# sanity check
# sanity check
if
not
utils
.
is_dict_like
(
hu
):
raise
DGLError
(
'Expect dictionary type for feature data.'
' Got "%s" instead.'
%
type
(
hu
))
if
is_all
(
u
):
if
is_all
(
u
):
num_nodes
=
self
.
number_of_nodes
()
num_nodes
=
self
.
number_of_nodes
()
else
:
else
:
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
num_nodes
=
len
(
u
)
num_nodes
=
len
(
u
)
if
utils
.
is_dict_like
(
hu
):
for
key
,
val
in
hu
.
items
(
):
for
key
,
val
in
hu
.
items
():
nfeats
=
F
.
shape
(
val
)[
0
]
assert
F
.
shape
(
val
)[
0
]
=
=
num_nodes
if
nfeats
!
=
num_nodes
:
else
:
raise
DGLError
(
'Expect number of features to match number of nodes (len(u)).'
assert
F
.
shape
(
hu
)[
0
]
==
num_nodes
' Got %d and %d instead.'
%
(
nfeats
,
num_nodes
))
# set
# set
if
is_all
(
u
):
if
is_all
(
u
):
if
utils
.
is_dict_like
(
hu
):
for
key
,
val
in
hu
.
items
():
for
key
,
val
in
hu
.
items
():
self
.
_node_frame
[
key
]
=
val
self
.
_node_frame
[
key
]
=
val
else
:
self
.
_node_frame
[
__REPR__
]
=
hu
else
:
else
:
if
utils
.
is_dict_like
(
hu
):
self
.
_node_frame
.
update_rows
(
u
,
hu
,
inplace
=
inplace
)
self
.
_node_frame
.
update_rows
(
u
,
hu
,
inplace
=
inplace
)
else
:
self
.
_node_frame
.
update_rows
(
u
,
{
__REPR__
:
hu
},
inplace
=
inplace
)
def
get_n_repr
(
self
,
u
=
ALL
):
def
get_n_repr
(
self
,
u
=
ALL
):
"""Get node(s) representation.
"""Get node(s) representation.
The returned feature tensor batches multiple node features on the first dimension.
Parameters
Parameters
----------
----------
u : node, container or tensor
u : node, container or tensor
...
@@ -605,23 +602,17 @@ class DGLGraph(object):
...
@@ -605,23 +602,17 @@ class DGLGraph(object):
Returns
Returns
-------
-------
dict
dict
Representation dict
Representation dict
from feature name to feature tensor.
"""
"""
if
len
(
self
.
node_attr_schemes
())
==
0
:
if
len
(
self
.
node_attr_schemes
())
==
0
:
return
dict
()
return
dict
()
if
is_all
(
u
):
if
is_all
(
u
):
if
len
(
self
.
_node_frame
)
==
1
and
__REPR__
in
self
.
_node_frame
:
return
dict
(
self
.
_node_frame
)
return
self
.
_node_frame
[
__REPR__
]
else
:
return
dict
(
self
.
_node_frame
)
else
:
else
:
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
if
len
(
self
.
_node_frame
)
==
1
and
__REPR__
in
self
.
_node_frame
:
return
self
.
_node_frame
.
select_rows
(
u
)
return
self
.
_node_frame
.
select_rows
(
u
)[
__REPR__
]
else
:
return
self
.
_node_frame
.
select_rows
(
u
)
def
pop_n_repr
(
self
,
key
=
__REPR__
):
def
pop_n_repr
(
self
,
key
):
"""Get and remove the specified node repr.
"""Get and remove the specified node repr.
Parameters
Parameters
...
@@ -636,23 +627,19 @@ class DGLGraph(object):
...
@@ -636,23 +627,19 @@ class DGLGraph(object):
"""
"""
return
self
.
_node_frame
.
pop
(
key
)
return
self
.
_node_frame
.
pop
(
key
)
def
set_e_repr
(
self
,
h
_uv
,
u
=
ALL
,
v
=
ALL
,
inplace
=
False
):
def
set_e_repr
(
self
,
h
e
,
u
=
ALL
,
v
=
ALL
,
inplace
=
False
):
"""Set edge(s) representation.
"""Set edge(s) representation.
To set multiple edge representations at once, pass `u` and `v` with tensors or
`he` is a dictionary from the feature name to feature tensor. Each tensor
supported containers of node ids. In this case, `h_uv` must be a tensor
is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
of shape (B, D1, D2, ...), where B is the number of the edges and
and (D1, D2, ...) be the shape of the edge representation tensor.
(D1, D2, ...) is the shape of the edge representation tensor.
Dictionary type is also supported for `h_uv`. In this case, each item
will be treated as separate attribute of the edges.
All update will be done out-placely to work with autograd unless the inplace
All update will be done out-placely to work with autograd unless the inplace
flag is true.
flag is true.
Parameters
Parameters
----------
----------
h
_uv
: tensor or dict of tensor
h
e
: tensor or dict of tensor
Edge representation.
Edge representation.
u : node, container or tensor
u : node, container or tensor
The source node(s).
The source node(s).
...
@@ -662,26 +649,33 @@ class DGLGraph(object):
...
@@ -662,26 +649,33 @@ class DGLGraph(object):
True if the update is done inplacely
True if the update is done inplacely
"""
"""
# sanity check
# sanity check
if
not
utils
.
is_dict_like
(
he
):
raise
DGLError
(
'Expect dictionary type for feature data.'
' Got "%s" instead.'
%
type
(
he
))
u_is_all
=
is_all
(
u
)
u_is_all
=
is_all
(
u
)
v_is_all
=
is_all
(
v
)
v_is_all
=
is_all
(
v
)
assert
u_is_all
==
v_is_all
assert
u_is_all
==
v_is_all
if
u_is_all
:
if
u_is_all
:
self
.
set_e_repr_by_id
(
h
_uv
,
eid
=
ALL
,
inplace
=
inplace
)
self
.
set_e_repr_by_id
(
h
e
,
eid
=
ALL
,
inplace
=
inplace
)
else
:
else
:
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
v
=
utils
.
toindex
(
v
)
_
,
_
,
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
_
,
_
,
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
self
.
set_e_repr_by_id
(
h
_uv
,
eid
=
eid
,
inplace
=
inplace
)
self
.
set_e_repr_by_id
(
h
e
,
eid
=
eid
,
inplace
=
inplace
)
def
set_e_repr_by_id
(
self
,
h
_uv
,
eid
=
ALL
,
inplace
=
False
):
def
set_e_repr_by_id
(
self
,
h
e
,
eid
=
ALL
,
inplace
=
False
):
"""Set edge(s) representation by edge id.
"""Set edge(s) representation by edge id.
`he` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
and (D1, D2, ...) be the shape of the edge representation tensor.
All update will be done out-placely to work with autograd unless the inplace
All update will be done out-placely to work with autograd unless the inplace
flag is true.
flag is true.
Parameters
Parameters
----------
----------
h
_uv
: tensor or dict of tensor
h
e
: tensor or dict of tensor
Edge representation.
Edge representation.
eid : int, container or tensor
eid : int, container or tensor
The edge id(s).
The edge id(s).
...
@@ -689,30 +683,27 @@ class DGLGraph(object):
...
@@ -689,30 +683,27 @@ class DGLGraph(object):
True if the update is done inplacely
True if the update is done inplacely
"""
"""
# sanity check
# sanity check
if
not
utils
.
is_dict_like
(
he
):
raise
DGLError
(
'Expect dictionary type for feature data.'
' Got "%s" instead.'
%
type
(
he
))
if
is_all
(
eid
):
if
is_all
(
eid
):
num_edges
=
self
.
number_of_edges
()
num_edges
=
self
.
number_of_edges
()
else
:
else
:
eid
=
utils
.
toindex
(
eid
)
eid
=
utils
.
toindex
(
eid
)
num_edges
=
len
(
eid
)
num_edges
=
len
(
eid
)
if
utils
.
is_dict_like
(
h_uv
):
for
key
,
val
in
he
.
items
(
):
for
key
,
val
in
h_uv
.
items
():
nfeats
=
F
.
shape
(
val
)[
0
]
assert
F
.
shape
(
val
)[
0
]
=
=
num_edges
if
nfeats
!
=
num_edges
:
else
:
raise
DGLError
(
'Expect number of features to match number of edges.'
assert
F
.
shape
(
h_uv
)[
0
]
==
num_edges
' Got %d and %d instead.'
%
(
nfeats
,
num_edges
))
# set
# set
if
is_all
(
eid
):
if
is_all
(
eid
):
# update column
# update column
if
utils
.
is_dict_like
(
h_uv
):
for
key
,
val
in
he
.
items
():
for
key
,
val
in
h_uv
.
items
():
self
.
_edge_frame
[
key
]
=
val
self
.
_edge_frame
[
key
]
=
val
else
:
self
.
_edge_frame
[
__REPR__
]
=
h_uv
else
:
else
:
# update row
# update row
if
utils
.
is_dict_like
(
h_uv
):
self
.
_edge_frame
.
update_rows
(
eid
,
he
,
inplace
=
inplace
)
self
.
_edge_frame
.
update_rows
(
eid
,
h_uv
,
inplace
=
inplace
)
else
:
self
.
_edge_frame
.
update_rows
(
eid
,
{
__REPR__
:
h_uv
},
inplace
=
inplace
)
def
get_e_repr
(
self
,
u
=
ALL
,
v
=
ALL
):
def
get_e_repr
(
self
,
u
=
ALL
,
v
=
ALL
):
"""Get node(s) representation.
"""Get node(s) representation.
...
@@ -742,7 +733,7 @@ class DGLGraph(object):
...
@@ -742,7 +733,7 @@ class DGLGraph(object):
_
,
_
,
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
_
,
_
,
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
return
self
.
get_e_repr_by_id
(
eid
=
eid
)
return
self
.
get_e_repr_by_id
(
eid
=
eid
)
def
pop_e_repr
(
self
,
key
=
__REPR__
):
def
pop_e_repr
(
self
,
key
):
"""Get and remove the specified edge repr.
"""Get and remove the specified edge repr.
Parameters
Parameters
...
@@ -768,21 +759,15 @@ class DGLGraph(object):
...
@@ -768,21 +759,15 @@ class DGLGraph(object):
Returns
Returns
-------
-------
dict
dict
Representation dict
Representation dict
from feature name to feature tensor.
"""
"""
if
len
(
self
.
edge_attr_schemes
())
==
0
:
if
len
(
self
.
edge_attr_schemes
())
==
0
:
return
dict
()
return
dict
()
if
is_all
(
eid
):
if
is_all
(
eid
):
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
return
dict
(
self
.
_edge_frame
)
return
self
.
_edge_frame
[
__REPR__
]
else
:
return
dict
(
self
.
_edge_frame
)
else
:
else
:
eid
=
utils
.
toindex
(
eid
)
eid
=
utils
.
toindex
(
eid
)
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
return
self
.
_edge_frame
.
select_rows
(
eid
)
return
self
.
_edge_frame
.
select_rows
(
eid
)[
__REPR__
]
else
:
return
self
.
_edge_frame
.
select_rows
(
eid
)
def
register_edge_func
(
self
,
edge_func
):
def
register_edge_func
(
self
,
edge_func
):
"""Register global edge update function.
"""Register global edge update function.
...
@@ -837,6 +822,8 @@ class DGLGraph(object):
...
@@ -837,6 +822,8 @@ class DGLGraph(object):
def
apply_nodes
(
self
,
v
=
ALL
,
apply_node_func
=
"default"
):
def
apply_nodes
(
self
,
v
=
ALL
,
apply_node_func
=
"default"
):
"""Apply the function on node representations.
"""Apply the function on node representations.
Applying a None function will be ignored.
Parameters
Parameters
----------
----------
v : int, iterable of int, tensor, optional
v : int, iterable of int, tensor, optional
...
@@ -868,7 +855,7 @@ class DGLGraph(object):
...
@@ -868,7 +855,7 @@ class DGLGraph(object):
# merge current node_repr with reduce output
# merge current node_repr with reduce output
curr_repr
=
utils
.
HybridDict
(
reduce_accum
,
curr_repr
)
curr_repr
=
utils
.
HybridDict
(
reduce_accum
,
curr_repr
)
new_repr
=
apply_node_func
(
curr_repr
)
new_repr
=
apply_node_func
(
curr_repr
)
if
reduce_accum
is
not
None
and
utils
.
is_dict_like
(
new_repr
)
:
if
reduce_accum
is
not
None
:
# merge new node_repr with reduce output
# merge new node_repr with reduce output
reduce_accum
.
update
(
new_repr
)
reduce_accum
.
update
(
new_repr
)
new_repr
=
reduce_accum
new_repr
=
reduce_accum
...
@@ -877,6 +864,8 @@ class DGLGraph(object):
...
@@ -877,6 +864,8 @@ class DGLGraph(object):
def
apply_edges
(
self
,
u
=
None
,
v
=
None
,
apply_edge_func
=
"default"
,
eid
=
None
):
def
apply_edges
(
self
,
u
=
None
,
v
=
None
,
apply_edge_func
=
"default"
,
eid
=
None
):
"""Apply the function on edge representations.
"""Apply the function on edge representations.
Applying a None function will be ignored.
Parameters
Parameters
----------
----------
u : optional, int, iterable of int, tensor
u : optional, int, iterable of int, tensor
...
@@ -893,7 +882,6 @@ class DGLGraph(object):
...
@@ -893,7 +882,6 @@ class DGLGraph(object):
if
not
apply_edge_func
:
if
not
apply_edge_func
:
# Skip none function call.
# Skip none function call.
return
return
if
eid
is
None
:
if
eid
is
None
:
new_repr
=
apply_edge_func
(
self
.
get_e_repr
(
u
,
v
))
new_repr
=
apply_edge_func
(
self
.
get_e_repr
(
u
,
v
))
self
.
set_e_repr
(
new_repr
,
u
,
v
)
self
.
set_e_repr
(
new_repr
,
u
,
v
)
...
@@ -914,9 +902,8 @@ class DGLGraph(object):
...
@@ -914,9 +902,8 @@ class DGLGraph(object):
The message function can be any of the pre-defined functions
The message function can be any of the pre-defined functions
('from_src').
('from_src').
Currently, we require the message functions of consecutive send's and
Currently, we require the message functions of consecutive send's to
send_on's to return the same keys. Otherwise the behavior will be
return the same keys. Otherwise the behavior will be undefined.
undefined.
Parameters
Parameters
----------
----------
...
@@ -964,10 +951,7 @@ class DGLGraph(object):
...
@@ -964,10 +951,7 @@ class DGLGraph(object):
edge_reprs
=
self
.
get_e_repr_by_id
(
eid
)
edge_reprs
=
self
.
get_e_repr_by_id
(
eid
)
msgs
=
message_func
(
src_reprs
,
edge_reprs
)
msgs
=
message_func
(
src_reprs
,
edge_reprs
)
self
.
_msg_graph
.
add_edges
(
u
,
v
)
self
.
_msg_graph
.
add_edges
(
u
,
v
)
if
utils
.
is_dict_like
(
msgs
):
self
.
_msg_frame
.
append
(
msgs
)
self
.
_msg_frame
.
append
(
msgs
)
else
:
self
.
_msg_frame
.
append
({
__MSG__
:
msgs
})
# TODO(minjie): Fix these codes in next PR.
# TODO(minjie): Fix these codes in next PR.
"""
"""
...
@@ -1061,7 +1045,6 @@ class DGLGraph(object):
...
@@ -1061,7 +1045,6 @@ class DGLGraph(object):
v
=
utils
.
toindex
(
v
)
v
=
utils
.
toindex
(
v
)
u
,
v
=
utils
.
edge_broadcasting
(
u
,
v
)
u
,
v
=
utils
.
edge_broadcasting
(
u
,
v
)
_
,
_
,
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
_
,
_
,
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
# call the UDF
# call the UDF
src_reprs
=
self
.
get_n_repr
(
u
)
src_reprs
=
self
.
get_n_repr
(
u
)
dst_reprs
=
self
.
get_n_repr
(
v
)
dst_reprs
=
self
.
get_n_repr
(
v
)
...
@@ -1148,25 +1131,19 @@ class DGLGraph(object):
...
@@ -1148,25 +1131,19 @@ class DGLGraph(object):
msg_shape
=
F
.
shape
(
msg
)
msg_shape
=
F
.
shape
(
msg
)
new_shape
=
(
bkt_len
,
deg
)
+
msg_shape
[
1
:]
new_shape
=
(
bkt_len
,
deg
)
+
msg_shape
[
1
:]
return
F
.
reshape
(
msg
,
new_shape
)
return
F
.
reshape
(
msg
,
new_shape
)
if
len
(
in_msgs
)
==
1
and
__MSG__
in
in_msgs
:
reshaped_in_msgs
=
utils
.
LazyDict
(
reshaped_in_msgs
=
_reshape_fn
(
in_msgs
[
__MSG__
])
lambda
key
:
_reshape_fn
(
in_msgs
[
key
]),
self
.
_msg_frame
.
schemes
)
else
:
reshaped_in_msgs
=
utils
.
LazyDict
(
lambda
key
:
_reshape_fn
(
in_msgs
[
key
]),
self
.
_msg_frame
.
schemes
)
reordered_v
.
append
(
v_bkt
.
tousertensor
())
reordered_v
.
append
(
v_bkt
.
tousertensor
())
new_reprs
.
append
(
reduce_func
(
dst_reprs
,
reshaped_in_msgs
))
new_reprs
.
append
(
reduce_func
(
dst_reprs
,
reshaped_in_msgs
))
# TODO: clear partial messages
# TODO
(minjie)
: clear partial messages
self
.
reset_messages
()
self
.
reset_messages
()
# Pack all reducer results together
# Pack all reducer results together
reordered_v
=
F
.
pack
(
reordered_v
)
reordered_v
=
F
.
pack
(
reordered_v
)
if
utils
.
is_dict_like
(
new_reprs
[
0
]):
keys
=
new_reprs
[
0
].
keys
()
keys
=
new_reprs
[
0
].
keys
()
new_reprs
=
{
key
:
F
.
pack
([
repr
[
key
]
for
repr
in
new_reprs
])
new_reprs
=
{
key
:
F
.
pack
([
repr
[
key
]
for
repr
in
new_reprs
])
for
key
in
keys
}
for
key
in
keys
}
else
:
new_reprs
=
{
__REPR__
:
F
.
pack
(
new_reprs
)}
if
v_is_all
and
not
has_zero_degree
:
if
v_is_all
and
not
has_zero_degree
:
# First do reorder and then replace the whole column.
# First do reorder and then replace the whole column.
...
@@ -1237,15 +1214,13 @@ class DGLGraph(object):
...
@@ -1237,15 +1214,13 @@ class DGLGraph(object):
if
executor
:
if
executor
:
new_reprs
=
executor
.
run
()
new_reprs
=
executor
.
run
()
if
not
utils
.
is_dict_like
(
new_reprs
):
new_reprs
=
{
__REPR__
:
new_reprs
}
unique_v
=
executor
.
recv_nodes
unique_v
=
executor
.
recv_nodes
self
.
_apply_nodes
(
unique_v
,
apply_node_func
,
reduce_accum
=
new_reprs
)
self
.
_apply_nodes
(
unique_v
,
apply_node_func
,
reduce_accum
=
new_reprs
)
elif
eid
is
not
None
:
elif
eid
is
not
None
:
_
,
v
,
_
=
self
.
_graph
.
find_edges
(
eid
)
_
,
v
,
_
=
self
.
_graph
.
find_edges
(
eid
)
unique_v
=
utils
.
toindex
(
F
.
unique
(
v
.
tousertensor
()))
unique_v
=
utils
.
toindex
(
F
.
unique
(
v
.
tousertensor
()))
# TODO: replace with the new DegreeBucketingScheduler
# TODO
(quan)
: replace with the new DegreeBucketingScheduler
self
.
send
(
eid
=
eid
,
message_func
=
message_func
)
self
.
send
(
eid
=
eid
,
message_func
=
message_func
)
self
.
recv
(
unique_v
,
reduce_func
,
apply_node_func
)
self
.
recv
(
unique_v
,
reduce_func
,
apply_node_func
)
else
:
else
:
...
@@ -1261,10 +1236,7 @@ class DGLGraph(object):
...
@@ -1261,10 +1236,7 @@ class DGLGraph(object):
edge_reprs
=
self
.
get_e_repr
(
u
,
v
)
edge_reprs
=
self
.
get_e_repr
(
u
,
v
)
msgs
=
message_func
(
src_reprs
,
edge_reprs
)
msgs
=
message_func
(
src_reprs
,
edge_reprs
)
msg_frame
=
FrameRef
()
msg_frame
=
FrameRef
()
if
utils
.
is_dict_like
(
msgs
):
msg_frame
.
append
(
msgs
)
msg_frame
.
append
(
msgs
)
else
:
msg_frame
.
append
({
__MSG__
:
msgs
})
# recv with degree bucketing
# recv with degree bucketing
executor
=
scheduler
.
get_recv_executor
(
graph
=
self
,
executor
=
scheduler
.
get_recv_executor
(
graph
=
self
,
...
@@ -1353,8 +1325,6 @@ class DGLGraph(object):
...
@@ -1353,8 +1325,6 @@ class DGLGraph(object):
"update_all"
,
self
,
message_func
=
message_func
,
reduce_func
=
reduce_func
)
"update_all"
,
self
,
message_func
=
message_func
,
reduce_func
=
reduce_func
)
if
executor
:
if
executor
:
new_reprs
=
executor
.
run
()
new_reprs
=
executor
.
run
()
if
not
utils
.
is_dict_like
(
new_reprs
):
new_reprs
=
{
__REPR__
:
new_reprs
}
self
.
_apply_nodes
(
ALL
,
apply_node_func
,
reduce_accum
=
new_reprs
)
self
.
_apply_nodes
(
ALL
,
apply_node_func
,
reduce_accum
=
new_reprs
)
else
:
else
:
self
.
send
(
ALL
,
ALL
,
message_func
)
self
.
send
(
ALL
,
ALL
,
message_func
)
...
@@ -1387,7 +1357,7 @@ class DGLGraph(object):
...
@@ -1387,7 +1357,7 @@ class DGLGraph(object):
Arguments for pre-defined iterators.
Arguments for pre-defined iterators.
"""
"""
if
isinstance
(
traverser
,
str
):
if
isinstance
(
traverser
,
str
):
# TODO Call pre-defined routine to unroll the computation.
# TODO
(minjie):
Call pre-defined routine to unroll the computation.
raise
RuntimeError
(
'Not implemented.'
)
raise
RuntimeError
(
'Not implemented.'
)
else
:
else
:
# NOTE: the iteration can return multiple edges at each step.
# NOTE: the iteration can return multiple edges at each step.
...
...
python/dgl/scheduler.py
View file @
3e76bcc0
...
@@ -3,7 +3,7 @@ from __future__ import absolute_import
...
@@ -3,7 +3,7 @@ from __future__ import absolute_import
import
numpy
as
np
import
numpy
as
np
from
.base
import
ALL
,
__MSG__
,
__REPR__
from
.base
import
ALL
,
DGLError
from
.
import
backend
as
F
from
.
import
backend
as
F
from
.function
import
message
as
fmsg
from
.function
import
message
as
fmsg
from
.function
import
reducer
as
fred
from
.function
import
reducer
as
fred
...
@@ -111,7 +111,15 @@ def light_degree_bucketing_for_graph(graph):
...
@@ -111,7 +111,15 @@ def light_degree_bucketing_for_graph(graph):
class
Executor
(
object
):
class
Executor
(
object
):
"""Base class for executing graph computation."""
def
run
(
self
):
def
run
(
self
):
"""Run this executor.
This should return the new node features.
TODO(minjie): extend this to support computation on edges.
"""
raise
NotImplementedError
raise
NotImplementedError
class
SPMVOperator
(
Executor
):
class
SPMVOperator
(
Executor
):
...
@@ -126,10 +134,7 @@ class SPMVOperator(Executor):
...
@@ -126,10 +134,7 @@ class SPMVOperator(Executor):
def
run
(
self
):
def
run
(
self
):
# get src col
# get src col
if
self
.
src_field
is
None
:
srccol
=
self
.
node_repr
[
self
.
src_field
]
srccol
=
self
.
node_repr
else
:
srccol
=
self
.
node_repr
[
self
.
src_field
]
ctx
=
F
.
get_context
(
srccol
)
ctx
=
F
.
get_context
(
srccol
)
# build adjmat
# build adjmat
...
@@ -142,10 +147,7 @@ class SPMVOperator(Executor):
...
@@ -142,10 +147,7 @@ class SPMVOperator(Executor):
dstcol
=
F
.
squeeze
(
dstcol
)
dstcol
=
F
.
squeeze
(
dstcol
)
else
:
else
:
dstcol
=
F
.
spmm
(
adjmat
,
srccol
)
dstcol
=
F
.
spmm
(
adjmat
,
srccol
)
if
self
.
dst_field
is
None
:
return
{
self
.
dst_field
:
dstcol
}
return
dstcol
else
:
return
{
self
.
dst_field
:
dstcol
}
# FIXME: refactorize in scheduler/executor redesign
# FIXME: refactorize in scheduler/executor redesign
...
@@ -180,20 +182,14 @@ class DegreeBucketingExecutor(Executor):
...
@@ -180,20 +182,14 @@ class DegreeBucketingExecutor(Executor):
msg_shape
=
F
.
shape
(
msg
)
msg_shape
=
F
.
shape
(
msg
)
new_shape
=
(
len
(
vv
),
deg
)
+
msg_shape
[
1
:]
new_shape
=
(
len
(
vv
),
deg
)
+
msg_shape
[
1
:]
return
F
.
reshape
(
msg
,
new_shape
)
return
F
.
reshape
(
msg
,
new_shape
)
if
len
(
in_msgs
)
==
1
and
__MSG__
in
in_msgs
:
reshaped_in_msgs
=
utils
.
LazyDict
(
reshaped_in_msgs
=
_reshape_fn
(
in_msgs
[
__MSG__
])
lambda
key
:
_reshape_fn
(
in_msgs
[
key
]),
self
.
msg_frame
.
schemes
)
else
:
reshaped_in_msgs
=
utils
.
LazyDict
(
lambda
key
:
_reshape_fn
(
in_msgs
[
key
]),
self
.
msg_frame
.
schemes
)
new_reprs
.
append
(
self
.
rfunc
(
dst_reprs
,
reshaped_in_msgs
))
new_reprs
.
append
(
self
.
rfunc
(
dst_reprs
,
reshaped_in_msgs
))
# Pack all reducer results together
# Pack all reducer results together
if
utils
.
is_dict_like
(
new_reprs
[
0
]):
keys
=
new_reprs
[
0
].
keys
()
keys
=
new_reprs
[
0
].
keys
()
new_reprs
=
{
key
:
F
.
pack
([
repr
[
key
]
for
repr
in
new_reprs
])
new_reprs
=
{
key
:
F
.
pack
([
repr
[
key
]
for
repr
in
new_reprs
])
for
key
in
keys
}
for
key
in
keys
}
else
:
new_reprs
=
{
__REPR__
:
F
.
pack
(
new_reprs
)}
return
new_reprs
return
new_reprs
...
@@ -249,12 +245,6 @@ class UpdateAllExecutor(BasicExecutor):
...
@@ -249,12 +245,6 @@ class UpdateAllExecutor(BasicExecutor):
self
.
_graph_shape
=
None
self
.
_graph_shape
=
None
self
.
_recv_nodes
=
None
self
.
_recv_nodes
=
None
@
property
def
graph_idx
(
self
):
if
self
.
_graph_idx
is
None
:
self
.
_graph_idx
=
self
.
g
.
_graph
.
adjacency_matrix
()
return
self
.
_graph_idx
@
property
@
property
def
graph_shape
(
self
):
def
graph_shape
(
self
):
if
self
.
_graph_shape
is
None
:
if
self
.
_graph_shape
is
None
:
...
@@ -280,16 +270,13 @@ class UpdateAllExecutor(BasicExecutor):
...
@@ -280,16 +270,13 @@ class UpdateAllExecutor(BasicExecutor):
def
_adj_build_fn
(
self
,
edge_field
,
ctx
,
use_edge_feat
):
def
_adj_build_fn
(
self
,
edge_field
,
ctx
,
use_edge_feat
):
if
use_edge_feat
:
if
use_edge_feat
:
if
edge_field
is
None
:
dat
=
self
.
edge_repr
[
edge_field
]
dat
=
self
.
edge_repr
else
:
dat
=
self
.
edge_repr
[
edge_field
]
dat
=
F
.
squeeze
(
dat
)
dat
=
F
.
squeeze
(
dat
)
# TODO(minjie): should not directly use _indices
# TODO(minjie): should not directly use _indices
idx
=
self
.
g
raph_idx
.
get
(
ctx
).
_indices
()
idx
=
self
.
g
.
adjacency_matrix
(
ctx
).
_indices
()
adjmat
=
F
.
sparse_tensor
(
idx
,
dat
,
self
.
graph_shape
)
adjmat
=
F
.
sparse_tensor
(
idx
,
dat
,
self
.
graph_shape
)
else
:
else
:
adjmat
=
self
.
g
raph_idx
.
get
(
ctx
)
adjmat
=
self
.
g
.
adjacency_matrix
(
ctx
)
return
adjmat
return
adjmat
...
@@ -351,10 +338,7 @@ class SendRecvExecutor(BasicExecutor):
...
@@ -351,10 +338,7 @@ class SendRecvExecutor(BasicExecutor):
def
_adj_build_fn
(
self
,
edge_field
,
ctx
,
use_edge_feat
):
def
_adj_build_fn
(
self
,
edge_field
,
ctx
,
use_edge_feat
):
if
use_edge_feat
:
if
use_edge_feat
:
if
edge_field
is
None
:
dat
=
self
.
edge_repr
[
edge_field
]
dat
=
self
.
edge_repr
else
:
dat
=
self
.
edge_repr
[
edge_field
]
dat
=
F
.
squeeze
(
dat
)
dat
=
F
.
squeeze
(
dat
)
else
:
else
:
dat
=
F
.
ones
((
len
(
self
.
u
),
))
dat
=
F
.
ones
((
len
(
self
.
u
),
))
...
@@ -386,9 +370,8 @@ class BundledExecutor(BasicExecutor):
...
@@ -386,9 +370,8 @@ class BundledExecutor(BasicExecutor):
func_pairs
=
[]
func_pairs
=
[]
for
rfn
in
rfunc
.
fn_list
:
for
rfn
in
rfunc
.
fn_list
:
mfn
=
out2mfunc
.
get
(
rfn
.
msg_field
,
None
)
mfn
=
out2mfunc
.
get
(
rfn
.
msg_field
,
None
)
# field check
if
mfn
is
None
:
assert
mfn
is
not
None
,
\
raise
DGLError
(
'Cannot find message field "%s".'
%
rfn
.
msg_field
)
"cannot find message func for reduce func in-field {}"
.
format
(
rfn
.
msg_field
)
func_pairs
.
append
((
mfn
,
rfn
))
func_pairs
.
append
((
mfn
,
rfn
))
return
func_pairs
return
func_pairs
...
@@ -409,7 +392,6 @@ class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor):
...
@@ -409,7 +392,6 @@ class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor):
self
.
_init_state
()
self
.
_init_state
()
BundledExecutor
.
__init__
(
self
,
graph
,
mfunc
,
rfunc
)
BundledExecutor
.
__init__
(
self
,
graph
,
mfunc
,
rfunc
)
class
BundledSendRecvExecutor
(
BundledExecutor
,
SendRecvExecutor
):
class
BundledSendRecvExecutor
(
BundledExecutor
,
SendRecvExecutor
):
def
__init__
(
self
,
graph
,
src
,
dst
,
mfunc
,
rfunc
):
def
__init__
(
self
,
graph
,
src
,
dst
,
mfunc
,
rfunc
):
self
.
_init_state
(
src
,
dst
)
self
.
_init_state
(
src
,
dst
)
...
...
tests/pytorch/test_basics.py
View file @
3e76bcc0
...
@@ -209,14 +209,13 @@ def test_reduce_0deg():
...
@@ -209,14 +209,13 @@ def test_reduce_0deg():
g
.
add_edge
(
3
,
0
)
g
.
add_edge
(
3
,
0
)
g
.
add_edge
(
4
,
0
)
g
.
add_edge
(
4
,
0
)
def
_message
(
src
,
edge
):
def
_message
(
src
,
edge
):
return
src
return
{
'm'
:
src
[
'h'
]}
def
_reduce
(
node
,
msgs
):
def
_reduce
(
node
,
msgs
):
assert
msgs
is
not
None
return
{
'h'
:
node
[
'h'
]
+
msgs
[
'm'
].
sum
(
1
)}
return
node
+
msgs
.
sum
(
1
)
old_repr
=
th
.
randn
(
5
,
5
)
old_repr
=
th
.
randn
(
5
,
5
)
g
.
set_n_repr
(
old_repr
)
g
.
set_n_repr
(
{
'h'
:
old_repr
}
)
g
.
update_all
(
_message
,
_reduce
)
g
.
update_all
(
_message
,
_reduce
)
new_repr
=
g
.
get_n_repr
()
new_repr
=
g
.
get_n_repr
()
[
'h'
]
assert
th
.
allclose
(
new_repr
[
1
:],
old_repr
[
1
:])
assert
th
.
allclose
(
new_repr
[
1
:],
old_repr
[
1
:])
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
.
sum
(
0
))
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
.
sum
(
0
))
...
@@ -226,25 +225,25 @@ def test_pull_0deg():
...
@@ -226,25 +225,25 @@ def test_pull_0deg():
g
.
add_nodes
(
2
)
g
.
add_nodes
(
2
)
g
.
add_edge
(
0
,
1
)
g
.
add_edge
(
0
,
1
)
def
_message
(
src
,
edge
):
def
_message
(
src
,
edge
):
return
src
return
{
'm'
:
src
[
'h'
]}
def
_reduce
(
node
,
msgs
):
def
_reduce
(
node
,
msgs
):
assert
msgs
is
not
None
return
{
'h'
:
msgs
[
'm'
].
sum
(
1
)}
return
msgs
.
sum
(
1
)
old_repr
=
th
.
randn
(
2
,
5
)
old_repr
=
th
.
randn
(
2
,
5
)
g
.
set_n_repr
(
old_repr
)
g
.
set_n_repr
({
'h'
:
old_repr
})
g
.
pull
(
0
,
_message
,
_reduce
)
g
.
pull
(
0
,
_message
,
_reduce
)
new_repr
=
g
.
get_n_repr
()
new_repr
=
g
.
get_n_repr
()
[
'h'
]
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
1
])
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
1
])
g
.
pull
(
1
,
_message
,
_reduce
)
g
.
pull
(
1
,
_message
,
_reduce
)
new_repr
=
g
.
get_n_repr
()
new_repr
=
g
.
get_n_repr
()
[
'h'
]
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
0
])
old_repr
=
th
.
randn
(
2
,
5
)
old_repr
=
th
.
randn
(
2
,
5
)
g
.
set_n_repr
(
old_repr
)
g
.
set_n_repr
(
{
'h'
:
old_repr
}
)
g
.
pull
([
0
,
1
],
_message
,
_reduce
)
g
.
pull
([
0
,
1
],
_message
,
_reduce
)
new_repr
=
g
.
get_n_repr
()
new_repr
=
g
.
get_n_repr
()
[
'h'
]
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
0
])
...
...
tests/pytorch/test_basics_anonymous.py
deleted
100644 → 0
View file @
fb6be9fb
import
torch
as
th
from
torch.autograd
import
Variable
import
numpy
as
np
from
dgl.graph
import
DGLGraph
,
__REPR__
D
=
32
reduce_msg_shapes
=
set
()
def
check_eq
(
a
,
b
):
assert
a
.
shape
==
b
.
shape
assert
th
.
sum
(
a
==
b
)
==
int
(
np
.
prod
(
list
(
a
.
shape
)))
def
message_func
(
hu
,
e_uv
):
assert
len
(
hu
.
shape
)
==
2
assert
hu
.
shape
[
1
]
==
D
return
hu
def
reduce_func
(
hv
,
msgs
):
reduce_msg_shapes
.
add
(
tuple
(
msgs
.
shape
))
assert
len
(
msgs
.
shape
)
==
3
assert
msgs
.
shape
[
2
]
==
D
return
hv
+
th
.
sum
(
msgs
,
1
)
def
generate_graph
(
grad
=
False
):
g
=
DGLGraph
()
g
.
add_nodes
(
10
)
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
i
,
9
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
)
ncol
=
Variable
(
th
.
randn
(
10
,
D
),
requires_grad
=
grad
)
ecol
=
Variable
(
th
.
randn
(
17
,
D
),
requires_grad
=
grad
)
g
.
set_n_repr
(
ncol
)
g
.
set_e_repr
(
ecol
)
return
g
def
test_batch_setter_getter
():
def
_pfc
(
x
):
return
list
(
x
.
numpy
()[:,
0
])
g
=
generate_graph
()
# set all nodes
g
.
set_n_repr
(
th
.
zeros
((
10
,
D
)))
assert
_pfc
(
g
.
get_n_repr
())
==
[
0.
]
*
10
# pop nodes
assert
_pfc
(
g
.
pop_n_repr
())
==
[
0.
]
*
10
assert
len
(
g
.
get_n_repr
())
==
0
g
.
set_n_repr
(
th
.
zeros
((
10
,
D
)))
# set partial nodes
u
=
th
.
tensor
([
1
,
3
,
5
])
g
.
set_n_repr
(
th
.
ones
((
3
,
D
)),
u
)
assert
_pfc
(
g
.
get_n_repr
())
==
[
0.
,
1.
,
0.
,
1.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
]
# get partial nodes
u
=
th
.
tensor
([
1
,
2
,
3
])
assert
_pfc
(
g
.
get_n_repr
(
u
))
==
[
1.
,
0.
,
1.
]
'''
s, d, eid
0, 1, 0
1, 9, 1
0, 2, 2
2, 9, 3
0, 3, 4
3, 9, 5
0, 4, 6
4, 9, 7
0, 5, 8
5, 9, 9
0, 6, 10
6, 9, 11
0, 7, 12
7, 9, 13
0, 8, 14
8, 9, 15
9, 0, 16
'''
# set all edges
g
.
set_e_repr
(
th
.
zeros
((
17
,
D
)))
assert
_pfc
(
g
.
get_e_repr
())
==
[
0.
]
*
17
# pop edges
assert
_pfc
(
g
.
pop_e_repr
())
==
[
0.
]
*
17
assert
len
(
g
.
get_e_repr
())
==
0
g
.
set_e_repr
(
th
.
zeros
((
17
,
D
)))
# set partial edges (many-many)
u
=
th
.
tensor
([
0
,
0
,
2
,
5
,
9
])
v
=
th
.
tensor
([
1
,
3
,
9
,
9
,
0
])
g
.
set_e_repr
(
th
.
ones
((
5
,
D
)),
u
,
v
)
truth
=
[
0.
]
*
17
truth
[
0
]
=
truth
[
4
]
=
truth
[
3
]
=
truth
[
9
]
=
truth
[
16
]
=
1.
assert
_pfc
(
g
.
get_e_repr
())
==
truth
# set partial edges (many-one)
u
=
th
.
tensor
([
3
,
4
,
6
])
v
=
th
.
tensor
([
9
])
g
.
set_e_repr
(
th
.
ones
((
3
,
D
)),
u
,
v
)
truth
[
5
]
=
truth
[
7
]
=
truth
[
11
]
=
1.
assert
_pfc
(
g
.
get_e_repr
())
==
truth
# set partial edges (one-many)
u
=
th
.
tensor
([
0
])
v
=
th
.
tensor
([
4
,
5
,
6
])
g
.
set_e_repr
(
th
.
ones
((
3
,
D
)),
u
,
v
)
truth
[
6
]
=
truth
[
8
]
=
truth
[
10
]
=
1.
assert
_pfc
(
g
.
get_e_repr
())
==
truth
# get partial edges (many-many)
u
=
th
.
tensor
([
0
,
6
,
0
])
v
=
th
.
tensor
([
6
,
9
,
7
])
assert
_pfc
(
g
.
get_e_repr
(
u
,
v
))
==
[
1.
,
1.
,
0.
]
# get partial edges (many-one)
u
=
th
.
tensor
([
5
,
6
,
7
])
v
=
th
.
tensor
([
9
])
assert
_pfc
(
g
.
get_e_repr
(
u
,
v
))
==
[
1.
,
1.
,
0.
]
# get partial edges (one-many)
u
=
th
.
tensor
([
0
])
v
=
th
.
tensor
([
3
,
4
,
5
])
assert
_pfc
(
g
.
get_e_repr
(
u
,
v
))
==
[
1.
,
1.
,
1.
]
def
test_batch_setter_autograd
():
g
=
generate_graph
(
grad
=
True
)
h1
=
g
.
get_n_repr
()
# partial set
v
=
th
.
tensor
([
1
,
2
,
8
])
hh
=
Variable
(
th
.
zeros
((
len
(
v
),
D
)),
requires_grad
=
True
)
g
.
set_n_repr
(
hh
,
v
)
h2
=
g
.
get_n_repr
()
h2
.
backward
(
th
.
ones
((
10
,
D
))
*
2
)
check_eq
(
h1
.
grad
[:,
0
],
th
.
tensor
([
2.
,
0.
,
0.
,
2.
,
2.
,
2.
,
2.
,
2.
,
0.
,
2.
]))
check_eq
(
hh
.
grad
[:,
0
],
th
.
tensor
([
2.
,
2.
,
2.
]))
def
test_batch_send
():
g
=
generate_graph
()
def
_fmsg
(
hu
,
edge
):
assert
hu
.
shape
==
(
5
,
D
)
return
hu
g
.
register_message_func
(
_fmsg
)
# many-many send
u
=
th
.
tensor
([
0
,
0
,
0
,
0
,
0
])
v
=
th
.
tensor
([
1
,
2
,
3
,
4
,
5
])
g
.
send
(
u
,
v
)
# one-many send
u
=
th
.
tensor
([
0
])
v
=
th
.
tensor
([
1
,
2
,
3
,
4
,
5
])
g
.
send
(
u
,
v
)
# many-one send
u
=
th
.
tensor
([
1
,
2
,
3
,
4
,
5
])
v
=
th
.
tensor
([
9
])
g
.
send
(
u
,
v
)
def
test_batch_recv
():
g
=
generate_graph
()
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
)
u
=
th
.
tensor
([
0
,
0
,
0
,
4
,
5
,
6
])
v
=
th
.
tensor
([
1
,
2
,
3
,
9
,
9
,
9
])
reduce_msg_shapes
.
clear
()
g
.
send
(
u
,
v
)
g
.
recv
(
th
.
unique
(
v
))
assert
(
reduce_msg_shapes
==
{(
1
,
3
,
D
),
(
3
,
1
,
D
)})
reduce_msg_shapes
.
clear
()
def
test_update_routines
():
g
=
generate_graph
()
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
)
# send_and_recv
reduce_msg_shapes
.
clear
()
u
=
th
.
tensor
([
0
,
0
,
0
,
4
,
5
,
6
])
v
=
th
.
tensor
([
1
,
2
,
3
,
9
,
9
,
9
])
g
.
send_and_recv
(
u
,
v
)
assert
(
reduce_msg_shapes
==
{(
1
,
3
,
D
),
(
3
,
1
,
D
)})
reduce_msg_shapes
.
clear
()
# pull
v
=
th
.
tensor
([
1
,
2
,
3
,
9
])
reduce_msg_shapes
.
clear
()
g
.
pull
(
v
)
assert
(
reduce_msg_shapes
==
{(
1
,
8
,
D
),
(
3
,
1
,
D
)})
reduce_msg_shapes
.
clear
()
# push
v
=
th
.
tensor
([
0
,
1
,
2
,
3
])
reduce_msg_shapes
.
clear
()
g
.
push
(
v
)
assert
(
reduce_msg_shapes
==
{(
1
,
3
,
D
),
(
8
,
1
,
D
)})
reduce_msg_shapes
.
clear
()
# update_all
reduce_msg_shapes
.
clear
()
g
.
update_all
()
assert
(
reduce_msg_shapes
==
{(
1
,
8
,
D
),
(
9
,
1
,
D
)})
reduce_msg_shapes
.
clear
()
if
__name__
==
'__main__'
:
test_batch_setter_getter
()
test_batch_setter_autograd
()
test_batch_send
()
test_batch_recv
()
test_update_routines
()
tests/pytorch/test_batched_graph.py
View file @
3e76bcc0
...
@@ -18,8 +18,8 @@ def tree1():
...
@@ -18,8 +18,8 @@ def tree1():
g
.
add_edge
(
4
,
1
)
g
.
add_edge
(
4
,
1
)
g
.
add_edge
(
1
,
0
)
g
.
add_edge
(
1
,
0
)
g
.
add_edge
(
2
,
0
)
g
.
add_edge
(
2
,
0
)
g
.
set_n_repr
(
th
.
Tensor
([
0
,
1
,
2
,
3
,
4
]))
g
.
set_n_repr
(
{
'h'
:
th
.
Tensor
([
0
,
1
,
2
,
3
,
4
])
}
)
g
.
set_e_repr
(
th
.
randn
(
4
,
10
))
g
.
set_e_repr
(
{
'h'
:
th
.
randn
(
4
,
10
)
}
)
return
g
return
g
def
tree2
():
def
tree2
():
...
@@ -37,17 +37,17 @@ def tree2():
...
@@ -37,17 +37,17 @@ def tree2():
g
.
add_edge
(
0
,
4
)
g
.
add_edge
(
0
,
4
)
g
.
add_edge
(
4
,
1
)
g
.
add_edge
(
4
,
1
)
g
.
add_edge
(
3
,
1
)
g
.
add_edge
(
3
,
1
)
g
.
set_n_repr
(
th
.
Tensor
([
0
,
1
,
2
,
3
,
4
]))
g
.
set_n_repr
(
{
'h'
:
th
.
Tensor
([
0
,
1
,
2
,
3
,
4
])
}
)
g
.
set_e_repr
(
th
.
randn
(
4
,
10
))
g
.
set_e_repr
(
{
'h'
:
th
.
randn
(
4
,
10
)
}
)
return
g
return
g
def
test_batch_unbatch
():
def
test_batch_unbatch
():
t1
=
tree1
()
t1
=
tree1
()
t2
=
tree2
()
t2
=
tree2
()
n1
=
t1
.
get_n_repr
()
n1
=
t1
.
get_n_repr
()
[
'h'
]
n2
=
t2
.
get_n_repr
()
n2
=
t2
.
get_n_repr
()
[
'h'
]
e1
=
t1
.
get_e_repr
()
e1
=
t1
.
get_e_repr
()
[
'h'
]
e2
=
t2
.
get_e_repr
()
e2
=
t2
.
get_e_repr
()
[
'h'
]
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
=
dgl
.
batch
([
t1
,
t2
])
assert
bg
.
number_of_nodes
()
==
10
assert
bg
.
number_of_nodes
()
==
10
...
@@ -57,10 +57,10 @@ def test_batch_unbatch():
...
@@ -57,10 +57,10 @@ def test_batch_unbatch():
assert
bg
.
batch_num_edges
==
[
4
,
4
]
assert
bg
.
batch_num_edges
==
[
4
,
4
]
tt1
,
tt2
=
dgl
.
unbatch
(
bg
)
tt1
,
tt2
=
dgl
.
unbatch
(
bg
)
assert
th
.
allclose
(
t1
.
get_n_repr
(),
tt1
.
get_n_repr
())
assert
th
.
allclose
(
t1
.
get_n_repr
()
[
'h'
]
,
tt1
.
get_n_repr
()
[
'h'
]
)
assert
th
.
allclose
(
t1
.
get_e_repr
(),
tt1
.
get_e_repr
())
assert
th
.
allclose
(
t1
.
get_e_repr
()
[
'h'
]
,
tt1
.
get_e_repr
()
[
'h'
]
)
assert
th
.
allclose
(
t2
.
get_n_repr
(),
tt2
.
get_n_repr
())
assert
th
.
allclose
(
t2
.
get_n_repr
()
[
'h'
]
,
tt2
.
get_n_repr
()
[
'h'
]
)
assert
th
.
allclose
(
t2
.
get_e_repr
(),
tt2
.
get_e_repr
())
assert
th
.
allclose
(
t2
.
get_e_repr
()
[
'h'
]
,
tt2
.
get_e_repr
()
[
'h'
]
)
def
test_batch_unbatch1
():
def
test_batch_unbatch1
():
t1
=
tree1
()
t1
=
tree1
()
...
@@ -74,20 +74,20 @@ def test_batch_unbatch1():
...
@@ -74,20 +74,20 @@ def test_batch_unbatch1():
assert
b2
.
batch_num_edges
==
[
4
,
4
,
4
]
assert
b2
.
batch_num_edges
==
[
4
,
4
,
4
]
s1
,
s2
,
s3
=
dgl
.
unbatch
(
b2
)
s1
,
s2
,
s3
=
dgl
.
unbatch
(
b2
)
assert
th
.
allclose
(
t2
.
get_n_repr
(),
s1
.
get_n_repr
())
assert
th
.
allclose
(
t2
.
get_n_repr
()
[
'h'
]
,
s1
.
get_n_repr
()
[
'h'
]
)
assert
th
.
allclose
(
t2
.
get_e_repr
(),
s1
.
get_e_repr
())
assert
th
.
allclose
(
t2
.
get_e_repr
()
[
'h'
]
,
s1
.
get_e_repr
()
[
'h'
]
)
assert
th
.
allclose
(
t1
.
get_n_repr
(),
s2
.
get_n_repr
())
assert
th
.
allclose
(
t1
.
get_n_repr
()
[
'h'
]
,
s2
.
get_n_repr
()
[
'h'
]
)
assert
th
.
allclose
(
t1
.
get_e_repr
(),
s2
.
get_e_repr
())
assert
th
.
allclose
(
t1
.
get_e_repr
()
[
'h'
]
,
s2
.
get_e_repr
()
[
'h'
]
)
assert
th
.
allclose
(
t2
.
get_n_repr
(),
s3
.
get_n_repr
())
assert
th
.
allclose
(
t2
.
get_n_repr
()
[
'h'
]
,
s3
.
get_n_repr
()
[
'h'
]
)
assert
th
.
allclose
(
t2
.
get_e_repr
(),
s3
.
get_e_repr
())
assert
th
.
allclose
(
t2
.
get_e_repr
()
[
'h'
]
,
s3
.
get_e_repr
()
[
'h'
]
)
def
test_batch_sendrecv
():
def
test_batch_sendrecv
():
t1
=
tree1
()
t1
=
tree1
()
t2
=
tree2
()
t2
=
tree2
()
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
.
register_message_func
(
lambda
src
,
edge
:
src
)
bg
.
register_message_func
(
lambda
src
,
edge
:
{
'm'
:
src
[
'h'
]}
)
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
th
.
sum
(
msgs
,
1
))
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
{
'h'
:
th
.
sum
(
msgs
[
'm'
]
,
1
)
}
)
u
=
[
3
,
4
,
2
+
5
,
0
+
5
]
u
=
[
3
,
4
,
2
+
5
,
0
+
5
]
v
=
[
1
,
1
,
4
+
5
,
4
+
5
]
v
=
[
1
,
1
,
4
+
5
,
4
+
5
]
...
@@ -95,8 +95,8 @@ def test_batch_sendrecv():
...
@@ -95,8 +95,8 @@ def test_batch_sendrecv():
bg
.
recv
(
v
)
bg
.
recv
(
v
)
t1
,
t2
=
dgl
.
unbatch
(
bg
)
t1
,
t2
=
dgl
.
unbatch
(
bg
)
assert
t1
.
get_n_repr
()[
1
]
==
7
assert
t1
.
get_n_repr
()[
'h'
][
1
]
==
7
assert
t2
.
get_n_repr
()[
4
]
==
2
assert
t2
.
get_n_repr
()[
'h'
][
4
]
==
2
def
test_batch_propagate
():
def
test_batch_propagate
():
...
@@ -104,8 +104,8 @@ def test_batch_propagate():
...
@@ -104,8 +104,8 @@ def test_batch_propagate():
t2
=
tree2
()
t2
=
tree2
()
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
.
register_message_func
(
lambda
src
,
edge
:
src
)
bg
.
register_message_func
(
lambda
src
,
edge
:
{
'm'
:
src
[
'h'
]}
)
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
th
.
sum
(
msgs
,
1
))
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
{
'h'
:
th
.
sum
(
msgs
[
'm'
]
,
1
)
}
)
# get leaves.
# get leaves.
order
=
[]
order
=
[]
...
@@ -123,23 +123,23 @@ def test_batch_propagate():
...
@@ -123,23 +123,23 @@ def test_batch_propagate():
bg
.
propagate
(
traverser
=
order
)
bg
.
propagate
(
traverser
=
order
)
t1
,
t2
=
dgl
.
unbatch
(
bg
)
t1
,
t2
=
dgl
.
unbatch
(
bg
)
assert
t1
.
get_n_repr
()[
0
]
==
9
assert
t1
.
get_n_repr
()[
'h'
][
0
]
==
9
assert
t2
.
get_n_repr
()[
1
]
==
5
assert
t2
.
get_n_repr
()[
'h'
][
1
]
==
5
def
test_batched_edge_ordering
():
def
test_batched_edge_ordering
():
g1
=
dgl
.
DGLGraph
()
g1
=
dgl
.
DGLGraph
()
g1
.
add_nodes
(
6
)
g1
.
add_nodes
(
6
)
g1
.
add_edges
([
4
,
4
,
2
,
2
,
0
],
[
5
,
3
,
3
,
1
,
1
])
g1
.
add_edges
([
4
,
4
,
2
,
2
,
0
],
[
5
,
3
,
3
,
1
,
1
])
e1
=
th
.
randn
(
5
,
10
)
e1
=
th
.
randn
(
5
,
10
)
g1
.
set_e_repr
(
e1
)
g1
.
set_e_repr
(
{
'h'
:
e1
}
)
g2
=
dgl
.
DGLGraph
()
g2
=
dgl
.
DGLGraph
()
g2
.
add_nodes
(
6
)
g2
.
add_nodes
(
6
)
g2
.
add_edges
([
0
,
1
,
2
,
5
,
4
,
5
],
[
1
,
2
,
3
,
4
,
3
,
0
])
g2
.
add_edges
([
0
,
1
,
2
,
5
,
4
,
5
],
[
1
,
2
,
3
,
4
,
3
,
0
])
e2
=
th
.
randn
(
6
,
10
)
e2
=
th
.
randn
(
6
,
10
)
g2
.
set_e_repr
(
e2
)
g2
.
set_e_repr
(
{
'h'
:
e2
}
)
g
=
dgl
.
batch
([
g1
,
g2
])
g
=
dgl
.
batch
([
g1
,
g2
])
r1
=
g
.
get_e_repr
()[
g
.
edge_id
(
4
,
5
)]
r1
=
g
.
get_e_repr
()[
'h'
][
g
.
edge_id
(
4
,
5
)]
r2
=
g1
.
get_e_repr
()[
g1
.
edge_id
(
4
,
5
)]
r2
=
g1
.
get_e_repr
()[
'h'
][
g1
.
edge_id
(
4
,
5
)]
assert
th
.
equal
(
r1
,
r2
)
assert
th
.
equal
(
r1
,
r2
)
def
test_batch_no_edge
():
def
test_batch_no_edge
():
...
...
tests/pytorch/test_function.py
View file @
3e76bcc0
import
torch
as
th
import
torch
as
th
import
dgl
import
dgl
import
dgl.function
as
fn
import
dgl.function
as
fn
from
dgl.graph
import
__REPR__
def
generate_graph
():
def
generate_graph
():
g
=
dgl
.
DGLGraph
()
g
=
dgl
.
DGLGraph
()
...
@@ -37,18 +36,9 @@ def generate_graph1():
...
@@ -37,18 +36,9 @@ def generate_graph1():
g
.
set_e_repr
(
h
)
g
.
set_e_repr
(
h
)
return
g
return
g
def
reducer_msg
(
node
,
msgs
):
return
th
.
sum
(
msgs
[
'm'
],
1
)
def
reducer_out
(
node
,
msgs
):
return
{
'h'
:
th
.
sum
(
msgs
,
1
)}
def
reducer_both
(
node
,
msgs
):
def
reducer_both
(
node
,
msgs
):
return
{
'h'
:
th
.
sum
(
msgs
[
'm'
],
1
)}
return
{
'h'
:
th
.
sum
(
msgs
[
'm'
],
1
)}
def
reducer_none
(
node
,
msgs
):
return
th
.
sum
(
msgs
,
1
)
def
test_copy_src
():
def
test_copy_src
():
# copy_src with both fields
# copy_src with both fields
g
=
generate_graph
()
g
=
generate_graph
()
...
@@ -58,30 +48,6 @@ def test_copy_src():
...
@@ -58,30 +48,6 @@ def test_copy_src():
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy_src with only src field; the out field should use anonymous repr
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_src
(
src
=
'h'
))
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy_src with no src field; should use anonymous repr
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_src
(
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy src with no fields;
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_src
())
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
def
test_copy_edge
():
def
test_copy_edge
():
# copy_edge with both fields
# copy_edge with both fields
g
=
generate_graph
()
g
=
generate_graph
()
...
@@ -91,30 +57,6 @@ def test_copy_edge():
...
@@ -91,30 +57,6 @@ def test_copy_edge():
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy_edge with only edge field; the out field should use anonymous repr
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_edge
(
edge
=
'h'
))
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy_edge with no edge field; should use anonymous repr
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_edge
(
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy edge with no fields;
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_edge
())
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
def
test_src_mul_edge
():
def
test_src_mul_edge
():
# src_mul_edge with all fields
# src_mul_edge with all fields
g
=
generate_graph
()
g
=
generate_graph
()
...
@@ -124,34 +66,6 @@ def test_src_mul_edge():
...
@@ -124,34 +66,6 @@ def test_src_mul_edge():
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
src_mul_edge
(
src
=
'h'
,
edge
=
'h'
))
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
(
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
())
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
())
g
.
register_reduce_func
(
reducer_none
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
(),
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_copy_src
()
test_copy_src
()
test_copy_edge
()
test_copy_edge
()
...
...
tests/pytorch/test_line_graph.py
View file @
3e76bcc0
...
@@ -5,35 +5,31 @@ import dgl
...
@@ -5,35 +5,31 @@ import dgl
D
=
5
D
=
5
def
check_eq
(
a
,
b
):
return
a
.
shape
==
b
.
shape
and
np
.
allclose
(
a
.
numpy
(),
b
.
numpy
())
def
test_line_graph
():
def
test_line_graph
():
N
=
5
N
=
5
G
=
dgl
.
DGLGraph
(
nx
.
star_graph
(
N
))
G
=
dgl
.
DGLGraph
(
nx
.
star_graph
(
N
))
G
.
set_e_repr
(
th
.
randn
((
2
*
N
,
D
)))
G
.
set_e_repr
(
{
'h'
:
th
.
randn
((
2
*
N
,
D
))
}
)
n_edges
=
G
.
number_of_edges
()
n_edges
=
G
.
number_of_edges
()
L
=
G
.
line_graph
(
shared
=
True
)
L
=
G
.
line_graph
(
shared
=
True
)
assert
L
.
number_of_nodes
()
==
2
*
N
assert
L
.
number_of_nodes
()
==
2
*
N
L
.
set_n_repr
(
th
.
randn
((
2
*
N
,
D
)))
L
.
set_n_repr
(
{
'h'
:
th
.
randn
((
2
*
N
,
D
))
}
)
# update node features on line graph should reflect to edge features on
# update node features on line graph should reflect to edge features on
# original graph.
# original graph.
u
=
[
0
,
0
,
2
,
3
]
u
=
[
0
,
0
,
2
,
3
]
v
=
[
1
,
2
,
0
,
0
]
v
=
[
1
,
2
,
0
,
0
]
eid
=
G
.
edge_ids
(
u
,
v
)
eid
=
G
.
edge_ids
(
u
,
v
)
L
.
set_n_repr
(
th
.
zeros
((
4
,
D
)),
eid
)
L
.
set_n_repr
(
{
'h'
:
th
.
zeros
((
4
,
D
))
}
,
eid
)
assert
check_eq
(
G
.
get_e_repr
(
u
,
v
),
th
.
zeros
((
4
,
D
)))
assert
th
.
allclose
(
G
.
get_e_repr
(
u
,
v
)
[
'h'
]
,
th
.
zeros
((
4
,
D
)))
# adding a new node feature on line graph should also reflect to a new
# adding a new node feature on line graph should also reflect to a new
# edge feature on original graph
# edge feature on original graph
data
=
th
.
randn
(
n_edges
,
D
)
data
=
th
.
randn
(
n_edges
,
D
)
L
.
set_n_repr
({
'w'
:
data
})
L
.
set_n_repr
({
'w'
:
data
})
assert
check_eq
(
G
.
get_e_repr
()[
'w'
],
data
)
assert
th
.
allclose
(
G
.
get_e_repr
()[
'w'
],
data
)
def
test_no_backtracking
():
def
test_no_backtracking
():
N
=
5
N
=
5
G
=
dgl
.
DGLGraph
(
nx
.
star_graph
(
N
))
G
=
dgl
.
DGLGraph
(
nx
.
star_graph
(
N
))
G
.
set_e_repr
(
th
.
randn
((
2
*
N
,
D
)))
L
=
G
.
line_graph
(
backtracking
=
False
)
L
=
G
.
line_graph
(
backtracking
=
False
)
assert
L
.
number_of_nodes
()
==
2
*
N
assert
L
.
number_of_nodes
()
==
2
*
N
for
i
in
range
(
1
,
N
):
for
i
in
range
(
1
,
N
):
...
...
tests/pytorch/test_specialization.py
View file @
3e76bcc0
...
@@ -22,23 +22,23 @@ def generate_graph():
...
@@ -22,23 +22,23 @@ def generate_graph():
def
test_update_all
():
def
test_update_all
():
def
_test
(
fld
):
def
_test
(
fld
):
def
message_func
(
hu
,
edge
):
def
message_func
(
hu
,
edge
):
return
hu
[
fld
]
return
{
'm'
:
hu
[
fld
]
}
def
message_func_edge
(
hu
,
edge
):
def
message_func_edge
(
hu
,
edge
):
if
len
(
hu
[
fld
].
shape
)
==
1
:
if
len
(
hu
[
fld
].
shape
)
==
1
:
return
hu
[
fld
]
*
edge
[
'e1'
]
return
{
'm'
:
hu
[
fld
]
*
edge
[
'e1'
]
}
else
:
else
:
return
hu
[
fld
]
*
edge
[
'e2'
]
return
{
'm'
:
hu
[
fld
]
*
edge
[
'e2'
]
}
def
reduce_func
(
hv
,
msgs
):
def
reduce_func
(
hv
,
msgs
):
return
{
fld
:
th
.
sum
(
msgs
,
1
)}
return
{
fld
:
th
.
sum
(
msgs
[
'm'
]
,
1
)}
def
apply_func
(
hu
):
def
apply_func
(
hu
):
return
{
fld
:
2
*
hu
[
fld
]}
return
{
fld
:
2
*
hu
[
fld
]}
g
=
generate_graph
()
g
=
generate_graph
()
# update all
# update all
v1
=
g
.
get_n_repr
()[
fld
]
v1
=
g
.
get_n_repr
()[
fld
]
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
fld
),
apply_func
)
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
get_n_repr
()[
fld
]
v2
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
message_func
,
reduce_func
,
apply_func
)
g
.
update_all
(
message_func
,
reduce_func
,
apply_func
)
...
@@ -46,12 +46,12 @@ def test_update_all():
...
@@ -46,12 +46,12 @@ def test_update_all():
assert
th
.
allclose
(
v2
,
v3
)
assert
th
.
allclose
(
v2
,
v3
)
# update all with edge weights
# update all with edge weights
v1
=
g
.
get_n_repr
()[
fld
]
v1
=
g
.
get_n_repr
()[
fld
]
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
),
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm'
),
fn
.
sum
(
out
=
fld
),
apply_func
)
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
get_n_repr
()[
fld
]
v2
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
),
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm'
),
fn
.
sum
(
out
=
fld
),
apply_func
)
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v3
=
g
.
get_n_repr
()[
fld
]
v3
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
message_func_edge
,
reduce_func
,
apply_func
)
g
.
update_all
(
message_func_edge
,
reduce_func
,
apply_func
)
...
@@ -68,42 +68,40 @@ def test_send_and_recv():
...
@@ -68,42 +68,40 @@ def test_send_and_recv():
v
=
th
.
tensor
([
1
,
2
,
3
,
9
,
9
,
0
])
v
=
th
.
tensor
([
1
,
2
,
3
,
9
,
9
,
0
])
def
_test
(
fld
):
def
_test
(
fld
):
def
message_func
(
hu
,
edge
):
def
message_func
(
hu
,
edge
):
return
hu
[
fld
]
return
{
'm'
:
hu
[
fld
]
}
def
message_func_edge
(
hu
,
edge
):
def
message_func_edge
(
hu
,
edge
):
if
len
(
hu
[
fld
].
shape
)
==
1
:
if
len
(
hu
[
fld
].
shape
)
==
1
:
return
hu
[
fld
]
*
edge
[
'e1'
]
return
{
'm'
:
hu
[
fld
]
*
edge
[
'e1'
]
}
else
:
else
:
return
hu
[
fld
]
*
edge
[
'e2'
]
return
{
'm'
:
hu
[
fld
]
*
edge
[
'e2'
]
}
def
reduce_func
(
hv
,
msgs
):
def
reduce_func
(
hv
,
msgs
):
return
{
fld
:
th
.
sum
(
msgs
,
1
)}
return
{
fld
:
th
.
sum
(
msgs
[
'm'
]
,
1
)}
def
apply_func
(
hu
):
def
apply_func
(
hu
):
return
{
fld
:
2
*
hu
[
fld
]}
return
{
fld
:
2
*
hu
[
fld
]}
g
=
generate_graph
()
g
=
generate_graph
()
# send and recv
# send and recv
v1
=
g
.
get_n_repr
()[
fld
]
v1
=
g
.
get_n_repr
()[
fld
]
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
,
out
=
'm'
),
fn
.
sum
(
out
=
fld
),
apply_func
)
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
get_n_repr
()[
fld
]
v2
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
set_n_repr
({
fld
:
v1
})
g
.
send_and_recv
(
u
,
v
,
message_func
,
g
.
send_and_recv
(
u
,
v
,
message_func
,
reduce_func
,
apply_func
)
reduce_func
,
apply_func
)
v3
=
g
.
get_n_repr
()[
fld
]
v3
=
g
.
get_n_repr
()[
fld
]
assert
th
.
allclose
(
v2
,
v3
)
assert
th
.
allclose
(
v2
,
v3
)
# send and recv with edge weights
# send and recv with edge weights
v1
=
g
.
get_n_repr
()[
fld
]
v1
=
g
.
get_n_repr
()[
fld
]
g
.
send_and_recv
(
u
,
v
,
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
),
g
.
send_and_recv
(
u
,
v
,
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm'
),
fn
.
sum
(
out
=
fld
),
apply_func
)
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
get_n_repr
()[
fld
]
v2
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
set_n_repr
({
fld
:
v1
})
g
.
send_and_recv
(
u
,
v
,
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
),
g
.
send_and_recv
(
u
,
v
,
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm'
),
fn
.
sum
(
out
=
fld
),
apply_func
)
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v3
=
g
.
get_n_repr
()[
fld
]
v3
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
set_n_repr
({
fld
:
v1
})
g
.
send_and_recv
(
u
,
v
,
message_func_edge
,
g
.
send_and_recv
(
u
,
v
,
message_func_edge
,
reduce_func
,
apply_func
)
reduce_func
,
apply_func
)
v4
=
g
.
get_n_repr
()[
fld
]
v4
=
g
.
get_n_repr
()[
fld
]
assert
th
.
allclose
(
v2
,
v3
)
assert
th
.
allclose
(
v2
,
v3
)
assert
th
.
allclose
(
v3
,
v4
)
assert
th
.
allclose
(
v3
,
v4
)
...
@@ -127,19 +125,19 @@ def test_update_all_multi_fn():
...
@@ -127,19 +125,19 @@ def test_update_all_multi_fn():
fld
=
'f2'
fld
=
'f2'
# update all, mix of builtin and UDF
# update all, mix of builtin and UDF
g
.
update_all
([
fn
.
copy_src
(
src
=
fld
,
out
=
'm1'
),
message_func
],
g
.
update_all
([
fn
.
copy_src
(
src
=
fld
,
out
=
'm1'
),
message_func
],
[
fn
.
sum
(
msg
s
=
'm1'
,
out
=
'v1'
),
reduce_func
],
[
fn
.
sum
(
msg
=
'm1'
,
out
=
'v1'
),
reduce_func
],
None
)
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v2
)
# run builtin with single message and reduce
# run builtin with single message and reduce
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
'v1'
),
None
)
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'v1'
),
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v1
=
g
.
get_n_repr
()[
'v1'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v2
)
# 1 message, 2 reduces
, using anonymous repr
# 1 message, 2 reduces
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
[
fn
.
sum
(
out
=
'v2'
),
fn
.
sum
(
out
=
'v3'
)],
None
)
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
,
out
=
'm'
),
[
fn
.
sum
(
msg
=
'm'
,
out
=
'v2'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'v3'
)],
None
)
v2
=
g
.
get_n_repr
()[
'v2'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v2
)
...
@@ -147,7 +145,7 @@ def test_update_all_multi_fn():
...
@@ -147,7 +145,7 @@ def test_update_all_multi_fn():
# update all with edge weights, 2 message, 3 reduces
# update all with edge weights, 2 message, 3 reduces
g
.
update_all
([
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm1'
),
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm2'
)],
g
.
update_all
([
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm1'
),
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm2'
)],
[
fn
.
sum
(
msg
s
=
'm1'
,
out
=
'v1'
),
fn
.
sum
(
msg
s
=
'm2'
,
out
=
'v2'
),
fn
.
sum
(
msg
s
=
'm1'
,
out
=
'v3'
)],
[
fn
.
sum
(
msg
=
'm1'
,
out
=
'v1'
),
fn
.
sum
(
msg
=
'm2'
,
out
=
'v2'
),
fn
.
sum
(
msg
=
'm1'
,
out
=
'v3'
)],
None
)
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
...
@@ -181,20 +179,23 @@ def test_send_and_recv_multi_fn():
...
@@ -181,20 +179,23 @@ def test_send_and_recv_multi_fn():
# send and recv, mix of builtin and UDF
# send and recv, mix of builtin and UDF
g
.
send_and_recv
(
u
,
v
,
g
.
send_and_recv
(
u
,
v
,
[
fn
.
copy_src
(
src
=
fld
,
out
=
'm1'
),
message_func
],
[
fn
.
copy_src
(
src
=
fld
,
out
=
'm1'
),
message_func
],
[
fn
.
sum
(
msg
s
=
'm1'
,
out
=
'v1'
),
reduce_func
],
[
fn
.
sum
(
msg
=
'm1'
,
out
=
'v1'
),
reduce_func
],
None
)
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v2
)
# run builtin with single message and reduce
# run builtin with single message and reduce
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
'v1'
),
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'v1'
),
None
)
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v1
=
g
.
get_n_repr
()[
'v1'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v2
)
# 1 message, 2 reduces, using anonymous repr
# 1 message, 2 reduces
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
[
fn
.
sum
(
out
=
'v2'
),
fn
.
sum
(
out
=
'v3'
)],
None
)
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
,
out
=
'm'
),
[
fn
.
sum
(
msg
=
'm'
,
out
=
'v2'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'v3'
)],
None
)
v2
=
g
.
get_n_repr
()[
'v2'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
assert
th
.
allclose
(
v1
,
v2
)
assert
th
.
allclose
(
v1
,
v2
)
...
@@ -203,7 +204,7 @@ def test_send_and_recv_multi_fn():
...
@@ -203,7 +204,7 @@ def test_send_and_recv_multi_fn():
# send and recv with edge weights, 2 message, 3 reduces
# send and recv with edge weights, 2 message, 3 reduces
g
.
send_and_recv
(
u
,
v
,
g
.
send_and_recv
(
u
,
v
,
[
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm1'
),
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm2'
)],
[
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm1'
),
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm2'
)],
[
fn
.
sum
(
msg
s
=
'm1'
,
out
=
'v1'
),
fn
.
sum
(
msg
s
=
'm2'
,
out
=
'v2'
),
fn
.
sum
(
msg
s
=
'm1'
,
out
=
'v3'
)],
[
fn
.
sum
(
msg
=
'm1'
,
out
=
'v1'
),
fn
.
sum
(
msg
=
'm2'
,
out
=
'v2'
),
fn
.
sum
(
msg
=
'm1'
,
out
=
'v3'
)],
None
)
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment