Unverified Commit 707f2ae9 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt]Add concurrent id hash map (#6082)

parent 144a491b
/**
* Copyright (c) 2023 by Contributors
* @file concurrent_id_hash_map.cc
* @brief Class about id hash map.
*/
#include "concurrent_id_hash_map.h"
#ifdef _MSC_VER
#include <intrin.h>
#endif // _MSC_VER
#include <cmath>
#include <numeric>
namespace {
static constexpr int64_t kEmptyKey = -1;
static constexpr int kGrainSize = 256;
// The formula is established from experience which is used to get the hashmap
// size from the input array size.
inline size_t GetMapSize(size_t num) {
size_t capacity = 1;
return capacity << static_cast<size_t>(1 + std::log2(num * 3));
}
} // namespace
namespace graphbolt {
namespace sampling {
template <typename IdType>
IdType ConcurrentIdHashMap<IdType>::CompareAndSwap(
IdType* ptr, IdType old_val, IdType new_val) {
#ifdef _MSC_VER
if (sizeof(IdType) == 4) {
return _InterlockedCompareExchange(
reinterpret_cast<long*>(ptr), new_val, old_val);
} else if (sizeof(IdType) == 8) {
return _InterlockedCompareExchange64(
reinterpret_cast<long long*>(ptr), new_val, old_val);
} else {
LOG(FATAL) << "ID can only be int32 or int64";
}
#elif __GNUC__ // _MSC_VER
return __sync_val_compare_and_swap(ptr, old_val, new_val);
#else // _MSC_VER
#error "CompareAndSwap is not supported on this platform."
#endif // _MSC_VER
}
template <typename IdType>
ConcurrentIdHashMap<IdType>::ConcurrentIdHashMap() : mask_(0) {}
template <typename IdType>
torch::Tensor ConcurrentIdHashMap<IdType>::Init(
const torch::Tensor& ids, size_t num_seeds) {
const IdType* ids_data = ids.data_ptr<IdType>();
const size_t num_ids = static_cast<size_t>(ids.size(0));
size_t capacity = GetMapSize(num_ids);
mask_ = static_cast<IdType>(capacity - 1);
hash_map_ =
torch::full({static_cast<int64_t>(capacity * 2)}, -1, ids.options());
// This code block is to fill the ids into hash_map_.
auto unique_ids = torch::empty_like(ids);
IdType* unique_ids_data = unique_ids.data_ptr<IdType>();
// Fill in the first `num_seeds` ids.
torch::parallel_for(0, num_seeds, kGrainSize, [&](int64_t s, int64_t e) {
for (int64_t i = s; i < e; i++) {
InsertAndSet(ids_data[i], static_cast<IdType>(i));
}
});
// Place the first `num_seeds` ids.
unique_ids.slice(0, 0, num_seeds) = ids.slice(0, 0, num_seeds);
// An auxiliary array indicates whether the corresponding elements
// are inserted into hash map or not. Use `int16_t` instead of `bool` as
// vector<bool> is unsafe when updating different elements from different
// threads. See https://en.cppreference.com/w/cpp/container#Thread_safety.
std::vector<int16_t> valid(num_ids);
const int64_t num_threads = torch::get_num_threads();
std::vector<size_t> block_offset(num_threads + 1, 0);
// Insert all elements in this loop.
torch::parallel_for(
num_seeds, num_ids, kGrainSize, [&](int64_t s, int64_t e) {
size_t count = 0;
for (int64_t i = s; i < e; i++) {
valid[i] = Insert(ids_data[i]);
count += valid[i];
}
auto thread_id = torch::get_thread_num();
block_offset[thread_id + 1] = count;
});
// Get ExclusiveSum of each block.
std::partial_sum(
block_offset.begin() + 1, block_offset.end(), block_offset.begin() + 1);
unique_ids = unique_ids.slice(0, 0, num_seeds + block_offset.back());
// Get unique array from ids and set value for hash map.
torch::parallel_for(
num_seeds, num_ids, kGrainSize, [&](int64_t s, int64_t e) {
auto thread_id = torch::get_thread_num();
auto pos = block_offset[thread_id] + num_seeds;
for (int64_t i = s; i < e; i++) {
if (valid[i]) {
unique_ids_data[pos] = ids_data[i];
Set(ids_data[i], pos);
pos = pos + 1;
}
}
});
return unique_ids;
}
template <typename IdType>
torch::Tensor ConcurrentIdHashMap<IdType>::MapIds(
const torch::Tensor& ids) const {
const IdType* ids_data = ids.data_ptr<IdType>();
torch::Tensor new_ids = torch::empty_like(ids);
auto num_ids = new_ids.size(0);
IdType* values_data = new_ids.data_ptr<IdType>();
torch::parallel_for(0, num_ids, kGrainSize, [&](int64_t s, int64_t e) {
for (int64_t i = s; i < e; i++) {
values_data[i] = MapId(ids_data[i]);
}
});
return new_ids;
}
template <typename IdType>
constexpr IdType getKeyIndex(IdType pos) {
return 2 * pos;
}
template <typename IdType>
constexpr IdType getValueIndex(IdType pos) {
return 2 * pos + 1;
}
template <typename IdType>
inline void ConcurrentIdHashMap<IdType>::Next(
IdType* pos, IdType* delta) const {
// Use Quadric probing.
*pos = (*pos + (*delta) * (*delta)) & mask_;
*delta = *delta + 1;
}
template <typename IdType>
inline IdType ConcurrentIdHashMap<IdType>::MapId(IdType id) const {
IdType pos = (id & mask_), delta = 1;
IdType empty_key = static_cast<IdType>(kEmptyKey);
IdType* hash_map_data = hash_map_.data_ptr<IdType>();
IdType key = hash_map_data[getKeyIndex(pos)];
while (key != empty_key && key != id) {
Next(&pos, &delta);
key = hash_map_data[getKeyIndex(pos)];
}
return hash_map_data[getValueIndex(pos)];
}
template <typename IdType>
bool ConcurrentIdHashMap<IdType>::Insert(IdType id) {
IdType pos = (id & mask_), delta = 1;
InsertState state = AttemptInsertAt(pos, id);
while (state == InsertState::OCCUPIED) {
Next(&pos, &delta);
state = AttemptInsertAt(pos, id);
}
return state == InsertState::INSERTED;
}
template <typename IdType>
inline void ConcurrentIdHashMap<IdType>::Set(IdType key, IdType value) {
IdType pos = (key & mask_), delta = 1;
IdType* hash_map_data = hash_map_.data_ptr<IdType>();
while (hash_map_data[getKeyIndex(pos)] != key) {
Next(&pos, &delta);
}
hash_map_data[getValueIndex(pos)] = value;
}
template <typename IdType>
inline void ConcurrentIdHashMap<IdType>::InsertAndSet(IdType id, IdType value) {
IdType pos = (id & mask_), delta = 1;
while (AttemptInsertAt(pos, id) == InsertState::OCCUPIED) {
Next(&pos, &delta);
}
hash_map_.data_ptr<IdType>()[getValueIndex(pos)] = value;
}
template <typename IdType>
inline typename ConcurrentIdHashMap<IdType>::InsertState
ConcurrentIdHashMap<IdType>::AttemptInsertAt(int64_t pos, IdType key) {
IdType empty_key = static_cast<IdType>(kEmptyKey);
IdType* hash_map_data = hash_map_.data_ptr<IdType>();
IdType old_val =
CompareAndSwap(&(hash_map_data[getKeyIndex(pos)]), empty_key, key);
if (old_val == empty_key) {
return InsertState::INSERTED;
} else if (old_val == key) {
return InsertState::EXISTED;
} else {
return InsertState::OCCUPIED;
}
}
template class ConcurrentIdHashMap<int32_t>;
template class ConcurrentIdHashMap<int64_t>;
template class ConcurrentIdHashMap<int16_t>;
template class ConcurrentIdHashMap<int8_t>;
template class ConcurrentIdHashMap<uint8_t>;
} // namespace sampling
} // namespace graphbolt
/**
* Copyright (c) 2023 by Contributors
* @file concurrent_id_hash_map.h
* @brief Class about concurrent id hash map.
*/
#ifndef GRAPHBOLT_CONCURRENT_ID_HASH_MAP_H_
#define GRAPHBOLT_CONCURRENT_ID_HASH_MAP_H_
#include <torch/torch.h>
#include <functional>
#include <memory>
#include <vector>
namespace graphbolt {
namespace sampling {
/**
* @brief A CPU targeted hashmap for mapping duplicate and non-consecutive ids
* in the provided array to unique and consecutive ones. It utilizes
* multi-threading to accelerate the insert and search speed. Currently it is
* only designed to be used in `ToBlockCpu` for optimizing, so it only support
* key insertions once with Init function, and it does not support key deletion.
*
* The hash map should be prepared in two phases before using. With the first
* being creating the hashmap, and then initialize it with an id array which is
* divided into 2 parts: [`seed ids`, `sampled ids`]. `Seed ids` refer to
* a set ids chosen as the input for sampling process and `sampled ids` are the
* ids new sampled from the process (note the the `seed ids` might also be
* sampled in the process and included in the `sampled ids`). In result `seed
* ids` are mapped to [0, num_seed_ids) and `sampled ids` to [num_seed_ids,
* num_unique_ids). Notice that mapping order is stable for `seed ids` while not
* for the `sampled ids`.
*
* For example, for an array `A` having 4 seed ids with following entries:
* [99, 98, 100, 97, 97, 101, 101, 102, 101]
* Create the hashmap `H` with:
* `H = ConcurrentIdHashMap()` (1)
* And Init it with:
* `U = H.Init(A)` (2) (U is an id array used to store the unqiue
* ids in A).
* Then `U` should be (U is not exclusive as the overall mapping is not stable):
* [99, 98, 100, 97, 102, 101]
* And the hashmap should generate following mappings:
* * [
* {key: 99, value: 0},
* {key: 98, value: 1},
* {key: 100, value: 2},
* {key: 97, value: 3},
* {key: 102, value: 4},
* {key: 101, value: 5}
* ]
* Search the hashmap with array `I`=[98, 99, 102]:
* R = H.Map(I) (3)
* R should be:
* [1, 0, 4]
**/
template <typename IdType>
class ConcurrentIdHashMap {
private:
/**
* @brief The result state of an attempt to insert.
*/
enum class InsertState {
OCCUPIED, // Indicates that the space where an insertion is being
// attempted is already occupied by another element.
EXISTED, // Indicates that the element being inserted already exists in the
// map, and thus no insertion is performed.
INSERTED // Indicates that the insertion was successful and a new element
// was added to the map.
};
public:
/**
* @brief Cross platform CAS operation.
* It is an atomic operation that compares the contents of a memory
* location with a given value and, only if they are the same, modifies
* the contents of that memory location to a new given value.
*
* @param ptr The pointer to the object to test and modify .
* @param old_val The value expected to be found in `ptr`.
* @param new_val The value to store in `ptr` if it is as expected.
*
* @return Old value pointed by the `ptr`.
*/
static IdType CompareAndSwap(IdType* ptr, IdType old_val, IdType new_val);
ConcurrentIdHashMap();
ConcurrentIdHashMap(const ConcurrentIdHashMap& other) = delete;
ConcurrentIdHashMap& operator=(const ConcurrentIdHashMap& other) = delete;
/**
* @brief Initialize the hashmap with an array of ids. The first `num_seeds`
* ids are unique and must be mapped to a contiguous array starting
* from 0. The left can be duplicated and the mapping result is not stable.
*
* @param ids The array of the ids to be inserted.
* @param num_seeds The number of seed ids.
*
* @return Unique ids from the input `ids`.
*/
torch::Tensor Init(const torch::Tensor& ids, size_t num_seeds);
/**
* @brief Find mappings of given keys.
*
* @param ids The keys to map for.
*
* @return Mapping results corresponding to `ids`.
*/
torch::Tensor MapIds(const torch::Tensor& ids) const;
private:
/**
* @brief Get the next position and delta for probing.
*
* @param[in,out] pos Calculate the next position with quadric probing.
* @param[in,out] delta Calculate the next delta by adding 1.
*/
inline void Next(IdType* pos, IdType* delta) const;
/**
* @brief Find the mapping of a given key.
*
* @param id The key to map for.
*
* @return Mapping result corresponding to `id`.
*/
inline IdType MapId(const IdType id) const;
/**
* @brief Insert an id into the hash map.
*
* @param id The id to be inserted.
*
* @return Whether the `id` is inserted or not.
*/
inline bool Insert(IdType id);
/**
* @brief Set the value for the key in the hash map.
*
* @param key The key to set for.
* @param value The value to be set for the `key`.
*
* @warning Key must exist.
*/
inline void Set(IdType key, IdType value);
/**
* @brief Insert a key into the hash map.
*
* @param id The key to be inserted.
* @param value The value to be set for the `key`.
*
*/
inline void InsertAndSet(IdType key, IdType value);
/**
* @brief Attempt to insert the key into the hash map at the given position.
*
* @param pos The position in the hash map to be inserted at.
* @param key The key to be inserted.
*
* @return The state of the insertion.
*/
inline InsertState AttemptInsertAt(int64_t pos, IdType key);
private:
/**
* @brief Hash maps which is used to store all elements.
*/
torch::Tensor hash_map_;
/**
* @brief Mask which is assisted to get the position in the table
* for a key by performing `&` operation with it.
*/
IdType mask_;
};
} // namespace sampling
} // namespace graphbolt
#endif // GRAPHBOLT_CONCURRENT_ID_HASH_MAP_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment