Unverified Commit 7dbc51cd authored by kgajdamo's avatar kgajdamo Committed by GitHub
Browse files

Replace unordered_map with a faster version (#254)



* Replace unordered_map with a faster version

* clone recursively repo when testing
Co-authored-by: default avatarMatthias Fey <matthias.fey@tu-dortmund.de>
parent 63b75d8d
......@@ -20,6 +20,9 @@ jobs:
steps:
- uses: actions/checkout@v2
with:
submodules: 'recursive'
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
......
[submodule "third_party/parallel-hashmap"]
path = third_party/parallel-hashmap
url = https://github.com/greg7mdp/parallel-hashmap.git
......@@ -39,6 +39,9 @@ target_include_directories(${PROJECT_NAME} INTERFACE
include(GNUInstallDirs)
include(CMakePackageConfigHelpers)
set(PHMAP_DIR third_party/parallel-hashmap)
target_include_directories(${PROJECT_NAME} PRIVATE ${PHMAP_DIR})
set(TORCHSPARSE_CMAKECONFIG_INSTALL_DIR "share/cmake/TorchSparse" CACHE STRING "install path for TorchSparseConfig.cmake")
configure_package_config_file(cmake/TorchSparseConfig.cmake.in
......
......@@ -2,6 +2,8 @@
#include "utils.h"
#include "parallel_hashmap/phmap.h"
#ifdef _WIN32
#include <process.h>
#endif
......@@ -17,7 +19,7 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
// Initialize some data structures for the sampling process:
vector<int64_t> samples;
unordered_map<int64_t, int64_t> to_local_node;
phmap::flat_hash_map<int64_t, int64_t> to_local_node;
auto *colptr_data = colptr.data_ptr<int64_t>();
auto *row_data = row.data_ptr<int64_t>();
......@@ -93,7 +95,7 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
}
if (!directed) {
unordered_map<int64_t, int64_t>::iterator iter;
phmap::flat_hash_map<int64_t, int64_t>::iterator iter;
for (int64_t i = 0; i < (int64_t)samples.size(); i++) {
const auto &w = samples[i];
const auto &col_start = colptr_data[w];
......
......@@ -103,11 +103,13 @@ def get_extensions():
if suffix == 'cuda' and osp.exists(path):
sources += [path]
phmap_dir = "third_party/parallel-hashmap"
Extension = CppExtension if suffix == 'cpu' else CUDAExtension
extension = Extension(
f'torch_sparse._{name}_{suffix}',
sources,
include_dirs=[extensions_dir],
include_dirs=[extensions_dir, phmap_dir],
define_macros=define_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
......
Subproject commit 01ea8093e6d0293ea252e8027c17d7dff26a9c9f
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment