Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e3a9a6bb
Unverified
Commit
e3a9a6bb
authored
Mar 30, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Mar 30, 2020
Browse files
add an optional include_dst_in_src argument (#1401)
parent
af61e2fb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
40 additions
and
18 deletions
+40
-18
include/dgl/transform.h
include/dgl/transform.h
+6
-3
python/dgl/transform.py
python/dgl/transform.py
+6
-3
src/graph/transform/to_bipartite.cc
src/graph/transform/to_bipartite.cc
+13
-5
tests/compute/test_transform.py
tests/compute/test_transform.py
+15
-7
No files found.
include/dgl/transform.h
View file @
e3a9a6bb
...
@@ -66,17 +66,20 @@ CompactGraphs(
...
@@ -66,17 +66,20 @@ CompactGraphs(
*
*
* \param graph The graph.
* \param graph The graph.
* \param rhs_nodes Designated nodes that would appear on the right side.
* \param rhs_nodes Designated nodes that would appear on the right side.
* \param include_rhs_in_lhs If false, do not include the nodes of node type \c ntype_r
* in \c ntype_l.
*
*
* \return A triplet containing
* \return A triplet containing
* * The bipartite-structured graph,
* * The bipartite-structured graph,
* * The induced node from the left side for each graph,
* * The induced node from the left side for each graph,
* * The induced edges.
* * The induced edges.
*
*
* \note For each node type \c ntype, the nodes in rhs_nodes[ntype] would always
* \note If include_rhs_in_lhs is true, then for each node type \c ntype, the nodes
* appear first in the nodes of type \c ntype_l in the new graph.
* in rhs_nodes[ntype] would always appear first in the nodes of type \c ntype_l
* in the new graph.
*/
*/
std
::
tuple
<
HeteroGraphPtr
,
std
::
vector
<
IdArray
>
,
std
::
vector
<
IdArray
>>
std
::
tuple
<
HeteroGraphPtr
,
std
::
vector
<
IdArray
>
,
std
::
vector
<
IdArray
>>
ToBlock
(
HeteroGraphPtr
graph
,
const
std
::
vector
<
IdArray
>
&
rhs_nodes
);
ToBlock
(
HeteroGraphPtr
graph
,
const
std
::
vector
<
IdArray
>
&
rhs_nodes
,
bool
include_rhs_in_lhs
);
/*!
/*!
* \brief Convert a multigraph to a simple graph.
* \brief Convert a multigraph to a simple graph.
...
...
python/dgl/transform.py
View file @
e3a9a6bb
...
@@ -749,7 +749,7 @@ def compact_graphs(graphs, always_preserve=None):
...
@@ -749,7 +749,7 @@ def compact_graphs(graphs, always_preserve=None):
return
new_graphs
return
new_graphs
def
to_block
(
g
,
dst_nodes
=
None
):
def
to_block
(
g
,
dst_nodes
=
None
,
include_dst_in_src
=
True
):
"""Convert a graph into a bipartite-structured "block" for message passing.
"""Convert a graph into a bipartite-structured "block" for message passing.
A block graph is uni-directional bipartite graph consisting of two sets of nodes
A block graph is uni-directional bipartite graph consisting of two sets of nodes
...
@@ -767,7 +767,7 @@ def to_block(g, dst_nodes=None):
...
@@ -767,7 +767,7 @@ def to_block(g, dst_nodes=None):
Moreover, the function also relabels node ids in each type to make the graph more compact.
Moreover, the function also relabels node ids in each type to make the graph more compact.
Specifically, the nodes of type ``vtype`` would contain the nodes that have at least one
Specifically, the nodes of type ``vtype`` would contain the nodes that have at least one
inbound edge of any type, while ``utype`` would contain all the DST nodes of type ``
u
type``,
inbound edge of any type, while ``utype`` would contain all the DST nodes of type ``
v
type``,
as well as the nodes that have at least one outbound edge to any DST node.
as well as the nodes that have at least one outbound edge to any DST node.
Since DST nodes are included in SRC nodes, a common requirement is to fetch
Since DST nodes are included in SRC nodes, a common requirement is to fetch
...
@@ -789,6 +789,8 @@ def to_block(g, dst_nodes=None):
...
@@ -789,6 +789,8 @@ def to_block(g, dst_nodes=None):
The graph.
The graph.
dst_nodes : Tensor or dict[str, Tensor], optional
dst_nodes : Tensor or dict[str, Tensor], optional
Optional DST nodes. If a tensor is given, the graph must have only one node type.
Optional DST nodes. If a tensor is given, the graph must have only one node type.
include_dst_in_src : bool, default True
If False, do not include DST nodes in SRC nodes.
Returns
Returns
-------
-------
...
@@ -882,7 +884,8 @@ def to_block(g, dst_nodes=None):
...
@@ -882,7 +884,8 @@ def to_block(g, dst_nodes=None):
else
:
else
:
dst_nodes_nd
.
append
(
nd
.
null
())
dst_nodes_nd
.
append
(
nd
.
null
())
new_graph_index
,
src_nodes_nd
,
induced_edges_nd
=
_CAPI_DGLToBlock
(
g
.
_graph
,
dst_nodes_nd
)
new_graph_index
,
src_nodes_nd
,
induced_edges_nd
=
_CAPI_DGLToBlock
(
g
.
_graph
,
dst_nodes_nd
,
include_dst_in_src
)
src_nodes
=
[
F
.
zerocopy_from_dgl_ndarray
(
nodes_nd
.
data
)
for
nodes_nd
in
src_nodes_nd
]
src_nodes
=
[
F
.
zerocopy_from_dgl_ndarray
(
nodes_nd
.
data
)
for
nodes_nd
in
src_nodes_nd
]
dst_nodes
=
[
F
.
zerocopy_from_dgl_ndarray
(
nodes_nd
)
for
nodes_nd
in
dst_nodes_nd
]
dst_nodes
=
[
F
.
zerocopy_from_dgl_ndarray
(
nodes_nd
)
for
nodes_nd
in
dst_nodes_nd
]
...
...
src/graph/transform/to_bipartite.cc
View file @
e3a9a6bb
...
@@ -28,7 +28,7 @@ namespace {
...
@@ -28,7 +28,7 @@ namespace {
template
<
typename
IdType
>
template
<
typename
IdType
>
std
::
tuple
<
HeteroGraphPtr
,
std
::
vector
<
IdArray
>
,
std
::
vector
<
IdArray
>>
std
::
tuple
<
HeteroGraphPtr
,
std
::
vector
<
IdArray
>
,
std
::
vector
<
IdArray
>>
ToBlock
(
HeteroGraphPtr
graph
,
const
std
::
vector
<
IdArray
>
&
rhs_nodes
)
{
ToBlock
(
HeteroGraphPtr
graph
,
const
std
::
vector
<
IdArray
>
&
rhs_nodes
,
bool
include_rhs_in_lhs
)
{
const
int64_t
num_etypes
=
graph
->
NumEdgeTypes
();
const
int64_t
num_etypes
=
graph
->
NumEdgeTypes
();
const
int64_t
num_ntypes
=
graph
->
NumVertexTypes
();
const
int64_t
num_ntypes
=
graph
->
NumVertexTypes
();
std
::
vector
<
EdgeArray
>
edge_arrays
(
num_etypes
);
std
::
vector
<
EdgeArray
>
edge_arrays
(
num_etypes
);
...
@@ -37,7 +37,13 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes) {
...
@@ -37,7 +37,13 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes) {
<<
"rhs_nodes not given for every node type"
;
<<
"rhs_nodes not given for every node type"
;
const
std
::
vector
<
IdHashMap
<
IdType
>>
rhs_node_mappings
(
rhs_nodes
.
begin
(),
rhs_nodes
.
end
());
const
std
::
vector
<
IdHashMap
<
IdType
>>
rhs_node_mappings
(
rhs_nodes
.
begin
(),
rhs_nodes
.
end
());
std
::
vector
<
IdHashMap
<
IdType
>>
lhs_node_mappings
(
rhs_node_mappings
);
// copy
std
::
vector
<
IdHashMap
<
IdType
>>
lhs_node_mappings
;
if
(
include_rhs_in_lhs
)
lhs_node_mappings
=
rhs_node_mappings
;
// copy
else
lhs_node_mappings
.
resize
(
num_ntypes
);
std
::
vector
<
int64_t
>
num_nodes_per_type
;
std
::
vector
<
int64_t
>
num_nodes_per_type
;
num_nodes_per_type
.
reserve
(
2
*
num_ntypes
);
num_nodes_per_type
.
reserve
(
2
*
num_ntypes
);
...
@@ -87,10 +93,10 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes) {
...
@@ -87,10 +93,10 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes) {
};
// namespace
};
// namespace
std
::
tuple
<
HeteroGraphPtr
,
std
::
vector
<
IdArray
>
,
std
::
vector
<
IdArray
>>
std
::
tuple
<
HeteroGraphPtr
,
std
::
vector
<
IdArray
>
,
std
::
vector
<
IdArray
>>
ToBlock
(
HeteroGraphPtr
graph
,
const
std
::
vector
<
IdArray
>
&
rhs_nodes
)
{
ToBlock
(
HeteroGraphPtr
graph
,
const
std
::
vector
<
IdArray
>
&
rhs_nodes
,
bool
include_rhs_in_lhs
)
{
std
::
tuple
<
HeteroGraphPtr
,
std
::
vector
<
IdArray
>
,
std
::
vector
<
IdArray
>>
ret
;
std
::
tuple
<
HeteroGraphPtr
,
std
::
vector
<
IdArray
>
,
std
::
vector
<
IdArray
>>
ret
;
ATEN_ID_TYPE_SWITCH
(
graph
->
DataType
(),
IdType
,
{
ATEN_ID_TYPE_SWITCH
(
graph
->
DataType
(),
IdType
,
{
ret
=
ToBlock
<
IdType
>
(
graph
,
rhs_nodes
);
ret
=
ToBlock
<
IdType
>
(
graph
,
rhs_nodes
,
include_rhs_in_lhs
);
});
});
return
ret
;
return
ret
;
}
}
...
@@ -99,11 +105,13 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock")
...
@@ -99,11 +105,13 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock")
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
const
HeteroGraphRef
graph_ref
=
args
[
0
];
const
HeteroGraphRef
graph_ref
=
args
[
0
];
const
std
::
vector
<
IdArray
>
&
rhs_nodes
=
ListValueToVector
<
IdArray
>
(
args
[
1
]);
const
std
::
vector
<
IdArray
>
&
rhs_nodes
=
ListValueToVector
<
IdArray
>
(
args
[
1
]);
const
bool
include_rhs_in_lhs
=
args
[
2
];
HeteroGraphPtr
new_graph
;
HeteroGraphPtr
new_graph
;
std
::
vector
<
IdArray
>
lhs_nodes
;
std
::
vector
<
IdArray
>
lhs_nodes
;
std
::
vector
<
IdArray
>
induced_edges
;
std
::
vector
<
IdArray
>
induced_edges
;
std
::
tie
(
new_graph
,
lhs_nodes
,
induced_edges
)
=
ToBlock
(
graph_ref
.
sptr
(),
rhs_nodes
);
std
::
tie
(
new_graph
,
lhs_nodes
,
induced_edges
)
=
ToBlock
(
graph_ref
.
sptr
(),
rhs_nodes
,
include_rhs_in_lhs
);
List
<
Value
>
lhs_nodes_ref
;
List
<
Value
>
lhs_nodes_ref
;
for
(
IdArray
&
array
:
lhs_nodes
)
for
(
IdArray
&
array
:
lhs_nodes
)
...
...
tests/compute/test_transform.py
View file @
e3a9a6bb
...
@@ -428,13 +428,14 @@ def test_to_simple():
...
@@ -428,13 +428,14 @@ def test_to_simple():
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"GPU compaction not implemented"
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'gpu'
,
reason
=
"GPU compaction not implemented"
)
def
test_to_block
():
def
test_to_block
():
def
check
(
g
,
bg
,
ntype
,
etype
,
dst_nodes
):
def
check
(
g
,
bg
,
ntype
,
etype
,
dst_nodes
,
include_dst_in_src
=
True
):
if
dst_nodes
is
not
None
:
if
dst_nodes
is
not
None
:
assert
F
.
array_equal
(
bg
.
dstnodes
[
ntype
].
data
[
dgl
.
NID
],
dst_nodes
)
assert
F
.
array_equal
(
bg
.
dstnodes
[
ntype
].
data
[
dgl
.
NID
],
dst_nodes
)
n_dst_nodes
=
bg
.
number_of_nodes
(
'DST/'
+
ntype
)
n_dst_nodes
=
bg
.
number_of_nodes
(
'DST/'
+
ntype
)
assert
F
.
array_equal
(
if
include_dst_in_src
:
bg
.
srcnodes
[
ntype
].
data
[
dgl
.
NID
][:
n_dst_nodes
],
assert
F
.
array_equal
(
bg
.
dstnodes
[
ntype
].
data
[
dgl
.
NID
])
bg
.
srcnodes
[
ntype
].
data
[
dgl
.
NID
][:
n_dst_nodes
],
bg
.
dstnodes
[
ntype
].
data
[
dgl
.
NID
])
g
=
g
[
etype
]
g
=
g
[
etype
]
bg
=
bg
[
etype
]
bg
=
bg
[
etype
]
...
@@ -452,13 +453,13 @@ def test_to_block():
...
@@ -452,13 +453,13 @@ def test_to_block():
assert
F
.
array_equal
(
induced_src_bg
,
induced_src_ans
)
assert
F
.
array_equal
(
induced_src_bg
,
induced_src_ans
)
assert
F
.
array_equal
(
induced_dst_bg
,
induced_dst_ans
)
assert
F
.
array_equal
(
induced_dst_bg
,
induced_dst_ans
)
def
checkall
(
g
,
bg
,
dst_nodes
):
def
checkall
(
g
,
bg
,
dst_nodes
,
include_dst_in_src
=
True
):
for
etype
in
g
.
etypes
:
for
etype
in
g
.
etypes
:
ntype
=
g
.
to_canonical_etype
(
etype
)[
2
]
ntype
=
g
.
to_canonical_etype
(
etype
)[
2
]
if
dst_nodes
is
not
None
and
ntype
in
dst_nodes
:
if
dst_nodes
is
not
None
and
ntype
in
dst_nodes
:
check
(
g
,
bg
,
ntype
,
etype
,
dst_nodes
[
ntype
])
check
(
g
,
bg
,
ntype
,
etype
,
dst_nodes
[
ntype
]
,
include_dst_in_src
)
else
:
else
:
check
(
g
,
bg
,
ntype
,
etype
,
None
)
check
(
g
,
bg
,
ntype
,
etype
,
None
,
include_dst_in_src
)
g
=
dgl
.
heterograph
({
g
=
dgl
.
heterograph
({
(
'A'
,
'AA'
,
'A'
):
[(
0
,
1
),
(
2
,
3
),
(
1
,
2
),
(
3
,
4
)],
(
'A'
,
'AA'
,
'A'
):
[(
0
,
1
),
(
2
,
3
),
(
1
,
2
),
(
3
,
4
)],
...
@@ -468,6 +469,13 @@ def test_to_block():
...
@@ -468,6 +469,13 @@ def test_to_block():
bg
=
dgl
.
to_block
(
g_a
)
bg
=
dgl
.
to_block
(
g_a
)
check
(
g_a
,
bg
,
'A'
,
'AA'
,
None
)
check
(
g_a
,
bg
,
'A'
,
'AA'
,
None
)
assert
bg
.
number_of_src_nodes
()
==
5
assert
bg
.
number_of_dst_nodes
()
==
4
bg
=
dgl
.
to_block
(
g_a
,
include_dst_in_src
=
False
)
check
(
g_a
,
bg
,
'A'
,
'AA'
,
None
,
False
)
assert
bg
.
number_of_src_nodes
()
==
4
assert
bg
.
number_of_dst_nodes
()
==
4
dst_nodes
=
F
.
tensor
([
3
,
4
],
dtype
=
F
.
int64
)
dst_nodes
=
F
.
tensor
([
3
,
4
],
dtype
=
F
.
int64
)
bg
=
dgl
.
to_block
(
g_a
,
dst_nodes
)
bg
=
dgl
.
to_block
(
g_a
,
dst_nodes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment