Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
be936da8
Unverified
Commit
be936da8
authored
Aug 21, 2019
by
Da Zheng
Committed by
GitHub
Aug 21, 2019
Browse files
use FFI for subgraph. (#781)
parent
0f127637
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
90 additions
and
67 deletions
+90
-67
include/dgl/graph_interface.h
include/dgl/graph_interface.h
+7
-1
python/dgl/graph_index.py
python/dgl/graph_index.py
+38
-32
src/graph/graph_apis.cc
src/graph/graph_apis.cc
+25
-23
src/graph/immutable_graph.cc
src/graph/immutable_graph.cc
+20
-11
No files found.
include/dgl/graph_interface.h
View file @
be936da8
...
...
@@ -354,7 +354,7 @@ class GraphInterface : public runtime::Object {
DGL_DEFINE_OBJECT_REF
(
GraphRef
,
GraphInterface
);
/*! \brief Subgraph data structure */
struct
Subgraph
{
struct
Subgraph
:
public
runtime
::
Object
{
/*! \brief The graph. */
GraphPtr
graph
;
/*!
...
...
@@ -367,8 +367,14 @@ struct Subgraph {
* \note This is also a map from the new edge id to the edge id in the parent graph.
*/
IdArray
induced_edges
;
static
constexpr
const
char
*
_type_key
=
"graph.Subgraph"
;
DGL_DECLARE_OBJECT_TYPE_INFO
(
Subgraph
,
runtime
::
Object
);
};
// Define SubgraphRef
DGL_DEFINE_OBJECT_REF
(
SubgraphRef
,
Subgraph
);
}
// namespace dgl
#endif // DGL_GRAPH_INTERFACE_H_
python/dgl/graph_index.py
View file @
be936da8
...
...
@@ -530,10 +530,7 @@ class GraphIndex(ObjectBase):
The subgraph index.
"""
v_array
=
v
.
todgltensor
()
rst
=
_CAPI_DGLGraphVertexSubgraph
(
self
,
v_array
)
induced_edges
=
utils
.
toindex
(
rst
(
2
))
gidx
=
rst
(
0
)
return
SubgraphIndex
(
gidx
,
self
,
v
,
induced_edges
)
return
_CAPI_DGLGraphVertexSubgraph
(
self
,
v_array
)
def
node_subgraphs
(
self
,
vs_arr
):
"""Return the induced node subgraphs.
...
...
@@ -571,10 +568,7 @@ class GraphIndex(ObjectBase):
The subgraph index.
"""
e_array
=
e
.
todgltensor
()
rst
=
_CAPI_DGLGraphEdgeSubgraph
(
self
,
e_array
,
preserve_nodes
)
induced_nodes
=
utils
.
toindex
(
rst
(
1
))
gidx
=
rst
(
0
)
return
SubgraphIndex
(
gidx
,
self
,
induced_nodes
,
e
)
return
_CAPI_DGLGraphEdgeSubgraph
(
self
,
e_array
,
preserve_nodes
)
@
utils
.
cached_member
(
cache
=
'_cache'
,
prefix
=
'scipy_adj'
)
def
adjacency_matrix_scipy
(
self
,
transpose
,
fmt
,
return_edge_ids
=
None
):
...
...
@@ -917,33 +911,45 @@ class GraphIndex(ObjectBase):
"""
return
_CAPI_DGLImmutableGraphAsNumBits
(
self
,
int
(
bits
))
class
SubgraphIndex
(
object
):
"""Internal subgraph data structure.
@
register_object
(
'graph.Subgraph'
)
class
SubgraphIndex
(
ObjectBase
):
"""Subgraph data structure"""
@
property
def
graph
(
self
):
"""The subgraph structure
Parameters
----------
graph : GraphIndex
The graph structure of this subgraph.
parent : GraphIndex
The parent graph index.
induced_nodes : utils.Index
The parent node ids in this subgraph.
induced_edges : utils.Index
The parent edge ids in this subgraph.
"""
def
__init__
(
self
,
graph
,
parent
,
induced_nodes
,
induced_edges
):
self
.
graph
=
graph
self
.
parent
=
parent
self
.
induced_nodes
=
induced_nodes
self
.
induced_edges
=
induced_edges
Returns
-------
GraphIndex
The subgraph
"""
return
_CAPI_DGLSubgraphGetGraph
(
self
)
def
__getstate__
(
self
):
raise
NotImplementedError
(
"SubgraphIndex pickling is not supported yet."
)
@
property
def
induced_nodes
(
self
):
"""Induced nodes for each node type. The return list
length should be equal to the number of node types.
def
__setstate__
(
self
,
state
):
raise
NotImplementedError
(
"SubgraphIndex unpickling is not supported yet."
)
Returns
-------
list of utils.Index
Induced nodes
"""
ret
=
_CAPI_DGLSubgraphGetInducedVertices
(
self
)
return
utils
.
toindex
(
ret
)
@
property
def
induced_edges
(
self
):
"""Induced edges for each edge type. The return list
length should be equal to the number of edge types.
Returns
-------
list of utils.Index
Induced edges
"""
ret
=
_CAPI_DGLSubgraphGetInducedEdges
(
self
)
return
utils
.
toindex
(
ret
)
###############################################################
...
...
src/graph/graph_apis.cc
View file @
be936da8
...
...
@@ -19,27 +19,6 @@ using dgl::runtime::NDArray;
namespace
dgl
{
namespace
{
// Convert Subgraph structure to PackedFunc.
PackedFunc
ConvertSubgraphToPackedFunc
(
const
Subgraph
&
sg
)
{
auto
body
=
[
sg
]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
const
int
which
=
args
[
0
];
if
(
which
==
0
)
{
*
rv
=
GraphRef
(
sg
.
graph
);
}
else
if
(
which
==
1
)
{
*
rv
=
std
::
move
(
sg
.
induced_vertices
);
}
else
if
(
which
==
2
)
{
*
rv
=
std
::
move
(
sg
.
induced_edges
);
}
else
{
LOG
(
FATAL
)
<<
"invalid choice"
;
}
};
return
PackedFunc
(
body
);
}
}
// namespace
///////////////////////////// Graph API ///////////////////////////////////
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLGraphCreateMutable"
)
...
...
@@ -312,7 +291,8 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
GraphRef
g
=
args
[
0
];
const
IdArray
vids
=
args
[
1
];
*
rv
=
ConvertSubgraphToPackedFunc
(
g
->
VertexSubgraph
(
vids
));
std
::
shared_ptr
<
Subgraph
>
subg
(
new
Subgraph
(
g
->
VertexSubgraph
(
vids
)));
*
rv
=
SubgraphRef
(
subg
);
});
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLGraphEdgeSubgraph"
)
...
...
@@ -320,7 +300,9 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
GraphRef
g
=
args
[
0
];
const
IdArray
eids
=
args
[
1
];
bool
preserve_nodes
=
args
[
2
];
*
rv
=
ConvertSubgraphToPackedFunc
(
g
->
EdgeSubgraph
(
eids
,
preserve_nodes
));
std
::
shared_ptr
<
Subgraph
>
subg
(
new
Subgraph
(
g
->
EdgeSubgraph
(
eids
,
preserve_nodes
)));
*
rv
=
SubgraphRef
(
subg
);
});
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLGraphGetAdj"
)
...
...
@@ -344,4 +326,24 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits")
*
rv
=
g
->
NumBits
();
});
// Subgraph C APIs
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLSubgraphGetGraph"
)
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
SubgraphRef
subg
=
args
[
0
];
*
rv
=
GraphRef
(
subg
->
graph
);
});
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLSubgraphGetInducedVertices"
)
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
SubgraphRef
subg
=
args
[
0
];
*
rv
=
subg
->
induced_vertices
;
});
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLSubgraphGetInducedEdges"
)
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
SubgraphRef
subg
=
args
[
0
];
*
rv
=
subg
->
induced_edges
;
});
}
// namespace dgl
src/graph/immutable_graph.cc
View file @
be936da8
...
...
@@ -196,7 +196,11 @@ Subgraph CSR::VertexSubgraph(IdArray vids) const {
const
auto
&
submat
=
aten
::
CSRSliceMatrix
(
adj_
,
vids
,
vids
);
IdArray
sub_eids
=
aten
::
Range
(
0
,
submat
.
data
->
shape
[
0
],
NumBits
(),
Context
());
CSRPtr
subcsr
(
new
CSR
(
submat
.
indptr
,
submat
.
indices
,
sub_eids
));
return
Subgraph
{
subcsr
,
vids
,
submat
.
data
};
Subgraph
subg
;
subg
.
graph
=
subcsr
;
subg
.
induced_vertices
=
vids
;
subg
.
induced_edges
=
submat
.
data
;
return
subg
;
}
CSRPtr
CSR
::
Transpose
()
const
{
...
...
@@ -313,20 +317,25 @@ EdgeArray COO::Edges(const std::string &order) const {
Subgraph
COO
::
EdgeSubgraph
(
IdArray
eids
,
bool
preserve_nodes
)
const
{
CHECK
(
IsValidIdArray
(
eids
))
<<
"Invalid edge id array."
;
COOPtr
subcoo
;
IdArray
induced_nodes
;
if
(
!
preserve_nodes
)
{
IdArray
new_src
=
aten
::
IndexSelect
(
adj_
.
row
,
eids
);
IdArray
new_dst
=
aten
::
IndexSelect
(
adj_
.
col
,
eids
);
IdArray
induced_nodes
=
aten
::
Relabel_
({
new_src
,
new_dst
});
induced_nodes
=
aten
::
Relabel_
({
new_src
,
new_dst
});
const
auto
new_nnodes
=
induced_nodes
->
shape
[
0
];
COOPtr
subcoo
(
new
COO
(
new_nnodes
,
new_src
,
new_dst
));
return
Subgraph
{
subcoo
,
induced_nodes
,
eids
};
subcoo
=
COOPtr
(
new
COO
(
new_nnodes
,
new_src
,
new_dst
));
}
else
{
IdArray
new_src
=
aten
::
IndexSelect
(
adj_
.
row
,
eids
);
IdArray
new_dst
=
aten
::
IndexSelect
(
adj_
.
col
,
eids
);
IdArray
induced_nodes
=
aten
::
Range
(
0
,
NumVertices
(),
NumBits
(),
Context
());
COOPtr
subcoo
(
new
COO
(
NumVertices
(),
new_src
,
new_dst
));
return
Subgraph
{
subcoo
,
induced_nodes
,
eids
};
induced_nodes
=
aten
::
Range
(
0
,
NumVertices
(),
NumBits
(),
Context
());
subcoo
=
COOPtr
(
new
COO
(
NumVertices
(),
new_src
,
new_dst
));
}
Subgraph
subg
;
subg
.
graph
=
subcoo
;
subg
.
induced_vertices
=
induced_nodes
;
subg
.
induced_edges
=
eids
;
return
subg
;
}
CSRPtr
COO
::
ToCSR
()
const
{
...
...
@@ -444,15 +453,15 @@ Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
// We prefer to generate a subgraph from out-csr.
auto
sg
=
GetOutCSR
()
->
VertexSubgraph
(
vids
);
CSRPtr
subcsr
=
std
::
dynamic_pointer_cast
<
CSR
>
(
sg
.
graph
);
return
Sub
graph
{
GraphPtr
(
new
ImmutableGraph
(
subcsr
))
,
sg
.
induced_vertices
,
sg
.
induced_edges
}
;
sg
.
graph
=
GraphPtr
(
new
ImmutableGraph
(
subcsr
))
;
return
sg
;
}
Subgraph
ImmutableGraph
::
EdgeSubgraph
(
IdArray
eids
,
bool
preserve_nodes
)
const
{
auto
sg
=
GetCOO
()
->
EdgeSubgraph
(
eids
,
preserve_nodes
);
COOPtr
subcoo
=
std
::
dynamic_pointer_cast
<
COO
>
(
sg
.
graph
);
return
Sub
graph
{
GraphPtr
(
new
ImmutableGraph
(
subcoo
))
,
sg
.
induced_vertices
,
sg
.
induced_edges
}
;
sg
.
graph
=
GraphPtr
(
new
ImmutableGraph
(
subcoo
))
;
return
sg
;
}
std
::
vector
<
IdArray
>
ImmutableGraph
::
GetAdj
(
bool
transpose
,
const
std
::
string
&
fmt
)
const
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment