Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
c30d2e13
Commit
c30d2e13
authored
Apr 12, 2020
by
bowendeng
Browse files
utils to support METIS with adj_wgt
parent
ab0cee58
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
31 additions
and
12 deletions
+31
-12
csrc/cpu/metis_cpu.cpp
csrc/cpu/metis_cpu.cpp
+3
-4
csrc/metis.cpp
csrc/metis.cpp
+3
-3
setup.py
setup.py
+1
-1
torch_sparse/metis.py
torch_sparse/metis.py
+17
-4
torch_sparse/utils.py
torch_sparse/utils.py
+7
-0
No files found.
csrc/cpu/metis_cpu.cpp
View file @
c30d2e13
...
...
@@ -6,9 +6,8 @@
#include "utils.h"
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
int64_t
num_parts
,
torch
::
Tensor
adjwgt
,
bool
recursive
)
{
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
int64_t
num_parts
,
torch
::
Tensor
edge_wgt
,
bool
recursive
)
{
#ifdef WITH_METIS
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
...
...
@@ -18,7 +17,7 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
auto
*
xadj
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
*
adjncy
=
col
.
data_ptr
<
int64_t
>
();
auto
*
adjwgt
=
adj
wgt
.
data_ptr
<
int64_t
>
();
auto
*
adjwgt
=
edge_
wgt
.
data_ptr
<
int64_t
>
();
int64_t
ncon
=
1
;
int64_t
objval
=
-
1
;
auto
part_data
=
part
.
data_ptr
<
int64_t
>
();
...
...
csrc/metis.cpp
View file @
c30d2e13
...
...
@@ -7,8 +7,8 @@
PyMODINIT_FUNC
PyInit__metis
(
void
)
{
return
NULL
;
}
#endif
torch
::
Tensor
partition
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
int64_t
num_parts
,
bool
recursive
)
{
torch
::
Tensor
partition
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
int64_t
num_parts
,
torch
::
Tensor
edge_wgt
,
bool
recursive
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
AT_ERROR
(
"No CUDA version supported"
);
...
...
@@ -16,7 +16,7 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
partition_cpu
(
rowptr
,
col
,
num_parts
,
recursive
);
return
partition_cpu
(
rowptr
,
col
,
num_parts
,
edge_wgt
,
recursive
);
}
}
...
...
setup.py
View file @
c30d2e13
...
...
@@ -80,7 +80,7 @@ tests_require = ['pytest', 'pytest-cov']
setup
(
name
=
'torch_sparse'
,
version
=
'0.6.
0
'
,
version
=
'0.6.
1
'
,
author
=
'Matthias Fey'
,
author_email
=
'matthias.fey@tu-dortmund.de'
,
url
=
'https://github.com/rusty1s/pytorch_sparse'
,
...
...
torch_sparse/metis.py
View file @
c30d2e13
...
...
@@ -3,15 +3,28 @@ from typing import Tuple
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.permute
import
permute
from
torch_sparse.utils
import
cartesian1d
def
metis_wgt
(
x
):
t1
,
t2
=
cartesian1d
(
x
,
x
)
diff
=
t1
-
t2
diff
=
diff
[
diff
!=
0
]
res
=
diff
.
abs
().
min
()
bod
=
x
.
max
()
-
x
.
min
()
scale
=
(
res
/
bod
).
item
()
tick
,
arange
=
scale
.
as_integer_ratio
()
x_ratio
=
(
x
-
x
.
min
())
/
bod
return
(
x_ratio
*
arange
+
tick
).
long
(),
tick
,
arange
def
partition
(
src
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
src
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
rowptr
,
col
=
src
.
storage
.
rowptr
().
cpu
(),
src
.
storage
.
col
().
cpu
()
adjwgt
=
src
.
storage
.
value
().
cpu
()
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
num_parts
,
adjwgt
,
edge_wgt
=
src
.
storage
.
value
().
cpu
()
edge_wgt
=
metis_wgt
(
edge_wgt
)[
0
]
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
num_parts
,
edge_wgt
,
recursive
)
cluster
=
cluster
.
to
(
src
.
device
())
...
...
torch_sparse/utils.py
View file @
c30d2e13
from
typing
import
Any
import
torch
try
:
from
typing_extensions
import
Final
# noqa
...
...
@@ -8,3 +9,9 @@ except ImportError:
def
is_scalar
(
other
:
Any
)
->
bool
:
return
isinstance
(
other
,
int
)
or
isinstance
(
other
,
float
)
def
cartesian1d
(
x
,
y
):
a1
,
a2
=
torch
.
meshgrid
([
x
,
y
])
coos
=
torch
.
stack
([
a1
,
a2
]).
T
.
reshape
(
-
1
,
2
)
return
coos
.
split
(
1
,
dim
=
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment