Unverified Commit 7671fcb0 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #33 from rusty1s/adj

[WIP] SparseTensor Format
parents 1fb5fa4f 704ad420
from itertools import product
import pytest
import torch
from torch_sparse.matmul import matmul
from torch_sparse.tensor import SparseTensor
import torch_scatter
from .utils import reductions, devices, grad_dtypes
@pytest.mark.parametrize('dtype,device,reduce',
product(grad_dtypes, devices, reductions))
def test_spmm(dtype, device, reduce):
src = torch.randn((10, 8), dtype=dtype, device=device)
src[2:4, :] = 0 # Remove multiple rows.
src[:, 2:4] = 0 # Remove multiple columns.
src = SparseTensor.from_dense(src).requires_grad_()
row, col, value = src.coo()
other = torch.randn((2, 8, 2), dtype=dtype, device=device,
requires_grad=True)
src_col = other.index_select(-2, col) * value.unsqueeze(-1)
expected = torch_scatter.scatter(src_col, row, dim=-2, reduce=reduce)
if reduce == 'min':
expected[expected > 1000] = 0
if reduce == 'max':
expected[expected < -1000] = 0
grad_out = torch.randn_like(expected)
expected.backward(grad_out)
expected_grad_value = value.grad
value.grad = None
expected_grad_other = other.grad
other.grad = None
out = matmul(src, other, reduce)
out.backward(grad_out)
assert torch.allclose(expected, out, atol=1e-6)
assert torch.allclose(expected_grad_value, value.grad, atol=1e-6)
assert torch.allclose(expected_grad_other, other.grad, atol=1e-6)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device):
src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype,
device=device)
src = SparseTensor.from_dense(src)
out = matmul(src, src)
assert out.sizes() == [3, 3]
assert out.has_value()
rowptr, col, value = out.csr()
assert rowptr.tolist() == [0, 1, 2, 3]
assert col.tolist() == [0, 1, 2]
assert value.tolist() == [1, 1, 1]
src.set_value_(None)
out = matmul(src, src)
assert out.sizes() == [3, 3]
assert not out.has_value()
rowptr, col, value = out.csr()
assert rowptr.tolist() == [0, 1, 2, 3]
assert col.tolist() == [0, 1, 2]
import torch
from torch_sparse.tensor import SparseTensor
def test_overload():
row = torch.tensor([0, 1, 1, 2, 2])
col = torch.tensor([1, 0, 2, 1, 2])
mat = SparseTensor(row=row, col=col)
other = torch.tensor([1, 2, 3]).view(3, 1)
other + mat
mat + other
other * mat
mat * other
other = torch.tensor([1, 2, 3]).view(1, 3)
other + mat
mat + other
other * mat
mat * other
...@@ -4,32 +4,16 @@ import pytest ...@@ -4,32 +4,16 @@ import pytest
import torch import torch
from torch_sparse import spspmm from torch_sparse import spspmm
from .utils import dtypes, devices, tensor from .utils import grad_dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device): def test_spspmm(dtype, device):
indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device) indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device)
valueA = tensor([1, 2, 3, 4, 5], dtype, device) valueA = tensor([1, 2, 3, 4, 5], dtype, device)
sizeA = torch.Size([3, 3])
indexB = torch.tensor([[0, 2], [1, 0]], device=device) indexB = torch.tensor([[0, 2], [1, 0]], device=device)
valueB = tensor([2, 4], dtype, device) valueB = tensor([2, 4], dtype, device)
sizeB = torch.Size([3, 2])
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2) indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
assert indexC.tolist() == [[0, 1, 2], [0, 1, 1]] assert indexC.tolist() == [[0, 1, 2], [0, 1, 1]]
assert valueC.tolist() == [8, 6, 8] assert valueC.tolist() == [8, 6, 8]
A = torch.sparse_coo_tensor(indexA, valueA, sizeA, device=device)
A = A.to_dense().requires_grad_()
B = torch.sparse_coo_tensor(indexB, valueB, sizeB, device=device)
B = B.to_dense().requires_grad_()
torch.matmul(A, B).sum().backward()
valueA = valueA.requires_grad_()
valueB = valueB.requires_grad_()
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
valueC.sum().backward()
assert valueA.grad.tolist() == A.grad[indexA[0], indexA[1]].tolist()
assert valueB.grad.tolist() == B.grad[indexB[0], indexB[1]].tolist()
from itertools import product
import pytest
import torch
from torch_sparse import spspmm, spmm
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_spmm_spspmm(dtype, device):
row = torch.tensor([0, 0, 1, 2, 2], device=device)
col = torch.tensor([0, 2, 1, 0, 1], device=device)
index = torch.stack([row, col], dim=0)
value = tensor([1, 2, 4, 1, 3], dtype, device)
x = tensor([[1, 4], [2, 5], [3, 6]], dtype, device)
value = value.requires_grad_(True)
out_index, out_value = spspmm(index, value, index, value, 3, 3, 3)
out = spmm(out_index, out_value, 3, 3, x)
assert out.size() == (3, 2)
from itertools import product
import pytest
import torch
from torch_sparse.storage import SparseStorage
from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_storage(dtype, device):
row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
storage = SparseStorage(row=row, col=col)
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 1, 0, 1]
assert storage.value() is None
assert storage.sparse_sizes() == (2, 2)
row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
value = tensor([2, 1, 4, 3], dtype, device)
storage = SparseStorage(row=row, col=col, value=value)
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 1, 0, 1]
assert storage.value().tolist() == [1, 2, 3, 4]
assert storage.sparse_sizes() == (2, 2)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_caching(dtype, device):
row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
storage = SparseStorage(row=row, col=col)
assert storage._row.tolist() == row.tolist()
assert storage._col.tolist() == col.tolist()
assert storage._value is None
assert storage._rowcount is None
assert storage._rowptr is None
assert storage._colcount is None
assert storage._colptr is None
assert storage._csr2csc is None
assert storage.num_cached_keys() == 0
storage.fill_cache_()
assert storage._rowcount.tolist() == [2, 2]
assert storage._rowptr.tolist() == [0, 2, 4]
assert storage._colcount.tolist() == [2, 2]
assert storage._colptr.tolist() == [0, 2, 4]
assert storage._csr2csc.tolist() == [0, 2, 1, 3]
assert storage._csc2csr.tolist() == [0, 2, 1, 3]
assert storage.num_cached_keys() == 5
storage = SparseStorage(row=row, rowptr=storage._rowptr, col=col,
value=storage._value,
sparse_sizes=storage._sparse_sizes,
rowcount=storage._rowcount, colptr=storage._colptr,
colcount=storage._colcount,
csr2csc=storage._csr2csc, csc2csr=storage._csc2csr)
assert storage._rowcount.tolist() == [2, 2]
assert storage._rowptr.tolist() == [0, 2, 4]
assert storage._colcount.tolist() == [2, 2]
assert storage._colptr.tolist() == [0, 2, 4]
assert storage._csr2csc.tolist() == [0, 2, 1, 3]
assert storage._csc2csr.tolist() == [0, 2, 1, 3]
assert storage.num_cached_keys() == 5
storage.clear_cache_()
assert storage._rowcount is None
assert storage._rowptr is not None
assert storage._colcount is None
assert storage._colptr is None
assert storage._csr2csc is None
assert storage.num_cached_keys() == 0
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_utility(dtype, device):
row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device)
storage = SparseStorage(row=row, col=col, value=value)
assert storage.has_value()
storage.set_value_(value, layout='csc')
assert storage.value().tolist() == [1, 3, 2, 4]
storage.set_value_(value, layout='coo')
assert storage.value().tolist() == [1, 2, 3, 4]
storage = storage.set_value(value, layout='csc')
assert storage.value().tolist() == [1, 3, 2, 4]
storage = storage.set_value(value, layout='coo')
assert storage.value().tolist() == [1, 2, 3, 4]
storage = storage.sparse_resize([3, 3])
assert storage.sparse_sizes() == [3, 3]
new_storage = storage.copy()
assert new_storage != storage
assert new_storage.col().data_ptr() == storage.col().data_ptr()
new_storage = storage.clone()
assert new_storage != storage
assert new_storage.col().data_ptr() != storage.col().data_ptr()
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_coalesce(dtype, device):
row, col = tensor([[0, 0, 0, 1, 1], [0, 1, 1, 0, 1]], torch.long, device)
value = tensor([1, 1, 1, 3, 4], dtype, device)
storage = SparseStorage(row=row, col=col, value=value)
assert storage.row().tolist() == row.tolist()
assert storage.col().tolist() == col.tolist()
assert storage.value().tolist() == value.tolist()
assert not storage.is_coalesced()
storage = storage.coalesce()
assert storage.is_coalesced()
assert storage.row().tolist() == [0, 0, 1, 1]
assert storage.col().tolist() == [0, 1, 0, 1]
assert storage.value().tolist() == [1, 2, 3, 4]
import torch import torch
dtypes = [torch.float, torch.double] reductions = ['sum', 'add', 'mean', 'min', 'max']
dtypes = [torch.float, torch.double, torch.int, torch.long]
grad_dtypes = [torch.float, torch.double]
devices = [torch.device('cpu')] devices = [torch.device('cpu')]
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -8,4 +11,4 @@ if torch.cuda.is_available(): ...@@ -8,4 +11,4 @@ if torch.cuda.is_available():
def tensor(x, dtype, device): def tensor(x, dtype, device):
return torch.tensor(x, dtype=dtype, device=device) return None if x is None else torch.tensor(x, dtype=dtype, device=device)
from .storage import SparseStorage
from .tensor import SparseTensor
from .transpose import t
from .narrow import narrow, __narrow_diag__
from .select import select
from .index_select import index_select, index_select_nnz
from .masked_select import masked_select, masked_select_nnz
from .diag import remove_diag, set_diag, fill_diag
from .add import add, add_, add_nnz, add_nnz_
from .mul import mul, mul_, mul_nnz, mul_nnz_
from .reduce import sum, mean, min, max
from .matmul import matmul
from .cat import cat, cat_diag
from .convert import to_torch_sparse, from_torch_sparse, to_scipy, from_scipy from .convert import to_torch_sparse, from_torch_sparse, to_scipy, from_scipy
from .coalesce import coalesce from .coalesce import coalesce
from .transpose import transpose from .transpose import transpose
...@@ -5,10 +19,37 @@ from .eye import eye ...@@ -5,10 +19,37 @@ from .eye import eye
from .spmm import spmm from .spmm import spmm
from .spspmm import spspmm from .spspmm import spspmm
__version__ = '0.4.4' __version__ = '1.0.0'
__all__ = [ __all__ = [
'__version__', 'SparseStorage',
'SparseTensor',
't',
'narrow',
'__narrow_diag__',
'select',
'index_select',
'index_select_nnz',
'masked_select',
'masked_select_nnz',
'remove_diag',
'set_diag',
'fill_diag',
'add',
'add_',
'add_nnz',
'add_nnz_',
'mul',
'mul_',
'mul_nnz',
'mul_nnz_',
'sum',
'mean',
'min',
'max',
'matmul',
'cat',
'cat_diag',
'to_torch_sparse', 'to_torch_sparse',
'from_torch_sparse', 'from_torch_sparse',
'to_scipy', 'to_scipy',
...@@ -18,4 +59,5 @@ __all__ = [ ...@@ -18,4 +59,5 @@ __all__ = [
'eye', 'eye',
'spmm', 'spmm',
'spspmm', 'spspmm',
'__version__',
] ]
from typing import Optional
import torch
from torch_scatter import gather_csr
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else:
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
f' ...) or (1, {src.size(1)}, ...), but got size '
f'{other.size()}.')
if value is not None:
value = other.to(value.dtype).add_(value)
else:
value = other.add_(1)
return src.set_value(value, layout='coo')
@torch.jit.script
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else:
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
f' ...) or (1, {src.size(1)}, ...), but got size '
f'{other.size()}.')
if value is not None:
value = value.add_(other.to(value.dtype))
else:
value = other.add_(1)
return src.set_value_(value, layout='coo')
@torch.jit.script
def add_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
if value is not None:
value = value.add(other.to(value.dtype))
else:
value = other.add(1)
return src.set_value(value, layout=layout)
@torch.jit.script
def add_nnz_(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
if value is not None:
value = value.add_(other.to(value.dtype))
else:
value = other.add(1)
return src.set_value_(value, layout=layout)
SparseTensor.add = lambda self, other: add(self, other)
SparseTensor.add_ = lambda self, other: add_(self, other)
SparseTensor.add_nnz = lambda self, other, layout=None: add_nnz(
self, other, layout)
SparseTensor.add_nnz_ = lambda self, other, layout=None: add_nnz_(
self, other, layout)
SparseTensor.__add__ = SparseTensor.add
SparseTensor.__radd__ = SparseTensor.add
SparseTensor.__iadd__ = SparseTensor.add_
from typing import List, Optional
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
assert len(tensors) > 0
if dim < 0:
dim = tensors[0].dim() + dim
if dim == 0:
rows: List[torch.Tensor] = []
rowptrs: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
rowcounts: List[torch.Tensor] = []
nnz: int = 0
for tensor in tensors:
row = tensor.storage._row
if row is not None:
rows.append(row + sparse_sizes[0])
rowptr = tensor.storage._rowptr
if rowptr is not None:
if len(rowptrs) > 0:
rowptr = rowptr[1:]
rowptrs.append(rowptr + nnz)
cols.append(tensor.storage._col)
value = tensor.storage._value
if value is not None:
values.append(value)
rowcount = tensor.storage._rowcount
if rowcount is not None:
rowcounts.append(rowcount)
sparse_sizes[0] += tensor.sparse_size(0)
sparse_sizes[1] = max(sparse_sizes[1], tensor.sparse_size(1))
nnz += tensor.nnz()
row: Optional[torch.Tensor] = None
if len(rows) == len(tensors):
row = torch.cat(rows, dim=0)
rowptr: Optional[torch.Tensor] = None
if len(rowptrs) == len(tensors):
rowptr = torch.cat(rowptrs, dim=0)
col = torch.cat(cols, dim=0)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=0)
rowcount: Optional[torch.Tensor] = None
if len(rowcounts) == len(tensors):
rowcount = torch.cat(rowcounts, dim=0)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return tensors[0].from_storage(storage)
elif dim == 1:
rows: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
colptrs: List[torch.Tensor] = []
colcounts: List[torch.Tensor] = []
nnz: int = 0
for tensor in tensors:
row, col, value = tensor.coo()
rows.append(row)
cols.append(tensor.storage._col + sparse_sizes[1])
if value is not None:
values.append(value)
colptr = tensor.storage._colptr
if colptr is not None:
if len(colptrs) > 0:
colptr = colptr[1:]
colptrs.append(colptr + nnz)
colcount = tensor.storage._colcount
if colcount is not None:
colcounts.append(colcount)
sparse_sizes[0] = max(sparse_sizes[0], tensor.sparse_size(0))
sparse_sizes[1] += tensor.sparse_size(1)
nnz += tensor.nnz()
row = torch.cat(rows, dim=0)
col = torch.cat(cols, dim=0)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=0)
colptr: Optional[torch.Tensor] = None
if len(colptrs) == len(tensors):
colptr = torch.cat(colptrs, dim=0)
colcount: Optional[torch.Tensor] = None
if len(colcounts) == len(tensors):
colcount = torch.cat(colcounts, dim=0)
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colptr=colptr, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=False)
return tensors[0].from_storage(storage)
elif dim > 1 and dim < tensors[0].dim():
values: List[torch.Tensor] = []
for tensor in tensors:
value = tensor.storage.value()
if value is not None:
values.append(value)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=dim - 1)
return tensors[0].set_value(value, layout='coo')
else:
raise IndexError(
(f'Dimension out of range: Expected to be in range of '
f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}, but got {dim}]'))
@torch.jit.script
def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
assert len(tensors) > 0
rows: List[torch.Tensor] = []
rowptrs: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
rowcounts: List[torch.Tensor] = []
colptrs: List[torch.Tensor] = []
colcounts: List[torch.Tensor] = []
csr2cscs: List[torch.Tensor] = []
csc2csrs: List[torch.Tensor] = []
nnz: int = 0
for tensor in tensors:
row = tensor.storage._row
if row is not None:
rows.append(row + sparse_sizes[0])
rowptr = tensor.storage._rowptr
if rowptr is not None:
if len(rowptrs) > 0:
rowptr = rowptr[1:]
rowptrs.append(rowptr + nnz)
cols.append(tensor.storage._col + sparse_sizes[1])
value = tensor.storage._value
if value is not None:
values.append(value)
rowcount = tensor.storage._rowcount
if rowcount is not None:
rowcounts.append(rowcount)
colptr = tensor.storage._colptr
if colptr is not None:
if len(colptrs) > 0:
colptr = colptr[1:]
colptrs.append(colptr + nnz)
colcount = tensor.storage._colcount
if colcount is not None:
colcounts.append(colcount)
csr2csc = tensor.storage._csr2csc
if csr2csc is not None:
csr2cscs.append(csr2csc + nnz)
csc2csr = tensor.storage._csc2csr
if csc2csr is not None:
csc2csrs.append(csc2csr + nnz)
sparse_sizes[0] += tensor.sparse_size(0)
sparse_sizes[1] += tensor.sparse_size(1)
nnz += tensor.nnz()
row: Optional[torch.Tensor] = None
if len(rows) == len(tensors):
row = torch.cat(rows, dim=0)
rowptr: Optional[torch.Tensor] = None
if len(rowptrs) == len(tensors):
rowptr = torch.cat(rowptrs, dim=0)
col = torch.cat(cols, dim=0)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=0)
rowcount: Optional[torch.Tensor] = None
if len(rowcounts) == len(tensors):
rowcount = torch.cat(rowcounts, dim=0)
colptr: Optional[torch.Tensor] = None
if len(colptrs) == len(tensors):
colptr = torch.cat(colptrs, dim=0)
colcount: Optional[torch.Tensor] = None
if len(colcounts) == len(tensors):
colcount = torch.cat(colcounts, dim=0)
csr2csc: Optional[torch.Tensor] = None
if len(csr2cscs) == len(tensors):
csr2csc = torch.cat(csr2cscs, dim=0)
csc2csr: Optional[torch.Tensor] = None
if len(csc2csrs) == len(tensors):
csc2csr = torch.cat(csc2csrs, dim=0)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=colptr, colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
return tensors[0].from_storage(storage)
import torch import torch
import torch_scatter from torch_sparse.storage import SparseStorage
from .utils.unique import unique
def coalesce(index, value, m, n, op='add'): def coalesce(index, value, m, n, op="add"):
"""Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate """Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate
entries are removed by scattering them together. For scattering, any entries are removed by scattering them together. For scattering, any
operation of `"torch_scatter"<https://github.com/rusty1s/pytorch_scatter>`_ operation of `"torch_scatter"<https://github.com/rusty1s/pytorch_scatter>`_
...@@ -21,21 +20,7 @@ def coalesce(index, value, m, n, op='add'): ...@@ -21,21 +20,7 @@ def coalesce(index, value, m, n, op='add'):
:rtype: (:class:`LongTensor`, :class:`Tensor`) :rtype: (:class:`LongTensor`, :class:`Tensor`)
""" """
row, col = index storage = SparseStorage(row=index[0], col=index[1], value=value,
sparse_sizes=torch.Size([m, n], is_sorted=False))
if value is None: storage = storage.coalesce(reduce=op)
_, perm = unique(row * n + col) return torch.stack([storage.row(), storage.col()], dim=0), storage.value()
index = torch.stack([row[perm], col[perm]], dim=0)
return index, value
uniq, inv = torch.unique(row * n + col, sorted=True, return_inverse=True)
perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
perm = inv.new_empty(uniq.size(0)).scatter_(0, inv, perm)
index = torch.stack([row[perm], col[perm]], dim=0)
op = getattr(torch_scatter, 'scatter_{}'.format(op))
value = op(value, inv, 0, None, perm.size(0))
value = value[0] if isinstance(value, tuple) else value
return index, value
import warnings
import os.path as osp
from typing import Optional
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
try:
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_diag.so'))
except OSError:
warnings.warn('Failed to load `diag` binaries.')
def non_diag_mask_placeholder(row: torch.Tensor, col: torch.Tensor, M: int,
N: int, k: int) -> torch.Tensor:
raise ImportError
return row
torch.ops.torch_sparse.non_diag_mask = non_diag_mask_placeholder
@torch.jit.script
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
row, col, value = src.coo()
inv_mask = row != col if k == 0 else row != (col - k)
new_row, new_col = row[inv_mask], col[inv_mask]
if value is not None:
value = value[inv_mask]
rowcount = src.storage._rowcount
colcount = src.storage._colcount
if rowcount is not None or colcount is not None:
mask = ~inv_mask
if rowcount is not None:
rowcount = rowcount.clone()
rowcount[row[mask]] -= 1
if colcount is not None:
colcount = colcount.clone()
colcount[col[mask]] -= 1
storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=value,
sparse_sizes=src.sparse_sizes(), rowcount=rowcount,
colptr=None, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
@torch.jit.script
def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
k: int = 0) -> SparseTensor:
src = remove_diag(src, k=k)
row, col, value = src.coo()
mask = torch.ops.torch_sparse.non_diag_mask(row, col, src.size(0),
src.size(1), k)
inv_mask = ~mask
start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
diag = torch.arange(start, start + num_diag, device=row.device)
new_row = row.new_empty(mask.size(0))
new_row[mask] = row
new_row[inv_mask] = diag
new_col = col.new_empty(mask.size(0))
new_col[mask] = col
new_col[inv_mask] = diag.add_(k)
new_value: Optional[torch.Tensor] = None
if value is not None:
new_value = value.new_empty((mask.size(0), ) + value.size()[1:])
new_value[mask] = value
if values is not None:
new_value[inv_mask] = values
else:
new_value[inv_mask] = torch.ones((num_diag, ), dtype=value.dtype,
device=value.device)
rowcount = src.storage._rowcount
if rowcount is not None:
rowcount = rowcount.clone()
rowcount[start:start + num_diag] += 1
colcount = src.storage._colcount
if colcount is not None:
colcount = colcount.clone()
colcount[start + k:start + num_diag + k] += 1
storage = SparseStorage(row=new_row, rowptr=None, col=new_col,
value=new_value, sparse_sizes=src.sparse_sizes(),
rowcount=rowcount, colptr=None, colcount=colcount,
csr2csc=None, csc2csr=None, is_sorted=True)
return src.from_storage(storage)
@torch.jit.script
def fill_diag(src: SparseTensor, fill_value: int, k: int = 0) -> SparseTensor:
num_diag = min(src.sparse_size(0), src.sparse_size(1) - k)
if k < 0:
num_diag = min(src.sparse_size(0) + k, src.sparse_size(1))
value = src.storage.value()
if value is not None:
sizes = [num_diag] + src.sizes()[2:]
return set_diag(src, value.new_full(sizes, fill_value), k)
else:
return set_diag(src, None, k)
SparseTensor.remove_diag = lambda self, k=0: remove_diag(self, k)
SparseTensor.set_diag = lambda self, values=None, k=0: set_diag(
self, values, k)
SparseTensor.fill_diag = lambda self, fill_value, k=0: fill_diag(
self, fill_value, k)
from typing import Optional
import torch
from torch_scatter import gather_csr
from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def index_select(src: SparseTensor, dim: int,
idx: torch.Tensor) -> SparseTensor:
dim = src.dim() + dim if dim < 0 else dim
assert idx.dim() == 1
if dim == 0:
old_rowptr, col, value = src.csr()
rowcount = src.storage.rowcount()
rowcount = rowcount[idx]
rowptr = col.new_zeros(idx.size(0) + 1)
torch.cumsum(rowcount, dim=0, out=rowptr[1:])
row = torch.arange(idx.size(0),
device=col.device).repeat_interleave(rowcount)
perm = torch.arange(row.size(0), device=row.device)
perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr)
col = col[perm]
if value is not None:
value = value[perm]
sparse_sizes = torch.Size([idx.size(0), src.sparse_size(1)])
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
elif dim == 1:
old_colptr, row, value = src.csc()
colcount = src.storage.colcount()
colcount = colcount[idx]
colptr = row.new_zeros(idx.size(0) + 1)
torch.cumsum(colcount, dim=0, out=colptr[1:])
col = torch.arange(idx.size(0),
device=row.device).repeat_interleave(colcount)
perm = torch.arange(col.size(0), device=col.device)
perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr)
row = row[perm]
csc2csr = (idx.size(0) * row + col).argsort()
row, col = row[csc2csr], col[csc2csr]
if value is not None:
value = value[perm][csc2csr]
sparse_sizes = torch.Size([src.sparse_size(0), idx.size(0)])
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colptr=colptr, colcount=colcount, csr2csc=None,
csc2csr=csc2csr, is_sorted=True)
return src.from_storage(storage)
else:
value = src.storage.value()
if value is not None:
return src.set_value(value.index_select(dim - 1, idx),
layout='coo')
else:
raise ValueError
@torch.jit.script
def index_select_nnz(src: SparseTensor, idx: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
assert idx.dim() == 1
if get_layout(layout) == 'csc':
idx = src.storage.csc2csr()[idx]
row, col, value = src.coo()
row, col = row[idx], col[idx]
if value is not None:
value = value[idx]
return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=src.sparse_sizes(), is_sorted=True)
SparseTensor.index_select = lambda self, dim, idx: index_select(self, dim, idx)
tmp = lambda self, idx, layout=None: index_select_nnz( # noqa
self, idx, layout)
SparseTensor.index_select_nnz = tmp
from typing import Optional
import torch
from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def masked_select(src: SparseTensor, dim: int,
mask: torch.Tensor) -> SparseTensor:
dim = src.dim() + dim if dim < 0 else dim
assert mask.dim() == 1
storage = src.storage
if dim == 0:
row, col, value = src.coo()
rowcount = src.storage.rowcount()
rowcount = rowcount[mask]
mask = mask[row]
row = torch.arange(rowcount.size(0),
device=row.device).repeat_interleave(rowcount)
col = col[mask]
if value is not None:
value = value[mask]
sparse_sizes = torch.Size([rowcount.size(0), src.sparse_size(1)])
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colcount=None, colptr=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
elif dim == 1:
row, col, value = src.coo()
csr2csc = src.storage.csr2csc()
row = row[csr2csc]
col = col[csr2csc]
colcount = src.storage.colcount()
colcount = colcount[mask]
mask = mask[col]
col = torch.arange(colcount.size(0),
device=col.device).repeat_interleave(colcount)
row = row[mask]
csc2csr = (colcount.size(0) * row + col).argsort()
row, col = row[csc2csr], col[csc2csr]
if value is not None:
value = value[csr2csc][mask][csc2csr]
sparse_sizes = torch.Size([src.sparse_size(0), colcount.size(0)])
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colcount=colcount, colptr=None, csr2csc=None,
csc2csr=csc2csr, is_sorted=True)
return src.from_storage(storage)
else:
value = src.storage.value()
if value is not None:
idx = mask.nonzero().flatten()
return src.set_value(value.index_select(dim - 1, idx),
layout='coo')
else:
raise ValueError
@torch.jit.script
def masked_select_nnz(src: SparseTensor, mask: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
assert mask.dim() == 1
if get_layout(layout) == 'csc':
mask = mask[src.storage.csc2csr()]
row, col, value = src.coo()
row, col = row[mask], col[mask]
if value is not None:
value = value[mask]
return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=src.sparse_sizes(), is_sorted=True)
SparseTensor.masked_select = lambda self, dim, mask: masked_select(
self, dim, mask)
tmp = lambda src, mask, layout=None: masked_select_nnz( # noqa
src, mask, layout)
SparseTensor.masked_select_nnz = tmp
import warnings
import os.path as osp
from typing import Optional, Union, Tuple
import torch
from torch_sparse.tensor import SparseTensor
try:
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_spmm.so'))
except OSError:
warnings.warn('Failed to load `spmm` binaries.')
def spmm_sum_placeholder(row: Optional[torch.Tensor], rowptr: torch.Tensor,
col: torch.Tensor, value: Optional[torch.Tensor],
colptr: Optional[torch.Tensor],
csr2csc: Optional[torch.Tensor],
mat: torch.Tensor) -> torch.Tensor:
raise ImportError
return mat
def spmm_mean_placeholder(row: Optional[torch.Tensor],
rowptr: torch.Tensor, col: torch.Tensor,
value: Optional[torch.Tensor],
rowcount: Optional[torch.Tensor],
colptr: Optional[torch.Tensor],
csr2csc: Optional[torch.Tensor],
mat: torch.Tensor) -> torch.Tensor:
raise ImportError
return mat
def spmm_min_max_placeholder(rowptr: torch.Tensor, col: torch.Tensor,
value: Optional[torch.Tensor],
mat: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
raise ImportError
return mat, mat
torch.ops.torch_sparse.spmm_sum = spmm_sum_placeholder
torch.ops.torch_sparse.spmm_mean = spmm_mean_placeholder
torch.ops.torch_sparse.spmm_min = spmm_min_max_placeholder
torch.ops.torch_sparse.spmm_max = spmm_min_max_placeholder
try:
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_spspmm.so'))
except OSError:
warnings.warn('Failed to load `spspmm` binaries.')
def spspmm_sum_placeholder(
rowptrA: torch.Tensor, colA: torch.Tensor,
valueA: Optional[torch.Tensor], rowptrB: torch.Tensor,
colB: torch.Tensor, valueB: Optional[torch.Tensor], K: int
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
raise ImportError
return rowptrA, colA, valueA
torch.ops.torch_sparse.spspmm_sum = spspmm_sum_placeholder
@torch.jit.script
def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
rowptr, col, value = src.csr()
row = src.storage._row
csr2csc = src.storage._csr2csc
colptr = src.storage._colptr
if value is not None and value.requires_grad:
row = src.storage.row()
if other.requires_grad:
row = src.storage.row()
csr2csc = src.storage.csr2csc()
colptr = src.storage.colptr()
return torch.ops.torch_sparse.spmm_sum(row, rowptr, col, value, colptr,
csr2csc, other)
@torch.jit.script
def spmm_add(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
return spmm_sum(src, other)
@torch.jit.script
def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
rowptr, col, value = src.csr()
row = src.storage._row
rowcount = src.storage._rowcount
csr2csc = src.storage._csr2csc
colptr = src.storage._colptr
if value is not None and value.requires_grad:
row = src.storage.row()
if other.requires_grad:
row = src.storage.row()
rowcount = src.storage.rowcount()
csr2csc = src.storage.csr2csc()
colptr = src.storage.colptr()
return torch.ops.torch_sparse.spmm_mean(row, rowptr, col, value, rowcount,
colptr, csr2csc, other)
@torch.jit.script
def spmm_min(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr()
return torch.ops.torch_sparse.spmm_min(rowptr, col, value, other)
@torch.jit.script
def spmm_max(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr()
return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)
@torch.jit.script
def spmm(src: SparseTensor, other: torch.Tensor,
reduce: str = "sum") -> torch.Tensor:
if reduce == 'sum' or reduce == 'add':
return spmm_sum(src, other)
elif reduce == 'mean':
return spmm_mean(src, other)
elif reduce == 'min':
return spmm_min(src, other)[0]
elif reduce == 'max':
return spmm_max(src, other)[0]
else:
raise ValueError
@torch.jit.script
def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
assert src.sparse_size(1) == other.sparse_size(0)
rowptrA, colA, valueA = src.csr()
rowptrB, colB, valueB = other.csr()
M, K = src.sparse_size(0), other.sparse_size(1)
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
rowptrA, colA, valueA, rowptrB, colB, valueB, K)
return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC,
sparse_sizes=torch.Size([M, K]), is_sorted=True)
@torch.jit.script
def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor:
return spspmm_sum(src, other)
@torch.jit.script
def spspmm(src: SparseTensor, other: SparseTensor,
reduce: str = "sum") -> SparseTensor:
if reduce == 'sum' or reduce == 'add':
return spspmm_sum(src, other)
elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
raise NotImplementedError
else:
raise ValueError
def matmul(src: SparseTensor, other: Union[torch.Tensor, SparseTensor],
reduce: str = "sum"):
if torch.is_tensor(other):
return spmm(src, other, reduce)
elif isinstance(other, SparseTensor):
return spspmm(src, other, reduce)
else:
raise ValueError
SparseTensor.spmm = lambda self, other, reduce=None: spmm(self, other, reduce)
SparseTensor.spspmm = lambda self, other, reduce=None: spspmm(
self, other, reduce)
SparseTensor.matmul = lambda self, other, reduce=None: matmul(
self, other, reduce)
SparseTensor.__matmul__ = lambda self, other: matmul(self, other, 'sum')
from typing import Optional
import torch
from torch_scatter import gather_csr
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else:
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
f' ...) or (1, {src.size(1)}, ...), but got size '
f'{other.size()}.')
if value is not None:
value = other.to(value.dtype).mul_(value)
else:
value = other
return src.set_value(value, layout='coo')
@torch.jit.script
def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else:
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
f' ...) or (1, {src.size(1)}, ...), but got size '
f'{other.size()}.')
if value is not None:
value = value.mul_(other.to(value.dtype))
else:
value = other
return src.set_value_(value, layout='coo')
@torch.jit.script
def mul_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
if value is not None:
value = value.mul(other.to(value.dtype))
else:
value = other
return src.set_value(value, layout=layout)
@torch.jit.script
def mul_nnz_(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
if value is not None:
value = value.mul_(other.to(value.dtype))
else:
value = other
return src.set_value_(value, layout=layout)
SparseTensor.mul = lambda self, other: mul(self, other)
SparseTensor.mul_ = lambda self, other: mul_(self, other)
SparseTensor.mul_nnz = lambda self, other, layout=None: mul_nnz(
self, other, layout)
SparseTensor.mul_nnz_ = lambda self, other, layout=None: mul_nnz_(
self, other, layout)
SparseTensor.__mul__ = SparseTensor.mul
SparseTensor.__rmul__ = SparseTensor.mul
SparseTensor.__imul__ = SparseTensor.mul_
from typing import Tuple
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def narrow(src: SparseTensor, dim: int, start: int,
length: int) -> SparseTensor:
if dim < 0:
dim = src.dim() + dim
if start < 0:
start = src.size(dim) + start
if dim == 0:
rowptr, col, value = src.csr()
rowptr = rowptr.narrow(0, start=start, length=length + 1)
row_start = rowptr[0]
rowptr = rowptr - row_start
row_length = rowptr[-1]
row = src.storage._row
if row is not None:
row = row.narrow(0, row_start, row_length) - start
col = col.narrow(0, row_start, row_length)
if value is not None:
value = value.narrow(0, row_start, row_length)
sparse_sizes = torch.Size([length, src.sparse_size(1)])
rowcount = src.storage._rowcount
if rowcount is not None:
rowcount = rowcount.narrow(0, start=start, length=length)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
elif dim == 1:
# This is faster than accessing `csc()` contrary to the `dim=0` case.
row, col, value = src.coo()
mask = (col >= start) & (col < start + length)
row = row[mask]
col = col[mask] - start
if value is not None:
value = value[mask]
sparse_sizes = torch.Size([src.sparse_size(0), length])
colptr = src.storage._colptr
if colptr is not None:
colptr = colptr.narrow(0, start=start, length=length + 1)
colptr = colptr - colptr[0]
colcount = src.storage._colcount
if colcount is not None:
colcount = colcount.narrow(0, start=start, length=length)
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colptr=colptr, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
else:
value = src.storage.value()
if value is not None:
return src.set_value(value.narrow(dim - 1, start, length),
layout='coo')
else:
raise ValueError
@torch.jit.script
def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
length: Tuple[int, int]) -> SparseTensor:
# This function builds the inverse operation of `cat_diag` and should hence
# only be used on *diagonally stacked* sparse matrices.
# That's the reason why this method is marked as *private*.
rowptr, col, value = src.csr()
rowptr = rowptr.narrow(0, start=start[0], length=length[0] + 1)
row_start = int(rowptr[0])
rowptr = rowptr - row_start
row_length = int(rowptr[-1])
row = src.storage._row
if row is not None:
row = row.narrow(0, row_start, row_length) - start[0]
col = col.narrow(0, row_start, row_length) - start[1]
if value is not None:
value = value.narrow(0, row_start, row_length)
sparse_sizes = length
rowcount = src.storage._rowcount
if rowcount is not None:
rowcount = rowcount.narrow(0, start[0], length[0])
colptr = src.storage._colptr
if colptr is not None:
colptr = colptr.narrow(0, start[1], length[1] + 1)
colptr = colptr - int(colptr[0]) # i.e. `row_start`
colcount = src.storage._colcount
if colcount is not None:
colcount = colcount.narrow(0, start[1], length[1])
csr2csc = src.storage._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.narrow(0, row_start, row_length) - row_start
csc2csr = src.storage._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.narrow(0, row_start, row_length) - row_start
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=colptr, colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
return src.from_storage(storage)
SparseTensor.narrow = lambda self, dim, start, length: narrow(
self, dim, start, length)
SparseTensor.__narrow_diag__ = lambda self, start, length: __narrow_diag__(
self, start, length)
from typing import Optional
import torch
from torch_scatter import scatter, segment_csr
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def reduction(src: SparseTensor, dim: Optional[int] = None,
reduce: str = 'sum') -> torch.Tensor:
value = src.storage.value()
if dim is None:
if value is not None:
if reduce == 'sum' or reduce == 'add':
return value.sum()
elif reduce == 'mean':
return value.mean()
elif reduce == 'min':
return value.min()
elif reduce == 'max':
return value.max()
else:
raise ValueError
else:
if reduce == 'sum' or reduce == 'add':
return torch.tensor(src.nnz(), dtype=src.dtype(),
device=src.device())
elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
return torch.tensor(1, dtype=src.dtype(), device=src.device())
else:
raise ValueError
else:
if dim < 0:
dim = src.dim() + dim
if dim == 0 and value is not None:
col = src.storage.col()
return scatter(value, col, dim=0, dim_size=src.size(0))
elif dim == 0 and value is None:
if reduce == 'sum' or reduce == 'add':
return src.storage.colcount().to(src.dtype())
elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
return torch.ones(src.size(1), dtype=src.dtype())
else:
raise ValueError
elif dim == 1 and value is not None:
return segment_csr(value, src.storage.rowptr(), None, reduce)
elif dim == 1 and value is None:
if reduce == 'sum' or reduce == 'add':
return src.storage.rowcount().to(src.dtype())
elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
return torch.ones(src.size(0), dtype=src.dtype())
else:
raise ValueError
elif dim > 1 and value is not None:
if reduce == 'sum' or reduce == 'add':
return value.sum(dim=dim - 1)
elif reduce == 'mean':
return value.mean(dim=dim - 1)
elif reduce == 'min':
return value.min(dim=dim - 1)[0]
elif reduce == 'max':
return value.max(dim=dim - 1)[0]
else:
raise ValueError
else:
raise ValueError
@torch.jit.script
def sum(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='sum')
@torch.jit.script
def mean(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='mean')
@torch.jit.script
def min(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='min')
@torch.jit.script
def max(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='max')
SparseTensor.sum = lambda self, dim=None: sum(self, dim)
SparseTensor.mean = lambda self, dim=None: mean(self, dim)
SparseTensor.min = lambda self, dim=None: min(self, dim)
SparseTensor.max = lambda self, dim=None: max(self, dim)
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.narrow import narrow
@torch.jit.script
def select(src: SparseTensor, dim: int, idx: int) -> SparseTensor:
return narrow(src, dim, start=idx, length=1)
SparseTensor.select = lambda self, dim, idx: select(self, dim, idx)
# import torch
from torch_scatter import scatter_add from torch_scatter import scatter_add
......
import torch import torch
from torch_sparse import transpose, to_scipy, from_scipy, coalesce from torch_sparse.tensor import SparseTensor
from torch_sparse.matmul import matmul
import torch_sparse.spspmm_cpu
if torch.cuda.is_available():
import torch_sparse.spspmm_cuda
def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False): def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
...@@ -25,82 +21,13 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False): ...@@ -25,82 +21,13 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
:rtype: (:class:`LongTensor`, :class:`Tensor`) :rtype: (:class:`LongTensor`, :class:`Tensor`)
""" """
if indexA.is_cuda and coalesced:
indexA, valueA = coalesce(indexA, valueA, m, k)
indexB, valueB = coalesce(indexB, valueB, k, n)
index, value = SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n)
return index.detach(), value
class SpSpMM(torch.autograd.Function):
@staticmethod
def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n)
ctx.m, ctx.k, ctx.n = m, k, n
ctx.save_for_backward(indexA, valueA, indexB, valueB, indexC)
return indexC, valueC
@staticmethod
def backward(ctx, grad_indexC, grad_valueC):
m, k = ctx.m, ctx.k
n = ctx.n
indexA, valueA, indexB, valueB, indexC = ctx.saved_tensors
grad_valueA = grad_valueB = None
if not grad_valueC.is_cuda:
if ctx.needs_input_grad[1] or ctx.needs_input_grad[1]:
grad_valueC = grad_valueC.clone()
if ctx.needs_input_grad[1]:
grad_valueA = torch_sparse.spspmm_cpu.spspmm_bw(
indexA, indexC.detach(), grad_valueC, indexB.detach(),
valueB, m, k)
if ctx.needs_input_grad[3]:
indexA, valueA = transpose(indexA, valueA, m, k)
indexC, grad_valueC = transpose(indexC, grad_valueC, m, n)
grad_valueB = torch_sparse.spspmm_cpu.spspmm_bw(
indexB, indexA.detach(), valueA, indexC.detach(),
grad_valueC, k, n)
else:
if ctx.needs_input_grad[1]:
grad_valueA = torch_sparse.spspmm_cuda.spspmm_bw(
indexA, indexC.detach(), grad_valueC.clone(),
indexB.detach(), valueB, m, k)
if ctx.needs_input_grad[3]:
indexA_T, valueA_T = transpose(indexA, valueA, m, k)
grad_indexB, grad_valueB = mm(indexA_T, valueA_T, indexC,
grad_valueC, k, m, n)
grad_valueB = lift(grad_indexB, grad_valueB, indexB, n)
return None, grad_valueA, None, grad_valueB, None, None, None
def mm(indexA, valueA, indexB, valueB, m, k, n):
assert valueA.dtype == valueB.dtype
if indexA.is_cuda:
return torch_sparse.spspmm_cuda.spspmm(indexA, valueA, indexB, valueB,
m, k, n)
A = to_scipy(indexA, valueA, m, k)
B = to_scipy(indexB, valueB, k, n)
C = A.dot(B).tocoo().tocsr().tocoo() # Force coalesce.
indexC, valueC = from_scipy(C)
return indexC, valueC
def lift(indexA, valueA, indexB, n): # pragma: no cover
idxA = indexA[0] * n + indexA[1]
idxB = indexB[0] * n + indexB[1]
max_value = max(idxA.max().item(), idxB.max().item()) + 1 A = SparseTensor(row=indexA[0], col=indexA[1], value=valueA,
valueB = valueA.new_zeros(max_value) sparse_sizes=torch.Size([m, k]), is_sorted=not coalesced)
B = SparseTensor(row=indexB[0], col=indexB[1], value=valueB,
sparse_sizes=torch.Size([k, n]), is_sorted=not coalesced)
valueB[idxA] = valueA C = matmul(A, B)
valueB = valueB[idxB] row, col, value = C.coo()
return valueB return torch.stack([row, col], dim=0), value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment