compact.cc 4.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
/*!
 *  Copyright (c) 2019 by Contributors
 * \file graph/transform/compact.cc
 * \brief Compact graph implementation
 */

#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <vector>
#include <utility>
#include "../../c_api_common.h"
#include "../unit_graph.h"
// TODO(BarclayII): currently CompactGraphs depend on IdHashMap implementation which
// only works on CPU.  Should fix later to make it device agnostic.
#include "../../array/cpu/array_utils.h"

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace transform {

namespace {

template<typename IdType>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs(
    const std::vector<HeteroGraphPtr> &graphs,
    const std::vector<IdArray> &always_preserve) {
  // TODO(BarclayII): check whether the node space and metagraph of each graph is the same.
  // Step 1: Collect the nodes that has connections for each type.
35
36
  const int64_t num_ntypes = graphs[0]->NumVertexTypes();
  std::vector<aten::IdHashMap<IdType>> hashmaps(num_ntypes);
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
  std::vector<std::vector<EdgeArray>> all_edges(graphs.size());   // all_edges[i][etype]

  for (size_t i = 0; i < always_preserve.size(); ++i)
    hashmaps[i].Update(always_preserve[i]);

  for (size_t i = 0; i < graphs.size(); ++i) {
    const HeteroGraphPtr curr_graph = graphs[i];
    const int64_t num_etypes = curr_graph->NumEdgeTypes();

    for (IdType etype = 0; etype < num_etypes; ++etype) {
      IdType srctype, dsttype;
      std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype);

      const EdgeArray edges = curr_graph->Edges(etype, "eid");

      hashmaps[srctype].Update(edges.src);
      hashmaps[dsttype].Update(edges.dst);

      all_edges[i].push_back(edges);
    }
  }

  // Step 2: Relabel the nodes for each type to a smaller ID space and save the mapping.
60
61
62
63
64
65
  std::vector<IdArray> induced_nodes(num_ntypes);
  std::vector<int64_t> num_induced_nodes(num_ntypes);
  for (int64_t i = 0; i < num_ntypes; ++i) {
    induced_nodes[i] = hashmaps[i].Values();
    num_induced_nodes[i] = hashmaps[i].Size();
  }
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

  // Step 3: Remap the edges of each graph.
  std::vector<HeteroGraphPtr> new_graphs;
  for (size_t i = 0; i < graphs.size(); ++i) {
    std::vector<HeteroGraphPtr> rel_graphs;
    const HeteroGraphPtr curr_graph = graphs[i];
    const auto meta_graph = curr_graph->meta_graph();
    const int64_t num_etypes = curr_graph->NumEdgeTypes();

    for (IdType etype = 0; etype < num_etypes; ++etype) {
      IdType srctype, dsttype;
      std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype);
      const EdgeArray &edges = all_edges[i][etype];

      const IdArray mapped_rows = hashmaps[srctype].Map(edges.src, -1);
      const IdArray mapped_cols = hashmaps[dsttype].Map(edges.dst, -1);

      rel_graphs.push_back(UnitGraph::CreateFromCOO(
          srctype == dsttype ? 1 : 2,
          induced_nodes[srctype]->shape[0],
          induced_nodes[dsttype]->shape[0],
          mapped_rows,
          mapped_cols));
    }

91
    new_graphs.push_back(CreateHeteroGraph(meta_graph, rel_graphs, num_induced_nodes));
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
  }

  return std::make_pair(new_graphs, induced_nodes);
}

};  // namespace

std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs(
    const std::vector<HeteroGraphPtr> &graphs,
    const std::vector<IdArray> &always_preserve) {
  std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> result;
  // TODO(BarclayII): check for all IdArrays
  CHECK(graphs[0]->DataType() == always_preserve[0]->dtype) << "data type mismatch.";
  ATEN_ID_TYPE_SWITCH(graphs[0]->DataType(), IdType, {
    result = CompactGraphs<IdType>(graphs, always_preserve);
  });
  return result;
}

DGL_REGISTER_GLOBAL("transform._CAPI_DGLCompactGraphs")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    List<HeteroGraphRef> graph_refs = args[0];
    List<Value> always_preserve_refs = args[1];

    std::vector<HeteroGraphPtr> graphs;
    std::vector<IdArray> always_preserve;
    for (HeteroGraphRef gref : graph_refs)
      graphs.push_back(gref.sptr());
    for (Value array : always_preserve_refs)
      always_preserve.push_back(array->data);

    const auto &result_pair = CompactGraphs(graphs, always_preserve);

    List<HeteroGraphRef> compacted_graph_refs;
    List<Value> induced_nodes;

    for (const HeteroGraphPtr g : result_pair.first)
      compacted_graph_refs.push_back(HeteroGraphRef(g));
    for (const IdArray &ids : result_pair.second)
      induced_nodes.push_back(Value(MakeValue(ids)));

    List<ObjectRef> result;
    result.push_back(compacted_graph_refs);
    result.push_back(induced_nodes);

    *rv = result;
  });

};  // namespace transform

};  // namespace dgl