"docs/api/errors.mdx" did not exist on "1188f408dd3c5d9739bd85c4df8250f4acb1b31f"
to_simple.cc 1.87 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
/*!
 *  Copyright (c) 2019 by Contributors
 * \file graph/transform/to_simple.cc
 * \brief Convert multigraphs to simple graphs
 */

#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <vector>
#include <utility>
13
#include "../heterograph.h"
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include "../unit_graph.h"
#include "../../c_api_common.h"

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace transform {

std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
ToSimpleGraph(const HeteroGraphPtr graph) {
  const int64_t num_etypes = graph->NumEdgeTypes();
  const auto metagraph = graph->meta_graph();
28
  const auto &ugs = std::dynamic_pointer_cast<HeteroGraph>(graph)->relation_graphs();
29
30
31
32
33

  std::vector<IdArray> counts(num_etypes), edge_maps(num_etypes);
  std::vector<HeteroGraphPtr> rel_graphs(num_etypes);

  for (int64_t etype = 0; etype < num_etypes; ++etype) {
34
35
    const auto result = ugs[etype]->ToSimple();
    std::tie(rel_graphs[etype], counts[etype], edge_maps[etype]) = result;
36
37
  }

38
39
  const HeteroGraphPtr result = CreateHeteroGraph(
      metagraph, rel_graphs, graph->NumVerticesPerType());
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

  return std::make_tuple(result, counts, edge_maps);
}

DGL_REGISTER_GLOBAL("transform._CAPI_DGLToSimpleHetero")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
    const HeteroGraphRef graph_ref = args[0];

    const auto result = ToSimpleGraph(graph_ref.sptr());

    List<Value> counts, edge_maps;
    for (const IdArray &count : std::get<1>(result))
      counts.push_back(Value(MakeValue(count)));
    for (const IdArray &edge_map : std::get<2>(result))
      edge_maps.push_back(Value(MakeValue(edge_map)));

    List<ObjectRef> ret;
    ret.push_back(HeteroGraphRef(std::get<0>(result)));
    ret.push_back(counts);
    ret.push_back(edge_maps);

    *rv = ret;
  });

};  // namespace transform

};  // namespace dgl