Unverified Commit 916ba55b authored by OlhaBabicheva's avatar OlhaBabicheva Committed by GitHub
Browse files

Replace unordered_map with phmap in hetero_sample (#266)

parent ae22058a
......@@ -2,8 +2,6 @@
#include "utils.h"
#include "parallel_hashmap/phmap.h"
#ifdef _WIN32
#include <process.h>
#endif
......@@ -142,21 +140,21 @@ hetero_sample(const vector<node_t> &node_types,
const int64_t num_hops) {
// Create a mapping to convert single string relations to edge type triplets:
unordered_map<rel_t, edge_t> to_edge_type;
phmap::flat_hash_map<rel_t, edge_t> to_edge_type;
for (const auto &k : edge_types)
to_edge_type[get<0>(k) + "__" + get<1>(k) + "__" + get<2>(k)] = k;
// Initialize some data structures for the sampling process:
unordered_map<node_t, vector<int64_t>> samples_dict;
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
unordered_map<node_t, vector<int64_t>> root_time_dict;
phmap::flat_hash_map<node_t, vector<int64_t>> samples_dict;
phmap::flat_hash_map<node_t, phmap::flat_hash_map<int64_t, int64_t>> to_local_node_dict;
phmap::flat_hash_map<node_t, vector<int64_t>> root_time_dict;
for (const auto &node_type : node_types) {
samples_dict[node_type];
to_local_node_dict[node_type];
root_time_dict[node_type];
}
unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
phmap::flat_hash_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
for (const auto &kv : colptr_dict) {
const auto &rel_type = kv.key();
rows_dict[rel_type];
......@@ -188,7 +186,7 @@ hetero_sample(const vector<node_t> &node_types,
}
}
unordered_map<node_t, pair<int64_t, int64_t>> slice_dict;
phmap::flat_hash_map<node_t, pair<int64_t, int64_t>> slice_dict;
for (const auto &kv : samples_dict)
slice_dict[kv.first] = {0, kv.second.size()};
......@@ -339,7 +337,7 @@ hetero_sample(const vector<node_t> &node_types,
}
if (!directed) { // Construct the subgraph among the sampled nodes:
unordered_map<int64_t, int64_t>::iterator iter;
phmap::flat_hash_map<int64_t, int64_t>::iterator iter;
for (const auto &kv : colptr_dict) {
const auto &rel_type = kv.key();
const auto &edge_type = to_edge_type[rel_type];
......@@ -455,4 +453,4 @@ hetero_temporal_neighbor_sample_cpu(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops);
}
}
}
\ No newline at end of file
#pragma once
#include "../extensions.h"
#include "parallel_hashmap/phmap.h"
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
......@@ -27,7 +28,7 @@ inline torch::Tensor from_vector(const std::vector<scalar_t> &vec,
template <typename key_t, typename scalar_t>
inline c10::Dict<key_t, torch::Tensor>
from_vector(const std::unordered_map<key_t, std::vector<scalar_t>> &vec_dict,
from_vector(const phmap::flat_hash_map<key_t, std::vector<scalar_t>> &vec_dict,
bool inplace = false) {
c10::Dict<key_t, torch::Tensor> out_dict;
for (const auto &kv : vec_dict)
......@@ -91,7 +92,7 @@ template <bool replace>
inline void
uniform_choice(const int64_t population, const int64_t num_samples,
const int64_t *idx_data, std::vector<int64_t> *samples,
std::unordered_map<int64_t, int64_t> *to_local_node) {
phmap::flat_hash_map<int64_t, int64_t> *to_local_node) {
if (population == 0 || num_samples == 0)
return;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment