compact.cc 6.64 KB
Newer Older
1
/**
2
3
4
5
6
7
8
9
10
11
12
13
14
15
 *  Copyright 2019-2021 Contributors
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
16
17
 * @file graph/transform/compact.cc
 * @brief Compact graph implementation
18
19
 */

20
21
#include "compact.h"

22
#include <dgl/array.h>
23
#include <dgl/base_heterograph.h>
24
#include <dgl/packed_func_ext.h>
25
#include <dgl/runtime/container.h>
26
27
28
#include <dgl/runtime/registry.h>
#include <dgl/transform.h>

29
#include <utility>
30
31
#include <vector>

32
33
#include "../../c_api_common.h"
#include "../unit_graph.h"
34
35
// TODO(BarclayII): currently CompactGraphs depend on IdHashMap implementation
// which only works on CPU.  Should fix later to make it device agnostic.
36
37
38
39
40
41
42
43
44
45
46
#include "../../array/cpu/array_utils.h"

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace transform {

namespace {

47
48
template <typename IdType>
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> CompactGraphsCPU(
49
50
    const std::vector<HeteroGraphPtr> &graphs,
    const std::vector<IdArray> &always_preserve) {
51
52
  // TODO(BarclayII): check whether the node space and metagraph of each graph
  // is the same. Step 1: Collect the nodes that has connections for each type.
53
54
  const int64_t num_ntypes = graphs[0]->NumVertexTypes();
  std::vector<aten::IdHashMap<IdType>> hashmaps(num_ntypes);
55
56
  std::vector<std::vector<EdgeArray>> all_edges(
      graphs.size());  // all_edges[i][etype]
57

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
  std::vector<int64_t> max_vertex_cnt(num_ntypes, 0);
  for (size_t i = 0; i < graphs.size(); ++i) {
    const HeteroGraphPtr curr_graph = graphs[i];
    const int64_t num_etypes = curr_graph->NumEdgeTypes();

    for (IdType etype = 0; etype < num_etypes; ++etype) {
      IdType srctype, dsttype;
      std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype);

      const int64_t n_edges = curr_graph->NumEdges(etype);
      max_vertex_cnt[srctype] += n_edges;
      max_vertex_cnt[dsttype] += n_edges;
    }
  }

  // Reserve the space for hash maps before ahead to aoivd rehashing
74
  for (size_t i = 0; i < static_cast<size_t>(num_ntypes); ++i) {
75
76
77
78
79
80
81
    if (i < always_preserve.size())
      hashmaps[i].Reserve(always_preserve[i]->shape[0] + max_vertex_cnt[i]);
    else
      hashmaps[i].Reserve(max_vertex_cnt[i]);
  }

  for (size_t i = 0; i < always_preserve.size(); ++i) {
82
    hashmaps[i].Update(always_preserve[i]);
83
  }
84
85
86
87
88

  for (size_t i = 0; i < graphs.size(); ++i) {
    const HeteroGraphPtr curr_graph = graphs[i];
    const int64_t num_etypes = curr_graph->NumEdgeTypes();

89
    all_edges[i].reserve(num_etypes);
90
91
92
93
94
95
96
97
98
99
100
101
102
    for (IdType etype = 0; etype < num_etypes; ++etype) {
      IdType srctype, dsttype;
      std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype);

      const EdgeArray edges = curr_graph->Edges(etype, "eid");

      hashmaps[srctype].Update(edges.src);
      hashmaps[dsttype].Update(edges.dst);

      all_edges[i].push_back(edges);
    }
  }

103
104
  // Step 2: Relabel the nodes for each type to a smaller ID space and save the
  // mapping.
105
106
107
108
109
110
  std::vector<IdArray> induced_nodes(num_ntypes);
  std::vector<int64_t> num_induced_nodes(num_ntypes);
  for (int64_t i = 0; i < num_ntypes; ++i) {
    induced_nodes[i] = hashmaps[i].Values();
    num_induced_nodes[i] = hashmaps[i].Size();
  }
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

  // Step 3: Remap the edges of each graph.
  std::vector<HeteroGraphPtr> new_graphs;
  for (size_t i = 0; i < graphs.size(); ++i) {
    std::vector<HeteroGraphPtr> rel_graphs;
    const HeteroGraphPtr curr_graph = graphs[i];
    const auto meta_graph = curr_graph->meta_graph();
    const int64_t num_etypes = curr_graph->NumEdgeTypes();

    for (IdType etype = 0; etype < num_etypes; ++etype) {
      IdType srctype, dsttype;
      std::tie(srctype, dsttype) = curr_graph->GetEndpointTypes(etype);
      const EdgeArray &edges = all_edges[i][etype];

      const IdArray mapped_rows = hashmaps[srctype].Map(edges.src, -1);
      const IdArray mapped_cols = hashmaps[dsttype].Map(edges.dst, -1);

      rel_graphs.push_back(UnitGraph::CreateFromCOO(
129
130
          srctype == dsttype ? 1 : 2, induced_nodes[srctype]->shape[0],
          induced_nodes[dsttype]->shape[0], mapped_rows, mapped_cols));
131
132
    }

133
134
    new_graphs.push_back(
        CreateHeteroGraph(meta_graph, rel_graphs, num_induced_nodes));
135
136
137
138
139
140
141
  }

  return std::make_pair(new_graphs, induced_nodes);
}

};  // namespace

142
template <>
143
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
144
CompactGraphs<kDGLCPU, int32_t>(
145
146
    const std::vector<HeteroGraphPtr> &graphs,
    const std::vector<IdArray> &always_preserve) {
147
148
149
  return CompactGraphsCPU<int32_t>(graphs, always_preserve);
}

150
template <>
151
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
152
CompactGraphs<kDGLCPU, int64_t>(
153
154
155
    const std::vector<HeteroGraphPtr> &graphs,
    const std::vector<IdArray> &always_preserve) {
  return CompactGraphsCPU<int64_t>(graphs, always_preserve);
156
157
158
}

DGL_REGISTER_GLOBAL("transform._CAPI_DGLCompactGraphs")
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      List<HeteroGraphRef> graph_refs = args[0];
      List<Value> always_preserve_refs = args[1];

      std::vector<HeteroGraphPtr> graphs;
      std::vector<IdArray> always_preserve;
      for (HeteroGraphRef gref : graph_refs) graphs.push_back(gref.sptr());
      for (Value array : always_preserve_refs)
        always_preserve.push_back(array->data);

      // TODO(BarclayII): check for all IdArrays
      CHECK(graphs[0]->DataType() == always_preserve[0]->dtype)
          << "data type mismatch.";

      std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> result_pair;

      ATEN_XPU_SWITCH_CUDA(
          graphs[0]->Context().device_type, XPU, "CompactGraphs", {
            ATEN_ID_TYPE_SWITCH(graphs[0]->DataType(), IdType, {
              result_pair = CompactGraphs<XPU, IdType>(graphs, always_preserve);
            });
          });

      List<HeteroGraphRef> compacted_graph_refs;
      List<Value> induced_nodes;

185
      for (const HeteroGraphPtr &g : result_pair.first)
186
187
188
189
190
191
192
193
194
        compacted_graph_refs.push_back(HeteroGraphRef(g));
      for (const IdArray &ids : result_pair.second)
        induced_nodes.push_back(Value(MakeValue(ids)));

      List<ObjectRef> result;
      result.push_back(compacted_graph_refs);
      result.push_back(induced_nodes);

      *rv = result;
195
    });
196
197
198
199

};  // namespace transform

};  // namespace dgl