Unverified Commit 3c391533 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Make unique_and_compact deterministic (#7217)

parent 1ad78fba
...@@ -65,10 +65,10 @@ torch::Tensor ConcurrentIdHashMap<IdType>::Init( ...@@ -65,10 +65,10 @@ torch::Tensor ConcurrentIdHashMap<IdType>::Init(
// This code block is to fill the ids into hash_map_. // This code block is to fill the ids into hash_map_.
auto unique_ids = torch::empty_like(ids); auto unique_ids = torch::empty_like(ids);
IdType* unique_ids_data = unique_ids.data_ptr<IdType>(); IdType* unique_ids_data = unique_ids.data_ptr<IdType>();
// Fill in the first `num_seeds` ids. // Insert all ids into the hash map.
torch::parallel_for(0, num_seeds, kGrainSize, [&](int64_t s, int64_t e) { torch::parallel_for(0, num_ids, kGrainSize, [&](int64_t s, int64_t e) {
for (int64_t i = s; i < e; i++) { for (int64_t i = s; i < e; i++) {
InsertAndSet(ids_data[i], static_cast<IdType>(i)); InsertAndSetMin(ids_data[i], static_cast<IdType>(i));
} }
}); });
// Place the first `num_seeds` ids. // Place the first `num_seeds` ids.
...@@ -82,13 +82,16 @@ torch::Tensor ConcurrentIdHashMap<IdType>::Init( ...@@ -82,13 +82,16 @@ torch::Tensor ConcurrentIdHashMap<IdType>::Init(
const int64_t num_threads = torch::get_num_threads(); const int64_t num_threads = torch::get_num_threads();
std::vector<size_t> block_offset(num_threads + 1, 0); std::vector<size_t> block_offset(num_threads + 1, 0);
// Insert all elements in this loop.
// Count the valid numbers in each thread.
torch::parallel_for( torch::parallel_for(
num_seeds, num_ids, kGrainSize, [&](int64_t s, int64_t e) { num_seeds, num_ids, kGrainSize, [&](int64_t s, int64_t e) {
size_t count = 0; size_t count = 0;
for (int64_t i = s; i < e; i++) { for (int64_t i = s; i < e; i++) {
valid[i] = Insert(ids_data[i]); if (MapId(ids_data[i]) == i) {
count += valid[i]; count++;
valid[i] = 1;
}
} }
auto thread_id = torch::get_thread_num(); auto thread_id = torch::get_thread_num();
block_offset[thread_id + 1] = count; block_offset[thread_id + 1] = count;
...@@ -199,6 +202,27 @@ inline void ConcurrentIdHashMap<IdType>::InsertAndSet(IdType id, IdType value) { ...@@ -199,6 +202,27 @@ inline void ConcurrentIdHashMap<IdType>::InsertAndSet(IdType id, IdType value) {
hash_map_.data_ptr<IdType>()[getValueIndex(pos)] = value; hash_map_.data_ptr<IdType>()[getValueIndex(pos)] = value;
} }
template <typename IdType>
void ConcurrentIdHashMap<IdType>::InsertAndSetMin(IdType id, IdType value) {
IdType pos = (id & mask_), delta = 1;
IdType* hash_map_data = hash_map_.data_ptr<IdType>();
InsertState state = AttemptInsertAt(pos, id);
while (state == InsertState::OCCUPIED) {
Next(&pos, &delta);
state = AttemptInsertAt(pos, id);
}
IdType empty_key = static_cast<IdType>(kEmptyKey);
IdType val_pos = getValueIndex(pos);
IdType old_val = empty_key;
while (old_val == empty_key || old_val > value) {
IdType replaced_val =
CompareAndSwap(&(hash_map_data[val_pos]), old_val, value);
if (old_val == replaced_val) break;
old_val = replaced_val;
}
}
template <typename IdType> template <typename IdType>
inline typename ConcurrentIdHashMap<IdType>::InsertState inline typename ConcurrentIdHashMap<IdType>::InsertState
ConcurrentIdHashMap<IdType>::AttemptInsertAt(int64_t pos, IdType key) { ConcurrentIdHashMap<IdType>::AttemptInsertAt(int64_t pos, IdType key) {
......
...@@ -158,6 +158,16 @@ class ConcurrentIdHashMap { ...@@ -158,6 +158,16 @@ class ConcurrentIdHashMap {
*/ */
inline void InsertAndSet(IdType key, IdType value); inline void InsertAndSet(IdType key, IdType value);
/**
* @brief Insert a key into the hash map. If the key exists, set the value
* with the smaller value.
*
* @param id The key to be inserted.
* @param value The value to be set for the `key`.
*
*/
inline void InsertAndSetMin(IdType id, IdType value);
/** /**
* @brief Attempt to insert the key into the hash map at the given position. * @brief Attempt to insert the key into the hash map at the given position.
* *
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment