Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
916ba55b
Unverified
Commit
916ba55b
authored
Aug 09, 2022
by
OlhaBabicheva
Committed by
GitHub
Aug 09, 2022
Browse files
Replace unordered_map with phmap in hetero_sample (#266)
parent
ae22058a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
12 deletions
+11
-12
csrc/cpu/neighbor_sample_cpu.cpp
csrc/cpu/neighbor_sample_cpu.cpp
+8
-10
csrc/cpu/utils.h
csrc/cpu/utils.h
+3
-2
No files found.
csrc/cpu/neighbor_sample_cpu.cpp
View file @
916ba55b
...
...
@@ -2,8 +2,6 @@
#include "utils.h"
#include "parallel_hashmap/phmap.h"
#ifdef _WIN32
#include <process.h>
#endif
...
...
@@ -142,21 +140,21 @@ hetero_sample(const vector<node_t> &node_types,
const
int64_t
num_hops
)
{
// Create a mapping to convert single string relations to edge type triplets:
unordered
_map
<
rel_t
,
edge_t
>
to_edge_type
;
phmap
::
flat_hash
_map
<
rel_t
,
edge_t
>
to_edge_type
;
for
(
const
auto
&
k
:
edge_types
)
to_edge_type
[
get
<
0
>
(
k
)
+
"__"
+
get
<
1
>
(
k
)
+
"__"
+
get
<
2
>
(
k
)]
=
k
;
// Initialize some data structures for the sampling process:
unordered
_map
<
node_t
,
vector
<
int64_t
>>
samples_dict
;
unordered_map
<
node_t
,
unordered
_map
<
int64_t
,
int64_t
>>
to_local_node_dict
;
unordered
_map
<
node_t
,
vector
<
int64_t
>>
root_time_dict
;
phmap
::
flat_hash
_map
<
node_t
,
vector
<
int64_t
>>
samples_dict
;
phmap
::
flat_hash_map
<
node_t
,
phmap
::
flat_hash
_map
<
int64_t
,
int64_t
>>
to_local_node_dict
;
phmap
::
flat_hash
_map
<
node_t
,
vector
<
int64_t
>>
root_time_dict
;
for
(
const
auto
&
node_type
:
node_types
)
{
samples_dict
[
node_type
];
to_local_node_dict
[
node_type
];
root_time_dict
[
node_type
];
}
unordered
_map
<
rel_t
,
vector
<
int64_t
>>
rows_dict
,
cols_dict
,
edges_dict
;
phmap
::
flat_hash
_map
<
rel_t
,
vector
<
int64_t
>>
rows_dict
,
cols_dict
,
edges_dict
;
for
(
const
auto
&
kv
:
colptr_dict
)
{
const
auto
&
rel_type
=
kv
.
key
();
rows_dict
[
rel_type
];
...
...
@@ -188,7 +186,7 @@ hetero_sample(const vector<node_t> &node_types,
}
}
unordered
_map
<
node_t
,
pair
<
int64_t
,
int64_t
>>
slice_dict
;
phmap
::
flat_hash
_map
<
node_t
,
pair
<
int64_t
,
int64_t
>>
slice_dict
;
for
(
const
auto
&
kv
:
samples_dict
)
slice_dict
[
kv
.
first
]
=
{
0
,
kv
.
second
.
size
()};
...
...
@@ -339,7 +337,7 @@ hetero_sample(const vector<node_t> &node_types,
}
if
(
!
directed
)
{
// Construct the subgraph among the sampled nodes:
unordered
_map
<
int64_t
,
int64_t
>::
iterator
iter
;
phmap
::
flat_hash
_map
<
int64_t
,
int64_t
>::
iterator
iter
;
for
(
const
auto
&
kv
:
colptr_dict
)
{
const
auto
&
rel_type
=
kv
.
key
();
const
auto
&
edge_type
=
to_edge_type
[
rel_type
];
...
...
@@ -455,4 +453,4 @@ hetero_temporal_neighbor_sample_cpu(
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
num_neighbors_dict
,
node_time_dict
,
num_hops
);
}
}
}
\ No newline at end of file
csrc/cpu/utils.h
View file @
916ba55b
#pragma once
#include "../extensions.h"
#include "parallel_hashmap/phmap.h"
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
...
...
@@ -27,7 +28,7 @@ inline torch::Tensor from_vector(const std::vector<scalar_t> &vec,
template
<
typename
key_t
,
typename
scalar_t
>
inline
c10
::
Dict
<
key_t
,
torch
::
Tensor
>
from_vector
(
const
std
::
unordered
_map
<
key_t
,
std
::
vector
<
scalar_t
>>
&
vec_dict
,
from_vector
(
const
phmap
::
flat_hash
_map
<
key_t
,
std
::
vector
<
scalar_t
>>
&
vec_dict
,
bool
inplace
=
false
)
{
c10
::
Dict
<
key_t
,
torch
::
Tensor
>
out_dict
;
for
(
const
auto
&
kv
:
vec_dict
)
...
...
@@ -91,7 +92,7 @@ template <bool replace>
inline
void
uniform_choice
(
const
int64_t
population
,
const
int64_t
num_samples
,
const
int64_t
*
idx_data
,
std
::
vector
<
int64_t
>
*
samples
,
std
::
unordered
_map
<
int64_t
,
int64_t
>
*
to_local_node
)
{
phmap
::
flat_hash
_map
<
int64_t
,
int64_t
>
*
to_local_node
)
{
if
(
population
==
0
||
num_samples
==
0
)
return
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment