to_bipartite.cc 8.24 KB
Newer Older
1
/**
2
3
4
5
6
7
8
9
10
11
12
13
14
15
 *  Copyright 2019-2021 Contributors
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
16
17
 * @file graph/transform/to_bipartite.cc
 * @brief Convert a graph to a bipartite-structured graph.
18
19
 */

20
21
#include "to_bipartite.h"

22
#include <dgl/array.h>
23
#include <dgl/base_heterograph.h>
24
#include <dgl/immutable_graph.h>
25
#include <dgl/packed_func_ext.h>
26
#include <dgl/runtime/container.h>
27
28
29
#include <dgl/runtime/registry.h>
#include <dgl/transform.h>

30
#include <tuple>
31
#include <utility>
32
33
#include <vector>

34
35
36
37
38
39
40
41
42
43
44
#include "../../array/cpu/array_utils.h"

namespace dgl {

using namespace dgl::runtime;
using namespace dgl::aten;

namespace transform {

namespace {

45
// Since partial specialization is not allowed for functions, use this as an
46
// intermediate for ToBlock where XPU = kDGLCPU.
47
48
49
50
51
template <typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockCPU(
    HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
    bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes_ptr) {
  std::vector<IdArray> &lhs_nodes = *lhs_nodes_ptr;
52
53
  const bool generate_lhs_nodes = lhs_nodes.empty();

54
55
56
57
58
  const int64_t num_etypes = graph->NumEdgeTypes();
  const int64_t num_ntypes = graph->NumVertexTypes();
  std::vector<EdgeArray> edge_arrays(num_etypes);

  CHECK(rhs_nodes.size() == static_cast<size_t>(num_ntypes))
59
      << "rhs_nodes not given for every node type";
60

61
62
  const std::vector<IdHashMap<IdType>> rhs_node_mappings(
      rhs_nodes.begin(), rhs_nodes.end());
63
64
  std::vector<IdHashMap<IdType>> lhs_node_mappings;

65
  if (generate_lhs_nodes) {
66
    // build lhs_node_mappings -- if we don't have them already
67
68
69
70
71
    if (include_rhs_in_lhs)
      lhs_node_mappings = rhs_node_mappings;  // copy
    else
      lhs_node_mappings.resize(num_ntypes);
  } else {
72
73
    lhs_node_mappings =
        std::vector<IdHashMap<IdType>>(lhs_nodes.begin(), lhs_nodes.end());
74
  }
75

76
77
78
79
  for (int64_t etype = 0; etype < num_etypes; ++etype) {
    const auto src_dst_types = graph->GetEndpointTypes(etype);
    const dgl_type_t srctype = src_dst_types.first;
    const dgl_type_t dsttype = src_dst_types.second;
80
    if (!aten::IsNullArray(rhs_nodes[dsttype])) {
81
      const EdgeArray &edges = graph->Edges(etype);
82
83
84
      if (generate_lhs_nodes) {
        lhs_node_mappings[srctype].Update(edges.src);
      }
85
86
      edge_arrays[etype] = edges;
    }
87
88
  }

89
90
91
  std::vector<int64_t> num_nodes_per_type;
  num_nodes_per_type.reserve(2 * num_ntypes);

92
93
94
  const auto meta_graph = graph->meta_graph();
  const EdgeArray etypes = meta_graph->Edges("eid");
  const IdArray new_dst = Add(etypes.dst, num_ntypes);
95
96
  const auto new_meta_graph =
      ImmutableGraph::CreateFromCOO(num_ntypes * 2, etypes.src, new_dst);
97
98
99
100
101
102
103
104
105
106
107
108
109
110

  for (int64_t ntype = 0; ntype < num_ntypes; ++ntype)
    num_nodes_per_type.push_back(lhs_node_mappings[ntype].Size());
  for (int64_t ntype = 0; ntype < num_ntypes; ++ntype)
    num_nodes_per_type.push_back(rhs_node_mappings[ntype].Size());

  std::vector<HeteroGraphPtr> rel_graphs;
  std::vector<IdArray> induced_edges;
  for (int64_t etype = 0; etype < num_etypes; ++etype) {
    const auto src_dst_types = graph->GetEndpointTypes(etype);
    const dgl_type_t srctype = src_dst_types.first;
    const dgl_type_t dsttype = src_dst_types.second;
    const IdHashMap<IdType> &lhs_map = lhs_node_mappings[srctype];
    const IdHashMap<IdType> &rhs_map = rhs_node_mappings[dsttype];
111
112
113
    if (rhs_map.Size() == 0) {
      // No rhs nodes are given for this edge type. Create an empty graph.
      rel_graphs.push_back(CreateFromCOO(
114
115
          2, lhs_map.Size(), rhs_map.Size(), aten::NullArray(),
          aten::NullArray()));
116
117
118
119
120
121
122
      induced_edges.push_back(aten::NullArray());
    } else {
      IdArray new_src = lhs_map.Map(edge_arrays[etype].src, -1);
      IdArray new_dst = rhs_map.Map(edge_arrays[etype].dst, -1);
      // Check whether there are unmapped IDs and raise error.
      for (int64_t i = 0; i < new_dst->shape[0]; ++i)
        CHECK_NE(new_dst.Ptr<IdType>()[i], -1)
123
124
125
126
127
128
            << "Node " << edge_arrays[etype].dst.Ptr<IdType>()[i]
            << " does not exist"
            << " in `rhs_nodes`. Argument `rhs_nodes` must contain all the edge"
            << " destination nodes.";
      rel_graphs.push_back(
          CreateFromCOO(2, lhs_map.Size(), rhs_map.Size(), new_src, new_dst));
129
130
      induced_edges.push_back(edge_arrays[etype].id);
    }
131
132
  }

133
134
  const HeteroGraphPtr new_graph =
      CreateHeteroGraph(new_meta_graph, rel_graphs, num_nodes_per_type);
135
136
137

  if (generate_lhs_nodes) {
    CHECK_EQ(lhs_nodes.size(), 0) << "InteralError: lhs_nodes should be empty "
138
                                     "when generating it.";
139
140
141
142
    for (const IdHashMap<IdType> &lhs_map : lhs_node_mappings)
      lhs_nodes.push_back(lhs_map.Values());
  }
  return std::make_tuple(new_graph, induced_edges);
143
144
}

145
}  // namespace
146

147
148
149
150
template <>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCPU, int32_t>(
    HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
    bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes) {
151
  return ToBlockCPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
152
153
}

154
155
156
157
template <>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCPU, int64_t>(
    HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
    bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes) {
158
  return ToBlockCPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
159
160
}

161
162
#ifdef DGL_USE_CUDA

163
164
// Forward declaration of GPU ToBlock implementations - actual implementation is
// in
165
// ./cuda/cuda_to_block.cu
166
167
168
169
170
171
172
173
174
175
176
177
178
179
// This is to get around the broken name mangling in VS2019 CL 16.5.5 +
// CUDA 11.3 which complains that the two template specializations have the same
// signature.
std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockGPU32(
    HeteroGraphPtr, const std::vector<IdArray> &, bool,
    std::vector<IdArray> *const);
std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockGPU64(
    HeteroGraphPtr, const std::vector<IdArray> &, bool,
    std::vector<IdArray> *const);

template <>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCUDA, int32_t>(
    HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
    bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes) {
180
181
182
  return ToBlockGPU32(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
}

183
184
185
186
template <>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCUDA, int64_t>(
    HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
    bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes) {
187
188
189
190
191
  return ToBlockGPU64(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
}

#endif  // DGL_USE_CUDA

192
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock")
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    .set_body([](DGLArgs args, DGLRetValue *rv) {
      const HeteroGraphRef graph_ref = args[0];
      const std::vector<IdArray> &rhs_nodes =
          ListValueToVector<IdArray>(args[1]);
      const bool include_rhs_in_lhs = args[2];
      std::vector<IdArray> lhs_nodes = ListValueToVector<IdArray>(args[3]);

      HeteroGraphPtr new_graph;
      std::vector<IdArray> induced_edges;

      ATEN_XPU_SWITCH_CUDA(graph_ref->Context().device_type, XPU, "ToBlock", {
        ATEN_ID_TYPE_SWITCH(graph_ref->DataType(), IdType, {
          std::tie(new_graph, induced_edges) = ToBlock<XPU, IdType>(
              graph_ref.sptr(), rhs_nodes, include_rhs_in_lhs, &lhs_nodes);
        });
208
      });
209

210
211
212
213
214
215
      List<Value> lhs_nodes_ref;
      for (IdArray &array : lhs_nodes)
        lhs_nodes_ref.push_back(Value(MakeValue(array)));
      List<Value> induced_edges_ref;
      for (IdArray &array : induced_edges)
        induced_edges_ref.push_back(Value(MakeValue(array)));
216

217
218
219
220
      List<ObjectRef> ret;
      ret.push_back(HeteroGraphRef(new_graph));
      ret.push_back(lhs_nodes_ref);
      ret.push_back(induced_edges_ref);
221

222
223
      *rv = ret;
    });
224
225
226
227

};  // namespace transform

};  // namespace dgl