concurrent_id_hash_map.cc 6.9 KB
Newer Older
1
2
/**
 *  Copyright (c) 2023 by Contributors
3
 * @file array/cpu/concurrent_id_hash_map.cc
4
5
6
 * @brief Class about id hash map
 */

7
#include "concurrent_id_hash_map.h"
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

#ifdef _MSC_VER
#include <intrin.h>
#endif  // _MSC_VER

#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/parallel_for.h>

#include <cmath>
#include <numeric>

using namespace dgl::runtime;

namespace {
static constexpr int64_t kEmptyKey = -1;
static constexpr int kGrainSize = 256;

// The formula is established from experience which is used
// to get the hashmap size from the input array size.
inline size_t GetMapSize(size_t num) {
  size_t capacity = 1;
  return capacity << static_cast<size_t>(1 + std::log2(num * 3));
}
}  // namespace

namespace dgl {
namespace aten {

template <typename IdType>
38
IdType ConcurrentIdHashMap<IdType>::CompareAndSwap(
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    IdType* ptr, IdType old_val, IdType new_val) {
#ifdef _MSC_VER
  if (sizeof(IdType) == 4) {
    return _InterlockedCompareExchange(
        reinterpret_cast<LONG*>(ptr), new_val, old_val);
  } else if (sizeof(IdType) == 8) {
    return _InterlockedCompareExchange64(
        reinterpret_cast<LONGLONG*>(ptr), new_val, old_val);
  } else {
    LOG(FATAL) << "ID can only be int32 or int64";
  }
#elif __GNUC__  // _MSC_VER
  return __sync_val_compare_and_swap(ptr, old_val, new_val);
#else           // _MSC_VER
#error "CompareAndSwap is not supported on this platform."
#endif  // _MSC_VER
}

template <typename IdType>
58
ConcurrentIdHashMap<IdType>::ConcurrentIdHashMap() : mask_(0) {
59
60
61
62
63
64
65
66
67
68
69
70
71
  // Used to deallocate the memory in hash_map_ with device api
  // when the pointer is freed.
  auto deleter = [](Mapping* mappings) {
    if (mappings != nullptr) {
      DGLContext ctx = DGLContext{kDGLCPU, 0};
      auto device = DeviceAPI::Get(ctx);
      device->FreeWorkspace(ctx, mappings);
    }
  };
  hash_map_ = {nullptr, deleter};
}

template <typename IdType>
72
73
IdArray ConcurrentIdHashMap<IdType>::Init(
    const IdArray& ids, size_t num_seeds) {
74
75
76
77
  CHECK_EQ(ids.defined(), true);
  const IdType* ids_data = ids.Ptr<IdType>();
  const size_t num_ids = static_cast<size_t>(ids->shape[0]);
  // Make sure `ids` is not 0 dim.
78
79
  CHECK_GE(num_seeds, 0);
  CHECK_GE(num_ids, num_seeds);
80
81
82
83
  size_t capacity = GetMapSize(num_ids);
  mask_ = static_cast<IdType>(capacity - 1);

  auto ctx = DGLContext{kDGLCPU, 0};
84
  auto device = DeviceAPI::Get(ctx);
85
  hash_map_.reset(static_cast<Mapping*>(
86
      device->AllocWorkspace(ctx, sizeof(Mapping) * capacity)));
87
88
89
90
  memset(hash_map_.get(), -1, sizeof(Mapping) * capacity);

  // This code block is to fill the ids into hash_map_.
  IdArray unique_ids = NewIdArray(num_ids, ctx, sizeof(IdType) * 8);
91
92
93
94
95
96
97
98
99
100
101
102
  IdType* unique_ids_data = unique_ids.Ptr<IdType>();
  // Fill in the first `num_seeds` ids.
  parallel_for(0, num_seeds, kGrainSize, [&](int64_t s, int64_t e) {
    for (int64_t i = s; i < e; i++) {
      InsertAndSet(ids_data[i], static_cast<IdType>(i));
    }
  });
  // Place the first `num_seeds` ids.
  device->CopyDataFromTo(
      ids_data, 0, unique_ids_data, 0, sizeof(IdType) * num_seeds, ctx, ctx,
      ids->dtype);

103
104
105
106
107
108
109
110
  // An auxiliary array indicates whether the corresponding elements
  // are inserted into hash map or not. Use `int16_t` instead of `bool` as
  // vector<bool> is unsafe when updating different elements from different
  // threads. See https://en.cppreference.com/w/cpp/container#Thread_safety.
  std::vector<int16_t> valid(num_ids);
  auto thread_num = compute_num_threads(0, num_ids, kGrainSize);
  std::vector<size_t> block_offset(thread_num + 1, 0);
  // Insert all elements in this loop.
111
  parallel_for(num_seeds, num_ids, kGrainSize, [&](int64_t s, int64_t e) {
112
113
    size_t count = 0;
    for (int64_t i = s; i < e; i++) {
114
      valid[i] = Insert(ids_data[i]);
115
116
117
118
119
120
121
122
      count += valid[i];
    }
    block_offset[omp_get_thread_num() + 1] = count;
  });

  // Get ExclusiveSum of each block.
  std::partial_sum(
      block_offset.begin() + 1, block_offset.end(), block_offset.begin() + 1);
123
  unique_ids->shape[0] = num_seeds + block_offset.back();
124
125

  // Get unique array from ids and set value for hash map.
126
  parallel_for(num_seeds, num_ids, kGrainSize, [&](int64_t s, int64_t e) {
127
    auto tid = omp_get_thread_num();
128
    auto pos = block_offset[tid] + num_seeds;
129
130
131
132
133
134
135
136
137
138
139
140
    for (int64_t i = s; i < e; i++) {
      if (valid[i]) {
        unique_ids_data[pos] = ids_data[i];
        Set(ids_data[i], pos);
        pos = pos + 1;
      }
    }
  });
  return unique_ids;
}

template <typename IdType>
141
IdArray ConcurrentIdHashMap<IdType>::MapIds(const IdArray& ids) const {
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
  CHECK_EQ(ids.defined(), true);
  const IdType* ids_data = ids.Ptr<IdType>();
  const size_t num_ids = static_cast<size_t>(ids->shape[0]);
  CHECK_GT(num_ids, 0);

  DGLContext ctx = DGLContext{kDGLCPU, 0};
  IdArray new_ids = NewIdArray(num_ids, ctx, sizeof(IdType) * 8);
  IdType* values_data = new_ids.Ptr<IdType>();

  parallel_for(0, num_ids, kGrainSize, [&](int64_t s, int64_t e) {
    for (int64_t i = s; i < e; i++) {
      values_data[i] = MapId(ids_data[i]);
    }
  });
  return new_ids;
}

template <typename IdType>
160
161
inline void ConcurrentIdHashMap<IdType>::Next(
    IdType* pos, IdType* delta) const {
162
163
164
165
166
167
  // Use Quadric probing.
  *pos = (*pos + (*delta) * (*delta)) & mask_;
  *delta = *delta + 1;
}

template <typename IdType>
168
inline IdType ConcurrentIdHashMap<IdType>::MapId(IdType id) const {
169
170
171
172
173
174
175
176
177
  IdType pos = (id & mask_), delta = 1;
  IdType empty_key = static_cast<IdType>(kEmptyKey);
  while (hash_map_[pos].key != empty_key && hash_map_[pos].key != id) {
    Next(&pos, &delta);
  }
  return hash_map_[pos].value;
}

template <typename IdType>
178
bool ConcurrentIdHashMap<IdType>::Insert(IdType id) {
179
  IdType pos = (id & mask_), delta = 1;
180
181
  InsertState state = AttemptInsertAt(pos, id);
  while (state == InsertState::OCCUPIED) {
182
    Next(&pos, &delta);
183
    state = AttemptInsertAt(pos, id);
184
  }
185
186

  return state == InsertState::INSERTED;
187
188
189
}

template <typename IdType>
190
inline void ConcurrentIdHashMap<IdType>::Set(IdType key, IdType value) {
191
192
193
194
195
196
197
198
199
  IdType pos = (key & mask_), delta = 1;
  while (hash_map_[pos].key != key) {
    Next(&pos, &delta);
  }

  hash_map_[pos].value = value;
}

template <typename IdType>
200
201
202
203
204
205
206
207
208
209
210
211
inline void ConcurrentIdHashMap<IdType>::InsertAndSet(IdType id, IdType value) {
  IdType pos = (id & mask_), delta = 1;
  while (AttemptInsertAt(pos, id) == InsertState::OCCUPIED) {
    Next(&pos, &delta);
  }

  hash_map_[pos].value = value;
}

template <typename IdType>
inline typename ConcurrentIdHashMap<IdType>::InsertState
ConcurrentIdHashMap<IdType>::AttemptInsertAt(int64_t pos, IdType key) {
212
213
  IdType empty_key = static_cast<IdType>(kEmptyKey);
  IdType old_val = CompareAndSwap(&(hash_map_[pos].key), empty_key, key);
214
215
216
217
  if (old_val == empty_key) {
    return InsertState::INSERTED;
  } else if (old_val == key) {
    return InsertState::EXISTED;
218
  } else {
219
    return InsertState::OCCUPIED;
220
221
222
  }
}

223
224
template class ConcurrentIdHashMap<int32_t>;
template class ConcurrentIdHashMap<int64_t>;
225
226
227

}  // namespace aten
}  // namespace dgl