cuda_hashtable.cu 13.7 KB
Newer Older
1
2
3
4
5
6
7
8
/*!
 *  Copyright (c) 2021 by Contributors
 * \file runtime/cuda/cuda_device_common.cuh
 * \brief Device level functions for within cuda kernels.
 */

#include <cassert>

9
#include "../../array/cuda/atomic.cuh"
10
#include "../../array/cuda/dgl_cub.cuh"
11
12
#include "cuda_common.h"
#include "cuda_hashtable.cuh"
13

14
using namespace dgl::aten::cuda;
15
16
17
18
19
20
21
22
23
24
25

namespace dgl {
namespace runtime {
namespace cuda {

namespace {

constexpr static const int BLOCK_SIZE = 256;
constexpr static const size_t TILE_SIZE = 1024;

/**
26
27
28
29
30
31
 * @brief This is the mutable version of the DeviceOrderedHashTable, for use in
 * inserting elements into the hashtable.
 *
 * @tparam IdType The type of ID to store in the hashtable.
 */
template <typename IdType>
32
33
34
35
36
37
class MutableDeviceOrderedHashTable : public DeviceOrderedHashTable<IdType> {
 public:
  typedef typename DeviceOrderedHashTable<IdType>::Mapping* Iterator;
  static constexpr IdType kEmptyKey = DeviceOrderedHashTable<IdType>::kEmptyKey;

  /**
38
39
40
41
   * @brief Create a new mutable hashtable for use on the device.
   *
   * @param hostTable The original hash table on the host.
   */
42
  explicit MutableDeviceOrderedHashTable(
43
44
      OrderedHashTable<IdType>* const hostTable)
      : DeviceOrderedHashTable<IdType>(hostTable->DeviceHandle()) {}
45
46

  /**
47
48
49
50
51
52
53
54
55
56
   * @brief Find the mutable mapping of a given key within the hash table.
   *
   * WARNING: The key must exist within the hashtable. Searching for a key not
   * in the hashtable is undefined behavior.
   *
   * @param id The key to search for.
   *
   * @return The mapping.
   */
  inline __device__ Iterator Search(const IdType id) {
57
58
59
60
61
62
    const IdType pos = SearchForPosition(id);

    return GetMutable(pos);
  }

  /**
63
64
65
66
67
68
69
70
   * \brief Attempt to insert into the hash table at a specific location.
   *
   * \param pos The position to insert at.
   * \param id The ID to insert into the hash table.
   * \param index The original index of the item being inserted.
   *
   * \return True, if the insertion was successful.
   */
71
  inline __device__ bool AttemptInsertAt(
72
      const size_t pos, const IdType id, const size_t index) {
73
74
75
76
77
    const IdType key = AtomicCAS(&GetMutable(pos)->key, kEmptyKey, id);
    if (key == kEmptyKey || key == id) {
      // we either set a match key, or found a matching key, so then place the
      // minimum index in position. Match the type of atomicMin, so ignore
      // linting
78
79
80
81
      atomicMin(
          reinterpret_cast<unsigned long long*>(    // NOLINT
              &GetMutable(pos)->index),
          static_cast<unsigned long long>(index));  // NOLINT
82
83
84
85
86
87
88
89
      return true;
    } else {
      // we need to search elsewhere
      return false;
    }
  }

  /**
90
91
92
93
94
95
96
97
   * @brief Insert key-index pair into the hashtable.
   *
   * @param id The ID to insert.
   * @param index The index at which the ID occured.
   *
   * @return An iterator to inserted mapping.
   */
  inline __device__ Iterator Insert(const IdType id, const size_t index) {
98
99
100
101
102
    size_t pos = Hash(id);

    // linearly scan for an empty slot or matching entry
    IdType delta = 1;
    while (!AttemptInsertAt(pos, id, index)) {
103
104
      pos = Hash(pos + delta);
      delta += 1;
105
106
107
108
109
110
111
    }

    return GetMutable(pos);
  }

 private:
  /**
112
113
114
115
116
117
   * @brief Get a mutable iterator to the given bucket in the hashtable.
   *
   * @param pos The given bucket.
   *
   * @return The iterator.
   */
118
119
120
121
122
  inline __device__ Iterator GetMutable(const size_t pos) {
    assert(pos < this->size_);
    // The parent class Device is read-only, but we ensure this can only be
    // constructed from a mutable version of OrderedHashTable, making this
    // a safe cast to perform.
123
    return const_cast<Iterator>(this->table_ + pos);
124
125
126
127
  }
};

/**
128
129
130
131
132
133
134
135
136
137
138
139
140
 * @brief Calculate the number of buckets in the hashtable. To guarantee we can
 * fill the hashtable in the worst case, we must use a number of buckets which
 * is a power of two.
 * https://en.wikipedia.org/wiki/Quadratic_probing#Limitations
 *
 * @param num The number of items to insert (should be an upper bound on the
 * number of unique keys).
 * @param scale The power of two larger the number of buckets should be than the
 * unique keys.
 *
 * @return The number of buckets the table should contain.
 */
size_t TableSize(const size_t num, const int scale) {
141
142
143
144
145
  const size_t next_pow2 = 1 << static_cast<size_t>(1 + std::log2(num >> 1));
  return next_pow2 << scale;
}

/**
146
147
148
149
150
151
 * @brief This structure is used with cub's block-level prefixscan in order to
 * keep a running sum as items are iteratively processed.
 *
 * @tparam IdType The type to perform the prefixsum on.
 */
template <typename IdType>
152
153
154
struct BlockPrefixCallbackOp {
  IdType running_total_;

155
156
  __device__ BlockPrefixCallbackOp(const IdType running_total)
      : running_total_(running_total) {}
157
158

  __device__ IdType operator()(const IdType block_aggregate) {
159
160
161
    const IdType old_prefix = running_total_;
    running_total_ += block_aggregate;
    return old_prefix;
162
163
164
165
166
167
  }
};

}  // namespace

/**
168
169
170
171
172
173
174
175
176
177
178
 * \brief This generates a hash map where the keys are the global item numbers,
 * and the values are indexes, and inputs may have duplciates.
 *
 * \tparam IdType The type of of id.
 * \tparam BLOCK_SIZE The size of the thread block.
 * \tparam TILE_SIZE The number of entries each thread block will process.
 * \param items The items to insert.
 * \param num_items The number of items to insert.
 * \param table The hash table.
 */
template <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
179
__global__ void generate_hashmap_duplicates(
180
    const IdType* const items, const int64_t num_items,
181
182
183
    MutableDeviceOrderedHashTable<IdType> table) {
  assert(BLOCK_SIZE == blockDim.x);

184
185
  const size_t block_start = TILE_SIZE * blockIdx.x;
  const size_t block_end = TILE_SIZE * (blockIdx.x + 1);
186

187
188
189
#pragma unroll
  for (size_t index = threadIdx.x + block_start; index < block_end;
       index += BLOCK_SIZE) {
190
191
192
193
194
195
196
    if (index < num_items) {
      table.Insert(items[index], index);
    }
  }
}

/**
197
198
199
200
201
202
203
204
205
206
207
 * \brief This generates a hash map where the keys are the global item numbers,
 * and the values are indexes, and all inputs are unique.
 *
 * \tparam IdType The type of of id.
 * \tparam BLOCK_SIZE The size of the thread block.
 * \tparam TILE_SIZE The number of entries each thread block will process.
 * \param items The unique items to insert.
 * \param num_items The number of items to insert.
 * \param table The hash table.
 */
template <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
208
__global__ void generate_hashmap_unique(
209
    const IdType* const items, const int64_t num_items,
210
211
212
213
214
    MutableDeviceOrderedHashTable<IdType> table) {
  assert(BLOCK_SIZE == blockDim.x);

  using Iterator = typename MutableDeviceOrderedHashTable<IdType>::Iterator;

215
216
  const size_t block_start = TILE_SIZE * blockIdx.x;
  const size_t block_end = TILE_SIZE * (blockIdx.x + 1);
217

218
219
220
#pragma unroll
  for (size_t index = threadIdx.x + block_start; index < block_end;
       index += BLOCK_SIZE) {
221
222
223
224
225
226
227
228
229
230
231
    if (index < num_items) {
      const Iterator pos = table.Insert(items[index], index);

      // since we are only inserting unique items, we know their local id
      // will be equal to their index
      pos->local = static_cast<IdType>(index);
    }
  }
}

/**
232
233
234
235
236
237
238
239
240
241
242
243
 * \brief This counts the number of nodes inserted per thread block.
 *
 * \tparam IdType The type of of id.
 * \tparam BLOCK_SIZE The size of the thread block.
 * \tparam TILE_SIZE The number of entries each thread block will process.
 * \param input The nodes to insert.
 * \param num_input The number of nodes to insert.
 * \param table The hash table.
 * \param num_unique The number of nodes inserted into the hash table per thread
 * block.
 */
template <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
244
__global__ void count_hashmap(
245
246
    const IdType* items, const size_t num_items,
    DeviceOrderedHashTable<IdType> table, IdType* const num_unique) {
247
248
249
250
251
  assert(BLOCK_SIZE == blockDim.x);

  using BlockReduce = typename cub::BlockReduce<IdType, BLOCK_SIZE>;
  using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping;

252
253
  const size_t block_start = TILE_SIZE * blockIdx.x;
  const size_t block_end = TILE_SIZE * (blockIdx.x + 1);
254
255
256

  IdType count = 0;

257
258
259
#pragma unroll
  for (size_t index = threadIdx.x + block_start; index < block_end;
       index += BLOCK_SIZE) {
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    if (index < num_items) {
      const Mapping& mapping = *table.Search(items[index]);
      if (mapping.index == index) {
        ++count;
      }
    }
  }

  __shared__ typename BlockReduce::TempStorage temp_space;

  count = BlockReduce(temp_space).Sum(count);

  if (threadIdx.x == 0) {
    num_unique[blockIdx.x] = count;
    if (blockIdx.x == 0) {
      num_unique[gridDim.x] = 0;
    }
  }
}

/**
281
282
283
284
285
286
287
288
289
290
291
292
293
294
 * \brief Update the local numbering of elements in the hashmap.
 *
 * \tparam IdType The type of id.
 * \tparam BLOCK_SIZE The size of the thread blocks.
 * \tparam TILE_SIZE The number of elements each thread block works on.
 * \param items The set of non-unique items to update from.
 * \param num_items The number of non-unique items.
 * \param table The hash table.
 * \param num_items_prefix The number of unique items preceding each thread
 * block.
 * \param unique_items The set of unique items (output).
 * \param num_unique_items The number of unique items (output).
 */
template <typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
295
__global__ void compact_hashmap(
296
    const IdType* const items, const size_t num_items,
297
    MutableDeviceOrderedHashTable<IdType> table,
298
299
    const IdType* const num_items_prefix, IdType* const unique_items,
    int64_t* const num_unique_items) {
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
  assert(BLOCK_SIZE == blockDim.x);

  using FlagType = uint16_t;
  using BlockScan = typename cub::BlockScan<FlagType, BLOCK_SIZE>;
  using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping;

  constexpr const int32_t VALS_PER_THREAD = TILE_SIZE / BLOCK_SIZE;

  __shared__ typename BlockScan::TempStorage temp_space;

  const IdType offset = num_items_prefix[blockIdx.x];

  BlockPrefixCallbackOp<FlagType> prefix_op(0);

  // count successful placements
  for (int32_t i = 0; i < VALS_PER_THREAD; ++i) {
316
    const IdType index = threadIdx.x + i * BLOCK_SIZE + blockIdx.x * TILE_SIZE;
317
318

    FlagType flag;
319
    Mapping* kv;
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    if (index < num_items) {
      kv = table.Search(items[index]);
      flag = kv->index == index;
    } else {
      flag = 0;
    }

    if (!flag) {
      kv = nullptr;
    }

    BlockScan(temp_space).ExclusiveSum(flag, flag, prefix_op);
    __syncthreads();

    if (kv) {
335
      const IdType pos = offset + flag;
336
337
338
339
340
341
342
343
344
345
346
347
      kv->local = pos;
      unique_items[pos] = items[index];
    }
  }

  if (threadIdx.x == 0 && blockIdx.x == 0) {
    *num_unique_items = num_items_prefix[gridDim.x];
  }
}

// DeviceOrderedHashTable implementation

348
template <typename IdType>
349
DeviceOrderedHashTable<IdType>::DeviceOrderedHashTable(
350
351
    const Mapping* const table, const size_t size)
    : table_(table), size_(size) {}
352

353
template <typename IdType>
354
355
356
357
358
359
DeviceOrderedHashTable<IdType> OrderedHashTable<IdType>::DeviceHandle() const {
  return DeviceOrderedHashTable<IdType>(table_, size_);
}

// OrderedHashTable implementation

360
template <typename IdType>
361
OrderedHashTable<IdType>::OrderedHashTable(
362
363
    const size_t size, DGLContext ctx, cudaStream_t stream, const int scale)
    : table_(nullptr), size_(TableSize(size, scale)), ctx_(ctx) {
364
365
366
367
368
  // make sure we will at least as many buckets as items.
  CHECK_GT(scale, 0);

  auto device = runtime::DeviceAPI::Get(ctx_);
  table_ = static_cast<Mapping*>(
369
      device->AllocWorkspace(ctx_, sizeof(Mapping) * size_));
370
371

  CUDA_CALL(cudaMemsetAsync(
372
373
      table_, DeviceOrderedHashTable<IdType>::kEmptyKey,
      sizeof(Mapping) * size_, stream));
374
375
}

376
template <typename IdType>
377
378
379
380
381
OrderedHashTable<IdType>::~OrderedHashTable() {
  auto device = runtime::DeviceAPI::Get(ctx_);
  device->FreeWorkspace(ctx_, table_);
}

382
template <typename IdType>
383
void OrderedHashTable<IdType>::FillWithDuplicates(
384
385
    const IdType* const input, const size_t num_input, IdType* const unique,
    int64_t* const num_unique, cudaStream_t stream) {
386
387
  auto device = runtime::DeviceAPI::Get(ctx_);

388
  const int64_t num_tiles = (num_input + TILE_SIZE - 1) / TILE_SIZE;
389
390
391
392
393
394

  const dim3 grid(num_tiles);
  const dim3 block(BLOCK_SIZE);

  auto device_table = MutableDeviceOrderedHashTable<IdType>(this);

395
396
397
  CUDA_KERNEL_CALL(
      (generate_hashmap_duplicates<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block,
      0, stream, input, num_input, device_table);
398

399
400
  IdType* item_prefix = static_cast<IdType*>(
      device->AllocWorkspace(ctx_, sizeof(IdType) * (num_input + 1)));
401

402
403
404
  CUDA_KERNEL_CALL(
      (count_hashmap<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0, stream,
      input, num_input, device_table, item_prefix);
405
406
407

  size_t workspace_bytes;
  CUDA_CALL(cub::DeviceScan::ExclusiveSum(
408
409
410
      nullptr, workspace_bytes, static_cast<IdType*>(nullptr),
      static_cast<IdType*>(nullptr), grid.x + 1, stream));
  void* workspace = device->AllocWorkspace(ctx_, workspace_bytes);
411
412

  CUDA_CALL(cub::DeviceScan::ExclusiveSum(
413
414
      workspace, workspace_bytes, item_prefix, item_prefix, grid.x + 1,
      stream));
415
416
  device->FreeWorkspace(ctx_, workspace);

417
418
419
  CUDA_KERNEL_CALL(
      (compact_hashmap<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0, stream,
      input, num_input, device_table, item_prefix, unique, num_unique);
420
421
422
  device->FreeWorkspace(ctx_, item_prefix);
}

423
template <typename IdType>
424
void OrderedHashTable<IdType>::FillWithUnique(
425
426
    const IdType* const input, const size_t num_input, cudaStream_t stream) {
  const int64_t num_tiles = (num_input + TILE_SIZE - 1) / TILE_SIZE;
427
428
429
430
431
432

  const dim3 grid(num_tiles);
  const dim3 block(BLOCK_SIZE);

  auto device_table = MutableDeviceOrderedHashTable<IdType>(this);

433
434
435
  CUDA_KERNEL_CALL(
      (generate_hashmap_unique<IdType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0,
      stream, input, num_input, device_table);
436
437
438
439
440
441
442
443
}

template class OrderedHashTable<int32_t>;
template class OrderedHashTable<int64_t>;

}  // namespace cuda
}  // namespace runtime
}  // namespace dgl