"vscode:/vscode.git/clone" did not exist on "b37c8a3ca4c9c626cdac763c6be697231665b0f8"
Unverified Commit bfb571cb authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #40 from rusty1s/metis

[WIP] Partition
parents e78637ea eee47eee
......@@ -60,6 +60,7 @@ jobs:
install:
- source script/cuda.sh
- source script/conda.sh
- source script/metis.sh
- conda create --yes -n test python="${PYTHON_VERSION}"
- source activate test
- conda install pytorch=${TORCH_VERSION} ${TOOLKIT} -c pytorch --yes
......
......@@ -59,17 +59,20 @@ $ echo $CPATH
>>> /usr/local/cuda/include:...
```
If you want to additionally build `torch-sparse` with METIS support, *e.g.* for partioning, please download and install the [METIS library](http://glaros.dtc.umn.edu/gkhome/metis/metis/download) by following the instructions in the `Install.txt` file.
Afterwards, set the environment variable `WITH_METIS=1`.
Then run:
```
pip install torch-scatter torch-sparse
```
When running in a docker container without nvidia driver, PyTorch needs to evaluate the compute capabilities and may fail.
When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail.
In this case, ensure that the compute capabilities are set via `TORCH_CUDA_ARCH_LIST`, *e.g.*:
```
export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX"
export TORCH_CUDA_ARCH_LIST="6.0 6.1 7.2+PTX 7.5+PTX"
```
## Functions
......
#include "metis_cpu.h"
#ifdef WITH_METIS
#include <metis.h>
#endif
#include "utils.h"
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts, bool recursive) {
#ifdef WITH_METIS
CHECK_CPU(rowptr);
CHECK_CPU(col);
int64_t nvtxs = rowptr.numel() - 1;
auto part = torch::empty(nvtxs, rowptr.options());
auto *xadj = rowptr.data_ptr<int64_t>();
auto *adjncy = col.data_ptr<int64_t>();
int64_t ncon = 1;
int64_t objval = -1;
auto part_data = part.data_ptr<int64_t>();
if (recursive) {
METIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, NULL,
&num_parts, NULL, NULL, NULL, &objval, part_data);
} else {
METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, NULL,
&num_parts, NULL, NULL, NULL, &objval, part_data);
}
return part;
#else
AT_ERROR("Not compiled with METIS support");
#endif
}
#pragma once
#include <torch/extension.h>
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts, bool recursive);
#include <Python.h>
#include <torch/script.h>
#include "cpu/metis_cpu.h"
#ifdef _WIN32
PyMODINIT_FUNC PyInit__metis(void) { return NULL; }
#endif
torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts, bool recursive) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return partition_cpu(rowptr, col, num_parts, recursive);
}
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::partition", &partition);
#!/bin/bash
METIS=metis-5.1.0
export WITH_METIS=1
wget -nv http://glaros.dtc.umn.edu/gkhome/fetch/sw/metis/${METIS}.tar.gz
tar -xvzf ${METIS}.tar.gz
cd ${METIS} || exit
sed -i.bak -e 's/IDXTYPEWIDTH 32/IDXTYPEWIDTH 64/g' include/metis.h
if [ "${TRAVIS_OS_NAME}" != "windows" ]; then
make config
make
sudo make install
else
# Fix GKlib on Windows: https://github.com/jlblancoc/suitesparse-metis-for-windows/issues/6
sed -i.bak -e '61,69d' GKlib/gk_arch.h
cd build || exit
cmake .. -A x64 # Ensure we are building with x64
cmake --build . --config "Release" --target ALL_BUILD
cp libmetis/Release/metis.lib /c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/VC/Tools/MSVC/14.16.27023/lib/x64
cp ../include/metis.h /c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/VC/Tools/MSVC/14.16.27023/include
cd ..
fi
cd ..
......@@ -16,10 +16,18 @@ if os.getenv('FORCE_CPU', '0') == '1':
BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
WITH_METIS = False
if os.getenv('WITH_METIS', '0') == '1':
WITH_METIS = True
def get_extensions():
Extension = CppExtension
define_macros = []
libraries = []
if WITH_METIS:
define_macros += [('WITH_METIS', None)]
libraries += ['metis']
extra_compile_args = {'cxx': []}
extra_link_args = []
......@@ -32,9 +40,9 @@ def get_extensions():
extra_compile_args['nvcc'] = nvcc_flags
if sys.platform == 'win32':
extra_link_args = ['cusparse.lib']
extra_link_args += ['cusparse.lib']
else:
extra_link_args = ['-lcusparse', '-l', 'cusparse']
extra_link_args += ['-lcusparse', '-l', 'cusparse']
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
......@@ -59,6 +67,7 @@ def get_extensions():
define_macros=define_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
libraries=libraries,
)
extensions += [extension]
......@@ -71,7 +80,7 @@ tests_require = ['pytest', 'pytest-cov']
setup(
name='torch_sparse',
version='0.5.1',
version='0.6.0',
author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de',
url='https://github.com/rusty1s/pytorch_sparse',
......
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from .utils import devices
@pytest.mark.parametrize('device', devices)
def test_metis(device):
mat = SparseTensor.from_dense(torch.randn((6, 6), device=device))
mat, partptr, perm = mat.partition(num_parts=2, recursive=False)
assert partptr.numel() == 3
assert perm.numel() == 6
mat, partptr, perm = mat.partition(num_parts=2, recursive=True)
assert partptr.numel() == 3
assert perm.numel() == 6
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from .utils import devices, tensor
@pytest.mark.parametrize('device', devices)
def test_permute(device):
row, col = tensor([[0, 0, 1, 2, 2], [0, 1, 0, 1, 2]], torch.long, device)
value = tensor([1, 2, 3, 4, 5], torch.float, device)
adj = SparseTensor(row=row, col=col, value=value)
row, col, value = adj.permute(torch.tensor([1, 0, 2])).coo()
assert row.tolist() == [0, 1, 1, 2, 2]
assert col.tolist() == [1, 0, 1, 0, 2]
assert value.tolist() == [3, 2, 1, 4, 5]
......@@ -93,8 +93,8 @@ def test_utility(dtype, device):
storage = storage.set_value(value, layout='coo')
assert storage.value().tolist() == [1, 2, 3, 4]
storage = storage.sparse_resize([3, 3])
assert storage.sparse_sizes() == [3, 3]
storage = storage.sparse_resize((3, 3))
assert storage.sparse_sizes() == (3, 3)
new_storage = storage.copy()
assert new_storage != storage
......
......@@ -3,11 +3,13 @@ import os.path as osp
import torch
__version__ = '0.5.1'
__version__ = '0.6.0'
expected_torch_version = (1, 4)
try:
for library in ['_version', '_convert', '_diag', '_spmm', '_spspmm']:
for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis'
]:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin)
except OSError as e:
......@@ -45,12 +47,14 @@ from .narrow import narrow, __narrow_diag__ # noqa
from .select import select # noqa
from .index_select import index_select, index_select_nnz # noqa
from .masked_select import masked_select, masked_select_nnz # noqa
from .permute import permute # noqa
from .diag import remove_diag, set_diag, fill_diag # noqa
from .add import add, add_, add_nnz, add_nnz_ # noqa
from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa
from .reduce import sum, mean, min, max # noqa
from .matmul import matmul # noqa
from .cat import cat, cat_diag # noqa
from .metis import partition # noqa
from .convert import to_torch_sparse, from_torch_sparse # noqa
from .convert import to_scipy, from_scipy # noqa
......@@ -71,6 +75,7 @@ __all__ = [
'index_select_nnz',
'masked_select',
'masked_select_nnz',
'permute',
'remove_diag',
'set_diag',
'fill_diag',
......@@ -89,6 +94,7 @@ __all__ = [
'matmul',
'cat',
'cat_diag',
'partition',
'to_torch_sparse',
'from_torch_sparse',
'to_scipy',
......
......@@ -5,7 +5,6 @@ from torch_scatter import gather_csr
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
......@@ -24,7 +23,6 @@ def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return src.set_value(value, layout='coo')
@torch.jit.script
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
......@@ -44,7 +42,6 @@ def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return src.set_value_(value, layout='coo')
@torch.jit.script
def add_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
......@@ -55,7 +52,6 @@ def add_nnz(src: SparseTensor, other: torch.Tensor,
return src.set_value(value, layout=layout)
@torch.jit.script
def add_nnz_(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
......
......@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
assert len(tensors) > 0
if dim < 0:
......@@ -142,7 +141,6 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got {dim}.')
@torch.jit.script
def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
assert len(tensors) > 0
......
......@@ -20,6 +20,6 @@ def coalesce(index, value, m, n, op="add"):
"""
storage = SparseStorage(row=index[0], col=index[1], value=value,
sparse_sizes=torch.Size([m, n]), is_sorted=False)
sparse_sizes=(m, n), is_sorted=False)
storage = storage.coalesce(reduce=op)
return torch.stack([storage.row(), storage.col()], dim=0), storage.value()
......@@ -5,7 +5,7 @@ from torch import from_numpy
def to_torch_sparse(index, value, m, n):
return torch.sparse_coo_tensor(index.detach(), value, torch.Size([m, n]))
return torch.sparse_coo_tensor(index.detach(), value, (m, n))
def from_torch_sparse(A):
......
......@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
row, col, value = src.coo()
inv_mask = row != col if k == 0 else row != (col - k)
......@@ -25,24 +24,14 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
colcount = colcount.clone()
colcount[col[mask]] -= 1
storage = SparseStorage(
row=new_row,
rowptr=None,
col=new_col,
value=value,
sparse_sizes=src.sparse_sizes(),
rowcount=rowcount,
colptr=None,
colcount=colcount,
csr2csc=None,
csc2csr=None,
is_sorted=True)
storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=value,
sparse_sizes=src.sparse_sizes(), rowcount=rowcount,
colptr=None, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
@torch.jit.script
def set_diag(src: SparseTensor,
values: Optional[torch.Tensor] = None,
def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
k: int = 0) -> SparseTensor:
src = remove_diag(src, k=k)
row, col, value = src.coo()
......@@ -69,8 +58,7 @@ def set_diag(src: SparseTensor,
if values is not None:
new_value[inv_mask] = values
else:
new_value[inv_mask] = torch.ones((num_diag, ),
dtype=value.dtype,
new_value[inv_mask] = torch.ones((num_diag, ), dtype=value.dtype,
device=value.device)
rowcount = src.storage._rowcount
......@@ -83,22 +71,13 @@ def set_diag(src: SparseTensor,
colcount = colcount.clone()
colcount[start + k:start + num_diag + k] += 1
storage = SparseStorage(
row=new_row,
rowptr=None,
col=new_col,
value=new_value,
sparse_sizes=src.sparse_sizes(),
rowcount=rowcount,
colptr=None,
colcount=colcount,
csr2csc=None,
csc2csr=None,
is_sorted=True)
storage = SparseStorage(row=new_row, rowptr=None, col=new_col,
value=new_value, sparse_sizes=src.sparse_sizes(),
rowcount=rowcount, colptr=None, colcount=colcount,
csr2csc=None, csc2csr=None, is_sorted=True)
return src.from_storage(storage)
@torch.jit.script
def fill_diag(src: SparseTensor, fill_value: int, k: int = 0) -> SparseTensor:
num_diag = min(src.sparse_size(0), src.sparse_size(1) - k)
if k < 0:
......
......@@ -6,7 +6,6 @@ from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def index_select(src: SparseTensor, dim: int,
idx: torch.Tensor) -> SparseTensor:
dim = src.dim() + dim if dim < 0 else dim
......@@ -32,7 +31,7 @@ def index_select(src: SparseTensor, dim: int,
if value is not None:
value = value[perm]
sparse_sizes = torch.Size([idx.size(0), src.sparse_size(1)])
sparse_sizes = (idx.size(0), src.sparse_size(1))
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
......@@ -62,7 +61,7 @@ def index_select(src: SparseTensor, dim: int,
if value is not None:
value = value[perm][csc2csr]
sparse_sizes = torch.Size([src.sparse_size(0), idx.size(0)])
sparse_sizes = (src.sparse_size(0), idx.size(0))
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
......@@ -79,7 +78,6 @@ def index_select(src: SparseTensor, dim: int,
raise ValueError
@torch.jit.script
def index_select_nnz(src: SparseTensor, idx: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
assert idx.dim() == 1
......
......@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def masked_select(src: SparseTensor, dim: int,
mask: torch.Tensor) -> SparseTensor:
dim = src.dim() + dim if dim < 0 else dim
......@@ -28,7 +27,7 @@ def masked_select(src: SparseTensor, dim: int,
if value is not None:
value = value[mask]
sparse_sizes = torch.Size([rowcount.size(0), src.sparse_size(1)])
sparse_sizes = (rowcount.size(0), src.sparse_size(1))
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
......@@ -55,7 +54,7 @@ def masked_select(src: SparseTensor, dim: int,
if value is not None:
value = value[csr2csc][mask][csc2csr]
sparse_sizes = torch.Size([src.sparse_size(0), colcount.size(0)])
sparse_sizes = (src.sparse_size(0), colcount.size(0))
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
......@@ -73,7 +72,6 @@ def masked_select(src: SparseTensor, dim: int,
raise ValueError
@torch.jit.script
def masked_select_nnz(src: SparseTensor, mask: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
assert mask.dim() == 1
......
......@@ -4,7 +4,6 @@ import torch
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
rowptr, col, value = src.csr()
......@@ -24,12 +23,10 @@ def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
csr2csc, other)
@torch.jit.script
def spmm_add(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
return spmm_sum(src, other)
@torch.jit.script
def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
rowptr, col, value = src.csr()
......@@ -51,21 +48,18 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
colptr, csr2csc, other)
@torch.jit.script
def spmm_min(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr()
return torch.ops.torch_sparse.spmm_min(rowptr, col, value, other)
@torch.jit.script
def spmm_max(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr()
return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)
@torch.jit.script
def spmm(src: SparseTensor, other: torch.Tensor,
reduce: str = "sum") -> torch.Tensor:
if reduce == 'sum' or reduce == 'add':
......@@ -80,7 +74,6 @@ def spmm(src: SparseTensor, other: torch.Tensor,
raise ValueError
@torch.jit.script
def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
assert src.sparse_size(1) == other.sparse_size(0)
rowptrA, colA, valueA = src.csr()
......@@ -88,21 +81,14 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
M, K = src.sparse_size(0), other.sparse_size(1)
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
rowptrA, colA, valueA, rowptrB, colB, valueB, K)
return SparseTensor(
row=None,
rowptr=rowptrC,
col=colC,
value=valueC,
sparse_sizes=torch.Size([M, K]),
is_sorted=True)
return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC,
sparse_sizes=(M, K), is_sorted=True)
@torch.jit.script
def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor:
return spspmm_sum(src, other)
@torch.jit.script
def spspmm(src: SparseTensor, other: SparseTensor,
reduce: str = "sum") -> SparseTensor:
if reduce == 'sum' or reduce == 'add':
......@@ -113,8 +99,7 @@ def spspmm(src: SparseTensor, other: SparseTensor,
raise ValueError
def matmul(src: SparseTensor,
other: Union[torch.Tensor, SparseTensor],
def matmul(src: SparseTensor, other: Union[torch.Tensor, SparseTensor],
reduce: str = "sum"):
if torch.is_tensor(other):
return spmm(src, other, reduce)
......
from typing import Tuple
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute
def partition(
src: SparseTensor, num_parts: int, recursive: bool = False
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu()
cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts,
recursive)
cluster = cluster.to(src.device())
cluster, perm = cluster.sort()
out = permute(src, perm)
partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)
return out, partptr, perm
SparseTensor.partition = lambda self, num_parts, recursive=False: partition(
self, num_parts, recursive)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment