/*! * Copyright (c) 2020 by Contributors * \file graph/heterograph_capi.cc * \brief Heterograph CAPI bindings. */ #include #include #include #include #include #include #include #include #include "../c_api_common.h" #include "./heterograph.h" #include "unit_graph.h" using namespace dgl::runtime; namespace dgl { ///////////////////////// Unitgraph functions ///////////////////////// // XXX(minjie): Ideally, Unitgraph should be invisible to python side DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO") .set_body([] (DGLArgs args, DGLRetValue* rv) { int64_t nvtypes = args[0]; int64_t num_src = args[1]; int64_t num_dst = args[2]; IdArray row = args[3]; IdArray col = args[4]; List formats = args[5]; bool row_sorted = args[6]; bool col_sorted = args[7]; std::vector formats_vec; for (Value val : formats) { std::string fmt = val->data; formats_vec.push_back(ParseSparseFormat(fmt)); } const auto code = SparseFormatsToCode(formats_vec); auto hgptr = CreateFromCOO(nvtypes, num_src, num_dst, row, col, row_sorted, col_sorted, code); *rv = HeteroGraphRef(hgptr); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCSR") .set_body([] (DGLArgs args, DGLRetValue* rv) { int64_t nvtypes = args[0]; int64_t num_src = args[1]; int64_t num_dst = args[2]; IdArray indptr = args[3]; IdArray indices = args[4]; IdArray edge_ids = args[5]; List formats = args[6]; bool transpose = args[7]; std::vector formats_vec; for (Value val : formats) { std::string fmt = val->data; formats_vec.push_back(ParseSparseFormat(fmt)); } const auto code = SparseFormatsToCode(formats_vec); if (!transpose) { auto hgptr = CreateFromCSR(nvtypes, num_src, num_dst, indptr, indices, edge_ids, code); *rv = HeteroGraphRef(hgptr); } else { auto hgptr = CreateFromCSC(nvtypes, num_src, num_dst, indptr, indices, edge_ids, code); *rv = HeteroGraphRef(hgptr); } }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraph") .set_body([] (DGLArgs args, DGLRetValue* rv) { GraphRef meta_graph = args[0]; List rel_graphs = args[1]; std::vector rel_ptrs; rel_ptrs.reserve(rel_graphs.size()); for (const auto& ref : rel_graphs) { rel_ptrs.push_back(ref.sptr()); } auto hgptr = CreateHeteroGraph(meta_graph.sptr(), rel_ptrs); *rv = HeteroGraphRef(hgptr); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraphWithNumNodes") .set_body([] (DGLArgs args, DGLRetValue* rv) { GraphRef meta_graph = args[0]; List rel_graphs = args[1]; IdArray num_nodes_per_type = args[2]; std::vector rel_ptrs; rel_ptrs.reserve(rel_graphs.size()); for (const auto& ref : rel_graphs) { rel_ptrs.push_back(ref.sptr()); } auto hgptr = CreateHeteroGraph( meta_graph.sptr(), rel_ptrs, num_nodes_per_type.ToVector()); *rv = HeteroGraphRef(hgptr); }); ///////////////////////// HeteroGraph member functions ///////////////////////// DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetMetaGraph") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; *rv = hg->meta_graph(); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsMetaGraphUniBipartite") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; GraphPtr mg = hg->meta_graph(); *rv = mg->IsUniBipartite(); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype; auto unit_graph = hg->GetRelationGraph(etype); auto meta_graph = unit_graph->meta_graph(); auto hgptr = CreateHeteroGraph( meta_graph, {unit_graph}, unit_graph->NumVerticesPerType()); *rv = HeteroGraphRef(hgptr); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; List etypes = args[1]; std::vector etypes_vec; for (Value val : etypes) { // (gq) have to decompose it into two statements because of a weird MSVC internal error dgl_id_t id = val->data; etypes_vec.push_back(id); } *rv = FlattenedHeteroGraphRef(hg->Flatten(etypes_vec)); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t vtype = args[1]; int64_t num = args[2]; hg->AddVertices(vtype, num); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdge") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; dgl_id_t src = args[2]; dgl_id_t dst = args[3]; hg->AddEdge(etype, src, dst); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdges") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; IdArray src = args[2]; IdArray dst = args[3]; hg->AddEdges(etype, src, dst); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroClear") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; hg->Clear(); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDataType") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; *rv = hg->DataType(); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroContext") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; *rv = hg->Context(); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsPinned") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; *rv = hg->IsPinned(); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumBits") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; *rv = hg->NumBits(); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsMultigraph") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; *rv = hg->IsMultigraph(); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsReadonly") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; *rv = hg->IsReadonly(); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumVertices") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t vtype = args[1]; *rv = static_cast(hg->NumVertices(vtype)); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumEdges") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; *rv = static_cast(hg->NumEdges(etype)); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertex") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t vtype = args[1]; dgl_id_t vid = args[2]; *rv = hg->HasVertex(vtype, vid); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertices") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t vtype = args[1]; IdArray vids = args[2]; *rv = hg->HasVertices(vtype, vids); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgeBetween") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; dgl_id_t src = args[2]; dgl_id_t dst = args[3]; *rv = hg->HasEdgeBetween(etype, src, dst); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgesBetween") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; IdArray src = args[2]; IdArray dst = args[3]; *rv = hg->HasEdgesBetween(etype, src, dst); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPredecessors") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; dgl_id_t dst = args[2]; *rv = hg->Predecessors(etype, dst); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSuccessors") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; dgl_id_t src = args[2]; *rv = hg->Successors(etype, src); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeId") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; dgl_id_t src = args[2]; dgl_id_t dst = args[3]; *rv = hg->EdgeId(etype, src, dst); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIdsAll") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; IdArray src = args[2]; IdArray dst = args[3]; const auto& ret = hg->EdgeIdsAll(etype, src, dst); *rv = ConvertEdgeArrayToPackedFunc(ret); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIdsOne") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; IdArray src = args[2]; IdArray dst = args[3]; *rv = hg->EdgeIdsOne(etype, src, dst); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroFindEdges") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; IdArray eids = args[2]; const auto& ret = hg->FindEdges(etype, eids); *rv = ConvertEdgeArrayToPackedFunc(ret); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_1") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; dgl_id_t vid = args[2]; const auto& ret = hg->InEdges(etype, vid); *rv = ConvertEdgeArrayToPackedFunc(ret); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_2") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; IdArray vids = args[2]; const auto& ret = hg->InEdges(etype, vids); *rv = ConvertEdgeArrayToPackedFunc(ret); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_1") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; dgl_id_t vid = args[2]; const auto& ret = hg->OutEdges(etype, vid); *rv = ConvertEdgeArrayToPackedFunc(ret); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_2") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; IdArray vids = args[2]; const auto& ret = hg->OutEdges(etype, vids); *rv = ConvertEdgeArrayToPackedFunc(ret); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdges") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; std::string order = args[2]; const auto& ret = hg->Edges(etype, order); *rv = ConvertEdgeArrayToPackedFunc(ret); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegree") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; dgl_id_t vid = args[2]; *rv = static_cast(hg->InDegree(etype, vid)); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegrees") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; IdArray vids = args[2]; *rv = hg->InDegrees(etype, vids); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegree") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; dgl_id_t vid = args[2]; *rv = static_cast(hg->OutDegree(etype, vid)); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegrees") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; IdArray vids = args[2]; *rv = hg->OutDegrees(etype, vids); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetAdj") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_type_t etype = args[1]; bool transpose = args[2]; std::string fmt = args[3]; *rv = ConvertNDArrayVectorToPackedFunc( hg->GetAdj(etype, transpose, fmt)); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroVertexSubgraph") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; List vids = args[1]; bool relabel_nodes = args[2]; CHECK(relabel_nodes) << "Node subgraph only supports relabel_nodes=True."; std::vector vid_vec; vid_vec.reserve(vids.size()); for (Value val : vids) { vid_vec.push_back(val->data); } std::shared_ptr subg( new HeteroSubgraph(hg->VertexSubgraph(vid_vec))); *rv = HeteroSubgraphRef(subg); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeSubgraph") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; List eids = args[1]; bool preserve_nodes = args[2]; std::vector eid_vec; eid_vec.reserve(eids.size()); for (Value val : eids) { eid_vec.push_back(val->data); } std::shared_ptr subg( new HeteroSubgraph(hg->EdgeSubgraph(eid_vec, preserve_nodes))); *rv = HeteroSubgraphRef(subg); }); ///////////////////////// HeteroSubgraph members ///////////////////////// DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetGraph") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroSubgraphRef subg = args[0]; *rv = HeteroGraphRef(subg->graph); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedVertices") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroSubgraphRef subg = args[0]; List induced_verts; for (IdArray arr : subg->induced_vertices) { induced_verts.push_back(Value(MakeValue(arr))); } *rv = induced_verts; }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedEdges") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroSubgraphRef subg = args[0]; List induced_edges; for (IdArray arr : subg->induced_edges) { induced_edges.push_back(Value(MakeValue(arr))); } *rv = induced_edges; }); ///////////////////////// Global functions and algorithms ///////////////////////// DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; int bits = args[1]; HeteroGraphPtr bhg_ptr = hg.sptr(); auto hg_ptr = std::dynamic_pointer_cast(bhg_ptr); HeteroGraphPtr hg_new; if (hg_ptr) { hg_new = HeteroGraph::AsNumBits(hg_ptr, bits); } else { hg_new = UnitGraph::AsNumBits(bhg_ptr, bits); } *rv = HeteroGraphRef(hg_new); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; int device_type = args[1]; int device_id = args[2]; DGLContext ctx; ctx.device_type = static_cast(device_type); ctx.device_id = device_id; HeteroGraphPtr hg_new = HeteroGraph::CopyTo(hg.sptr(), ctx); *rv = HeteroGraphRef(hg_new); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPinMemory_") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; auto hgindex = std::dynamic_pointer_cast(hg.sptr()); hgindex->PinMemory_(); *rv = hg; }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpinMemory_") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; auto hgindex = std::dynamic_pointer_cast(hg.sptr()); hgindex->UnpinMemory_(); *rv = hg; }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroRecordStream") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; DGLStreamHandle stream = args[1]; auto hgindex = std::dynamic_pointer_cast(hg.sptr()); hgindex->RecordStream(stream); *rv = hg; }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyToSharedMem") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; std::string name = args[1]; List ntypes = args[2]; List etypes = args[3]; List fmts = args[4]; auto ntypes_vec = ListValueToVector(ntypes); auto etypes_vec = ListValueToVector(etypes); std::set fmts_set; for (const auto &fmt : fmts) { std::string fmt_data = fmt->data; fmts_set.insert(fmt_data); } auto hg_share = HeteroGraph::CopyToSharedMem( hg.sptr(), name, ntypes_vec, etypes_vec, fmts_set); *rv = HeteroGraphRef(hg_share); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFromSharedMem") .set_body([] (DGLArgs args, DGLRetValue* rv) { std::string name = args[0]; HeteroGraphPtr hg; std::vector ntypes; std::vector etypes; std::tie(hg, ntypes, etypes) = HeteroGraph::CreateFromSharedMem(name); List ntypes_list; List etypes_list; for (const auto &ntype : ntypes) ntypes_list.push_back(Value(MakeValue(ntype))); for (const auto &etype : etypes) etypes_list.push_back(Value(MakeValue(etype))); List ret; ret.push_back(HeteroGraphRef(hg)); ret.push_back(ntypes_list); ret.push_back(etypes_list); *rv = ret; }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroJointUnion") .set_body([] (DGLArgs args, DGLRetValue* rv) { GraphRef meta_graph = args[0]; List component_graphs = args[1]; CHECK(component_graphs.size() > 1) << "Expect graph list to have at least two graphs"; std::vector component_ptrs; component_ptrs.reserve(component_graphs.size()); const int64_t bits = component_graphs[0]->NumBits(); const DGLContext ctx = component_graphs[0]->Context(); for (const auto& component : component_graphs) { component_ptrs.push_back(component.sptr()); CHECK_EQ(component->NumBits(), bits) << "Expect graphs to joint union have the same index dtype(int" << bits << "), but got int" << component->NumBits(); CHECK_EQ(component->Context(), ctx) << "Expect graphs to joint union have the same context" << ctx << "), but got " << component->Context(); } auto hgptr = JointUnionHeteroGraph(meta_graph.sptr(), component_ptrs); *rv = HeteroGraphRef(hgptr); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion_v2") .set_body([] (DGLArgs args, DGLRetValue* rv) { GraphRef meta_graph = args[0]; List component_graphs = args[1]; CHECK(component_graphs.size() > 0) << "Expect graph list has at least one graph"; std::vector component_ptrs; component_ptrs.reserve(component_graphs.size()); const int64_t bits = component_graphs[0]->NumBits(); const DGLContext ctx = component_graphs[0]->Context(); for (const auto& component : component_graphs) { component_ptrs.push_back(component.sptr()); CHECK_EQ(component->NumBits(), bits) << "Expect graphs to batch have the same index dtype(int" << bits << "), but got int" << component->NumBits(); CHECK_EQ(component->Context(), ctx) << "Expect graphs to batch have the same context" << ctx << "), but got " << component->Context(); } auto hgptr = DisjointUnionHeteroGraph2(meta_graph.sptr(), component_ptrs); *rv = HeteroGraphRef(hgptr); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes_v2") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; const IdArray vertex_sizes = args[1]; const IdArray edge_sizes = args[2]; std::vector ret; ret = DisjointPartitionHeteroBySizes2(hg->meta_graph(), hg.sptr(), vertex_sizes, edge_sizes); List ret_list; for (HeteroGraphPtr hgptr : ret) { ret_list.push_back(HeteroGraphRef(hgptr)); } *rv = ret_list; }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; const IdArray vertex_sizes = args[1]; const IdArray edge_sizes = args[2]; const int64_t bits = hg->NumBits(); std::vector ret; ATEN_ID_BITS_SWITCH(bits, IdType, { ret = DisjointPartitionHeteroBySizes(hg->meta_graph(), hg.sptr(), vertex_sizes, edge_sizes); }); List ret_list; for (HeteroGraphPtr hgptr : ret) { ret_list.push_back(HeteroGraphRef(hgptr)); } *rv = ret_list; }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSlice") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; const IdArray num_nodes_per_type = args[1]; const IdArray start_nid_per_type = args[2]; const IdArray num_edges_per_type = args[3]; const IdArray start_eid_per_type = args[4]; auto hgptr = SliceHeteroGraph(hg->meta_graph(), hg.sptr(), num_nodes_per_type, start_nid_per_type, num_edges_per_type, start_eid_per_type); *rv = HeteroGraphRef(hgptr); }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetCreatedFormats") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; List format_list; dgl_format_code_t code = hg->GetRelationGraph(0)->GetCreatedFormats(); for (auto format : CodeToSparseFormats(code)) { format_list.push_back( Value(MakeValue(ToStringSparseFormat(format)))); } *rv = format_list; }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetAllowedFormats") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; List format_list; dgl_format_code_t code = hg->GetRelationGraph(0)->GetAllowedFormats(); for (auto format : CodeToSparseFormats(code)) { format_list.push_back( Value(MakeValue(ToStringSparseFormat(format)))); } *rv = format_list; }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFormat") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; dgl_format_code_t code = hg->GetRelationGraph(0)->GetAllowedFormats(); auto get_format_f = [&](size_t etype_b, size_t etype_e) { for (auto etype = etype_b; etype < etype_e; ++etype) { auto bg = std::dynamic_pointer_cast(hg->GetRelationGraph(etype)); for (auto format : CodeToSparseFormats(code)) bg->GetFormat(format); } }; #if !(defined(DGL_USE_CUDA)) runtime::parallel_for(0, hg->NumEdgeTypes(), get_format_f); #else get_format_f(0, hg->NumEdgeTypes()); #endif }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatGraph") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; List formats = args[1]; std::vector formats_vec; for (Value val : formats) { std::string fmt = val->data; formats_vec.push_back(ParseSparseFormat(fmt)); } auto hgptr = hg->GetGraphInFormat( SparseFormatsToCode(formats_vec)); *rv = HeteroGraphRef(hgptr); }); DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLInSubgraph") .set_body([] (DGLArgs args, DGLRetValue *rv) { HeteroGraphRef hg = args[0]; const auto& nodes = ListValueToVector(args[1]); bool relabel_nodes = args[2]; std::shared_ptr ret(new HeteroSubgraph); *ret = InEdgeGraph(hg.sptr(), nodes, relabel_nodes); *rv = HeteroGraphRef(ret); }); DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLOutSubgraph") .set_body([] (DGLArgs args, DGLRetValue *rv) { HeteroGraphRef hg = args[0]; const auto& nodes = ListValueToVector(args[1]); bool relabel_nodes = args[2]; std::shared_ptr ret(new HeteroSubgraph); *ret = OutEdgeGraph(hg.sptr(), nodes, relabel_nodes); *rv = HeteroGraphRef(ret); }); DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsImmutableGraph") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; *rv = GraphRef(hg->AsImmutableGraph()); }); DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortOutEdges") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; NDArray tag = args[1]; int64_t num_tag = args[2]; CHECK_EQ(hg->Context().device_type, kDGLCPU) << "Only support sorting by tag on cpu"; CHECK(aten::IsValidIdArray(tag)); CHECK_EQ(tag->ctx.device_type, kDGLCPU) << "Only support sorting by tag on cpu"; const auto csr = hg->GetCSRMatrix(0); NDArray tag_pos = aten::NullArray(); aten::CSRMatrix output; std::tie(output, tag_pos) = aten::CSRSortByTag(csr, tag, num_tag); HeteroGraphPtr output_hg = CreateFromCSR(hg->NumVertexTypes(), output, ALL_CODE); List ret; ret.push_back(HeteroGraphRef(output_hg)); ret.push_back(Value(MakeValue(tag_pos))); *rv = ret; }); DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortInEdges") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; NDArray tag = args[1]; int64_t num_tag = args[2]; CHECK_EQ(hg->Context().device_type, kDGLCPU) << "Only support sorting by tag on cpu"; CHECK(aten::IsValidIdArray(tag)); CHECK_EQ(tag->ctx.device_type, kDGLCPU) << "Only support sorting by tag on cpu"; const auto csc = hg->GetCSCMatrix(0); NDArray tag_pos = aten::NullArray(); aten::CSRMatrix output; std::tie(output, tag_pos) = aten::CSRSortByTag(csc, tag, num_tag); HeteroGraphPtr output_hg = CreateFromCSC(hg->NumVertexTypes(), output, ALL_CODE); List ret; ret.push_back(HeteroGraphRef(output_hg)); ret.push_back(Value(MakeValue(tag_pos))); *rv = ret; }); DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes") .set_body([] (DGLArgs args, DGLRetValue* rv) { GraphRef metagraph = args[0]; std::unordered_set dst_set; std::unordered_set src_set; for (int64_t eid = 0; eid < metagraph->NumEdges(); ++eid) { auto edge = metagraph->FindEdge(eid); auto src = edge.first; auto dst = edge.second; dst_set.insert(dst); src_set.insert(src); } List srclist, dstlist; List> ret_list; for (int64_t nid = 0; nid < metagraph->NumVertices(); ++nid) { auto is_dst = dst_set.count(nid); auto is_src = src_set.count(nid); if (is_dst && is_src) return; else if (is_dst) dstlist.push_back(Value(MakeValue(nid))); else // If a node type is isolated, put it in srctype as defined in the Python docstring. srclist.push_back(Value(MakeValue(nid))); } ret_list.push_back(srclist); ret_list.push_back(dstlist); *rv = ret_list; }); DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroReverse") .set_body([] (DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; CHECK_GT(hg->NumEdgeTypes(), 0); auto g = std::dynamic_pointer_cast(hg.sptr()); std::vector rev_ugs; const auto &ugs = g->relation_graphs(); rev_ugs.resize(ugs.size()); for (size_t i = 0; i < ugs.size(); ++i) { const auto &rev_ug = ugs[i]->Reverse(); rev_ugs[i] = rev_ug; } // node types are not changed const auto& num_nodes = g->NumVerticesPerType(); const auto& meta_edges = hg->meta_graph()->Edges("eid"); // reverse the metagraph const auto& rev_meta = ImmutableGraph::CreateFromCOO(hg->meta_graph()->NumVertices(), meta_edges.dst, meta_edges.src); *rv = CreateHeteroGraph(rev_meta, rev_ugs, num_nodes); }); } // namespace dgl