cuda_map_edges.cuh 7.95 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright 2020-2022 Contributors
4
5
6
7
8
9
10
11
12
13
14
15
16
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
17
18
 * @file graph/transform/cuda/cuda_map_edges.cuh
 * @brief Device level functions for mapping edges.
19
20
21
22
23
 */

#ifndef DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_
#define DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_

24
25
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/base_heterograph.h>
sangwzh's avatar
sangwzh committed
26
#include <hip/hip_runtime.h>
27
28
#include <dgl/runtime/c_runtime_api.h>

29
30
31
32
#include <algorithm>
#include <memory>
#include <tuple>
#include <utility>
33
#include <vector>
34
35
36
37
38
39
40
41
42
43
44
45

#include "../../../runtime/cuda/cuda_common.h"
#include "../../../runtime/cuda/cuda_hashtable.cuh"

using namespace dgl::aten;
using namespace dgl::runtime::cuda;

namespace dgl {
namespace transform {

namespace cuda {

46
template <typename IdType, int BLOCK_SIZE, IdType TILE_SIZE>
47
__device__ void map_vertex_ids(
48
49
    const IdType* const global, IdType* const new_global,
    const IdType num_vertices, const DeviceOrderedHashTable<IdType>& table) {
50
51
52
53
  assert(BLOCK_SIZE == blockDim.x);

  using Mapping = typename OrderedHashTable<IdType>::Mapping;

54
55
  const IdType tile_start = TILE_SIZE * blockIdx.x;
  const IdType tile_end = min(TILE_SIZE * (blockIdx.x + 1), num_vertices);
56

57
58
  for (IdType idx = threadIdx.x + tile_start; idx < tile_end;
       idx += BLOCK_SIZE) {
59
60
61
62
63
64
    const Mapping& mapping = *table.Search(global[idx]);
    new_global[idx] = mapping.local;
  }
}

/**
65
 * @brief Generate mapped edge endpoint ids.
66
 *
67
68
69
70
71
72
73
74
75
76
77
78
 * @tparam IdType The type of id.
 * @tparam BLOCK_SIZE The size of each thread block.
 * @tparam TILE_SIZE The number of edges to process per thread block.
 * @param global_srcs_device The source ids to map.
 * @param new_global_srcs_device The mapped source ids (output).
 * @param global_dsts_device The destination ids to map.
 * @param new_global_dsts_device The mapped destination ids (output).
 * @param num_edges The number of edges to map.
 * @param src_mapping The mapping of sources ids.
 * @param src_hash_size The the size of source id hash table/mapping.
 * @param dst_mapping The mapping of destination ids.
 * @param dst_hash_size The the size of destination id hash table/mapping.
79
80
 */
template <typename IdType, int BLOCK_SIZE, IdType TILE_SIZE>
81
__global__ void map_edge_ids(
82
83
84
85
    const IdType* const global_srcs_device,
    IdType* const new_global_srcs_device,
    const IdType* const global_dsts_device,
    IdType* const new_global_dsts_device, const IdType num_edges,
86
87
88
89
90
91
92
    DeviceOrderedHashTable<IdType> src_mapping,
    DeviceOrderedHashTable<IdType> dst_mapping) {
  assert(BLOCK_SIZE == blockDim.x);
  assert(2 == gridDim.y);

  if (blockIdx.y == 0) {
    map_vertex_ids<IdType, BLOCK_SIZE, TILE_SIZE>(
93
        global_srcs_device, new_global_srcs_device, num_edges, src_mapping);
94
95
  } else {
    map_vertex_ids<IdType, BLOCK_SIZE, TILE_SIZE>(
96
        global_dsts_device, new_global_dsts_device, num_edges, dst_mapping);
97
98
99
100
  }
}

/**
101
 * @brief Device level node maps for each node type.
102
 *
103
104
 * @param num_nodes Number of nodes per type.
 * @param offset When offset is set to 0, LhsHashTable is identical to
105
106
 *        RhsHashTable. Or set to num_nodes.size()/2 to use seperated
 *        LhsHashTable and RhsHashTable.
107
108
 * @param ctx The DGL context.
 * @param stream The stream to operate on.
109
110
 */
template <typename IdType>
111
112
113
114
115
class DeviceNodeMap {
 public:
  using Mapping = typename OrderedHashTable<IdType>::Mapping;

  DeviceNodeMap(
116
      const std::vector<int64_t>& num_nodes, const int64_t offset,
sangwzh's avatar
sangwzh committed
117
      DGLContext ctx, hipStream_t stream)
118
119
120
121
      : num_types_(num_nodes.size()),
        rhs_offset_(offset),
        hash_tables_(),
        ctx_(ctx) {
122
123
124
125
126
    auto device = runtime::DeviceAPI::Get(ctx);

    hash_tables_.reserve(num_types_);
    for (int64_t i = 0; i < num_types_; ++i) {
      hash_tables_.emplace_back(
127
          new OrderedHashTable<IdType>(num_nodes[i], ctx_, stream));
128
129
130
    }
  }

131
  OrderedHashTable<IdType>& LhsHashTable(const size_t index) {
132
133
134
    return HashData(index);
  }

135
136
  OrderedHashTable<IdType>& RhsHashTable(const size_t index) {
    return HashData(index + rhs_offset_);
137
138
  }

139
  const OrderedHashTable<IdType>& LhsHashTable(const size_t index) const {
140
141
142
    return HashData(index);
  }

143
144
  const OrderedHashTable<IdType>& RhsHashTable(const size_t index) const {
    return HashData(index + rhs_offset_);
145
146
  }

147
  IdType LhsHashSize(const size_t index) const { return HashSize(index); }
148

149
150
  IdType RhsHashSize(const size_t index) const {
    return HashSize(rhs_offset_ + index);
151
152
  }

153
  size_t Size() const { return hash_tables_.size(); }
154
155
156
157
158
159
160

 private:
  int64_t num_types_;
  size_t rhs_offset_;
  std::vector<std::unique_ptr<OrderedHashTable<IdType>>> hash_tables_;
  DGLContext ctx_;

161
  inline OrderedHashTable<IdType>& HashData(const size_t index) {
162
163
164
165
    CHECK_LT(index, hash_tables_.size());
    return *hash_tables_[index];
  }

166
  inline const OrderedHashTable<IdType>& HashData(const size_t index) const {
167
168
169
170
    CHECK_LT(index, hash_tables_.size());
    return *hash_tables_[index];
  }

171
  inline IdType HashSize(const size_t index) const {
172
173
174
175
    return HashData(index).size();
  }
};

176
177
178
template <typename IdType>
inline size_t RoundUpDiv(const IdType num, const size_t divisor) {
  return static_cast<IdType>(num / divisor) + (num % divisor == 0 ? 0 : 1);
179
180
}

181
182
183
template <typename IdType>
inline IdType RoundUp(const IdType num, const size_t unit) {
  return RoundUpDiv(num, unit) * unit;
184
185
}

186
187
188
template <typename IdType>
std::tuple<std::vector<IdArray>, std::vector<IdArray>> MapEdges(
    HeteroGraphPtr graph, const std::vector<EdgeArray>& edge_sets,
sangwzh's avatar
sangwzh committed
189
    const DeviceNodeMap<IdType>& node_map, hipStream_t stream) {
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
  constexpr const int BLOCK_SIZE = 128;
  constexpr const size_t TILE_SIZE = 1024;

  const auto& ctx = graph->Context();

  std::vector<IdArray> new_lhs;
  new_lhs.reserve(edge_sets.size());
  std::vector<IdArray> new_rhs;
  new_rhs.reserve(edge_sets.size());

  // The next peformance optimization here, is to perform mapping of all edge
  // types in a single kernel launch.
  const int64_t num_edge_sets = static_cast<int64_t>(edge_sets.size());
  for (int64_t etype = 0; etype < num_edge_sets; ++etype) {
    const EdgeArray& edges = edge_sets[etype];
    if (edges.id.defined() && edges.src->shape[0] > 0) {
      const int64_t num_edges = edges.src->shape[0];

208
209
      new_lhs.emplace_back(NewIdArray(num_edges, ctx, sizeof(IdType) * 8));
      new_rhs.emplace_back(NewIdArray(num_edges, ctx, sizeof(IdType) * 8));
210
211
212
213
214
215
216
217
218

      const auto src_dst_types = graph->GetEndpointTypes(etype);
      const int src_type = src_dst_types.first;
      const int dst_type = src_dst_types.second;

      const dim3 grid(RoundUpDiv(num_edges, TILE_SIZE), 2);
      const dim3 block(BLOCK_SIZE);

      // map the srcs
219
220
221
222
223
224
      CUDA_KERNEL_CALL(
          (map_edge_ids<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0, stream,
          edges.src.Ptr<IdType>(), new_lhs.back().Ptr<IdType>(),
          edges.dst.Ptr<IdType>(), new_rhs.back().Ptr<IdType>(), num_edges,
          node_map.LhsHashTable(src_type).DeviceHandle(),
          node_map.RhsHashTable(dst_type).DeviceHandle());
225
226
    } else {
      new_lhs.emplace_back(
227
          aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));
228
      new_rhs.emplace_back(
229
          aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));
230
231
232
233
234
235
236
237
238
239
240
    }
  }

  return std::tuple<std::vector<IdArray>, std::vector<IdArray>>(
      std::move(new_lhs), std::move(new_rhs));
}

}  // namespace cuda
}  // namespace transform
}  // namespace dgl

241
#endif  // DGL_GRAPH_TRANSFORM_CUDA_CUDA_MAP_EDGES_CUH_