Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
ebeb4509
Commit
ebeb4509
authored
Mar 30, 2021
by
rusty1s
Browse files
update with root_n_id
parent
54b0a095
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
10 deletions
+22
-10
csrc/cpu/ego_sample_cpu.cpp
csrc/cpu/ego_sample_cpu.cpp
+8
-4
csrc/cpu/ego_sample_cpu.h
csrc/cpu/ego_sample_cpu.h
+1
-1
csrc/ego_sample.cpp
csrc/ego_sample.cpp
+2
-2
test/test_ego_sample.py
test/test_ego_sample.py
+11
-3
No files found.
csrc/cpu/ego_sample_cpu.cpp
View file @
ebeb4509
...
...
@@ -8,9 +8,9 @@ inline torch::Tensor vec2tensor(std::vector<int64_t> vec) {
return
torch
::
from_blob
(
vec
.
data
(),
{(
int64_t
)
vec
.
size
()},
at
::
kLong
).
clone
();
}
// Returns `rowptr`, `col`, `n_id`, `e_id`, `ptr`
// Returns `rowptr`, `col`, `n_id`, `e_id`, `ptr`
, `root_n_id`
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
torch
::
Tensor
,
torch
::
Tensor
>
ego_k_hop_sample_adj_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
idx
,
int64_t
depth
,
int64_t
num_neighbors
,
bool
replace
)
{
...
...
@@ -19,12 +19,13 @@ ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col,
std
::
vector
<
torch
::
Tensor
>
out_cols
(
idx
.
numel
());
std
::
vector
<
torch
::
Tensor
>
out_n_ids
(
idx
.
numel
());
std
::
vector
<
torch
::
Tensor
>
out_e_ids
(
idx
.
numel
());
auto
out_root_n_id
=
torch
::
empty
({
idx
.
numel
()},
at
::
kLong
);
out_rowptrs
[
0
]
=
torch
::
zeros
({
1
},
at
::
kLong
);
auto
rowptr_data
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
col_data
=
col
.
data_ptr
<
int64_t
>
();
auto
idx_data
=
idx
.
data_ptr
<
int64_t
>
();
auto
out_root_n_id_data
=
out_root_n_id
.
data_ptr
<
int64_t
>
();
at
::
parallel_for
(
0
,
idx
.
numel
(),
1
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
int64_t
row_start
,
row_end
,
row_count
,
vec_start
,
vec_end
,
v
,
w
;
...
...
@@ -82,6 +83,8 @@ ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col,
i
++
;
}
out_root_n_id_data
[
g
]
=
n_id_map
[
idx_data
[
g
]];
std
::
vector
<
int64_t
>
rowptrs
,
cols
,
e_ids
;
for
(
int64_t
v
:
n_ids
)
{
row_start
=
rowptr_data
[
v
],
row_end
=
rowptr_data
[
v
+
1
];
...
...
@@ -114,11 +117,12 @@ ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col,
out_rowptrs
[
g
+
1
].
add_
(
edge_cumsum
);
out_cols
[
g
].
add_
(
node_cumsum
);
out_ptr_data
[
g
]
=
node_cumsum
;
out_root_n_id_data
[
g
]
+=
node_cumsum
;
}
node_cumsum
+=
out_n_ids
[
idx
.
numel
()
-
1
].
numel
();
out_ptr_data
[
idx
.
numel
()]
=
node_cumsum
;
return
std
::
make_tuple
(
torch
::
cat
(
out_rowptrs
,
0
),
torch
::
cat
(
out_cols
,
0
),
torch
::
cat
(
out_n_ids
,
0
),
torch
::
cat
(
out_e_ids
,
0
),
out_ptr
);
out_ptr
,
out_root_n_id
);
}
csrc/cpu/ego_sample_cpu.h
View file @
ebeb4509
...
...
@@ -3,7 +3,7 @@
#include <torch/extension.h>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
torch
::
Tensor
,
torch
::
Tensor
>
ego_k_hop_sample_adj_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
idx
,
int64_t
depth
,
int64_t
num_neighbors
,
bool
replace
);
csrc/ego_sample.cpp
View file @
ebeb4509
...
...
@@ -11,9 +11,9 @@ PyMODINIT_FUNC PyInit__ego_sample_cpu(void) { return NULL; }
#endif
#endif
// Returns `rowptr`, `col`, `n_id`, `e_id`, `ptr`
// Returns `rowptr`, `col`, `n_id`, `e_id`, `ptr`
, `root_n_id`
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
torch
::
Tensor
,
torch
::
Tensor
>
ego_k_hop_sample_adj
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
idx
,
int64_t
depth
,
int64_t
num_neighbors
,
bool
replace
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
...
...
test/test_ego_sample.py
View file @
ebeb4509
...
...
@@ -8,7 +8,15 @@ def test_ego_k_hop_sample_adj():
col
=
torch
.
tensor
([
1
,
2
,
3
,
0
,
2
,
0
,
1
,
4
,
5
,
0
,
2
,
5
,
2
,
4
])
_
=
SparseTensor
(
row
=
row
,
col
=
col
,
sparse_sizes
=
(
6
,
6
))
idx
=
torch
.
tensor
([
2
])
nid
=
torch
.
tensor
([
0
,
1
])
fn
=
torch
.
ops
.
torch_sparse
.
ego_k_hop_sample_adj
fn
(
rowptr
,
col
,
idx
,
1
,
3
,
False
)
out
=
fn
(
rowptr
,
col
,
nid
,
1
,
3
,
False
)
rowptr
,
col
,
nid
,
eid
,
ptr
,
root_n_id
=
out
assert
nid
.
tolist
()
==
[
0
,
1
,
2
,
3
,
0
,
1
,
2
]
assert
rowptr
.
tolist
()
==
[
0
,
3
,
5
,
7
,
8
,
10
,
12
,
14
]
# row [0, 0, 0, 1, 1, 2, 2, 3, 4, 4, 5, 5, 6, 6]
assert
col
.
tolist
()
==
[
1
,
2
,
3
,
0
,
2
,
0
,
1
,
0
,
5
,
6
,
4
,
6
,
4
,
5
]
assert
eid
.
tolist
()
==
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
9
,
0
,
1
,
3
,
4
,
5
,
6
]
assert
ptr
.
tolist
()
==
[
0
,
4
,
7
]
assert
root_n_id
.
tolist
()
==
[
0
,
5
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment