Commit 411dcb76 authored by rusty1s's avatar rusty1s
Browse files

select methods

parent 7f7036cd
......@@ -23,3 +23,5 @@ __all__ = [
from .storage import SparseStorage
from .tensor import SparseTensor
from .transpose import t
from .narrow import narrow
from .select import select
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def narrow(src, dim, start, length):
@torch.jit.script
def narrow(src: SparseTensor, dim: int, start: int, length: int):
dim = src.dim() + dim if dim < 0 else dim
start = src.size(dim) + start if start < 0 else start
if dim == 0:
rowptr, col, value = src.csr()
# rowptr = src.storage.rowptr
# Maintain `rowcount`...
rowcount = src.storage._rowcount
if rowcount is not None:
rowcount = rowcount.narrow(0, start=start, length=length)
rowptr = rowptr.narrow(0, start=start, length=length + 1)
row_start = rowptr[0]
......@@ -22,46 +19,60 @@ def narrow(src, dim, start, length):
row = src.storage._row
if row is not None:
row = row.narrow(0, row_start, row_length) - start
col = col.narrow(0, row_start, row_length)
if src.has_value():
if value is not None:
value = value.narrow(0, row_start, row_length)
sparse_size = torch.Size([length, src.sparse_size(1)])
sparse_sizes = torch.Size([length, src.sparse_size(1)])
storage = src.storage.__class__(row=row, rowptr=rowptr, col=col,
value=value, sparse_size=sparse_size,
rowcount=rowcount, is_sorted=True)
rowcount = src.storage._rowcount
if rowcount is not None:
rowcount = rowcount.narrow(0, start=start, length=length)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
elif dim == 1:
# This is faster than accessing `csc()` contrary to the `dim=0` case.
row, col, value = src.coo()
mask = (col >= start) & (col < start + length)
row, col = row[mask], col[mask] - start
row = row[mask]
col = col[mask] - start
# Maintain `colcount`...
colcount = src.storage._colcount
if colcount is not None:
colcount = colcount.narrow(0, start=start, length=length)
if value is not None:
value = value[mask]
sparse_sizes = torch.Size([src.sparse_size(0), length])
# Maintain `colptr`...
colptr = src.storage._colptr
if colptr is not None:
colptr = colptr.narrow(0, start=start, length=length + 1)
colptr = colptr - colptr[0]
if src.has_value():
value = value[mask]
sparse_size = torch.Size([src.sparse_size(0), length])
colcount = src.storage._colcount
if colcount is not None:
colcount = colcount.narrow(0, start=start, length=length)
storage = src.storage.__class__(row=row, col=col, value=value,
sparse_size=sparse_size, colptr=colptr,
colcount=colcount, is_sorted=True)
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colptr=colptr, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
else:
storage = src.storage.apply_value(
lambda x: x.narrow(dim - 1, start, length))
value = src.storage.value()
if value is not None:
return src.set_value(value.narrow(dim - 1, start, length),
layout='coo')
else:
raise ValueError
return src.from_storage(storage)
SparseTensor.narrow = lambda self, dim, start, length: narrow(
self, dim, start, length)
def select(src, dim, idx):
return src.narrow(dim, start=idx, length=1)
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.narrow import narrow
@torch.jit.script
def select(src: SparseTensor, dim: int, idx: int):
return narrow(src, dim, start=idx, length=1)
SparseTensor.select = lambda self, dim, idx: select(self, dim, idx)
......@@ -6,12 +6,10 @@ import scipy.sparse
from torch_sparse.storage import SparseStorage, get_layout
# from torch_sparse.narrow import narrow
# from torch_sparse.select import select
# from torch_sparse.index_select import index_select, index_select_nnz
# from torch_sparse.masked_select import masked_select, masked_select_nnz
# import torch_sparse.reduce
# from torch_sparse.diag import remove_diag, set_diag
# import torch_sparse.reduce
# from torch_sparse.matmul import matmul
# from torch_sparse.add import add, add_, add_nnz, add_nnz_
# from torch_sparse.mul import mul, mul_, mul_nnz, mul_nnz_
......
......@@ -5,32 +5,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def transpose(index, value, m, n, coalesced=True):
"""Transposes dimensions 0 and 1 of a sparse tensor.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of corresponding dense matrix.
n (int): The second dimension of corresponding dense matrix.
coalesced (bool, optional): If set to :obj:`False`, will not coalesce
the output. (default: :obj:`True`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
if value.dim() == 1 and not value.is_cuda:
mat = to_scipy(index, value, m, n).tocsc()
(col, row), value = from_scipy(mat)
index = torch.stack([row, col], dim=0)
return index, value
row, col = index
index = torch.stack([col, row], dim=0)
if coalesced:
index, value = coalesce(index, value, n, m)
return index, value
@torch.jit.script
def t(src: SparseTensor):
csr2csc = src.storage.csr2csc()
......@@ -60,3 +34,31 @@ def t(src: SparseTensor):
SparseTensor.t = lambda self: t(self)
###############################################################################
def transpose(index, value, m, n, coalesced=True):
"""Transposes dimensions 0 and 1 of a sparse tensor.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of corresponding dense matrix.
n (int): The second dimension of corresponding dense matrix.
coalesced (bool, optional): If set to :obj:`False`, will not coalesce
the output. (default: :obj:`True`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
if value.dim() == 1 and not value.is_cuda:
mat = to_scipy(index, value, m, n).tocsc()
(col, row), value = from_scipy(mat)
index = torch.stack([row, col], dim=0)
return index, value
row, col = index
index = torch.stack([col, row], dim=0)
if coalesced:
index, value = coalesce(index, value, n, m)
return index, value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment