Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
be936da8
"..._static/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "ebbd5f643d3006c601183e6f5a111611663754c5"
Unverified
Commit
be936da8
authored
Aug 21, 2019
by
Da Zheng
Committed by
GitHub
Aug 21, 2019
Browse files
use FFI for subgraph. (#781)
parent
0f127637
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
90 additions
and
67 deletions
+90
-67
include/dgl/graph_interface.h
include/dgl/graph_interface.h
+7
-1
python/dgl/graph_index.py
python/dgl/graph_index.py
+38
-32
src/graph/graph_apis.cc
src/graph/graph_apis.cc
+25
-23
src/graph/immutable_graph.cc
src/graph/immutable_graph.cc
+20
-11
No files found.
include/dgl/graph_interface.h
View file @
be936da8
...
@@ -354,7 +354,7 @@ class GraphInterface : public runtime::Object {
...
@@ -354,7 +354,7 @@ class GraphInterface : public runtime::Object {
DGL_DEFINE_OBJECT_REF
(
GraphRef
,
GraphInterface
);
DGL_DEFINE_OBJECT_REF
(
GraphRef
,
GraphInterface
);
/*! \brief Subgraph data structure */
/*! \brief Subgraph data structure */
struct
Subgraph
{
struct
Subgraph
:
public
runtime
::
Object
{
/*! \brief The graph. */
/*! \brief The graph. */
GraphPtr
graph
;
GraphPtr
graph
;
/*!
/*!
...
@@ -367,8 +367,14 @@ struct Subgraph {
...
@@ -367,8 +367,14 @@ struct Subgraph {
* \note This is also a map from the new edge id to the edge id in the parent graph.
* \note This is also a map from the new edge id to the edge id in the parent graph.
*/
*/
IdArray
induced_edges
;
IdArray
induced_edges
;
static
constexpr
const
char
*
_type_key
=
"graph.Subgraph"
;
DGL_DECLARE_OBJECT_TYPE_INFO
(
Subgraph
,
runtime
::
Object
);
};
};
// Define SubgraphRef
DGL_DEFINE_OBJECT_REF
(
SubgraphRef
,
Subgraph
);
}
// namespace dgl
}
// namespace dgl
#endif // DGL_GRAPH_INTERFACE_H_
#endif // DGL_GRAPH_INTERFACE_H_
python/dgl/graph_index.py
View file @
be936da8
...
@@ -530,10 +530,7 @@ class GraphIndex(ObjectBase):
...
@@ -530,10 +530,7 @@ class GraphIndex(ObjectBase):
The subgraph index.
The subgraph index.
"""
"""
v_array
=
v
.
todgltensor
()
v_array
=
v
.
todgltensor
()
rst
=
_CAPI_DGLGraphVertexSubgraph
(
self
,
v_array
)
return
_CAPI_DGLGraphVertexSubgraph
(
self
,
v_array
)
induced_edges
=
utils
.
toindex
(
rst
(
2
))
gidx
=
rst
(
0
)
return
SubgraphIndex
(
gidx
,
self
,
v
,
induced_edges
)
def
node_subgraphs
(
self
,
vs_arr
):
def
node_subgraphs
(
self
,
vs_arr
):
"""Return the induced node subgraphs.
"""Return the induced node subgraphs.
...
@@ -571,10 +568,7 @@ class GraphIndex(ObjectBase):
...
@@ -571,10 +568,7 @@ class GraphIndex(ObjectBase):
The subgraph index.
The subgraph index.
"""
"""
e_array
=
e
.
todgltensor
()
e_array
=
e
.
todgltensor
()
rst
=
_CAPI_DGLGraphEdgeSubgraph
(
self
,
e_array
,
preserve_nodes
)
return
_CAPI_DGLGraphEdgeSubgraph
(
self
,
e_array
,
preserve_nodes
)
induced_nodes
=
utils
.
toindex
(
rst
(
1
))
gidx
=
rst
(
0
)
return
SubgraphIndex
(
gidx
,
self
,
induced_nodes
,
e
)
@
utils
.
cached_member
(
cache
=
'_cache'
,
prefix
=
'scipy_adj'
)
@
utils
.
cached_member
(
cache
=
'_cache'
,
prefix
=
'scipy_adj'
)
def
adjacency_matrix_scipy
(
self
,
transpose
,
fmt
,
return_edge_ids
=
None
):
def
adjacency_matrix_scipy
(
self
,
transpose
,
fmt
,
return_edge_ids
=
None
):
...
@@ -917,33 +911,45 @@ class GraphIndex(ObjectBase):
...
@@ -917,33 +911,45 @@ class GraphIndex(ObjectBase):
"""
"""
return
_CAPI_DGLImmutableGraphAsNumBits
(
self
,
int
(
bits
))
return
_CAPI_DGLImmutableGraphAsNumBits
(
self
,
int
(
bits
))
class
SubgraphIndex
(
object
):
@
register_object
(
'graph.Subgraph'
)
"""Internal subgraph data structure.
class
SubgraphIndex
(
ObjectBase
):
"""Subgraph data structure"""
@
property
def
graph
(
self
):
"""The subgraph structure
Parameters
Returns
----------
-------
graph : GraphIndex
GraphIndex
The graph structure of this subgraph.
The subgraph
parent : GraphIndex
"""
The parent graph index.
return
_CAPI_DGLSubgraphGetGraph
(
self
)
induced_nodes : utils.Index
The parent node ids in this subgraph.
induced_edges : utils.Index
The parent edge ids in this subgraph.
"""
def
__init__
(
self
,
graph
,
parent
,
induced_nodes
,
induced_edges
):
self
.
graph
=
graph
self
.
parent
=
parent
self
.
induced_nodes
=
induced_nodes
self
.
induced_edges
=
induced_edges
def
__getstate__
(
self
):
@
property
raise
NotImplementedError
(
def
induced_nodes
(
self
):
"SubgraphIndex pickling is not supported yet."
)
"""Induced nodes for each node type. The return list
length should be equal to the number of node types.
def
__setstate__
(
self
,
state
):
Returns
raise
NotImplementedError
(
-------
"SubgraphIndex unpickling is not supported yet."
)
list of utils.Index
Induced nodes
"""
ret
=
_CAPI_DGLSubgraphGetInducedVertices
(
self
)
return
utils
.
toindex
(
ret
)
@
property
def
induced_edges
(
self
):
"""Induced edges for each edge type. The return list
length should be equal to the number of edge types.
Returns
-------
list of utils.Index
Induced edges
"""
ret
=
_CAPI_DGLSubgraphGetInducedEdges
(
self
)
return
utils
.
toindex
(
ret
)
###############################################################
###############################################################
...
...
src/graph/graph_apis.cc
View file @
be936da8
...
@@ -19,27 +19,6 @@ using dgl::runtime::NDArray;
...
@@ -19,27 +19,6 @@ using dgl::runtime::NDArray;
namespace
dgl
{
namespace
dgl
{
namespace
{
// Convert Subgraph structure to PackedFunc.
PackedFunc
ConvertSubgraphToPackedFunc
(
const
Subgraph
&
sg
)
{
auto
body
=
[
sg
]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
const
int
which
=
args
[
0
];
if
(
which
==
0
)
{
*
rv
=
GraphRef
(
sg
.
graph
);
}
else
if
(
which
==
1
)
{
*
rv
=
std
::
move
(
sg
.
induced_vertices
);
}
else
if
(
which
==
2
)
{
*
rv
=
std
::
move
(
sg
.
induced_edges
);
}
else
{
LOG
(
FATAL
)
<<
"invalid choice"
;
}
};
return
PackedFunc
(
body
);
}
}
// namespace
///////////////////////////// Graph API ///////////////////////////////////
///////////////////////////// Graph API ///////////////////////////////////
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLGraphCreateMutable"
)
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLGraphCreateMutable"
)
...
@@ -312,7 +291,8 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
...
@@ -312,7 +291,8 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
GraphRef
g
=
args
[
0
];
GraphRef
g
=
args
[
0
];
const
IdArray
vids
=
args
[
1
];
const
IdArray
vids
=
args
[
1
];
*
rv
=
ConvertSubgraphToPackedFunc
(
g
->
VertexSubgraph
(
vids
));
std
::
shared_ptr
<
Subgraph
>
subg
(
new
Subgraph
(
g
->
VertexSubgraph
(
vids
)));
*
rv
=
SubgraphRef
(
subg
);
});
});
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLGraphEdgeSubgraph"
)
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLGraphEdgeSubgraph"
)
...
@@ -320,7 +300,9 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
...
@@ -320,7 +300,9 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
GraphRef
g
=
args
[
0
];
GraphRef
g
=
args
[
0
];
const
IdArray
eids
=
args
[
1
];
const
IdArray
eids
=
args
[
1
];
bool
preserve_nodes
=
args
[
2
];
bool
preserve_nodes
=
args
[
2
];
*
rv
=
ConvertSubgraphToPackedFunc
(
g
->
EdgeSubgraph
(
eids
,
preserve_nodes
));
std
::
shared_ptr
<
Subgraph
>
subg
(
new
Subgraph
(
g
->
EdgeSubgraph
(
eids
,
preserve_nodes
)));
*
rv
=
SubgraphRef
(
subg
);
});
});
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLGraphGetAdj"
)
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLGraphGetAdj"
)
...
@@ -344,4 +326,24 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits")
...
@@ -344,4 +326,24 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits")
*
rv
=
g
->
NumBits
();
*
rv
=
g
->
NumBits
();
});
});
// Subgraph C APIs
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLSubgraphGetGraph"
)
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
SubgraphRef
subg
=
args
[
0
];
*
rv
=
GraphRef
(
subg
->
graph
);
});
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLSubgraphGetInducedVertices"
)
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
SubgraphRef
subg
=
args
[
0
];
*
rv
=
subg
->
induced_vertices
;
});
DGL_REGISTER_GLOBAL
(
"graph_index._CAPI_DGLSubgraphGetInducedEdges"
)
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
SubgraphRef
subg
=
args
[
0
];
*
rv
=
subg
->
induced_edges
;
});
}
// namespace dgl
}
// namespace dgl
src/graph/immutable_graph.cc
View file @
be936da8
...
@@ -196,7 +196,11 @@ Subgraph CSR::VertexSubgraph(IdArray vids) const {
...
@@ -196,7 +196,11 @@ Subgraph CSR::VertexSubgraph(IdArray vids) const {
const
auto
&
submat
=
aten
::
CSRSliceMatrix
(
adj_
,
vids
,
vids
);
const
auto
&
submat
=
aten
::
CSRSliceMatrix
(
adj_
,
vids
,
vids
);
IdArray
sub_eids
=
aten
::
Range
(
0
,
submat
.
data
->
shape
[
0
],
NumBits
(),
Context
());
IdArray
sub_eids
=
aten
::
Range
(
0
,
submat
.
data
->
shape
[
0
],
NumBits
(),
Context
());
CSRPtr
subcsr
(
new
CSR
(
submat
.
indptr
,
submat
.
indices
,
sub_eids
));
CSRPtr
subcsr
(
new
CSR
(
submat
.
indptr
,
submat
.
indices
,
sub_eids
));
return
Subgraph
{
subcsr
,
vids
,
submat
.
data
};
Subgraph
subg
;
subg
.
graph
=
subcsr
;
subg
.
induced_vertices
=
vids
;
subg
.
induced_edges
=
submat
.
data
;
return
subg
;
}
}
CSRPtr
CSR
::
Transpose
()
const
{
CSRPtr
CSR
::
Transpose
()
const
{
...
@@ -313,20 +317,25 @@ EdgeArray COO::Edges(const std::string &order) const {
...
@@ -313,20 +317,25 @@ EdgeArray COO::Edges(const std::string &order) const {
Subgraph
COO
::
EdgeSubgraph
(
IdArray
eids
,
bool
preserve_nodes
)
const
{
Subgraph
COO
::
EdgeSubgraph
(
IdArray
eids
,
bool
preserve_nodes
)
const
{
CHECK
(
IsValidIdArray
(
eids
))
<<
"Invalid edge id array."
;
CHECK
(
IsValidIdArray
(
eids
))
<<
"Invalid edge id array."
;
COOPtr
subcoo
;
IdArray
induced_nodes
;
if
(
!
preserve_nodes
)
{
if
(
!
preserve_nodes
)
{
IdArray
new_src
=
aten
::
IndexSelect
(
adj_
.
row
,
eids
);
IdArray
new_src
=
aten
::
IndexSelect
(
adj_
.
row
,
eids
);
IdArray
new_dst
=
aten
::
IndexSelect
(
adj_
.
col
,
eids
);
IdArray
new_dst
=
aten
::
IndexSelect
(
adj_
.
col
,
eids
);
IdArray
induced_nodes
=
aten
::
Relabel_
({
new_src
,
new_dst
});
induced_nodes
=
aten
::
Relabel_
({
new_src
,
new_dst
});
const
auto
new_nnodes
=
induced_nodes
->
shape
[
0
];
const
auto
new_nnodes
=
induced_nodes
->
shape
[
0
];
COOPtr
subcoo
(
new
COO
(
new_nnodes
,
new_src
,
new_dst
));
subcoo
=
COOPtr
(
new
COO
(
new_nnodes
,
new_src
,
new_dst
));
return
Subgraph
{
subcoo
,
induced_nodes
,
eids
};
}
else
{
}
else
{
IdArray
new_src
=
aten
::
IndexSelect
(
adj_
.
row
,
eids
);
IdArray
new_src
=
aten
::
IndexSelect
(
adj_
.
row
,
eids
);
IdArray
new_dst
=
aten
::
IndexSelect
(
adj_
.
col
,
eids
);
IdArray
new_dst
=
aten
::
IndexSelect
(
adj_
.
col
,
eids
);
IdArray
induced_nodes
=
aten
::
Range
(
0
,
NumVertices
(),
NumBits
(),
Context
());
induced_nodes
=
aten
::
Range
(
0
,
NumVertices
(),
NumBits
(),
Context
());
COOPtr
subcoo
(
new
COO
(
NumVertices
(),
new_src
,
new_dst
));
subcoo
=
COOPtr
(
new
COO
(
NumVertices
(),
new_src
,
new_dst
));
return
Subgraph
{
subcoo
,
induced_nodes
,
eids
};
}
}
Subgraph
subg
;
subg
.
graph
=
subcoo
;
subg
.
induced_vertices
=
induced_nodes
;
subg
.
induced_edges
=
eids
;
return
subg
;
}
}
CSRPtr
COO
::
ToCSR
()
const
{
CSRPtr
COO
::
ToCSR
()
const
{
...
@@ -444,15 +453,15 @@ Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
...
@@ -444,15 +453,15 @@ Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
// We prefer to generate a subgraph from out-csr.
// We prefer to generate a subgraph from out-csr.
auto
sg
=
GetOutCSR
()
->
VertexSubgraph
(
vids
);
auto
sg
=
GetOutCSR
()
->
VertexSubgraph
(
vids
);
CSRPtr
subcsr
=
std
::
dynamic_pointer_cast
<
CSR
>
(
sg
.
graph
);
CSRPtr
subcsr
=
std
::
dynamic_pointer_cast
<
CSR
>
(
sg
.
graph
);
return
Sub
graph
{
GraphPtr
(
new
ImmutableGraph
(
subcsr
))
,
sg
.
graph
=
GraphPtr
(
new
ImmutableGraph
(
subcsr
))
;
sg
.
induced_vertices
,
sg
.
induced_edges
}
;
return
sg
;
}
}
Subgraph
ImmutableGraph
::
EdgeSubgraph
(
IdArray
eids
,
bool
preserve_nodes
)
const
{
Subgraph
ImmutableGraph
::
EdgeSubgraph
(
IdArray
eids
,
bool
preserve_nodes
)
const
{
auto
sg
=
GetCOO
()
->
EdgeSubgraph
(
eids
,
preserve_nodes
);
auto
sg
=
GetCOO
()
->
EdgeSubgraph
(
eids
,
preserve_nodes
);
COOPtr
subcoo
=
std
::
dynamic_pointer_cast
<
COO
>
(
sg
.
graph
);
COOPtr
subcoo
=
std
::
dynamic_pointer_cast
<
COO
>
(
sg
.
graph
);
return
Sub
graph
{
GraphPtr
(
new
ImmutableGraph
(
subcoo
))
,
sg
.
graph
=
GraphPtr
(
new
ImmutableGraph
(
subcoo
))
;
sg
.
induced_vertices
,
sg
.
induced_edges
}
;
return
sg
;
}
}
std
::
vector
<
IdArray
>
ImmutableGraph
::
GetAdj
(
bool
transpose
,
const
std
::
string
&
fmt
)
const
{
std
::
vector
<
IdArray
>
ImmutableGraph
::
GetAdj
(
bool
transpose
,
const
std
::
string
&
fmt
)
const
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment