Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
7b5d4c58
Commit
7b5d4c58
authored
Sep 20, 2018
by
Minjie Wang
Browse files
pass specialization test
parent
b2e4bdc0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
91 additions
and
118 deletions
+91
-118
include/dgl/graph.h
include/dgl/graph.h
+20
-20
python/dgl/graph.py
python/dgl/graph.py
+15
-30
python/dgl/graph_index.py
python/dgl/graph_index.py
+25
-1
python/dgl/scheduler.py
python/dgl/scheduler.py
+3
-3
src/graph/graph.cc
src/graph/graph.cc
+27
-27
tests/pytorch/test_cached_graph.py
tests/pytorch/test_cached_graph.py
+0
-35
tests/pytorch/test_specialization.py
tests/pytorch/test_specialization.py
+1
-2
No files found.
include/dgl/graph.h
View file @
7b5d4c58
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
#include <stdint.h>
#include <stdint.h>
#include "runtime/ndarray.h"
#include "runtime/ndarray.h"
#include "./vector_view.h"
namespace
dgl
{
namespace
dgl
{
...
@@ -89,7 +88,8 @@ class Graph {
...
@@ -89,7 +88,8 @@ class Graph {
* \brief Clear the graph. Remove all vertices/edges.
* \brief Clear the graph. Remove all vertices/edges.
*/
*/
void
Clear
()
{
void
Clear
()
{
adjlist_
=
vector_view
<
EdgeList
>
();
adjlist_
.
clear
();
reverse_adjlist_
.
clear
();
read_only_
=
false
;
read_only_
=
false
;
num_edges_
=
0
;
num_edges_
=
0
;
}
}
...
@@ -184,8 +184,9 @@ class Graph {
...
@@ -184,8 +184,9 @@ class Graph {
/*!
/*!
* \brief Get all the edges in the graph.
* \brief Get all the edges in the graph.
* \note If sorted is true, the id array is not returned.
* \note If sorted is true, the returned edges list is sorted by their src and
* \param sorted Whether the returned edge list is sorted by their edge ids.
* dst ids. Otherwise, they are in their edge id order.
* \param sorted Whether the returned edge list is sorted by their src and dst ids
* \return the id arrays of the two endpoints of the edges.
* \return the id arrays of the two endpoints of the edges.
*/
*/
EdgeArray
Edges
(
bool
sorted
=
false
)
const
;
EdgeArray
Edges
(
bool
sorted
=
false
)
const
;
...
@@ -197,7 +198,7 @@ class Graph {
...
@@ -197,7 +198,7 @@ class Graph {
*/
*/
uint64_t
InDegree
(
dgl_id_t
vid
)
const
{
uint64_t
InDegree
(
dgl_id_t
vid
)
const
{
CHECK
(
HasVertex
(
vid
))
<<
"invalid vertex: "
<<
vid
;
CHECK
(
HasVertex
(
vid
))
<<
"invalid vertex: "
<<
vid
;
return
adjlist_
[
vid
].
pred
.
size
();
return
reverse_
adjlist_
[
vid
].
succ
.
size
();
}
}
/*!
/*!
...
@@ -277,23 +278,22 @@ class Graph {
...
@@ -277,23 +278,22 @@ class Graph {
/*! \brief Internal edge list type */
/*! \brief Internal edge list type */
struct
EdgeList
{
struct
EdgeList
{
/*! \brief successor vertex list */
/*! \brief successor vertex list */
vector
_view
<
dgl_id_t
>
succ
;
std
::
vector
<
dgl_id_t
>
succ
;
/*! \brief predecessor vertex list */
/*! \brief predecessor vertex list */
vector_view
<
dgl_id_t
>
pred
;
std
::
vector
<
dgl_id_t
>
edge_id
;
/*! \brief (local) succ edge id property */
std
::
vector
<
dgl_id_t
>
succ_edge_id
;
/*! \brief (local) pred edge id property */
std
::
vector
<
dgl_id_t
>
pred_edge_id
;
};
};
/*! \brief Adjacency list using vector storage */
typedef
std
::
vector
<
EdgeList
>
AdjacencyList
;
// TODO(minjie): adjacent list is good for graph mutation and finding pred/succ.
// It is not good for getting all the edges of the graph. If the graph is known
/*! \brief adjacency list using vector storage */
// to be static, how to design a data structure to speed this up? This idea can
AdjacencyList
adjlist_
;
// be further extended. For example, CSC/CSR graph storage is known to be more
/*! \brief reverse adjacency list using vector storage */
// compact than adjlist, but is difficult to be mutated. Shall we switch to a CSR/CSC
AdjacencyList
reverse_adjlist_
;
// graph structure at some point? When shall such conversion happen? Which one
// will more likely to be a bottleneck? memory or computation?
/*! \brief all edges' src endpoints in their edge id order */
vector_view
<
EdgeList
>
adjlist_
;
std
::
vector
<
dgl_id_t
>
all_edges_src_
;
/*! \brief all edges' dst endpoints in their edge id order */
std
::
vector
<
dgl_id_t
>
all_edges_dst_
;
/*! \brief read only flag */
/*! \brief read only flag */
bool
read_only_
{
false
};
bool
read_only_
{
false
};
/*! \brief number of edges */
/*! \brief number of edges */
...
...
python/dgl/graph.py
View file @
7b5d4c58
...
@@ -121,6 +121,10 @@ class DGLGraph(object):
...
@@ -121,6 +121,10 @@ class DGLGraph(object):
"""
"""
return
self
.
_graph
.
number_of_nodes
()
return
self
.
_graph
.
number_of_nodes
()
def
__len__
(
self
):
"""Return the number of nodes."""
return
self
.
number_of_nodes
()
def
number_of_edges
(
self
):
def
number_of_edges
(
self
):
"""Return the number of edges.
"""Return the number of edges.
...
@@ -145,6 +149,10 @@ class DGLGraph(object):
...
@@ -145,6 +149,10 @@ class DGLGraph(object):
True if the node exists
True if the node exists
"""
"""
return
self
.
has_node
(
vid
)
return
self
.
has_node
(
vid
)
def
__contains__
(
self
,
vid
):
"""Same as has_node."""
return
self
.
has_node
(
vid
)
def
has_nodes
(
self
,
vids
):
def
has_nodes
(
self
,
vids
):
"""Return true if the nodes exist.
"""Return true if the nodes exist.
...
@@ -319,7 +327,7 @@ class DGLGraph(object):
...
@@ -319,7 +327,7 @@ class DGLGraph(object):
Parameters
Parameters
----------
----------
sorted : bool
sorted : bool
True if the returned edges are sorted by their ids.
True if the returned edges are sorted by their
src and dst
ids.
Returns
Returns
-------
-------
...
@@ -543,29 +551,12 @@ class DGLGraph(object):
...
@@ -543,29 +551,12 @@ class DGLGraph(object):
v_is_all
=
is_all
(
v
)
v_is_all
=
is_all
(
v
)
assert
u_is_all
==
v_is_all
assert
u_is_all
==
v_is_all
if
u_is_all
:
if
u_is_all
:
num_edges
=
self
.
number_of_edges
(
)
self
.
set_e_repr_by_id
(
h_uv
,
eid
=
ALL
)
else
:
else
:
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
v
=
utils
.
toindex
(
v
)
num_edges
=
max
(
len
(
u
),
len
(
v
))
if
utils
.
is_dict_like
(
h_uv
):
for
key
,
val
in
h_uv
.
items
():
assert
F
.
shape
(
val
)[
0
]
==
num_edges
else
:
assert
F
.
shape
(
h_uv
)[
0
]
==
num_edges
# set
if
u_is_all
:
if
utils
.
is_dict_like
(
h_uv
):
for
key
,
val
in
h_uv
.
items
():
self
.
_edge_frame
[
key
]
=
val
else
:
self
.
_edge_frame
[
__REPR__
]
=
h_uv
else
:
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
if
utils
.
is_dict_like
(
h_uv
):
self
.
set_e_repr_by_id
(
h_uv
,
eid
=
eid
)
self
.
_edge_frame
[
eid
]
=
h_uv
else
:
self
.
_edge_frame
[
eid
]
=
{
__REPR__
:
h_uv
}
def
set_e_repr_by_id
(
self
,
h_uv
,
eid
=
ALL
):
def
set_e_repr_by_id
(
self
,
h_uv
,
eid
=
ALL
):
"""Set edge(s) representation by edge id.
"""Set edge(s) representation by edge id.
...
@@ -622,18 +613,12 @@ class DGLGraph(object):
...
@@ -622,18 +613,12 @@ class DGLGraph(object):
if
len
(
self
.
edge_attr_schemes
())
==
0
:
if
len
(
self
.
edge_attr_schemes
())
==
0
:
return
dict
()
return
dict
()
if
u_is_all
:
if
u_is_all
:
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
return
self
.
get_e_repr_by_id
(
eid
=
ALL
)
return
self
.
_edge_frame
[
__REPR__
]
else
:
return
dict
(
self
.
_edge_frame
)
else
:
else
:
u
=
utils
.
toindex
(
u
)
u
=
utils
.
toindex
(
u
)
v
=
utils
.
toindex
(
v
)
v
=
utils
.
toindex
(
v
)
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
eid
=
self
.
_graph
.
edge_ids
(
u
,
v
)
if
len
(
self
.
_edge_frame
)
==
1
and
__REPR__
in
self
.
_edge_frame
:
return
self
.
get_e_repr_by_id
(
eid
=
eid
)
return
self
.
_edge_frame
.
select_rows
(
eid
)[
__REPR__
]
else
:
return
self
.
_edge_frame
.
select_rows
(
eid
)
def
pop_e_repr
(
self
,
key
=
__REPR__
):
def
pop_e_repr
(
self
,
key
=
__REPR__
):
"""Get and remove the specified edge repr.
"""Get and remove the specified edge repr.
...
@@ -855,7 +840,7 @@ class DGLGraph(object):
...
@@ -855,7 +840,7 @@ class DGLGraph(object):
def
_batch_send
(
self
,
u
,
v
,
message_func
):
def
_batch_send
(
self
,
u
,
v
,
message_func
):
if
is_all
(
u
)
and
is_all
(
v
):
if
is_all
(
u
)
and
is_all
(
v
):
u
,
v
,
_
=
self
.
_graph
.
edges
(
sorted
=
True
)
u
,
v
,
_
=
self
.
_graph
.
edges
()
self
.
_msg_graph
.
add_edges
(
u
,
v
)
self
.
_msg_graph
.
add_edges
(
u
,
v
)
# call UDF
# call UDF
src_reprs
=
self
.
get_n_repr
(
u
)
src_reprs
=
self
.
get_n_repr
(
u
)
...
@@ -920,7 +905,7 @@ class DGLGraph(object):
...
@@ -920,7 +905,7 @@ class DGLGraph(object):
def
_batch_update_edge
(
self
,
u
,
v
,
edge_func
):
def
_batch_update_edge
(
self
,
u
,
v
,
edge_func
):
if
is_all
(
u
)
and
is_all
(
v
):
if
is_all
(
u
)
and
is_all
(
v
):
u
,
v
=
self
.
_graph
.
edges
(
sorted
=
True
)
u
,
v
=
self
.
_graph
.
edges
()
# call the UDF
# call the UDF
src_reprs
=
self
.
get_n_repr
(
u
)
src_reprs
=
self
.
get_n_repr
(
u
)
dst_reprs
=
self
.
get_n_repr
(
v
)
dst_reprs
=
self
.
get_n_repr
(
v
)
...
...
python/dgl/graph_index.py
View file @
7b5d4c58
...
@@ -21,6 +21,7 @@ class GraphIndex(object):
...
@@ -21,6 +21,7 @@ class GraphIndex(object):
self
.
from_networkx
(
graph_data
)
self
.
from_networkx
(
graph_data
)
elif
graph_data
is
not
None
:
elif
graph_data
is
not
None
:
self
.
from_networkx
(
nx
.
DiGraph
(
graph_data
))
self
.
from_networkx
(
nx
.
DiGraph
(
graph_data
))
self
.
_cache
=
{}
def
__del__
(
self
):
def
__del__
(
self
):
"""Free this graph index object."""
"""Free this graph index object."""
...
@@ -35,6 +36,7 @@ class GraphIndex(object):
...
@@ -35,6 +36,7 @@ class GraphIndex(object):
Number of nodes to be added.
Number of nodes to be added.
"""
"""
_CAPI_DGLGraphAddVertices
(
self
.
_handle
,
num
);
_CAPI_DGLGraphAddVertices
(
self
.
_handle
,
num
);
self
.
_cache
.
clear
()
def
add_edge
(
self
,
u
,
v
):
def
add_edge
(
self
,
u
,
v
):
"""Add one edge.
"""Add one edge.
...
@@ -47,6 +49,7 @@ class GraphIndex(object):
...
@@ -47,6 +49,7 @@ class GraphIndex(object):
The dst node.
The dst node.
"""
"""
_CAPI_DGLGraphAddEdge
(
self
.
_handle
,
u
,
v
);
_CAPI_DGLGraphAddEdge
(
self
.
_handle
,
u
,
v
);
self
.
_cache
.
clear
()
def
add_edges
(
self
,
u
,
v
):
def
add_edges
(
self
,
u
,
v
):
"""Add many edges.
"""Add many edges.
...
@@ -61,10 +64,12 @@ class GraphIndex(object):
...
@@ -61,10 +64,12 @@ class GraphIndex(object):
u_array
=
u
.
todgltensor
()
u_array
=
u
.
todgltensor
()
v_array
=
v
.
todgltensor
()
v_array
=
v
.
todgltensor
()
_CAPI_DGLGraphAddEdges
(
self
.
_handle
,
u_array
,
v_array
)
_CAPI_DGLGraphAddEdges
(
self
.
_handle
,
u_array
,
v_array
)
self
.
_cache
.
clear
()
def
clear
(
self
):
def
clear
(
self
):
"""Clear the graph."""
"""Clear the graph."""
_CAPI_DGLGraphClear
(
self
.
_handle
)
_CAPI_DGLGraphClear
(
self
.
_handle
)
self
.
_cache
.
clear
()
def
number_of_nodes
(
self
):
def
number_of_nodes
(
self
):
"""Return the number of nodes.
"""Return the number of nodes.
...
@@ -283,7 +288,7 @@ class GraphIndex(object):
...
@@ -283,7 +288,7 @@ class GraphIndex(object):
Parameters
Parameters
----------
----------
sorted : bool
sorted : bool
True if the returned edges are sorted by their ids.
True if the returned edges are sorted by their
src and dst
ids.
Returns
Returns
-------
-------
...
@@ -362,6 +367,25 @@ class GraphIndex(object):
...
@@ -362,6 +367,25 @@ class GraphIndex(object):
v_array
=
v
.
todgltensor
()
v_array
=
v
.
todgltensor
()
return
utils
.
toindex
(
_CAPI_DGLGraphOutDegrees
(
self
.
_handle
,
v_array
))
return
utils
.
toindex
(
_CAPI_DGLGraphOutDegrees
(
self
.
_handle
,
v_array
))
def
adjacency_matrix
(
self
):
"""Return the adjacency matrix representation of this graph.
Returns
-------
utils.CtxCachedObject
An object that returns tensor given context.
"""
if
not
'adj'
in
self
.
_cache
:
src
,
dst
,
_
=
self
.
edges
(
sorted
=
False
)
src
=
F
.
unsqueeze
(
src
.
tousertensor
(),
0
)
dst
=
F
.
unsqueeze
(
dst
.
tousertensor
(),
0
)
idx
=
F
.
pack
([
dst
,
src
])
n
=
self
.
number_of_nodes
()
dat
=
F
.
ones
((
self
.
number_of_edges
(),))
mat
=
F
.
sparse_tensor
(
idx
,
dat
,
[
n
,
n
])
self
.
_cache
[
'adj'
]
=
utils
.
CtxCachedObject
(
lambda
ctx
:
F
.
to_context
(
mat
,
ctx
))
return
self
.
_cache
[
'adj'
]
def
to_networkx
(
self
):
def
to_networkx
(
self
):
"""Convert to networkx graph.
"""Convert to networkx graph.
...
...
python/dgl/scheduler.py
View file @
7b5d4c58
...
@@ -134,7 +134,7 @@ class UpdateAllExecutor(BasicExecutor):
...
@@ -134,7 +134,7 @@ class UpdateAllExecutor(BasicExecutor):
@
property
@
property
def
graph_idx
(
self
):
def
graph_idx
(
self
):
if
self
.
_graph_idx
is
None
:
if
self
.
_graph_idx
is
None
:
self
.
_graph_idx
=
self
.
g
.
cached
_graph
.
adj
mat
()
self
.
_graph_idx
=
self
.
g
.
_graph
.
adj
acency_matrix
()
return
self
.
_graph_idx
return
self
.
_graph_idx
@
property
@
property
...
@@ -221,8 +221,8 @@ class SendRecvExecutor(BasicExecutor):
...
@@ -221,8 +221,8 @@ class SendRecvExecutor(BasicExecutor):
def
_build_adjmat
(
self
):
def
_build_adjmat
(
self
):
# handle graph index
# handle graph index
new2old
,
old2new
=
utils
.
build_relabel_map
(
self
.
v
)
new2old
,
old2new
=
utils
.
build_relabel_map
(
self
.
v
)
u
=
self
.
u
.
totensor
()
u
=
self
.
u
.
to
user
tensor
()
v
=
self
.
v
.
totensor
()
v
=
self
.
v
.
to
user
tensor
()
# TODO(minjie): should not directly use []
# TODO(minjie): should not directly use []
new_v
=
old2new
[
v
]
new_v
=
old2new
[
v
]
n
=
self
.
g
.
number_of_nodes
()
n
=
self
.
g
.
number_of_nodes
()
...
...
src/graph/graph.cc
View file @
7b5d4c58
...
@@ -13,6 +13,7 @@ inline bool IsValidIdArray(const IdArray& arr) {
...
@@ -13,6 +13,7 @@ inline bool IsValidIdArray(const IdArray& arr) {
void
Graph
::
AddVertices
(
uint64_t
num_vertices
)
{
void
Graph
::
AddVertices
(
uint64_t
num_vertices
)
{
CHECK
(
!
read_only_
)
<<
"Graph is read-only. Mutations are not allowed."
;
CHECK
(
!
read_only_
)
<<
"Graph is read-only. Mutations are not allowed."
;
adjlist_
.
resize
(
adjlist_
.
size
()
+
num_vertices
);
adjlist_
.
resize
(
adjlist_
.
size
()
+
num_vertices
);
reverse_adjlist_
.
resize
(
reverse_adjlist_
.
size
()
+
num_vertices
);
}
}
void
Graph
::
AddEdge
(
dgl_id_t
src
,
dgl_id_t
dst
)
{
void
Graph
::
AddEdge
(
dgl_id_t
src
,
dgl_id_t
dst
)
{
...
@@ -21,9 +22,11 @@ void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
...
@@ -21,9 +22,11 @@ void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
<<
"In valid vertices: "
<<
src
<<
" "
<<
dst
;
<<
"In valid vertices: "
<<
src
<<
" "
<<
dst
;
dgl_id_t
eid
=
num_edges_
++
;
dgl_id_t
eid
=
num_edges_
++
;
adjlist_
[
src
].
succ
.
push_back
(
dst
);
adjlist_
[
src
].
succ
.
push_back
(
dst
);
adjlist_
[
src
].
succ_edge_id
.
push_back
(
eid
);
adjlist_
[
src
].
edge_id
.
push_back
(
eid
);
adjlist_
[
dst
].
pred
.
push_back
(
src
);
reverse_adjlist_
[
dst
].
succ
.
push_back
(
src
);
adjlist_
[
dst
].
pred_edge_id
.
push_back
(
eid
);
reverse_adjlist_
[
dst
].
edge_id
.
push_back
(
eid
);
all_edges_src_
.
push_back
(
src
);
all_edges_dst_
.
push_back
(
dst
);
}
}
void
Graph
::
AddEdges
(
IdArray
src_ids
,
IdArray
dst_ids
)
{
void
Graph
::
AddEdges
(
IdArray
src_ids
,
IdArray
dst_ids
)
{
...
@@ -108,7 +111,7 @@ BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const {
...
@@ -108,7 +111,7 @@ BoolArray Graph::HasEdges(IdArray src_ids, IdArray dst_ids) const {
IdArray
Graph
::
Predecessors
(
dgl_id_t
vid
,
uint64_t
radius
)
const
{
IdArray
Graph
::
Predecessors
(
dgl_id_t
vid
,
uint64_t
radius
)
const
{
CHECK
(
HasVertex
(
vid
))
<<
"invalid vertex: "
<<
vid
;
CHECK
(
HasVertex
(
vid
))
<<
"invalid vertex: "
<<
vid
;
CHECK
(
radius
>=
1
)
<<
"invalid radius: "
<<
radius
;
CHECK
(
radius
>=
1
)
<<
"invalid radius: "
<<
radius
;
const
auto
&
pred
=
adjlist_
[
vid
].
pred
;
const
auto
&
pred
=
reverse_
adjlist_
[
vid
].
succ
;
const
int64_t
len
=
pred
.
size
();
const
int64_t
len
=
pred
.
size
();
IdArray
rst
=
IdArray
::
Empty
({
len
},
DLDataType
{
kDLInt
,
64
,
1
},
DLContext
{
kDLCPU
,
0
});
IdArray
rst
=
IdArray
::
Empty
({
len
},
DLDataType
{
kDLInt
,
64
,
1
},
DLContext
{
kDLCPU
,
0
});
int64_t
*
rst_data
=
static_cast
<
int64_t
*>
(
rst
->
data
);
int64_t
*
rst_data
=
static_cast
<
int64_t
*>
(
rst
->
data
);
...
@@ -138,7 +141,7 @@ dgl_id_t Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
...
@@ -138,7 +141,7 @@ dgl_id_t Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
const
auto
&
succ
=
adjlist_
[
src
].
succ
;
const
auto
&
succ
=
adjlist_
[
src
].
succ
;
for
(
size_t
i
=
0
;
i
<
succ
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
succ
.
size
();
++
i
)
{
if
(
succ
[
i
]
==
dst
)
{
if
(
succ
[
i
]
==
dst
)
{
return
adjlist_
[
src
].
succ_
edge_id
[
i
];
return
adjlist_
[
src
].
edge_id
[
i
];
}
}
}
}
LOG
(
FATAL
)
<<
"invalid edge: "
<<
src
<<
" -> "
<<
dst
;
LOG
(
FATAL
)
<<
"invalid edge: "
<<
src
<<
" -> "
<<
dst
;
...
@@ -179,7 +182,7 @@ IdArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
...
@@ -179,7 +182,7 @@ IdArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
// O(E)
// O(E)
Graph
::
EdgeArray
Graph
::
InEdges
(
dgl_id_t
vid
)
const
{
Graph
::
EdgeArray
Graph
::
InEdges
(
dgl_id_t
vid
)
const
{
CHECK
(
HasVertex
(
vid
))
<<
"invalid vertex: "
<<
vid
;
CHECK
(
HasVertex
(
vid
))
<<
"invalid vertex: "
<<
vid
;
const
int64_t
len
=
adjlist_
[
vid
].
pred
.
size
();
const
int64_t
len
=
reverse_
adjlist_
[
vid
].
succ
.
size
();
IdArray
src
=
IdArray
::
Empty
({
len
},
DLDataType
{
kDLInt
,
64
,
1
},
DLContext
{
kDLCPU
,
0
});
IdArray
src
=
IdArray
::
Empty
({
len
},
DLDataType
{
kDLInt
,
64
,
1
},
DLContext
{
kDLCPU
,
0
});
IdArray
dst
=
IdArray
::
Empty
({
len
},
DLDataType
{
kDLInt
,
64
,
1
},
DLContext
{
kDLCPU
,
0
});
IdArray
dst
=
IdArray
::
Empty
({
len
},
DLDataType
{
kDLInt
,
64
,
1
},
DLContext
{
kDLCPU
,
0
});
IdArray
eid
=
IdArray
::
Empty
({
len
},
DLDataType
{
kDLInt
,
64
,
1
},
DLContext
{
kDLCPU
,
0
});
IdArray
eid
=
IdArray
::
Empty
({
len
},
DLDataType
{
kDLInt
,
64
,
1
},
DLContext
{
kDLCPU
,
0
});
...
@@ -187,8 +190,8 @@ Graph::EdgeArray Graph::InEdges(dgl_id_t vid) const {
...
@@ -187,8 +190,8 @@ Graph::EdgeArray Graph::InEdges(dgl_id_t vid) const {
int64_t
*
dst_data
=
static_cast
<
int64_t
*>
(
dst
->
data
);
int64_t
*
dst_data
=
static_cast
<
int64_t
*>
(
dst
->
data
);
int64_t
*
eid_data
=
static_cast
<
int64_t
*>
(
eid
->
data
);
int64_t
*
eid_data
=
static_cast
<
int64_t
*>
(
eid
->
data
);
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
src_data
[
i
]
=
adjlist_
[
vid
].
pred
[
i
];
src_data
[
i
]
=
reverse_
adjlist_
[
vid
].
succ
[
i
];
eid_data
[
i
]
=
adjlist_
[
vid
].
pred_
edge_id
[
i
];
eid_data
[
i
]
=
reverse_
adjlist_
[
vid
].
edge_id
[
i
];
}
}
std
::
fill
(
dst_data
,
dst_data
+
len
,
vid
);
std
::
fill
(
dst_data
,
dst_data
+
len
,
vid
);
return
EdgeArray
{
src
,
dst
,
eid
};
return
EdgeArray
{
src
,
dst
,
eid
};
...
@@ -202,7 +205,7 @@ Graph::EdgeArray Graph::InEdges(IdArray vids) const {
...
@@ -202,7 +205,7 @@ Graph::EdgeArray Graph::InEdges(IdArray vids) const {
int64_t
rstlen
=
0
;
int64_t
rstlen
=
0
;
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
CHECK
(
HasVertex
(
vid_data
[
i
]))
<<
"Invalid vertex: "
<<
vid_data
[
i
];
CHECK
(
HasVertex
(
vid_data
[
i
]))
<<
"Invalid vertex: "
<<
vid_data
[
i
];
rstlen
+=
adjlist_
[
vid_data
[
i
]].
pred
.
size
();
rstlen
+=
reverse_
adjlist_
[
vid_data
[
i
]].
succ
.
size
();
}
}
IdArray
src
=
IdArray
::
Empty
({
rstlen
},
vids
->
dtype
,
vids
->
ctx
);
IdArray
src
=
IdArray
::
Empty
({
rstlen
},
vids
->
dtype
,
vids
->
ctx
);
IdArray
dst
=
IdArray
::
Empty
({
rstlen
},
vids
->
dtype
,
vids
->
ctx
);
IdArray
dst
=
IdArray
::
Empty
({
rstlen
},
vids
->
dtype
,
vids
->
ctx
);
...
@@ -211,8 +214,8 @@ Graph::EdgeArray Graph::InEdges(IdArray vids) const {
...
@@ -211,8 +214,8 @@ Graph::EdgeArray Graph::InEdges(IdArray vids) const {
int64_t
*
dst_ptr
=
static_cast
<
int64_t
*>
(
dst
->
data
);
int64_t
*
dst_ptr
=
static_cast
<
int64_t
*>
(
dst
->
data
);
int64_t
*
eid_ptr
=
static_cast
<
int64_t
*>
(
eid
->
data
);
int64_t
*
eid_ptr
=
static_cast
<
int64_t
*>
(
eid
->
data
);
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
const
auto
&
pred
=
adjlist_
[
vid_data
[
i
]].
pred
;
const
auto
&
pred
=
reverse_
adjlist_
[
vid_data
[
i
]].
succ
;
const
auto
&
eids
=
adjlist_
[
vid_data
[
i
]].
pred_
edge_id
;
const
auto
&
eids
=
reverse_
adjlist_
[
vid_data
[
i
]].
edge_id
;
for
(
size_t
j
=
0
;
j
<
pred
.
size
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
pred
.
size
();
++
j
)
{
*
(
src_ptr
++
)
=
pred
[
j
];
*
(
src_ptr
++
)
=
pred
[
j
];
*
(
dst_ptr
++
)
=
vid_data
[
i
];
*
(
dst_ptr
++
)
=
vid_data
[
i
];
...
@@ -234,7 +237,7 @@ Graph::EdgeArray Graph::OutEdges(dgl_id_t vid) const {
...
@@ -234,7 +237,7 @@ Graph::EdgeArray Graph::OutEdges(dgl_id_t vid) const {
int64_t
*
eid_data
=
static_cast
<
int64_t
*>
(
eid
->
data
);
int64_t
*
eid_data
=
static_cast
<
int64_t
*>
(
eid
->
data
);
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
dst_data
[
i
]
=
adjlist_
[
vid
].
succ
[
i
];
dst_data
[
i
]
=
adjlist_
[
vid
].
succ
[
i
];
eid_data
[
i
]
=
adjlist_
[
vid
].
succ_
edge_id
[
i
];
eid_data
[
i
]
=
adjlist_
[
vid
].
edge_id
[
i
];
}
}
std
::
fill
(
src_data
,
src_data
+
len
,
vid
);
std
::
fill
(
src_data
,
src_data
+
len
,
vid
);
return
EdgeArray
{
src
,
dst
,
eid
};
return
EdgeArray
{
src
,
dst
,
eid
};
...
@@ -258,7 +261,7 @@ Graph::EdgeArray Graph::OutEdges(IdArray vids) const {
...
@@ -258,7 +261,7 @@ Graph::EdgeArray Graph::OutEdges(IdArray vids) const {
int64_t
*
eid_ptr
=
static_cast
<
int64_t
*>
(
eid
->
data
);
int64_t
*
eid_ptr
=
static_cast
<
int64_t
*>
(
eid
->
data
);
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
const
auto
&
succ
=
adjlist_
[
vid_data
[
i
]].
succ
;
const
auto
&
succ
=
adjlist_
[
vid_data
[
i
]].
succ
;
const
auto
&
eids
=
adjlist_
[
vid_data
[
i
]].
succ_
edge_id
;
const
auto
&
eids
=
adjlist_
[
vid_data
[
i
]].
edge_id
;
for
(
size_t
j
=
0
;
j
<
succ
.
size
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
succ
.
size
();
++
j
)
{
*
(
src_ptr
++
)
=
vid_data
[
i
];
*
(
src_ptr
++
)
=
vid_data
[
i
];
*
(
dst_ptr
++
)
=
succ
[
j
];
*
(
dst_ptr
++
)
=
succ
[
j
];
...
@@ -279,22 +282,21 @@ Graph::EdgeArray Graph::Edges(bool sorted) const {
...
@@ -279,22 +282,21 @@ Graph::EdgeArray Graph::Edges(bool sorted) const {
typedef
std
::
tuple
<
int64_t
,
int64_t
,
int64_t
>
Tuple
;
typedef
std
::
tuple
<
int64_t
,
int64_t
,
int64_t
>
Tuple
;
std
::
vector
<
Tuple
>
tuples
;
std
::
vector
<
Tuple
>
tuples
;
tuples
.
reserve
(
len
);
tuples
.
reserve
(
len
);
for
(
dgl_id_t
u
=
0
;
u
<
NumVertices
();
++
u
)
{
for
(
uint64_t
eid
=
0
;
eid
<
num_edges_
;
++
eid
)
{
for
(
size_t
i
=
0
;
i
<
adjlist_
[
u
].
succ
.
size
();
++
i
)
{
tuples
.
emplace_back
(
all_edges_src_
[
eid
],
all_edges_dst_
[
eid
],
eid
);
tuples
.
emplace_back
(
u
,
adjlist_
[
u
].
succ
[
i
],
adjlist_
[
u
].
succ_edge_id
[
i
]);
}
}
}
// sort according to
edge
ids
// sort according to
src and dst
ids
std
::
sort
(
tuples
.
begin
(),
tuples
.
end
(),
std
::
sort
(
tuples
.
begin
(),
tuples
.
end
(),
[]
(
const
Tuple
&
t1
,
const
Tuple
&
t2
)
{
[]
(
const
Tuple
&
t1
,
const
Tuple
&
t2
)
{
return
std
::
get
<
2
>
(
t1
)
<
std
::
get
<
2
>
(
t2
);
return
std
::
get
<
0
>
(
t1
)
<
std
::
get
<
0
>
(
t2
)
||
(
std
::
get
<
0
>
(
t1
)
==
std
::
get
<
0
>
(
t2
)
&&
std
::
get
<
1
>
(
t1
)
<
std
::
get
<
1
>
(
t2
));
});
});
// make return arrays
// make return arrays
int64_t
*
src_ptr
=
static_cast
<
int64_t
*>
(
src
->
data
);
int64_t
*
src_ptr
=
static_cast
<
int64_t
*>
(
src
->
data
);
int64_t
*
dst_ptr
=
static_cast
<
int64_t
*>
(
dst
->
data
);
int64_t
*
dst_ptr
=
static_cast
<
int64_t
*>
(
dst
->
data
);
int64_t
*
eid_ptr
=
static_cast
<
int64_t
*>
(
eid
->
data
);
int64_t
*
eid_ptr
=
static_cast
<
int64_t
*>
(
eid
->
data
);
for
(
int64
_t
i
=
0
;
i
<
len
;
++
i
)
{
for
(
size
_t
i
=
0
;
i
<
tuples
.
size
()
;
++
i
)
{
src_ptr
[
i
]
=
std
::
get
<
0
>
(
tuples
[
i
]);
src_ptr
[
i
]
=
std
::
get
<
0
>
(
tuples
[
i
]);
dst_ptr
[
i
]
=
std
::
get
<
1
>
(
tuples
[
i
]);
dst_ptr
[
i
]
=
std
::
get
<
1
>
(
tuples
[
i
]);
eid_ptr
[
i
]
=
std
::
get
<
2
>
(
tuples
[
i
]);
eid_ptr
[
i
]
=
std
::
get
<
2
>
(
tuples
[
i
]);
...
@@ -303,12 +305,10 @@ Graph::EdgeArray Graph::Edges(bool sorted) const {
...
@@ -303,12 +305,10 @@ Graph::EdgeArray Graph::Edges(bool sorted) const {
int64_t
*
src_ptr
=
static_cast
<
int64_t
*>
(
src
->
data
);
int64_t
*
src_ptr
=
static_cast
<
int64_t
*>
(
src
->
data
);
int64_t
*
dst_ptr
=
static_cast
<
int64_t
*>
(
dst
->
data
);
int64_t
*
dst_ptr
=
static_cast
<
int64_t
*>
(
dst
->
data
);
int64_t
*
eid_ptr
=
static_cast
<
int64_t
*>
(
eid
->
data
);
int64_t
*
eid_ptr
=
static_cast
<
int64_t
*>
(
eid
->
data
);
for
(
dgl_id_t
u
=
0
;
u
<
NumVertices
();
++
u
)
{
std
::
copy
(
all_edges_src_
.
begin
(),
all_edges_src_
.
end
(),
src_ptr
);
for
(
size_t
i
=
0
;
i
<
adjlist_
[
u
].
succ
.
size
();
++
i
)
{
std
::
copy
(
all_edges_dst_
.
begin
(),
all_edges_dst_
.
end
(),
dst_ptr
);
*
(
src_ptr
++
)
=
u
;
for
(
uint64_t
eid
=
0
;
eid
<
num_edges_
;
++
eid
)
{
*
(
dst_ptr
++
)
=
adjlist_
[
u
].
succ
[
i
];
eid_ptr
[
eid
]
=
eid
;
*
(
eid_ptr
++
)
=
adjlist_
[
u
].
succ_edge_id
[
i
];
}
}
}
}
}
...
@@ -325,7 +325,7 @@ DegreeArray Graph::InDegrees(IdArray vids) const {
...
@@ -325,7 +325,7 @@ DegreeArray Graph::InDegrees(IdArray vids) const {
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
const
auto
vid
=
vid_data
[
i
];
const
auto
vid
=
vid_data
[
i
];
CHECK
(
HasVertex
(
vid
))
<<
"Invalid vertex: "
<<
vid
;
CHECK
(
HasVertex
(
vid
))
<<
"Invalid vertex: "
<<
vid
;
rst_data
[
i
]
=
adjlist_
[
vid
].
pred
.
size
();
rst_data
[
i
]
=
reverse_
adjlist_
[
vid
].
succ
.
size
();
}
}
return
rst
;
return
rst
;
}
}
...
...
tests/pytorch/test_cached_graph.py
deleted
100644 → 0
View file @
b2e4bdc0
import
torch
as
th
import
numpy
as
np
import
networkx
as
nx
from
dgl
import
DGLGraph
from
dgl.cached_graph
import
*
from
dgl.utils
import
Index
def
check_eq
(
a
,
b
):
assert
a
.
shape
==
b
.
shape
assert
th
.
sum
(
a
==
b
)
==
int
(
np
.
prod
(
list
(
a
.
shape
)))
def
test_basics
():
g
=
DGLGraph
()
g
.
add_edge
(
0
,
1
)
g
.
add_edge
(
1
,
2
)
g
.
add_edge
(
1
,
3
)
g
.
add_edge
(
2
,
4
)
g
.
add_edge
(
2
,
5
)
g
.
add_edge
(
0
,
2
)
cg
=
create_cached_graph
(
g
)
u
=
Index
(
th
.
tensor
([
0
,
0
,
1
,
1
,
2
,
2
]))
v
=
Index
(
th
.
tensor
([
1
,
2
,
2
,
3
,
4
,
5
]))
check_eq
(
cg
.
get_edge_id
(
u
,
v
).
totensor
(),
th
.
tensor
([
0
,
5
,
1
,
2
,
3
,
4
]))
query
=
Index
(
th
.
tensor
([
0
,
1
,
2
,
5
]))
s
,
d
,
orphan
=
cg
.
in_edges
(
query
)
check_eq
(
s
.
totensor
(),
th
.
tensor
([
0
,
0
,
1
,
2
]))
check_eq
(
d
.
totensor
(),
th
.
tensor
([
1
,
2
,
2
,
5
]))
assert
orphan
.
tolist
()
==
[
0
]
s
,
d
,
orphan
=
cg
.
out_edges
(
query
)
check_eq
(
s
.
totensor
(),
th
.
tensor
([
0
,
0
,
1
,
1
,
2
,
2
]))
check_eq
(
d
.
totensor
(),
th
.
tensor
([
1
,
2
,
2
,
3
,
4
,
5
]))
assert
orphan
.
tolist
()
==
[
5
]
if
__name__
==
'__main__'
:
test_basics
()
tests/pytorch/test_specialization.py
View file @
7b5d4c58
...
@@ -7,8 +7,7 @@ D = 5
...
@@ -7,8 +7,7 @@ D = 5
def
generate_graph
():
def
generate_graph
():
g
=
dgl
.
DGLGraph
()
g
=
dgl
.
DGLGraph
()
for
i
in
range
(
10
):
g
.
add_nodes
(
10
)
g
.
add_node
(
i
)
# 10 nodes.
# create a graph where 0 is the source and 9 is the sink
# create a graph where 0 is the source and 9 is the sink
for
i
in
range
(
1
,
9
):
for
i
in
range
(
1
,
9
):
g
.
add_edge
(
0
,
i
)
g
.
add_edge
(
0
,
i
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment