Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
4d49d44e
Unverified
Commit
4d49d44e
authored
Mar 26, 2021
by
Matthias Fey
Committed by
GitHub
Mar 26, 2021
Browse files
Merge pull request #123 from Spazierganger/vwgts
add node weights for metis wrapper
parents
54d8418e
cb1e30da
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
73 additions
and
35 deletions
+73
-35
csrc/cpu/metis_cpu.cpp
csrc/cpu/metis_cpu.cpp
+35
-8
csrc/cpu/metis_cpu.h
csrc/cpu/metis_cpu.h
+6
-4
csrc/metis.cpp
csrc/metis.cpp
+6
-3
test/test_metis.py
test/test_metis.py
+16
-16
torch_sparse/metis.py
torch_sparse/metis.py
+10
-4
No files found.
csrc/cpu/metis_cpu.cpp
View file @
4d49d44e
...
...
@@ -12,32 +12,46 @@
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
optional_node_weight
,
int64_t
num_parts
,
bool
recursive
)
{
#ifdef WITH_METIS
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
if
(
optional_value
.
has_value
())
{
CHECK_CPU
(
optional_value
.
value
());
CHECK_INPUT
(
optional_value
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_value
.
value
().
numel
()
==
col
.
numel
());
}
if
(
optional_node_weight
.
has_value
())
{
CHECK_CPU
(
optional_node_weight
.
value
());
CHECK_INPUT
(
optional_node_weight
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_node_weight
.
value
().
numel
()
==
rowptr
.
numel
()
-
1
);
}
int64_t
nvtxs
=
rowptr
.
numel
()
-
1
;
int64_t
ncon
=
1
;
auto
*
xadj
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
*
adjncy
=
col
.
data_ptr
<
int64_t
>
();
int64_t
*
adjwgt
=
NULL
;
if
(
optional_value
.
has_value
())
adjwgt
=
optional_value
.
value
().
data_ptr
<
int64_t
>
();
int64_t
*
vwgt
=
NULL
;
if
(
optional_node_weight
.
has_value
())
vwgt
=
optional_node_weight
.
value
().
data_ptr
<
int64_t
>
();
int64_t
objval
=
-
1
;
auto
part
=
torch
::
empty
(
nvtxs
,
rowptr
.
options
());
auto
part_data
=
part
.
data_ptr
<
int64_t
>
();
if
(
recursive
)
{
METIS_PartGraphRecursive
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
NULL
,
NULL
,
adjwgt
,
METIS_PartGraphRecursive
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
vwgt
,
NULL
,
adjwgt
,
&
num_parts
,
NULL
,
NULL
,
NULL
,
&
objval
,
part_data
);
}
else
{
METIS_PartGraphKway
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
NULL
,
NULL
,
adjwgt
,
METIS_PartGraphKway
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
vwgt
,
NULL
,
adjwgt
,
&
num_parts
,
NULL
,
NULL
,
NULL
,
&
objval
,
part_data
);
}
...
...
@@ -50,10 +64,11 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
// needs mt-metis installed via:
// ./configure --shared --edges64bit --vertices64bit --weights64bit
// --partitions64bit
torch
::
Tensor
mt_partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
int64_t
num_parts
,
bool
recursive
,
int64_t
num_workers
)
{
torch
::
Tensor
mt_partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
optional_node_weight
,
int64_t
num_parts
,
bool
recursive
,
int64_t
num_workers
)
{
#ifdef WITH_MTMETIS
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
...
...
@@ -63,13 +78,25 @@ torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
CHECK_INPUT
(
optional_value
.
value
().
numel
()
==
col
.
numel
());
}
if
(
optional_node_weight
.
has_value
())
{
CHECK_CPU
(
optional_node_weight
.
value
());
CHECK_INPUT
(
optional_node_weight
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_node_weight
.
value
().
numel
()
==
rowptr
.
numel
()
-
1
);
}
mtmetis_vtx_type
nvtxs
=
rowptr
.
numel
()
-
1
;
mtmetis_vtx_type
ncon
=
1
;
mtmetis_adj_type
*
xadj
=
(
mtmetis_adj_type
*
)
rowptr
.
data_ptr
<
int64_t
>
();
mtmetis_vtx_type
*
adjncy
=
(
mtmetis_vtx_type
*
)
col
.
data_ptr
<
int64_t
>
();
mtmetis_wgt_type
*
adjwgt
=
NULL
;
if
(
optional_value
.
has_value
())
adjwgt
=
optional_value
.
value
().
data_ptr
<
int64_t
>
();
mtmetis_wgt_type
*
vwgt
=
NULL
;
if
(
optional_node_weight
.
has_value
())
vwgt
=
optional_node_weight
.
value
().
data_ptr
<
int64_t
>
();
mtmetis_pid_type
nparts
=
num_parts
;
mtmetis_wgt_type
objval
=
-
1
;
auto
part
=
torch
::
empty
(
nvtxs
,
rowptr
.
options
());
...
...
@@ -79,10 +106,10 @@ torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
opts
[
MTMETIS_OPTION_NTHREADS
]
=
num_workers
;
if
(
recursive
)
{
MTMETIS_PartGraphRecursive
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
NULL
,
NULL
,
adjwgt
,
MTMETIS_PartGraphRecursive
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
vwgt
,
NULL
,
adjwgt
,
&
nparts
,
NULL
,
NULL
,
opts
,
&
objval
,
part_data
);
}
else
{
MTMETIS_PartGraphKway
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
NULL
,
NULL
,
adjwgt
,
MTMETIS_PartGraphKway
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
vwgt
,
NULL
,
adjwgt
,
&
nparts
,
NULL
,
NULL
,
opts
,
&
objval
,
part_data
);
}
...
...
csrc/cpu/metis_cpu.h
View file @
4d49d44e
...
...
@@ -4,9 +4,11 @@
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
optional_node_weight
,
int64_t
num_parts
,
bool
recursive
);
torch
::
Tensor
mt_partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
int64_t
num_parts
,
bool
recursive
,
int64_t
num_workers
);
torch
::
Tensor
mt_partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
optional_node_weight
,
int64_t
num_parts
,
bool
recursive
,
int64_t
num_workers
);
csrc/metis.cpp
View file @
4d49d44e
...
...
@@ -13,6 +13,7 @@ PyMODINIT_FUNC PyInit__metis_cpu(void) { return NULL; }
torch
::
Tensor
partition
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
optional_node_weight
,
int64_t
num_parts
,
bool
recursive
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
...
...
@@ -21,12 +22,14 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
partition_cpu
(
rowptr
,
col
,
optional_value
,
num_parts
,
recursive
);
return
partition_cpu
(
rowptr
,
col
,
optional_value
,
optional_node_weight
,
num_parts
,
recursive
);
}
}
torch
::
Tensor
mt_partition
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
optional_node_weight
,
int64_t
num_parts
,
bool
recursive
,
int64_t
num_workers
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
...
...
@@ -36,8 +39,8 @@ torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
mt_partition_cpu
(
rowptr
,
col
,
optional_value
,
num_parts
,
recursive
,
num_workers
);
return
mt_partition_cpu
(
rowptr
,
col
,
optional_value
,
optional_node_weight
,
num_parts
,
recursive
,
num_workers
);
}
}
...
...
test/test_metis.py
View file @
4d49d44e
import
pytest
from
itertools
import
product
import
torch
from
torch_sparse.tensor
import
SparseTensor
...
...
@@ -12,26 +14,24 @@ except RuntimeError:
@
pytest
.
mark
.
skipif
(
not
with_metis
,
reason
=
'Not compiled with METIS support'
)
@
pytest
.
mark
.
parametrize
(
'device
'
,
devices
)
def
test_metis
(
device
):
value
1
=
torch
.
randn
(
6
*
6
,
device
=
device
).
view
(
6
,
6
)
value
2
=
torch
.
arange
(
6
*
6
,
dtype
=
torch
.
long
,
device
=
device
).
view
(
6
,
6
)
value
3
=
torch
.
ones
(
6
*
6
,
device
=
device
).
view
(
6
,
6
)
@
pytest
.
mark
.
parametrize
(
'device
,weighted'
,
product
(
devices
,
[
False
,
True
])
)
def
test_metis
(
device
,
weighted
):
mat
1
=
torch
.
randn
(
6
*
6
,
device
=
device
).
view
(
6
,
6
)
mat
2
=
torch
.
arange
(
6
*
6
,
dtype
=
torch
.
long
,
device
=
device
).
view
(
6
,
6
)
mat
3
=
torch
.
ones
(
6
*
6
,
device
=
device
).
view
(
6
,
6
)
for
value
in
[
value1
,
value2
,
value3
]:
mat
=
SparseTensor
.
from_dense
(
valu
e
)
vec1
=
None
vec2
=
torch
.
rand
(
6
,
device
=
devic
e
)
_
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
2
,
recursive
=
False
,
weighted
=
True
)
assert
partptr
.
numel
()
==
3
for
mat
,
vec
in
product
([
mat1
,
mat2
,
mat3
],
[
vec1
,
vec2
]):
mat
=
SparseTensor
.
from_dense
(
mat
)
_
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
1
,
recursive
=
False
,
weighted
=
weighted
,
node_weight
=
vec
)
assert
partptr
.
numel
()
==
2
assert
perm
.
numel
()
==
6
_
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
2
,
recursive
=
False
,
weighted
=
False
)
weighted
=
weighted
,
node_weight
=
vec
)
assert
partptr
.
numel
()
==
3
assert
perm
.
numel
()
==
6
_
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
1
,
recursive
=
False
,
weighted
=
True
)
assert
partptr
.
numel
()
==
2
assert
perm
.
numel
()
==
6
torch_sparse/metis.py
View file @
4d49d44e
...
...
@@ -20,8 +20,8 @@ def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
def
partition
(
src
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
,
weighted
:
bool
=
False
src
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
,
weighted
:
bool
=
False
,
node_weight
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
num_parts
>=
1
...
...
@@ -41,8 +41,14 @@ def partition(
else
:
value
=
None
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
value
,
num_parts
,
recursive
)
if
node_weight
is
not
None
:
assert
node_weight
.
numel
()
==
rowptr
.
numel
()
-
1
node_weight
=
node_weight
.
view
(
-
1
).
detach
().
cpu
()
if
node_weight
.
is_floating_point
():
node_weight
=
weight2metis
(
node_weight
)
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
value
,
node_weight
,
num_parts
,
recursive
)
cluster
=
cluster
.
to
(
src
.
device
())
cluster
,
perm
=
cluster
.
sort
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment