Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
9c3519b4
Commit
9c3519b4
authored
Apr 14, 2020
by
rusty1s
Browse files
update
parent
523e86a3
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
34 additions
and
33 deletions
+34
-33
.gitignore
.gitignore
+1
-1
csrc/cpu/metis_cpu.cpp
csrc/cpu/metis_cpu.cpp
+12
-13
csrc/cpu/metis_cpu.h
csrc/cpu/metis_cpu.h
+2
-2
csrc/metis.cpp
csrc/metis.cpp
+6
-4
torch_sparse/metis.py
torch_sparse/metis.py
+13
-6
torch_sparse/utils.py
torch_sparse/utils.py
+0
-7
No files found.
.gitignore
View file @
9c3519b4
csrc/cpu/metis_cpu.cpp
View file @
9c3519b4
...
@@ -6,26 +6,25 @@
...
@@ -6,26 +6,25 @@
#include "utils.h"
#include "utils.h"
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
int64_t
num_parts
,
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
edge_wgt
,
bool
recursive
)
{
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
int64_t
num_parts
,
bool
recursive
)
{
#ifdef WITH_METIS
#ifdef WITH_METIS
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
CHECK_CPU
(
col
);
if
(
optional_value
.
has_value
())
{
CHECK_CPU
(
optional_value
.
value
());
CHECK_INPUT
(
optional_value
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_value
.
value
().
numel
()
==
col
.
numel
());
}
int64_t
nvtxs
=
rowptr
.
numel
()
-
1
;
int64_t
nvtxs
=
rowptr
.
numel
()
-
1
;
auto
part
=
torch
::
empty
(
nvtxs
,
rowptr
.
options
());
auto
part
=
torch
::
empty
(
nvtxs
,
rowptr
.
options
());
auto
*
xadj
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
*
xadj
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
*
adjncy
=
col
.
data_ptr
<
int64_t
>
();
auto
*
adjncy
=
col
.
data_ptr
<
int64_t
>
();
int64_t
*
adjwgt
=
NULL
;
int64_t
*
adjwgt
;
if
(
optional_value
.
has_value
())
if
(
edge_wgt
.
has_value
()){
adjwgt
=
optional_value
.
value
().
data_ptr
<
int64_t
>
();
adjwgt
=
edge_wgt
.
value
().
data_ptr
<
int64_t
>
();
adjwgt
=
(
idx_t
*
)
adjwgt
;
}
else
{
adjwgt
=
nullptr
;
}
int64_t
ncon
=
1
;
int64_t
ncon
=
1
;
int64_t
objval
=
-
1
;
int64_t
objval
=
-
1
;
auto
part_data
=
part
.
data_ptr
<
int64_t
>
();
auto
part_data
=
part
.
data_ptr
<
int64_t
>
();
...
...
csrc/cpu/metis_cpu.h
View file @
9c3519b4
...
@@ -3,5 +3,5 @@
...
@@ -3,5 +3,5 @@
#include <torch/extension.h>
#include <torch/extension.h>
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
int64_t
num_parts
,
torch
::
optional
<
torch
::
Tensor
>
edge_wgt
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
bool
recursive
);
int64_t
num_parts
,
bool
recursive
);
csrc/metis.cpp
View file @
9c3519b4
#include "cpu/metis_cpu.h"
#include <Python.h>
#include <Python.h>
#include <torch/script.h>
#include <torch/script.h>
#include "cpu/metis_cpu.h"
#ifdef _WIN32
#ifdef _WIN32
PyMODINIT_FUNC
PyInit__metis
(
void
)
{
return
NULL
;
}
PyMODINIT_FUNC
PyInit__metis
(
void
)
{
return
NULL
;
}
#endif
#endif
torch
::
Tensor
partition
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
int64_t
num_parts
,
torch
::
Tensor
partition
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
edge_wgt
,
bool
recursive
)
{
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
int64_t
num_parts
,
bool
recursive
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
#ifdef WITH_CUDA
AT_ERROR
(
"No CUDA version supported"
);
AT_ERROR
(
"No CUDA version supported"
);
...
@@ -15,7 +17,7 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,int64_t num_part
...
@@ -15,7 +17,7 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,int64_t num_part
AT_ERROR
(
"Not compiled with CUDA support"
);
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
#endif
}
else
{
}
else
{
return
partition_cpu
(
rowptr
,
col
,
num_parts
,
edge_wgt
,
recursive
);
return
partition_cpu
(
rowptr
,
col
,
optional_value
,
num_parts
,
recursive
);
}
}
}
}
...
...
torch_sparse/metis.py
View file @
9c3519b4
...
@@ -3,10 +3,15 @@ from typing import Tuple
...
@@ -3,10 +3,15 @@ from typing import Tuple
import
torch
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.permute
import
permute
from
torch_sparse.permute
import
permute
from
torch_sparse.utils
import
cartesian1d
def
metis_wgt
(
x
):
def
cartesian1d
(
x
,
y
):
a1
,
a2
=
torch
.
meshgrid
([
x
,
y
])
coos
=
torch
.
stack
([
a1
,
a2
]).
T
.
reshape
(
-
1
,
2
)
return
coos
.
split
(
1
,
dim
=
1
)
def
metis_weight
(
x
):
t1
,
t2
=
cartesian1d
(
x
,
x
)
t1
,
t2
=
cartesian1d
(
x
,
x
)
diff
=
t1
-
t2
diff
=
t1
-
t2
diff
=
diff
[
diff
!=
0
]
diff
=
diff
[
diff
!=
0
]
...
@@ -22,10 +27,12 @@ def metis_wgt(x):
...
@@ -22,10 +27,12 @@ def metis_wgt(x):
def
partition
(
src
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
def
partition
(
src
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
rowptr
,
col
=
src
.
storage
.
rowptr
().
cpu
(),
src
.
storage
.
col
().
cpu
()
rowptr
,
col
,
value
=
src
.
csr
()
edge_wgt
=
src
.
storage
.
value
().
cpu
()
rowptr
,
col
=
rowptr
.
cpu
(),
col
.
cpu
()
edge_wgt
=
metis_wgt
(
edge_wgt
)
if
value
is
not
None
and
value
.
dim
()
==
1
:
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
num_parts
,
edge_wgt
,
value
=
value
.
detach
().
cpu
()
value
=
metis_weight
(
value
)
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
value
,
num_parts
,
recursive
)
recursive
)
cluster
=
cluster
.
to
(
src
.
device
())
cluster
=
cluster
.
to
(
src
.
device
())
...
...
torch_sparse/utils.py
View file @
9c3519b4
from
typing
import
Any
from
typing
import
Any
import
torch
try
:
try
:
from
typing_extensions
import
Final
# noqa
from
typing_extensions
import
Final
# noqa
...
@@ -9,9 +8,3 @@ except ImportError:
...
@@ -9,9 +8,3 @@ except ImportError:
def
is_scalar
(
other
:
Any
)
->
bool
:
def
is_scalar
(
other
:
Any
)
->
bool
:
return
isinstance
(
other
,
int
)
or
isinstance
(
other
,
float
)
return
isinstance
(
other
,
int
)
or
isinstance
(
other
,
float
)
def
cartesian1d
(
x
,
y
):
a1
,
a2
=
torch
.
meshgrid
([
x
,
y
])
coos
=
torch
.
stack
([
a1
,
a2
]).
T
.
reshape
(
-
1
,
2
)
return
coos
.
split
(
1
,
dim
=
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment