Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9c135fd5
Unverified
Commit
9c135fd5
authored
Oct 19, 2018
by
VoVAllen
Committed by
GitHub
Oct 19, 2018
Browse files
Merge pull request #4 from jermainewang/master
Sync with latest commit
parents
9d3f299d
00add9f2
Changes
73
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
896 additions
and
444 deletions
+896
-444
python/dgl/frame.py
python/dgl/frame.py
+524
-93
python/dgl/function/message.py
python/dgl/function/message.py
+59
-46
python/dgl/function/reducer.py
python/dgl/function/reducer.py
+37
-35
python/dgl/generator/__init__.py
python/dgl/generator/__init__.py
+0
-4
python/dgl/generator/line.py
python/dgl/generator/line.py
+0
-38
python/dgl/graph.py
python/dgl/graph.py
+132
-114
python/dgl/graph_index.py
python/dgl/graph_index.py
+54
-8
python/dgl/scheduler.py
python/dgl/scheduler.py
+22
-40
src/c_api_common.cc
src/c_api_common.cc
+6
-1
src/c_api_common.h
src/c_api_common.h
+19
-7
src/graph/graph.cc
src/graph/graph.cc
+12
-7
src/graph/graph_apis.cc
src/graph/graph_apis.cc
+5
-0
src/graph/graph_op.cc
src/graph/graph_op.cc
+8
-30
src/runtime/README.md
src/runtime/README.md
+0
-3
src/runtime/file_util.h
src/runtime/file_util.h
+3
-3
src/runtime/meta_data.h
src/runtime/meta_data.h
+3
-3
src/runtime/module_util.h
src/runtime/module_util.h
+3
-3
src/runtime/pack_args.h
src/runtime/pack_args.h
+3
-3
src/runtime/runtime_base.h
src/runtime/runtime_base.h
+3
-3
src/runtime/thread_storage_scope.h
src/runtime/thread_storage_scope.h
+3
-3
No files found.
python/dgl/frame.py
View file @
9c135fd5
This diff is collapsed.
Click to expand it.
python/dgl/function/message.py
View file @
9c135fd5
...
@@ -4,17 +4,25 @@ from __future__ import absolute_import
...
@@ -4,17 +4,25 @@ from __future__ import absolute_import
import
operator
import
operator
import
dgl.backend
as
F
import
dgl.backend
as
F
__all__
=
[
"MessageFunction"
,
"src_mul_edge"
,
"copy_src"
,
"copy_edge"
]
__all__
=
[
"src_mul_edge"
,
"copy_src"
,
"copy_edge"
]
class
MessageFunction
(
object
):
class
MessageFunction
(
object
):
"""Base builtin message function class."""
def
__call__
(
self
,
src
,
edge
):
def
__call__
(
self
,
src
,
edge
):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise
NotImplementedError
raise
NotImplementedError
def
name
(
self
):
def
name
(
self
):
"""Return the name of this builtin function."""
raise
NotImplementedError
raise
NotImplementedError
def
is_spmv_supported
(
self
,
g
):
def
is_spmv_supported
(
self
,
g
):
"""Return whether the SPMV optimization is supported."""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -22,12 +30,6 @@ class BundledMessageFunction(MessageFunction):
...
@@ -22,12 +30,6 @@ class BundledMessageFunction(MessageFunction):
def
__init__
(
self
,
fn_list
):
def
__init__
(
self
,
fn_list
):
if
not
isinstance
(
fn_list
,
(
list
,
tuple
)):
if
not
isinstance
(
fn_list
,
(
list
,
tuple
)):
fn_list
=
[
fn_list
]
fn_list
=
[
fn_list
]
else
:
# sanity check on out field
for
fn
in
fn_list
:
# cannot perform check for udf
if
isinstance
(
fn
,
MessageFunction
)
and
fn
.
out_field
is
None
:
raise
RuntimeError
(
"Not specifying out field for multiple message is ambiguous"
)
self
.
fn_list
=
fn_list
self
.
fn_list
=
fn_list
def
is_spmv_supported
(
self
,
g
):
def
is_spmv_supported
(
self
,
g
):
...
@@ -43,11 +45,8 @@ class BundledMessageFunction(MessageFunction):
...
@@ -43,11 +45,8 @@ class BundledMessageFunction(MessageFunction):
if
ret
is
None
:
if
ret
is
None
:
ret
=
msg
ret
=
msg
else
:
else
:
try
:
# ret and msg must be dict
# ret and msg must be dict
ret
.
update
(
msg
)
ret
.
update
(
msg
)
except
:
raise
RuntimeError
(
"Must specify out field for multiple message"
)
return
ret
return
ret
def
name
(
self
):
def
name
(
self
):
...
@@ -55,25 +54,26 @@ class BundledMessageFunction(MessageFunction):
...
@@ -55,25 +54,26 @@ class BundledMessageFunction(MessageFunction):
def
_is_spmv_supported_node_feat
(
g
,
field
):
def
_is_spmv_supported_node_feat
(
g
,
field
):
if
field
is
None
:
"""Return whether the node feature shape supports SPMV optimization.
feat
=
g
.
get_n_repr
()
else
:
Only scalar and vector features are supported currently.
"""
feat
=
g
.
get_n_repr
()[
field
]
feat
=
g
.
get_n_repr
()[
field
]
shape
=
F
.
shape
(
feat
)
shape
=
F
.
shape
(
feat
)
return
len
(
shape
)
==
1
or
len
(
shape
)
==
2
return
len
(
shape
)
==
1
or
len
(
shape
)
==
2
def
_is_spmv_supported_edge_feat
(
g
,
field
):
def
_is_spmv_supported_edge_feat
(
g
,
field
):
# check shape, only scalar edge feature can be optimized at the moment
"""Return whether the edge feature shape supports SPMV optimization.
if
field
is
None
:
feat
=
g
.
get_e_repr
()
Only scalar feature is supported currently.
else
:
"""
feat
=
g
.
get_e_repr
()[
field
]
feat
=
g
.
get_e_repr
()[
field
]
shape
=
F
.
shape
(
feat
)
shape
=
F
.
shape
(
feat
)
return
len
(
shape
)
==
1
or
(
len
(
shape
)
==
2
and
shape
[
1
]
==
1
)
return
len
(
shape
)
==
1
or
(
len
(
shape
)
==
2
and
shape
[
1
]
==
1
)
class
SrcMulEdgeMessageFunction
(
MessageFunction
):
class
SrcMulEdgeMessageFunction
(
MessageFunction
):
def
__init__
(
self
,
mul_op
,
src_field
=
None
,
edge_field
=
None
,
out_field
=
None
):
def
__init__
(
self
,
mul_op
,
src_field
,
edge_field
,
out_field
):
self
.
mul_op
=
mul_op
self
.
mul_op
=
mul_op
self
.
src_field
=
src_field
self
.
src_field
=
src_field
self
.
edge_field
=
edge_field
self
.
edge_field
=
edge_field
...
@@ -84,21 +84,14 @@ class SrcMulEdgeMessageFunction(MessageFunction):
...
@@ -84,21 +84,14 @@ class SrcMulEdgeMessageFunction(MessageFunction):
and
_is_spmv_supported_edge_feat
(
g
,
self
.
edge_field
)
and
_is_spmv_supported_edge_feat
(
g
,
self
.
edge_field
)
def
__call__
(
self
,
src
,
edge
):
def
__call__
(
self
,
src
,
edge
):
if
self
.
src_field
is
not
None
:
ret
=
self
.
mul_op
(
src
[
self
.
src_field
],
edge
[
self
.
edge_field
])
src
=
src
[
self
.
src_field
]
if
self
.
edge_field
is
not
None
:
edge
=
edge
[
self
.
edge_field
]
ret
=
self
.
mul_op
(
src
,
edge
)
if
self
.
out_field
is
None
:
return
ret
else
:
return
{
self
.
out_field
:
ret
}
return
{
self
.
out_field
:
ret
}
def
name
(
self
):
def
name
(
self
):
return
"src_mul_edge"
return
"src_mul_edge"
class
CopySrcMessageFunction
(
MessageFunction
):
class
CopySrcMessageFunction
(
MessageFunction
):
def
__init__
(
self
,
src_field
=
None
,
out_field
=
None
):
def
__init__
(
self
,
src_field
,
out_field
):
self
.
src_field
=
src_field
self
.
src_field
=
src_field
self
.
out_field
=
out_field
self
.
out_field
=
out_field
...
@@ -106,14 +99,7 @@ class CopySrcMessageFunction(MessageFunction):
...
@@ -106,14 +99,7 @@ class CopySrcMessageFunction(MessageFunction):
return
_is_spmv_supported_node_feat
(
g
,
self
.
src_field
)
return
_is_spmv_supported_node_feat
(
g
,
self
.
src_field
)
def
__call__
(
self
,
src
,
edge
):
def
__call__
(
self
,
src
,
edge
):
if
self
.
src_field
is
not
None
:
return
{
self
.
out_field
:
src
[
self
.
src_field
]}
ret
=
src
[
self
.
src_field
]
else
:
ret
=
src
if
self
.
out_field
is
None
:
return
ret
else
:
return
{
self
.
out_field
:
ret
}
def
name
(
self
):
def
name
(
self
):
return
"copy_src"
return
"copy_src"
...
@@ -142,14 +128,41 @@ class CopyEdgeMessageFunction(MessageFunction):
...
@@ -142,14 +128,41 @@ class CopyEdgeMessageFunction(MessageFunction):
return
"copy_edge"
return
"copy_edge"
def
src_mul_edge
(
src
=
None
,
edge
=
None
,
out
=
None
):
def
src_mul_edge
(
src
,
edge
,
out
):
"""TODO(minjie): docstring """
"""Builtin message function that computes message by multiplying source node features
with edge features.
Parameters
----------
src : str
The source feature name.
edge : str
The edge feature name.
out : str
The output message name.
"""
return
SrcMulEdgeMessageFunction
(
operator
.
mul
,
src
,
edge
,
out
)
return
SrcMulEdgeMessageFunction
(
operator
.
mul
,
src
,
edge
,
out
)
def
copy_src
(
src
=
None
,
out
=
None
):
def
copy_src
(
src
,
out
):
"""TODO(minjie): docstring """
"""Builtin message function that computes message using source node feature.
Parameters
----------
src : str
The source feature name.
out : str
The output message name.
"""
return
CopySrcMessageFunction
(
src
,
out
)
return
CopySrcMessageFunction
(
src
,
out
)
def
copy_edge
(
edge
=
None
,
out
=
None
):
def
copy_edge
(
edge
,
out
):
"""TODO(minjie): docstring """
"""Builtin message function that computes message using edge feature.
Parameters
----------
edge : str
The edge feature name.
out : str
The output message name.
"""
return
CopyEdgeMessageFunction
(
edge
,
out
)
return
CopyEdgeMessageFunction
(
edge
,
out
)
python/dgl/function/reducer.py
View file @
9c135fd5
...
@@ -3,27 +3,30 @@ from __future__ import absolute_import
...
@@ -3,27 +3,30 @@ from __future__ import absolute_import
from
..
import
backend
as
F
from
..
import
backend
as
F
__all__
=
[
"ReduceFunction"
,
"sum"
,
"max"
]
__all__
=
[
"sum"
,
"max"
]
class
ReduceFunction
(
object
):
class
ReduceFunction
(
object
):
"""Base builtin reduce function class."""
def
__call__
(
self
,
node
,
msgs
):
def
__call__
(
self
,
node
,
msgs
):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise
NotImplementedError
raise
NotImplementedError
def
name
(
self
):
def
name
(
self
):
"""Return the name of this builtin function."""
raise
NotImplementedError
raise
NotImplementedError
def
is_spmv_supported
(
self
):
def
is_spmv_supported
(
self
):
"""Return whether the SPMV optimization is supported."""
raise
NotImplementedError
raise
NotImplementedError
class
BundledReduceFunction
(
ReduceFunction
):
class
BundledReduceFunction
(
ReduceFunction
):
def
__init__
(
self
,
fn_list
):
def
__init__
(
self
,
fn_list
):
if
not
isinstance
(
fn_list
,
(
list
,
tuple
)):
if
not
isinstance
(
fn_list
,
(
list
,
tuple
)):
fn_list
=
[
fn_list
]
fn_list
=
[
fn_list
]
else
:
# sanity check on out field
for
fn
in
fn_list
:
if
isinstance
(
fn
,
ReduceFunction
)
and
fn
.
out_field
is
None
:
raise
RuntimeError
(
"Not specifying out field for multiple reduce is ambiguous"
)
self
.
fn_list
=
fn_list
self
.
fn_list
=
fn_list
def
is_spmv_supported
(
self
):
def
is_spmv_supported
(
self
):
...
@@ -39,51 +42,50 @@ class BundledReduceFunction(ReduceFunction):
...
@@ -39,51 +42,50 @@ class BundledReduceFunction(ReduceFunction):
if
ret
is
None
:
if
ret
is
None
:
ret
=
rpr
ret
=
rpr
else
:
else
:
try
:
# ret and rpr must be dict
# ret and rpr must be dict
ret
.
update
(
rpr
)
ret
.
update
(
rpr
)
except
:
raise
RuntimeError
(
"Must specify out field for multiple reudce"
)
return
ret
return
ret
def
name
(
self
):
def
name
(
self
):
return
"bundled"
return
"bundled"
class
ReducerFunctionTemplate
(
ReduceFunction
):
class
ReducerFunctionTemplate
(
ReduceFunction
):
def
__init__
(
self
,
name
,
batch_op
,
nonbatch_
op
,
msg_field
=
None
,
out_field
=
None
):
def
__init__
(
self
,
name
,
op
,
msg_field
,
out_field
):
self
.
name
=
name
self
.
name
=
name
self
.
batch_op
=
batch_op
self
.
op
=
op
self
.
nonbatch_op
=
nonbatch_op
self
.
msg_field
=
msg_field
self
.
msg_field
=
msg_field
self
.
out_field
=
out_field
self
.
out_field
=
out_field
def
is_spmv_supported
(
self
):
def
is_spmv_supported
(
self
):
#
TODO: support max
#
NOTE: only sum is supported right now.
return
self
.
name
==
"sum"
return
self
.
name
==
"sum"
def
__call__
(
self
,
node
,
msgs
):
def
__call__
(
self
,
node
,
msgs
):
if
isinstance
(
msgs
,
list
):
return
{
self
.
out_field
:
self
.
op
(
msgs
[
self
.
msg_field
],
1
)}
if
self
.
msg_field
is
None
:
ret
=
self
.
nonbatch_op
(
msgs
)
else
:
ret
=
self
.
nonbatch_op
([
msg
[
self
.
msg_field
]
for
msg
in
msgs
])
else
:
if
self
.
msg_field
is
None
:
ret
=
self
.
batch_op
(
msgs
,
1
)
else
:
ret
=
self
.
batch_op
(
msgs
[
self
.
msg_field
],
1
)
if
self
.
out_field
is
None
:
return
ret
else
:
return
{
self
.
out_field
:
ret
}
def
name
(
self
):
def
name
(
self
):
return
self
.
name
return
self
.
name
_python_sum
=
sum
def
sum
(
msg
,
out
):
def
sum
(
msgs
=
None
,
out
=
None
):
"""Builtin reduce function that aggregates messages by sum.
return
ReducerFunctionTemplate
(
"sum"
,
F
.
sum
,
_python_sum
,
msgs
,
out
)
Parameters
----------
msg : str
The message name.
out : str
The output node feature name.
"""
return
ReducerFunctionTemplate
(
"sum"
,
F
.
sum
,
msg
,
out
)
def
max
(
msg
,
out
):
"""Builtin reduce function that aggregates messages by max.
_python_max
=
max
Parameters
def
max
(
msgs
=
None
,
out
=
None
):
----------
return
ReducerFunctionTemplate
(
"max"
,
F
.
max
,
_python_max
,
msgs
,
out
)
msg : str
The message name.
out : str
The output node feature name.
"""
return
ReducerFunctionTemplate
(
"max"
,
F
.
max
,
msg
,
out
)
python/dgl/generator/__init__.py
deleted
100644 → 0
View file @
9d3f299d
"""Package for graph generators"""
from
__future__
import
absolute_import
from
.line
import
*
python/dgl/generator/line.py
deleted
100644 → 0
View file @
9d3f299d
"""Line graph generator."""
from
__future__
import
absolute_import
import
networkx
as
nx
import
numpy
as
np
from
..
import
backend
as
F
from
..graph
import
DGLGraph
from
..frame
import
FrameRef
def
line_graph
(
G
,
no_backtracking
=
False
):
"""Create the line graph that shares the underlying features.
The node features of the result line graph will share the edge features
of the given graph.
Parameters
----------
G : DGLGraph
The input graph.
no_backtracking : bool
Whether the backtracking edges are included in the line graph.
If i~j and j~i are two edges in original graph G, then
(i,j)~(j,i) and (j,i)~(i,j) are the "backtracking" edges on
the line graph.
"""
L
=
nx
.
DiGraph
()
for
eid
,
from_node
in
enumerate
(
G
.
edge_list
):
L
.
add_node
(
from_node
)
for
to_node
in
G
.
edges
(
from_node
[
1
]):
if
no_backtracking
and
to_node
[
1
]
==
from_node
[
0
]:
continue
L
.
add_edge
(
from_node
,
to_node
)
relabel_map
=
{}
for
i
,
e
in
enumerate
(
G
.
edge_list
):
relabel_map
[
e
]
=
i
nx
.
relabel
.
relabel_nodes
(
L
,
relabel_map
,
copy
=
False
)
return
DGLGraph
(
L
,
node_frame
=
G
.
_edge_frame
)
python/dgl/graph.py
View file @
9c135fd5
...
@@ -6,7 +6,7 @@ import networkx as nx
...
@@ -6,7 +6,7 @@ import networkx as nx
import
numpy
as
np
import
numpy
as
np
import
dgl
import
dgl
from
.base
import
ALL
,
is_all
,
__MSG__
,
__REPR__
from
.base
import
ALL
,
is_all
,
DGLError
,
dgl_warning
from
.
import
backend
as
F
from
.
import
backend
as
F
from
.backend
import
Tensor
from
.backend
import
Tensor
from
.frame
import
FrameRef
,
merge_frames
from
.frame
import
FrameRef
,
merge_frames
...
@@ -22,7 +22,6 @@ class DGLGraph(object):
...
@@ -22,7 +22,6 @@ class DGLGraph(object):
"""Base graph class specialized for neural networks on graphs.
"""Base graph class specialized for neural networks on graphs.
TODO(minjie): document of batching semantics
TODO(minjie): document of batching semantics
TODO(minjie): document of __REPR__ semantics
Parameters
Parameters
----------
----------
...
@@ -448,7 +447,9 @@ class DGLGraph(object):
...
@@ -448,7 +447,9 @@ class DGLGraph(object):
The nx graph
The nx graph
"""
"""
nx_graph
=
self
.
_graph
.
to_networkx
()
nx_graph
=
self
.
_graph
.
to_networkx
()
#TODO: attributes
#TODO(minjie): attributes
dgl_warning
(
'to_networkx currently does not support converting'
' node/edge features automatically.'
)
return
nx_graph
return
nx_graph
def
from_networkx
(
self
,
nx_graph
,
node_attrs
=
None
,
edge_attrs
=
None
):
def
from_networkx
(
self
,
nx_graph
,
node_attrs
=
None
,
edge_attrs
=
None
):
...
@@ -504,70 +505,95 @@ class DGLGraph(object):
...
@@ -504,70 +505,95 @@ class DGLGraph(object):
self
.
_msg_graph
.
add_nodes
(
self
.
_graph
.
number_of_nodes
())
self
.
_msg_graph
.
add_nodes
(
self
.
_graph
.
number_of_nodes
())
def
node_attr_schemes
(
self
):
def
node_attr_schemes
(
self
):
"""Return the node
attribut
e schemes.
"""Return the node
featur
e schemes.
Returns
Returns
-------
-------
iterable
dict of str to schemes
The s
et of attribute names
The s
chemes of node feature columns.
"""
"""
return
self
.
_node_frame
.
schemes
return
self
.
_node_frame
.
schemes
def
edge_attr_schemes
(
self
):
def
edge_attr_schemes
(
self
):
"""Return the edge
attribut
e schemes.
"""Return the edge
featur
e schemes.
Returns
Returns
-------
-------
iterable
dict of str to schemes
The s
et of attribute names
The s
chemes of edge feature columns.
"""
"""
return
self
.
_edge_frame
.
schemes
return
self
.
_edge_frame
.
schemes
def
set_n_initializer
(
self
,
initializer
):
"""Set the initializer for empty node features.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
"""
self
.
_node_frame
.
set_initializer
(
initializer
)
def
set_e_initializer
(
self
,
initializer
):
"""Set the initializer for empty edge features.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
"""
self
.
_edge_frame
.
set_initializer
(
initializer
)
def
set_n_repr
(
self
,
hu
,
u
=
ALL
,
inplace
=
False
):
def
set_n_repr
(
self
,
hu
,
u
=
ALL
,
inplace
=
False
):
"""Set node(s) representation.
"""Set node(s) representation.
To set multiple node representations at once, pass `u` with a
tensor
or
`hu` is a dictionary from the feature name to feature tensor. Each
tensor
a supported container of node ids. In this case, `hu` must be a tensor
is of shape (B, D1, D2, ...), where B is the number of nodes to be updated,
of shape (B, D1, D2, ...), where B is the number of the nodes and
and (D1, D2, ...) be the shape of the node representation tensor. The
(D1, D2, ...) is the shape of the node representation tensor
.
length of the given node ids must match B (i.e, len(u) == B)
.
Dictionary type is also supported for `hu`. In this case, each item
All update will be done out-placely to work with autograd unless the inplace
will be treated as separate attribute of the nodes
.
flag is true
.
Parameters
Parameters
----------
----------
hu :
tensor or
dict of tensor
hu : dict of tensor
Node representation.
Node representation.
u : node, container or tensor
u : node, container or tensor
The node(s).
The node(s).
inplace : bool
True if the update is done inplacely
"""
"""
# sanity check
# sanity check
if
not
utils
.
is_dict_like
(
hu
):
raise
DGLError
(
'Expect dictionary type for feature data.'
' Got "%s" instead.'
%
type
(
hu
))
if
is_all
(
u
):
if
is_all
(
u
):
num_nodes
=
self
.
number_of_nodes
()
num_nodes
=
self
.
number_of_nodes
()
else
:
else
:
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
num_nodes
=
len
(
u
)
num_nodes
=
len
(
u
)
if
utils
.
is_dict_like
(
hu
):
for
key
,
val
in
hu
.
items
():
for
key
,
val
in
hu
.
items
():
assert
F
.
shape
(
val
)[
0
]
==
num_nodes
nfeats
=
F
.
shape
(
val
)[
0
]
else
:
if
nfeats
!=
num_nodes
:
assert
F
.
shape
(
hu
)[
0
]
==
num_nodes
raise
DGLError
(
'Expect number of features to match number of nodes (len(u)).'
' Got %d and %d instead.'
%
(
nfeats
,
num_nodes
))
# set
# set
if
is_all
(
u
):
if
is_all
(
u
):
if
utils
.
is_dict_like
(
hu
):
for
key
,
val
in
hu
.
items
():
for
key
,
val
in
hu
.
items
():
self
.
_node_frame
[
key
]
=
val
self
.
_node_frame
[
key
]
=
val
else
:
else
:
self
.
_node_frame
[
__REPR__
]
=
hu
else
:
if
utils
.
is_dict_like
(
hu
):
self
.
_node_frame
.
update_rows
(
u
,
hu
,
inplace
=
inplace
)
self
.
_node_frame
.
update_rows
(
u
,
hu
,
inplace
=
inplace
)
else
:
self
.
_node_frame
.
update_rows
(
u
,
{
__REPR__
:
hu
},
inplace
=
inplace
)
def
get_n_repr
(
self
,
u
=
ALL
):
def
get_n_repr
(
self
,
u
=
ALL
):
"""Get node(s) representation.
"""Get node(s) representation.
The returned feature tensor batches multiple node features on the first dimension.
Parameters
Parameters
----------
----------
u : node, container or tensor
u : node, container or tensor
...
@@ -576,23 +602,17 @@ class DGLGraph(object):
...
@@ -576,23 +602,17 @@ class DGLGraph(object):
Returns
Returns
-------
-------
dict
dict
Representation dict
Representation dict
from feature name to feature tensor.
"""
"""
if
len
(
self
.
node_attr_schemes
())
==
0
:
if
len
(
self
.
node_attr_schemes
())
==
0
:
return
dict
()
return
dict
()
if
is_all
(
u
):
if
is_all
(
u
):
if
len
(
self
.
_node_frame
)
==
1
and
__REPR__
in
self
.
_node_frame
:
return
self
.
_node_frame
[
__REPR__
]
else
:
return
dict
(
self
.
_node_frame
)
return
dict
(
self
.
_node_frame
)
else
:
else
:
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
if
len
(
self
.
_node_frame
)
==
1
and
__REPR__
in
self
.
_node_frame
:
return
self
.
_node_frame
.
select_rows
(
u
)[
__REPR__
]
else
:
return
self
.
_node_frame
.
select_rows
(
u
)
return
self
.
_node_frame
.
select_rows
(
u
)
def
pop_n_repr
(
self
,
key
=
__REPR__
):
def
pop_n_repr
(
self
,
key
):
"""Get and remove the specified node repr.
"""Get and remove the specified node repr.
Parameters
Parameters
...
@@ -607,71 +627,83 @@ class DGLGraph(object):
...
@@ -607,71 +627,83 @@ class DGLGraph(object):
"""
"""
return
self
.
_node_frame
.
pop
(
key
)
return
self
.
_node_frame
.
pop
(
key
)
def
set_e_repr
(
self
,
h
_uv
,
u
=
ALL
,
v
=
ALL
):
def
set_e_repr
(
self
,
h
e
,
u
=
ALL
,
v
=
ALL
,
inplace
=
False
):
"""Set edge(s) representation.
"""Set edge(s) representation.
To set multiple edge representations at once, pass `u` and `v` with tensors or
`he` is a dictionary from the feature name to feature tensor. Each tensor
supported containers of node ids. In this case, `h_uv` must be a tensor
is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
of shape (B, D1, D2, ...), where B is the number of the edges and
and (D1, D2, ...) be the shape of the edge representation tensor.
(D1, D2, ...) is the shape of the edge representation tensor.
Dictionary type is also supported for `h_uv`. In this case, each item
All update will be done out-placely to work with autograd unless the inplace
will be treated as separate attribute of the edges
.
flag is true
.
Parameters
Parameters
----------
----------
h
_uv
: tensor or dict of tensor
h
e
: tensor or dict of tensor
Edge representation.
Edge representation.
u : node, container or tensor
u : node, container or tensor
The source node(s).
The source node(s).
v : node, container or tensor
v : node, container or tensor
The destination node(s).
The destination node(s).
inplace : bool
True if the update is done inplacely
"""
"""
# sanity check
# sanity check
if
not
utils
.
is_dict_like
(
he
):
raise
DGLError
(
'Expect dictionary type for feature data.'
' Got "%s" instead.'
%
type
(
he
))
u_is_all
=
is_all
(
u
)
u_is_all
=
is_all
(
u
)
v_is_all
=
is_all
(
v
)
v_is_all
=
is_all
(
v
)
assert
u_is_all
==
v_is_all
assert
u_is_all
==
v_is_all
if
u_is_all
:
if
u_is_all
:
self
.
set_e_repr_by_id
(
h
_uv
,
eid
=
ALL
)
self
.
set_e_repr_by_id
(
h
e
,
eid
=
ALL
,
inplace
=
inplace
)
else
:
else
:
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
v
=
utils
.
toindex
(
v
)
_
,
_
,
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
_
,
_
,
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
self
.
set_e_repr_by_id
(
h
_uv
,
eid
=
eid
)
self
.
set_e_repr_by_id
(
h
e
,
eid
=
eid
,
inplace
=
inplace
)
def
set_e_repr_by_id
(
self
,
h
_uv
,
eid
=
ALL
):
def
set_e_repr_by_id
(
self
,
h
e
,
eid
=
ALL
,
inplace
=
False
):
"""Set edge(s) representation by edge id.
"""Set edge(s) representation by edge id.
`he` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
and (D1, D2, ...) be the shape of the edge representation tensor.
All update will be done out-placely to work with autograd unless the inplace
flag is true.
Parameters
Parameters
----------
----------
h
_uv
: tensor or dict of tensor
h
e
: tensor or dict of tensor
Edge representation.
Edge representation.
eid : int, container or tensor
eid : int, container or tensor
The edge id(s).
The edge id(s).
inplace : bool
True if the update is done inplacely
"""
"""
# sanity check
# sanity check
if
not
utils
.
is_dict_like
(
he
):
raise
DGLError
(
'Expect dictionary type for feature data.'
' Got "%s" instead.'
%
type
(
he
))
if
is_all
(
eid
):
if
is_all
(
eid
):
num_edges
=
self
.
number_of_edges
()
num_edges
=
self
.
number_of_edges
()
else
:
else
:
eid
=
utils
.
toindex
(
eid
)
eid
=
utils
.
toindex
(
eid
)
num_edges
=
len
(
eid
)
num_edges
=
len
(
eid
)
if
utils
.
is_dict_like
(
h_uv
):
for
key
,
val
in
he
.
items
(
):
for
key
,
val
in
h_uv
.
items
():
nfeats
=
F
.
shape
(
val
)[
0
]
assert
F
.
shape
(
val
)[
0
]
=
=
num_edges
if
nfeats
!
=
num_edges
:
else
:
raise
DGLError
(
'Expect number of features to match number of edges.'
assert
F
.
shape
(
h_uv
)[
0
]
==
num_edges
' Got %d and %d instead.'
%
(
nfeats
,
num_edges
))
# set
# set
if
is_all
(
eid
):
if
is_all
(
eid
):
if
utils
.
is_dict_like
(
h_uv
):
# update column
for
key
,
val
in
h
_uv
.
items
():
for
key
,
val
in
h
e
.
items
():
self
.
_edge_frame
[
key
]
=
val
self
.
_edge_frame
[
key
]
=
val
else
:
else
:
self
.
_edge_frame
[
__REPR__
]
=
h_uv
# update row
else
:
self
.
_edge_frame
.
update_rows
(
eid
,
he
,
inplace
=
inplace
)
if
utils
.
is_dict_like
(
h_uv
):
self
.
_edge_frame
[
eid
]
=
h_uv
else
:
self
.
_edge_frame
[
eid
]
=
{
__REPR__
:
h_uv
}
def
get_e_repr
(
self
,
u
=
ALL
,
v
=
ALL
):
def
get_e_repr
(
self
,
u
=
ALL
,
v
=
ALL
):
"""Get node(s) representation.
"""Get node(s) representation.
...
@@ -701,7 +733,7 @@ class DGLGraph(object):
...
@@ -701,7 +733,7 @@ class DGLGraph(object):
_
,
_
,
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
_
,
_
,
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
return
self
.
get_e_repr_by_id
(
eid
=
eid
)
return
self
.
get_e_repr_by_id
(
eid
=
eid
)
def
pop_e_repr
(
self
,
key
=
__REPR__
):
def
pop_e_repr
(
self
,
key
):
"""Get and remove the specified edge repr.
"""Get and remove the specified edge repr.
Parameters
Parameters
...
@@ -727,20 +759,14 @@ class DGLGraph(object):
...
@@ -727,20 +759,14 @@ class DGLGraph(object):
Returns
Returns
-------
-------
dict
dict
Representation dict
Representation dict
from feature name to feature tensor.
"""
"""
if
len
(
self
.
edge_attr_schemes
())
==
0
:
if
len
(
self
.
edge_attr_schemes
())
==
0
:
return
dict
()
return
dict
()
if
is_all
(
eid
):
if
is_all
(
eid
):
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
return
self
.
_edge_frame
[
__REPR__
]
else
:
return
dict
(
self
.
_edge_frame
)
return
dict
(
self
.
_edge_frame
)
else
:
else
:
eid
=
utils
.
toindex
(
eid
)
eid
=
utils
.
toindex
(
eid
)
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
return
self
.
_edge_frame
.
select_rows
(
eid
)[
__REPR__
]
else
:
return
self
.
_edge_frame
.
select_rows
(
eid
)
return
self
.
_edge_frame
.
select_rows
(
eid
)
def
register_edge_func
(
self
,
edge_func
):
def
register_edge_func
(
self
,
edge_func
):
...
@@ -793,12 +819,14 @@ class DGLGraph(object):
...
@@ -793,12 +819,14 @@ class DGLGraph(object):
"""
"""
self
.
_apply_edge_func
=
apply_edge_func
self
.
_apply_edge_func
=
apply_edge_func
def
apply_nodes
(
self
,
v
,
apply_node_func
=
"default"
):
def
apply_nodes
(
self
,
v
=
ALL
,
apply_node_func
=
"default"
):
"""Apply the function on node representations.
"""Apply the function on node representations.
Applying a None function will be ignored.
Parameters
Parameters
----------
----------
v : int, iterable of int, tensor
v : int, iterable of int, tensor
, optional
The node id(s).
The node id(s).
apply_node_func : callable
apply_node_func : callable
The apply node function.
The apply node function.
...
@@ -827,7 +855,7 @@ class DGLGraph(object):
...
@@ -827,7 +855,7 @@ class DGLGraph(object):
# merge current node_repr with reduce output
# merge current node_repr with reduce output
curr_repr
=
utils
.
HybridDict
(
reduce_accum
,
curr_repr
)
curr_repr
=
utils
.
HybridDict
(
reduce_accum
,
curr_repr
)
new_repr
=
apply_node_func
(
curr_repr
)
new_repr
=
apply_node_func
(
curr_repr
)
if
reduce_accum
is
not
None
and
utils
.
is_dict_like
(
new_repr
)
:
if
reduce_accum
is
not
None
:
# merge new node_repr with reduce output
# merge new node_repr with reduce output
reduce_accum
.
update
(
new_repr
)
reduce_accum
.
update
(
new_repr
)
new_repr
=
reduce_accum
new_repr
=
reduce_accum
...
@@ -836,6 +864,8 @@ class DGLGraph(object):
...
@@ -836,6 +864,8 @@ class DGLGraph(object):
def
apply_edges
(
self
,
u
=
None
,
v
=
None
,
apply_edge_func
=
"default"
,
eid
=
None
):
def
apply_edges
(
self
,
u
=
None
,
v
=
None
,
apply_edge_func
=
"default"
,
eid
=
None
):
"""Apply the function on edge representations.
"""Apply the function on edge representations.
Applying a None function will be ignored.
Parameters
Parameters
----------
----------
u : optional, int, iterable of int, tensor
u : optional, int, iterable of int, tensor
...
@@ -852,7 +882,6 @@ class DGLGraph(object):
...
@@ -852,7 +882,6 @@ class DGLGraph(object):
if
not
apply_edge_func
:
if
not
apply_edge_func
:
# Skip none function call.
# Skip none function call.
return
return
if
eid
is
None
:
if
eid
is
None
:
new_repr
=
apply_edge_func
(
self
.
get_e_repr
(
u
,
v
))
new_repr
=
apply_edge_func
(
self
.
get_e_repr
(
u
,
v
))
self
.
set_e_repr
(
new_repr
,
u
,
v
)
self
.
set_e_repr
(
new_repr
,
u
,
v
)
...
@@ -873,9 +902,8 @@ class DGLGraph(object):
...
@@ -873,9 +902,8 @@ class DGLGraph(object):
The message function can be any of the pre-defined functions
The message function can be any of the pre-defined functions
('from_src').
('from_src').
Currently, we require the message functions of consecutive send's and
Currently, we require the message functions of consecutive send's to
send_on's to return the same keys. Otherwise the behavior will be
return the same keys. Otherwise the behavior will be undefined.
undefined.
Parameters
Parameters
----------
----------
...
@@ -922,7 +950,11 @@ class DGLGraph(object):
...
@@ -922,7 +950,11 @@ class DGLGraph(object):
src_reprs
=
self
.
get_n_repr
(
u
)
src_reprs
=
self
.
get_n_repr
(
u
)
edge_reprs
=
self
.
get_e_repr_by_id
(
eid
)
edge_reprs
=
self
.
get_e_repr_by_id
(
eid
)
msgs
=
message_func
(
src_reprs
,
edge_reprs
)
msgs
=
message_func
(
src_reprs
,
edge_reprs
)
self
.
_msg_graph
.
add_edges
(
u
,
v
)
self
.
_msg_frame
.
append
(
msgs
)
# TODO(minjie): Fix these codes in next PR.
"""
new_uv = []
new_uv = []
msg_target_rows = []
msg_target_rows = []
msg_update_rows = []
msg_update_rows = []
...
@@ -945,8 +977,8 @@ class DGLGraph(object):
...
@@ -945,8 +977,8 @@ class DGLGraph(object):
self._msg_frame.update_rows(
self._msg_frame.update_rows(
msg_target_rows,
msg_target_rows,
{k: F.gather_row(msgs[k], msg_update_rows.tousertensor())
{k: F.gather_row(msgs[k], msg_update_rows.tousertensor())
for
k
in
msgs
}
for k in msgs}
,
)
inplace=False
)
if len(msg_append_rows) > 0:
if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv)
new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u)
new_u = utils.toindex(new_u)
...
@@ -954,14 +986,13 @@ class DGLGraph(object):
...
@@ -954,14 +986,13 @@ class DGLGraph(object):
self._msg_graph.add_edges(new_u, new_v)
self._msg_graph.add_edges(new_u, new_v)
self._msg_frame.append(
self._msg_frame.append(
{k: F.gather_row(msgs[k], msg_append_rows.tousertensor())
{k: F.gather_row(msgs[k], msg_append_rows.tousertensor())
for
k
in
msgs
}
for k in msgs})
)
else:
else:
if len(msg_target_rows) > 0:
if len(msg_target_rows) > 0:
self._msg_frame.update_rows(
self._msg_frame.update_rows(
msg_target_rows,
msg_target_rows,
{
__MSG__
:
F
.
gather_row
(
msgs
,
msg_update_rows
.
tousertensor
())}
{__MSG__: F.gather_row(msgs, msg_update_rows.tousertensor())}
,
)
inplace=False
)
if len(msg_append_rows) > 0:
if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv)
new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u)
new_u = utils.toindex(new_u)
...
@@ -970,6 +1001,7 @@ class DGLGraph(object):
...
@@ -970,6 +1001,7 @@ class DGLGraph(object):
self._msg_frame.append(
self._msg_frame.append(
{__MSG__: F.gather_row(msgs, msg_append_rows.tousertensor())}
{__MSG__: F.gather_row(msgs, msg_append_rows.tousertensor())}
)
)
"""
def
update_edge
(
self
,
u
=
ALL
,
v
=
ALL
,
edge_func
=
"default"
,
eid
=
None
):
def
update_edge
(
self
,
u
=
ALL
,
v
=
ALL
,
edge_func
=
"default"
,
eid
=
None
):
"""Update representation on edge u->v
"""Update representation on edge u->v
...
@@ -1013,7 +1045,6 @@ class DGLGraph(object):
...
@@ -1013,7 +1045,6 @@ class DGLGraph(object):
v
=
utils
.
toindex
(
v
)
v
=
utils
.
toindex
(
v
)
u
,
v
=
utils
.
edge_broadcasting
(
u
,
v
)
u
,
v
=
utils
.
edge_broadcasting
(
u
,
v
)
_
,
_
,
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
_
,
_
,
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
# call the UDF
# call the UDF
src_reprs
=
self
.
get_n_repr
(
u
)
src_reprs
=
self
.
get_n_repr
(
u
)
dst_reprs
=
self
.
get_n_repr
(
v
)
dst_reprs
=
self
.
get_n_repr
(
v
)
...
@@ -1100,25 +1131,19 @@ class DGLGraph(object):
...
@@ -1100,25 +1131,19 @@ class DGLGraph(object):
msg_shape
=
F
.
shape
(
msg
)
msg_shape
=
F
.
shape
(
msg
)
new_shape
=
(
bkt_len
,
deg
)
+
msg_shape
[
1
:]
new_shape
=
(
bkt_len
,
deg
)
+
msg_shape
[
1
:]
return
F
.
reshape
(
msg
,
new_shape
)
return
F
.
reshape
(
msg
,
new_shape
)
if
len
(
in_msgs
)
==
1
and
__MSG__
in
in_msgs
:
reshaped_in_msgs
=
_reshape_fn
(
in_msgs
[
__MSG__
])
else
:
reshaped_in_msgs
=
utils
.
LazyDict
(
reshaped_in_msgs
=
utils
.
LazyDict
(
lambda
key
:
_reshape_fn
(
in_msgs
[
key
]),
self
.
_msg_frame
.
schemes
)
lambda
key
:
_reshape_fn
(
in_msgs
[
key
]),
self
.
_msg_frame
.
schemes
)
reordered_v
.
append
(
v_bkt
.
tousertensor
())
reordered_v
.
append
(
v_bkt
.
tousertensor
())
new_reprs
.
append
(
reduce_func
(
dst_reprs
,
reshaped_in_msgs
))
new_reprs
.
append
(
reduce_func
(
dst_reprs
,
reshaped_in_msgs
))
# TODO: clear partial messages
# TODO
(minjie)
: clear partial messages
self
.
reset_messages
()
self
.
reset_messages
()
# Pack all reducer results together
# Pack all reducer results together
reordered_v
=
F
.
pack
(
reordered_v
)
reordered_v
=
F
.
pack
(
reordered_v
)
if
utils
.
is_dict_like
(
new_reprs
[
0
]):
keys
=
new_reprs
[
0
].
keys
()
keys
=
new_reprs
[
0
].
keys
()
new_reprs
=
{
key
:
F
.
pack
([
repr
[
key
]
for
repr
in
new_reprs
])
new_reprs
=
{
key
:
F
.
pack
([
repr
[
key
]
for
repr
in
new_reprs
])
for
key
in
keys
}
for
key
in
keys
}
else
:
new_reprs
=
{
__REPR__
:
F
.
pack
(
new_reprs
)}
if
v_is_all
and
not
has_zero_degree
:
if
v_is_all
and
not
has_zero_degree
:
# First do reorder and then replace the whole column.
# First do reorder and then replace the whole column.
...
@@ -1189,15 +1214,13 @@ class DGLGraph(object):
...
@@ -1189,15 +1214,13 @@ class DGLGraph(object):
if
executor
:
if
executor
:
new_reprs
=
executor
.
run
()
new_reprs
=
executor
.
run
()
if
not
utils
.
is_dict_like
(
new_reprs
):
new_reprs
=
{
__REPR__
:
new_reprs
}
unique_v
=
executor
.
recv_nodes
unique_v
=
executor
.
recv_nodes
self
.
_apply_nodes
(
unique_v
,
apply_node_func
,
reduce_accum
=
new_reprs
)
self
.
_apply_nodes
(
unique_v
,
apply_node_func
,
reduce_accum
=
new_reprs
)
elif
eid
is
not
None
:
elif
eid
is
not
None
:
_
,
v
,
_
=
self
.
_graph
.
find_edges
(
eid
)
_
,
v
,
_
=
self
.
_graph
.
find_edges
(
eid
)
unique_v
=
utils
.
toindex
(
F
.
unique
(
v
.
tousertensor
()))
unique_v
=
utils
.
toindex
(
F
.
unique
(
v
.
tousertensor
()))
# TODO: replace with the new DegreeBucketingScheduler
# TODO
(quan)
: replace with the new DegreeBucketingScheduler
self
.
send
(
eid
=
eid
,
message_func
=
message_func
)
self
.
send
(
eid
=
eid
,
message_func
=
message_func
)
self
.
recv
(
unique_v
,
reduce_func
,
apply_node_func
)
self
.
recv
(
unique_v
,
reduce_func
,
apply_node_func
)
else
:
else
:
...
@@ -1213,10 +1236,7 @@ class DGLGraph(object):
...
@@ -1213,10 +1236,7 @@ class DGLGraph(object):
edge_reprs
=
self
.
get_e_repr
(
u
,
v
)
edge_reprs
=
self
.
get_e_repr
(
u
,
v
)
msgs
=
message_func
(
src_reprs
,
edge_reprs
)
msgs
=
message_func
(
src_reprs
,
edge_reprs
)
msg_frame
=
FrameRef
()
msg_frame
=
FrameRef
()
if
utils
.
is_dict_like
(
msgs
):
msg_frame
.
append
(
msgs
)
msg_frame
.
append
(
msgs
)
else
:
msg_frame
.
append
({
__MSG__
:
msgs
})
# recv with degree bucketing
# recv with degree bucketing
executor
=
scheduler
.
get_recv_executor
(
graph
=
self
,
executor
=
scheduler
.
get_recv_executor
(
graph
=
self
,
...
@@ -1305,8 +1325,6 @@ class DGLGraph(object):
...
@@ -1305,8 +1325,6 @@ class DGLGraph(object):
"update_all"
,
self
,
message_func
=
message_func
,
reduce_func
=
reduce_func
)
"update_all"
,
self
,
message_func
=
message_func
,
reduce_func
=
reduce_func
)
if
executor
:
if
executor
:
new_reprs
=
executor
.
run
()
new_reprs
=
executor
.
run
()
if
not
utils
.
is_dict_like
(
new_reprs
):
new_reprs
=
{
__REPR__
:
new_reprs
}
self
.
_apply_nodes
(
ALL
,
apply_node_func
,
reduce_accum
=
new_reprs
)
self
.
_apply_nodes
(
ALL
,
apply_node_func
,
reduce_accum
=
new_reprs
)
else
:
else
:
self
.
send
(
ALL
,
ALL
,
message_func
)
self
.
send
(
ALL
,
ALL
,
message_func
)
...
@@ -1339,7 +1357,7 @@ class DGLGraph(object):
...
@@ -1339,7 +1357,7 @@ class DGLGraph(object):
Arguments for pre-defined iterators.
Arguments for pre-defined iterators.
"""
"""
if
isinstance
(
traverser
,
str
):
if
isinstance
(
traverser
,
str
):
# TODO Call pre-defined routine to unroll the computation.
# TODO
(minjie):
Call pre-defined routine to unroll the computation.
raise
RuntimeError
(
'Not implemented.'
)
raise
RuntimeError
(
'Not implemented.'
)
else
:
else
:
# NOTE: the iteration can return multiple edges at each step.
# NOTE: the iteration can return multiple edges at each step.
...
...
python/dgl/graph_index.py
View file @
9c135fd5
...
@@ -3,7 +3,7 @@ from __future__ import absolute_import
...
@@ -3,7 +3,7 @@ from __future__ import absolute_import
import
ctypes
import
ctypes
import
numpy
as
np
import
numpy
as
np
import
networkx
as
nx
import
networkx
as
nx
import
scipy
.sparse
as
sp
import
scipy
from
._ffi.base
import
c_array
from
._ffi.base
import
c_array
from
._ffi.function
import
_init_api
from
._ffi.function
import
_init_api
...
@@ -600,30 +600,59 @@ class GraphIndex(object):
...
@@ -600,30 +600,59 @@ class GraphIndex(object):
return
GraphIndex
(
handle
)
return
GraphIndex
(
handle
)
class
SubgraphIndex
(
GraphIndex
):
class
SubgraphIndex
(
GraphIndex
):
def
__init__
(
self
,
handle
,
parent
,
induced_nodes
,
induced_edges
):
"""Graph index for subgraph.
super
().
__init__
(
handle
)
Parameters
----------
handle : GraphIndexHandle
The capi handle.
paranet : GraphIndex
The parent graph index.
induced_nodes : utils.Index
The parent node ids in this subgraph.
induced_edges : utils.Index
The parent edge ids in this subgraph.
"""
def
__init__
(
self
,
handle
,
parent
,
induced_nodes
,
induced_edges
):
super
(
SubgraphIndex
,
self
).
__init__
(
handle
)
self
.
_parent
=
parent
self
.
_parent
=
parent
self
.
_induced_nodes
=
induced_nodes
self
.
_induced_nodes
=
induced_nodes
self
.
_induced_edges
=
induced_edges
self
.
_induced_edges
=
induced_edges
def
add_nodes
(
self
,
num
):
def
add_nodes
(
self
,
num
):
"""Add nodes. Disabled because SubgraphIndex is read-only."""
raise
RuntimeError
(
'Readonly graph. Mutation is not allowed.'
)
raise
RuntimeError
(
'Readonly graph. Mutation is not allowed.'
)
def
add_edge
(
self
,
u
,
v
):
def
add_edge
(
self
,
u
,
v
):
"""Add edges. Disabled because SubgraphIndex is read-only."""
raise
RuntimeError
(
'Readonly graph. Mutation is not allowed.'
)
raise
RuntimeError
(
'Readonly graph. Mutation is not allowed.'
)
def
add_edges
(
self
,
u
,
v
):
def
add_edges
(
self
,
u
,
v
):
"""Add edges. Disabled because SubgraphIndex is read-only."""
raise
RuntimeError
(
'Readonly graph. Mutation is not allowed.'
)
raise
RuntimeError
(
'Readonly graph. Mutation is not allowed.'
)
@
property
def
induced_edges
(
self
):
return
self
.
_induced_edges
@
property
@
property
def
induced_nodes
(
self
):
def
induced_nodes
(
self
):
"""Return parent node ids.
Returns
-------
utils.Index
The parent node ids.
"""
return
self
.
_induced_nodes
return
self
.
_induced_nodes
@
property
def
induced_edges
(
self
):
"""Return parent edge ids.
Returns
-------
utils.Index
The parent edge ids.
"""
return
self
.
_induced_edges
def
disjoint_union
(
graphs
):
def
disjoint_union
(
graphs
):
"""Return a disjoint union of the input graphs.
"""Return a disjoint union of the input graphs.
...
@@ -697,8 +726,25 @@ def create_graph_index(graph_data=None, multigraph=False):
...
@@ -697,8 +726,25 @@ def create_graph_index(graph_data=None, multigraph=False):
handle
=
_CAPI_DGLGraphCreate
(
multigraph
)
handle
=
_CAPI_DGLGraphCreate
(
multigraph
)
gi
=
GraphIndex
(
handle
)
gi
=
GraphIndex
(
handle
)
if
graph_data
is
not
None
:
if
graph_data
is
None
:
return
gi
# scipy format
if
isinstance
(
graph_data
,
scipy
.
sparse
.
spmatrix
):
try
:
gi
.
from_scipy_sparse_matrix
(
graph_data
)
return
gi
except
:
raise
Exception
(
'Graph data is not a valid scipy sparse matrix.'
)
# networkx - any format
try
:
gi
.
from_networkx
(
graph_data
)
gi
.
from_networkx
(
graph_data
)
except
:
raise
Exception
(
'Error while creating graph from input of type "%s".'
%
type
(
graph_data
))
return
gi
return
gi
_init_api
(
"dgl.graph_index"
)
_init_api
(
"dgl.graph_index"
)
python/dgl/scheduler.py
View file @
9c135fd5
...
@@ -3,7 +3,7 @@ from __future__ import absolute_import
...
@@ -3,7 +3,7 @@ from __future__ import absolute_import
import
numpy
as
np
import
numpy
as
np
from
.base
import
ALL
,
__MSG__
,
__REPR__
from
.base
import
ALL
,
DGLError
from
.
import
backend
as
F
from
.
import
backend
as
F
from
.function
import
message
as
fmsg
from
.function
import
message
as
fmsg
from
.function
import
reducer
as
fred
from
.function
import
reducer
as
fred
...
@@ -111,7 +111,15 @@ def light_degree_bucketing_for_graph(graph):
...
@@ -111,7 +111,15 @@ def light_degree_bucketing_for_graph(graph):
class
Executor
(
object
):
class
Executor
(
object
):
"""Base class for executing graph computation."""
def
run
(
self
):
def
run
(
self
):
"""Run this executor.
This should return the new node features.
TODO(minjie): extend this to support computation on edges.
"""
raise
NotImplementedError
raise
NotImplementedError
class
SPMVOperator
(
Executor
):
class
SPMVOperator
(
Executor
):
...
@@ -126,9 +134,6 @@ class SPMVOperator(Executor):
...
@@ -126,9 +134,6 @@ class SPMVOperator(Executor):
def
run
(
self
):
def
run
(
self
):
# get src col
# get src col
if
self
.
src_field
is
None
:
srccol
=
self
.
node_repr
else
:
srccol
=
self
.
node_repr
[
self
.
src_field
]
srccol
=
self
.
node_repr
[
self
.
src_field
]
ctx
=
F
.
get_context
(
srccol
)
ctx
=
F
.
get_context
(
srccol
)
...
@@ -142,9 +147,6 @@ class SPMVOperator(Executor):
...
@@ -142,9 +147,6 @@ class SPMVOperator(Executor):
dstcol
=
F
.
squeeze
(
dstcol
)
dstcol
=
F
.
squeeze
(
dstcol
)
else
:
else
:
dstcol
=
F
.
spmm
(
adjmat
,
srccol
)
dstcol
=
F
.
spmm
(
adjmat
,
srccol
)
if
self
.
dst_field
is
None
:
return
dstcol
else
:
return
{
self
.
dst_field
:
dstcol
}
return
{
self
.
dst_field
:
dstcol
}
...
@@ -180,20 +182,14 @@ class DegreeBucketingExecutor(Executor):
...
@@ -180,20 +182,14 @@ class DegreeBucketingExecutor(Executor):
msg_shape
=
F
.
shape
(
msg
)
msg_shape
=
F
.
shape
(
msg
)
new_shape
=
(
len
(
vv
),
deg
)
+
msg_shape
[
1
:]
new_shape
=
(
len
(
vv
),
deg
)
+
msg_shape
[
1
:]
return
F
.
reshape
(
msg
,
new_shape
)
return
F
.
reshape
(
msg
,
new_shape
)
if
len
(
in_msgs
)
==
1
and
__MSG__
in
in_msgs
:
reshaped_in_msgs
=
_reshape_fn
(
in_msgs
[
__MSG__
])
else
:
reshaped_in_msgs
=
utils
.
LazyDict
(
reshaped_in_msgs
=
utils
.
LazyDict
(
lambda
key
:
_reshape_fn
(
in_msgs
[
key
]),
self
.
msg_frame
.
schemes
)
lambda
key
:
_reshape_fn
(
in_msgs
[
key
]),
self
.
msg_frame
.
schemes
)
new_reprs
.
append
(
self
.
rfunc
(
dst_reprs
,
reshaped_in_msgs
))
new_reprs
.
append
(
self
.
rfunc
(
dst_reprs
,
reshaped_in_msgs
))
# Pack all reducer results together
# Pack all reducer results together
if
utils
.
is_dict_like
(
new_reprs
[
0
]):
keys
=
new_reprs
[
0
].
keys
()
keys
=
new_reprs
[
0
].
keys
()
new_reprs
=
{
key
:
F
.
pack
([
repr
[
key
]
for
repr
in
new_reprs
])
new_reprs
=
{
key
:
F
.
pack
([
repr
[
key
]
for
repr
in
new_reprs
])
for
key
in
keys
}
for
key
in
keys
}
else
:
new_reprs
=
{
__REPR__
:
F
.
pack
(
new_reprs
)}
return
new_reprs
return
new_reprs
...
@@ -249,12 +245,6 @@ class UpdateAllExecutor(BasicExecutor):
...
@@ -249,12 +245,6 @@ class UpdateAllExecutor(BasicExecutor):
self
.
_graph_shape
=
None
self
.
_graph_shape
=
None
self
.
_recv_nodes
=
None
self
.
_recv_nodes
=
None
@
property
def
graph_idx
(
self
):
if
self
.
_graph_idx
is
None
:
self
.
_graph_idx
=
self
.
g
.
_graph
.
adjacency_matrix
()
return
self
.
_graph_idx
@
property
@
property
def
graph_shape
(
self
):
def
graph_shape
(
self
):
if
self
.
_graph_shape
is
None
:
if
self
.
_graph_shape
is
None
:
...
@@ -280,16 +270,13 @@ class UpdateAllExecutor(BasicExecutor):
...
@@ -280,16 +270,13 @@ class UpdateAllExecutor(BasicExecutor):
def
_adj_build_fn
(
self
,
edge_field
,
ctx
,
use_edge_feat
):
def
_adj_build_fn
(
self
,
edge_field
,
ctx
,
use_edge_feat
):
if
use_edge_feat
:
if
use_edge_feat
:
if
edge_field
is
None
:
dat
=
self
.
edge_repr
else
:
dat
=
self
.
edge_repr
[
edge_field
]
dat
=
self
.
edge_repr
[
edge_field
]
dat
=
F
.
squeeze
(
dat
)
dat
=
F
.
squeeze
(
dat
)
# TODO(minjie): should not directly use _indices
# TODO(minjie): should not directly use _indices
idx
=
self
.
g
raph_idx
.
get
(
ctx
).
_indices
()
idx
=
self
.
g
.
adjacency_matrix
(
ctx
).
_indices
()
adjmat
=
F
.
sparse_tensor
(
idx
,
dat
,
self
.
graph_shape
)
adjmat
=
F
.
sparse_tensor
(
idx
,
dat
,
self
.
graph_shape
)
else
:
else
:
adjmat
=
self
.
g
raph_idx
.
get
(
ctx
)
adjmat
=
self
.
g
.
adjacency_matrix
(
ctx
)
return
adjmat
return
adjmat
...
@@ -351,9 +338,6 @@ class SendRecvExecutor(BasicExecutor):
...
@@ -351,9 +338,6 @@ class SendRecvExecutor(BasicExecutor):
def
_adj_build_fn
(
self
,
edge_field
,
ctx
,
use_edge_feat
):
def
_adj_build_fn
(
self
,
edge_field
,
ctx
,
use_edge_feat
):
if
use_edge_feat
:
if
use_edge_feat
:
if
edge_field
is
None
:
dat
=
self
.
edge_repr
else
:
dat
=
self
.
edge_repr
[
edge_field
]
dat
=
self
.
edge_repr
[
edge_field
]
dat
=
F
.
squeeze
(
dat
)
dat
=
F
.
squeeze
(
dat
)
else
:
else
:
...
@@ -386,9 +370,8 @@ class BundledExecutor(BasicExecutor):
...
@@ -386,9 +370,8 @@ class BundledExecutor(BasicExecutor):
func_pairs
=
[]
func_pairs
=
[]
for
rfn
in
rfunc
.
fn_list
:
for
rfn
in
rfunc
.
fn_list
:
mfn
=
out2mfunc
.
get
(
rfn
.
msg_field
,
None
)
mfn
=
out2mfunc
.
get
(
rfn
.
msg_field
,
None
)
# field check
if
mfn
is
None
:
assert
mfn
is
not
None
,
\
raise
DGLError
(
'Cannot find message field "%s".'
%
rfn
.
msg_field
)
"cannot find message func for reduce func in-field {}"
.
format
(
rfn
.
msg_field
)
func_pairs
.
append
((
mfn
,
rfn
))
func_pairs
.
append
((
mfn
,
rfn
))
return
func_pairs
return
func_pairs
...
@@ -409,7 +392,6 @@ class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor):
...
@@ -409,7 +392,6 @@ class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor):
self
.
_init_state
()
self
.
_init_state
()
BundledExecutor
.
__init__
(
self
,
graph
,
mfunc
,
rfunc
)
BundledExecutor
.
__init__
(
self
,
graph
,
mfunc
,
rfunc
)
class
BundledSendRecvExecutor
(
BundledExecutor
,
SendRecvExecutor
):
class
BundledSendRecvExecutor
(
BundledExecutor
,
SendRecvExecutor
):
def
__init__
(
self
,
graph
,
src
,
dst
,
mfunc
,
rfunc
):
def
__init__
(
self
,
graph
,
src
,
dst
,
mfunc
,
rfunc
):
self
.
_init_state
(
src
,
dst
)
self
.
_init_state
(
src
,
dst
)
...
...
src/c_api_common.cc
View file @
9c135fd5
/*!
* Copyright (c) 2018 by Contributors
* \file c_runtime_api.cc
* \brief DGL C API common implementations
*/
#include "c_api_common.h"
#include "c_api_common.h"
using
tvm
::
runtime
::
TVMArgs
;
using
tvm
::
runtime
::
TVMArgs
;
...
...
src/c_api_common.h
View file @
9c135fd5
// DGL C API common util functions
/*!
* Copyright (c) 2018 by Contributors
* \file c_api_common.h
* \brief DGL C API common util functions
*/
#ifndef DGL_C_API_COMMON_H_
#ifndef DGL_C_API_COMMON_H_
#define DGL_C_API_COMMON_H_
#define DGL_C_API_COMMON_H_
...
@@ -12,11 +16,19 @@ namespace dgl {
...
@@ -12,11 +16,19 @@ namespace dgl {
// Graph handler type
// Graph handler type
typedef
void
*
GraphHandle
;
typedef
void
*
GraphHandle
;
// Convert the given DLTensor to a temporary DLManagedTensor that does not own memory.
/*!
DLManagedTensor
*
CreateTmpDLManagedTensor
(
const
tvm
::
runtime
::
TVMArgValue
&
arg
);
* \brief Convert the given DLTensor to DLManagedTensor.
*
* Return a temporary DLManagedTensor that does not own memory.
*/
DLManagedTensor
*
CreateTmpDLManagedTensor
(
const
tvm
::
runtime
::
TVMArgValue
&
arg
);
// Convert a vector of NDArray to PackedFunc
/*!
tvm
::
runtime
::
PackedFunc
ConvertNDArrayVectorToPackedFunc
(
const
std
::
vector
<
tvm
::
runtime
::
NDArray
>&
vec
);
* \brief Convert a vector of NDArray to PackedFunc.
*/
tvm
::
runtime
::
PackedFunc
ConvertNDArrayVectorToPackedFunc
(
const
std
::
vector
<
tvm
::
runtime
::
NDArray
>&
vec
);
}
// namespace dgl
}
// namespace dgl
...
...
src/graph/graph.cc
View file @
9c135fd5
// Graph class implementation
/*!
* Copyright (c) 2018 by Contributors
* \file graph/graph.cc
* \brief DGL graph index implementation
*/
#include <dgl/graph.h>
#include <algorithm>
#include <algorithm>
#include <unordered_map>
#include <unordered_map>
#include <set>
#include <set>
#include <functional>
#include <functional>
#include <dgl/graph.h>
namespace
dgl
{
namespace
dgl
{
namespace
{
namespace
{
...
@@ -461,7 +465,8 @@ Subgraph Graph::EdgeSubgraph(IdArray eids) const {
...
@@ -461,7 +465,8 @@ Subgraph Graph::EdgeSubgraph(IdArray eids) const {
rst
.
graph
.
AddEdge
(
oldv2newv
[
src_id
],
oldv2newv
[
dst_id
]);
rst
.
graph
.
AddEdge
(
oldv2newv
[
src_id
],
oldv2newv
[
dst_id
]);
}
}
rst
.
induced_vertices
=
IdArray
::
Empty
({
static_cast
<
int64_t
>
(
nodes
.
size
())},
eids
->
dtype
,
eids
->
ctx
);
rst
.
induced_vertices
=
IdArray
::
Empty
(
{
static_cast
<
int64_t
>
(
nodes
.
size
())},
eids
->
dtype
,
eids
->
ctx
);
std
::
copy
(
nodes
.
begin
(),
nodes
.
end
(),
static_cast
<
int64_t
*>
(
rst
.
induced_vertices
->
data
));
std
::
copy
(
nodes
.
begin
(),
nodes
.
end
(),
static_cast
<
int64_t
*>
(
rst
.
induced_vertices
->
data
));
return
rst
;
return
rst
;
...
...
src/graph/graph_apis.cc
View file @
9c135fd5
/*!
* Copyright (c) 2018 by Contributors
* \file graph/graph.cc
* \brief DGL graph index APIs
*/
#include <dgl/graph.h>
#include <dgl/graph.h>
#include <dgl/graph_op.h>
#include <dgl/graph_op.h>
#include "../c_api_common.h"
#include "../c_api_common.h"
...
...
src/graph/graph_op.cc
View file @
9c135fd5
// Graph operation implementation
/*!
* Copyright (c) 2018 by Contributors
* \file graph/graph.cc
* \brief Graph operation implementation
*/
#include <dgl/graph_op.h>
#include <dgl/graph_op.h>
#include <algorithm>
#include <algorithm>
namespace
dgl
{
namespace
dgl
{
Graph
GraphOp
::
LineGraph
(
const
Graph
*
g
,
bool
backtracking
){
Graph
GraphOp
::
LineGraph
(
const
Graph
*
g
,
bool
backtracking
)
{
typedef
std
::
pair
<
dgl_id_t
,
dgl_id_t
>
entry
;
typedef
std
::
pair
<
dgl_id_t
,
dgl_id_t
>
entry
;
typedef
std
::
map
<
dgl_id_t
,
std
::
vector
<
entry
>>
csm
;
// Compressed Sparse Matrix
typedef
std
::
map
<
dgl_id_t
,
std
::
vector
<
entry
>>
csm
;
// Compressed Sparse Matrix
...
@@ -117,32 +121,6 @@ std::vector<Graph> GraphOp::DisjointPartitionBySizes(const Graph* graph, IdArray
...
@@ -117,32 +121,6 @@ std::vector<Graph> GraphOp::DisjointPartitionBySizes(const Graph* graph, IdArray
node_offset
+=
sizes_data
[
i
];
node_offset
+=
sizes_data
[
i
];
edge_offset
+=
num_edges
;
edge_offset
+=
num_edges
;
}
}
/*for (int64_t i = 0; i < len; ++i) {
rst[i].AddVertices(sizes_data[i]);
}
for (dgl_id_t eid = 0; eid < graph->num_edges_; ++eid) {
const dgl_id_t src = graph->all_edges_src_[eid];
const dgl_id_t dst = graph->all_edges_dst_[eid];
size_t src_select = 0, dst_select = 0;
for (size_t i = 1; i < cumsum.size(); ++i) { // TODO: replace with binary search
if (cumsum[i] > src) {
src_select = i;
break;
}
}
for (size_t i = 1; i < cumsum.size(); ++i) { // TODO: replace with binary search
if (cumsum[i] > dst) {
dst_select = i;
break;
}
}
if (src_select != dst_select) {
// the edge is ignored if across two partitions
continue;
}
const int64_t offset = cumsum[src_select - 1];
rst[src_select - 1].AddEdge(src - offset, dst - offset);
}*/
return
rst
;
return
rst
;
}
}
...
...
src/runtime/README.md
deleted
100644 → 0
View file @
9d3f299d
# C API and runtime
Borrowed and adapted from TVM project.
src/runtime/file_util.h
View file @
9c135fd5
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
* \file file_util.h
* \file file_util.h
* \brief Minimum file manipulation util for runtime.
* \brief Minimum file manipulation util for runtime.
*/
*/
#ifndef
TVM
_RUNTIME_FILE_UTIL_H_
#ifndef
DGL
_RUNTIME_FILE_UTIL_H_
#define
TVM
_RUNTIME_FILE_UTIL_H_
#define
DGL
_RUNTIME_FILE_UTIL_H_
#include <string>
#include <string>
#include "meta_data.h"
#include "meta_data.h"
...
@@ -73,4 +73,4 @@ void LoadMetaDataFromFile(
...
@@ -73,4 +73,4 @@ void LoadMetaDataFromFile(
std
::
unordered_map
<
std
::
string
,
FunctionInfo
>*
fmap
);
std
::
unordered_map
<
std
::
string
,
FunctionInfo
>*
fmap
);
}
// namespace runtime
}
// namespace runtime
}
// namespace tvm
}
// namespace tvm
#endif //
TVM
_RUNTIME_FILE_UTIL_H_
#endif //
DGL
_RUNTIME_FILE_UTIL_H_
src/runtime/meta_data.h
View file @
9c135fd5
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
* \file meta_data.h
* \file meta_data.h
* \brief Meta data related utilities
* \brief Meta data related utilities
*/
*/
#ifndef
TVM
_RUNTIME_META_DATA_H_
#ifndef
DGL
_RUNTIME_META_DATA_H_
#define
TVM
_RUNTIME_META_DATA_H_
#define
DGL
_RUNTIME_META_DATA_H_
#include <dmlc/json.h>
#include <dmlc/json.h>
#include <dmlc/io.h>
#include <dmlc/io.h>
...
@@ -33,4 +33,4 @@ struct FunctionInfo {
...
@@ -33,4 +33,4 @@ struct FunctionInfo {
namespace
dmlc
{
namespace
dmlc
{
DMLC_DECLARE_TRAITS
(
has_saveload
,
::
tvm
::
runtime
::
FunctionInfo
,
true
);
DMLC_DECLARE_TRAITS
(
has_saveload
,
::
tvm
::
runtime
::
FunctionInfo
,
true
);
}
// namespace dmlc
}
// namespace dmlc
#endif //
TVM
_RUNTIME_META_DATA_H_
#endif //
DGL
_RUNTIME_META_DATA_H_
src/runtime/module_util.h
View file @
9c135fd5
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
* \file module_util.h
* \file module_util.h
* \brief Helper utilities for module building
* \brief Helper utilities for module building
*/
*/
#ifndef
TVM
_RUNTIME_MODULE_UTIL_H_
#ifndef
DGL
_RUNTIME_MODULE_UTIL_H_
#define
TVM
_RUNTIME_MODULE_UTIL_H_
#define
DGL
_RUNTIME_MODULE_UTIL_H_
#include <dgl/runtime/module.h>
#include <dgl/runtime/module.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_runtime_api.h>
...
@@ -58,4 +58,4 @@ void InitContextFunctions(FLookup flookup) {
...
@@ -58,4 +58,4 @@ void InitContextFunctions(FLookup flookup) {
}
}
}
// namespace runtime
}
// namespace runtime
}
// namespace tvm
}
// namespace tvm
#endif //
TVM
_RUNTIME_MODULE_UTIL_H_
#endif //
DGL
_RUNTIME_MODULE_UTIL_H_
src/runtime/pack_args.h
View file @
9c135fd5
...
@@ -10,8 +10,8 @@
...
@@ -10,8 +10,8 @@
* union_32bit args[N], int num_args);
* union_32bit args[N], int num_args);
* - Pack buffer by address, pack rest parameter into 32bit union buffer.
* - Pack buffer by address, pack rest parameter into 32bit union buffer.
*/
*/
#ifndef
TVM
_RUNTIME_PACK_ARGS_H_
#ifndef
DGL
_RUNTIME_PACK_ARGS_H_
#define
TVM
_RUNTIME_PACK_ARGS_H_
#define
DGL
_RUNTIME_PACK_ARGS_H_
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_runtime_api.h>
#include <vector>
#include <vector>
...
@@ -307,4 +307,4 @@ inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types)
...
@@ -307,4 +307,4 @@ inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types)
}
}
}
// namespace runtime
}
// namespace runtime
}
// namespace tvm
}
// namespace tvm
#endif //
TVM
_RUNTIME_PACK_ARGS_H_
#endif //
DGL
_RUNTIME_PACK_ARGS_H_
src/runtime/runtime_base.h
View file @
9c135fd5
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
* \file runtime_base.h
* \file runtime_base.h
* \brief Base of all C APIs
* \brief Base of all C APIs
*/
*/
#ifndef
TVM
_RUNTIME_RUNTIME_BASE_H_
#ifndef
DGL
_RUNTIME_RUNTIME_BASE_H_
#define
TVM
_RUNTIME_RUNTIME_BASE_H_
#define
DGL
_RUNTIME_RUNTIME_BASE_H_
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_runtime_api.h>
#include <stdexcept>
#include <stdexcept>
...
@@ -31,4 +31,4 @@ inline int TVMAPIHandleException(const std::runtime_error &e) {
...
@@ -31,4 +31,4 @@ inline int TVMAPIHandleException(const std::runtime_error &e) {
return
-
1
;
return
-
1
;
}
}
#endif //
TVM
_RUNTIME_RUNTIME_BASE_H_
#endif //
DGL
_RUNTIME_RUNTIME_BASE_H_
src/runtime/thread_storage_scope.h
View file @
9c135fd5
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
* \file thread_storage_scope.h
* \file thread_storage_scope.h
* \brief Extract thread axis configuration from TVMArgs.
* \brief Extract thread axis configuration from TVMArgs.
*/
*/
#ifndef
TVM
_RUNTIME_THREAD_STORAGE_SCOPE_H_
#ifndef
DGL
_RUNTIME_THREAD_STORAGE_SCOPE_H_
#define
TVM
_RUNTIME_THREAD_STORAGE_SCOPE_H_
#define
DGL
_RUNTIME_THREAD_STORAGE_SCOPE_H_
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/packed_func.h>
#include <string>
#include <string>
...
@@ -204,4 +204,4 @@ struct hash<::tvm::runtime::StorageScope> {
...
@@ -204,4 +204,4 @@ struct hash<::tvm::runtime::StorageScope> {
}
}
};
};
}
// namespace std
}
// namespace std
#endif //
TVM
_RUNTIME_THREAD_STORAGE_SCOPE_H_
#endif //
DGL
_RUNTIME_THREAD_STORAGE_SCOPE_H_
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment