/** * Copyright (c) 2023 by Contributors * @file concurrent_id_hash_map.cc * @brief Class about id hash map. */ #include "concurrent_id_hash_map.h" #ifdef _MSC_VER #include #endif // _MSC_VER #include #include namespace { static constexpr int64_t kEmptyKey = -1; static constexpr int kGrainSize = 256; // The formula is established from experience which is used to get the hashmap // size from the input array size. inline size_t GetMapSize(size_t num) { size_t capacity = 1; return capacity << static_cast(1 + std::log2(num * 3)); } } // namespace namespace graphbolt { namespace sampling { template IdType ConcurrentIdHashMap::CompareAndSwap( IdType* ptr, IdType old_val, IdType new_val) { #ifdef _MSC_VER if (sizeof(IdType) == 4) { return _InterlockedCompareExchange( reinterpret_cast(ptr), new_val, old_val); } else if (sizeof(IdType) == 8) { return _InterlockedCompareExchange64( reinterpret_cast(ptr), new_val, old_val); } else { LOG(FATAL) << "ID can only be int32 or int64"; } #elif __GNUC__ // _MSC_VER return __sync_val_compare_and_swap(ptr, old_val, new_val); #else // _MSC_VER #error "CompareAndSwap is not supported on this platform." #endif // _MSC_VER } template ConcurrentIdHashMap::ConcurrentIdHashMap() : mask_(0) {} template torch::Tensor ConcurrentIdHashMap::Init( const torch::Tensor& ids, size_t num_seeds) { const IdType* ids_data = ids.data_ptr(); const size_t num_ids = static_cast(ids.size(0)); size_t capacity = GetMapSize(num_ids); mask_ = static_cast(capacity - 1); hash_map_ = torch::full({static_cast(capacity * 2)}, -1, ids.options()); // This code block is to fill the ids into hash_map_. auto unique_ids = torch::empty_like(ids); IdType* unique_ids_data = unique_ids.data_ptr(); // Fill in the first `num_seeds` ids. torch::parallel_for(0, num_seeds, kGrainSize, [&](int64_t s, int64_t e) { for (int64_t i = s; i < e; i++) { InsertAndSet(ids_data[i], static_cast(i)); } }); // Place the first `num_seeds` ids. unique_ids.slice(0, 0, num_seeds) = ids.slice(0, 0, num_seeds); // An auxiliary array indicates whether the corresponding elements // are inserted into hash map or not. Use `int16_t` instead of `bool` as // vector is unsafe when updating different elements from different // threads. See https://en.cppreference.com/w/cpp/container#Thread_safety. std::vector valid(num_ids); const int64_t num_threads = torch::get_num_threads(); std::vector block_offset(num_threads + 1, 0); // Insert all elements in this loop. torch::parallel_for( num_seeds, num_ids, kGrainSize, [&](int64_t s, int64_t e) { size_t count = 0; for (int64_t i = s; i < e; i++) { valid[i] = Insert(ids_data[i]); count += valid[i]; } auto thread_id = torch::get_thread_num(); block_offset[thread_id + 1] = count; }); // Get ExclusiveSum of each block. std::partial_sum( block_offset.begin() + 1, block_offset.end(), block_offset.begin() + 1); unique_ids = unique_ids.slice(0, 0, num_seeds + block_offset.back()); // Get unique array from ids and set value for hash map. torch::parallel_for( num_seeds, num_ids, kGrainSize, [&](int64_t s, int64_t e) { auto thread_id = torch::get_thread_num(); auto pos = block_offset[thread_id] + num_seeds; for (int64_t i = s; i < e; i++) { if (valid[i]) { unique_ids_data[pos] = ids_data[i]; Set(ids_data[i], pos); pos = pos + 1; } } }); return unique_ids; } template torch::Tensor ConcurrentIdHashMap::MapIds( const torch::Tensor& ids) const { const IdType* ids_data = ids.data_ptr(); torch::Tensor new_ids = torch::empty_like(ids); auto num_ids = new_ids.size(0); IdType* values_data = new_ids.data_ptr(); torch::parallel_for(0, num_ids, kGrainSize, [&](int64_t s, int64_t e) { for (int64_t i = s; i < e; i++) { values_data[i] = MapId(ids_data[i]); } }); return new_ids; } template constexpr IdType getKeyIndex(IdType pos) { return 2 * pos; } template constexpr IdType getValueIndex(IdType pos) { return 2 * pos + 1; } template inline void ConcurrentIdHashMap::Next( IdType* pos, IdType* delta) const { // Use Quadric probing. *pos = (*pos + (*delta) * (*delta)) & mask_; *delta = *delta + 1; } template inline IdType ConcurrentIdHashMap::MapId(IdType id) const { IdType pos = (id & mask_), delta = 1; IdType empty_key = static_cast(kEmptyKey); IdType* hash_map_data = hash_map_.data_ptr(); IdType key = hash_map_data[getKeyIndex(pos)]; while (key != empty_key && key != id) { Next(&pos, &delta); key = hash_map_data[getKeyIndex(pos)]; } return hash_map_data[getValueIndex(pos)]; } template bool ConcurrentIdHashMap::Insert(IdType id) { IdType pos = (id & mask_), delta = 1; InsertState state = AttemptInsertAt(pos, id); while (state == InsertState::OCCUPIED) { Next(&pos, &delta); state = AttemptInsertAt(pos, id); } return state == InsertState::INSERTED; } template inline void ConcurrentIdHashMap::Set(IdType key, IdType value) { IdType pos = (key & mask_), delta = 1; IdType* hash_map_data = hash_map_.data_ptr(); while (hash_map_data[getKeyIndex(pos)] != key) { Next(&pos, &delta); } hash_map_data[getValueIndex(pos)] = value; } template inline void ConcurrentIdHashMap::InsertAndSet(IdType id, IdType value) { IdType pos = (id & mask_), delta = 1; while (AttemptInsertAt(pos, id) == InsertState::OCCUPIED) { Next(&pos, &delta); } hash_map_.data_ptr()[getValueIndex(pos)] = value; } template inline typename ConcurrentIdHashMap::InsertState ConcurrentIdHashMap::AttemptInsertAt(int64_t pos, IdType key) { IdType empty_key = static_cast(kEmptyKey); IdType* hash_map_data = hash_map_.data_ptr(); IdType old_val = CompareAndSwap(&(hash_map_data[getKeyIndex(pos)]), empty_key, key); if (old_val == empty_key) { return InsertState::INSERTED; } else if (old_val == key) { return InsertState::EXISTED; } else { return InsertState::OCCUPIED; } } template class ConcurrentIdHashMap; template class ConcurrentIdHashMap; template class ConcurrentIdHashMap; template class ConcurrentIdHashMap; template class ConcurrentIdHashMap; } // namespace sampling } // namespace graphbolt