"vscode:/vscode.git/clone" did not exist on "9792b9d7e3687df66dcbc31673246dc14911d923"
cuda_hashtable.cuh 8.19 KB
Newer Older
1
/**
2
 *  Copyright (c) 2021 by Contributors
3
4
 * @file runtime/cuda/cuda_device_common.cuh
 * @brief Device level functions for within cuda kernels.
5
6
7
8
9
10
11
12
 */

#ifndef DGL_RUNTIME_CUDA_CUDA_HASHTABLE_CUH_
#define DGL_RUNTIME_CUDA_CUDA_HASHTABLE_CUH_

#include <dgl/runtime/c_runtime_api.h>

#include "cuda_common.h"
13
#include "cuda_runtime.h"
14
15
16
17
18

namespace dgl {
namespace runtime {
namespace cuda {

19
template <typename>
20
21
class OrderedHashTable;

22
/**
23
 * @brief A device-side handle for a GPU hashtable for mapping items to the
24
25
26
27
28
29
30
31
32
33
 * first index at which they appear in the provided data array.
 *
 * For any ID array A, one can view it as a mapping from the index `i`
 * (continuous integer range from zero) to its element `A[i]`. This hashtable
 * serves as a reverse mapping, i.e., from element `A[i]` to its index `i`.
 * Quadratic probing is used for collision resolution. See
 * DeviceOrderedHashTable's documentation for how the Mapping structure is
 * used.
 *
 * The hash table should be used in two phases, with the first being populating
34
 * the hash table with the OrderedHashTable object, and then generating this
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
 * handle from it. This object can then be used to search the hash table,
 * to find mappings, from with CUDA code.
 *
 * If a device-side handle is created from a hash table with the following
 * entries:
 * [
 *   {key: 0, local: 0, index: 0},
 *   {key: 3, local: 1, index: 1},
 *   {key: 2, local: 2, index: 2},
 *   {key: 8, local: 3, index: 4},
 *   {key: 4, local: 4, index: 5},
 *   {key: 1, local: 5, index: 8}
 * ]
 * The array [0, 3, 2, 0, 8, 4, 3, 2, 1, 8] could have `Search()` called on
 * each id, to be mapped via:
 * ```
 * __global__ void map(int32_t * array,
 *                     size_t size,
 *                     DeviceOrderedHashTable<int32_t> table) {
 *   int idx = threadIdx.x + blockIdx.x*blockDim.x;
 *   if (idx < size) {
 *     array[idx] = table.Search(array[idx])->local;
 *   }
 * }
 * ```
 * to get the remaped array:
 * [0, 1, 2, 0, 3, 4, 1, 2, 5, 3]
 *
63
 * @tparam IdType The type of the IDs.
64
 */
65
template <typename IdType>
66
class DeviceOrderedHashTable {
67
68
 public:
  /**
69
   * @brief An entry in the hashtable.
70
71
   */
  struct Mapping {
72
    /**
73
     * @brief The ID of the item inserted.
74
     */
75
    IdType key;
76
    /**
77
     * @brief The index of the item in the unique list.
78
79
     */
    IdType local;
80
    /**
81
     * @brief The index of the item when inserted into the hashtable (e.g.,
82
     * the index within the array passed into FillWithDuplicates()).
83
     */
84
85
86
87
88
    int64_t index;
  };

  typedef const Mapping* ConstIterator;

89
90
91
  DeviceOrderedHashTable(const DeviceOrderedHashTable& other) = default;
  DeviceOrderedHashTable& operator=(const DeviceOrderedHashTable& other) =
      default;
92
93

  /**
94
   * @brief Find the non-mutable mapping of a given key within the hash table.
95
96
97
98
   *
   * WARNING: The key must exist within the hashtable. Searching for a key not
   * in the hashtable is undefined behavior.
   *
99
   * @param id The key to search for.
100
   *
101
   * @return An iterator to the mapping.
102
   */
103
  inline __device__ ConstIterator Search(const IdType id) const {
104
105
106
107
108
109
    const IdType pos = SearchForPosition(id);

    return &table_[pos];
  }

  /**
110
   * @brief Check whether a key exists within the hashtable.
111
   *
112
   * @param id The key to check for.
113
   *
114
   * @return True if the key exists in the hashtable.
115
   */
116
  inline __device__ bool Contains(const IdType id) const {
117
118
119
120
121
122
    IdType pos = Hash(id);

    IdType delta = 1;
    while (table_[pos].key != kEmptyKey) {
      if (table_[pos].key == id) {
        return true;
123
      }
124
125
      pos = Hash(pos + delta);
      delta += 1;
126
    }
127
128
129
130
131
132
133
    return false;
  }

 protected:
  // Must be uniform bytes for memset to work
  static constexpr IdType kEmptyKey = static_cast<IdType>(-1);

134
  const Mapping* table_;
135
136
137
  size_t size_;

  /**
138
   * @brief Create a new device-side handle to the hash table.
139
   *
140
141
   * @param table The table stored in GPU memory.
   * @param size The size of the table.
142
   */
143
  explicit DeviceOrderedHashTable(const Mapping* table, size_t size);
144
145

  /**
146
   * @brief Search for an item in the hash table which is known to exist.
147
148
149
150
   *
   * WARNING: If the ID searched for does not exist within the hashtable, this
   * function will never return.
   *
151
   * @param id The ID of the item to search for.
152
   *
153
   * @return The the position of the item in the hashtable.
154
   */
155
  inline __device__ IdType SearchForPosition(const IdType id) const {
156
157
158
159
160
161
    IdType pos = Hash(id);

    // linearly scan for matching entry
    IdType delta = 1;
    while (table_[pos].key != id) {
      assert(table_[pos].key != kEmptyKey);
162
163
      pos = Hash(pos + delta);
      delta += 1;
164
    }
165
166
167
168
169
170
    assert(pos < size_);

    return pos;
  }

  /**
171
   * @brief Hash an ID to a to a position in the hash table.
172
   *
173
   * @param id The ID to hash.
174
   *
175
   * @return The hash.
176
   */
177
  inline __device__ size_t Hash(const IdType id) const { return id % size_; }
178
179

  friend class OrderedHashTable<IdType>;
180
181
};

182
/**
183
 * @brief A host-side handle for a GPU hashtable for mapping items to the
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
 * first index at which they appear in the provided data array. This host-side
 * handle is responsible for allocating and free the GPU memory of the
 * hashtable.
 *
 * For any ID array A, one can view it as a mapping from the index `i`
 * (continuous integer range from zero) to its element `A[i]`. This hashtable
 * serves as a reverse mapping, i.e., from element `A[i]` to its index `i`.
 * Quadratic probing is used for collision resolution.
 *
 * The hash table should be used in two phases, the first is filling the hash
 * table via 'FillWithDuplicates()' or 'FillWithUnique()'. Then, the
 * 'DeviceHandle()' method can be called, to get a version suitable for
 * searching from device and kernel functions.
 *
 * If 'FillWithDuplicates()' was called with an array of:
 * [0, 3, 2, 0, 8, 4, 3, 2, 1, 8]
 *
 * The resulting entries in the hash-table would be:
 * [
 *   {key: 0, local: 0, index: 0},
 *   {key: 3, local: 1, index: 1},
 *   {key: 2, local: 2, index: 2},
 *   {key: 8, local: 3, index: 4},
 *   {key: 4, local: 4, index: 5},
 *   {key: 1, local: 5, index: 8}
 * ]
 *
211
 * @tparam IdType The type of the IDs.
212
 */
213
template <typename IdType>
214
class OrderedHashTable {
215
216
217
218
219
220
 public:
  static constexpr int kDefaultScale = 3;

  using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping;

  /**
221
   * @brief Create a new ordered hash table. The amoutn of GPU memory
222
223
   * consumed by the resulting hashtable is O(`size` * 2^`scale`).
   *
224
225
226
   * @param size The number of items to insert into the hashtable.
   * @param ctx The device context to store the hashtable on.
   * @param scale The power of two times larger the number of buckets should
227
   * be than the number of items.
228
   * @param stream The stream to use for initializing the hashtable.
229
230
   */
  OrderedHashTable(
231
      const size_t size, DGLContext ctx, cudaStream_t stream,
232
233
234
      const int scale = kDefaultScale);

  /**
235
   * @brief Cleanup after the hashtable.
236
237
238
239
   */
  ~OrderedHashTable();

  // Disable copying
240
241
  OrderedHashTable(const OrderedHashTable& other) = delete;
  OrderedHashTable& operator=(const OrderedHashTable& other) = delete;
242
243

  /**
244
   * @brief Fill the hashtable with the array containing possibly duplicate
245
246
   * IDs.
   *
247
248
249
250
251
   * @param input The array of IDs to insert.
   * @param num_input The number of IDs to insert.
   * @param unique The list of unique IDs inserted.
   * @param num_unique The number of unique IDs inserted.
   * @param stream The stream to perform operations on.
252
253
   */
  void FillWithDuplicates(
254
255
      const IdType* const input, const size_t num_input, IdType* const unique,
      int64_t* const num_unique, cudaStream_t stream);
256
257

  /**
258
   * @brief Fill the hashtable with an array of unique keys.
259
   *
260
261
262
   * @param input The array of unique IDs.
   * @param num_input The number of keys.
   * @param stream The stream to perform operations on.
263
264
   */
  void FillWithUnique(
265
      const IdType* const input, const size_t num_input, cudaStream_t stream);
266
267

  /**
268
   * @brief Get a verison of the hashtable usable from device functions.
269
   *
270
   * @return This hashtable.
271
272
273
274
   */
  DeviceOrderedHashTable<IdType> DeviceHandle() const;

 private:
275
  Mapping* table_;
276
277
  size_t size_;
  DGLContext ctx_;
278
279
};

280
281
282
}  // namespace cuda
}  // namespace runtime
}  // namespace dgl
283

284
#endif  // DGL_RUNTIME_CUDA_CUDA_HASHTABLE_CUH_