Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
84b46170
Commit
84b46170
authored
Aug 10, 2021
by
rusty1s
Browse files
bugfix
parent
014c4bae
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
24 deletions
+23
-24
csrc/cpu/neighbor_sample_cpu.cpp
csrc/cpu/neighbor_sample_cpu.cpp
+23
-24
No files found.
csrc/cpu/neighbor_sample_cpu.cpp
View file @
84b46170
...
...
@@ -39,9 +39,8 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
if
(
col_count
==
0
)
continue
;
if
(
replace
)
{
for
(
int64_t
j
=
0
;
j
<
num_samples
;
j
++
)
{
const
int64_t
offset
=
col_start
+
rand
()
%
col_count
;
if
((
num_samples
<
0
)
||
(
!
replace
&&
(
num_samples
>=
col_count
)))
{
for
(
int64_t
offset
=
col_start
;
offset
<
col_end
;
offset
++
)
{
const
int64_t
&
v
=
row_data
[
offset
];
const
auto
res
=
to_local_node
.
insert
({
v
,
samples
.
size
()});
if
(
res
.
second
)
...
...
@@ -52,8 +51,9 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
edges
.
push_back
(
offset
);
}
}
}
else
if
(
num_samples
>=
col_count
)
{
for
(
int64_t
offset
=
col_start
;
offset
<
col_end
;
offset
++
)
{
}
else
if
(
replace
)
{
for
(
int64_t
j
=
0
;
j
<
num_samples
;
j
++
)
{
const
int64_t
offset
=
col_start
+
rand
()
%
col_count
;
const
int64_t
&
v
=
row_data
[
offset
];
const
auto
res
=
to_local_node
.
insert
({
v
,
samples
.
size
()});
if
(
res
.
second
)
...
...
@@ -111,14 +111,14 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
}
template
<
bool
replace
,
bool
directed
>
std
::
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
hetero_sample
(
const
std
::
vector
<
node_t
>
&
node_types
,
const
std
::
vector
<
edge_t
>
&
edge_types
,
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
hetero_sample
(
const
vector
<
node_t
>
&
node_types
,
const
vector
<
edge_t
>
&
edge_types
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
colptr_dict
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
row_dict
,
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
input_node_dict
,
const
c10
::
Dict
<
rel_t
,
std
::
vector
<
int64_t
>>
&
num_neighbors_dict
,
const
c10
::
Dict
<
rel_t
,
vector
<
int64_t
>>
&
num_neighbors_dict
,
const
int64_t
num_hops
)
{
// Create a mapping to convert single string relations to edge type triplets:
...
...
@@ -129,9 +129,9 @@ hetero_sample(const std::vector<node_t> &node_types,
// Initialize some data structures for the sampling process:
unordered_map
<
node_t
,
vector
<
int64_t
>>
samples_dict
;
unordered_map
<
node_t
,
unordered_map
<
int64_t
,
int64_t
>>
to_local_node_dict
;
for
(
const
auto
&
k
:
node_types
)
{
samples_dict
[
k
];
to_local_node_dict
[
k
];
for
(
const
auto
&
node_type
:
node_types
)
{
samples_dict
[
node_type
];
to_local_node_dict
[
node_type
];
}
unordered_map
<
rel_t
,
vector
<
int64_t
>>
rows_dict
,
cols_dict
,
edges_dict
;
...
...
@@ -167,7 +167,7 @@ hetero_sample(const std::vector<node_t> &node_types,
const
auto
&
edge_type
=
to_edge_type
[
rel_type
];
const
auto
&
src_node_type
=
get
<
0
>
(
edge_type
);
const
auto
&
dst_node_type
=
get
<
2
>
(
edge_type
);
const
auto
&
num_samples
=
kv
.
value
()[
ell
];
const
auto
num_samples
=
kv
.
value
()[
ell
];
const
auto
&
dst_samples
=
samples_dict
.
at
(
dst_node_type
);
auto
&
src_samples
=
samples_dict
.
at
(
src_node_type
);
auto
&
to_local_src_node
=
to_local_node_dict
.
at
(
src_node_type
);
...
...
@@ -190,9 +190,8 @@ hetero_sample(const std::vector<node_t> &node_types,
if
(
col_count
==
0
)
continue
;
if
(
replace
)
{
for
(
int64_t
j
=
0
;
j
<
num_samples
;
j
++
)
{
const
int64_t
offset
=
col_start
+
rand
()
%
col_count
;
if
((
num_samples
<
0
)
||
(
!
replace
&&
(
num_samples
>=
col_count
)))
{
for
(
int64_t
offset
=
col_start
;
offset
<
col_end
;
offset
++
)
{
const
int64_t
&
v
=
row_data
[
offset
];
const
auto
res
=
to_local_src_node
.
insert
({
v
,
src_samples
.
size
()});
if
(
res
.
second
)
...
...
@@ -203,8 +202,9 @@ hetero_sample(const std::vector<node_t> &node_types,
edges
.
push_back
(
offset
);
}
}
}
else
if
(
num_samples
>=
col_count
)
{
for
(
int64_t
offset
=
col_start
;
offset
<
col_end
;
offset
++
)
{
}
else
if
(
replace
)
{
for
(
int64_t
j
=
0
;
j
<
num_samples
;
j
++
)
{
const
int64_t
offset
=
col_start
+
rand
()
%
col_count
;
const
int64_t
&
v
=
row_data
[
offset
];
const
auto
res
=
to_local_src_node
.
insert
({
v
,
src_samples
.
size
()});
if
(
res
.
second
)
...
...
@@ -302,15 +302,14 @@ neighbor_sample_cpu(const torch::Tensor &colptr, const torch::Tensor &row,
}
}
std
::
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
hetero_neighbor_sample_cpu
(
const
std
::
vector
<
node_t
>
&
node_types
,
const
std
::
vector
<
edge_t
>
&
edge_types
,
const
vector
<
node_t
>
&
node_types
,
const
vector
<
edge_t
>
&
edge_types
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
colptr_dict
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
row_dict
,
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
input_node_dict
,
const
c10
::
Dict
<
rel_t
,
std
::
vector
<
int64_t
>>
&
num_neighbors_dict
,
const
c10
::
Dict
<
rel_t
,
vector
<
int64_t
>>
&
num_neighbors_dict
,
const
int64_t
num_hops
,
const
bool
replace
,
const
bool
directed
)
{
if
(
replace
&&
directed
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment