cuda_hashtable.hip 14 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2021 by Contributors
5
6
 * @file runtime/cuda/cuda_device_common.cuh
 * @brief Device level functions for within cuda kernels.
7
8
9
 */

#include <cassert>
sangwzh's avatar
sangwzh committed
10
#include <hipcub/hipcub.hpp>  // NOLINT
11

12
#include "../../array/cuda/atomic.cuh"
13
14
#include "cuda_common.h"
#include "cuda_hashtable.cuh"
15

16
using namespace dgl::aten::cuda;
17
18
19
20
21
22
23
24
25
26
27

namespace dgl {
namespace runtime {
namespace cuda {

namespace {

constexpr static const int BLOCK_SIZE = 256;
constexpr static const size_t TILE_SIZE = 1024;

/**
28
29
30
31
32
33
 * @brief This is the mutable version of the DeviceOrderedHashTable, for use in
 * inserting elements into the hashtable.
 *
 * @tparam IdType The type of ID to store in the hashtable.
 */
template <typename IdType>
34
35
36
37
38
39
class MutableDeviceOrderedHashTable : public DeviceOrderedHashTable<IdType> {
 public:
  typedef typename DeviceOrderedHashTable<IdType>::Mapping* Iterator;
  static constexpr IdType kEmptyKey = DeviceOrderedHashTable<IdType>::kEmptyKey;

  /**
40
41
42
43
   * @brief Create a new mutable hashtable for use on the device.
   *
   * @param hostTable The original hash table on the host.
   */
44
  explicit MutableDeviceOrderedHashTable(
45
46
      OrderedHashTable<IdType>* const hostTable)
      : DeviceOrderedHashTable<IdType>(hostTable->DeviceHandle()) {}
47
48

  /**
49
50
51
52
53
54
55
56
57
58
   * @brief Find the mutable mapping of a given key within the hash table.
   *
   * WARNING: The key must exist within the hashtable. Searching for a key not
   * in the hashtable is undefined behavior.
   *
   * @param id The key to search for.
   *
   * @return The mapping.
   */
  inline __device__ Iterator Search(const IdType id) {
sangwzh's avatar
sangwzh committed
59
60
61
    // const IdType pos = SearchForPosition(id);
    const IdType pos = DeviceOrderedHashTable<IdType>::SearchForPosition(id);

62
63
64
65
66

    return GetMutable(pos);
  }

  /**
67
   * @brief Attempt to insert into the hash table at a specific location.
68
   *
69
70
71
   * @param pos The position to insert at.
   * @param id The ID to insert into the hash table.
   * @param index The original index of the item being inserted.
72
   *
73
   * @return True, if the insertion was successful.
74
   */
75
  inline __device__ bool AttemptInsertAt(
76
      const size_t pos, const IdType id, const size_t index) {
77
78
79
80
81
    const IdType key = AtomicCAS(&GetMutable(pos)->key, kEmptyKey, id);
    if (key == kEmptyKey || key == id) {
      // we either set a match key, or found a matching key, so then place the
      // minimum index in position. Match the type of atomicMin, so ignore
      // linting
82
      atomicMin(
83
          reinterpret_cast<unsigned long long*>(  // NOLINT
84
85
              &GetMutable(pos)->index),
          static_cast<unsigned long long>(index));  // NOLINT
86
87
88
89
90
91
92
93
      return true;
    } else {
      // we need to search elsewhere
      return false;
    }
  }

  /**
94
95
96
97
98
99
100
101
   * @brief Insert key-index pair into the hashtable.
   *
   * @param id The ID to insert.
   * @param index The index at which the ID occured.
   *
   * @return An iterator to inserted mapping.
   */
  inline __device__ Iterator Insert(const IdType id, const size_t index) {
sangwzh's avatar
sangwzh committed
102
103
104
    // size_t pos = Hash(id);
    size_t pos = DeviceOrderedHashTable<IdType>::Hash(id);

105
106
107
108

    // linearly scan for an empty slot or matching entry
    IdType delta = 1;
    while (!AttemptInsertAt(pos, id, index)) {
sangwzh's avatar
sangwzh committed
109
110
111
      // pos = Hash(pos + delta);
      pos = DeviceOrderedHashTable<IdType>::Hash(pos+delta);

112
      delta += 1;
113
114
115
116
117
118
119
    }

    return GetMutable(pos);
  }

 private:
  /**
120
121
122
123
124
125
   * @brief Get a mutable iterator to the given bucket in the hashtable.
   *
   * @param pos The given bucket.
   *
   * @return The iterator.
   */
126
127
128
129
130
  inline __device__ Iterator GetMutable(const size_t pos) {
    assert(pos < this->size_);
    // The parent class Device is read-only, but we ensure this can only be
    // constructed from a mutable version of OrderedHashTable, making this
    // a safe cast to perform.
131
    return const_cast<Iterator>(this->table_ + pos);
132
133
134
135
  }
};

/**
136
137
138
139
140
141
142
143
144
145
146
147
148
 * @brief Calculate the number of buckets in the hashtable. To guarantee we can
 * fill the hashtable in the worst case, we must use a number of buckets which
 * is a power of two.
 * https://en.wikipedia.org/wiki/Quadratic_probing#Limitations
 *
 * @param num The number of items to insert (should be an upper bound on the
 * number of unique keys).
 * @param scale The power of two larger the number of buckets should be than the
 * unique keys.
 *
 * @return The number of buckets the table should contain.
 */
size_t TableSize(const size_t num, const int scale) {
149
150
151
152
153
  const size_t next_pow2 = 1 << static_cast<size_t>(1 + std::log2(num >> 1));
  return next_pow2 << scale;
}

/**
154
155
156
157
158
159
 * @brief This structure is used with cub's block-level prefixscan in order to
 * keep a running sum as items are iteratively processed.
 *
 * @tparam IdType The type to perform the prefixsum on.
 */
template <typename IdType>
160
161
162
struct BlockPrefixCallbackOp {
  IdType running_total_;

163
164
  __device__ BlockPrefixCallbackOp(const IdType running_total)
      : running_total_(running_total) {}
165
166

  __device__ IdType operator()(const IdType block_aggregate) {
167
168
169
    const IdType old_prefix = running_total_;
    running_total_ += block_aggregate;
    return old_prefix;
170
171
172
173
174
175
  }
};

}  // namespace

/**
176
 * @brief This generates a hash map where the keys are the global item numbers,
177
178
 * and the values are indexes, and inputs may have duplciates.
 *
179
180
181
182
183
184
 * @tparam IdType The type of of id.
 * @tparam BLOCK_SIZE The size of the thread block.
 * @tparam TILE_SIZE The number of entries each thread block will process.
 * @param items The items to insert.
 * @param num_items The number of items to insert.
 * @param table The hash table.
185
186
 */
template <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
187
__global__ void generate_hashmap_duplicates(
188
    const IdType* const items, const int64_t num_items,
189
190
191
    MutableDeviceOrderedHashTable<IdType> table) {
  assert(BLOCK_SIZE == blockDim.x);

192
193
  const size_t block_start = TILE_SIZE * blockIdx.x;
  const size_t block_end = TILE_SIZE * (blockIdx.x + 1);
194

195
196
197
#pragma unroll
  for (size_t index = threadIdx.x + block_start; index < block_end;
       index += BLOCK_SIZE) {
198
199
200
201
202
203
204
    if (index < num_items) {
      table.Insert(items[index], index);
    }
  }
}

/**
205
 * @brief This generates a hash map where the keys are the global item numbers,
206
207
 * and the values are indexes, and all inputs are unique.
 *
208
209
210
211
212
213
 * @tparam IdType The type of of id.
 * @tparam BLOCK_SIZE The size of the thread block.
 * @tparam TILE_SIZE The number of entries each thread block will process.
 * @param items The unique items to insert.
 * @param num_items The number of items to insert.
 * @param table The hash table.
214
215
 */
template <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
216
__global__ void generate_hashmap_unique(
217
    const IdType* const items, const int64_t num_items,
218
219
220
221
222
    MutableDeviceOrderedHashTable<IdType> table) {
  assert(BLOCK_SIZE == blockDim.x);

  using Iterator = typename MutableDeviceOrderedHashTable<IdType>::Iterator;

223
224
  const size_t block_start = TILE_SIZE * blockIdx.x;
  const size_t block_end = TILE_SIZE * (blockIdx.x + 1);
225

226
227
228
#pragma unroll
  for (size_t index = threadIdx.x + block_start; index < block_end;
       index += BLOCK_SIZE) {
229
230
231
232
233
234
235
236
237
238
239
    if (index < num_items) {
      const Iterator pos = table.Insert(items[index], index);

      // since we are only inserting unique items, we know their local id
      // will be equal to their index
      pos->local = static_cast<IdType>(index);
    }
  }
}

/**
240
 * @brief This counts the number of nodes inserted per thread block.
241
 *
242
243
244
245
246
247
248
 * @tparam IdType The type of of id.
 * @tparam BLOCK_SIZE The size of the thread block.
 * @tparam TILE_SIZE The number of entries each thread block will process.
 * @param input The nodes to insert.
 * @param num_input The number of nodes to insert.
 * @param table The hash table.
 * @param num_unique The number of nodes inserted into the hash table per thread
249
250
251
 * block.
 */
template <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
252
__global__ void count_hashmap(
253
254
    const IdType* items, const size_t num_items,
    DeviceOrderedHashTable<IdType> table, IdType* const num_unique) {
255
256
  assert(BLOCK_SIZE == blockDim.x);

sangwzh's avatar
sangwzh committed
257
  using BlockReduce = typename hipcub::BlockReduce<IdType, BLOCK_SIZE>;
258
259
  using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping;

260
261
  const size_t block_start = TILE_SIZE * blockIdx.x;
  const size_t block_end = TILE_SIZE * (blockIdx.x + 1);
262
263
264

  IdType count = 0;

265
266
267
#pragma unroll
  for (size_t index = threadIdx.x + block_start; index < block_end;
       index += BLOCK_SIZE) {
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    if (index < num_items) {
      const Mapping& mapping = *table.Search(items[index]);
      if (mapping.index == index) {
        ++count;
      }
    }
  }

  __shared__ typename BlockReduce::TempStorage temp_space;

  count = BlockReduce(temp_space).Sum(count);

  if (threadIdx.x == 0) {
    num_unique[blockIdx.x] = count;
    if (blockIdx.x == 0) {
      num_unique[gridDim.x] = 0;
    }
  }
}

/**
289
 * @brief Update the local numbering of elements in the hashmap.
290
 *
291
292
293
294
295
296
297
 * @tparam IdType The type of id.
 * @tparam BLOCK_SIZE The size of the thread blocks.
 * @tparam TILE_SIZE The number of elements each thread block works on.
 * @param items The set of non-unique items to update from.
 * @param num_items The number of non-unique items.
 * @param table The hash table.
 * @param num_items_prefix The number of unique items preceding each thread
298
 * block.
299
300
 * @param unique_items The set of unique items (output).
 * @param num_unique_items The number of unique items (output).
301
302
 */
template <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
303
__global__ void compact_hashmap(
304
    const IdType* const items, const size_t num_items,
305
    MutableDeviceOrderedHashTable<IdType> table,
306
307
    const IdType* const num_items_prefix, IdType* const unique_items,
    int64_t* const num_unique_items) {
308
309
310
  assert(BLOCK_SIZE == blockDim.x);

  using FlagType = uint16_t;
sangwzh's avatar
sangwzh committed
311
  using BlockScan = typename hipcub::BlockScan<FlagType, BLOCK_SIZE>;
312
313
314
315
316
317
318
319
320
321
322
323
  using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping;

  constexpr const int32_t VALS_PER_THREAD = TILE_SIZE / BLOCK_SIZE;

  __shared__ typename BlockScan::TempStorage temp_space;

  const IdType offset = num_items_prefix[blockIdx.x];

  BlockPrefixCallbackOp<FlagType> prefix_op(0);

  // count successful placements
  for (int32_t i = 0; i < VALS_PER_THREAD; ++i) {
324
    const IdType index = threadIdx.x + i * BLOCK_SIZE + blockIdx.x * TILE_SIZE;
325
326

    FlagType flag;
327
    Mapping* kv;
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    if (index < num_items) {
      kv = table.Search(items[index]);
      flag = kv->index == index;
    } else {
      flag = 0;
    }

    if (!flag) {
      kv = nullptr;
    }

    BlockScan(temp_space).ExclusiveSum(flag, flag, prefix_op);
    __syncthreads();

    if (kv) {
343
      const IdType pos = offset + flag;
344
345
346
347
348
349
350
351
352
353
354
355
      kv->local = pos;
      unique_items[pos] = items[index];
    }
  }

  if (threadIdx.x == 0 && blockIdx.x == 0) {
    *num_unique_items = num_items_prefix[gridDim.x];
  }
}

// DeviceOrderedHashTable implementation

356
template <typename IdType>
357
DeviceOrderedHashTable<IdType>::DeviceOrderedHashTable(
358
359
    const Mapping* const table, const size_t size)
    : table_(table), size_(size) {}
360

361
template <typename IdType>
362
363
364
365
366
367
DeviceOrderedHashTable<IdType> OrderedHashTable<IdType>::DeviceHandle() const {
  return DeviceOrderedHashTable<IdType>(table_, size_);
}

// OrderedHashTable implementation

368
template <typename IdType>
369
OrderedHashTable<IdType>::OrderedHashTable(
sangwzh's avatar
sangwzh committed
370
    const size_t size, DGLContext ctx, hipStream_t stream, const int scale)
371
    : table_(nullptr), size_(TableSize(size, scale)), ctx_(ctx) {
372
373
374
375
376
  // make sure we will at least as many buckets as items.
  CHECK_GT(scale, 0);

  auto device = runtime::DeviceAPI::Get(ctx_);
  table_ = static_cast<Mapping*>(
377
      device->AllocWorkspace(ctx_, sizeof(Mapping) * size_));
378

sangwzh's avatar
sangwzh committed
379
  CUDA_CALL(hipMemsetAsync(
380
381
      table_, DeviceOrderedHashTable<IdType>::kEmptyKey,
      sizeof(Mapping) * size_, stream));
382
383
}

384
template <typename IdType>
385
386
387
388
389
OrderedHashTable<IdType>::~OrderedHashTable() {
  auto device = runtime::DeviceAPI::Get(ctx_);
  device->FreeWorkspace(ctx_, table_);
}

390
template <typename IdType>
391
void OrderedHashTable<IdType>::FillWithDuplicates(
392
    const IdType* const input, const size_t num_input, IdType* const unique,
sangwzh's avatar
sangwzh committed
393
    int64_t* const num_unique, hipStream_t stream) {
394
395
  auto device = runtime::DeviceAPI::Get(ctx_);

396
  const int64_t num_tiles = (num_input + TILE_SIZE - 1) / TILE_SIZE;
397
398
399
400
401
402

  const dim3 grid(num_tiles);
  const dim3 block(BLOCK_SIZE);

  auto device_table = MutableDeviceOrderedHashTable<IdType>(this);

403
404
405
  CUDA_KERNEL_CALL(
      (generate_hashmap_duplicates<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block,
      0, stream, input, num_input, device_table);
406

407
408
  IdType* item_prefix = static_cast<IdType*>(
      device->AllocWorkspace(ctx_, sizeof(IdType) * (num_input + 1)));
409

410
411
412
  CUDA_KERNEL_CALL(
      (count_hashmap<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0, stream,
      input, num_input, device_table, item_prefix);
413
414

  size_t workspace_bytes;
sangwzh's avatar
sangwzh committed
415
  CUDA_CALL(hipcub::DeviceScan::ExclusiveSum(
416
417
418
      nullptr, workspace_bytes, static_cast<IdType*>(nullptr),
      static_cast<IdType*>(nullptr), grid.x + 1, stream));
  void* workspace = device->AllocWorkspace(ctx_, workspace_bytes);
419

sangwzh's avatar
sangwzh committed
420
  CUDA_CALL(hipcub::DeviceScan::ExclusiveSum(
421
422
      workspace, workspace_bytes, item_prefix, item_prefix, grid.x + 1,
      stream));
423
424
  device->FreeWorkspace(ctx_, workspace);

425
426
427
  CUDA_KERNEL_CALL(
      (compact_hashmap<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0, stream,
      input, num_input, device_table, item_prefix, unique, num_unique);
428
429
430
  device->FreeWorkspace(ctx_, item_prefix);
}

431
template <typename IdType>
432
void OrderedHashTable<IdType>::FillWithUnique(
sangwzh's avatar
sangwzh committed
433
    const IdType* const input, const size_t num_input, hipStream_t stream) {
434
  const int64_t num_tiles = (num_input + TILE_SIZE - 1) / TILE_SIZE;
435
436
437
438
439
440

  const dim3 grid(num_tiles);
  const dim3 block(BLOCK_SIZE);

  auto device_table = MutableDeviceOrderedHashTable<IdType>(this);

441
442
443
  CUDA_KERNEL_CALL(
      (generate_hashmap_unique<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0,
      stream, input, num_input, device_table);
444
445
446
447
448
449
450
451
}

template class OrderedHashTable<int32_t>;
template class OrderedHashTable<int64_t>;

}  // namespace cuda
}  // namespace runtime
}  // namespace dgl