to_simple.cc 1.9 KB
Newer Older
1
2
/*!
 *  Copyright (c) 2019 by Contributors
3
4
 * @file graph/transform/to_simple.cc
 * @brief Convert multigraphs to simple graphs
5
6
7
 */

#include <dgl/array.h>
8
#include <dgl/base_heterograph.h>
9
#include <dgl/packed_func_ext.h>
10
11
#include <dgl/transform.h>

12
#include <utility>
13
14
15
#include <vector>

#include "../../c_api_common.h"
16
#include "../heterograph.h"
17
18
19
20
21
22
23
24
25
26
27
28
29
#include "../unit_graph.h"

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace transform {

std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
ToSimpleGraph(const HeteroGraphPtr graph) {
  const int64_t num_etypes = graph->NumEdgeTypes();
  const auto metagraph = graph->meta_graph();
30
31
  const auto &ugs =
      std::dynamic_pointer_cast<HeteroGraph>(graph)->relation_graphs();
32
33
34
35
36

  std::vector<IdArray> counts(num_etypes), edge_maps(num_etypes);
  std::vector<HeteroGraphPtr> rel_graphs(num_etypes);

  for (int64_t etype = 0; etype < num_etypes; ++etype) {
37
38
    const auto result = ugs[etype]->ToSimple();
    std::tie(rel_graphs[etype], counts[etype], edge_maps[etype]) = result;
39
40
  }

41
42
  const HeteroGraphPtr result =
      CreateHeteroGraph(metagraph, rel_graphs, graph->NumVerticesPerType());
43
44
45
46
47

  return std::make_tuple(result, counts, edge_maps);
}

DGL_REGISTER_GLOBAL("transform._CAPI_DGLToSimpleHetero")
48
49
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      const HeteroGraphRef graph_ref = args[0];
50

51
      const auto result = ToSimpleGraph(graph_ref.sptr());
52

53
54
55
56
57
      List<Value> counts, edge_maps;
      for (const IdArray &count : std::get<1>(result))
        counts.push_back(Value(MakeValue(count)));
      for (const IdArray &edge_map : std::get<2>(result))
        edge_maps.push_back(Value(MakeValue(edge_map)));
58

59
60
61
62
      List<ObjectRef> ret;
      ret.push_back(HeteroGraphRef(std::get<0>(result)));
      ret.push_back(counts);
      ret.push_back(edge_maps);
63

64
65
      *rv = ret;
    });
66
67
68
69

};  // namespace transform

};  // namespace dgl