Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
a29aabed
Commit
a29aabed
authored
Mar 26, 2021
by
rusty1s
Browse files
rename
parent
247154f9
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
67 additions
and
68 deletions
+67
-68
csrc/cpu/metis_cpu.cpp
csrc/cpu/metis_cpu.cpp
+23
-18
csrc/cpu/metis_cpu.h
csrc/cpu/metis_cpu.h
+6
-6
csrc/metis.cpp
csrc/metis.cpp
+6
-5
test/test_metis.py
test/test_metis.py
+23
-29
torch_sparse/metis.py
torch_sparse/metis.py
+9
-10
No files found.
csrc/cpu/metis_cpu.cpp
View file @
a29aabed
...
@@ -12,34 +12,36 @@
...
@@ -12,34 +12,36 @@
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
v
weight
s
,
torch
::
optional
<
torch
::
Tensor
>
optional_node_
weight
,
int64_t
num_parts
,
bool
recursive
)
{
int64_t
num_parts
,
bool
recursive
)
{
#ifdef WITH_METIS
#ifdef WITH_METIS
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
CHECK_CPU
(
col
);
if
(
optional_value
.
has_value
())
{
if
(
optional_value
.
has_value
())
{
CHECK_CPU
(
optional_value
.
value
());
CHECK_CPU
(
optional_value
.
value
());
CHECK_INPUT
(
optional_value
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_value
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_value
.
value
().
numel
()
==
col
.
numel
());
CHECK_INPUT
(
optional_value
.
value
().
numel
()
==
col
.
numel
());
}
}
if
(
v
weight
s
.
has_value
())
{
if
(
optional_node_
weight
.
has_value
())
{
CHECK_CPU
(
v
weight
s
.
value
());
CHECK_CPU
(
optional_node_
weight
.
value
());
CHECK_INPUT
(
v
weight
s
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_node_
weight
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
v
weight
s
.
value
().
numel
()
==
rowptr
.
numel
()
-
1
);
CHECK_INPUT
(
optional_node_
weight
.
value
().
numel
()
==
rowptr
.
numel
()
-
1
);
}
}
int64_t
nvtxs
=
rowptr
.
numel
()
-
1
;
int64_t
nvtxs
=
rowptr
.
numel
()
-
1
;
int64_t
ncon
=
1
;
int64_t
ncon
=
1
;
auto
*
xadj
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
*
xadj
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
*
adjncy
=
col
.
data_ptr
<
int64_t
>
();
auto
*
adjncy
=
col
.
data_ptr
<
int64_t
>
();
int64_t
*
adjwgt
=
NULL
;
int64_t
*
adjwgt
=
NULL
;
if
(
optional_value
.
has_value
())
if
(
optional_value
.
has_value
())
adjwgt
=
optional_value
.
value
().
data_ptr
<
int64_t
>
();
adjwgt
=
optional_value
.
value
().
data_ptr
<
int64_t
>
();
int64_t
*
vwgt
=
NULL
;
int64_t
*
vwgt
=
NULL
;
if
(
v
weight
s
.
has_value
())
if
(
optional_node_
weight
.
has_value
())
vwgt
=
v
weight
s
.
value
().
data_ptr
<
int64_t
>
();
vwgt
=
optional_node_
weight
.
value
().
data_ptr
<
int64_t
>
();
int64_t
objval
=
-
1
;
int64_t
objval
=
-
1
;
auto
part
=
torch
::
empty
(
nvtxs
,
rowptr
.
options
());
auto
part
=
torch
::
empty
(
nvtxs
,
rowptr
.
options
());
...
@@ -62,11 +64,11 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
...
@@ -62,11 +64,11 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
// needs mt-metis installed via:
// needs mt-metis installed via:
// ./configure --shared --edges64bit --vertices64bit --weights64bit
// ./configure --shared --edges64bit --vertices64bit --weights64bit
// --partitions64bit
// --partitions64bit
torch
::
Tensor
mt_partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
mt_partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
vweights
,
torch
::
optional
<
torch
::
Tensor
>
optional_node_weight
,
int64_t
num_parts
,
bool
recursive
,
int64_t
num_parts
,
bool
recursive
,
int64_t
num_workers
)
{
int64_t
num_workers
)
{
#ifdef WITH_MTMETIS
#ifdef WITH_MTMETIS
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
CHECK_CPU
(
col
);
...
@@ -76,10 +78,10 @@ torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
...
@@ -76,10 +78,10 @@ torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
CHECK_INPUT
(
optional_value
.
value
().
numel
()
==
col
.
numel
());
CHECK_INPUT
(
optional_value
.
value
().
numel
()
==
col
.
numel
());
}
}
if
(
v
weight
s
.
has_value
())
{
if
(
optional_node_
weight
.
has_value
())
{
CHECK_CPU
(
v
weight
s
.
value
());
CHECK_CPU
(
optional_node_
weight
.
value
());
CHECK_INPUT
(
v
weight
s
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_node_
weight
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
v
weight
s
.
value
().
numel
()
==
rowptr
.
numel
()
-
1
);
CHECK_INPUT
(
optional_node_
weight
.
value
().
numel
()
==
rowptr
.
numel
()
-
1
);
}
}
mtmetis_vtx_type
nvtxs
=
rowptr
.
numel
()
-
1
;
mtmetis_vtx_type
nvtxs
=
rowptr
.
numel
()
-
1
;
...
@@ -87,11 +89,14 @@ torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
...
@@ -87,11 +89,14 @@ torch::Tensor mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
mtmetis_adj_type
*
xadj
=
(
mtmetis_adj_type
*
)
rowptr
.
data_ptr
<
int64_t
>
();
mtmetis_adj_type
*
xadj
=
(
mtmetis_adj_type
*
)
rowptr
.
data_ptr
<
int64_t
>
();
mtmetis_vtx_type
*
adjncy
=
(
mtmetis_vtx_type
*
)
col
.
data_ptr
<
int64_t
>
();
mtmetis_vtx_type
*
adjncy
=
(
mtmetis_vtx_type
*
)
col
.
data_ptr
<
int64_t
>
();
mtmetis_wgt_type
*
adjwgt
=
NULL
;
mtmetis_wgt_type
*
adjwgt
=
NULL
;
if
(
optional_value
.
has_value
())
if
(
optional_value
.
has_value
())
adjwgt
=
optional_value
.
value
().
data_ptr
<
int64_t
>
();
adjwgt
=
optional_value
.
value
().
data_ptr
<
int64_t
>
();
mtmetis_wgt_type
*
vwgt
=
NULL
;
mtmetis_wgt_type
*
vwgt
=
NULL
;
if
(
vweights
.
has_value
())
if
(
optional_node_weight
.
has_value
())
vwgt
=
vweights
.
value
().
data_ptr
<
int64_t
>
();
vwgt
=
optional_node_weight
.
value
().
data_ptr
<
int64_t
>
();
mtmetis_pid_type
nparts
=
num_parts
;
mtmetis_pid_type
nparts
=
num_parts
;
mtmetis_wgt_type
objval
=
-
1
;
mtmetis_wgt_type
objval
=
-
1
;
auto
part
=
torch
::
empty
(
nvtxs
,
rowptr
.
options
());
auto
part
=
torch
::
empty
(
nvtxs
,
rowptr
.
options
());
...
...
csrc/cpu/metis_cpu.h
View file @
a29aabed
...
@@ -4,11 +4,11 @@
...
@@ -4,11 +4,11 @@
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
v
weight
s
,
torch
::
optional
<
torch
::
Tensor
>
optional_node_
weight
,
int64_t
num_parts
,
bool
recursive
);
int64_t
num_parts
,
bool
recursive
);
torch
::
Tensor
mt_partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
mt_partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
vweights
,
torch
::
optional
<
torch
::
Tensor
>
optional_node_weight
,
int64_t
num_parts
,
bool
recursive
,
int64_t
num_parts
,
bool
recursive
,
int64_t
num_workers
);
int64_t
num_workers
);
csrc/metis.cpp
View file @
a29aabed
...
@@ -13,7 +13,7 @@ PyMODINIT_FUNC PyInit__metis_cpu(void) { return NULL; }
...
@@ -13,7 +13,7 @@ PyMODINIT_FUNC PyInit__metis_cpu(void) { return NULL; }
torch
::
Tensor
partition
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
partition
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
v
weight
s
,
torch
::
optional
<
torch
::
Tensor
>
optional_node_
weight
,
int64_t
num_parts
,
bool
recursive
)
{
int64_t
num_parts
,
bool
recursive
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
#ifdef WITH_CUDA
...
@@ -22,13 +22,14 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
...
@@ -22,13 +22,14 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR
(
"Not compiled with CUDA support"
);
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
#endif
}
else
{
}
else
{
return
partition_cpu
(
rowptr
,
col
,
optional_value
,
vweights
,
num_parts
,
recursive
);
return
partition_cpu
(
rowptr
,
col
,
optional_value
,
optional_node_weight
,
num_parts
,
recursive
);
}
}
}
}
torch
::
Tensor
mt_partition
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
mt_partition
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
v
weight
s
,
torch
::
optional
<
torch
::
Tensor
>
optional_node_
weight
,
int64_t
num_parts
,
bool
recursive
,
int64_t
num_parts
,
bool
recursive
,
int64_t
num_workers
)
{
int64_t
num_workers
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
if
(
rowptr
.
device
().
is_cuda
())
{
...
@@ -38,8 +39,8 @@ torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
...
@@ -38,8 +39,8 @@ torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR
(
"Not compiled with CUDA support"
);
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
#endif
}
else
{
}
else
{
return
mt_partition_cpu
(
rowptr
,
col
,
optional_value
,
vweights
,
num_parts
,
recursive
,
return
mt_partition_cpu
(
rowptr
,
col
,
optional_value
,
optional_node_weight
,
num_workers
);
num_parts
,
recursive
,
num_workers
);
}
}
}
}
...
...
test/test_metis.py
View file @
a29aabed
import
pytest
import
pytest
from
itertools
import
product
import
torch
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
...
@@ -12,32 +14,24 @@ except RuntimeError:
...
@@ -12,32 +14,24 @@ except RuntimeError:
@
pytest
.
mark
.
skipif
(
not
with_metis
,
reason
=
'Not compiled with METIS support'
)
@
pytest
.
mark
.
skipif
(
not
with_metis
,
reason
=
'Not compiled with METIS support'
)
@
pytest
.
mark
.
parametrize
(
'device
'
,
devices
)
@
pytest
.
mark
.
parametrize
(
'device
,weighted'
,
product
(
devices
,
[
False
,
True
])
)
def
test_metis
(
device
):
def
test_metis
(
device
,
weighted
):
value
1
=
torch
.
randn
(
6
*
6
,
device
=
device
).
view
(
6
,
6
)
mat
1
=
torch
.
randn
(
6
*
6
,
device
=
device
).
view
(
6
,
6
)
value
2
=
torch
.
arange
(
6
*
6
,
dtype
=
torch
.
long
,
device
=
device
).
view
(
6
,
6
)
mat
2
=
torch
.
arange
(
6
*
6
,
dtype
=
torch
.
long
,
device
=
device
).
view
(
6
,
6
)
value
3
=
torch
.
ones
(
6
*
6
,
device
=
device
).
view
(
6
,
6
)
mat
3
=
torch
.
ones
(
6
*
6
,
device
=
device
).
view
(
6
,
6
)
vwgts
=
torch
.
rand
(
6
,
device
=
device
)
vec1
=
None
vec2
=
torch
.
rand
(
6
,
device
=
device
)
for
value
in
[
value1
,
value2
,
value3
]:
for
mat
,
vec
in
product
([
mat1
,
mat2
,
mat3
],
[
vec1
,
vec2
]):
for
vwgt
in
[
None
,
vwgts
]:
mat
=
SparseTensor
.
from_dense
(
mat
)
mat
=
SparseTensor
.
from_dense
(
value
)
_
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
2
,
recursive
=
False
,
_
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
1
,
recursive
=
False
,
vweights
=
vwgt
,
weighted
=
weighted
,
node_weight
=
vec
)
weighted
=
True
)
assert
partptr
.
numel
()
==
2
assert
partptr
.
numel
()
==
3
assert
perm
.
numel
()
==
6
assert
perm
.
numel
()
==
6
_
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
2
,
recursive
=
False
,
_
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
2
,
recursive
=
False
,
vweights
=
vwgt
,
weighted
=
weighted
,
node_weight
=
vec
)
weighted
=
False
)
assert
partptr
.
numel
()
==
3
assert
partptr
.
numel
()
==
3
assert
perm
.
numel
()
==
6
assert
perm
.
numel
()
==
6
_
,
partptr
,
perm
=
mat
.
partition
(
num_parts
=
1
,
recursive
=
False
,
vweights
=
vwgt
,
weighted
=
True
)
assert
partptr
.
numel
()
==
2
assert
perm
.
numel
()
==
6
torch_sparse/metis.py
View file @
a29aabed
...
@@ -21,8 +21,7 @@ def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
...
@@ -21,8 +21,7 @@ def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
def
partition
(
def
partition
(
src
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
,
src
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
,
vweights
:
torch
.
tensor
=
None
,
weighted
:
bool
=
False
,
node_weight
:
torch
.
tensor
=
None
weighted
:
bool
=
False
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
num_parts
>=
1
assert
num_parts
>=
1
...
@@ -42,14 +41,14 @@ def partition(
...
@@ -42,14 +41,14 @@ def partition(
else
:
else
:
value
=
None
value
=
None
if
v
weight
s
is
not
None
:
if
node_
weight
is
not
None
:
assert
v
weight
s
.
numel
()
==
rowptr
.
numel
()
-
1
assert
node_
weight
.
numel
()
==
rowptr
.
numel
()
-
1
v
weight
s
=
v
weight
s
.
view
(
-
1
).
detach
().
cpu
()
node_
weight
=
node_
weight
.
view
(
-
1
).
detach
().
cpu
()
if
v
weight
s
.
is_floating_point
():
if
node_
weight
.
is_floating_point
():
v
weight
s
=
weight2metis
(
v
weight
s
)
node_
weight
=
weight2metis
(
node_
weight
)
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
value
,
v
weight
s
,
num_parts
,
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
value
,
node_
weight
,
recursive
)
num_parts
,
recursive
)
cluster
=
cluster
.
to
(
src
.
device
())
cluster
=
cluster
.
to
(
src
.
device
())
cluster
,
perm
=
cluster
.
sort
()
cluster
,
perm
=
cluster
.
sort
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment