Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
14d88497
"vscode:/vscode.git/clone" did not exist on "d873acc2545e5b73be75d0e18cedfe7163febf88"
Commit
14d88497
authored
Sep 13, 2018
by
Minjie Wang
Browse files
impl
parent
2c626b90
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
457 additions
and
42 deletions
+457
-42
include/dgl/graph.h
include/dgl/graph.h
+88
-16
include/dgl/vector_view.h
include/dgl/vector_view.h
+56
-13
src/graph/graph.cc
src/graph/graph.cc
+313
-13
No files found.
include/dgl/graph.h
View file @
14d88497
...
@@ -13,23 +13,50 @@ typedef tvm::runtime::NDArray IdArray;
...
@@ -13,23 +13,50 @@ typedef tvm::runtime::NDArray IdArray;
typedef
tvm
::
runtime
::
NDArray
DegreeArray
;
typedef
tvm
::
runtime
::
NDArray
DegreeArray
;
typedef
tvm
::
runtime
::
NDArray
BoolArray
;
typedef
tvm
::
runtime
::
NDArray
BoolArray
;
class
DGLGraph
;
class
Graph
;
class
DGLSubGraph
;
/*!
/*!
* \brief Base dgl graph class.
* \brief Base dgl graph class.
*
*
* DGL
G
raph is
a
directed
graph
. Vertices are integers enumerated from zero. Edges
* DGL
's g
raph is directed. Vertices are integers enumerated from zero. Edges
* are uniquely identified by the two endpoints. Multi-edge is currently not
* are uniquely identified by the two endpoints. Multi-edge is currently not
* supported.
* supported.
*
*
* Removal of vertices/edges is not allowed. Instead, the graph can only be "cleared"
* Removal of vertices/edges is not allowed. Instead, the graph can only be "cleared"
* by removing all the vertices and edges.
* by removing all the vertices and edges.
*
* When calling functions supporing multiple edges (e.g. AddEdges, HasEdges),
* the input edges are represented by two id arrays for source and destination
* vertex ids. In the general case, the two arrays should have the same length.
* If the length of src id array is one, it represents one-many connections.
* If the length of dst id array is one, it represents many-one connections.
*/
*/
class
DGL
Graph
{
class
Graph
{
public:
public:
/*! \brief default constructor */
/*! \brief default constructor */
DGLGraph
()
{}
Graph
()
{}
/*! \brief default copy constructor */
Graph
(
const
Graph
&
other
)
=
default
;
#ifndef _MSC_VER
/*! \brief default move constructor */
Graph
(
Graph
&&
other
)
=
default
;
#else
Graph
(
Graph
&&
other
)
{
adjlist_
=
other
.
adjlist_
;
read_only_
=
other
.
read_only_
;
num_edges_
=
other
.
num_edges_
;
other
.
clear
();
}
#endif // _MSC_VER
/*! \brief default assign constructor */
Graph
&
operator
=
(
const
Graph
&
other
)
=
default
;
/*! \brief default destructor */
~
Graph
()
=
default
;
/*!
/*!
* \brief Add vertices to the graph.
* \brief Add vertices to the graph.
* \note Since vertices are integers enumerated from zero, only the number of
* \note Since vertices are integers enumerated from zero, only the number of
...
@@ -37,46 +64,68 @@ class DGLGraph {
...
@@ -37,46 +64,68 @@ class DGLGraph {
* \param num_vertices The number of vertices to be added.
* \param num_vertices The number of vertices to be added.
*/
*/
void
AddVertices
(
uint64_t
num_vertices
);
void
AddVertices
(
uint64_t
num_vertices
);
/*!
/*!
* \brief Add one edge to the graph.
* \brief Add one edge to the graph.
* \param src The source vertex.
* \param src The source vertex.
* \param dst The destination vertex.
* \param dst The destination vertex.
*/
*/
void
AddEdge
(
dgl_id_t
src
,
dgl_id_t
dst
);
void
AddEdge
(
dgl_id_t
src
,
dgl_id_t
dst
);
/*!
/*!
* \brief Add edges to the graph.
* \brief Add edges to the graph.
* \param src_ids The source vertex id array.
* \param src_ids The source vertex id array.
* \param dst_ids The destination vertex id array.
* \param dst_ids The destination vertex id array.
*/
*/
void
AddEdges
(
IdArray
src_ids
,
IdArray
dst_ids
);
void
AddEdges
(
IdArray
src_ids
,
IdArray
dst_ids
);
/*!
/*!
* \brief Clear the graph. Remove all vertices/edges.
* \brief Clear the graph. Remove all vertices/edges.
*/
*/
void
Clear
();
void
Clear
()
{
adjlist_
=
vector_view
<
EdgeList
>
();
read_only_
=
false
;
num_edges_
=
0
;
}
/*! \return the number of vertices in the graph.*/
/*! \return the number of vertices in the graph.*/
uint64_t
NumVertices
()
const
;
uint64_t
NumVertices
()
const
{
return
adjlist_
.
size
();
}
/*! \return the number of edges in the graph.*/
/*! \return the number of edges in the graph.*/
uint64_t
NumEdges
()
const
;
uint64_t
NumEdges
()
const
{
return
num_edges_
;
}
/*! \return true if the given vertex is in the graph.*/
/*! \return true if the given vertex is in the graph.*/
bool
HasVertex
(
dgl_id_t
vid
)
const
;
bool
HasVertex
(
dgl_id_t
vid
)
const
{
return
vid
<
NumVertices
();
}
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/
BoolArray
HasVertices
(
IdArray
vids
)
const
;
BoolArray
HasVertices
(
IdArray
vids
)
const
;
/*! \return true if the given edge is in the graph.*/
/*! \return true if the given edge is in the graph.*/
bool
HasEdge
(
dgl_id_t
src
,
dgl_id_t
dst
)
const
;
bool
HasEdge
(
dgl_id_t
src
,
dgl_id_t
dst
)
const
;
/*! \return a 0-1 array indicating whether the given edges are in the graph.*/
/*! \return a 0-1 array indicating whether the given edges are in the graph.*/
BoolArray
HasEdges
(
IdArray
src_ids
,
IdArray
dst_ids
)
const
;
BoolArray
HasEdges
(
IdArray
src_ids
,
IdArray
dst_ids
)
const
;
/*!
/*!
* \brief Find the predecessors of a vertex.
* \brief Find the predecessors of a vertex.
* \param vid The vertex id.
* \param vid The vertex id.
* \return the predecessor id array.
* \return the predecessor id array.
*/
*/
IdArray
Predecessors
(
dgl_id_t
vid
)
const
;
IdArray
Predecessors
(
dgl_id_t
vid
)
const
;
/*!
/*!
* \brief Find the successors of a vertex.
* \brief Find the successors of a vertex.
* \param vid The vertex id.
* \param vid The vertex id.
* \return the successor id array.
* \return the successor id array.
*/
*/
IdArray
Successors
(
dgl_id_t
vid
)
const
;
IdArray
Successors
(
dgl_id_t
vid
)
const
;
/*!
/*!
* \brief Get the id of the given edge.
* \brief Get the id of the given edge.
* \note Edges are associated with an integer id start from zero.
* \note Edges are associated with an integer id start from zero.
...
@@ -86,6 +135,7 @@ class DGLGraph {
...
@@ -86,6 +135,7 @@ class DGLGraph {
* \return the edge id.
* \return the edge id.
*/
*/
dgl_id_t
EdgeId
(
dgl_id_t
src
,
dgl_id_t
dst
)
const
;
dgl_id_t
EdgeId
(
dgl_id_t
src
,
dgl_id_t
dst
)
const
;
/*!
/*!
* \brief Get the id of the given edges.
* \brief Get the id of the given edges.
* \note Edges are associated with an integer id start from zero.
* \note Edges are associated with an integer id start from zero.
...
@@ -93,59 +143,77 @@ class DGLGraph {
...
@@ -93,59 +143,77 @@ class DGLGraph {
* \return the edge id array.
* \return the edge id array.
*/
*/
IdArray
EdgeIds
(
IdArray
src
,
IdArray
dst
)
const
;
IdArray
EdgeIds
(
IdArray
src
,
IdArray
dst
)
const
;
/*!
/*!
* \brief Get the in edges of the vertex.
* \brief Get the in edges of the vertex.
* \note The returned dst id array is filled with vid.
* \param vid The vertex id.
* \param vid The vertex id.
* \return the id arrays of the two endpoints of the edges.
* \return the id arrays of the two endpoints of the edges.
*/
*/
std
::
pair
<
IdArray
,
IdArray
>
InEdges
(
dgl_id_t
vid
)
const
;
std
::
pair
<
IdArray
,
IdArray
>
InEdges
(
dgl_id_t
vid
)
const
;
/*!
/*!
* \brief Get the in edges of the vertices.
* \brief Get the in edges of the vertices.
* \param vids The vertex id array.
* \param vids The vertex id array.
* \return the id arrays of the two endpoints of the edges.
* \return the id arrays of the two endpoints of the edges.
*/
*/
std
::
pair
<
IdArray
,
IdArray
>
InEdges
(
IdArray
vids
)
const
;
std
::
pair
<
IdArray
,
IdArray
>
InEdges
(
IdArray
vids
)
const
;
/*!
/*!
* \brief Get the out edges of the vertex.
* \brief Get the out edges of the vertex.
* \note The returned src id array is filled with vid.
* \param vid The vertex id.
* \param vid The vertex id.
* \return the id arrays of the two endpoints of the edges.
* \return the id arrays of the two endpoints of the edges.
*/
*/
std
::
pair
<
IdArray
,
IdArray
>
OutEdges
(
dgl_id_t
vid
)
const
;
std
::
pair
<
IdArray
,
IdArray
>
OutEdges
(
dgl_id_t
vid
)
const
;
/*!
/*!
* \brief Get the out edges of the vertices.
* \brief Get the out edges of the vertices.
* \param vids The vertex id array.
* \param vids The vertex id array.
* \return the id arrays of the two endpoints of the edges.
* \return the id arrays of the two endpoints of the edges.
*/
*/
std
::
pair
<
IdArray
,
IdArray
>
OutEdges
(
IdArray
vids
)
const
;
std
::
pair
<
IdArray
,
IdArray
>
OutEdges
(
IdArray
vids
)
const
;
/*!
/*!
* \brief Get all the edges in the graph.
* \brief Get all the edges in the graph.
* \return the id arrays of the two endpoints of the edges.
* \return the id arrays of the two endpoints of the edges.
*/
*/
std
::
pair
<
IdArray
,
IdArray
>
Edges
()
const
;
std
::
pair
<
IdArray
,
IdArray
>
Edges
()
const
;
/*!
/*!
* \brief Get the in degree of the given vertex.
* \brief Get the in degree of the given vertex.
* \param vid The vertex id.
* \param vid The vertex id.
* \return the in degree
* \return the in degree
*/
*/
uint64_t
InDegree
(
dgl_id_t
vid
)
const
;
uint64_t
InDegree
(
dgl_id_t
vid
)
const
{
CHECK
(
HasVertex
(
vid
))
<<
"invalid vertex: "
<<
vid
;
return
adjlist_
[
vid
].
pred
.
size
();
}
/*!
/*!
* \brief Get the in degrees of the given vertices.
* \brief Get the in degrees of the given vertices.
* \param vid The vertex id array.
* \param vid The vertex id array.
* \return the in degree array
* \return the in degree array
*/
*/
DegreeArray
InDegrees
(
IdArray
vids
)
const
;
DegreeArray
InDegrees
(
IdArray
vids
)
const
;
/*!
/*!
* \brief Get the out degree of the given vertex.
* \brief Get the out degree of the given vertex.
* \param vid The vertex id.
* \param vid The vertex id.
* \return the out degree
* \return the out degree
*/
*/
uint64_t
OutDegree
(
dgl_id_t
vid
)
const
;
uint64_t
OutDegree
(
dgl_id_t
vid
)
const
{
CHECK
(
HasVertex
(
vid
))
<<
"invalid vertex: "
<<
vid
;
return
adjlist_
[
vid
].
succ
.
size
();
}
/*!
/*!
* \brief Get the out degrees of the given vertices.
* \brief Get the out degrees of the given vertices.
* \param vid The vertex id array.
* \param vid The vertex id array.
* \return the out degree array
* \return the out degree array
*/
*/
DegreeArray
OutDegrees
(
IdArray
vids
)
const
;
DegreeArray
OutDegrees
(
IdArray
vids
)
const
;
/*!
/*!
* \brief Construct the induced subgraph of the given vertices.
* \brief Construct the induced subgraph of the given vertices.
*
*
...
@@ -162,7 +230,8 @@ class DGLGraph {
...
@@ -162,7 +230,8 @@ class DGLGraph {
* \param vids The vertices in the subgraph.
* \param vids The vertices in the subgraph.
* \return the induced subgraph
* \return the induced subgraph
*/
*/
DGLGraph
Subgraph
(
IdArray
vids
)
const
;
Graph
Subgraph
(
IdArray
vids
)
const
;
/*!
/*!
* \brief Construct the induced edge subgraph of the given edges.
* \brief Construct the induced edge subgraph of the given edges.
*
*
...
@@ -179,7 +248,8 @@ class DGLGraph {
...
@@ -179,7 +248,8 @@ class DGLGraph {
* \param vids The edges in the subgraph.
* \param vids The edges in the subgraph.
* \return the induced edge subgraph
* \return the induced edge subgraph
*/
*/
DGLGraph
EdgeSubgraph
(
IdArray
src
,
IdArray
dst
)
const
;
Graph
EdgeSubgraph
(
IdArray
src
,
IdArray
dst
)
const
;
/*!
/*!
* \brief Return a new graph with all the edges reversed.
* \brief Return a new graph with all the edges reversed.
*
*
...
@@ -187,7 +257,7 @@ class DGLGraph {
...
@@ -187,7 +257,7 @@ class DGLGraph {
*
*
* \return the reversed graph
* \return the reversed graph
*/
*/
DGL
Graph
Reverse
()
const
;
Graph
Reverse
()
const
;
private:
private:
/*! \brief Internal edge list type */
/*! \brief Internal edge list type */
...
@@ -196,7 +266,7 @@ class DGLGraph {
...
@@ -196,7 +266,7 @@ class DGLGraph {
vector_view
<
dgl_id_t
>
succ
;
vector_view
<
dgl_id_t
>
succ
;
/*! \brief predecessor vertex list */
/*! \brief predecessor vertex list */
vector_view
<
dgl_id_t
>
pred
;
vector_view
<
dgl_id_t
>
pred
;
/*! \brief (local) edge id property */
/*! \brief (local)
succ
edge id property */
std
::
vector
<
dgl_id_t
>
edge_id
;
std
::
vector
<
dgl_id_t
>
edge_id
;
};
};
/*! \brief Adjacency list using vector storage */
/*! \brief Adjacency list using vector storage */
...
@@ -210,6 +280,8 @@ class DGLGraph {
...
@@ -210,6 +280,8 @@ class DGLGraph {
vector_view
<
EdgeList
>
adjlist_
;
vector_view
<
EdgeList
>
adjlist_
;
/*! \brief read only flag */
/*! \brief read only flag */
bool
read_only_
{
false
};
bool
read_only_
{
false
};
/*! \brief number of edges */
uint64_t
num_edges_
=
0
;
};
};
}
// namespace dgl
}
// namespace dgl
...
...
include/dgl/vector_view.h
View file @
14d88497
...
@@ -15,10 +15,42 @@ namespace dgl {
...
@@ -15,10 +15,42 @@ namespace dgl {
*/
*/
template
<
typename
ValueType
>
template
<
typename
ValueType
>
class
vector_view
{
class
vector_view
{
struct
vector_view_iterator
;
public:
public:
typedef
vector_view_iterator
iterator
;
/*! \brief iterator class */
class
iterator
:
public
std
::
iterator
<
std
::
forward_iterator_tag
,
ValueType
>
{
public:
/*! \brief iterator constructor */
iterator
(
const
vector_view
<
ValueType
>*
vec
,
size_t
pos
)
:
vec_
(
vec
),
pos_
(
pos
)
{}
/*! \brief move to next */
iterator
&
operator
++
()
{
++
pos_
;
return
*
this
;
}
/*! \brief move to next */
iterator
operator
++
(
int
)
{
iterator
retval
=
*
this
;
++
(
*
this
);
return
retval
;
}
/*! \brief equal operator */
bool
operator
==
(
iterator
other
)
const
{
return
vec_
==
other
.
vec_
and
pos_
==
other
.
pos_
;
}
/*! \brief not equal operator */
bool
operator
!=
(
iterator
other
)
const
{
return
!
(
*
this
==
other
);
}
/*! \brief dereference operator */
const
ValueType
&
operator
*
()
const
{
return
(
*
vec_
)[
pos_
];
}
private:
/*! \brief vector_view pointer */
const
vector_view
<
ValueType
>*
vec_
;
/*! \brief current position */
size_t
pos_
;
};
/*! \brief Default constructor. Create an empty vector. */
/*! \brief Default constructor. Create an empty vector. */
vector_view
()
vector_view
()
:
data_
(
std
::
make_shared
<
std
::
vector
<
ValueType
>
>
())
{}
:
data_
(
std
::
make_shared
<
std
::
vector
<
ValueType
>
>
())
{}
...
@@ -28,7 +60,7 @@ class vector_view {
...
@@ -28,7 +60,7 @@ class vector_view {
:
data_
(
vec
.
data_
),
index_
(
index
),
is_view_
(
true
)
{}
:
data_
(
vec
.
data_
),
index_
(
index
),
is_view_
(
true
)
{}
/*! \brief constructor from a vector pointer */
/*! \brief constructor from a vector pointer */
vector_view
(
const
std
::
shared_ptr
<
std
::
vector
<
ValueType
>
>&
other
)
explicit
vector_view
(
const
std
::
shared_ptr
<
std
::
vector
<
ValueType
>
>&
other
)
:
data_
(
other
)
{}
:
data_
(
other
)
{}
/*! \brief default copy constructor */
/*! \brief default copy constructor */
...
@@ -53,7 +85,7 @@ class vector_view {
...
@@ -53,7 +85,7 @@ class vector_view {
~
vector_view
()
=
default
;
~
vector_view
()
=
default
;
/*! \brief default assign constructor */
/*! \brief default assign constructor */
vector_view
&
operator
=
(
const
vector_view
<
ValueType
>&
other
)
=
default
;
vector_view
<
ValueType
>
&
operator
=
(
const
vector_view
<
ValueType
>&
other
)
=
default
;
/*! \return the size of the vector */
/*! \return the size of the vector */
size_t
size
()
const
{
size_t
size
()
const
{
...
@@ -87,11 +119,15 @@ class vector_view {
...
@@ -87,11 +119,15 @@ class vector_view {
}
}
}
}
// TODO(minjie)
/*! \return an iterator pointing at the first element */
iterator
begin
()
const
;
iterator
begin
()
const
{
return
iterator
(
this
,
0
);
}
// TODO(minjie)
/*! \return an iterator pointing at the last element */
iterator
end
()
const
;
iterator
end
()
const
{
return
iterator
(
this
,
size
());
}
// Modifiers
// Modifiers
// NOTE: The modifiers are not allowed for view.
// NOTE: The modifiers are not allowed for view.
...
@@ -112,10 +148,17 @@ class vector_view {
...
@@ -112,10 +148,17 @@ class vector_view {
data_
=
std
::
make_shared
<
std
::
vector
<
ValueType
>
>
();
data_
=
std
::
make_shared
<
std
::
vector
<
ValueType
>
>
();
}
}
private:
/*! \brief Resize the vector */
struct
vector_view_iterator
{
void
resize
(
size_t
size
)
{
// TODO
CHECK
(
!
is_view_
);
};
data_
->
resize
(
size
);
}
/*! \brief Resize the vector with init value */
void
resize
(
size_t
size
,
const
ValueType
&
val
)
{
CHECK
(
!
is_view_
);
data_
->
resize
(
size
,
val
);
}
private:
private:
/*! \brief pointer to the underlying vector data */
/*! \brief pointer to the underlying vector data */
...
...
src/graph/graph.cc
View file @
14d88497
#include <dgl/runtime/packed_func.h>
// Graph class implementation
#include <dgl/runtime/registry.h>
#include <algorithm>
#include <dgl/graph.h>
using
namespace
tvm
;
namespace
dgl
{
using
namespace
tvm
::
runtime
;
namespace
{
inline
bool
IsValidIdArray
(
const
IdArray
&
arr
)
{
return
arr
->
ctx
.
device_type
==
kDLCPU
&&
arr
->
ndim
==
1
&&
arr
->
dtype
.
code
==
kDLInt
&&
arr
->
dtype
.
bits
==
64
;
}
}
// namespace
void
Graph
::
AddVertices
(
uint64_t
num_vertices
)
{
CHECK
(
!
read_only_
)
<<
"Graph is read-only. Mutations are not allowed."
;
adjlist_
.
resize
(
adjlist_
.
size
()
+
num_vertices
);
}
void
Graph
::
AddEdge
(
dgl_id_t
src
,
dgl_id_t
dst
)
{
CHECK
(
!
read_only_
)
<<
"Graph is read-only. Mutations are not allowed."
;
CHECK
(
HasVertex
(
src
)
&&
HasVertex
(
dst
))
<<
"In valid vertices: "
<<
src
<<
" "
<<
dst
;
dgl_id_t
eid
=
num_edges_
++
;
adjlist_
[
src
].
succ
.
push_back
(
dst
);
adjlist_
[
src
].
edge_id
.
push_back
(
eid
);
adjlist_
[
dst
].
pred
.
push_back
(
src
);
}
void
Graph
::
AddEdges
(
IdArray
src_ids
,
IdArray
dst_ids
)
{
CHECK
(
!
read_only_
)
<<
"Graph is read-only. Mutations are not allowed."
;
CHECK
(
IsValidIdArray
(
src_ids
))
<<
"Invalid src id array."
;
CHECK
(
IsValidIdArray
(
dst_ids
))
<<
"Invalid dst id array."
;
const
auto
srclen
=
src_ids
->
shape
[
0
];
const
auto
dstlen
=
src_ids
->
shape
[
0
];
const
int64_t
*
src_data
=
static_cast
<
int64_t
*>
(
src_ids
->
data
);
const
int64_t
*
dst_data
=
static_cast
<
int64_t
*>
(
dst_ids
->
data
);
if
(
srclen
==
1
)
{
// one-many
for
(
int64_t
i
=
0
;
i
<
dstlen
;
++
i
)
{
AddEdge
(
src_data
[
0
],
dst_data
[
i
]);
}
}
else
if
(
dstlen
==
1
)
{
// many-one
for
(
int64_t
i
=
0
;
i
<
srclen
;
++
i
)
{
AddEdge
(
src_data
[
i
],
dst_data
[
0
]);
}
}
else
{
// many-many
CHECK
(
srclen
==
dstlen
)
<<
"Invalid src and dst id array."
;
for
(
int64_t
i
=
0
;
i
<
srclen
;
++
i
)
{
AddEdge
(
src_data
[
i
],
dst_data
[
i
]);
}
}
}
BoolArray
Graph
::
HasVertices
(
IdArray
vids
)
const
{
CHECK
(
IsValidIdArray
(
vids
))
<<
"Invalid vertex id array."
;
const
auto
len
=
vids
->
shape
[
0
];
BoolArray
rst
=
BoolArray
::
Empty
({
len
},
vids
->
dtype
,
vids
->
ctx
);
const
int64_t
*
vid_data
=
static_cast
<
int64_t
*>
(
vids
->
data
);
int64_t
*
rst_data
=
static_cast
<
int64_t
*>
(
rst
->
data
);
const
uint64_t
nverts
=
NumVertices
();
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
rst_data
[
i
]
=
(
vid_data
[
i
]
<
nverts
)
?
1
:
0
;
}
return
rst
;
}
// O(E)
bool
Graph
::
HasEdge
(
dgl_id_t
src
,
dgl_id_t
dst
)
const
{
if
(
!
HasVertex
(
src
)
||
!
HasVertex
(
dst
))
return
false
;
const
auto
&
succ
=
adjlist_
[
src
].
succ
;
return
std
::
find
(
succ
.
begin
(),
succ
.
end
(),
dst
)
!=
succ
.
end
();
}
// O(E*K) pretty slow
BoolArray
Graph
::
HasEdges
(
IdArray
src_ids
,
IdArray
dst_ids
)
const
{
CHECK
(
IsValidIdArray
(
src_ids
))
<<
"Invalid src id array."
;
CHECK
(
IsValidIdArray
(
dst_ids
))
<<
"Invalid dst id array."
;
const
auto
srclen
=
src_ids
->
shape
[
0
];
const
auto
dstlen
=
src_ids
->
shape
[
0
];
const
auto
rstlen
=
std
::
max
(
srclen
,
dstlen
);
BoolArray
rst
=
BoolArray
::
Empty
({
rstlen
},
src_ids
->
dtype
,
src_ids
->
ctx
);
int64_t
*
rst_data
=
static_cast
<
int64_t
*>
(
rst
->
data
);
const
int64_t
*
src_data
=
static_cast
<
int64_t
*>
(
src_ids
->
data
);
const
int64_t
*
dst_data
=
static_cast
<
int64_t
*>
(
dst_ids
->
data
);
if
(
srclen
==
1
)
{
// one-many
for
(
int64_t
i
=
0
;
i
<
dstlen
;
++
i
)
{
rst_data
[
i
]
=
HasEdge
(
src_data
[
0
],
dst_data
[
i
])
?
1
:
0
;
}
}
else
if
(
dstlen
==
1
)
{
// many-one
for
(
int64_t
i
=
0
;
i
<
srclen
;
++
i
)
{
rst_data
[
i
]
=
HasEdge
(
src_data
[
i
],
dst_data
[
0
])
?
1
:
0
;
}
}
else
{
// many-many
CHECK
(
srclen
==
dstlen
)
<<
"Invalid src and dst id array."
;
for
(
int64_t
i
=
0
;
i
<
srclen
;
++
i
)
{
rst_data
[
i
]
=
HasEdge
(
src_data
[
i
],
dst_data
[
i
])
?
1
:
0
;
}
}
return
rst
;
}
// The data is copy-out; support zero-copy?
IdArray
Graph
::
Predecessors
(
dgl_id_t
vid
)
const
{
CHECK
(
HasVertex
(
vid
))
<<
"invalid vertex: "
<<
vid
;
const
auto
&
pred
=
adjlist_
[
vid
].
pred
;
const
int64_t
len
=
pred
.
size
();
IdArray
rst
=
IdArray
::
Empty
({
len
},
DLDataType
{
kDLInt
,
64
,
1
},
DLContext
{
kDLCPU
,
0
});
int64_t
*
rst_data
=
static_cast
<
int64_t
*>
(
rst
->
data
);
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
rst_data
[
i
]
=
pred
[
i
];
}
return
rst
;
}
// The data is copy-out; support zero-copy?
IdArray
Graph
::
Successors
(
dgl_id_t
vid
)
const
{
CHECK
(
HasVertex
(
vid
))
<<
"invalid vertex: "
<<
vid
;
const
auto
&
succ
=
adjlist_
[
vid
].
succ
;
const
int64_t
len
=
succ
.
size
();
IdArray
rst
=
IdArray
::
Empty
({
len
},
DLDataType
{
kDLInt
,
64
,
1
},
DLContext
{
kDLCPU
,
0
});
int64_t
*
rst_data
=
static_cast
<
int64_t
*>
(
rst
->
data
);
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
rst_data
[
i
]
=
succ
[
i
];
}
return
rst
;
}
// O(E)
dgl_id_t
Graph
::
EdgeId
(
dgl_id_t
src
,
dgl_id_t
dst
)
const
{
CHECK
(
HasVertex
(
src
))
<<
"invalid edge: "
<<
src
<<
" -> "
<<
dst
;
const
auto
&
succ
=
adjlist_
[
src
].
succ
;
for
(
size_t
i
=
0
;
i
<
succ
.
size
();
++
i
)
{
if
(
succ
[
i
]
==
dst
)
{
return
adjlist_
[
src
].
edge_id
[
i
];
}
}
LOG
(
FATAL
)
<<
"invalid edge: "
<<
src
<<
" -> "
<<
dst
;
return
0
;
}
// O(E*k) pretty slow
IdArray
Graph
::
EdgeIds
(
IdArray
src_ids
,
IdArray
dst_ids
)
const
{
CHECK
(
IsValidIdArray
(
src_ids
))
<<
"Invalid src id array."
;
CHECK
(
IsValidIdArray
(
dst_ids
))
<<
"Invalid dst id array."
;
const
auto
srclen
=
src_ids
->
shape
[
0
];
const
auto
dstlen
=
src_ids
->
shape
[
0
];
const
auto
rstlen
=
std
::
max
(
srclen
,
dstlen
);
IdArray
rst
=
IdArray
::
Empty
({
rstlen
},
src_ids
->
dtype
,
src_ids
->
ctx
);
int64_t
*
rst_data
=
static_cast
<
int64_t
*>
(
rst
->
data
);
const
int64_t
*
src_data
=
static_cast
<
int64_t
*>
(
src_ids
->
data
);
const
int64_t
*
dst_data
=
static_cast
<
int64_t
*>
(
dst_ids
->
data
);
if
(
srclen
==
1
)
{
// one-many
for
(
int64_t
i
=
0
;
i
<
dstlen
;
++
i
)
{
rst_data
[
i
]
=
EdgeId
(
src_data
[
0
],
dst_data
[
i
]);
}
}
else
if
(
dstlen
==
1
)
{
// many-one
for
(
int64_t
i
=
0
;
i
<
srclen
;
++
i
)
{
rst_data
[
i
]
=
EdgeId
(
src_data
[
i
],
dst_data
[
0
]);
}
}
else
{
// many-many
CHECK
(
srclen
==
dstlen
)
<<
"Invalid src and dst id array."
;
for
(
int64_t
i
=
0
;
i
<
srclen
;
++
i
)
{
rst_data
[
i
]
=
EdgeId
(
src_data
[
i
],
dst_data
[
i
]);
}
}
return
rst
;
}
// O(E)
std
::
pair
<
IdArray
,
IdArray
>
Graph
::
InEdges
(
dgl_id_t
vid
)
const
{
const
auto
&
src
=
Predecessors
(
vid
);
const
auto
srclen
=
src
->
shape
[
0
];
IdArray
dst
=
IdArray
::
Empty
({
srclen
},
src
->
dtype
,
src
->
ctx
);
int64_t
*
dst_data
=
static_cast
<
int64_t
*>
(
dst
->
data
);
std
::
fill
(
dst_data
,
dst_data
+
srclen
,
vid
);
return
std
::
make_pair
(
src
,
dst
);
}
// O(E)
std
::
pair
<
IdArray
,
IdArray
>
Graph
::
InEdges
(
IdArray
vids
)
const
{
CHECK
(
IsValidIdArray
(
vids
))
<<
"Invalid vertex id array."
;
const
auto
len
=
vids
->
shape
[
0
];
const
int64_t
*
vid_data
=
static_cast
<
int64_t
*>
(
vids
->
data
);
int64_t
rstlen
=
0
;
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
CHECK
(
HasVertex
(
vid_data
[
i
]))
<<
"Invalid vertex: "
<<
vid_data
[
i
];
rstlen
+=
adjlist_
[
vid_data
[
i
]].
pred
.
size
();
}
IdArray
src
=
IdArray
::
Empty
({
rstlen
},
vids
->
dtype
,
vids
->
ctx
);
IdArray
dst
=
IdArray
::
Empty
({
rstlen
},
vids
->
dtype
,
vids
->
ctx
);
int64_t
*
src_ptr
=
static_cast
<
int64_t
*>
(
src
->
data
);
int64_t
*
dst_ptr
=
static_cast
<
int64_t
*>
(
dst
->
data
);
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
const
auto
&
pred
=
adjlist_
[
vid_data
[
i
]].
pred
;
for
(
size_t
j
=
0
;
j
<
pred
.
size
();
++
j
)
{
*
(
src_ptr
++
)
=
pred
[
j
];
*
(
dst_ptr
++
)
=
vid_data
[
i
];
}
}
return
std
::
make_pair
(
src
,
dst
);
}
// O(E)
std
::
pair
<
IdArray
,
IdArray
>
Graph
::
OutEdges
(
dgl_id_t
vid
)
const
{
const
auto
&
dst
=
Successors
(
vid
);
const
auto
dstlen
=
dst
->
shape
[
0
];
IdArray
src
=
IdArray
::
Empty
({
dstlen
},
dst
->
dtype
,
dst
->
ctx
);
int64_t
*
src_data
=
static_cast
<
int64_t
*>
(
src
->
data
);
std
::
fill
(
src_data
,
src_data
+
dstlen
,
vid
);
return
std
::
make_pair
(
src
,
dst
);
}
// O(E)
std
::
pair
<
IdArray
,
IdArray
>
Graph
::
OutEdges
(
IdArray
vids
)
const
{
CHECK
(
IsValidIdArray
(
vids
))
<<
"Invalid vertex id array."
;
const
auto
len
=
vids
->
shape
[
0
];
const
int64_t
*
vid_data
=
static_cast
<
int64_t
*>
(
vids
->
data
);
int64_t
rstlen
=
0
;
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
CHECK
(
HasVertex
(
vid_data
[
i
]))
<<
"Invalid vertex: "
<<
vid_data
[
i
];
rstlen
+=
adjlist_
[
vid_data
[
i
]].
succ
.
size
();
}
IdArray
src
=
IdArray
::
Empty
({
rstlen
},
vids
->
dtype
,
vids
->
ctx
);
IdArray
dst
=
IdArray
::
Empty
({
rstlen
},
vids
->
dtype
,
vids
->
ctx
);
int64_t
*
src_ptr
=
static_cast
<
int64_t
*>
(
src
->
data
);
int64_t
*
dst_ptr
=
static_cast
<
int64_t
*>
(
dst
->
data
);
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
const
auto
&
succ
=
adjlist_
[
vid_data
[
i
]].
succ
;
for
(
size_t
j
=
0
;
j
<
succ
.
size
();
++
j
)
{
*
(
src_ptr
++
)
=
vid_data
[
i
];
*
(
dst_ptr
++
)
=
succ
[
j
];
}
}
return
std
::
make_pair
(
src
,
dst
);
}
// O(E*log(E)) due to sorting
std
::
pair
<
IdArray
,
IdArray
>
Graph
::
Edges
()
const
{
const
int64_t
len
=
num_edges_
;
typedef
std
::
tuple
<
int64_t
,
int64_t
,
int64_t
>
Tuple
;
std
::
vector
<
Tuple
>
tuples
;
tuples
.
reserve
(
len
);
for
(
dgl_id_t
u
=
0
;
u
<
NumVertices
();
++
u
)
{
for
(
size_t
i
=
0
;
i
<
adjlist_
[
u
].
succ
.
size
();
++
i
)
{
tuples
.
push_back
(
std
::
make_tuple
(
u
,
adjlist_
[
u
].
succ
[
i
],
adjlist_
[
u
].
edge_id
[
i
]));
}
}
// sort according to edge ids
std
::
sort
(
tuples
.
begin
(),
tuples
.
end
(),
[]
(
const
Tuple
&
t1
,
const
Tuple
&
t2
)
{
return
std
::
get
<
2
>
(
t1
)
<
std
::
get
<
2
>
(
t2
);
});
// make return arrays
IdArray
src
=
IdArray
::
Empty
({
len
},
DLDataType
{
kDLInt
,
64
,
1
},
DLContext
{
kDLCPU
,
0
});
IdArray
dst
=
IdArray
::
Empty
({
len
},
DLDataType
{
kDLInt
,
64
,
1
},
DLContext
{
kDLCPU
,
0
});
int64_t
*
src_ptr
=
static_cast
<
int64_t
*>
(
src
->
data
);
int64_t
*
dst_ptr
=
static_cast
<
int64_t
*>
(
dst
->
data
);
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
src_ptr
[
i
]
=
std
::
get
<
0
>
(
tuples
[
i
]);
dst_ptr
[
i
]
=
std
::
get
<
1
>
(
tuples
[
i
]);
}
return
std
::
make_pair
(
src
,
dst
);
}
// O(V)
DegreeArray
Graph
::
InDegrees
(
IdArray
vids
)
const
{
CHECK
(
IsValidIdArray
(
vids
))
<<
"Invalid vertex id array."
;
const
auto
len
=
vids
->
shape
[
0
];
const
int64_t
*
vid_data
=
static_cast
<
int64_t
*>
(
vids
->
data
);
DegreeArray
rst
=
DegreeArray
::
Empty
({
len
},
vids
->
dtype
,
vids
->
ctx
);
int64_t
*
rst_data
=
static_cast
<
int64_t
*>
(
rst
->
data
);
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
const
auto
vid
=
vid_data
[
i
];
CHECK
(
HasVertex
(
vid
))
<<
"Invalid vertex: "
<<
vid
;
rst_data
[
i
]
=
adjlist_
[
vid
].
pred
.
size
();
}
return
rst
;
}
// O(V)
DegreeArray
Graph
::
OutDegrees
(
IdArray
vids
)
const
{
CHECK
(
IsValidIdArray
(
vids
))
<<
"Invalid vertex id array."
;
const
auto
len
=
vids
->
shape
[
0
];
const
int64_t
*
vid_data
=
static_cast
<
int64_t
*>
(
vids
->
data
);
DegreeArray
rst
=
DegreeArray
::
Empty
({
len
},
vids
->
dtype
,
vids
->
ctx
);
int64_t
*
rst_data
=
static_cast
<
int64_t
*>
(
rst
->
data
);
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
const
auto
vid
=
vid_data
[
i
];
CHECK
(
HasVertex
(
vid
))
<<
"Invalid vertex: "
<<
vid
;
rst_data
[
i
]
=
adjlist_
[
vid
].
succ
.
size
();
}
return
rst
;
}
Graph
Graph
::
Subgraph
(
IdArray
vids
)
const
{
LOG
(
FATAL
)
<<
"not implemented"
;
return
*
this
;
}
void
MyAdd
(
TVMArgs
args
,
TVMRetValue
*
rv
)
{
Graph
Graph
::
EdgeSubgraph
(
IdArray
src
,
IdArray
dst
)
const
{
int
a
=
args
[
0
];
LOG
(
FATAL
)
<<
"not implemented"
;
int
b
=
args
[
1
];
return
*
this
;
*
rv
=
a
+
b
;
}
}
void
CallPacked
()
{
Graph
Graph
::
Reverse
()
const
{
PackedFunc
myadd
=
PackedFunc
(
MyAdd
)
;
LOG
(
FATAL
)
<<
"not implemented"
;
int
c
=
myadd
(
1
,
2
)
;
return
*
this
;
}
}
TVM_REGISTER_GLOBAL
(
"myadd"
)
}
// namespace dgl
.
set_body
(
MyAdd
);
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment