cuda_map_edges.cuh 7.82 KB
Newer Older
1
/**
2
3
4
5
6
7
8
9
10
11
12
13
14
15
 *  Copyright 2020-2021 Contributors
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
16
17
 * @file graph/transform/cuda/cuda_map_edges.cuh
 * @brief Device level functions for mapping edges.
18
19
20
21
22
23
 */

#ifndef DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_
#define DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_

#include <cuda_runtime.h>
24
25
#include <dgl/runtime/c_runtime_api.h>

26
27
28
29
#include <algorithm>
#include <memory>
#include <tuple>
#include <utility>
30
#include <vector>
31
32
33
34
35
36
37
38
39
40
41
42

#include "../../../runtime/cuda/cuda_common.h"
#include "../../../runtime/cuda/cuda_hashtable.cuh"

using namespace dgl::aten;
using namespace dgl::runtime::cuda;

namespace dgl {
namespace transform {

namespace cuda {

43
template <typename IdType, int BLOCK_SIZE, IdType TILE_SIZE>
44
__device__ void map_vertex_ids(
45
46
    const IdType* const global, IdType* const new_global,
    const IdType num_vertices, const DeviceOrderedHashTable<IdType>& table) {
47
48
49
50
  assert(BLOCK_SIZE == blockDim.x);

  using Mapping = typename OrderedHashTable<IdType>::Mapping;

51
52
  const IdType tile_start = TILE_SIZE * blockIdx.x;
  const IdType tile_end = min(TILE_SIZE * (blockIdx.x + 1), num_vertices);
53

54
55
  for (IdType idx = threadIdx.x + tile_start; idx < tile_end;
       idx += BLOCK_SIZE) {
56
57
58
59
60
61
    const Mapping& mapping = *table.Search(global[idx]);
    new_global[idx] = mapping.local;
  }
}

/**
62
 * @brief Generate mapped edge endpoint ids.
63
 *
64
65
66
67
68
69
70
71
72
73
74
75
 * @tparam IdType The type of id.
 * @tparam BLOCK_SIZE The size of each thread block.
 * @tparam TILE_SIZE The number of edges to process per thread block.
 * @param global_srcs_device The source ids to map.
 * @param new_global_srcs_device The mapped source ids (output).
 * @param global_dsts_device The destination ids to map.
 * @param new_global_dsts_device The mapped destination ids (output).
 * @param num_edges The number of edges to map.
 * @param src_mapping The mapping of sources ids.
 * @param src_hash_size The the size of source id hash table/mapping.
 * @param dst_mapping The mapping of destination ids.
 * @param dst_hash_size The the size of destination id hash table/mapping.
76
77
 */
template <typename IdType, int BLOCK_SIZE, IdType TILE_SIZE>
78
__global__ void map_edge_ids(
79
80
81
82
    const IdType* const global_srcs_device,
    IdType* const new_global_srcs_device,
    const IdType* const global_dsts_device,
    IdType* const new_global_dsts_device, const IdType num_edges,
83
84
85
86
87
88
89
    DeviceOrderedHashTable<IdType> src_mapping,
    DeviceOrderedHashTable<IdType> dst_mapping) {
  assert(BLOCK_SIZE == blockDim.x);
  assert(2 == gridDim.y);

  if (blockIdx.y == 0) {
    map_vertex_ids<IdType, BLOCK_SIZE, TILE_SIZE>(
90
        global_srcs_device, new_global_srcs_device, num_edges, src_mapping);
91
92
  } else {
    map_vertex_ids<IdType, BLOCK_SIZE, TILE_SIZE>(
93
        global_dsts_device, new_global_dsts_device, num_edges, dst_mapping);
94
95
96
97
  }
}

/**
98
 * @brief Device level node maps for each node type.
99
 *
100
101
 * @param num_nodes Number of nodes per type.
 * @param offset When offset is set to 0, LhsHashTable is identical to
102
103
 *        RhsHashTable. Or set to num_nodes.size()/2 to use seperated
 *        LhsHashTable and RhsHashTable.
104
105
 * @param ctx The DGL context.
 * @param stream The stream to operate on.
106
107
 */
template <typename IdType>
108
109
110
111
112
class DeviceNodeMap {
 public:
  using Mapping = typename OrderedHashTable<IdType>::Mapping;

  DeviceNodeMap(
113
114
115
116
117
118
      const std::vector<int64_t>& num_nodes, const int64_t offset,
      DGLContext ctx, cudaStream_t stream)
      : num_types_(num_nodes.size()),
        rhs_offset_(offset),
        hash_tables_(),
        ctx_(ctx) {
119
120
121
122
123
    auto device = runtime::DeviceAPI::Get(ctx);

    hash_tables_.reserve(num_types_);
    for (int64_t i = 0; i < num_types_; ++i) {
      hash_tables_.emplace_back(
124
          new OrderedHashTable<IdType>(num_nodes[i], ctx_, stream));
125
126
127
    }
  }

128
  OrderedHashTable<IdType>& LhsHashTable(const size_t index) {
129
130
131
    return HashData(index);
  }

132
133
  OrderedHashTable<IdType>& RhsHashTable(const size_t index) {
    return HashData(index + rhs_offset_);
134
135
  }

136
  const OrderedHashTable<IdType>& LhsHashTable(const size_t index) const {
137
138
139
    return HashData(index);
  }

140
141
  const OrderedHashTable<IdType>& RhsHashTable(const size_t index) const {
    return HashData(index + rhs_offset_);
142
143
  }

144
  IdType LhsHashSize(const size_t index) const { return HashSize(index); }
145

146
147
  IdType RhsHashSize(const size_t index) const {
    return HashSize(rhs_offset_ + index);
148
149
  }

150
  size_t Size() const { return hash_tables_.size(); }
151
152
153
154
155
156
157

 private:
  int64_t num_types_;
  size_t rhs_offset_;
  std::vector<std::unique_ptr<OrderedHashTable<IdType>>> hash_tables_;
  DGLContext ctx_;

158
  inline OrderedHashTable<IdType>& HashData(const size_t index) {
159
160
161
162
    CHECK_LT(index, hash_tables_.size());
    return *hash_tables_[index];
  }

163
  inline const OrderedHashTable<IdType>& HashData(const size_t index) const {
164
165
166
167
    CHECK_LT(index, hash_tables_.size());
    return *hash_tables_[index];
  }

168
  inline IdType HashSize(const size_t index) const {
169
170
171
172
    return HashData(index).size();
  }
};

173
174
175
template <typename IdType>
inline size_t RoundUpDiv(const IdType num, const size_t divisor) {
  return static_cast<IdType>(num / divisor) + (num % divisor == 0 ? 0 : 1);
176
177
}

178
179
180
template <typename IdType>
inline IdType RoundUp(const IdType num, const size_t unit) {
  return RoundUpDiv(num, unit) * unit;
181
182
}

183
184
185
186
template <typename IdType>
std::tuple<std::vector<IdArray>, std::vector<IdArray>> MapEdges(
    HeteroGraphPtr graph, const std::vector<EdgeArray>& edge_sets,
    const DeviceNodeMap<IdType>& node_map, cudaStream_t stream) {
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
  constexpr const int BLOCK_SIZE = 128;
  constexpr const size_t TILE_SIZE = 1024;

  const auto& ctx = graph->Context();

  std::vector<IdArray> new_lhs;
  new_lhs.reserve(edge_sets.size());
  std::vector<IdArray> new_rhs;
  new_rhs.reserve(edge_sets.size());

  // The next peformance optimization here, is to perform mapping of all edge
  // types in a single kernel launch.
  const int64_t num_edge_sets = static_cast<int64_t>(edge_sets.size());
  for (int64_t etype = 0; etype < num_edge_sets; ++etype) {
    const EdgeArray& edges = edge_sets[etype];
    if (edges.id.defined() && edges.src->shape[0] > 0) {
      const int64_t num_edges = edges.src->shape[0];

205
206
      new_lhs.emplace_back(NewIdArray(num_edges, ctx, sizeof(IdType) * 8));
      new_rhs.emplace_back(NewIdArray(num_edges, ctx, sizeof(IdType) * 8));
207
208
209
210
211
212
213
214
215

      const auto src_dst_types = graph->GetEndpointTypes(etype);
      const int src_type = src_dst_types.first;
      const int dst_type = src_dst_types.second;

      const dim3 grid(RoundUpDiv(num_edges, TILE_SIZE), 2);
      const dim3 block(BLOCK_SIZE);

      // map the srcs
216
217
218
219
220
221
      CUDA_KERNEL_CALL(
          (map_edge_ids<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0, stream,
          edges.src.Ptr<IdType>(), new_lhs.back().Ptr<IdType>(),
          edges.dst.Ptr<IdType>(), new_rhs.back().Ptr<IdType>(), num_edges,
          node_map.LhsHashTable(src_type).DeviceHandle(),
          node_map.RhsHashTable(dst_type).DeviceHandle());
222
223
    } else {
      new_lhs.emplace_back(
224
          aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));
225
      new_rhs.emplace_back(
226
          aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));
227
228
229
230
231
232
233
234
235
236
237
    }
  }

  return std::tuple<std::vector<IdArray>, std::vector<IdArray>>(
      std::move(new_lhs), std::move(new_rhs));
}

}  // namespace cuda
}  // namespace transform
}  // namespace dgl

238
#endif  // DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_