cuda_hashtable.cuh 8.28 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2021 by Contributors
5
6
 * @file runtime/cuda/cuda_device_common.cuh
 * @brief Device level functions for within cuda kernels.
7
8
9
10
11
12
13
14
 */

#ifndef DGL_RUNTIME_CUDA_CUDA_HASHTABLE_CUH_
#define DGL_RUNTIME_CUDA_CUDA_HASHTABLE_CUH_

#include <dgl/runtime/c_runtime_api.h>

#include "cuda_common.h"
sangwzh's avatar
sangwzh committed
15
#include <hip/hip_runtime.h>
16
17
18
19
20

namespace dgl {
namespace runtime {
namespace cuda {

21
template <typename>
22
23
class OrderedHashTable;

24
/**
25
 * @brief A device-side handle for a GPU hashtable for mapping items to the
26
27
28
29
30
31
32
33
34
35
 * first index at which they appear in the provided data array.
 *
 * For any ID array A, one can view it as a mapping from the index `i`
 * (continuous integer range from zero) to its element `A[i]`. This hashtable
 * serves as a reverse mapping, i.e., from element `A[i]` to its index `i`.
 * Quadratic probing is used for collision resolution. See
 * DeviceOrderedHashTable's documentation for how the Mapping structure is
 * used.
 *
 * The hash table should be used in two phases, with the first being populating
36
 * the hash table with the OrderedHashTable object, and then generating this
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
 * handle from it. This object can then be used to search the hash table,
 * to find mappings, from with CUDA code.
 *
 * If a device-side handle is created from a hash table with the following
 * entries:
 * [
 *   {key: 0, local: 0, index: 0},
 *   {key: 3, local: 1, index: 1},
 *   {key: 2, local: 2, index: 2},
 *   {key: 8, local: 3, index: 4},
 *   {key: 4, local: 4, index: 5},
 *   {key: 1, local: 5, index: 8}
 * ]
 * The array [0, 3, 2, 0, 8, 4, 3, 2, 1, 8] could have `Search()` called on
 * each id, to be mapped via:
 * ```
 * __global__ void map(int32_t * array,
 *                     size_t size,
 *                     DeviceOrderedHashTable<int32_t> table) {
 *   int idx = threadIdx.x + blockIdx.x*blockDim.x;
 *   if (idx < size) {
 *     array[idx] = table.Search(array[idx])->local;
 *   }
 * }
 * ```
 * to get the remaped array:
 * [0, 1, 2, 0, 3, 4, 1, 2, 5, 3]
 *
65
 * @tparam IdType The type of the IDs.
66
 */
67
template <typename IdType>
68
class DeviceOrderedHashTable {
69
70
 public:
  /**
71
   * @brief An entry in the hashtable.
72
73
   */
  struct Mapping {
74
    /**
75
     * @brief The ID of the item inserted.
76
     */
77
    IdType key;
78
    /**
79
     * @brief The index of the item in the unique list.
80
81
     */
    IdType local;
82
    /**
83
     * @brief The index of the item when inserted into the hashtable (e.g.,
84
     * the index within the array passed into FillWithDuplicates()).
85
     */
86
87
88
89
90
    int64_t index;
  };

  typedef const Mapping* ConstIterator;

91
92
93
  DeviceOrderedHashTable(const DeviceOrderedHashTable& other) = default;
  DeviceOrderedHashTable& operator=(const DeviceOrderedHashTable& other) =
      default;
94
95

  /**
96
   * @brief Find the non-mutable mapping of a given key within the hash table.
97
98
99
100
   *
   * WARNING: The key must exist within the hashtable. Searching for a key not
   * in the hashtable is undefined behavior.
   *
101
   * @param id The key to search for.
102
   *
103
   * @return An iterator to the mapping.
104
   */
105
  inline __device__ ConstIterator Search(const IdType id) const {
106
107
108
109
110
111
    const IdType pos = SearchForPosition(id);

    return &table_[pos];
  }

  /**
112
   * @brief Check whether a key exists within the hashtable.
113
   *
114
   * @param id The key to check for.
115
   *
116
   * @return True if the key exists in the hashtable.
117
   */
118
  inline __device__ bool Contains(const IdType id) const {
119
120
121
122
123
124
    IdType pos = Hash(id);

    IdType delta = 1;
    while (table_[pos].key != kEmptyKey) {
      if (table_[pos].key == id) {
        return true;
125
      }
126
127
      pos = Hash(pos + delta);
      delta += 1;
128
    }
129
130
131
132
133
134
135
    return false;
  }

 protected:
  // Must be uniform bytes for memset to work
  static constexpr IdType kEmptyKey = static_cast<IdType>(-1);

136
  const Mapping* table_;
137
138
139
  size_t size_;

  /**
140
   * @brief Create a new device-side handle to the hash table.
141
   *
142
143
   * @param table The table stored in GPU memory.
   * @param size The size of the table.
144
   */
145
  explicit DeviceOrderedHashTable(const Mapping* table, size_t size);
146
147

  /**
148
   * @brief Search for an item in the hash table which is known to exist.
149
150
151
152
   *
   * WARNING: If the ID searched for does not exist within the hashtable, this
   * function will never return.
   *
153
   * @param id The ID of the item to search for.
154
   *
155
   * @return The the position of the item in the hashtable.
156
   */
157
  inline __device__ IdType SearchForPosition(const IdType id) const {
158
159
160
161
162
163
    IdType pos = Hash(id);

    // linearly scan for matching entry
    IdType delta = 1;
    while (table_[pos].key != id) {
      assert(table_[pos].key != kEmptyKey);
164
165
      pos = Hash(pos + delta);
      delta += 1;
166
    }
167
168
169
170
171
172
    assert(pos < size_);

    return pos;
  }

  /**
173
   * @brief Hash an ID to a to a position in the hash table.
174
   *
175
   * @param id The ID to hash.
176
   *
177
   * @return The hash.
178
   */
179
  inline __device__ size_t Hash(const IdType id) const { return id % size_; }
180
181

  friend class OrderedHashTable<IdType>;
182
183
};

184
/**
185
 * @brief A host-side handle for a GPU hashtable for mapping items to the
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
 * first index at which they appear in the provided data array. This host-side
 * handle is responsible for allocating and free the GPU memory of the
 * hashtable.
 *
 * For any ID array A, one can view it as a mapping from the index `i`
 * (continuous integer range from zero) to its element `A[i]`. This hashtable
 * serves as a reverse mapping, i.e., from element `A[i]` to its index `i`.
 * Quadratic probing is used for collision resolution.
 *
 * The hash table should be used in two phases, the first is filling the hash
 * table via 'FillWithDuplicates()' or 'FillWithUnique()'. Then, the
 * 'DeviceHandle()' method can be called, to get a version suitable for
 * searching from device and kernel functions.
 *
 * If 'FillWithDuplicates()' was called with an array of:
 * [0, 3, 2, 0, 8, 4, 3, 2, 1, 8]
 *
 * The resulting entries in the hash-table would be:
 * [
 *   {key: 0, local: 0, index: 0},
 *   {key: 3, local: 1, index: 1},
 *   {key: 2, local: 2, index: 2},
 *   {key: 8, local: 3, index: 4},
 *   {key: 4, local: 4, index: 5},
 *   {key: 1, local: 5, index: 8}
 * ]
 *
213
 * @tparam IdType The type of the IDs.
214
 */
215
template <typename IdType>
216
class OrderedHashTable {
217
218
219
220
221
222
 public:
  static constexpr int kDefaultScale = 3;

  using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping;

  /**
223
   * @brief Create a new ordered hash table. The amoutn of GPU memory
224
225
   * consumed by the resulting hashtable is O(`size` * 2^`scale`).
   *
226
227
228
   * @param size The number of items to insert into the hashtable.
   * @param ctx The device context to store the hashtable on.
   * @param scale The power of two times larger the number of buckets should
229
   * be than the number of items.
230
   * @param stream The stream to use for initializing the hashtable.
231
232
   */
  OrderedHashTable(
sangwzh's avatar
sangwzh committed
233
      const size_t size, DGLContext ctx, hipStream_t stream,
234
235
236
      const int scale = kDefaultScale);

  /**
237
   * @brief Cleanup after the hashtable.
238
239
240
241
   */
  ~OrderedHashTable();

  // Disable copying
242
243
  OrderedHashTable(const OrderedHashTable& other) = delete;
  OrderedHashTable& operator=(const OrderedHashTable& other) = delete;
244
245

  /**
246
   * @brief Fill the hashtable with the array containing possibly duplicate
247
248
   * IDs.
   *
249
250
251
252
253
   * @param input The array of IDs to insert.
   * @param num_input The number of IDs to insert.
   * @param unique The list of unique IDs inserted.
   * @param num_unique The number of unique IDs inserted.
   * @param stream The stream to perform operations on.
254
255
   */
  void FillWithDuplicates(
256
      const IdType* const input, const size_t num_input, IdType* const unique,
sangwzh's avatar
sangwzh committed
257
      int64_t* const num_unique, hipStream_t stream);
258
259

  /**
260
   * @brief Fill the hashtable with an array of unique keys.
261
   *
262
263
264
   * @param input The array of unique IDs.
   * @param num_input The number of keys.
   * @param stream The stream to perform operations on.
265
266
   */
  void FillWithUnique(
sangwzh's avatar
sangwzh committed
267
      const IdType* const input, const size_t num_input, hipStream_t stream);
268
269

  /**
270
   * @brief Get a verison of the hashtable usable from device functions.
271
   *
272
   * @return This hashtable.
273
274
275
276
   */
  DeviceOrderedHashTable<IdType> DeviceHandle() const;

 private:
277
  Mapping* table_;
278
279
  size_t size_;
  DGLContext ctx_;
280
281
};

282
283
284
}  // namespace cuda
}  // namespace runtime
}  // namespace dgl
285

286
#endif  // DGL_RUNTIME_CUDA_CUDA_HASHTABLE_CUH_