Commit eee47eee authored by rusty1s's avatar rusty1s
Browse files

metis support optional

parent 0a659cde
...@@ -59,7 +59,8 @@ $ echo $CPATH ...@@ -59,7 +59,8 @@ $ echo $CPATH
>>> /usr/local/cuda/include:... >>> /usr/local/cuda/include:...
``` ```
Afterwards, download and install the [METIS library](http://glaros.dtc.umn.edu/gkhome/metis/metis/download) by following the instructions in the `Install.txt` file. If you want to additionally build `torch-sparse` with METIS support, *e.g.* for partioning, please download and install the [METIS library](http://glaros.dtc.umn.edu/gkhome/metis/metis/download) by following the instructions in the `Install.txt` file.
Afterwards, set the environment variable `WITH_METIS=1`.
Then run: Then run:
...@@ -67,11 +68,11 @@ Then run: ...@@ -67,11 +68,11 @@ Then run:
pip install torch-scatter torch-sparse pip install torch-scatter torch-sparse
``` ```
When running in a docker container without nvidia driver, PyTorch needs to evaluate the compute capabilities and may fail. When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail.
In this case, ensure that the compute capabilities are set via `TORCH_CUDA_ARCH_LIST`, *e.g.*: In this case, ensure that the compute capabilities are set via `TORCH_CUDA_ARCH_LIST`, *e.g.*:
``` ```
export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX" export TORCH_CUDA_ARCH_LIST="6.0 6.1 7.2+PTX 7.5+PTX"
``` ```
## Functions ## Functions
......
#include "metis_cpu.h" #include "metis_cpu.h"
#ifdef WITH_METIS
#include <metis.h> #include <metis.h>
#endif
#include "utils.h" #include "utils.h"
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts, bool recursive) { int64_t num_parts, bool recursive) {
#ifdef WITH_METIS
CHECK_CPU(rowptr); CHECK_CPU(rowptr);
CHECK_CPU(col); CHECK_CPU(col);
int64_t nvtxs = rowptr.numel() - 1; int64_t nvtxs = rowptr.numel() - 1;
auto part = torch::empty(nvtxs, rowptr.options());
auto *xadj = rowptr.data_ptr<int64_t>(); auto *xadj = rowptr.data_ptr<int64_t>();
auto *adjncy = col.data_ptr<int64_t>(); auto *adjncy = col.data_ptr<int64_t>();
int64_t ncon = 1; int64_t ncon = 1;
int64_t objval = -1; int64_t objval = -1;
auto part = torch::empty(nvtxs, rowptr.options());
auto part_data = part.data_ptr<int64_t>(); auto part_data = part.data_ptr<int64_t>();
if (recursive) { if (recursive) {
...@@ -26,4 +30,7 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -26,4 +30,7 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
} }
return part; return part;
#else
AT_ERROR("Not compiled with METIS support");
#endif
} }
...@@ -3,10 +3,8 @@ ...@@ -3,10 +3,8 @@
#include "cpu/metis_cpu.h" #include "cpu/metis_cpu.h"
#include <metis.h>
#ifdef _WIN32 #ifdef _WIN32
PyMODINIT_FUNC PyInit__metis_wrapper(void) { return NULL; } PyMODINIT_FUNC PyInit__metis(void) { return NULL; }
#endif #endif
torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col, torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
......
#!/bin/bash #!/bin/bash
METIS=metis-5.1.0 METIS=metis-5.1.0
export WITH_METIS=1
wget -nv http://glaros.dtc.umn.edu/gkhome/fetch/sw/metis/${METIS}.tar.gz wget -nv http://glaros.dtc.umn.edu/gkhome/fetch/sw/metis/${METIS}.tar.gz
tar -xvzf ${METIS}.tar.gz tar -xvzf ${METIS}.tar.gz
......
...@@ -16,10 +16,18 @@ if os.getenv('FORCE_CPU', '0') == '1': ...@@ -16,10 +16,18 @@ if os.getenv('FORCE_CPU', '0') == '1':
BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1' BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
WITH_METIS = False
if os.getenv('WITH_METIS', '0') == '1':
WITH_METIS = True
def get_extensions(): def get_extensions():
Extension = CppExtension Extension = CppExtension
define_macros = [] define_macros = []
libraries = []
if WITH_METIS:
define_macros += [('WITH_METIS', None)]
libraries += ['metis']
extra_compile_args = {'cxx': []} extra_compile_args = {'cxx': []}
extra_link_args = [] extra_link_args = []
...@@ -59,7 +67,7 @@ def get_extensions(): ...@@ -59,7 +67,7 @@ def get_extensions():
define_macros=define_macros, define_macros=define_macros,
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args, extra_link_args=extra_link_args,
libraries=['metis'], libraries=libraries,
) )
extensions += [extension] extensions += [extension]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment