Unverified Commit e3a9a6bb authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

add an optional include_dst_in_src argument (#1401)

parent af61e2fb
......@@ -66,17 +66,20 @@ CompactGraphs(
*
* \param graph The graph.
* \param rhs_nodes Designated nodes that would appear on the right side.
* \param include_rhs_in_lhs If false, do not include the nodes of node type \c ntype_r
* in \c ntype_l.
*
* \return A triplet containing
* * The bipartite-structured graph,
* * The induced node from the left side for each graph,
* * The induced edges.
*
* \note For each node type \c ntype, the nodes in rhs_nodes[ntype] would always
* appear first in the nodes of type \c ntype_l in the new graph.
* \note If include_rhs_in_lhs is true, then for each node type \c ntype, the nodes
* in rhs_nodes[ntype] would always appear first in the nodes of type \c ntype_l
* in the new graph.
*/
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes);
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs);
/*!
* \brief Convert a multigraph to a simple graph.
......
......@@ -749,7 +749,7 @@ def compact_graphs(graphs, always_preserve=None):
return new_graphs
def to_block(g, dst_nodes=None):
def to_block(g, dst_nodes=None, include_dst_in_src=True):
"""Convert a graph into a bipartite-structured "block" for message passing.
A block graph is uni-directional bipartite graph consisting of two sets of nodes
......@@ -767,7 +767,7 @@ def to_block(g, dst_nodes=None):
Moreover, the function also relabels node ids in each type to make the graph more compact.
Specifically, the nodes of type ``vtype`` would contain the nodes that have at least one
inbound edge of any type, while ``utype`` would contain all the DST nodes of type ``utype``,
inbound edge of any type, while ``utype`` would contain all the DST nodes of type ``vtype``,
as well as the nodes that have at least one outbound edge to any DST node.
Since DST nodes are included in SRC nodes, a common requirement is to fetch
......@@ -789,6 +789,8 @@ def to_block(g, dst_nodes=None):
The graph.
dst_nodes : Tensor or dict[str, Tensor], optional
Optional DST nodes. If a tensor is given, the graph must have only one node type.
include_dst_in_src : bool, default True
If False, do not include DST nodes in SRC nodes.
Returns
-------
......@@ -882,7 +884,8 @@ def to_block(g, dst_nodes=None):
else:
dst_nodes_nd.append(nd.null())
new_graph_index, src_nodes_nd, induced_edges_nd = _CAPI_DGLToBlock(g._graph, dst_nodes_nd)
new_graph_index, src_nodes_nd, induced_edges_nd = _CAPI_DGLToBlock(
g._graph, dst_nodes_nd, include_dst_in_src)
src_nodes = [F.zerocopy_from_dgl_ndarray(nodes_nd.data) for nodes_nd in src_nodes_nd]
dst_nodes = [F.zerocopy_from_dgl_ndarray(nodes_nd) for nodes_nd in dst_nodes_nd]
......
......@@ -28,7 +28,7 @@ namespace {
template<typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes) {
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs) {
const int64_t num_etypes = graph->NumEdgeTypes();
const int64_t num_ntypes = graph->NumVertexTypes();
std::vector<EdgeArray> edge_arrays(num_etypes);
......@@ -37,7 +37,13 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes) {
<< "rhs_nodes not given for every node type";
const std::vector<IdHashMap<IdType>> rhs_node_mappings(rhs_nodes.begin(), rhs_nodes.end());
std::vector<IdHashMap<IdType>> lhs_node_mappings(rhs_node_mappings); // copy
std::vector<IdHashMap<IdType>> lhs_node_mappings;
if (include_rhs_in_lhs)
lhs_node_mappings = rhs_node_mappings; // copy
else
lhs_node_mappings.resize(num_ntypes);
std::vector<int64_t> num_nodes_per_type;
num_nodes_per_type.reserve(2 * num_ntypes);
......@@ -87,10 +93,10 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes) {
}; // namespace
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes) {
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs) {
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> ret;
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ret = ToBlock<IdType>(graph, rhs_nodes);
ret = ToBlock<IdType>(graph, rhs_nodes, include_rhs_in_lhs);
});
return ret;
}
......@@ -99,11 +105,13 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
const HeteroGraphRef graph_ref = args[0];
const std::vector<IdArray> &rhs_nodes = ListValueToVector<IdArray>(args[1]);
const bool include_rhs_in_lhs = args[2];
HeteroGraphPtr new_graph;
std::vector<IdArray> lhs_nodes;
std::vector<IdArray> induced_edges;
std::tie(new_graph, lhs_nodes, induced_edges) = ToBlock(graph_ref.sptr(), rhs_nodes);
std::tie(new_graph, lhs_nodes, induced_edges) = ToBlock(
graph_ref.sptr(), rhs_nodes, include_rhs_in_lhs);
List<Value> lhs_nodes_ref;
for (IdArray &array : lhs_nodes)
......
......@@ -428,13 +428,14 @@ def test_to_simple():
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU compaction not implemented")
def test_to_block():
def check(g, bg, ntype, etype, dst_nodes):
def check(g, bg, ntype, etype, dst_nodes, include_dst_in_src=True):
if dst_nodes is not None:
assert F.array_equal(bg.dstnodes[ntype].data[dgl.NID], dst_nodes)
n_dst_nodes = bg.number_of_nodes('DST/' + ntype)
assert F.array_equal(
bg.srcnodes[ntype].data[dgl.NID][:n_dst_nodes],
bg.dstnodes[ntype].data[dgl.NID])
if include_dst_in_src:
assert F.array_equal(
bg.srcnodes[ntype].data[dgl.NID][:n_dst_nodes],
bg.dstnodes[ntype].data[dgl.NID])
g = g[etype]
bg = bg[etype]
......@@ -452,13 +453,13 @@ def test_to_block():
assert F.array_equal(induced_src_bg, induced_src_ans)
assert F.array_equal(induced_dst_bg, induced_dst_ans)
def checkall(g, bg, dst_nodes):
def checkall(g, bg, dst_nodes, include_dst_in_src=True):
for etype in g.etypes:
ntype = g.to_canonical_etype(etype)[2]
if dst_nodes is not None and ntype in dst_nodes:
check(g, bg, ntype, etype, dst_nodes[ntype])
check(g, bg, ntype, etype, dst_nodes[ntype], include_dst_in_src)
else:
check(g, bg, ntype, etype, None)
check(g, bg, ntype, etype, None, include_dst_in_src)
g = dgl.heterograph({
('A', 'AA', 'A'): [(0, 1), (2, 3), (1, 2), (3, 4)],
......@@ -468,6 +469,13 @@ def test_to_block():
bg = dgl.to_block(g_a)
check(g_a, bg, 'A', 'AA', None)
assert bg.number_of_src_nodes() == 5
assert bg.number_of_dst_nodes() == 4
bg = dgl.to_block(g_a, include_dst_in_src=False)
check(g_a, bg, 'A', 'AA', None, False)
assert bg.number_of_src_nodes() == 4
assert bg.number_of_dst_nodes() == 4
dst_nodes = F.tensor([3, 4], dtype=F.int64)
bg = dgl.to_block(g_a, dst_nodes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment