Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
a06899bb
Commit
a06899bb
authored
Feb 23, 2020
by
rusty1s
Browse files
recursive
parent
45d29d1a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
17 deletions
+22
-17
csrc/cpu/metis_wrapper_cpu.cpp
csrc/cpu/metis_wrapper_cpu.cpp
+9
-4
csrc/cpu/metis_wrapper_cpu.h
csrc/cpu/metis_wrapper_cpu.h
+2
-2
csrc/metis_wrapper.cpp
csrc/metis_wrapper.cpp
+5
-5
torch_sparse/metis.py
torch_sparse/metis.py
+6
-6
No files found.
csrc/cpu/metis_wrapper_cpu.cpp
View file @
a06899bb
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
#include "utils.h"
#include "utils.h"
torch
::
Tensor
partition_
kway_
cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
int64_t
num_parts
)
{
int64_t
num_parts
,
bool
recursive
)
{
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
CHECK_CPU
(
col
);
...
@@ -17,8 +17,13 @@ torch::Tensor partition_kway_cpu(torch::Tensor rowptr, torch::Tensor col,
...
@@ -17,8 +17,13 @@ torch::Tensor partition_kway_cpu(torch::Tensor rowptr, torch::Tensor col,
auto
part
=
torch
::
empty
(
nvtxs
,
rowptr
.
options
());
auto
part
=
torch
::
empty
(
nvtxs
,
rowptr
.
options
());
auto
part_data
=
part
.
data_ptr
<
int64_t
>
();
auto
part_data
=
part
.
data_ptr
<
int64_t
>
();
METIS_PartGraphKway
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
NULL
,
NULL
,
NULL
,
&
num_parts
,
if
(
recursive
)
{
NULL
,
NULL
,
NULL
,
&
objval
,
part_data
);
METIS
PartGraphRecursive
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
NULL
,
NULL
,
NULL
,
&
num_parts
,
NULL
,
NULL
,
NULL
,
&
objval
,
part_data
);
}
else
{
METIS_PartGraphKway
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
NULL
,
NULL
,
NULL
,
&
num_parts
,
NULL
,
NULL
,
NULL
,
&
objval
,
part_data
);
}
return
part
;
return
part
;
}
}
csrc/cpu/metis_wrapper_cpu.h
View file @
a06899bb
...
@@ -2,5 +2,5 @@
...
@@ -2,5 +2,5 @@
#include <torch/extension.h>
#include <torch/extension.h>
torch
::
Tensor
partition_
kway_
cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
int64_t
num_parts
);
int64_t
num_parts
,
bool
recursive
);
csrc/metis_wrapper.cpp
View file @
a06899bb
...
@@ -9,8 +9,8 @@
...
@@ -9,8 +9,8 @@
PyMODINIT_FUNC
PyInit__metis_wrapper
(
void
)
{
return
NULL
;
}
PyMODINIT_FUNC
PyInit__metis_wrapper
(
void
)
{
return
NULL
;
}
#endif
#endif
torch
::
Tensor
partition
_kway
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
partition
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
int64_t
num_parts
)
{
int64_t
num_parts
,
bool
recursive
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
#ifdef WITH_CUDA
AT_ERROR
(
"No CUDA version supported"
);
AT_ERROR
(
"No CUDA version supported"
);
...
@@ -18,9 +18,9 @@ torch::Tensor partition_kway(torch::Tensor rowptr, torch::Tensor col,
...
@@ -18,9 +18,9 @@ torch::Tensor partition_kway(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR
(
"Not compiled with CUDA support"
);
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
#endif
}
else
{
}
else
{
return
partition_kway_cpu
(
rowptr
,
col
,
num_parts
);
return
partition_kway_cpu
(
rowptr
,
col
,
num_parts
,
recursive
);
}
}
}
}
static
auto
registry
=
torch
::
RegisterOperators
().
op
(
static
auto
registry
=
"torch_sparse::partition
_kway
"
,
&
partition
_kway
);
torch
::
RegisterOperators
().
op
(
"torch_sparse::partition"
,
&
partition
);
torch_sparse/metis.py
View file @
a06899bb
...
@@ -5,12 +5,13 @@ from torch_sparse.tensor import SparseTensor
...
@@ -5,12 +5,13 @@ from torch_sparse.tensor import SparseTensor
from
torch_sparse.permute
import
permute
from
torch_sparse.permute
import
permute
def
partition
_kway
(
def
partition
(
src
:
SparseTensor
,
src
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
num_parts
:
int
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
rowptr
,
col
=
src
.
storage
.
rowptr
().
cpu
(),
src
.
storage
.
col
().
cpu
()
rowptr
,
col
=
src
.
storage
.
rowptr
().
cpu
(),
src
.
storage
.
col
().
cpu
()
cluster
=
torch
.
ops
.
torch_sparse
.
partition_kway
(
rowptr
,
col
,
num_parts
)
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
num_parts
,
recursive
)
cluster
=
cluster
.
to
(
src
.
device
())
cluster
=
cluster
.
to
(
src
.
device
())
cluster
,
perm
=
cluster
.
sort
()
cluster
,
perm
=
cluster
.
sort
()
...
@@ -20,5 +21,4 @@ def partition_kway(
...
@@ -20,5 +21,4 @@ def partition_kway(
return
out
,
partptr
,
perm
return
out
,
partptr
,
perm
SparseTensor
.
partition_kway
=
lambda
self
,
num_parts
:
partition_kway
(
SparseTensor
.
partition
=
lambda
self
,
num_parts
:
partition
(
self
,
num_parts
)
self
,
num_parts
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment