Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
be936da8
Unverified
Commit
be936da8
authored
Aug 21, 2019
by
Da Zheng
Committed by
GitHub
Aug 21, 2019
Browse files
use FFI for subgraph. (#781)
parent
0f127637
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
90 additions
and
67 deletions
+90
-67
include/dgl/graph_interface.h
include/dgl/graph_interface.h
+7
-1
python/dgl/graph_index.py
python/dgl/graph_index.py
+38
-32
src/graph/graph_apis.cc
src/graph/graph_apis.cc
+25
-23
src/graph/immutable_graph.cc
src/graph/immutable_graph.cc
+20
-11
No files found.
include/dgl/graph_interface.h
View file @
be936da8
...
@@ -354,7 +354,7 @@ class GraphInterface : public runtime::Object {
...
@@ -354,7 +354,7 @@ class GraphInterface : public runtime::Object {
DGL_DEFINE_OBJECT_REF
(
GraphRef
,
GraphInterface
);
DGL_DEFINE_OBJECT_REF
(
GraphRef
,
GraphInterface
);
/*! \brief Subgraph data structure */
/*! \brief Subgraph data structure */
struct
Subgraph
{
struct
Subgraph
:
public
runtime
::
Object
{
/*! \brief The graph. */
/*! \brief The graph. */
GraphPtr
graph
;
GraphPtr
graph
;
/*!
/*!
...
@@ -367,8 +367,14 @@ struct Subgraph {
...
@@ -367,8 +367,14 @@ struct Subgraph {
* \note This is also a map from the new edge id to the edge id in the parent graph.
* \note This is also a map from the new edge id to the edge id in the parent graph.
*/
*/
IdArray
induced_edges
;
IdArray
induced_edges
;
static
constexpr
const
char
*
_type_key
=
"graph.Subgraph"
;
DGL_DECLARE_OBJECT_TYPE_INFO
(
Subgraph
,
runtime
::
Object
);
};
};
// Define SubgraphRef
DGL_DEFINE_OBJECT_REF
(
SubgraphRef
,
Subgraph
);
}
// namespace dgl
}
// namespace dgl
#endif // DGL_GRAPH_INTERFACE_H_
#endif // DGL_GRAPH_INTERFACE_H_
python/dgl/graph_index.py
View file @
be936da8
...
@@ -530,10 +530,7 @@ class GraphIndex(ObjectBase):
...
@@ -530,10 +530,7 @@ class GraphIndex(ObjectBase):
The subgraph index.
The subgraph index.
"""
"""
v_array
=
v
.
todgltensor
()
v_array
=
v
.
todgltensor
()
rst
=
_CAPI_DGLGraphVertexSubgraph
(
self
,
v_array
)
return
_CAPI_DGLGraphVertexSubgraph
(
self
,
v_array
)
induced_edges
=
utils
.
toindex
(
rst
(
2
))
gidx
=
rst
(
0
)
return
SubgraphIndex
(
gidx
,
self
,
v
,
induced_edges
)
def
node_subgraphs
(
self
,
vs_arr
):
def
node_subgraphs
(
self
,
vs_arr
):
"""Return the induced node subgraphs.
"""Return the induced node subgraphs.
...
@@ -571,10 +568,7 @@ class GraphIndex(ObjectBase):
...
@@ -571,10 +568,7 @@ class GraphIndex(ObjectBase):
The subgraph index.
The subgraph index.
"""
"""
e_array
=
e
.
todgltensor
()
e_array
=
e
.
todgltensor
()
rst
=
_CAPI_DGLGraphEdgeSubgraph
(
self
,
e_array
,
preserve_nodes
)
return
_CAPI_DGLGraphEdgeSubgraph
(
self
,
e_array
,
preserve_nodes
)
induced_nodes
=
utils
.
toindex
(
rst
(
1
))
gidx
=
rst
(
0
)
return
SubgraphIndex
(
gidx
,
self
,
induced_nodes
,
e
)
@
utils
.
cached_member
(
cache
=
'_cache'
,
prefix
=
'scipy_adj'
)
@
utils
.
cached_member
(
cache
=
'_cache'
,
prefix
=
'scipy_adj'
)
def
adjacency_matrix_scipy
(
self
,
transpose
,
fmt
,
return_edge_ids
=
None
):
def
adjacency_matrix_scipy
(
self
,
transpose
,
fmt
,
return_edge_ids
=
None
):
...
@@ -917,33 +911,45 @@ class GraphIndex(ObjectBase):
...
@@ -917,33 +911,45 @@ class GraphIndex(ObjectBase):
"""
"""
return
_CAPI_DGLImmutableGraphAsNumBits
(
self
,
int
(
bits
))
return
_CAPI_DGLImmutableGraphAsNumBits
(
self
,
int
(
bits
))
class
SubgraphIndex
(
object
):
@
register_object
(
'graph.Subgraph'
)
"""Internal subgraph data structure.
class
SubgraphIndex
(
ObjectBase
):
"""Subgraph data structure"""
@
property
def
graph
(
self
):
"""The subgraph structure
Parameters
Returns
----------
-------
graph : GraphIndex
GraphIndex
The graph structure of this subgraph.
The subgraph
parent : GraphIndex
"""
The parent graph index.
return
_CAPI_DGLSubgraphGetGraph
(
self
)
induced_nodes : utils.Index
The parent node ids in this subgraph.
induced_edges : utils.Index
The parent edge ids in this subgraph.
"""
def
__init__
(
self
,
graph
,
parent
,
induced_nodes
,
induced_edges
):
self
.
graph
=
graph
self
.
parent
=
parent
self
.
induced_nodes
=
induced_nodes
self
.
induced_edges
=
induced_edges
def
__getstate__
(
self
):
@
property
raise
NotImplementedError
(
def
induced_nodes
(
self
):
"SubgraphIndex pickling is not supported yet."
)
"""Induced nodes for each node type. The return list
length should be equal to the number of node types.
def
__setstate__
(
self
,
state
):
Returns
raise
NotImplementedError
(
-------
"SubgraphIndex unpickling is not supported yet."
)
list of utils.Index
Induced nodes
"""
ret
=
_CAPI_DGLSubgraphGetInducedVertices
(
self
)
return
utils
.
toindex
(
ret
)
@
property
def
induced_edges
(
self
):
"""Induced edges for each edge type. The return list
length should be equal to the number of edge types.
Returns
-------
list of utils.Index
Induced edges
"""
ret
=
_CAPI_DGLSubgraphGetInducedEdges
(
self
)
return
utils
.
toindex
(
ret
)
###############################################################
###############################################################
...
...
src/graph/graph_apis.cc
View file @
be936da8
...
@@ -19,27 +19,6 @@ using dgl::runtime::NDArray;
...
@@ -19,27 +19,6 @@ using dgl::runtime::NDArray;
namespace
dgl
{
namespace
dgl
{
namespace
{
// Convert Subgraph structure to PackedFunc.
PackedFunc
ConvertSubgraphToPackedFunc
(
const
Subgraph
&
sg
)
{
auto
body
=
[
sg
]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
const
int
which
=
args
[
0
];
if
(
which
==
0
)
{
*
rv
=
GraphRef
(
sg
.
graph
);
}
else
if
(
which
==
1
)
{
*
rv
=
std
::
move
(
sg
.
induced_vertices
);
}
else
if
(
which
==
2
)
{
*
rv
=
std
::
move
(
sg
.
induced_edges
);
}
else
{
LOG
(
FATAL
)
<<
"invalid choice"
;
}
};
return
PackedFunc
(
body
);
}
}
// namespace
///////////////////////////// Graph API ///////////////////////////////////
///////////////////////////// Graph API ///////////////////////////////////
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLGraphCreateMutable"
)
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLGraphCreateMutable"
)
...
@@ -312,7 +291,8 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
...
@@ -312,7 +291,8 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
GraphRef
g
=
args
[
0
];
GraphRef
g
=
args
[
0
];
const
IdArray
vids
=
args
[
1
];
const
IdArray
vids
=
args
[
1
];
*
rv
=
ConvertSubgraphToPackedFunc
(
g
->
VertexSubgraph
(
vids
));
std
::
shared_ptr
<
Subgraph
>
subg
(
new
Subgraph
(
g
->
VertexSubgraph
(
vids
)));
*
rv
=
SubgraphRef
(
subg
);
});
});
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLGraphEdgeSubgraph"
)
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLGraphEdgeSubgraph"
)
...
@@ -320,7 +300,9 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
...
@@ -320,7 +300,9 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
GraphRef
g
=
args
[
0
];
GraphRef
g
=
args
[
0
];
const
IdArray
eids
=
args
[
1
];
const
IdArray
eids
=
args
[
1
];
bool
preserve_nodes
=
args
[
2
];
bool
preserve_nodes
=
args
[
2
];
*
rv
=
ConvertSubgraphToPackedFunc
(
g
->
EdgeSubgraph
(
eids
,
preserve_nodes
));
std
::
shared_ptr
<
Subgraph
>
subg
(
new
Subgraph
(
g
->
EdgeSubgraph
(
eids
,
preserve_nodes
)));
*
rv
=
SubgraphRef
(
subg
);
});
});
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLGraphGetAdj"
)
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLGraphGetAdj"
)
...
@@ -344,4 +326,24 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits")
...
@@ -344,4 +326,24 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits")
*
rv
=
g
->
NumBits
();
*
rv
=
g
->
NumBits
();
});
});
// Subgraph C APIs
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLSubgraphGetGraph"
)
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
SubgraphRef
subg
=
args
[
0
];
*
rv
=
GraphRef
(
subg
->
graph
);
});
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLSubgraphGetInducedVertices"
)
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
SubgraphRef
subg
=
args
[
0
];
*
rv
=
subg
->
induced_vertices
;
});
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLSubgraphGetInducedEdges"
)
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
SubgraphRef
subg
=
args
[
0
];
*
rv
=
subg
->
induced_edges
;
});
}
// namespace dgl
}
// namespace dgl
src/graph/immutable_graph.cc
View file @
be936da8
...
@@ -196,7 +196,11 @@ Subgraph CSR::VertexSubgraph(IdArray vids) const {
...
@@ -196,7 +196,11 @@ Subgraph CSR::VertexSubgraph(IdArray vids) const {
const
auto
&
submat
=
aten
::
CSRSliceMatrix
(
adj_
,
vids
,
vids
);
const
auto
&
submat
=
aten
::
CSRSliceMatrix
(
adj_
,
vids
,
vids
);
IdArray
sub_eids
=
aten
::
Range
(
0
,
submat
.
data
->
shape
[
0
],
NumBits
(),
Context
());
IdArray
sub_eids
=
aten
::
Range
(
0
,
submat
.
data
->
shape
[
0
],
NumBits
(),
Context
());
CSRPtr
subcsr
(
new
CSR
(
submat
.
indptr
,
submat
.
indices
,
sub_eids
));
CSRPtr
subcsr
(
new
CSR
(
submat
.
indptr
,
submat
.
indices
,
sub_eids
));
return
Subgraph
{
subcsr
,
vids
,
submat
.
data
};
Subgraph
subg
;
subg
.
graph
=
subcsr
;
subg
.
induced_vertices
=
vids
;
subg
.
induced_edges
=
submat
.
data
;
return
subg
;
}
}
CSRPtr
CSR
::
Transpose
()
const
{
CSRPtr
CSR
::
Transpose
()
const
{
...
@@ -313,20 +317,25 @@ EdgeArray COO::Edges(const std::string &order) const {
...
@@ -313,20 +317,25 @@ EdgeArray COO::Edges(const std::string &order) const {
Subgraph
COO
::
EdgeSubgraph
(
IdArray
eids
,
bool
preserve_nodes
)
const
{
Subgraph
COO
::
EdgeSubgraph
(
IdArray
eids
,
bool
preserve_nodes
)
const
{
CHECK
(
IsValidIdArray
(
eids
))
<<
"Invalid edge id array."
;
CHECK
(
IsValidIdArray
(
eids
))
<<
"Invalid edge id array."
;
COOPtr
subcoo
;
IdArray
induced_nodes
;
if
(
!
preserve_nodes
)
{
if
(
!
preserve_nodes
)
{
IdArray
new_src
=
aten
::
IndexSelect
(
adj_
.
row
,
eids
);
IdArray
new_src
=
aten
::
IndexSelect
(
adj_
.
row
,
eids
);
IdArray
new_dst
=
aten
::
IndexSelect
(
adj_
.
col
,
eids
);
IdArray
new_dst
=
aten
::
IndexSelect
(
adj_
.
col
,
eids
);
IdArray
induced_nodes
=
aten
::
Relabel_
({
new_src
,
new_dst
});
induced_nodes
=
aten
::
Relabel_
({
new_src
,
new_dst
});
const
auto
new_nnodes
=
induced_nodes
->
shape
[
0
];
const
auto
new_nnodes
=
induced_nodes
->
shape
[
0
];
COOPtr
subcoo
(
new
COO
(
new_nnodes
,
new_src
,
new_dst
));
subcoo
=
COOPtr
(
new
COO
(
new_nnodes
,
new_src
,
new_dst
));
return
Subgraph
{
subcoo
,
induced_nodes
,
eids
};
}
else
{
}
else
{
IdArray
new_src
=
aten
::
IndexSelect
(
adj_
.
row
,
eids
);
IdArray
new_src
=
aten
::
IndexSelect
(
adj_
.
row
,
eids
);
IdArray
new_dst
=
aten
::
IndexSelect
(
adj_
.
col
,
eids
);
IdArray
new_dst
=
aten
::
IndexSelect
(
adj_
.
col
,
eids
);
IdArray
induced_nodes
=
aten
::
Range
(
0
,
NumVertices
(),
NumBits
(),
Context
());
induced_nodes
=
aten
::
Range
(
0
,
NumVertices
(),
NumBits
(),
Context
());
COOPtr
subcoo
(
new
COO
(
NumVertices
(),
new_src
,
new_dst
));
subcoo
=
COOPtr
(
new
COO
(
NumVertices
(),
new_src
,
new_dst
));
return
Subgraph
{
subcoo
,
induced_nodes
,
eids
};
}
}
Subgraph
subg
;
subg
.
graph
=
subcoo
;
subg
.
induced_vertices
=
induced_nodes
;
subg
.
induced_edges
=
eids
;
return
subg
;
}
}
CSRPtr
COO
::
ToCSR
()
const
{
CSRPtr
COO
::
ToCSR
()
const
{
...
@@ -444,15 +453,15 @@ Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
...
@@ -444,15 +453,15 @@ Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
// We prefer to generate a subgraph from out-csr.
// We prefer to generate a subgraph from out-csr.
auto
sg
=
GetOutCSR
()
->
VertexSubgraph
(
vids
);
auto
sg
=
GetOutCSR
()
->
VertexSubgraph
(
vids
);
CSRPtr
subcsr
=
std
::
dynamic_pointer_cast
<
CSR
>
(
sg
.
graph
);
CSRPtr
subcsr
=
std
::
dynamic_pointer_cast
<
CSR
>
(
sg
.
graph
);
return
Sub
graph
{
GraphPtr
(
new
ImmutableGraph
(
subcsr
))
,
sg
.
graph
=
GraphPtr
(
new
ImmutableGraph
(
subcsr
))
;
sg
.
induced_vertices
,
sg
.
induced_edges
}
;
return
sg
;
}
}
Subgraph
ImmutableGraph
::
EdgeSubgraph
(
IdArray
eids
,
bool
preserve_nodes
)
const
{
Subgraph
ImmutableGraph
::
EdgeSubgraph
(
IdArray
eids
,
bool
preserve_nodes
)
const
{
auto
sg
=
GetCOO
()
->
EdgeSubgraph
(
eids
,
preserve_nodes
);
auto
sg
=
GetCOO
()
->
EdgeSubgraph
(
eids
,
preserve_nodes
);
COOPtr
subcoo
=
std
::
dynamic_pointer_cast
<
COO
>
(
sg
.
graph
);
COOPtr
subcoo
=
std
::
dynamic_pointer_cast
<
COO
>
(
sg
.
graph
);
return
Sub
graph
{
GraphPtr
(
new
ImmutableGraph
(
subcoo
))
,
sg
.
graph
=
GraphPtr
(
new
ImmutableGraph
(
subcoo
))
;
sg
.
induced_vertices
,
sg
.
induced_edges
}
;
return
sg
;
}
}
std
::
vector
<
IdArray
>
ImmutableGraph
::
GetAdj
(
bool
transpose
,
const
std
::
string
&
fmt
)
const
{
std
::
vector
<
IdArray
>
ImmutableGraph
::
GetAdj
(
bool
transpose
,
const
std
::
string
&
fmt
)
const
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment