Unverified Commit e3a9a6bb authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

add an optional include_dst_in_src argument (#1401)

parent af61e2fb
...@@ -66,17 +66,20 @@ CompactGraphs( ...@@ -66,17 +66,20 @@ CompactGraphs(
* *
* \param graph The graph. * \param graph The graph.
* \param rhs_nodes Designated nodes that would appear on the right side. * \param rhs_nodes Designated nodes that would appear on the right side.
* \param include_rhs_in_lhs If false, do not include the nodes of node type \c ntype_r
* in \c ntype_l.
* *
* \return A triplet containing * \return A triplet containing
* * The bipartite-structured graph, * * The bipartite-structured graph,
* * The induced node from the left side for each graph, * * The induced node from the left side for each graph,
* * The induced edges. * * The induced edges.
* *
* \note For each node type \c ntype, the nodes in rhs_nodes[ntype] would always * \note If include_rhs_in_lhs is true, then for each node type \c ntype, the nodes
* appear first in the nodes of type \c ntype_l in the new graph. * in rhs_nodes[ntype] would always appear first in the nodes of type \c ntype_l
* in the new graph.
*/ */
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes); ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs);
/*! /*!
* \brief Convert a multigraph to a simple graph. * \brief Convert a multigraph to a simple graph.
......
...@@ -749,7 +749,7 @@ def compact_graphs(graphs, always_preserve=None): ...@@ -749,7 +749,7 @@ def compact_graphs(graphs, always_preserve=None):
return new_graphs return new_graphs
def to_block(g, dst_nodes=None): def to_block(g, dst_nodes=None, include_dst_in_src=True):
"""Convert a graph into a bipartite-structured "block" for message passing. """Convert a graph into a bipartite-structured "block" for message passing.
A block graph is uni-directional bipartite graph consisting of two sets of nodes A block graph is uni-directional bipartite graph consisting of two sets of nodes
...@@ -767,7 +767,7 @@ def to_block(g, dst_nodes=None): ...@@ -767,7 +767,7 @@ def to_block(g, dst_nodes=None):
Moreover, the function also relabels node ids in each type to make the graph more compact. Moreover, the function also relabels node ids in each type to make the graph more compact.
Specifically, the nodes of type ``vtype`` would contain the nodes that have at least one Specifically, the nodes of type ``vtype`` would contain the nodes that have at least one
inbound edge of any type, while ``utype`` would contain all the DST nodes of type ``utype``, inbound edge of any type, while ``utype`` would contain all the DST nodes of type ``vtype``,
as well as the nodes that have at least one outbound edge to any DST node. as well as the nodes that have at least one outbound edge to any DST node.
Since DST nodes are included in SRC nodes, a common requirement is to fetch Since DST nodes are included in SRC nodes, a common requirement is to fetch
...@@ -789,6 +789,8 @@ def to_block(g, dst_nodes=None): ...@@ -789,6 +789,8 @@ def to_block(g, dst_nodes=None):
The graph. The graph.
dst_nodes : Tensor or dict[str, Tensor], optional dst_nodes : Tensor or dict[str, Tensor], optional
Optional DST nodes. If a tensor is given, the graph must have only one node type. Optional DST nodes. If a tensor is given, the graph must have only one node type.
include_dst_in_src : bool, default True
If False, do not include DST nodes in SRC nodes.
Returns Returns
------- -------
...@@ -882,7 +884,8 @@ def to_block(g, dst_nodes=None): ...@@ -882,7 +884,8 @@ def to_block(g, dst_nodes=None):
else: else:
dst_nodes_nd.append(nd.null()) dst_nodes_nd.append(nd.null())
new_graph_index, src_nodes_nd, induced_edges_nd = _CAPI_DGLToBlock(g._graph, dst_nodes_nd) new_graph_index, src_nodes_nd, induced_edges_nd = _CAPI_DGLToBlock(
g._graph, dst_nodes_nd, include_dst_in_src)
src_nodes = [F.zerocopy_from_dgl_ndarray(nodes_nd.data) for nodes_nd in src_nodes_nd] src_nodes = [F.zerocopy_from_dgl_ndarray(nodes_nd.data) for nodes_nd in src_nodes_nd]
dst_nodes = [F.zerocopy_from_dgl_ndarray(nodes_nd) for nodes_nd in dst_nodes_nd] dst_nodes = [F.zerocopy_from_dgl_ndarray(nodes_nd) for nodes_nd in dst_nodes_nd]
......
...@@ -28,7 +28,7 @@ namespace { ...@@ -28,7 +28,7 @@ namespace {
template<typename IdType> template<typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes) { ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs) {
const int64_t num_etypes = graph->NumEdgeTypes(); const int64_t num_etypes = graph->NumEdgeTypes();
const int64_t num_ntypes = graph->NumVertexTypes(); const int64_t num_ntypes = graph->NumVertexTypes();
std::vector<EdgeArray> edge_arrays(num_etypes); std::vector<EdgeArray> edge_arrays(num_etypes);
...@@ -37,7 +37,13 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes) { ...@@ -37,7 +37,13 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes) {
<< "rhs_nodes not given for every node type"; << "rhs_nodes not given for every node type";
const std::vector<IdHashMap<IdType>> rhs_node_mappings(rhs_nodes.begin(), rhs_nodes.end()); const std::vector<IdHashMap<IdType>> rhs_node_mappings(rhs_nodes.begin(), rhs_nodes.end());
std::vector<IdHashMap<IdType>> lhs_node_mappings(rhs_node_mappings); // copy std::vector<IdHashMap<IdType>> lhs_node_mappings;
if (include_rhs_in_lhs)
lhs_node_mappings = rhs_node_mappings; // copy
else
lhs_node_mappings.resize(num_ntypes);
std::vector<int64_t> num_nodes_per_type; std::vector<int64_t> num_nodes_per_type;
num_nodes_per_type.reserve(2 * num_ntypes); num_nodes_per_type.reserve(2 * num_ntypes);
...@@ -87,10 +93,10 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes) { ...@@ -87,10 +93,10 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes) {
}; // namespace }; // namespace
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes) { ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs) {
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> ret; std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> ret;
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ret = ToBlock<IdType>(graph, rhs_nodes); ret = ToBlock<IdType>(graph, rhs_nodes, include_rhs_in_lhs);
}); });
return ret; return ret;
} }
...@@ -99,11 +105,13 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock") ...@@ -99,11 +105,13 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([] (DGLArgs args, DGLRetValue *rv) {
const HeteroGraphRef graph_ref = args[0]; const HeteroGraphRef graph_ref = args[0];
const std::vector<IdArray> &rhs_nodes = ListValueToVector<IdArray>(args[1]); const std::vector<IdArray> &rhs_nodes = ListValueToVector<IdArray>(args[1]);
const bool include_rhs_in_lhs = args[2];
HeteroGraphPtr new_graph; HeteroGraphPtr new_graph;
std::vector<IdArray> lhs_nodes; std::vector<IdArray> lhs_nodes;
std::vector<IdArray> induced_edges; std::vector<IdArray> induced_edges;
std::tie(new_graph, lhs_nodes, induced_edges) = ToBlock(graph_ref.sptr(), rhs_nodes); std::tie(new_graph, lhs_nodes, induced_edges) = ToBlock(
graph_ref.sptr(), rhs_nodes, include_rhs_in_lhs);
List<Value> lhs_nodes_ref; List<Value> lhs_nodes_ref;
for (IdArray &array : lhs_nodes) for (IdArray &array : lhs_nodes)
......
...@@ -428,10 +428,11 @@ def test_to_simple(): ...@@ -428,10 +428,11 @@ def test_to_simple():
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU compaction not implemented") @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU compaction not implemented")
def test_to_block(): def test_to_block():
def check(g, bg, ntype, etype, dst_nodes): def check(g, bg, ntype, etype, dst_nodes, include_dst_in_src=True):
if dst_nodes is not None: if dst_nodes is not None:
assert F.array_equal(bg.dstnodes[ntype].data[dgl.NID], dst_nodes) assert F.array_equal(bg.dstnodes[ntype].data[dgl.NID], dst_nodes)
n_dst_nodes = bg.number_of_nodes('DST/' + ntype) n_dst_nodes = bg.number_of_nodes('DST/' + ntype)
if include_dst_in_src:
assert F.array_equal( assert F.array_equal(
bg.srcnodes[ntype].data[dgl.NID][:n_dst_nodes], bg.srcnodes[ntype].data[dgl.NID][:n_dst_nodes],
bg.dstnodes[ntype].data[dgl.NID]) bg.dstnodes[ntype].data[dgl.NID])
...@@ -452,13 +453,13 @@ def test_to_block(): ...@@ -452,13 +453,13 @@ def test_to_block():
assert F.array_equal(induced_src_bg, induced_src_ans) assert F.array_equal(induced_src_bg, induced_src_ans)
assert F.array_equal(induced_dst_bg, induced_dst_ans) assert F.array_equal(induced_dst_bg, induced_dst_ans)
def checkall(g, bg, dst_nodes): def checkall(g, bg, dst_nodes, include_dst_in_src=True):
for etype in g.etypes: for etype in g.etypes:
ntype = g.to_canonical_etype(etype)[2] ntype = g.to_canonical_etype(etype)[2]
if dst_nodes is not None and ntype in dst_nodes: if dst_nodes is not None and ntype in dst_nodes:
check(g, bg, ntype, etype, dst_nodes[ntype]) check(g, bg, ntype, etype, dst_nodes[ntype], include_dst_in_src)
else: else:
check(g, bg, ntype, etype, None) check(g, bg, ntype, etype, None, include_dst_in_src)
g = dgl.heterograph({ g = dgl.heterograph({
('A', 'AA', 'A'): [(0, 1), (2, 3), (1, 2), (3, 4)], ('A', 'AA', 'A'): [(0, 1), (2, 3), (1, 2), (3, 4)],
...@@ -468,6 +469,13 @@ def test_to_block(): ...@@ -468,6 +469,13 @@ def test_to_block():
bg = dgl.to_block(g_a) bg = dgl.to_block(g_a)
check(g_a, bg, 'A', 'AA', None) check(g_a, bg, 'A', 'AA', None)
assert bg.number_of_src_nodes() == 5
assert bg.number_of_dst_nodes() == 4
bg = dgl.to_block(g_a, include_dst_in_src=False)
check(g_a, bg, 'A', 'AA', None, False)
assert bg.number_of_src_nodes() == 4
assert bg.number_of_dst_nodes() == 4
dst_nodes = F.tensor([3, 4], dtype=F.int64) dst_nodes = F.tensor([3, 4], dtype=F.int64)
bg = dgl.to_block(g_a, dst_nodes) bg = dgl.to_block(g_a, dst_nodes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment