Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
247154f9
Commit
247154f9
authored
Mar 25, 2021
by
Chendi Qian
Browse files
add node weights for metis wrapper
parent
54d8418e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
62 additions
and
23 deletions
+62
-23
csrc/cpu/metis_cpu.cpp
csrc/cpu/metis_cpu.cpp
+26
-4
csrc/cpu/metis_cpu.h
csrc/cpu/metis_cpu.h
+2
-0
csrc/metis.cpp
csrc/metis.cpp
+4
-2
test/test_metis.py
test/test_metis.py
+22
-16
torch_sparse/metis.py
torch_sparse/metis.py
+8
-1
No files found.
csrc/cpu/metis_cpu.cpp
View file @
247154f9
...
...
@@ -12,6 +12,7 @@
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
vweights
,
int64_t
num_parts
,
bool
recursive
)
{
#ifdef WITH_METIS
CHECK_CPU
(
rowptr
);
...
...
@@ -22,6 +23,12 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
CHECK_INPUT
(
optional_value
.
value
().
numel
()
==
col
.
numel
());
}
if
(
vweights
.
has_value
())
{
CHECK_CPU
(
vweights
.
value
());
CHECK_INPUT
(
vweights
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
vweights
.
value
().
numel
()
==
rowptr
.
numel
()
-
1
);
}
int64_t
nvtxs
=
rowptr
.
numel
()
-
1
;
int64_t
ncon
=
1
;
auto
*
xadj
=
rowptr
.
data_ptr
<
int64_t
>
();
...
...
@@ -29,15 +36,20 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t
*
adjwgt
=
NULL
;
if
(
optional_value
.
has_value
())
adjwgt
=
optional_value
.
value
().
data_ptr
<
int64_t
>
();
int64_t
*
vwgt
=
NULL
;
if
(
vweights
.
has_value
())
vwgt
=
vweights
.
value
().
data_ptr
<
int64_t
>
();
int64_t
objval
=
-
1
;
auto
part
=
torch
::
empty
(
nvtxs
,
rowptr
.
options
());
auto
part_data
=
part
.
data_ptr
<
int64_t
>
();
if
(
recursive
)
{
METIS_PartGraphRecursive
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
NULL
,
NULL
,
adjwgt
,
METIS_PartGraphRecursive
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
vwgt
,
NULL
,
adjwgt
,
&
num_parts
,
NULL
,
NULL
,
NULL
,
&
objval
,
part_data
);
}
else
{
METIS_PartGraphKway
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
NULL
,
NULL
,
adjwgt
,
METIS_PartGraphKway
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
vwgt
,
NULL
,
adjwgt
,
&
num_parts
,
NULL
,
NULL
,
NULL
,
&
objval
,
part_data
);
}
...
...
@@ -52,6 +64,7 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
// --partitions64bit
torch
::
Tensor
mt_partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
vweights
,
int64_t
num_parts
,
bool
recursive
,
int64_t
num_workers
)
{
#ifdef WITH_MTMETIS
...
...
@@ -63,6 +76,12 @@ torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
CHECK_INPUT
(
optional_value
.
value
().
numel
()
==
col
.
numel
());
}
if
(
vweights
.
has_value
())
{
CHECK_CPU
(
vweights
.
value
());
CHECK_INPUT
(
vweights
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
vweights
.
value
().
numel
()
==
rowptr
.
numel
()
-
1
);
}
mtmetis_vtx_type
nvtxs
=
rowptr
.
numel
()
-
1
;
mtmetis_vtx_type
ncon
=
1
;
mtmetis_adj_type
*
xadj
=
(
mtmetis_adj_type
*
)
rowptr
.
data_ptr
<
int64_t
>
();
...
...
@@ -70,6 +89,9 @@ torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
mtmetis_wgt_type
*
adjwgt
=
NULL
;
if
(
optional_value
.
has_value
())
adjwgt
=
optional_value
.
value
().
data_ptr
<
int64_t
>
();
mtmetis_wgt_type
*
vwgt
=
NULL
;
if
(
vweights
.
has_value
())
vwgt
=
vweights
.
value
().
data_ptr
<
int64_t
>
();
mtmetis_pid_type
nparts
=
num_parts
;
mtmetis_wgt_type
objval
=
-
1
;
auto
part
=
torch
::
empty
(
nvtxs
,
rowptr
.
options
());
...
...
@@ -79,10 +101,10 @@ torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
opts
[
MTMETIS_OPTION_NTHREADS
]
=
num_workers
;
if
(
recursive
)
{
MTMETIS_PartGraphRecursive
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
NULL
,
NULL
,
adjwgt
,
MTMETIS_PartGraphRecursive
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
vwgt
,
NULL
,
adjwgt
,
&
nparts
,
NULL
,
NULL
,
opts
,
&
objval
,
part_data
);
}
else
{
MTMETIS_PartGraphKway
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
NULL
,
NULL
,
adjwgt
,
MTMETIS_PartGraphKway
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
vwgt
,
NULL
,
adjwgt
,
&
nparts
,
NULL
,
NULL
,
opts
,
&
objval
,
part_data
);
}
...
...
csrc/cpu/metis_cpu.h
View file @
247154f9
...
...
@@ -4,9 +4,11 @@
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
vweights
,
int64_t
num_parts
,
bool
recursive
);
torch
::
Tensor
mt_partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
vweights
,
int64_t
num_parts
,
bool
recursive
,
int64_t
num_workers
);
csrc/metis.cpp
View file @
247154f9
...
...
@@ -13,6 +13,7 @@ PyMODINIT_FUNC PyInit__metis_cpu(void) { return NULL; }
torch
::
Tensor
partition
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
vweights
,
int64_t
num_parts
,
bool
recursive
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
...
...
@@ -21,12 +22,13 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
partition_cpu
(
rowptr
,
col
,
optional_value
,
num_parts
,
recursive
);
return
partition_cpu
(
rowptr
,
col
,
optional_value
,
vweights
,
num_parts
,
recursive
);
}
}
torch
::
Tensor
mt_partition
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
vweights
,
int64_t
num_parts
,
bool
recursive
,
int64_t
num_workers
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
...
...
@@ -36,7 +38,7 @@ torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
mt_partition_cpu
(
rowptr
,
col
,
optional_value
,
num_parts
,
recursive
,
return
mt_partition_cpu
(
rowptr
,
col
,
optional_value
,
vweights
,
num_parts
,
recursive
,
num_workers
);
}
}
...
...
test/test_metis.py
View file @
247154f9
...
...
@@ -18,20 +18,26 @@ def test_metis(device):
value2
=
torch
.
arange
(
6
*
6
,
dtype
=
torch
.
long
,
device
=
device
).
view
(
6
,
6
)
value3
=
torch
.
ones
(
6
*
6
,
device
=
device
).
view
(
6
,
6
)
vwgts
=
torch
.
rand
(
6
,
device
=
device
)
for
value
in
[
value1
,
value2
,
value3
]:
mat
=
SparseTensor
.
from_dense
(
value
)
_
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
2
,
recursive
=
False
,
weighted
=
True
)
assert
partptr
.
numel
()
==
3
assert
perm
.
numel
()
==
6
_
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
2
,
recursive
=
False
,
weighted
=
False
)
assert
partptr
.
numel
()
==
3
assert
perm
.
numel
()
==
6
_
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
1
,
recursive
=
False
,
weighted
=
True
)
assert
partptr
.
numel
()
==
2
assert
perm
.
numel
()
==
6
for
vwgt
in
[
None
,
vwgts
]:
mat
=
SparseTensor
.
from_dense
(
value
)
_
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
2
,
recursive
=
False
,
vweights
=
vwgt
,
weighted
=
True
)
assert
partptr
.
numel
()
==
3
assert
perm
.
numel
()
==
6
_
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
2
,
recursive
=
False
,
vweights
=
vwgt
,
weighted
=
False
)
assert
partptr
.
numel
()
==
3
assert
perm
.
numel
()
==
6
_
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
1
,
recursive
=
False
,
vweights
=
vwgt
,
weighted
=
True
)
assert
partptr
.
numel
()
==
2
assert
perm
.
numel
()
==
6
torch_sparse/metis.py
View file @
247154f9
...
...
@@ -21,6 +21,7 @@ def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
def
partition
(
src
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
,
vweights
:
torch
.
tensor
=
None
,
weighted
:
bool
=
False
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -41,7 +42,13 @@ def partition(
else
:
value
=
None
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
value
,
num_parts
,
if
vweights
is
not
None
:
assert
vweights
.
numel
()
==
rowptr
.
numel
()
-
1
vweights
=
vweights
.
view
(
-
1
).
detach
().
cpu
()
if
vweights
.
is_floating_point
():
vweights
=
weight2metis
(
vweights
)
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
value
,
vweights
,
num_parts
,
recursive
)
cluster
=
cluster
.
to
(
src
.
device
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment