Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
afbfdc97
Unverified
Commit
afbfdc97
authored
Apr 14, 2020
by
Matthias Fey
Committed by
GitHub
Apr 14, 2020
Browse files
Merge pull request #55 from bwdeng20/master
torch_sparse.partition supports weighted graphs
parents
056c0bab
7f8aac48
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
59 additions
and
16 deletions
+59
-16
.gitignore
.gitignore
+1
-0
csrc/cpu/metis_cpu.cpp
csrc/cpu/metis_cpu.cpp
+11
-3
csrc/cpu/metis_cpu.h
csrc/cpu/metis_cpu.h
+1
-0
csrc/metis.cpp
csrc/metis.cpp
+2
-1
test/test_metis.py
test/test_metis.py
+15
-7
torch_sparse/metis.py
torch_sparse/metis.py
+29
-5
No files found.
.gitignore
View file @
afbfdc97
...
...
@@ -6,3 +6,4 @@ dist/
*.egg-info/
.coverage
*.so
.idea/
csrc/cpu/metis_cpu.cpp
View file @
afbfdc97
...
...
@@ -7,25 +7,33 @@
#include "utils.h"
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
int64_t
num_parts
,
bool
recursive
)
{
#ifdef WITH_METIS
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
if
(
optional_value
.
has_value
())
{
CHECK_CPU
(
optional_value
.
value
());
CHECK_INPUT
(
optional_value
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_value
.
value
().
numel
()
==
col
.
numel
());
}
int64_t
nvtxs
=
rowptr
.
numel
()
-
1
;
auto
part
=
torch
::
empty
(
nvtxs
,
rowptr
.
options
());
auto
*
xadj
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
*
adjncy
=
col
.
data_ptr
<
int64_t
>
();
int64_t
*
adjwgt
=
NULL
;
if
(
optional_value
.
has_value
())
adjwgt
=
optional_value
.
value
().
data_ptr
<
int64_t
>
();
int64_t
ncon
=
1
;
int64_t
objval
=
-
1
;
auto
part_data
=
part
.
data_ptr
<
int64_t
>
();
if
(
recursive
)
{
METIS_PartGraphRecursive
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
NULL
,
NULL
,
NULL
,
METIS_PartGraphRecursive
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
NULL
,
NULL
,
adjwgt
,
&
num_parts
,
NULL
,
NULL
,
NULL
,
&
objval
,
part_data
);
}
else
{
METIS_PartGraphKway
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
NULL
,
NULL
,
NULL
,
METIS_PartGraphKway
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
NULL
,
NULL
,
adjwgt
,
&
num_parts
,
NULL
,
NULL
,
NULL
,
&
objval
,
part_data
);
}
...
...
csrc/cpu/metis_cpu.h
View file @
afbfdc97
...
...
@@ -3,4 +3,5 @@
#include <torch/extension.h>
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
int64_t
num_parts
,
bool
recursive
);
csrc/metis.cpp
View file @
afbfdc97
...
...
@@ -8,6 +8,7 @@ PyMODINIT_FUNC PyInit__metis(void) { return NULL; }
#endif
torch
::
Tensor
partition
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
int64_t
num_parts
,
bool
recursive
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
...
...
@@ -16,7 +17,7 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
partition_cpu
(
rowptr
,
col
,
num_parts
,
recursive
);
return
partition_cpu
(
rowptr
,
col
,
optional_value
,
num_parts
,
recursive
);
}
}
...
...
test/test_metis.py
View file @
afbfdc97
...
...
@@ -7,11 +7,19 @@ from .utils import devices
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_metis
(
device
):
mat
=
SparseTensor
.
from_dense
(
torch
.
randn
((
6
,
6
),
device
=
device
))
mat
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
2
,
recursive
=
False
)
value1
=
torch
.
randn
(
6
*
6
,
device
=
device
).
view
(
6
,
6
)
value2
=
torch
.
arange
(
6
*
6
,
dtype
=
torch
.
long
,
device
=
device
).
view
(
6
,
6
)
value3
=
torch
.
ones
(
6
*
6
,
device
=
device
).
view
(
6
,
6
)
for
value
in
[
value1
,
value2
,
value3
]:
mat
=
SparseTensor
.
from_dense
(
value
)
_
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
2
,
recursive
=
False
,
weighted
=
True
)
assert
partptr
.
numel
()
==
3
assert
perm
.
numel
()
==
6
mat
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
2
,
recursive
=
True
)
_
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
2
,
recursive
=
False
,
weighted
=
False
)
assert
partptr
.
numel
()
==
3
assert
perm
.
numel
()
==
6
torch_sparse/metis.py
View file @
afbfdc97
from
typing
import
Tuple
from
typing
import
Tuple
,
Optional
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.permute
import
permute
def
partition
(
src
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
def
weight2metis
(
weight
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
sorted_weight
=
weight
.
sort
()[
0
]
diff
=
sorted_weight
[
1
:]
-
sorted_weight
[:
-
1
]
if
diff
.
sum
()
==
0
:
return
None
weight_min
,
weight_max
=
sorted_weight
[
0
],
sorted_weight
[
-
1
]
srange
=
weight_max
-
weight_min
min_diff
=
diff
.
min
()
scale
=
(
min_diff
/
srange
).
item
()
tick
,
arange
=
scale
.
as_integer_ratio
()
weight_ratio
=
(
weight
-
weight_min
).
div_
(
srange
).
mul_
(
arange
).
add_
(
tick
)
return
weight_ratio
.
to
(
torch
.
long
)
rowptr
,
col
=
src
.
storage
.
rowptr
().
cpu
(),
src
.
storage
.
col
().
cpu
()
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
num_parts
,
def
partition
(
src
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
,
weighted
=
False
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
=
rowptr
.
cpu
(),
col
.
cpu
()
if
value
is
not
None
and
weighted
:
assert
value
.
numel
()
==
col
.
numel
()
value
=
value
.
view
(
-
1
).
detach
().
cpu
()
if
value
.
is_floating_point
():
value
=
weight2metis
(
value
)
else
:
value
=
None
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
value
,
num_parts
,
recursive
)
cluster
=
cluster
.
to
(
src
.
device
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment