Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
f0609836
Commit
f0609836
authored
Nov 12, 2020
by
rusty1s
Browse files
mt-metis support (experimental)
parent
f577fcee
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
88 additions
and
7 deletions
+88
-7
csrc/cpu/metis_cpu.cpp
csrc/cpu/metis_cpu.cpp
+51
-2
csrc/cpu/metis_cpu.h
csrc/cpu/metis_cpu.h
+5
-0
csrc/metis.cpp
csrc/metis.cpp
+19
-2
csrc/sparse.h
csrc/sparse.h
+4
-0
setup.py
setup.py
+9
-3
No files found.
csrc/cpu/metis_cpu.cpp
View file @
f0609836
...
@@ -4,6 +4,10 @@
...
@@ -4,6 +4,10 @@
#include <metis.h>
#include <metis.h>
#endif
#endif
#ifdef WITH_MTMETIS
#include <mtmetis.h>
#endif
#include "utils.h"
#include "utils.h"
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
...
@@ -19,14 +23,14 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
...
@@ -19,14 +23,14 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
}
}
int64_t
nvtxs
=
rowptr
.
numel
()
-
1
;
int64_t
nvtxs
=
rowptr
.
numel
()
-
1
;
auto
part
=
torch
::
empty
(
nvtxs
,
rowptr
.
options
())
;
int64_t
ncon
=
1
;
auto
*
xadj
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
*
xadj
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
*
adjncy
=
col
.
data_ptr
<
int64_t
>
();
auto
*
adjncy
=
col
.
data_ptr
<
int64_t
>
();
int64_t
*
adjwgt
=
NULL
;
int64_t
*
adjwgt
=
NULL
;
if
(
optional_value
.
has_value
())
if
(
optional_value
.
has_value
())
adjwgt
=
optional_value
.
value
().
data_ptr
<
int64_t
>
();
adjwgt
=
optional_value
.
value
().
data_ptr
<
int64_t
>
();
int64_t
ncon
=
1
;
int64_t
objval
=
-
1
;
int64_t
objval
=
-
1
;
auto
part
=
torch
::
empty
(
nvtxs
,
rowptr
.
options
());
auto
part_data
=
part
.
data_ptr
<
int64_t
>
();
auto
part_data
=
part
.
data_ptr
<
int64_t
>
();
if
(
recursive
)
{
if
(
recursive
)
{
...
@@ -42,3 +46,48 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
...
@@ -42,3 +46,48 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR
(
"Not compiled with METIS support"
);
AT_ERROR
(
"Not compiled with METIS support"
);
#endif
#endif
}
}
// needs mt-metis installed via:
// ./configure --shared --edges64bit --vertices64bit --weights64bit
// --partitions64bit
torch
::
Tensor
mt_partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
int64_t
num_parts
,
bool
recursive
,
int64_t
num_workers
)
{
#ifdef WITH_MTMETIS
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
if
(
optional_value
.
has_value
())
{
CHECK_CPU
(
optional_value
.
value
());
CHECK_INPUT
(
optional_value
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_value
.
value
().
numel
()
==
col
.
numel
());
}
mtmetis_vtx_type
nvtxs
=
rowptr
.
numel
()
-
1
;
mtmetis_vtx_type
ncon
=
1
;
mtmetis_adj_type
*
xadj
=
(
mtmetis_adj_type
*
)
rowptr
.
data_ptr
<
int64_t
>
();
mtmetis_vtx_type
*
adjncy
=
(
mtmetis_vtx_type
*
)
col
.
data_ptr
<
int64_t
>
();
mtmetis_wgt_type
*
adjwgt
=
NULL
;
if
(
optional_value
.
has_value
())
adjwgt
=
optional_value
.
value
().
data_ptr
<
int64_t
>
();
mtmetis_pid_type
nparts
=
num_parts
;
mtmetis_wgt_type
objval
=
-
1
;
auto
part
=
torch
::
empty
(
nvtxs
,
rowptr
.
options
());
mtmetis_pid_type
*
part_data
=
(
mtmetis_pid_type
*
)
part
.
data_ptr
<
int64_t
>
();
double
*
opts
=
mtmetis_init_options
();
opts
[
MTMETIS_OPTION_NTHREADS
]
=
num_workers
;
if
(
recursive
)
{
MTMETIS_PartGraphRecursive
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
NULL
,
NULL
,
adjwgt
,
&
nparts
,
NULL
,
NULL
,
opts
,
&
objval
,
part_data
);
}
else
{
MTMETIS_PartGraphKway
(
&
nvtxs
,
&
ncon
,
xadj
,
adjncy
,
NULL
,
NULL
,
adjwgt
,
&
nparts
,
NULL
,
NULL
,
opts
,
&
objval
,
part_data
);
}
return
part
;
#else
AT_ERROR
(
"Not compiled with MTMETIS support"
);
#endif
}
csrc/cpu/metis_cpu.h
View file @
f0609836
...
@@ -5,3 +5,8 @@
...
@@ -5,3 +5,8 @@
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
int64_t
num_parts
,
bool
recursive
);
int64_t
num_parts
,
bool
recursive
);
torch
::
Tensor
mt_partition_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
int64_t
num_parts
,
bool
recursive
,
int64_t
num_workers
);
csrc/metis.cpp
View file @
f0609836
...
@@ -21,5 +21,22 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
...
@@ -21,5 +21,22 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
}
}
}
}
static
auto
registry
=
torch
::
Tensor
mt_partition
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
RegisterOperators
().
op
(
"torch_sparse::partition"
,
&
partition
);
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
int64_t
num_parts
,
bool
recursive
,
int64_t
num_workers
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
AT_ERROR
(
"No CUDA version supported"
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
mt_partition_cpu
(
rowptr
,
col
,
optional_value
,
num_parts
,
recursive
,
num_workers
);
}
}
static
auto
registry
=
torch
::
RegisterOperators
()
.
op
(
"torch_sparse::partition"
,
&
partition
)
.
op
(
"torch_sparse::mt_partition"
,
&
mt_partition
);
csrc/sparse.h
View file @
f0609836
...
@@ -11,6 +11,10 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
...
@@ -11,6 +11,10 @@ torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
int64_t
num_parts
,
bool
recursive
);
int64_t
num_parts
,
bool
recursive
);
torch
::
Tensor
mt_partition
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
int64_t
num_parts
,
bool
recursive
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
relabel
(
torch
::
Tensor
col
,
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
relabel
(
torch
::
Tensor
col
,
torch
::
Tensor
idx
);
torch
::
Tensor
idx
);
...
...
setup.py
View file @
f0609836
...
@@ -17,9 +17,8 @@ if os.getenv('FORCE_CPU', '0') == '1':
...
@@ -17,9 +17,8 @@ if os.getenv('FORCE_CPU', '0') == '1':
BUILD_DOCS
=
os
.
getenv
(
'BUILD_DOCS'
,
'0'
)
==
'1'
BUILD_DOCS
=
os
.
getenv
(
'BUILD_DOCS'
,
'0'
)
==
'1'
WITH_METIS
=
False
WITH_METIS
=
True
if
os
.
getenv
(
'WITH_METIS'
,
'0'
)
==
'1'
else
False
if
os
.
getenv
(
'WITH_METIS'
,
'0'
)
==
'1'
:
WITH_MTMETIS
=
True
if
os
.
getenv
(
'WITH_MTMETIS'
,
'0'
)
==
'1'
else
False
WITH_METIS
=
True
def
get_extensions
():
def
get_extensions
():
...
@@ -29,6 +28,13 @@ def get_extensions():
...
@@ -29,6 +28,13 @@ def get_extensions():
if
WITH_METIS
:
if
WITH_METIS
:
define_macros
+=
[(
'WITH_METIS'
,
None
)]
define_macros
+=
[(
'WITH_METIS'
,
None
)]
libraries
+=
[
'metis'
]
libraries
+=
[
'metis'
]
if
WITH_MTMETIS
:
define_macros
+=
[(
'WITH_MTMETIS'
,
None
)]
define_macros
+=
[(
'MTMETIS_64BIT_VERTICES'
,
None
)]
define_macros
+=
[(
'MTMETIS_64BIT_EDGES'
,
None
)]
define_macros
+=
[(
'MTMETIS_64BIT_WEIGHTS'
,
None
)]
define_macros
+=
[(
'MTMETIS_64BIT_PARTITIONS'
,
None
)]
libraries
+=
[
'mtmetis'
,
'wildriver'
]
extra_compile_args
=
{
'cxx'
:
[]}
extra_compile_args
=
{
'cxx'
:
[]}
extra_link_args
=
[]
extra_link_args
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment