Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3e76bcc0
Commit
3e76bcc0
authored
Oct 18, 2018
by
Minjie Wang
Browse files
remove anonymous repr
parent
fb6be9fb
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
276 additions
and
600 deletions
+276
-600
examples/pytorch/gcn/gcn_spmv.py
examples/pytorch/gcn/gcn_spmv.py
+1
-1
python/dgl/__init__.py
python/dgl/__init__.py
+1
-1
python/dgl/base.py
python/dgl/base.py
+0
-3
python/dgl/function/message.py
python/dgl/function/message.py
+59
-46
python/dgl/function/reducer.py
python/dgl/function/reducer.py
+37
-35
python/dgl/graph.py
python/dgl/graph.py
+74
-104
python/dgl/scheduler.py
python/dgl/scheduler.py
+22
-40
tests/pytorch/test_basics.py
tests/pytorch/test_basics.py
+13
-14
tests/pytorch/test_basics_anonymous.py
tests/pytorch/test_basics_anonymous.py
+0
-198
tests/pytorch/test_batched_graph.py
tests/pytorch/test_batched_graph.py
+30
-30
tests/pytorch/test_function.py
tests/pytorch/test_function.py
+0
-86
tests/pytorch/test_line_graph.py
tests/pytorch/test_line_graph.py
+5
-9
tests/pytorch/test_specialization.py
tests/pytorch/test_specialization.py
+34
-33
No files found.
examples/pytorch/gcn/gcn_spmv.py
View file @
3e76bcc0
...
...
@@ -56,7 +56,7 @@ class GCN(nn.Module):
g
.
apply_nodes
(
apply_node_func
=
lambda
node
:
F
.
dropout
(
node
[
'h'
],
p
=
self
.
dropout
))
self
.
g
.
update_all
(
fn
.
copy_src
(
src
=
'h'
,
out
=
'm'
),
fn
.
sum
(
msg
s
=
'm'
,
out
=
'h'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'h'
),
layer
)
return
self
.
g
.
pop_n_repr
(
'h'
)
...
...
python/dgl/__init__.py
View file @
3e76bcc0
...
...
@@ -11,5 +11,5 @@ from ._ffi.base import DGLError, __version__
from
.base
import
ALL
from
.batched_graph
import
*
from
.generator
import
*
from
.graph
import
DGLGraph
,
__MSG__
,
__REPR__
from
.graph
import
DGLGraph
from
.subgraph
import
DGLSubGraph
python/dgl/base.py
View file @
3e76bcc0
...
...
@@ -11,7 +11,4 @@ ALL = "__ALL__"
def
is_all
(
arg
):
return
isinstance
(
arg
,
str
)
and
arg
==
ALL
__MSG__
=
"__MSG__"
__REPR__
=
"__REPR__"
dgl_warning
=
warnings
.
warn
python/dgl/function/message.py
View file @
3e76bcc0
...
...
@@ -4,17 +4,25 @@ from __future__ import absolute_import
import
operator
import
dgl.backend
as
F
__all__
=
[
"MessageFunction"
,
"src_mul_edge"
,
"copy_src"
,
"copy_edge"
]
__all__
=
[
"src_mul_edge"
,
"copy_src"
,
"copy_edge"
]
class
MessageFunction
(
object
):
"""Base builtin message function class."""
def
__call__
(
self
,
src
,
edge
):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise
NotImplementedError
def
name
(
self
):
"""Return the name of this builtin function."""
raise
NotImplementedError
def
is_spmv_supported
(
self
,
g
):
"""Return whether the SPMV optimization is supported."""
raise
NotImplementedError
...
...
@@ -22,12 +30,6 @@ class BundledMessageFunction(MessageFunction):
def
__init__
(
self
,
fn_list
):
if
not
isinstance
(
fn_list
,
(
list
,
tuple
)):
fn_list
=
[
fn_list
]
else
:
# sanity check on out field
for
fn
in
fn_list
:
# cannot perform check for udf
if
isinstance
(
fn
,
MessageFunction
)
and
fn
.
out_field
is
None
:
raise
RuntimeError
(
"Not specifying out field for multiple message is ambiguous"
)
self
.
fn_list
=
fn_list
def
is_spmv_supported
(
self
,
g
):
...
...
@@ -43,11 +45,8 @@ class BundledMessageFunction(MessageFunction):
if
ret
is
None
:
ret
=
msg
else
:
try
:
# ret and msg must be dict
ret
.
update
(
msg
)
except
:
raise
RuntimeError
(
"Must specify out field for multiple message"
)
# ret and msg must be dict
ret
.
update
(
msg
)
return
ret
def
name
(
self
):
...
...
@@ -55,25 +54,26 @@ class BundledMessageFunction(MessageFunction):
def
_is_spmv_supported_node_feat
(
g
,
field
):
if
field
is
None
:
feat
=
g
.
get_n_repr
()
else
:
feat
=
g
.
get_n_repr
()[
field
]
"""Return whether the node feature shape supports SPMV optimization.
Only scalar and vector features are supported currently.
"""
feat
=
g
.
get_n_repr
()[
field
]
shape
=
F
.
shape
(
feat
)
return
len
(
shape
)
==
1
or
len
(
shape
)
==
2
def
_is_spmv_supported_edge_feat
(
g
,
field
):
# check shape, only scalar edge feature can be optimized at the moment
if
field
is
None
:
feat
=
g
.
get_e_repr
()
else
:
feat
=
g
.
get_e_repr
()[
field
]
"""Return whether the edge feature shape supports SPMV optimization.
Only scalar feature is supported currently.
"""
feat
=
g
.
get_e_repr
()[
field
]
shape
=
F
.
shape
(
feat
)
return
len
(
shape
)
==
1
or
(
len
(
shape
)
==
2
and
shape
[
1
]
==
1
)
class
SrcMulEdgeMessageFunction
(
MessageFunction
):
def
__init__
(
self
,
mul_op
,
src_field
=
None
,
edge_field
=
None
,
out_field
=
None
):
def
__init__
(
self
,
mul_op
,
src_field
,
edge_field
,
out_field
):
self
.
mul_op
=
mul_op
self
.
src_field
=
src_field
self
.
edge_field
=
edge_field
...
...
@@ -84,21 +84,14 @@ class SrcMulEdgeMessageFunction(MessageFunction):
and
_is_spmv_supported_edge_feat
(
g
,
self
.
edge_field
)
def
__call__
(
self
,
src
,
edge
):
if
self
.
src_field
is
not
None
:
src
=
src
[
self
.
src_field
]
if
self
.
edge_field
is
not
None
:
edge
=
edge
[
self
.
edge_field
]
ret
=
self
.
mul_op
(
src
,
edge
)
if
self
.
out_field
is
None
:
return
ret
else
:
return
{
self
.
out_field
:
ret
}
ret
=
self
.
mul_op
(
src
[
self
.
src_field
],
edge
[
self
.
edge_field
])
return
{
self
.
out_field
:
ret
}
def
name
(
self
):
return
"src_mul_edge"
class
CopySrcMessageFunction
(
MessageFunction
):
def
__init__
(
self
,
src_field
=
None
,
out_field
=
None
):
def
__init__
(
self
,
src_field
,
out_field
):
self
.
src_field
=
src_field
self
.
out_field
=
out_field
...
...
@@ -106,14 +99,7 @@ class CopySrcMessageFunction(MessageFunction):
return
_is_spmv_supported_node_feat
(
g
,
self
.
src_field
)
def
__call__
(
self
,
src
,
edge
):
if
self
.
src_field
is
not
None
:
ret
=
src
[
self
.
src_field
]
else
:
ret
=
src
if
self
.
out_field
is
None
:
return
ret
else
:
return
{
self
.
out_field
:
ret
}
return
{
self
.
out_field
:
src
[
self
.
src_field
]}
def
name
(
self
):
return
"copy_src"
...
...
@@ -142,14 +128,41 @@ class CopyEdgeMessageFunction(MessageFunction):
return
"copy_edge"
def
src_mul_edge
(
src
=
None
,
edge
=
None
,
out
=
None
):
"""TODO(minjie): docstring """
def
src_mul_edge
(
src
,
edge
,
out
):
"""Builtin message function that computes message by multiplying source node features
with edge features.
Parameters
----------
src : str
The source feature name.
edge : str
The edge feature name.
out : str
The output message name.
"""
return
SrcMulEdgeMessageFunction
(
operator
.
mul
,
src
,
edge
,
out
)
def
copy_src
(
src
=
None
,
out
=
None
):
"""TODO(minjie): docstring """
def
copy_src
(
src
,
out
):
"""Builtin message function that computes message using source node feature.
Parameters
----------
src : str
The source feature name.
out : str
The output message name.
"""
return
CopySrcMessageFunction
(
src
,
out
)
def
copy_edge
(
edge
=
None
,
out
=
None
):
"""TODO(minjie): docstring """
def
copy_edge
(
edge
,
out
):
"""Builtin message function that computes message using edge feature.
Parameters
----------
edge : str
The edge feature name.
out : str
The output message name.
"""
return
CopyEdgeMessageFunction
(
edge
,
out
)
python/dgl/function/reducer.py
View file @
3e76bcc0
...
...
@@ -3,27 +3,30 @@ from __future__ import absolute_import
from
..
import
backend
as
F
__all__
=
[
"ReduceFunction"
,
"sum"
,
"max"
]
__all__
=
[
"sum"
,
"max"
]
class
ReduceFunction
(
object
):
"""Base builtin reduce function class."""
def
__call__
(
self
,
node
,
msgs
):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise
NotImplementedError
def
name
(
self
):
"""Return the name of this builtin function."""
raise
NotImplementedError
def
is_spmv_supported
(
self
):
"""Return whether the SPMV optimization is supported."""
raise
NotImplementedError
class
BundledReduceFunction
(
ReduceFunction
):
def
__init__
(
self
,
fn_list
):
if
not
isinstance
(
fn_list
,
(
list
,
tuple
)):
fn_list
=
[
fn_list
]
else
:
# sanity check on out field
for
fn
in
fn_list
:
if
isinstance
(
fn
,
ReduceFunction
)
and
fn
.
out_field
is
None
:
raise
RuntimeError
(
"Not specifying out field for multiple reduce is ambiguous"
)
self
.
fn_list
=
fn_list
def
is_spmv_supported
(
self
):
...
...
@@ -39,51 +42,50 @@ class BundledReduceFunction(ReduceFunction):
if
ret
is
None
:
ret
=
rpr
else
:
try
:
# ret and rpr must be dict
ret
.
update
(
rpr
)
except
:
raise
RuntimeError
(
"Must specify out field for multiple reudce"
)
# ret and rpr must be dict
ret
.
update
(
rpr
)
return
ret
def
name
(
self
):
return
"bundled"
class
ReducerFunctionTemplate
(
ReduceFunction
):
def
__init__
(
self
,
name
,
batch_op
,
nonbatch_
op
,
msg_field
=
None
,
out_field
=
None
):
def
__init__
(
self
,
name
,
op
,
msg_field
,
out_field
):
self
.
name
=
name
self
.
batch_op
=
batch_op
self
.
nonbatch_op
=
nonbatch_op
self
.
op
=
op
self
.
msg_field
=
msg_field
self
.
out_field
=
out_field
def
is_spmv_supported
(
self
):
#
TODO: support max
#
NOTE: only sum is supported right now.
return
self
.
name
==
"sum"
def
__call__
(
self
,
node
,
msgs
):
if
isinstance
(
msgs
,
list
):
if
self
.
msg_field
is
None
:
ret
=
self
.
nonbatch_op
(
msgs
)
else
:
ret
=
self
.
nonbatch_op
([
msg
[
self
.
msg_field
]
for
msg
in
msgs
])
else
:
if
self
.
msg_field
is
None
:
ret
=
self
.
batch_op
(
msgs
,
1
)
else
:
ret
=
self
.
batch_op
(
msgs
[
self
.
msg_field
],
1
)
if
self
.
out_field
is
None
:
return
ret
else
:
return
{
self
.
out_field
:
ret
}
return
{
self
.
out_field
:
self
.
op
(
msgs
[
self
.
msg_field
],
1
)}
def
name
(
self
):
return
self
.
name
_python_sum
=
sum
def
sum
(
msgs
=
None
,
out
=
None
):
return
ReducerFunctionTemplate
(
"sum"
,
F
.
sum
,
_python_sum
,
msgs
,
out
)
def
sum
(
msg
,
out
):
"""Builtin reduce function that aggregates messages by sum.
Parameters
----------
msg : str
The message name.
out : str
The output node feature name.
"""
return
ReducerFunctionTemplate
(
"sum"
,
F
.
sum
,
msg
,
out
)
def
max
(
msg
,
out
):
"""Builtin reduce function that aggregates messages by max.
_python_max
=
max
def
max
(
msgs
=
None
,
out
=
None
):
return
ReducerFunctionTemplate
(
"max"
,
F
.
max
,
_python_max
,
msgs
,
out
)
Parameters
----------
msg : str
The message name.
out : str
The output node feature name.
"""
return
ReducerFunctionTemplate
(
"max"
,
F
.
max
,
msg
,
out
)
python/dgl/graph.py
View file @
3e76bcc0
...
...
@@ -6,7 +6,7 @@ import networkx as nx
import
numpy
as
np
import
dgl
from
.base
import
ALL
,
is_all
,
__MSG__
,
__REPR__
from
.base
import
ALL
,
is_all
,
DGLError
,
dgl_warning
from
.
import
backend
as
F
from
.backend
import
Tensor
from
.frame
import
FrameRef
,
merge_frames
...
...
@@ -22,7 +22,6 @@ class DGLGraph(object):
"""Base graph class specialized for neural networks on graphs.
TODO(minjie): document of batching semantics
TODO(minjie): document of __REPR__ semantics
Parameters
----------
...
...
@@ -448,7 +447,9 @@ class DGLGraph(object):
The nx graph
"""
nx_graph
=
self
.
_graph
.
to_networkx
()
#TODO: attributes
#TODO(minjie): attributes
dgl_warning
(
'to_networkx currently does not support converting'
' node/edge features automatically.'
)
return
nx_graph
def
from_networkx
(
self
,
nx_graph
,
node_attrs
=
None
,
edge_attrs
=
None
):
...
...
@@ -550,20 +551,17 @@ class DGLGraph(object):
def
set_n_repr
(
self
,
hu
,
u
=
ALL
,
inplace
=
False
):
"""Set node(s) representation.
To set multiple node representations at once, pass `u` with a tensor or
a supported container of node ids. In this case, `hu` must be a tensor
of shape (B, D1, D2, ...), where B is the number of the nodes and
(D1, D2, ...) is the shape of the node representation tensor.
Dictionary type is also supported for `hu`. In this case, each item
will be treated as separate attribute of the nodes.
`hu` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of nodes to be updated,
and (D1, D2, ...) be the shape of the node representation tensor. The
length of the given node ids must match B (i.e, len(u) == B).
All update will be done out-placely to work with autograd unless the inplace
flag is true.
Parameters
----------
hu :
tensor or
dict of tensor
hu : dict of tensor
Node representation.
u : node, container or tensor
The node(s).
...
...
@@ -571,32 +569,31 @@ class DGLGraph(object):
True if the update is done inplacely
"""
# sanity check
if
not
utils
.
is_dict_like
(
hu
):
raise
DGLError
(
'Expect dictionary type for feature data.'
' Got "%s" instead.'
%
type
(
hu
))
if
is_all
(
u
):
num_nodes
=
self
.
number_of_nodes
()
else
:
u
=
utils
.
toindex
(
u
)
num_nodes
=
len
(
u
)
if
utils
.
is_dict_like
(
hu
):
for
key
,
val
in
hu
.
items
():
assert
F
.
shape
(
val
)[
0
]
=
=
num_nodes
else
:
assert
F
.
shape
(
hu
)[
0
]
==
num_nodes
for
key
,
val
in
hu
.
items
(
):
nfeats
=
F
.
shape
(
val
)[
0
]
if
nfeats
!
=
num_nodes
:
raise
DGLError
(
'Expect number of features to match number of nodes (len(u)).'
' Got %d and %d instead.'
%
(
nfeats
,
num_nodes
))
# set
if
is_all
(
u
):
if
utils
.
is_dict_like
(
hu
):
for
key
,
val
in
hu
.
items
():
self
.
_node_frame
[
key
]
=
val
else
:
self
.
_node_frame
[
__REPR__
]
=
hu
for
key
,
val
in
hu
.
items
():
self
.
_node_frame
[
key
]
=
val
else
:
if
utils
.
is_dict_like
(
hu
):
self
.
_node_frame
.
update_rows
(
u
,
hu
,
inplace
=
inplace
)
else
:
self
.
_node_frame
.
update_rows
(
u
,
{
__REPR__
:
hu
},
inplace
=
inplace
)
self
.
_node_frame
.
update_rows
(
u
,
hu
,
inplace
=
inplace
)
def
get_n_repr
(
self
,
u
=
ALL
):
"""Get node(s) representation.
The returned feature tensor batches multiple node features on the first dimension.
Parameters
----------
u : node, container or tensor
...
...
@@ -605,23 +602,17 @@ class DGLGraph(object):
Returns
-------
dict
Representation dict
Representation dict
from feature name to feature tensor.
"""
if
len
(
self
.
node_attr_schemes
())
==
0
:
return
dict
()
if
is_all
(
u
):
if
len
(
self
.
_node_frame
)
==
1
and
__REPR__
in
self
.
_node_frame
:
return
self
.
_node_frame
[
__REPR__
]
else
:
return
dict
(
self
.
_node_frame
)
return
dict
(
self
.
_node_frame
)
else
:
u
=
utils
.
toindex
(
u
)
if
len
(
self
.
_node_frame
)
==
1
and
__REPR__
in
self
.
_node_frame
:
return
self
.
_node_frame
.
select_rows
(
u
)[
__REPR__
]
else
:
return
self
.
_node_frame
.
select_rows
(
u
)
return
self
.
_node_frame
.
select_rows
(
u
)
def
pop_n_repr
(
self
,
key
=
__REPR__
):
def
pop_n_repr
(
self
,
key
):
"""Get and remove the specified node repr.
Parameters
...
...
@@ -636,23 +627,19 @@ class DGLGraph(object):
"""
return
self
.
_node_frame
.
pop
(
key
)
def
set_e_repr
(
self
,
h
_uv
,
u
=
ALL
,
v
=
ALL
,
inplace
=
False
):
def
set_e_repr
(
self
,
h
e
,
u
=
ALL
,
v
=
ALL
,
inplace
=
False
):
"""Set edge(s) representation.
To set multiple edge representations at once, pass `u` and `v` with tensors or
supported containers of node ids. In this case, `h_uv` must be a tensor
of shape (B, D1, D2, ...), where B is the number of the edges and
(D1, D2, ...) is the shape of the edge representation tensor.
Dictionary type is also supported for `h_uv`. In this case, each item
will be treated as separate attribute of the edges.
`he` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
and (D1, D2, ...) be the shape of the edge representation tensor.
All update will be done out-placely to work with autograd unless the inplace
flag is true.
Parameters
----------
h
_uv
: tensor or dict of tensor
h
e
: tensor or dict of tensor
Edge representation.
u : node, container or tensor
The source node(s).
...
...
@@ -662,26 +649,33 @@ class DGLGraph(object):
True if the update is done inplacely
"""
# sanity check
if
not
utils
.
is_dict_like
(
he
):
raise
DGLError
(
'Expect dictionary type for feature data.'
' Got "%s" instead.'
%
type
(
he
))
u_is_all
=
is_all
(
u
)
v_is_all
=
is_all
(
v
)
assert
u_is_all
==
v_is_all
if
u_is_all
:
self
.
set_e_repr_by_id
(
h
_uv
,
eid
=
ALL
,
inplace
=
inplace
)
self
.
set_e_repr_by_id
(
h
e
,
eid
=
ALL
,
inplace
=
inplace
)
else
:
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
_
,
_
,
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
self
.
set_e_repr_by_id
(
h
_uv
,
eid
=
eid
,
inplace
=
inplace
)
self
.
set_e_repr_by_id
(
h
e
,
eid
=
eid
,
inplace
=
inplace
)
def
set_e_repr_by_id
(
self
,
h
_uv
,
eid
=
ALL
,
inplace
=
False
):
def
set_e_repr_by_id
(
self
,
h
e
,
eid
=
ALL
,
inplace
=
False
):
"""Set edge(s) representation by edge id.
`he` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
and (D1, D2, ...) be the shape of the edge representation tensor.
All update will be done out-placely to work with autograd unless the inplace
flag is true.
Parameters
----------
h
_uv
: tensor or dict of tensor
h
e
: tensor or dict of tensor
Edge representation.
eid : int, container or tensor
The edge id(s).
...
...
@@ -689,30 +683,27 @@ class DGLGraph(object):
True if the update is done inplacely
"""
# sanity check
if
not
utils
.
is_dict_like
(
he
):
raise
DGLError
(
'Expect dictionary type for feature data.'
' Got "%s" instead.'
%
type
(
he
))
if
is_all
(
eid
):
num_edges
=
self
.
number_of_edges
()
else
:
eid
=
utils
.
toindex
(
eid
)
num_edges
=
len
(
eid
)
if
utils
.
is_dict_like
(
h_uv
):
for
key
,
val
in
h_uv
.
items
():
assert
F
.
shape
(
val
)[
0
]
=
=
num_edges
else
:
assert
F
.
shape
(
h_uv
)[
0
]
==
num_edges
for
key
,
val
in
he
.
items
(
):
nfeats
=
F
.
shape
(
val
)[
0
]
if
nfeats
!
=
num_edges
:
raise
DGLError
(
'Expect number of features to match number of edges.'
' Got %d and %d instead.'
%
(
nfeats
,
num_edges
))
# set
if
is_all
(
eid
):
# update column
if
utils
.
is_dict_like
(
h_uv
):
for
key
,
val
in
h_uv
.
items
():
self
.
_edge_frame
[
key
]
=
val
else
:
self
.
_edge_frame
[
__REPR__
]
=
h_uv
for
key
,
val
in
he
.
items
():
self
.
_edge_frame
[
key
]
=
val
else
:
# update row
if
utils
.
is_dict_like
(
h_uv
):
self
.
_edge_frame
.
update_rows
(
eid
,
h_uv
,
inplace
=
inplace
)
else
:
self
.
_edge_frame
.
update_rows
(
eid
,
{
__REPR__
:
h_uv
},
inplace
=
inplace
)
self
.
_edge_frame
.
update_rows
(
eid
,
he
,
inplace
=
inplace
)
def
get_e_repr
(
self
,
u
=
ALL
,
v
=
ALL
):
"""Get node(s) representation.
...
...
@@ -742,7 +733,7 @@ class DGLGraph(object):
_
,
_
,
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
return
self
.
get_e_repr_by_id
(
eid
=
eid
)
def
pop_e_repr
(
self
,
key
=
__REPR__
):
def
pop_e_repr
(
self
,
key
):
"""Get and remove the specified edge repr.
Parameters
...
...
@@ -768,21 +759,15 @@ class DGLGraph(object):
Returns
-------
dict
Representation dict
Representation dict
from feature name to feature tensor.
"""
if
len
(
self
.
edge_attr_schemes
())
==
0
:
return
dict
()
if
is_all
(
eid
):
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
return
self
.
_edge_frame
[
__REPR__
]
else
:
return
dict
(
self
.
_edge_frame
)
return
dict
(
self
.
_edge_frame
)
else
:
eid
=
utils
.
toindex
(
eid
)
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
return
self
.
_edge_frame
.
select_rows
(
eid
)[
__REPR__
]
else
:
return
self
.
_edge_frame
.
select_rows
(
eid
)
return
self
.
_edge_frame
.
select_rows
(
eid
)
def
register_edge_func
(
self
,
edge_func
):
"""Register global edge update function.
...
...
@@ -837,6 +822,8 @@ class DGLGraph(object):
def
apply_nodes
(
self
,
v
=
ALL
,
apply_node_func
=
"default"
):
"""Apply the function on node representations.
Applying a None function will be ignored.
Parameters
----------
v : int, iterable of int, tensor, optional
...
...
@@ -868,7 +855,7 @@ class DGLGraph(object):
# merge current node_repr with reduce output
curr_repr
=
utils
.
HybridDict
(
reduce_accum
,
curr_repr
)
new_repr
=
apply_node_func
(
curr_repr
)
if
reduce_accum
is
not
None
and
utils
.
is_dict_like
(
new_repr
)
:
if
reduce_accum
is
not
None
:
# merge new node_repr with reduce output
reduce_accum
.
update
(
new_repr
)
new_repr
=
reduce_accum
...
...
@@ -877,6 +864,8 @@ class DGLGraph(object):
def
apply_edges
(
self
,
u
=
None
,
v
=
None
,
apply_edge_func
=
"default"
,
eid
=
None
):
"""Apply the function on edge representations.
Applying a None function will be ignored.
Parameters
----------
u : optional, int, iterable of int, tensor
...
...
@@ -893,7 +882,6 @@ class DGLGraph(object):
if
not
apply_edge_func
:
# Skip none function call.
return
if
eid
is
None
:
new_repr
=
apply_edge_func
(
self
.
get_e_repr
(
u
,
v
))
self
.
set_e_repr
(
new_repr
,
u
,
v
)
...
...
@@ -914,9 +902,8 @@ class DGLGraph(object):
The message function can be any of the pre-defined functions
('from_src').
Currently, we require the message functions of consecutive send's and
send_on's to return the same keys. Otherwise the behavior will be
undefined.
Currently, we require the message functions of consecutive send's to
return the same keys. Otherwise the behavior will be undefined.
Parameters
----------
...
...
@@ -964,10 +951,7 @@ class DGLGraph(object):
edge_reprs
=
self
.
get_e_repr_by_id
(
eid
)
msgs
=
message_func
(
src_reprs
,
edge_reprs
)
self
.
_msg_graph
.
add_edges
(
u
,
v
)
if
utils
.
is_dict_like
(
msgs
):
self
.
_msg_frame
.
append
(
msgs
)
else
:
self
.
_msg_frame
.
append
({
__MSG__
:
msgs
})
self
.
_msg_frame
.
append
(
msgs
)
# TODO(minjie): Fix these codes in next PR.
"""
...
...
@@ -1061,7 +1045,6 @@ class DGLGraph(object):
v
=
utils
.
toindex
(
v
)
u
,
v
=
utils
.
edge_broadcasting
(
u
,
v
)
_
,
_
,
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
# call the UDF
src_reprs
=
self
.
get_n_repr
(
u
)
dst_reprs
=
self
.
get_n_repr
(
v
)
...
...
@@ -1148,25 +1131,19 @@ class DGLGraph(object):
msg_shape
=
F
.
shape
(
msg
)
new_shape
=
(
bkt_len
,
deg
)
+
msg_shape
[
1
:]
return
F
.
reshape
(
msg
,
new_shape
)
if
len
(
in_msgs
)
==
1
and
__MSG__
in
in_msgs
:
reshaped_in_msgs
=
_reshape_fn
(
in_msgs
[
__MSG__
])
else
:
reshaped_in_msgs
=
utils
.
LazyDict
(
lambda
key
:
_reshape_fn
(
in_msgs
[
key
]),
self
.
_msg_frame
.
schemes
)
reshaped_in_msgs
=
utils
.
LazyDict
(
lambda
key
:
_reshape_fn
(
in_msgs
[
key
]),
self
.
_msg_frame
.
schemes
)
reordered_v
.
append
(
v_bkt
.
tousertensor
())
new_reprs
.
append
(
reduce_func
(
dst_reprs
,
reshaped_in_msgs
))
# TODO: clear partial messages
# TODO
(minjie)
: clear partial messages
self
.
reset_messages
()
# Pack all reducer results together
reordered_v
=
F
.
pack
(
reordered_v
)
if
utils
.
is_dict_like
(
new_reprs
[
0
]):
keys
=
new_reprs
[
0
].
keys
()
new_reprs
=
{
key
:
F
.
pack
([
repr
[
key
]
for
repr
in
new_reprs
])
for
key
in
keys
}
else
:
new_reprs
=
{
__REPR__
:
F
.
pack
(
new_reprs
)}
keys
=
new_reprs
[
0
].
keys
()
new_reprs
=
{
key
:
F
.
pack
([
repr
[
key
]
for
repr
in
new_reprs
])
for
key
in
keys
}
if
v_is_all
and
not
has_zero_degree
:
# First do reorder and then replace the whole column.
...
...
@@ -1237,15 +1214,13 @@ class DGLGraph(object):
if
executor
:
new_reprs
=
executor
.
run
()
if
not
utils
.
is_dict_like
(
new_reprs
):
new_reprs
=
{
__REPR__
:
new_reprs
}
unique_v
=
executor
.
recv_nodes
self
.
_apply_nodes
(
unique_v
,
apply_node_func
,
reduce_accum
=
new_reprs
)
elif
eid
is
not
None
:
_
,
v
,
_
=
self
.
_graph
.
find_edges
(
eid
)
unique_v
=
utils
.
toindex
(
F
.
unique
(
v
.
tousertensor
()))
# TODO: replace with the new DegreeBucketingScheduler
# TODO
(quan)
: replace with the new DegreeBucketingScheduler
self
.
send
(
eid
=
eid
,
message_func
=
message_func
)
self
.
recv
(
unique_v
,
reduce_func
,
apply_node_func
)
else
:
...
...
@@ -1261,10 +1236,7 @@ class DGLGraph(object):
edge_reprs
=
self
.
get_e_repr
(
u
,
v
)
msgs
=
message_func
(
src_reprs
,
edge_reprs
)
msg_frame
=
FrameRef
()
if
utils
.
is_dict_like
(
msgs
):
msg_frame
.
append
(
msgs
)
else
:
msg_frame
.
append
({
__MSG__
:
msgs
})
msg_frame
.
append
(
msgs
)
# recv with degree bucketing
executor
=
scheduler
.
get_recv_executor
(
graph
=
self
,
...
...
@@ -1353,8 +1325,6 @@ class DGLGraph(object):
"update_all"
,
self
,
message_func
=
message_func
,
reduce_func
=
reduce_func
)
if
executor
:
new_reprs
=
executor
.
run
()
if
not
utils
.
is_dict_like
(
new_reprs
):
new_reprs
=
{
__REPR__
:
new_reprs
}
self
.
_apply_nodes
(
ALL
,
apply_node_func
,
reduce_accum
=
new_reprs
)
else
:
self
.
send
(
ALL
,
ALL
,
message_func
)
...
...
@@ -1387,7 +1357,7 @@ class DGLGraph(object):
Arguments for pre-defined iterators.
"""
if
isinstance
(
traverser
,
str
):
# TODO Call pre-defined routine to unroll the computation.
# TODO
(minjie):
Call pre-defined routine to unroll the computation.
raise
RuntimeError
(
'Not implemented.'
)
else
:
# NOTE: the iteration can return multiple edges at each step.
...
...
python/dgl/scheduler.py
View file @
3e76bcc0
...
...
@@ -3,7 +3,7 @@ from __future__ import absolute_import
import
numpy
as
np
from
.base
import
ALL
,
__MSG__
,
__REPR__
from
.base
import
ALL
,
DGLError
from
.
import
backend
as
F
from
.function
import
message
as
fmsg
from
.function
import
reducer
as
fred
...
...
@@ -111,7 +111,15 @@ def light_degree_bucketing_for_graph(graph):
class
Executor
(
object
):
"""Base class for executing graph computation."""
def
run
(
self
):
"""Run this executor.
This should return the new node features.
TODO(minjie): extend this to support computation on edges.
"""
raise
NotImplementedError
class
SPMVOperator
(
Executor
):
...
...
@@ -126,10 +134,7 @@ class SPMVOperator(Executor):
def
run
(
self
):
# get src col
if
self
.
src_field
is
None
:
srccol
=
self
.
node_repr
else
:
srccol
=
self
.
node_repr
[
self
.
src_field
]
srccol
=
self
.
node_repr
[
self
.
src_field
]
ctx
=
F
.
get_context
(
srccol
)
# build adjmat
...
...
@@ -142,10 +147,7 @@ class SPMVOperator(Executor):
dstcol
=
F
.
squeeze
(
dstcol
)
else
:
dstcol
=
F
.
spmm
(
adjmat
,
srccol
)
if
self
.
dst_field
is
None
:
return
dstcol
else
:
return
{
self
.
dst_field
:
dstcol
}
return
{
self
.
dst_field
:
dstcol
}
# FIXME: refactorize in scheduler/executor redesign
...
...
@@ -180,20 +182,14 @@ class DegreeBucketingExecutor(Executor):
msg_shape
=
F
.
shape
(
msg
)
new_shape
=
(
len
(
vv
),
deg
)
+
msg_shape
[
1
:]
return
F
.
reshape
(
msg
,
new_shape
)
if
len
(
in_msgs
)
==
1
and
__MSG__
in
in_msgs
:
reshaped_in_msgs
=
_reshape_fn
(
in_msgs
[
__MSG__
])
else
:
reshaped_in_msgs
=
utils
.
LazyDict
(
lambda
key
:
_reshape_fn
(
in_msgs
[
key
]),
self
.
msg_frame
.
schemes
)
reshaped_in_msgs
=
utils
.
LazyDict
(
lambda
key
:
_reshape_fn
(
in_msgs
[
key
]),
self
.
msg_frame
.
schemes
)
new_reprs
.
append
(
self
.
rfunc
(
dst_reprs
,
reshaped_in_msgs
))
# Pack all reducer results together
if
utils
.
is_dict_like
(
new_reprs
[
0
]):
keys
=
new_reprs
[
0
].
keys
()
new_reprs
=
{
key
:
F
.
pack
([
repr
[
key
]
for
repr
in
new_reprs
])
for
key
in
keys
}
else
:
new_reprs
=
{
__REPR__
:
F
.
pack
(
new_reprs
)}
keys
=
new_reprs
[
0
].
keys
()
new_reprs
=
{
key
:
F
.
pack
([
repr
[
key
]
for
repr
in
new_reprs
])
for
key
in
keys
}
return
new_reprs
...
...
@@ -249,12 +245,6 @@ class UpdateAllExecutor(BasicExecutor):
self
.
_graph_shape
=
None
self
.
_recv_nodes
=
None
@
property
def
graph_idx
(
self
):
if
self
.
_graph_idx
is
None
:
self
.
_graph_idx
=
self
.
g
.
_graph
.
adjacency_matrix
()
return
self
.
_graph_idx
@
property
def
graph_shape
(
self
):
if
self
.
_graph_shape
is
None
:
...
...
@@ -280,16 +270,13 @@ class UpdateAllExecutor(BasicExecutor):
def
_adj_build_fn
(
self
,
edge_field
,
ctx
,
use_edge_feat
):
if
use_edge_feat
:
if
edge_field
is
None
:
dat
=
self
.
edge_repr
else
:
dat
=
self
.
edge_repr
[
edge_field
]
dat
=
self
.
edge_repr
[
edge_field
]
dat
=
F
.
squeeze
(
dat
)
# TODO(minjie): should not directly use _indices
idx
=
self
.
g
raph_idx
.
get
(
ctx
).
_indices
()
idx
=
self
.
g
.
adjacency_matrix
(
ctx
).
_indices
()
adjmat
=
F
.
sparse_tensor
(
idx
,
dat
,
self
.
graph_shape
)
else
:
adjmat
=
self
.
g
raph_idx
.
get
(
ctx
)
adjmat
=
self
.
g
.
adjacency_matrix
(
ctx
)
return
adjmat
...
...
@@ -351,10 +338,7 @@ class SendRecvExecutor(BasicExecutor):
def
_adj_build_fn
(
self
,
edge_field
,
ctx
,
use_edge_feat
):
if
use_edge_feat
:
if
edge_field
is
None
:
dat
=
self
.
edge_repr
else
:
dat
=
self
.
edge_repr
[
edge_field
]
dat
=
self
.
edge_repr
[
edge_field
]
dat
=
F
.
squeeze
(
dat
)
else
:
dat
=
F
.
ones
((
len
(
self
.
u
),
))
...
...
@@ -386,9 +370,8 @@ class BundledExecutor(BasicExecutor):
func_pairs
=
[]
for
rfn
in
rfunc
.
fn_list
:
mfn
=
out2mfunc
.
get
(
rfn
.
msg_field
,
None
)
# field check
assert
mfn
is
not
None
,
\
"cannot find message func for reduce func in-field {}"
.
format
(
rfn
.
msg_field
)
if
mfn
is
None
:
raise
DGLError
(
'Cannot find message field "%s".'
%
rfn
.
msg_field
)
func_pairs
.
append
((
mfn
,
rfn
))
return
func_pairs
...
...
@@ -409,7 +392,6 @@ class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor):
self
.
_init_state
()
BundledExecutor
.
__init__
(
self
,
graph
,
mfunc
,
rfunc
)
class
BundledSendRecvExecutor
(
BundledExecutor
,
SendRecvExecutor
):
def
__init__
(
self
,
graph
,
src
,
dst
,
mfunc
,
rfunc
):
self
.
_init_state
(
src
,
dst
)
...
...
tests/pytorch/test_basics.py
View file @
3e76bcc0
...
...
@@ -209,14 +209,13 @@ def test_reduce_0deg():
g
.
add_edge
(
3
,
0
)
g
.
add_edge
(
4
,
0
)
def
_message
(
src
,
edge
):
return
src
return
{
'm'
:
src
[
'h'
]}
def
_reduce
(
node
,
msgs
):
assert
msgs
is
not
None
return
node
+
msgs
.
sum
(
1
)
return
{
'h'
:
node
[
'h'
]
+
msgs
[
'm'
].
sum
(
1
)}
old_repr
=
th
.
randn
(
5
,
5
)
g
.
set_n_repr
(
old_repr
)
g
.
set_n_repr
(
{
'h'
:
old_repr
}
)
g
.
update_all
(
_message
,
_reduce
)
new_repr
=
g
.
get_n_repr
()
new_repr
=
g
.
get_n_repr
()
[
'h'
]
assert
th
.
allclose
(
new_repr
[
1
:],
old_repr
[
1
:])
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
.
sum
(
0
))
...
...
@@ -226,25 +225,25 @@ def test_pull_0deg():
g
.
add_nodes
(
2
)
g
.
add_edge
(
0
,
1
)
def
_message
(
src
,
edge
):
return
src
return
{
'm'
:
src
[
'h'
]}
def
_reduce
(
node
,
msgs
):
assert
msgs
is
not
None
return
msgs
.
sum
(
1
)
return
{
'h'
:
msgs
[
'm'
].
sum
(
1
)}
old_repr
=
th
.
randn
(
2
,
5
)
g
.
set_n_repr
(
old_repr
)
g
.
set_n_repr
({
'h'
:
old_repr
})
g
.
pull
(
0
,
_message
,
_reduce
)
new_repr
=
g
.
get_n_repr
()
new_repr
=
g
.
get_n_repr
()
[
'h'
]
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
1
])
g
.
pull
(
1
,
_message
,
_reduce
)
new_repr
=
g
.
get_n_repr
()
new_repr
=
g
.
get_n_repr
()
[
'h'
]
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
0
])
old_repr
=
th
.
randn
(
2
,
5
)
g
.
set_n_repr
(
old_repr
)
g
.
set_n_repr
(
{
'h'
:
old_repr
}
)
g
.
pull
([
0
,
1
],
_message
,
_reduce
)
new_repr
=
g
.
get_n_repr
()
new_repr
=
g
.
get_n_repr
()
[
'h'
]
assert
th
.
allclose
(
new_repr
[
0
],
old_repr
[
0
])
assert
th
.
allclose
(
new_repr
[
1
],
old_repr
[
0
])
...
...
tests/pytorch/test_basics_anonymous.py
deleted
100644 → 0
View file @
fb6be9fb
import
torch
as
th
from
torch.autograd
import
Variable
import
numpy
as
np
from
dgl.graph
import
DGLGraph
,
__REPR__
D
=
32
reduce_msg_shapes
=
set
()
def
check_eq
(
a
,
b
):
assert
a
.
shape
==
b
.
shape
assert
th
.
sum
(
a
==
b
)
==
int
(
np
.
prod
(
list
(
a
.
shape
)))
def
message_func
(
hu
,
e_uv
):
assert
len
(
hu
.
shape
)
==
2
assert
hu
.
shape
[
1
]
==
D
return
hu
def
reduce_func
(
hv
,
msgs
):
reduce_msg_shapes
.
add
(
tuple
(
msgs
.
shape
))
assert
len
(
msgs
.
shape
)
==
3
assert
msgs
.
shape
[
2
]
==
D
return
hv
+
th
.
sum
(
msgs
,
1
)
def
generate_graph
(
grad
=
False
):
g
=
DGLGraph
()
g
.
add_nodes
(
10
)
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
i
,
9
)
# add a back flow from 9 to 0
g
.
add_edge
(
9
,
0
)
ncol
=
Variable
(
th
.
randn
(
10
,
D
),
requires_grad
=
grad
)
ecol
=
Variable
(
th
.
randn
(
17
,
D
),
requires_grad
=
grad
)
g
.
set_n_repr
(
ncol
)
g
.
set_e_repr
(
ecol
)
return
g
def
test_batch_setter_getter
():
def
_pfc
(
x
):
return
list
(
x
.
numpy
()[:,
0
])
g
=
generate_graph
()
# set all nodes
g
.
set_n_repr
(
th
.
zeros
((
10
,
D
)))
assert
_pfc
(
g
.
get_n_repr
())
==
[
0.
]
*
10
# pop nodes
assert
_pfc
(
g
.
pop_n_repr
())
==
[
0.
]
*
10
assert
len
(
g
.
get_n_repr
())
==
0
g
.
set_n_repr
(
th
.
zeros
((
10
,
D
)))
# set partial nodes
u
=
th
.
tensor
([
1
,
3
,
5
])
g
.
set_n_repr
(
th
.
ones
((
3
,
D
)),
u
)
assert
_pfc
(
g
.
get_n_repr
())
==
[
0.
,
1.
,
0.
,
1.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
]
# get partial nodes
u
=
th
.
tensor
([
1
,
2
,
3
])
assert
_pfc
(
g
.
get_n_repr
(
u
))
==
[
1.
,
0.
,
1.
]
'''
s, d, eid
0, 1, 0
1, 9, 1
0, 2, 2
2, 9, 3
0, 3, 4
3, 9, 5
0, 4, 6
4, 9, 7
0, 5, 8
5, 9, 9
0, 6, 10
6, 9, 11
0, 7, 12
7, 9, 13
0, 8, 14
8, 9, 15
9, 0, 16
'''
# set all edges
g
.
set_e_repr
(
th
.
zeros
((
17
,
D
)))
assert
_pfc
(
g
.
get_e_repr
())
==
[
0.
]
*
17
# pop edges
assert
_pfc
(
g
.
pop_e_repr
())
==
[
0.
]
*
17
assert
len
(
g
.
get_e_repr
())
==
0
g
.
set_e_repr
(
th
.
zeros
((
17
,
D
)))
# set partial edges (many-many)
u
=
th
.
tensor
([
0
,
0
,
2
,
5
,
9
])
v
=
th
.
tensor
([
1
,
3
,
9
,
9
,
0
])
g
.
set_e_repr
(
th
.
ones
((
5
,
D
)),
u
,
v
)
truth
=
[
0.
]
*
17
truth
[
0
]
=
truth
[
4
]
=
truth
[
3
]
=
truth
[
9
]
=
truth
[
16
]
=
1.
assert
_pfc
(
g
.
get_e_repr
())
==
truth
# set partial edges (many-one)
u
=
th
.
tensor
([
3
,
4
,
6
])
v
=
th
.
tensor
([
9
])
g
.
set_e_repr
(
th
.
ones
((
3
,
D
)),
u
,
v
)
truth
[
5
]
=
truth
[
7
]
=
truth
[
11
]
=
1.
assert
_pfc
(
g
.
get_e_repr
())
==
truth
# set partial edges (one-many)
u
=
th
.
tensor
([
0
])
v
=
th
.
tensor
([
4
,
5
,
6
])
g
.
set_e_repr
(
th
.
ones
((
3
,
D
)),
u
,
v
)
truth
[
6
]
=
truth
[
8
]
=
truth
[
10
]
=
1.
assert
_pfc
(
g
.
get_e_repr
())
==
truth
# get partial edges (many-many)
u
=
th
.
tensor
([
0
,
6
,
0
])
v
=
th
.
tensor
([
6
,
9
,
7
])
assert
_pfc
(
g
.
get_e_repr
(
u
,
v
))
==
[
1.
,
1.
,
0.
]
# get partial edges (many-one)
u
=
th
.
tensor
([
5
,
6
,
7
])
v
=
th
.
tensor
([
9
])
assert
_pfc
(
g
.
get_e_repr
(
u
,
v
))
==
[
1.
,
1.
,
0.
]
# get partial edges (one-many)
u
=
th
.
tensor
([
0
])
v
=
th
.
tensor
([
3
,
4
,
5
])
assert
_pfc
(
g
.
get_e_repr
(
u
,
v
))
==
[
1.
,
1.
,
1.
]
def
test_batch_setter_autograd
():
g
=
generate_graph
(
grad
=
True
)
h1
=
g
.
get_n_repr
()
# partial set
v
=
th
.
tensor
([
1
,
2
,
8
])
hh
=
Variable
(
th
.
zeros
((
len
(
v
),
D
)),
requires_grad
=
True
)
g
.
set_n_repr
(
hh
,
v
)
h2
=
g
.
get_n_repr
()
h2
.
backward
(
th
.
ones
((
10
,
D
))
*
2
)
check_eq
(
h1
.
grad
[:,
0
],
th
.
tensor
([
2.
,
0.
,
0.
,
2.
,
2.
,
2.
,
2.
,
2.
,
0.
,
2.
]))
check_eq
(
hh
.
grad
[:,
0
],
th
.
tensor
([
2.
,
2.
,
2.
]))
def
test_batch_send
():
g
=
generate_graph
()
def
_fmsg
(
hu
,
edge
):
assert
hu
.
shape
==
(
5
,
D
)
return
hu
g
.
register_message_func
(
_fmsg
)
# many-many send
u
=
th
.
tensor
([
0
,
0
,
0
,
0
,
0
])
v
=
th
.
tensor
([
1
,
2
,
3
,
4
,
5
])
g
.
send
(
u
,
v
)
# one-many send
u
=
th
.
tensor
([
0
])
v
=
th
.
tensor
([
1
,
2
,
3
,
4
,
5
])
g
.
send
(
u
,
v
)
# many-one send
u
=
th
.
tensor
([
1
,
2
,
3
,
4
,
5
])
v
=
th
.
tensor
([
9
])
g
.
send
(
u
,
v
)
def
test_batch_recv
():
g
=
generate_graph
()
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
)
u
=
th
.
tensor
([
0
,
0
,
0
,
4
,
5
,
6
])
v
=
th
.
tensor
([
1
,
2
,
3
,
9
,
9
,
9
])
reduce_msg_shapes
.
clear
()
g
.
send
(
u
,
v
)
g
.
recv
(
th
.
unique
(
v
))
assert
(
reduce_msg_shapes
==
{(
1
,
3
,
D
),
(
3
,
1
,
D
)})
reduce_msg_shapes
.
clear
()
def
test_update_routines
():
g
=
generate_graph
()
g
.
register_message_func
(
message_func
)
g
.
register_reduce_func
(
reduce_func
)
# send_and_recv
reduce_msg_shapes
.
clear
()
u
=
th
.
tensor
([
0
,
0
,
0
,
4
,
5
,
6
])
v
=
th
.
tensor
([
1
,
2
,
3
,
9
,
9
,
9
])
g
.
send_and_recv
(
u
,
v
)
assert
(
reduce_msg_shapes
==
{(
1
,
3
,
D
),
(
3
,
1
,
D
)})
reduce_msg_shapes
.
clear
()
# pull
v
=
th
.
tensor
([
1
,
2
,
3
,
9
])
reduce_msg_shapes
.
clear
()
g
.
pull
(
v
)
assert
(
reduce_msg_shapes
==
{(
1
,
8
,
D
),
(
3
,
1
,
D
)})
reduce_msg_shapes
.
clear
()
# push
v
=
th
.
tensor
([
0
,
1
,
2
,
3
])
reduce_msg_shapes
.
clear
()
g
.
push
(
v
)
assert
(
reduce_msg_shapes
==
{(
1
,
3
,
D
),
(
8
,
1
,
D
)})
reduce_msg_shapes
.
clear
()
# update_all
reduce_msg_shapes
.
clear
()
g
.
update_all
()
assert
(
reduce_msg_shapes
==
{(
1
,
8
,
D
),
(
9
,
1
,
D
)})
reduce_msg_shapes
.
clear
()
if
__name__
==
'__main__'
:
test_batch_setter_getter
()
test_batch_setter_autograd
()
test_batch_send
()
test_batch_recv
()
test_update_routines
()
tests/pytorch/test_batched_graph.py
View file @
3e76bcc0
...
...
@@ -18,8 +18,8 @@ def tree1():
g
.
add_edge
(
4
,
1
)
g
.
add_edge
(
1
,
0
)
g
.
add_edge
(
2
,
0
)
g
.
set_n_repr
(
th
.
Tensor
([
0
,
1
,
2
,
3
,
4
]))
g
.
set_e_repr
(
th
.
randn
(
4
,
10
))
g
.
set_n_repr
(
{
'h'
:
th
.
Tensor
([
0
,
1
,
2
,
3
,
4
])
}
)
g
.
set_e_repr
(
{
'h'
:
th
.
randn
(
4
,
10
)
}
)
return
g
def
tree2
():
...
...
@@ -37,17 +37,17 @@ def tree2():
g
.
add_edge
(
0
,
4
)
g
.
add_edge
(
4
,
1
)
g
.
add_edge
(
3
,
1
)
g
.
set_n_repr
(
th
.
Tensor
([
0
,
1
,
2
,
3
,
4
]))
g
.
set_e_repr
(
th
.
randn
(
4
,
10
))
g
.
set_n_repr
(
{
'h'
:
th
.
Tensor
([
0
,
1
,
2
,
3
,
4
])
}
)
g
.
set_e_repr
(
{
'h'
:
th
.
randn
(
4
,
10
)
}
)
return
g
def
test_batch_unbatch
():
t1
=
tree1
()
t2
=
tree2
()
n1
=
t1
.
get_n_repr
()
n2
=
t2
.
get_n_repr
()
e1
=
t1
.
get_e_repr
()
e2
=
t2
.
get_e_repr
()
n1
=
t1
.
get_n_repr
()
[
'h'
]
n2
=
t2
.
get_n_repr
()
[
'h'
]
e1
=
t1
.
get_e_repr
()
[
'h'
]
e2
=
t2
.
get_e_repr
()
[
'h'
]
bg
=
dgl
.
batch
([
t1
,
t2
])
assert
bg
.
number_of_nodes
()
==
10
...
...
@@ -57,10 +57,10 @@ def test_batch_unbatch():
assert
bg
.
batch_num_edges
==
[
4
,
4
]
tt1
,
tt2
=
dgl
.
unbatch
(
bg
)
assert
th
.
allclose
(
t1
.
get_n_repr
(),
tt1
.
get_n_repr
())
assert
th
.
allclose
(
t1
.
get_e_repr
(),
tt1
.
get_e_repr
())
assert
th
.
allclose
(
t2
.
get_n_repr
(),
tt2
.
get_n_repr
())
assert
th
.
allclose
(
t2
.
get_e_repr
(),
tt2
.
get_e_repr
())
assert
th
.
allclose
(
t1
.
get_n_repr
()
[
'h'
]
,
tt1
.
get_n_repr
()
[
'h'
]
)
assert
th
.
allclose
(
t1
.
get_e_repr
()
[
'h'
]
,
tt1
.
get_e_repr
()
[
'h'
]
)
assert
th
.
allclose
(
t2
.
get_n_repr
()
[
'h'
]
,
tt2
.
get_n_repr
()
[
'h'
]
)
assert
th
.
allclose
(
t2
.
get_e_repr
()
[
'h'
]
,
tt2
.
get_e_repr
()
[
'h'
]
)
def
test_batch_unbatch1
():
t1
=
tree1
()
...
...
@@ -74,20 +74,20 @@ def test_batch_unbatch1():
assert
b2
.
batch_num_edges
==
[
4
,
4
,
4
]
s1
,
s2
,
s3
=
dgl
.
unbatch
(
b2
)
assert
th
.
allclose
(
t2
.
get_n_repr
(),
s1
.
get_n_repr
())
assert
th
.
allclose
(
t2
.
get_e_repr
(),
s1
.
get_e_repr
())
assert
th
.
allclose
(
t1
.
get_n_repr
(),
s2
.
get_n_repr
())
assert
th
.
allclose
(
t1
.
get_e_repr
(),
s2
.
get_e_repr
())
assert
th
.
allclose
(
t2
.
get_n_repr
(),
s3
.
get_n_repr
())
assert
th
.
allclose
(
t2
.
get_e_repr
(),
s3
.
get_e_repr
())
assert
th
.
allclose
(
t2
.
get_n_repr
()
[
'h'
]
,
s1
.
get_n_repr
()
[
'h'
]
)
assert
th
.
allclose
(
t2
.
get_e_repr
()
[
'h'
]
,
s1
.
get_e_repr
()
[
'h'
]
)
assert
th
.
allclose
(
t1
.
get_n_repr
()
[
'h'
]
,
s2
.
get_n_repr
()
[
'h'
]
)
assert
th
.
allclose
(
t1
.
get_e_repr
()
[
'h'
]
,
s2
.
get_e_repr
()
[
'h'
]
)
assert
th
.
allclose
(
t2
.
get_n_repr
()
[
'h'
]
,
s3
.
get_n_repr
()
[
'h'
]
)
assert
th
.
allclose
(
t2
.
get_e_repr
()
[
'h'
]
,
s3
.
get_e_repr
()
[
'h'
]
)
def
test_batch_sendrecv
():
t1
=
tree1
()
t2
=
tree2
()
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
.
register_message_func
(
lambda
src
,
edge
:
src
)
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
th
.
sum
(
msgs
,
1
))
bg
.
register_message_func
(
lambda
src
,
edge
:
{
'm'
:
src
[
'h'
]}
)
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
{
'h'
:
th
.
sum
(
msgs
[
'm'
]
,
1
)
}
)
u
=
[
3
,
4
,
2
+
5
,
0
+
5
]
v
=
[
1
,
1
,
4
+
5
,
4
+
5
]
...
...
@@ -95,8 +95,8 @@ def test_batch_sendrecv():
bg
.
recv
(
v
)
t1
,
t2
=
dgl
.
unbatch
(
bg
)
assert
t1
.
get_n_repr
()[
1
]
==
7
assert
t2
.
get_n_repr
()[
4
]
==
2
assert
t1
.
get_n_repr
()[
'h'
][
1
]
==
7
assert
t2
.
get_n_repr
()[
'h'
][
4
]
==
2
def
test_batch_propagate
():
...
...
@@ -104,8 +104,8 @@ def test_batch_propagate():
t2
=
tree2
()
bg
=
dgl
.
batch
([
t1
,
t2
])
bg
.
register_message_func
(
lambda
src
,
edge
:
src
)
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
th
.
sum
(
msgs
,
1
))
bg
.
register_message_func
(
lambda
src
,
edge
:
{
'm'
:
src
[
'h'
]}
)
bg
.
register_reduce_func
(
lambda
node
,
msgs
:
{
'h'
:
th
.
sum
(
msgs
[
'm'
]
,
1
)
}
)
# get leaves.
order
=
[]
...
...
@@ -123,23 +123,23 @@ def test_batch_propagate():
bg
.
propagate
(
traverser
=
order
)
t1
,
t2
=
dgl
.
unbatch
(
bg
)
assert
t1
.
get_n_repr
()[
0
]
==
9
assert
t2
.
get_n_repr
()[
1
]
==
5
assert
t1
.
get_n_repr
()[
'h'
][
0
]
==
9
assert
t2
.
get_n_repr
()[
'h'
][
1
]
==
5
def
test_batched_edge_ordering
():
g1
=
dgl
.
DGLGraph
()
g1
.
add_nodes
(
6
)
g1
.
add_edges
([
4
,
4
,
2
,
2
,
0
],
[
5
,
3
,
3
,
1
,
1
])
e1
=
th
.
randn
(
5
,
10
)
g1
.
set_e_repr
(
e1
)
g1
.
set_e_repr
(
{
'h'
:
e1
}
)
g2
=
dgl
.
DGLGraph
()
g2
.
add_nodes
(
6
)
g2
.
add_edges
([
0
,
1
,
2
,
5
,
4
,
5
],
[
1
,
2
,
3
,
4
,
3
,
0
])
e2
=
th
.
randn
(
6
,
10
)
g2
.
set_e_repr
(
e2
)
g2
.
set_e_repr
(
{
'h'
:
e2
}
)
g
=
dgl
.
batch
([
g1
,
g2
])
r1
=
g
.
get_e_repr
()[
g
.
edge_id
(
4
,
5
)]
r2
=
g1
.
get_e_repr
()[
g1
.
edge_id
(
4
,
5
)]
r1
=
g
.
get_e_repr
()[
'h'
][
g
.
edge_id
(
4
,
5
)]
r2
=
g1
.
get_e_repr
()[
'h'
][
g1
.
edge_id
(
4
,
5
)]
assert
th
.
equal
(
r1
,
r2
)
def
test_batch_no_edge
():
...
...
tests/pytorch/test_function.py
View file @
3e76bcc0
import
torch
as
th
import
dgl
import
dgl.function
as
fn
from
dgl.graph
import
__REPR__
def
generate_graph
():
g
=
dgl
.
DGLGraph
()
...
...
@@ -37,18 +36,9 @@ def generate_graph1():
g
.
set_e_repr
(
h
)
return
g
def
reducer_msg
(
node
,
msgs
):
return
th
.
sum
(
msgs
[
'm'
],
1
)
def
reducer_out
(
node
,
msgs
):
return
{
'h'
:
th
.
sum
(
msgs
,
1
)}
def
reducer_both
(
node
,
msgs
):
return
{
'h'
:
th
.
sum
(
msgs
[
'm'
],
1
)}
def
reducer_none
(
node
,
msgs
):
return
th
.
sum
(
msgs
,
1
)
def
test_copy_src
():
# copy_src with both fields
g
=
generate_graph
()
...
...
@@ -58,30 +48,6 @@ def test_copy_src():
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy_src with only src field; the out field should use anonymous repr
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_src
(
src
=
'h'
))
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy_src with no src field; should use anonymous repr
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_src
(
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy src with no fields;
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_src
())
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
def
test_copy_edge
():
# copy_edge with both fields
g
=
generate_graph
()
...
...
@@ -91,30 +57,6 @@ def test_copy_edge():
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy_edge with only edge field; the out field should use anonymous repr
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
copy_edge
(
edge
=
'h'
))
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy_edge with no edge field; should use anonymous repr
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_edge
(
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
# copy edge with no fields;
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
copy_edge
())
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
10.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
44.
]))
def
test_src_mul_edge
():
# src_mul_edge with all fields
g
=
generate_graph
()
...
...
@@ -124,34 +66,6 @@ def test_src_mul_edge():
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
g
=
generate_graph
()
g
.
register_message_func
(
fn
.
src_mul_edge
(
src
=
'h'
,
edge
=
'h'
))
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
(
out
=
'm'
))
g
.
register_reduce_func
(
reducer_both
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
())
g
.
register_reduce_func
(
reducer_out
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
()[
'h'
],
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
g
=
generate_graph1
()
g
.
register_message_func
(
fn
.
src_mul_edge
())
g
.
register_reduce_func
(
reducer_none
)
g
.
update_all
()
assert
th
.
allclose
(
g
.
get_n_repr
(),
th
.
tensor
([
100.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
1.
,
284.
]))
if
__name__
==
'__main__'
:
test_copy_src
()
test_copy_edge
()
...
...
tests/pytorch/test_line_graph.py
View file @
3e76bcc0
...
...
@@ -5,35 +5,31 @@ import dgl
D
=
5
def
check_eq
(
a
,
b
):
return
a
.
shape
==
b
.
shape
and
np
.
allclose
(
a
.
numpy
(),
b
.
numpy
())
def
test_line_graph
():
N
=
5
G
=
dgl
.
DGLGraph
(
nx
.
star_graph
(
N
))
G
.
set_e_repr
(
th
.
randn
((
2
*
N
,
D
)))
G
.
set_e_repr
(
{
'h'
:
th
.
randn
((
2
*
N
,
D
))
}
)
n_edges
=
G
.
number_of_edges
()
L
=
G
.
line_graph
(
shared
=
True
)
assert
L
.
number_of_nodes
()
==
2
*
N
L
.
set_n_repr
(
th
.
randn
((
2
*
N
,
D
)))
L
.
set_n_repr
(
{
'h'
:
th
.
randn
((
2
*
N
,
D
))
}
)
# update node features on line graph should reflect to edge features on
# original graph.
u
=
[
0
,
0
,
2
,
3
]
v
=
[
1
,
2
,
0
,
0
]
eid
=
G
.
edge_ids
(
u
,
v
)
L
.
set_n_repr
(
th
.
zeros
((
4
,
D
)),
eid
)
assert
check_eq
(
G
.
get_e_repr
(
u
,
v
),
th
.
zeros
((
4
,
D
)))
L
.
set_n_repr
(
{
'h'
:
th
.
zeros
((
4
,
D
))
}
,
eid
)
assert
th
.
allclose
(
G
.
get_e_repr
(
u
,
v
)
[
'h'
]
,
th
.
zeros
((
4
,
D
)))
# adding a new node feature on line graph should also reflect to a new
# edge feature on original graph
data
=
th
.
randn
(
n_edges
,
D
)
L
.
set_n_repr
({
'w'
:
data
})
assert
check_eq
(
G
.
get_e_repr
()[
'w'
],
data
)
assert
th
.
allclose
(
G
.
get_e_repr
()[
'w'
],
data
)
def
test_no_backtracking
():
N
=
5
G
=
dgl
.
DGLGraph
(
nx
.
star_graph
(
N
))
G
.
set_e_repr
(
th
.
randn
((
2
*
N
,
D
)))
L
=
G
.
line_graph
(
backtracking
=
False
)
assert
L
.
number_of_nodes
()
==
2
*
N
for
i
in
range
(
1
,
N
):
...
...
tests/pytorch/test_specialization.py
View file @
3e76bcc0
...
...
@@ -22,23 +22,23 @@ def generate_graph():
def
test_update_all
():
def
_test
(
fld
):
def
message_func
(
hu
,
edge
):
return
hu
[
fld
]
return
{
'm'
:
hu
[
fld
]
}
def
message_func_edge
(
hu
,
edge
):
if
len
(
hu
[
fld
].
shape
)
==
1
:
return
hu
[
fld
]
*
edge
[
'e1'
]
return
{
'm'
:
hu
[
fld
]
*
edge
[
'e1'
]
}
else
:
return
hu
[
fld
]
*
edge
[
'e2'
]
return
{
'm'
:
hu
[
fld
]
*
edge
[
'e2'
]
}
def
reduce_func
(
hv
,
msgs
):
return
{
fld
:
th
.
sum
(
msgs
,
1
)}
return
{
fld
:
th
.
sum
(
msgs
[
'm'
]
,
1
)}
def
apply_func
(
hu
):
return
{
fld
:
2
*
hu
[
fld
]}
g
=
generate_graph
()
# update all
v1
=
g
.
get_n_repr
()[
fld
]
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
fld
),
apply_func
)
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
message_func
,
reduce_func
,
apply_func
)
...
...
@@ -46,12 +46,12 @@ def test_update_all():
assert
th
.
allclose
(
v2
,
v3
)
# update all with edge weights
v1
=
g
.
get_n_repr
()[
fld
]
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
),
fn
.
sum
(
out
=
fld
),
apply_func
)
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
),
fn
.
sum
(
out
=
fld
),
apply_func
)
g
.
update_all
(
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v3
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
update_all
(
message_func_edge
,
reduce_func
,
apply_func
)
...
...
@@ -68,42 +68,40 @@ def test_send_and_recv():
v
=
th
.
tensor
([
1
,
2
,
3
,
9
,
9
,
0
])
def
_test
(
fld
):
def
message_func
(
hu
,
edge
):
return
hu
[
fld
]
return
{
'm'
:
hu
[
fld
]
}
def
message_func_edge
(
hu
,
edge
):
if
len
(
hu
[
fld
].
shape
)
==
1
:
return
hu
[
fld
]
*
edge
[
'e1'
]
return
{
'm'
:
hu
[
fld
]
*
edge
[
'e1'
]
}
else
:
return
hu
[
fld
]
*
edge
[
'e2'
]
return
{
'm'
:
hu
[
fld
]
*
edge
[
'e2'
]
}
def
reduce_func
(
hv
,
msgs
):
return
{
fld
:
th
.
sum
(
msgs
,
1
)}
return
{
fld
:
th
.
sum
(
msgs
[
'm'
]
,
1
)}
def
apply_func
(
hu
):
return
{
fld
:
2
*
hu
[
fld
]}
g
=
generate_graph
()
# send and recv
v1
=
g
.
get_n_repr
()[
fld
]
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
fld
),
apply_func
)
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
send_and_recv
(
u
,
v
,
message_func
,
reduce_func
,
apply_func
)
g
.
send_and_recv
(
u
,
v
,
message_func
,
reduce_func
,
apply_func
)
v3
=
g
.
get_n_repr
()[
fld
]
assert
th
.
allclose
(
v2
,
v3
)
# send and recv with edge weights
v1
=
g
.
get_n_repr
()[
fld
]
g
.
send_and_recv
(
u
,
v
,
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
),
fn
.
sum
(
out
=
fld
),
apply_func
)
g
.
send_and_recv
(
u
,
v
,
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v2
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
send_and_recv
(
u
,
v
,
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
),
fn
.
sum
(
out
=
fld
),
apply_func
)
g
.
send_and_recv
(
u
,
v
,
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
fld
),
apply_func
)
v3
=
g
.
get_n_repr
()[
fld
]
g
.
set_n_repr
({
fld
:
v1
})
g
.
send_and_recv
(
u
,
v
,
message_func_edge
,
reduce_func
,
apply_func
)
g
.
send_and_recv
(
u
,
v
,
message_func_edge
,
reduce_func
,
apply_func
)
v4
=
g
.
get_n_repr
()[
fld
]
assert
th
.
allclose
(
v2
,
v3
)
assert
th
.
allclose
(
v3
,
v4
)
...
...
@@ -127,19 +125,19 @@ def test_update_all_multi_fn():
fld
=
'f2'
# update all, mix of builtin and UDF
g
.
update_all
([
fn
.
copy_src
(
src
=
fld
,
out
=
'm1'
),
message_func
],
[
fn
.
sum
(
msg
s
=
'm1'
,
out
=
'v1'
),
reduce_func
],
[
fn
.
sum
(
msg
=
'm1'
,
out
=
'v1'
),
reduce_func
],
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
assert
th
.
allclose
(
v1
,
v2
)
# run builtin with single message and reduce
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
'v1'
),
None
)
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'v1'
),
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
assert
th
.
allclose
(
v1
,
v2
)
# 1 message, 2 reduces
, using anonymous repr
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
),
[
fn
.
sum
(
out
=
'v2'
),
fn
.
sum
(
out
=
'v3'
)],
None
)
# 1 message, 2 reduces
g
.
update_all
(
fn
.
copy_src
(
src
=
fld
,
out
=
'm'
),
[
fn
.
sum
(
msg
=
'm'
,
out
=
'v2'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'v3'
)],
None
)
v2
=
g
.
get_n_repr
()[
'v2'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
assert
th
.
allclose
(
v1
,
v2
)
...
...
@@ -147,7 +145,7 @@ def test_update_all_multi_fn():
# update all with edge weights, 2 message, 3 reduces
g
.
update_all
([
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm1'
),
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm2'
)],
[
fn
.
sum
(
msg
s
=
'm1'
,
out
=
'v1'
),
fn
.
sum
(
msg
s
=
'm2'
,
out
=
'v2'
),
fn
.
sum
(
msg
s
=
'm1'
,
out
=
'v3'
)],
[
fn
.
sum
(
msg
=
'm1'
,
out
=
'v1'
),
fn
.
sum
(
msg
=
'm2'
,
out
=
'v2'
),
fn
.
sum
(
msg
=
'm1'
,
out
=
'v3'
)],
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
...
...
@@ -181,20 +179,23 @@ def test_send_and_recv_multi_fn():
# send and recv, mix of builtin and UDF
g
.
send_and_recv
(
u
,
v
,
[
fn
.
copy_src
(
src
=
fld
,
out
=
'm1'
),
message_func
],
[
fn
.
sum
(
msg
s
=
'm1'
,
out
=
'v1'
),
reduce_func
],
[
fn
.
sum
(
msg
=
'm1'
,
out
=
'v1'
),
reduce_func
],
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
assert
th
.
allclose
(
v1
,
v2
)
# run builtin with single message and reduce
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
fn
.
sum
(
out
=
'v1'
),
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
,
out
=
'm'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'v1'
),
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
assert
th
.
allclose
(
v1
,
v2
)
# 1 message, 2 reduces, using anonymous repr
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
),
[
fn
.
sum
(
out
=
'v2'
),
fn
.
sum
(
out
=
'v3'
)],
None
)
# 1 message, 2 reduces
g
.
send_and_recv
(
u
,
v
,
fn
.
copy_src
(
src
=
fld
,
out
=
'm'
),
[
fn
.
sum
(
msg
=
'm'
,
out
=
'v2'
),
fn
.
sum
(
msg
=
'm'
,
out
=
'v3'
)],
None
)
v2
=
g
.
get_n_repr
()[
'v2'
]
v3
=
g
.
get_n_repr
()[
'v3'
]
assert
th
.
allclose
(
v1
,
v2
)
...
...
@@ -203,7 +204,7 @@ def test_send_and_recv_multi_fn():
# send and recv with edge weights, 2 message, 3 reduces
g
.
send_and_recv
(
u
,
v
,
[
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e1'
,
out
=
'm1'
),
fn
.
src_mul_edge
(
src
=
fld
,
edge
=
'e2'
,
out
=
'm2'
)],
[
fn
.
sum
(
msg
s
=
'm1'
,
out
=
'v1'
),
fn
.
sum
(
msg
s
=
'm2'
,
out
=
'v2'
),
fn
.
sum
(
msg
s
=
'm1'
,
out
=
'v3'
)],
[
fn
.
sum
(
msg
=
'm1'
,
out
=
'v1'
),
fn
.
sum
(
msg
=
'm2'
,
out
=
'v2'
),
fn
.
sum
(
msg
=
'm1'
,
out
=
'v3'
)],
None
)
v1
=
g
.
get_n_repr
()[
'v1'
]
v2
=
g
.
get_n_repr
()[
'v2'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment