Unverified Commit 7dbc51cd authored by kgajdamo's avatar kgajdamo Committed by GitHub
Browse files

Replace unordered_map with a faster version (#254)



* Replace unordered_map with a faster version

* clone recursively repo when testing
Co-authored-by: default avatarMatthias Fey <matthias.fey@tu-dortmund.de>
parent 63b75d8d
...@@ -20,6 +20,9 @@ jobs: ...@@ -20,6 +20,9 @@ jobs:
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
with:
submodules: 'recursive'
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v2
with: with:
......
[submodule "third_party/parallel-hashmap"]
path = third_party/parallel-hashmap
url = https://github.com/greg7mdp/parallel-hashmap.git
...@@ -39,6 +39,9 @@ target_include_directories(${PROJECT_NAME} INTERFACE ...@@ -39,6 +39,9 @@ target_include_directories(${PROJECT_NAME} INTERFACE
include(GNUInstallDirs) include(GNUInstallDirs)
include(CMakePackageConfigHelpers) include(CMakePackageConfigHelpers)
set(PHMAP_DIR third_party/parallel-hashmap)
target_include_directories(${PROJECT_NAME} PRIVATE ${PHMAP_DIR})
set(TORCHSPARSE_CMAKECONFIG_INSTALL_DIR "share/cmake/TorchSparse" CACHE STRING "install path for TorchSparseConfig.cmake") set(TORCHSPARSE_CMAKECONFIG_INSTALL_DIR "share/cmake/TorchSparse" CACHE STRING "install path for TorchSparseConfig.cmake")
configure_package_config_file(cmake/TorchSparseConfig.cmake.in configure_package_config_file(cmake/TorchSparseConfig.cmake.in
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include "utils.h" #include "utils.h"
#include "parallel_hashmap/phmap.h"
#ifdef _WIN32 #ifdef _WIN32
#include <process.h> #include <process.h>
#endif #endif
...@@ -17,7 +19,7 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row, ...@@ -17,7 +19,7 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
// Initialize some data structures for the sampling process: // Initialize some data structures for the sampling process:
vector<int64_t> samples; vector<int64_t> samples;
unordered_map<int64_t, int64_t> to_local_node; phmap::flat_hash_map<int64_t, int64_t> to_local_node;
auto *colptr_data = colptr.data_ptr<int64_t>(); auto *colptr_data = colptr.data_ptr<int64_t>();
auto *row_data = row.data_ptr<int64_t>(); auto *row_data = row.data_ptr<int64_t>();
...@@ -93,7 +95,7 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row, ...@@ -93,7 +95,7 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
} }
if (!directed) { if (!directed) {
unordered_map<int64_t, int64_t>::iterator iter; phmap::flat_hash_map<int64_t, int64_t>::iterator iter;
for (int64_t i = 0; i < (int64_t)samples.size(); i++) { for (int64_t i = 0; i < (int64_t)samples.size(); i++) {
const auto &w = samples[i]; const auto &w = samples[i];
const auto &col_start = colptr_data[w]; const auto &col_start = colptr_data[w];
......
...@@ -103,11 +103,13 @@ def get_extensions(): ...@@ -103,11 +103,13 @@ def get_extensions():
if suffix == 'cuda' and osp.exists(path): if suffix == 'cuda' and osp.exists(path):
sources += [path] sources += [path]
phmap_dir = "third_party/parallel-hashmap"
Extension = CppExtension if suffix == 'cpu' else CUDAExtension Extension = CppExtension if suffix == 'cpu' else CUDAExtension
extension = Extension( extension = Extension(
f'torch_sparse._{name}_{suffix}', f'torch_sparse._{name}_{suffix}',
sources, sources,
include_dirs=[extensions_dir], include_dirs=[extensions_dir, phmap_dir],
define_macros=define_macros, define_macros=define_macros,
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args, extra_link_args=extra_link_args,
......
Subproject commit 01ea8093e6d0293ea252e8027c17d7dff26a9c9f
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment