Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
e4cac317
"tests/vscode:/vscode.git/clone" did not exist on "899a3192b6d34f892c35764cda581fb9f7fffd9c"
Commit
e4cac317
authored
Jul 13, 2021
by
rusty1s
Browse files
fix
parent
e2dca775
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
5 deletions
+11
-5
csrc/cpu/hgt_sample_cpu.cpp
csrc/cpu/hgt_sample_cpu.cpp
+11
-5
No files found.
csrc/cpu/hgt_sample_cpu.cpp
View file @
e4cac317
...
...
@@ -119,19 +119,26 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &rowptr_dict,
rel_to_edge_type
[
rel_type
]
=
split
(
rel_type
);
}
// Initialize various data structures for the sampling process, and add the
// input nodes to the final sampled output set (line 1):
// Initialize various data structures for the sampling process:
std
::
unordered_map
<
node_t
,
std
::
vector
<
int64_t
>>
sampled_nodes_dict
;
std
::
unordered_map
<
node_t
,
std
::
unordered_map
<
int64_t
,
int64_t
>>
global_to_local_node_dict
;
std
::
unordered_map
<
node_t
,
std
::
unordered_map
<
int64_t
,
float
>>
budget_dict
;
for
(
const
auto
&
kv
:
num_samples_dict
)
{
const
auto
&
node_type
=
kv
.
key
();
sampled_nodes_dict
[
node_type
];
global_to_local_node_dict
[
node_type
];
budget_dict
[
node_type
];
}
// Add all input nodes of every node type to the sampled output set (line 1):
for
(
const
auto
&
kv
:
input_node_dict
)
{
const
auto
&
node_type
=
kv
.
key
();
const
auto
&
input_node
=
kv
.
value
();
const
auto
*
input_node_data
=
input_node
.
data_ptr
<
int64_t
>
();
auto
&
sampled_nodes
=
sampled_nodes_dict
[
node_type
]
;
auto
&
global_to_local_node
=
global_to_local_node_dict
[
node_type
]
;
auto
&
sampled_nodes
=
sampled_nodes_dict
.
at
(
node_type
)
;
auto
&
global_to_local_node
=
global_to_local_node_dict
.
at
(
node_type
)
;
// Add each origin node to the sampled output nodes:
for
(
int64_t
i
=
0
;
i
<
input_node
.
numel
();
i
++
)
{
...
...
@@ -143,7 +150,6 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &rowptr_dict,
// Update budget after all input nodes have been added to the sampled output
// set (line 2-5):
std
::
unordered_map
<
node_t
,
std
::
unordered_map
<
int64_t
,
float
>>
budget_dict
;
for
(
const
auto
&
kv
:
sampled_nodes_dict
)
{
update_budget
(
&
budget_dict
,
kv
.
first
,
kv
.
second
,
global_to_local_node_dict
,
rel_to_edge_type
,
rowptr_dict
,
col_dict
,
false
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment