/*! * Copyright (c) 2019 by Contributors * \file graph/transform/to_bipartite.cc * \brief Convert a graph to a bipartite-structured graph. */ #include #include #include #include #include #include #include #include #include // TODO(BarclayII): currently ToBlock depend on IdHashMap implementation which // only works on CPU. Should fix later to make it device agnostic. #include "../../array/cpu/array_utils.h" namespace dgl { using namespace dgl::runtime; using namespace dgl::aten; namespace transform { namespace { template std::tuple, std::vector> ToBlock(HeteroGraphPtr graph, const std::vector &rhs_nodes) { const int64_t num_etypes = graph->NumEdgeTypes(); const int64_t num_ntypes = graph->NumVertexTypes(); std::vector edge_arrays(num_etypes); CHECK(rhs_nodes.size() == static_cast(num_ntypes)) << "rhs_nodes not given for every node type"; const std::vector> rhs_node_mappings(rhs_nodes.begin(), rhs_nodes.end()); std::vector> lhs_node_mappings(rhs_node_mappings); // copy std::vector num_nodes_per_type; num_nodes_per_type.reserve(2 * num_ntypes); for (int64_t etype = 0; etype < num_etypes; ++etype) { const auto src_dst_types = graph->GetEndpointTypes(etype); const dgl_type_t srctype = src_dst_types.first; const dgl_type_t dsttype = src_dst_types.second; const EdgeArray edges = graph->InEdges(etype, rhs_nodes[dsttype]); lhs_node_mappings[srctype].Update(edges.src); edge_arrays[etype] = edges; } const auto meta_graph = graph->meta_graph(); const EdgeArray etypes = meta_graph->Edges("eid"); const IdArray new_dst = Add(etypes.dst, num_ntypes); const auto new_meta_graph = ImmutableGraph::CreateFromCOO( num_ntypes * 2, etypes.src, new_dst); for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) num_nodes_per_type.push_back(lhs_node_mappings[ntype].Size()); for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) num_nodes_per_type.push_back(rhs_node_mappings[ntype].Size()); std::vector rel_graphs; std::vector induced_edges; for (int64_t etype = 0; etype < num_etypes; ++etype) { const auto src_dst_types = graph->GetEndpointTypes(etype); const dgl_type_t srctype = src_dst_types.first; const dgl_type_t dsttype = src_dst_types.second; const IdHashMap &lhs_map = lhs_node_mappings[srctype]; const IdHashMap &rhs_map = rhs_node_mappings[dsttype]; rel_graphs.push_back(CreateFromCOO( 2, lhs_map.Size(), rhs_map.Size(), lhs_map.Map(edge_arrays[etype].src, -1), rhs_map.Map(edge_arrays[etype].dst, -1))); induced_edges.push_back(edge_arrays[etype].id); } const HeteroGraphPtr new_graph = CreateHeteroGraph( new_meta_graph, rel_graphs, num_nodes_per_type); std::vector lhs_nodes; for (const IdHashMap &lhs_map : lhs_node_mappings) lhs_nodes.push_back(lhs_map.Values()); return std::make_tuple(new_graph, lhs_nodes, induced_edges); } }; // namespace std::tuple, std::vector> ToBlock(HeteroGraphPtr graph, const std::vector &rhs_nodes) { std::tuple, std::vector> ret; ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ret = ToBlock(graph, rhs_nodes); }); return ret; } DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock") .set_body([] (DGLArgs args, DGLRetValue *rv) { const HeteroGraphRef graph_ref = args[0]; const std::vector &rhs_nodes = ListValueToVector(args[1]); HeteroGraphPtr new_graph; std::vector lhs_nodes; std::vector induced_edges; std::tie(new_graph, lhs_nodes, induced_edges) = ToBlock(graph_ref.sptr(), rhs_nodes); List lhs_nodes_ref; for (IdArray &array : lhs_nodes) lhs_nodes_ref.push_back(Value(MakeValue(array))); List induced_edges_ref; for (IdArray &array : induced_edges) induced_edges_ref.push_back(Value(MakeValue(array))); List ret; ret.push_back(HeteroGraphRef(new_graph)); ret.push_back(lhs_nodes_ref); ret.push_back(induced_edges_ref); *rv = ret; }); }; // namespace transform }; // namespace dgl