Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
fff381c5
Commit
fff381c5
authored
Mar 19, 2020
by
rusty1s
Browse files
added saint extract_adj method
parent
92b1e639
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
115 additions
and
1 deletion
+115
-1
csrc/cpu/saint_cpu.cpp
csrc/cpu/saint_cpu.cpp
+49
-0
csrc/cpu/saint_cpu.h
csrc/cpu/saint_cpu.h
+7
-0
csrc/saint.cpp
csrc/saint.cpp
+25
-0
test/test_saint.py
test/test_saint.py
+11
-0
torch_sparse/__init__.py
torch_sparse/__init__.py
+1
-1
torch_sparse/saint.py
torch_sparse/saint.py
+22
-0
No files found.
csrc/cpu/saint_cpu.cpp
0 → 100644
View file @
fff381c5
#include "saint_cpu.h"
#include "utils.h"
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
subgraph_cpu
(
torch
::
Tensor
idx
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
row
,
torch
::
Tensor
col
)
{
CHECK_CPU
(
idx
);
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
CHECK_INPUT
(
idx
.
dim
()
==
1
);
CHECK_INPUT
(
rowptr
.
dim
()
==
1
);
CHECK_INPUT
(
col
.
dim
()
==
1
);
auto
assoc
=
torch
::
full
({
rowptr
.
size
(
0
)
-
1
},
-
1
,
idx
.
options
());
assoc
.
index_copy_
(
0
,
idx
,
torch
::
arange
(
idx
.
size
(
0
),
idx
.
options
()));
auto
idx_data
=
idx
.
data_ptr
<
int64_t
>
();
auto
rowptr_data
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
col_data
=
col
.
data_ptr
<
int64_t
>
();
auto
assoc_data
=
assoc
.
data_ptr
<
int64_t
>
();
std
::
vector
<
int64_t
>
rows
,
cols
,
indices
;
int64_t
v
,
w
,
w_new
,
row_start
,
row_end
;
for
(
int64_t
v_new
=
0
;
v_new
<
idx
.
size
(
0
);
v_new
++
)
{
v
=
idx_data
[
v_new
];
row_start
=
rowptr_data
[
v
];
row_end
=
rowptr_data
[
v
+
1
];
for
(
int64_t
j
=
row_start
;
j
<
row_end
;
j
++
)
{
w
=
col_data
[
j
];
w_new
=
assoc_data
[
w
];
if
(
w_new
>
-
1
)
{
rows
.
push_back
(
v_new
);
cols
.
push_back
(
w_new
);
indices
.
push_back
(
j
);
}
}
}
int64_t
length
=
rows
.
size
();
row
=
torch
::
from_blob
(
rows
.
data
(),
{
length
},
row
.
options
()).
clone
();
col
=
torch
::
from_blob
(
cols
.
data
(),
{
length
},
row
.
options
()).
clone
();
idx
=
torch
::
from_blob
(
indices
.
data
(),
{
length
},
row
.
options
()).
clone
();
return
std
::
make_tuple
(
row
,
col
,
idx
);
}
csrc/cpu/saint_cpu.h
0 → 100644
View file @
fff381c5
#pragma once
#include <torch/extension.h>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
subgraph_cpu
(
torch
::
Tensor
idx
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
row
,
torch
::
Tensor
col
);
csrc/saint.cpp
0 → 100644
View file @
fff381c5
#include <Python.h>
#include <torch/script.h>
#include "cpu/saint_cpu.h"
#ifdef _WIN32
PyMODINIT_FUNC
PyInit__saint
(
void
)
{
return
NULL
;
}
#endif
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
subgraph
(
torch
::
Tensor
idx
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
row
,
torch
::
Tensor
col
)
{
if
(
idx
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
AT_ERROR
(
"No CUDA version supported"
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
subgraph_cpu
(
idx
,
rowptr
,
row
,
col
);
}
}
static
auto
registry
=
torch
::
RegisterOperators
().
op
(
"torch_sparse::saint_subgraph"
,
&
subgraph
);
test/test_saint.py
View file @
fff381c5
import
pytest
import
pytest
import
torch
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.saint
import
subgraph
from
.utils
import
devices
from
.utils
import
devices
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_subgraph
(
device
):
row
=
torch
.
tensor
([
0
,
0
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
4
])
col
=
torch
.
tensor
([
1
,
2
,
0
,
2
,
0
,
1
,
3
,
2
,
4
,
3
])
adj
=
SparseTensor
(
row
=
row
,
col
=
col
).
to
(
device
)
node_idx
=
torch
.
tensor
([
0
,
1
,
2
])
adj
,
edge_index
=
subgraph
(
adj
,
node_idx
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_sample_node
(
device
):
def
test_sample_node
(
device
):
row
=
torch
.
tensor
([
0
,
0
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
4
])
row
=
torch
.
tensor
([
0
,
0
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
4
])
...
...
torch_sparse/__init__.py
View file @
fff381c5
...
@@ -9,7 +9,7 @@ expected_torch_version = (1, 4)
...
@@ -9,7 +9,7 @@ expected_torch_version = (1, 4)
try
:
try
:
for
library
in
[
for
library
in
[
'_version'
,
'_convert'
,
'_diag'
,
'_spmm'
,
'_spspmm'
,
'_metis'
,
'_version'
,
'_convert'
,
'_diag'
,
'_spmm'
,
'_spspmm'
,
'_metis'
,
'_rw'
'_rw'
,
'_saint'
]:
]:
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
library
,
[
osp
.
dirname
(
__file__
)]).
origin
)
library
,
[
osp
.
dirname
(
__file__
)]).
origin
)
...
...
torch_sparse/saint.py
View file @
fff381c5
...
@@ -6,6 +6,28 @@ from torch_scatter import scatter_add
...
@@ -6,6 +6,28 @@ from torch_scatter import scatter_add
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
def
subgraph
(
src
:
SparseTensor
,
node_idx
:
torch
.
Tensor
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
]:
row
,
col
,
value
=
src
.
coo
()
rowptr
=
src
.
storage
.
rowptr
()
data
=
torch
.
ops
.
torch_sparse
.
saint_subgraph
(
node_idx
,
rowptr
,
row
,
col
)
row
,
col
,
edge_index
=
data
if
value
is
not
None
:
value
=
value
[
edge_index
]
out
=
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
node_idx
.
size
(
0
),
node_idx
.
size
(
0
)),
is_sorted
=
True
)
return
out
,
edge_index
def
sample_node
(
src
:
SparseTensor
,
def
sample_node
(
src
:
SparseTensor
,
num_nodes
:
int
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
]:
num_nodes
:
int
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
]:
row
,
col
,
_
=
src
.
coo
()
row
,
col
,
_
=
src
.
coo
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment