cuda_hashtable.cuh 8.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
/*!
 *  Copyright (c) 2021 by Contributors
 * \file runtime/cuda/cuda_device_common.cuh
 * \brief Device level functions for within cuda kernels.
 */

#ifndef DGL_RUNTIME_CUDA_CUDA_HASHTABLE_CUH_
#define DGL_RUNTIME_CUDA_CUDA_HASHTABLE_CUH_

#include <dgl/runtime/c_runtime_api.h>

#include "cuda_runtime.h"
#include "cuda_common.h"

namespace dgl {
namespace runtime {
namespace cuda {

template<typename>
class OrderedHashTable;

/*!
 * \brief A device-side handle for a GPU hashtable for mapping items to the
 * first index at which they appear in the provided data array.
 *
 * For any ID array A, one can view it as a mapping from the index `i`
 * (continuous integer range from zero) to its element `A[i]`. This hashtable
 * serves as a reverse mapping, i.e., from element `A[i]` to its index `i`.
 * Quadratic probing is used for collision resolution. See
 * DeviceOrderedHashTable's documentation for how the Mapping structure is
 * used.
 *
 * The hash table should be used in two phases, with the first being populating
 * the hash table with the OrderedHashTable object, and then generating this 
 * handle from it. This object can then be used to search the hash table,
 * to find mappings, from with CUDA code.
 *
 * If a device-side handle is created from a hash table with the following
 * entries:
 * [
 *   {key: 0, local: 0, index: 0},
 *   {key: 3, local: 1, index: 1},
 *   {key: 2, local: 2, index: 2},
 *   {key: 8, local: 3, index: 4},
 *   {key: 4, local: 4, index: 5},
 *   {key: 1, local: 5, index: 8}
 * ]
 * The array [0, 3, 2, 0, 8, 4, 3, 2, 1, 8] could have `Search()` called on
 * each id, to be mapped via:
 * ```
 * __global__ void map(int32_t * array,
 *                     size_t size,
 *                     DeviceOrderedHashTable<int32_t> table) {
 *   int idx = threadIdx.x + blockIdx.x*blockDim.x;
 *   if (idx < size) {
 *     array[idx] = table.Search(array[idx])->local;
 *   }
 * }
 * ```
 * to get the remaped array:
 * [0, 1, 2, 0, 3, 4, 1, 2, 5, 3]
 *
 * \tparam IdType The type of the IDs.
 */
template<typename IdType>
class DeviceOrderedHashTable {
  public:
    /**
    * \brief An entry in the hashtable.
    */
    struct Mapping {
      /**
      * \brief The ID of the item inserted.
      */
      IdType key;
      /**
      * \brief The index of the item in the unique list.
      */
      IdType local;
      /**
      * \brief The index of the item when inserted into the hashtable (e.g.,
      * the index within the array passed into FillWithDuplicates()).
      */
      int64_t index;
    };

    typedef const Mapping* ConstIterator;

    DeviceOrderedHashTable(
        const DeviceOrderedHashTable& other) = default;
    DeviceOrderedHashTable& operator=(
        const DeviceOrderedHashTable& other) = default;

    /**
    * \brief Find the non-mutable mapping of a given key within the hash table.
    *
    * WARNING: The key must exist within the hashtable. Searching for a key not
    * in the hashtable is undefined behavior.
    *
    * \param id The key to search for.
    *
    * \return An iterator to the mapping.
    */
    inline __device__ ConstIterator Search(
        const IdType id) const {
      const IdType pos = SearchForPosition(id);

      return &table_[pos];
    }

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    /**
     * \brief Check whether a key exists within the hashtable.
     *
     * \param id The key to check for.
     *
     * \return True if the key exists in the hashtable.
     */
    inline __device__ bool Contains(
        const IdType id) const {
      IdType pos = Hash(id);

      IdType delta = 1;
      while (table_[pos].key != kEmptyKey) {
        if (table_[pos].key == id) {
          return true;
        }
        pos = Hash(pos+delta);
        delta +=1;
      }
      return false;
    }

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
  protected:
    // Must be uniform bytes for memset to work
    static constexpr IdType kEmptyKey = static_cast<IdType>(-1);

    const Mapping * table_;
    size_t size_;

    /**
    * \brief Create a new device-side handle to the hash table.
    *
    * \param table The table stored in GPU memory.
    * \param size The size of the table.
    */
    explicit DeviceOrderedHashTable(
        const Mapping * table,
        size_t size);

    /**
    * \brief Search for an item in the hash table which is known to exist.
    *
    * WARNING: If the ID searched for does not exist within the hashtable, this
    * function will never return.
    *
    * \param id The ID of the item to search for.
    *
    * \return The the position of the item in the hashtable.
    */
    inline __device__ IdType SearchForPosition(
        const IdType id) const {
      IdType pos = Hash(id);

      // linearly scan for matching entry
      IdType delta = 1;
      while (table_[pos].key != id) {
        assert(table_[pos].key != kEmptyKey);
        pos = Hash(pos+delta);
        delta +=1;
      }
      assert(pos < size_);

      return pos;
    }

    /**
    * \brief Hash an ID to a to a position in the hash table.
    *
    * \param id The ID to hash.
    *
    * \return The hash.
    */
    inline __device__ size_t Hash(
        const IdType id) const {
      return id % size_;
    }

    friend class OrderedHashTable<IdType>;
};

/*!
 * \brief A host-side handle for a GPU hashtable for mapping items to the
 * first index at which they appear in the provided data array. This host-side
 * handle is responsible for allocating and free the GPU memory of the
 * hashtable.
 *
 * For any ID array A, one can view it as a mapping from the index `i`
 * (continuous integer range from zero) to its element `A[i]`. This hashtable
 * serves as a reverse mapping, i.e., from element `A[i]` to its index `i`.
 * Quadratic probing is used for collision resolution.
 *
 * The hash table should be used in two phases, the first is filling the hash
 * table via 'FillWithDuplicates()' or 'FillWithUnique()'. Then, the
 * 'DeviceHandle()' method can be called, to get a version suitable for
 * searching from device and kernel functions.
 *
 * If 'FillWithDuplicates()' was called with an array of:
 * [0, 3, 2, 0, 8, 4, 3, 2, 1, 8]
 *
 * The resulting entries in the hash-table would be:
 * [
 *   {key: 0, local: 0, index: 0},
 *   {key: 3, local: 1, index: 1},
 *   {key: 2, local: 2, index: 2},
 *   {key: 8, local: 3, index: 4},
 *   {key: 4, local: 4, index: 5},
 *   {key: 1, local: 5, index: 8}
 * ]
 *
 * \tparam IdType The type of the IDs.
 */
template<typename IdType>
class OrderedHashTable {
  public:
    static constexpr int kDefaultScale = 3;

    using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping;

    /**
    * \brief Create a new ordered hash table. The amoutn of GPU memory
    * consumed by the resulting hashtable is O(`size` * 2^`scale`).
    *
    * \param size The number of items to insert into the hashtable.
    * \param ctx The device context to store the hashtable on.
    * \param scale The power of two times larger the number of buckets should
    * be than the number of items.
    * \param stream The stream to use for initializing the hashtable.
    */
    OrderedHashTable(
        const size_t size,
        DGLContext ctx,
        cudaStream_t stream,
        const int scale = kDefaultScale);

    /**
    * \brief Cleanup after the hashtable.
    */
    ~OrderedHashTable();

    // Disable copying 
    OrderedHashTable(
        const OrderedHashTable& other) = delete;
    OrderedHashTable& operator=(
        const OrderedHashTable& other) = delete;

    /**
    * \brief Fill the hashtable with the array containing possibly duplicate
    * IDs.
    *
    * \param input The array of IDs to insert.
    * \param num_input The number of IDs to insert.
    * \param unique The list of unique IDs inserted.
    * \param num_unique The number of unique IDs inserted.
    * \param stream The stream to perform operations on.
    */
    void FillWithDuplicates(
        const IdType * const input,
        const size_t num_input,
        IdType * const unique,
        int64_t * const num_unique,
        cudaStream_t stream);

    /**
    * \brief Fill the hashtable with an array of unique keys.
    *
    * \param input The array of unique IDs.
    * \param num_input The number of keys.
    * \param stream The stream to perform operations on.
    */
    void FillWithUnique(
        const IdType * const input,
        const size_t num_input,
        cudaStream_t stream);

    /**
    * \brief Get a verison of the hashtable usable from device functions.
    * 
    * \return This hashtable.
    */
    DeviceOrderedHashTable<IdType> DeviceHandle() const;

  private:
    Mapping * table_;
    size_t size_;
    DGLContext ctx_;

};


}  // cuda
}  // runtime
}  // dgl

#endif