"git@developer.sourcefind.cn:change/sglang.git" did not exist on "16a6b1d83a71bc1b669f3772bdce9e74a54fd404"
Unverified Commit bfb571cb authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #40 from rusty1s/metis

[WIP] Partition
parents e78637ea eee47eee
...@@ -60,6 +60,7 @@ jobs: ...@@ -60,6 +60,7 @@ jobs:
install: install:
- source script/cuda.sh - source script/cuda.sh
- source script/conda.sh - source script/conda.sh
- source script/metis.sh
- conda create --yes -n test python="${PYTHON_VERSION}" - conda create --yes -n test python="${PYTHON_VERSION}"
- source activate test - source activate test
- conda install pytorch=${TORCH_VERSION} ${TOOLKIT} -c pytorch --yes - conda install pytorch=${TORCH_VERSION} ${TOOLKIT} -c pytorch --yes
......
...@@ -59,17 +59,20 @@ $ echo $CPATH ...@@ -59,17 +59,20 @@ $ echo $CPATH
>>> /usr/local/cuda/include:... >>> /usr/local/cuda/include:...
``` ```
If you want to additionally build `torch-sparse` with METIS support, *e.g.* for partioning, please download and install the [METIS library](http://glaros.dtc.umn.edu/gkhome/metis/metis/download) by following the instructions in the `Install.txt` file.
Afterwards, set the environment variable `WITH_METIS=1`.
Then run: Then run:
``` ```
pip install torch-scatter torch-sparse pip install torch-scatter torch-sparse
``` ```
When running in a docker container without nvidia driver, PyTorch needs to evaluate the compute capabilities and may fail. When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail.
In this case, ensure that the compute capabilities are set via `TORCH_CUDA_ARCH_LIST`, *e.g.*: In this case, ensure that the compute capabilities are set via `TORCH_CUDA_ARCH_LIST`, *e.g.*:
``` ```
export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX" export TORCH_CUDA_ARCH_LIST="6.0 6.1 7.2+PTX 7.5+PTX"
``` ```
## Functions ## Functions
......
#include "metis_cpu.h"
#ifdef WITH_METIS
#include <metis.h>
#endif
#include "utils.h"
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts, bool recursive) {
#ifdef WITH_METIS
CHECK_CPU(rowptr);
CHECK_CPU(col);
int64_t nvtxs = rowptr.numel() - 1;
auto part = torch::empty(nvtxs, rowptr.options());
auto *xadj = rowptr.data_ptr<int64_t>();
auto *adjncy = col.data_ptr<int64_t>();
int64_t ncon = 1;
int64_t objval = -1;
auto part_data = part.data_ptr<int64_t>();
if (recursive) {
METIS_PartGraphRecursive(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, NULL,
&num_parts, NULL, NULL, NULL, &objval, part_data);
} else {
METIS_PartGraphKway(&nvtxs, &ncon, xadj, adjncy, NULL, NULL, NULL,
&num_parts, NULL, NULL, NULL, &objval, part_data);
}
return part;
#else
AT_ERROR("Not compiled with METIS support");
#endif
}
#pragma once
#include <torch/extension.h>
torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts, bool recursive);
#include <Python.h>
#include <torch/script.h>
#include "cpu/metis_cpu.h"
#ifdef _WIN32
PyMODINIT_FUNC PyInit__metis(void) { return NULL; }
#endif
torch::Tensor partition(torch::Tensor rowptr, torch::Tensor col,
int64_t num_parts, bool recursive) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return partition_cpu(rowptr, col, num_parts, recursive);
}
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::partition", &partition);
#!/bin/bash
METIS=metis-5.1.0
export WITH_METIS=1
wget -nv http://glaros.dtc.umn.edu/gkhome/fetch/sw/metis/${METIS}.tar.gz
tar -xvzf ${METIS}.tar.gz
cd ${METIS} || exit
sed -i.bak -e 's/IDXTYPEWIDTH 32/IDXTYPEWIDTH 64/g' include/metis.h
if [ "${TRAVIS_OS_NAME}" != "windows" ]; then
make config
make
sudo make install
else
# Fix GKlib on Windows: https://github.com/jlblancoc/suitesparse-metis-for-windows/issues/6
sed -i.bak -e '61,69d' GKlib/gk_arch.h
cd build || exit
cmake .. -A x64 # Ensure we are building with x64
cmake --build . --config "Release" --target ALL_BUILD
cp libmetis/Release/metis.lib /c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/VC/Tools/MSVC/14.16.27023/lib/x64
cp ../include/metis.h /c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/VC/Tools/MSVC/14.16.27023/include
cd ..
fi
cd ..
...@@ -16,10 +16,18 @@ if os.getenv('FORCE_CPU', '0') == '1': ...@@ -16,10 +16,18 @@ if os.getenv('FORCE_CPU', '0') == '1':
BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1' BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
WITH_METIS = False
if os.getenv('WITH_METIS', '0') == '1':
WITH_METIS = True
def get_extensions(): def get_extensions():
Extension = CppExtension Extension = CppExtension
define_macros = [] define_macros = []
libraries = []
if WITH_METIS:
define_macros += [('WITH_METIS', None)]
libraries += ['metis']
extra_compile_args = {'cxx': []} extra_compile_args = {'cxx': []}
extra_link_args = [] extra_link_args = []
...@@ -32,9 +40,9 @@ def get_extensions(): ...@@ -32,9 +40,9 @@ def get_extensions():
extra_compile_args['nvcc'] = nvcc_flags extra_compile_args['nvcc'] = nvcc_flags
if sys.platform == 'win32': if sys.platform == 'win32':
extra_link_args = ['cusparse.lib'] extra_link_args += ['cusparse.lib']
else: else:
extra_link_args = ['-lcusparse', '-l', 'cusparse'] extra_link_args += ['-lcusparse', '-l', 'cusparse']
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc') extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp')) main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
...@@ -59,6 +67,7 @@ def get_extensions(): ...@@ -59,6 +67,7 @@ def get_extensions():
define_macros=define_macros, define_macros=define_macros,
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args, extra_link_args=extra_link_args,
libraries=libraries,
) )
extensions += [extension] extensions += [extension]
...@@ -71,7 +80,7 @@ tests_require = ['pytest', 'pytest-cov'] ...@@ -71,7 +80,7 @@ tests_require = ['pytest', 'pytest-cov']
setup( setup(
name='torch_sparse', name='torch_sparse',
version='0.5.1', version='0.6.0',
author='Matthias Fey', author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de', author_email='matthias.fey@tu-dortmund.de',
url='https://github.com/rusty1s/pytorch_sparse', url='https://github.com/rusty1s/pytorch_sparse',
......
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from .utils import devices
@pytest.mark.parametrize('device', devices)
def test_metis(device):
mat = SparseTensor.from_dense(torch.randn((6, 6), device=device))
mat, partptr, perm = mat.partition(num_parts=2, recursive=False)
assert partptr.numel() == 3
assert perm.numel() == 6
mat, partptr, perm = mat.partition(num_parts=2, recursive=True)
assert partptr.numel() == 3
assert perm.numel() == 6
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from .utils import devices, tensor
@pytest.mark.parametrize('device', devices)
def test_permute(device):
row, col = tensor([[0, 0, 1, 2, 2], [0, 1, 0, 1, 2]], torch.long, device)
value = tensor([1, 2, 3, 4, 5], torch.float, device)
adj = SparseTensor(row=row, col=col, value=value)
row, col, value = adj.permute(torch.tensor([1, 0, 2])).coo()
assert row.tolist() == [0, 1, 1, 2, 2]
assert col.tolist() == [1, 0, 1, 0, 2]
assert value.tolist() == [3, 2, 1, 4, 5]
...@@ -93,8 +93,8 @@ def test_utility(dtype, device): ...@@ -93,8 +93,8 @@ def test_utility(dtype, device):
storage = storage.set_value(value, layout='coo') storage = storage.set_value(value, layout='coo')
assert storage.value().tolist() == [1, 2, 3, 4] assert storage.value().tolist() == [1, 2, 3, 4]
storage = storage.sparse_resize([3, 3]) storage = storage.sparse_resize((3, 3))
assert storage.sparse_sizes() == [3, 3] assert storage.sparse_sizes() == (3, 3)
new_storage = storage.copy() new_storage = storage.copy()
assert new_storage != storage assert new_storage != storage
......
...@@ -3,11 +3,13 @@ import os.path as osp ...@@ -3,11 +3,13 @@ import os.path as osp
import torch import torch
__version__ = '0.5.1' __version__ = '0.6.0'
expected_torch_version = (1, 4) expected_torch_version = (1, 4)
try: try:
for library in ['_version', '_convert', '_diag', '_spmm', '_spspmm']: for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis'
]:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec( torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin) library, [osp.dirname(__file__)]).origin)
except OSError as e: except OSError as e:
...@@ -45,12 +47,14 @@ from .narrow import narrow, __narrow_diag__ # noqa ...@@ -45,12 +47,14 @@ from .narrow import narrow, __narrow_diag__ # noqa
from .select import select # noqa from .select import select # noqa
from .index_select import index_select, index_select_nnz # noqa from .index_select import index_select, index_select_nnz # noqa
from .masked_select import masked_select, masked_select_nnz # noqa from .masked_select import masked_select, masked_select_nnz # noqa
from .permute import permute # noqa
from .diag import remove_diag, set_diag, fill_diag # noqa from .diag import remove_diag, set_diag, fill_diag # noqa
from .add import add, add_, add_nnz, add_nnz_ # noqa from .add import add, add_, add_nnz, add_nnz_ # noqa
from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa
from .reduce import sum, mean, min, max # noqa from .reduce import sum, mean, min, max # noqa
from .matmul import matmul # noqa from .matmul import matmul # noqa
from .cat import cat, cat_diag # noqa from .cat import cat, cat_diag # noqa
from .metis import partition # noqa
from .convert import to_torch_sparse, from_torch_sparse # noqa from .convert import to_torch_sparse, from_torch_sparse # noqa
from .convert import to_scipy, from_scipy # noqa from .convert import to_scipy, from_scipy # noqa
...@@ -71,6 +75,7 @@ __all__ = [ ...@@ -71,6 +75,7 @@ __all__ = [
'index_select_nnz', 'index_select_nnz',
'masked_select', 'masked_select',
'masked_select_nnz', 'masked_select_nnz',
'permute',
'remove_diag', 'remove_diag',
'set_diag', 'set_diag',
'fill_diag', 'fill_diag',
...@@ -89,6 +94,7 @@ __all__ = [ ...@@ -89,6 +94,7 @@ __all__ = [
'matmul', 'matmul',
'cat', 'cat',
'cat_diag', 'cat_diag',
'partition',
'to_torch_sparse', 'to_torch_sparse',
'from_torch_sparse', 'from_torch_sparse',
'to_scipy', 'to_scipy',
......
...@@ -5,7 +5,6 @@ from torch_scatter import gather_csr ...@@ -5,7 +5,6 @@ from torch_scatter import gather_csr
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor: def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise... if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
...@@ -24,7 +23,6 @@ def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor: ...@@ -24,7 +23,6 @@ def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return src.set_value(value, layout='coo') return src.set_value(value, layout='coo')
@torch.jit.script
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise... if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
...@@ -44,7 +42,6 @@ def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: ...@@ -44,7 +42,6 @@ def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return src.set_value_(value, layout='coo') return src.set_value_(value, layout='coo')
@torch.jit.script
def add_nnz(src: SparseTensor, other: torch.Tensor, def add_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor: layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value() value = src.storage.value()
...@@ -55,7 +52,6 @@ def add_nnz(src: SparseTensor, other: torch.Tensor, ...@@ -55,7 +52,6 @@ def add_nnz(src: SparseTensor, other: torch.Tensor,
return src.set_value(value, layout=layout) return src.set_value(value, layout=layout)
@torch.jit.script
def add_nnz_(src: SparseTensor, other: torch.Tensor, def add_nnz_(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor: layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value() value = src.storage.value()
......
...@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage ...@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor: def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
assert len(tensors) > 0 assert len(tensors) > 0
if dim < 0: if dim < 0:
...@@ -142,7 +141,6 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor: ...@@ -142,7 +141,6 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got {dim}.') '[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got {dim}.')
@torch.jit.script
def cat_diag(tensors: List[SparseTensor]) -> SparseTensor: def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
assert len(tensors) > 0 assert len(tensors) > 0
......
...@@ -20,6 +20,6 @@ def coalesce(index, value, m, n, op="add"): ...@@ -20,6 +20,6 @@ def coalesce(index, value, m, n, op="add"):
""" """
storage = SparseStorage(row=index[0], col=index[1], value=value, storage = SparseStorage(row=index[0], col=index[1], value=value,
sparse_sizes=torch.Size([m, n]), is_sorted=False) sparse_sizes=(m, n), is_sorted=False)
storage = storage.coalesce(reduce=op) storage = storage.coalesce(reduce=op)
return torch.stack([storage.row(), storage.col()], dim=0), storage.value() return torch.stack([storage.row(), storage.col()], dim=0), storage.value()
...@@ -5,7 +5,7 @@ from torch import from_numpy ...@@ -5,7 +5,7 @@ from torch import from_numpy
def to_torch_sparse(index, value, m, n): def to_torch_sparse(index, value, m, n):
return torch.sparse_coo_tensor(index.detach(), value, torch.Size([m, n])) return torch.sparse_coo_tensor(index.detach(), value, (m, n))
def from_torch_sparse(A): def from_torch_sparse(A):
......
...@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage ...@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor: def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
row, col, value = src.coo() row, col, value = src.coo()
inv_mask = row != col if k == 0 else row != (col - k) inv_mask = row != col if k == 0 else row != (col - k)
...@@ -25,24 +24,14 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor: ...@@ -25,24 +24,14 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
colcount = colcount.clone() colcount = colcount.clone()
colcount[col[mask]] -= 1 colcount[col[mask]] -= 1
storage = SparseStorage( storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=value,
row=new_row, sparse_sizes=src.sparse_sizes(), rowcount=rowcount,
rowptr=None, colptr=None, colcount=colcount, csr2csc=None,
col=new_col, csc2csr=None, is_sorted=True)
value=value,
sparse_sizes=src.sparse_sizes(),
rowcount=rowcount,
colptr=None,
colcount=colcount,
csr2csc=None,
csc2csr=None,
is_sorted=True)
return src.from_storage(storage) return src.from_storage(storage)
@torch.jit.script def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
def set_diag(src: SparseTensor,
values: Optional[torch.Tensor] = None,
k: int = 0) -> SparseTensor: k: int = 0) -> SparseTensor:
src = remove_diag(src, k=k) src = remove_diag(src, k=k)
row, col, value = src.coo() row, col, value = src.coo()
...@@ -69,8 +58,7 @@ def set_diag(src: SparseTensor, ...@@ -69,8 +58,7 @@ def set_diag(src: SparseTensor,
if values is not None: if values is not None:
new_value[inv_mask] = values new_value[inv_mask] = values
else: else:
new_value[inv_mask] = torch.ones((num_diag, ), new_value[inv_mask] = torch.ones((num_diag, ), dtype=value.dtype,
dtype=value.dtype,
device=value.device) device=value.device)
rowcount = src.storage._rowcount rowcount = src.storage._rowcount
...@@ -83,22 +71,13 @@ def set_diag(src: SparseTensor, ...@@ -83,22 +71,13 @@ def set_diag(src: SparseTensor,
colcount = colcount.clone() colcount = colcount.clone()
colcount[start + k:start + num_diag + k] += 1 colcount[start + k:start + num_diag + k] += 1
storage = SparseStorage( storage = SparseStorage(row=new_row, rowptr=None, col=new_col,
row=new_row, value=new_value, sparse_sizes=src.sparse_sizes(),
rowptr=None, rowcount=rowcount, colptr=None, colcount=colcount,
col=new_col, csr2csc=None, csc2csr=None, is_sorted=True)
value=new_value,
sparse_sizes=src.sparse_sizes(),
rowcount=rowcount,
colptr=None,
colcount=colcount,
csr2csc=None,
csc2csr=None,
is_sorted=True)
return src.from_storage(storage) return src.from_storage(storage)
@torch.jit.script
def fill_diag(src: SparseTensor, fill_value: int, k: int = 0) -> SparseTensor: def fill_diag(src: SparseTensor, fill_value: int, k: int = 0) -> SparseTensor:
num_diag = min(src.sparse_size(0), src.sparse_size(1) - k) num_diag = min(src.sparse_size(0), src.sparse_size(1) - k)
if k < 0: if k < 0:
......
...@@ -6,7 +6,6 @@ from torch_sparse.storage import SparseStorage, get_layout ...@@ -6,7 +6,6 @@ from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def index_select(src: SparseTensor, dim: int, def index_select(src: SparseTensor, dim: int,
idx: torch.Tensor) -> SparseTensor: idx: torch.Tensor) -> SparseTensor:
dim = src.dim() + dim if dim < 0 else dim dim = src.dim() + dim if dim < 0 else dim
...@@ -32,7 +31,7 @@ def index_select(src: SparseTensor, dim: int, ...@@ -32,7 +31,7 @@ def index_select(src: SparseTensor, dim: int,
if value is not None: if value is not None:
value = value[perm] value = value[perm]
sparse_sizes = torch.Size([idx.size(0), src.sparse_size(1)]) sparse_sizes = (idx.size(0), src.sparse_size(1))
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount, sparse_sizes=sparse_sizes, rowcount=rowcount,
...@@ -62,7 +61,7 @@ def index_select(src: SparseTensor, dim: int, ...@@ -62,7 +61,7 @@ def index_select(src: SparseTensor, dim: int,
if value is not None: if value is not None:
value = value[perm][csc2csr] value = value[perm][csc2csr]
sparse_sizes = torch.Size([src.sparse_size(0), idx.size(0)]) sparse_sizes = (src.sparse_size(0), idx.size(0))
storage = SparseStorage(row=row, rowptr=None, col=col, value=value, storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None, sparse_sizes=sparse_sizes, rowcount=None,
...@@ -79,7 +78,6 @@ def index_select(src: SparseTensor, dim: int, ...@@ -79,7 +78,6 @@ def index_select(src: SparseTensor, dim: int,
raise ValueError raise ValueError
@torch.jit.script
def index_select_nnz(src: SparseTensor, idx: torch.Tensor, def index_select_nnz(src: SparseTensor, idx: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor: layout: Optional[str] = None) -> SparseTensor:
assert idx.dim() == 1 assert idx.dim() == 1
......
...@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage, get_layout ...@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def masked_select(src: SparseTensor, dim: int, def masked_select(src: SparseTensor, dim: int,
mask: torch.Tensor) -> SparseTensor: mask: torch.Tensor) -> SparseTensor:
dim = src.dim() + dim if dim < 0 else dim dim = src.dim() + dim if dim < 0 else dim
...@@ -28,7 +27,7 @@ def masked_select(src: SparseTensor, dim: int, ...@@ -28,7 +27,7 @@ def masked_select(src: SparseTensor, dim: int,
if value is not None: if value is not None:
value = value[mask] value = value[mask]
sparse_sizes = torch.Size([rowcount.size(0), src.sparse_size(1)]) sparse_sizes = (rowcount.size(0), src.sparse_size(1))
storage = SparseStorage(row=row, rowptr=None, col=col, value=value, storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount, sparse_sizes=sparse_sizes, rowcount=rowcount,
...@@ -55,7 +54,7 @@ def masked_select(src: SparseTensor, dim: int, ...@@ -55,7 +54,7 @@ def masked_select(src: SparseTensor, dim: int,
if value is not None: if value is not None:
value = value[csr2csc][mask][csc2csr] value = value[csr2csc][mask][csc2csr]
sparse_sizes = torch.Size([src.sparse_size(0), colcount.size(0)]) sparse_sizes = (src.sparse_size(0), colcount.size(0))
storage = SparseStorage(row=row, rowptr=None, col=col, value=value, storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None, sparse_sizes=sparse_sizes, rowcount=None,
...@@ -73,7 +72,6 @@ def masked_select(src: SparseTensor, dim: int, ...@@ -73,7 +72,6 @@ def masked_select(src: SparseTensor, dim: int,
raise ValueError raise ValueError
@torch.jit.script
def masked_select_nnz(src: SparseTensor, mask: torch.Tensor, def masked_select_nnz(src: SparseTensor, mask: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor: layout: Optional[str] = None) -> SparseTensor:
assert mask.dim() == 1 assert mask.dim() == 1
......
...@@ -4,7 +4,6 @@ import torch ...@@ -4,7 +4,6 @@ import torch
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor: def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
...@@ -24,12 +23,10 @@ def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor: ...@@ -24,12 +23,10 @@ def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
csr2csc, other) csr2csc, other)
@torch.jit.script
def spmm_add(src: SparseTensor, other: torch.Tensor) -> torch.Tensor: def spmm_add(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
return spmm_sum(src, other) return spmm_sum(src, other)
@torch.jit.script
def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor: def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
...@@ -51,21 +48,18 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor: ...@@ -51,21 +48,18 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
colptr, csr2csc, other) colptr, csr2csc, other)
@torch.jit.script
def spmm_min(src: SparseTensor, def spmm_min(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
return torch.ops.torch_sparse.spmm_min(rowptr, col, value, other) return torch.ops.torch_sparse.spmm_min(rowptr, col, value, other)
@torch.jit.script
def spmm_max(src: SparseTensor, def spmm_max(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other) return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)
@torch.jit.script
def spmm(src: SparseTensor, other: torch.Tensor, def spmm(src: SparseTensor, other: torch.Tensor,
reduce: str = "sum") -> torch.Tensor: reduce: str = "sum") -> torch.Tensor:
if reduce == 'sum' or reduce == 'add': if reduce == 'sum' or reduce == 'add':
...@@ -80,7 +74,6 @@ def spmm(src: SparseTensor, other: torch.Tensor, ...@@ -80,7 +74,6 @@ def spmm(src: SparseTensor, other: torch.Tensor,
raise ValueError raise ValueError
@torch.jit.script
def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor: def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
assert src.sparse_size(1) == other.sparse_size(0) assert src.sparse_size(1) == other.sparse_size(0)
rowptrA, colA, valueA = src.csr() rowptrA, colA, valueA = src.csr()
...@@ -88,21 +81,14 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor: ...@@ -88,21 +81,14 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
M, K = src.sparse_size(0), other.sparse_size(1) M, K = src.sparse_size(0), other.sparse_size(1)
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum( rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
rowptrA, colA, valueA, rowptrB, colB, valueB, K) rowptrA, colA, valueA, rowptrB, colB, valueB, K)
return SparseTensor( return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC,
row=None, sparse_sizes=(M, K), is_sorted=True)
rowptr=rowptrC,
col=colC,
value=valueC,
sparse_sizes=torch.Size([M, K]),
is_sorted=True)
@torch.jit.script
def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor: def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor:
return spspmm_sum(src, other) return spspmm_sum(src, other)
@torch.jit.script
def spspmm(src: SparseTensor, other: SparseTensor, def spspmm(src: SparseTensor, other: SparseTensor,
reduce: str = "sum") -> SparseTensor: reduce: str = "sum") -> SparseTensor:
if reduce == 'sum' or reduce == 'add': if reduce == 'sum' or reduce == 'add':
...@@ -113,8 +99,7 @@ def spspmm(src: SparseTensor, other: SparseTensor, ...@@ -113,8 +99,7 @@ def spspmm(src: SparseTensor, other: SparseTensor,
raise ValueError raise ValueError
def matmul(src: SparseTensor, def matmul(src: SparseTensor, other: Union[torch.Tensor, SparseTensor],
other: Union[torch.Tensor, SparseTensor],
reduce: str = "sum"): reduce: str = "sum"):
if torch.is_tensor(other): if torch.is_tensor(other):
return spmm(src, other, reduce) return spmm(src, other, reduce)
......
from typing import Tuple
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute
def partition(
src: SparseTensor, num_parts: int, recursive: bool = False
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rowptr, col = src.storage.rowptr().cpu(), src.storage.col().cpu()
cluster = torch.ops.torch_sparse.partition(rowptr, col, num_parts,
recursive)
cluster = cluster.to(src.device())
cluster, perm = cluster.sort()
out = permute(src, perm)
partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)
return out, partptr, perm
SparseTensor.partition = lambda self, num_parts, recursive=False: partition(
self, num_parts, recursive)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment